From 4e7e9a796c42df2e0713a916f055bc697828fdb7 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 7 May 2026 21:13:22 +0300 Subject: [PATCH 01/17] feat(di): add InjectionToken, Scope, ProviderDescriptor, normalize_provider Co-Authored-By: Claude Sonnet 4.6 --- nest/common/provider.py | 82 ++++++++++++++++++++ tests/test_common/test_provider.py | 119 +++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+) create mode 100644 nest/common/provider.py create mode 100644 tests/test_common/test_provider.py diff --git a/nest/common/provider.py b/nest/common/provider.py new file mode 100644 index 0000000..ed6d698 --- /dev/null +++ b/nest/common/provider.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, List, Optional, Type, Union + + +class Scope(str, Enum): + SINGLETON = "singleton" + TRANSIENT = "transient" + REQUEST = "request" + + +class InjectionToken: + """Named token for injecting non-class values (strings, primitives, configs).""" + + def __init__(self, name: str, description: str = "") -> None: + self.name = name + self.description = description + + def __repr__(self) -> str: + return f"InjectionToken({self.name!r})" + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + return isinstance(other, InjectionToken) and self.name == other.name + + +@dataclass +class ProviderDescriptor: + """Normalized provider definition. Exactly one of use_class/use_value/use_factory/use_existing must be set.""" + + provide: Union[Type, InjectionToken, str] + use_class: Optional[Type] = None + use_value: Any = None + use_factory: Optional[Callable] = None + use_existing: Optional[Union[Type, InjectionToken]] = None + scope: Scope = Scope.SINGLETON + inject: List[Any] = field(default_factory=list) + + +def normalize_provider( + provider: Union[Type, dict, ProviderDescriptor], +) -> ProviderDescriptor: + """Convert any provider form (class, dict, or ProviderDescriptor) to a ProviderDescriptor.""" + if isinstance(provider, ProviderDescriptor): + return provider + + if isinstance(provider, dict): + provide = provider["provide"] + scope = provider.get("scope", Scope.SINGLETON) + if "useValue" in provider: + return ProviderDescriptor( + provide=provide, use_value=provider["useValue"], scope=scope + ) + if "useClass" in provider: + return ProviderDescriptor( + provide=provide, use_class=provider["useClass"], scope=scope + ) + if "useFactory" in provider: + return ProviderDescriptor( + provide=provide, + use_factory=provider["useFactory"], + inject=provider.get("inject", []), + scope=scope, + ) + if "useExisting" in provider: + return ProviderDescriptor( + provide=provide, use_existing=provider["useExisting"], scope=scope + ) + raise ValueError( + f"Invalid provider descriptor: {provider!r}. " + "Must contain one of: useValue, useClass, useFactory, useExisting" + ) + + if not callable(provider): + raise ValueError( + f"Provider must be a class, dict, or ProviderDescriptor, got {provider!r}" + ) + return ProviderDescriptor(provide=provider, use_class=provider, scope=Scope.SINGLETON) diff --git a/tests/test_common/test_provider.py b/tests/test_common/test_provider.py new file mode 100644 index 0000000..0a24d7d --- /dev/null +++ b/tests/test_common/test_provider.py @@ -0,0 +1,119 @@ +import pytest +from nest.common.provider import ( + InjectionToken, + Scope, + ProviderDescriptor, + normalize_provider, +) + + +class SomeService: + pass + + +class OtherService: + pass + + +# --- InjectionToken --- + +def test_injection_token_has_name(): + token = InjectionToken("DB_URL") + assert token.name == "DB_URL" + + +def test_injection_token_repr(): + token = InjectionToken("MY_TOKEN") + assert "MY_TOKEN" in repr(token) + + +def test_injection_tokens_with_same_name_are_equal(): + assert InjectionToken("A") == InjectionToken("A") + + +def test_injection_tokens_with_different_names_are_not_equal(): + assert InjectionToken("A") != InjectionToken("B") + + +def test_injection_token_is_hashable(): + token = InjectionToken("X") + d = {token: "value"} + assert d[token] == "value" + + +# --- Scope --- + +def test_scope_values_exist(): + assert Scope.SINGLETON + assert Scope.TRANSIENT + assert Scope.REQUEST + + +# --- normalize_provider: class form --- + +def test_normalize_class_provider_sets_use_class(): + desc = normalize_provider(SomeService) + assert desc.provide is SomeService + assert desc.use_class is SomeService + assert desc.scope == Scope.SINGLETON + + +# --- normalize_provider: useValue dict --- + +def test_normalize_use_value(): + desc = normalize_provider({"provide": "DB_URL", "useValue": "postgres://localhost/db"}) + assert desc.provide == "DB_URL" + assert desc.use_value == "postgres://localhost/db" + assert desc.use_class is None + + +# --- normalize_provider: useClass dict --- + +def test_normalize_use_class(): + desc = normalize_provider({"provide": SomeService, "useClass": OtherService}) + assert desc.provide is SomeService + assert desc.use_class is OtherService + + +# --- normalize_provider: useFactory dict --- + +def test_normalize_use_factory(): + factory = lambda: SomeService() + desc = normalize_provider({ + "provide": SomeService, + "useFactory": factory, + "inject": [OtherService], + }) + assert desc.provide is SomeService + assert desc.use_factory is factory + assert desc.inject == [OtherService] + + +# --- normalize_provider: useExisting dict --- + +def test_normalize_use_existing(): + desc = normalize_provider({"provide": SomeService, "useExisting": OtherService}) + assert desc.provide is SomeService + assert desc.use_existing is OtherService + + +# --- normalize_provider: scope override --- + +def test_normalize_scope_override(): + desc = normalize_provider({"provide": SomeService, "useClass": SomeService, "scope": Scope.TRANSIENT}) + assert desc.scope == Scope.TRANSIENT + + +# --- normalize_provider: invalid dict --- + +def test_normalize_invalid_dict_raises(): + with pytest.raises(ValueError, match="Invalid provider descriptor"): + normalize_provider({"provide": SomeService}) + + +# --- normalize_provider: already a ProviderDescriptor --- + +def test_normalize_passthrough_descriptor(): + desc = ProviderDescriptor(provide=SomeService, use_class=SomeService) + result = normalize_provider(desc) + assert result is desc From c22a98fe65f821adb283f754219f3901210bb53e Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 7 May 2026 21:23:28 +0300 Subject: [PATCH 02/17] feat(di): add DependencyGraph with cycle detection and topological sort Co-Authored-By: Claude Sonnet 4.6 --- nest/core/dependency_graph.py | 83 +++++++++++++++++ tests/test_core/test_dependency_graph.py | 112 +++++++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 nest/core/dependency_graph.py create mode 100644 tests/test_core/test_dependency_graph.py diff --git a/nest/core/dependency_graph.py b/nest/core/dependency_graph.py new file mode 100644 index 0000000..d053fb4 --- /dev/null +++ b/nest/core/dependency_graph.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any, List, Set + + +class CycleError(Exception): + """Raised when a circular dependency is detected in the provider graph.""" + pass + + +class DependencyGraph: + """ + Directed graph where an edge A → B means 'A depends on B'. + Used to detect circular dependencies and determine initialization order. + """ + + def __init__(self) -> None: + self._edges: dict[Any, Set[Any]] = {} + + @property + def nodes(self) -> Set[Any]: + return set(self._edges.keys()) + + def add_node(self, node: Any) -> None: + if node not in self._edges: + self._edges[node] = set() + + def add_dependency(self, dependent: Any, dependency: Any) -> None: + """Record that `dependent` requires `dependency` to be initialized first.""" + self.add_node(dependent) + self.add_node(dependency) + self._edges[dependent].add(dependency) + + def detect_cycles(self) -> List[List[Any]]: + """Return a list of cycles (each cycle is a list of nodes). Empty list = no cycles.""" + visited: Set[Any] = set() + path: List[Any] = [] + path_set: Set[Any] = set() + cycles: List[List[Any]] = [] + + def dfs(node: Any) -> None: + visited.add(node) + path.append(node) + path_set.add(node) + for dep in self._edges.get(node, set()): + if dep not in visited: + dfs(dep) + elif dep in path_set: + idx = path.index(dep) + cycles.append(list(path[idx:]) + [dep]) + path.pop() + path_set.discard(node) + + for node in list(self._edges.keys()): + if node not in visited: + dfs(node) + return cycles + + def topological_sort(self) -> List[Any]: + """Return nodes in initialization order: dependencies come before dependents.""" + visited: Set[Any] = set() + result: List[Any] = [] + + def visit(node: Any) -> None: + if node in visited: + return + visited.add(node) + for dep in self._edges.get(node, set()): + visit(dep) + result.append(node) + + for node in list(self._edges.keys()): + visit(node) + return result + + def validate(self) -> None: + """Raise CycleError if any circular dependencies exist.""" + cycles = self.detect_cycles() + if cycles: + chain = " → ".join( + getattr(n, "__name__", repr(n)) for n in cycles[0] + ) + raise CycleError(f"Circular dependency detected: {chain}") diff --git a/tests/test_core/test_dependency_graph.py b/tests/test_core/test_dependency_graph.py new file mode 100644 index 0000000..89e27b7 --- /dev/null +++ b/tests/test_core/test_dependency_graph.py @@ -0,0 +1,112 @@ +import pytest +from nest.core.dependency_graph import DependencyGraph, CycleError + + +class A: pass +class B: pass +class C: pass +class D: pass + + +def test_add_node_is_idempotent(): + g = DependencyGraph() + g.add_node(A) + g.add_node(A) + assert A in g.nodes + + +def test_add_dependency_adds_both_nodes(): + g = DependencyGraph() + g.add_dependency(A, B) # A depends on B + assert A in g.nodes + assert B in g.nodes + + +def test_no_cycles_in_linear_chain(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, C) + assert g.detect_cycles() == [] + + +def test_direct_cycle_detected(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, A) + cycles = g.detect_cycles() + assert len(cycles) == 1 + cycle = cycles[0] + assert A in cycle + assert B in cycle + + +def test_indirect_cycle_detected(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, C) + g.add_dependency(C, A) + cycles = g.detect_cycles() + assert len(cycles) == 1 + assert A in cycles[0] + assert B in cycles[0] + assert C in cycles[0] + + +def test_no_cycle_in_diamond(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(A, C) + g.add_dependency(B, D) + g.add_dependency(C, D) + assert g.detect_cycles() == [] + + +def test_topological_sort_single_node(): + g = DependencyGraph() + g.add_node(A) + result = g.topological_sort() + assert result == [A] + + +def test_topological_sort_dependency_comes_first(): + g = DependencyGraph() + g.add_dependency(A, B) # A depends on B → B must come first + result = g.topological_sort() + assert result.index(B) < result.index(A) + + +def test_topological_sort_chain(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, C) + result = g.topological_sort() + assert result.index(C) < result.index(B) < result.index(A) + + +def test_topological_sort_diamond(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(A, C) + g.add_dependency(B, D) + g.add_dependency(C, D) + result = g.topological_sort() + assert result.index(D) < result.index(B) + assert result.index(D) < result.index(C) + assert result.index(B) < result.index(A) + assert result.index(C) < result.index(A) + + +def test_validate_raises_cycle_error_with_chain(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, A) + with pytest.raises(CycleError) as exc_info: + g.validate() + assert "A" in str(exc_info.value) or "B" in str(exc_info.value) + + +def test_validate_passes_for_valid_graph(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, C) + g.validate() # must not raise From 871996377967df16716e09e5462c56e4b57507b0 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 7 May 2026 21:28:21 +0300 Subject: [PATCH 03/17] feat(di): add CompiledModule dataclass and update ModuleCompiler to normalize providers Co-Authored-By: Claude Sonnet 4.6 --- nest/common/module.py | 43 ++++++++-- nest/core/pynest_container.py | 28 ++++-- tests/test_common/test_module_compiler.py | 100 ++++++++++++++++++++++ 3 files changed, 156 insertions(+), 15 deletions(-) create mode 100644 tests/test_common/test_module_compiler.py diff --git a/nest/common/module.py b/nest/common/module.py index a877eab..b3070e5 100644 --- a/nest/common/module.py +++ b/nest/common/module.py @@ -2,10 +2,19 @@ import random import string import uuid +from dataclasses import dataclass, field from typing import Any, List, Type from uuid import uuid4 -from nest.core import Module +@dataclass +class CompiledModule: + """The result of compiling a @Module-decorated class. Immutable snapshot used by the container.""" + token: str + metatype: Type + imports: List[Type] = field(default_factory=list) + controllers: List[Type] = field(default_factory=list) + exports: List[Any] = field(default_factory=list) + provider_descriptors: List[Any] = field(default_factory=list) class ModulesContainer(dict): @@ -187,18 +196,34 @@ class ModuleCompiler: def __init__(self, module_token_factory: ModuleTokenFactory = ModuleTokenFactory()): self.module_token_factory = module_token_factory - def compile(self, metatype: Type[Any]): - metadata = self.extract_metadata(metatype) - module_type = metadata["type"] - dynamic_metadata = metadata["dynamic_metadata"] - token = self.module_token_factory.create(module_type, dynamic_metadata) - return ModuleFactory(module_type, token, dynamic_metadata) + def compile(self, metatype: Type[Any]) -> CompiledModule: + from nest.common.provider import normalize_provider # local import avoids circular + + if not self.has_module_metadata(metatype): + raise Exception(f"{metatype.__name__} has no metadata found") + + raw_providers = getattr(metatype, "providers", []) or [] + controllers = getattr(metatype, "controllers", []) or [] + imports = getattr(metatype, "imports", []) or [] + exports = getattr(metatype, "exports", []) or [] + + provider_descriptors = [normalize_provider(p) for p in raw_providers] + token = self.module_token_factory.create(metatype) + + return CompiledModule( + token=token, + metatype=metatype, + imports=list(imports), + controllers=list(controllers), + exports=list(exports), + provider_descriptors=provider_descriptors, + ) def extract_metadata(self, metatype) -> dict: + # Kept for backward compat with PyNestApplicationContext.select() metadata = {"type": metatype, "dynamic_metadata": {}} - if not self.has_module_metadata(metatype): - raise Exception(f"{metatype.__name__} as no metadata found") + raise Exception(f"{metatype.__name__} has no metadata found") for props in ["imports", "providers", "controllers", "exports"]: metadata["dynamic_metadata"][props] = getattr(metatype, props, []) return metadata diff --git a/nest/core/pynest_container.py b/nest/core/pynest_container.py index 9e524f8..58005b8 100644 --- a/nest/core/pynest_container.py +++ b/nest/core/pynest_container.py @@ -11,6 +11,7 @@ UnknownModuleException, ) from nest.common.module import ( + CompiledModule, Module, ModuleCompiler, ModuleFactory, @@ -95,7 +96,7 @@ def add_module(self, metaclass) -> dict: return {"module_ref": self.modules.get(token), "inserted": False} return {"module_ref": self.register_module(module_factory), "inserted": True} - def register_module(self, module_factory: ModuleFactory) -> Module: + def register_module(self, module_factory) -> Module: """ Register a module in the container. @@ -104,18 +105,33 @@ def register_module(self, module_factory: ModuleFactory) -> Module: associated with the module, and logs the detection of the module. Args: - module_factory (ModuleFactory): The factory object that contains the type and metadata - for creating the module. + module_factory: Either a CompiledModule or legacy ModuleFactory containing module info. Returns: Module: The module reference that has been registered in the container. """ - module_ref = Module(module_factory.type, self) + if isinstance(module_factory, CompiledModule): + metatype = module_factory.metatype + # Build legacy metadata dict from CompiledModule fields for backward compat + dynamic_metadata = { + "imports": module_factory.imports, + "providers": [ + desc.use_class for desc in module_factory.provider_descriptors + if desc.use_class is not None + ], + "controllers": module_factory.controllers, + "exports": module_factory.exports, + } + else: + metatype = module_factory.type + dynamic_metadata = module_factory.dynamic_metadata + + module_ref = Module(metatype, self) module_ref.token = module_factory.token self._modules[module_factory.token] = module_ref - self.add_metadata(module_factory.token, module_factory.dynamic_metadata) + self.add_metadata(module_factory.token, dynamic_metadata) self.add_import(module_factory.token) self.add_providers( self._get_providers(module_factory.token), module_factory.token @@ -125,7 +141,7 @@ def register_module(self, module_factory: ModuleFactory) -> Module: ) self.logger.info( - click.style(module_factory.type.__name__ + " Detected ", fg="green") + click.style(metatype.__name__ + " Detected ", fg="green") ) return module_ref diff --git a/tests/test_common/test_module_compiler.py b/tests/test_common/test_module_compiler.py new file mode 100644 index 0000000..23b2799 --- /dev/null +++ b/tests/test_common/test_module_compiler.py @@ -0,0 +1,100 @@ +import pytest +from nest.common.module import ModuleCompiler, ModuleTokenFactory, CompiledModule +from nest.common.provider import ProviderDescriptor, Scope +from nest.core.decorators.module import Module +from nest.core.decorators.injectable import Injectable + + +@Injectable +class ServiceA: + pass + + +@Injectable +class ServiceB: + def __init__(self, a: ServiceA): + self.a = a + + +@Module(providers=[ServiceA]) +class ModuleA: + pass + + +@Module(providers=[ServiceB], imports=[ModuleA], exports=[ServiceB]) +class ModuleB: + pass + + +@Module(providers=[{"provide": "DB_URL", "useValue": "postgres://localhost/test"}]) +class ConfigModule: + pass + + +def test_compiled_module_has_token(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleA) + assert isinstance(result, CompiledModule) + assert result.token is not None + assert len(result.token) > 0 + + +def test_compiled_module_has_metatype(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleA) + assert result.metatype is ModuleA + + +def test_compiled_module_providers_are_normalized(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleA) + assert len(result.provider_descriptors) == 1 + desc = result.provider_descriptors[0] + assert isinstance(desc, ProviderDescriptor) + assert desc.use_class is ServiceA + + +def test_compiled_module_dict_provider_is_normalized(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ConfigModule) + assert len(result.provider_descriptors) == 1 + desc = result.provider_descriptors[0] + assert desc.use_value == "postgres://localhost/test" + assert desc.use_class is None + + +def test_compiled_module_imports(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleB) + assert ModuleA in result.imports + + +def test_compiled_module_exports(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleB) + assert ServiceB in result.exports + + +def test_same_module_gets_same_token(): + factory = ModuleTokenFactory() + compiler = ModuleCompiler(factory) + token1 = compiler.compile(ModuleA).token + token2 = compiler.compile(ModuleA).token + assert token1 == token2 + + +def test_different_modules_get_different_tokens(): + factory = ModuleTokenFactory() + compiler = ModuleCompiler(factory) + token_a = compiler.compile(ModuleA).token + token_b = compiler.compile(ModuleB).token + assert token_a != token_b + + +def test_module_without_decorator_raises(): + class Bare: + pass + + compiler = ModuleCompiler(ModuleTokenFactory()) + with pytest.raises(Exception, match="no metadata"): + compiler.compile(Bare) From 37f5fd262ee518ec4e42e760469a91c23f1667b2 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 7 May 2026 21:34:56 +0300 Subject: [PATCH 04/17] =?UTF-8?q?feat(di):=20add=20PyNestInjectorModule=20?= =?UTF-8?q?and=20build=5Finjector=20=E2=80=94=20bridges=20ProviderDescript?= =?UTF-8?q?ors=20to=20injector=20bindings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- nest/core/injector_module.py | 76 +++++++++++++++ tests/test_core/test_injector_module.py | 117 ++++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 nest/core/injector_module.py create mode 100644 tests/test_core/test_injector_module.py diff --git a/nest/core/injector_module.py b/nest/core/injector_module.py new file mode 100644 index 0000000..bf32b20 --- /dev/null +++ b/nest/core/injector_module.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any, List + +from injector import Injector, Module as InjectorModule, noscope, singleton + +from nest.common.provider import InjectionToken, ProviderDescriptor, Scope + + +def _injector_scope(scope: Scope): + if scope == Scope.SINGLETON: + return singleton + return noscope # TRANSIENT and REQUEST both use noscope at this layer + + +def _to_key(token: Any) -> Any: + """Convert an InjectionToken or string to an injector-compatible key. + + Both InjectionToken and plain strings are usable directly as injector + binding keys (the injector library accepts any hashable as a key). + """ + return token + + +class PyNestInjectorModule(InjectorModule): + """Translates a list of ProviderDescriptors into injector bindings. + + use_class and use_value providers are bound eagerly inside configure(). + use_factory and use_existing providers are deferred to build_injector() + so that their dependencies (or aliased singletons) are already resolved. + """ + + def __init__(self, descriptors: List[ProviderDescriptor]) -> None: + self._descriptors = [ + d for d in descriptors + if d.use_factory is None and d.use_existing is None + ] + + def configure(self, binder) -> None: + from injector import InstanceProvider + + for desc in self._descriptors: + scope = _injector_scope(desc.scope) + key = _to_key(desc.provide) + + if desc.use_value is not None: + binder.bind(key, to=InstanceProvider(desc.use_value)) + elif desc.use_class is not None: + binder.bind(key, to=desc.use_class, scope=scope) + + +def build_injector(descriptors: List[ProviderDescriptor]) -> Injector: + """ + Build and return a configured Injector from a list of ProviderDescriptors. + + use_factory and use_existing providers are resolved post-build so that + their dependencies and aliased singletons are already in the injector. + """ + from injector import InstanceProvider + + injector = Injector([PyNestInjectorModule(descriptors)]) + + for desc in descriptors: + key = _to_key(desc.provide) + + if desc.use_factory is not None: + deps = [injector.get(_to_key(t)) for t in desc.inject] + instance = desc.use_factory(*deps) + injector.binder.bind(key, to=InstanceProvider(instance)) + + elif desc.use_existing is not None: + # Resolve the aliased key to get the same (singleton) instance + existing_instance = injector.get(_to_key(desc.use_existing)) + injector.binder.bind(key, to=InstanceProvider(existing_instance)) + + return injector diff --git a/tests/test_core/test_injector_module.py b/tests/test_core/test_injector_module.py new file mode 100644 index 0000000..e48c743 --- /dev/null +++ b/tests/test_core/test_injector_module.py @@ -0,0 +1,117 @@ +import pytest +from injector import Injector +from nest.core.injector_module import build_injector +from nest.common.provider import ProviderDescriptor, Scope, InjectionToken +from nest.core.decorators.injectable import Injectable + + +@Injectable +class Repo: + def find_all(self): + return ["item1"] + + +@Injectable +class Service: + def __init__(self, repo: Repo): + self.repo = repo + + def get_items(self): + return self.repo.find_all() + + +def test_build_injector_resolves_class_provider(): + descriptors = [ + ProviderDescriptor(provide=Repo, use_class=Repo), + ProviderDescriptor(provide=Service, use_class=Service), + ] + injector = build_injector(descriptors) + service = injector.get(Service) + assert isinstance(service, Service) + assert isinstance(service.repo, Repo) + + +def test_build_injector_singleton_scope_returns_same_instance(): + descriptors = [ProviderDescriptor(provide=Repo, use_class=Repo, scope=Scope.SINGLETON)] + injector = build_injector(descriptors) + a = injector.get(Repo) + b = injector.get(Repo) + assert a is b + + +def test_build_injector_transient_scope_returns_new_instance(): + descriptors = [ProviderDescriptor(provide=Repo, use_class=Repo, scope=Scope.TRANSIENT)] + injector = build_injector(descriptors) + a = injector.get(Repo) + b = injector.get(Repo) + assert a is not b + + +def test_build_injector_use_value_injection_token(): + token = InjectionToken("DB_URL") + descriptors = [ProviderDescriptor(provide=token, use_value="postgres://localhost/test")] + injector = build_injector(descriptors) + value = injector.get(token) + assert value == "postgres://localhost/test" + + +def test_build_injector_use_value_string_key(): + descriptors = [ProviderDescriptor(provide="DB_URL", use_value="postgres://localhost/test")] + injector = build_injector(descriptors) + value = injector.get("DB_URL") + assert value == "postgres://localhost/test" + + +def test_build_injector_use_factory_no_deps(): + call_count = [0] + + def factory(): + call_count[0] += 1 + return Repo() + + descriptors = [ProviderDescriptor(provide=Repo, use_factory=factory, inject=[])] + injector = build_injector(descriptors) + result = injector.get(Repo) + assert isinstance(result, Repo) + assert call_count[0] == 1 + + +def test_build_injector_use_factory_with_deps(): + descriptors = [ + ProviderDescriptor(provide=Repo, use_class=Repo), + ProviderDescriptor( + provide=Service, + use_factory=lambda repo: Service(repo), + inject=[Repo], + ), + ] + injector = build_injector(descriptors) + service = injector.get(Service) + assert isinstance(service.repo, Repo) + + +def test_build_injector_use_existing_aliases(): + descriptors = [ + ProviderDescriptor(provide=Repo, use_class=Repo), + ProviderDescriptor(provide=Service, use_class=Service), + ProviderDescriptor(provide="IService", use_existing=Service), + ] + injector = build_injector(descriptors) + service1 = injector.get(Service) + service2 = injector.get("IService") + assert service1 is service2 + + +def test_use_class_different_from_provide(): + class MockRepo(Repo): + def find_all(self): + return [] + + descriptors = [ + ProviderDescriptor(provide=Repo, use_class=MockRepo), + ProviderDescriptor(provide=Service, use_class=Service), + ] + injector = build_injector(descriptors) + service = injector.get(Service) + assert isinstance(service.repo, MockRepo) + assert service.get_items() == [] From a916606f5c98651cb0e47b5076da135cefd8f0eb Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 7 May 2026 21:38:11 +0300 Subject: [PATCH 05/17] =?UTF-8?q?feat(di):=20rewrite=20PyNestContainer=20?= =?UTF-8?q?=E2=80=94=20non-singleton,=20build()=20+=20get()=20API,=20insta?= =?UTF-8?q?nce-based=20injection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- nest/core/pynest_container.py | 364 +++++++++-------------- tests/test_core/test_pynest_container.py | 164 ++++++++-- 2 files changed, 286 insertions(+), 242 deletions(-) diff --git a/nest/core/pynest_container.py b/nest/core/pynest_container.py index 58005b8..fd20df8 100644 --- a/nest/core/pynest_container.py +++ b/nest/core/pynest_container.py @@ -1,248 +1,168 @@ +from __future__ import annotations + +import inspect import logging -from typing import Any, List, Optional, Union - -import click -from injector import Injector, UnknownProvider, singleton - -from nest.common.constants import DEPENDENCIES, INJECTABLE_TOKEN -from nest.common.exceptions import ( - CircularDependencyException, - NoneInjectableException, - UnknownModuleException, -) -from nest.common.module import ( - CompiledModule, - Module, - ModuleCompiler, - ModuleFactory, - ModulesContainer, - ModuleTokenFactory, -) - -TController = type("TController", (), {}) -TProvider = type("TProvider", (), {}) +from typing import Any, Dict, List, Optional, Type, Union + +from nest.common.exceptions import CircularDependencyException +from nest.common.module import CompiledModule, ModuleCompiler, ModuleTokenFactory +from nest.common.provider import InjectionToken, ProviderDescriptor +from nest.core.dependency_graph import DependencyGraph +from nest.core.injector_module import build_injector, _to_key + + +class ModuleRef: + """Internal container representation of a registered module.""" + + def __init__(self, token: str, metatype: Type, compiled: CompiledModule) -> None: + self.token = token + self.metatype = metatype + self.compiled = compiled + + @property + def name(self) -> str: + return self.metatype.__name__ class PyNestContainer: """ - A singleton container class for managing modules, providers, and dependencies - in a PyNest application. + IoC container managing the module graph, provider bindings, and instance lifecycle. + NOT a singleton — one fresh instance per application, created by PyNestFactory. """ - _instance = None - _dependencies = None - - def __new__(cls): - """Create a singleton instance of PyNestContainer.""" - if cls._instance is None: - cls._instance = super(PyNestContainer, cls).__new__(cls) - return cls._instance - - def __init__(self): - """Initialize the PyNestContainer.""" - if not hasattr(self, "_initialized"): # Prevent reinitialization - self.logger = logging.getLogger("pynest") - self._injector = Injector() - self._global_modules = set() - self._modules = ModulesContainer() - self._module_token_factory = ModuleTokenFactory() - self._module_compiler = ModuleCompiler(self._module_token_factory) - self._modules_metadata = {} - self._initialized = True + def __init__(self) -> None: + self._logger = logging.getLogger("pynest.container") + self._injector = None + self._modules: Dict[str, ModuleRef] = {} + self._all_descriptors: List[ProviderDescriptor] = [] + self._controller_classes: List[Type] = [] + self._module_token_factory = ModuleTokenFactory() + self._module_compiler = ModuleCompiler(self._module_token_factory) + + # ── Public API ───────────────────────────────────────────────────────────── @property - def modules(self): + def modules(self) -> Dict[str, ModuleRef]: return self._modules @property def module_token_factory(self): return self._module_token_factory - @property - def modules_metadata(self): - return self._modules_metadata - @property def module_compiler(self): return self._module_compiler - def get_instance( - self, - dependency: TProvider, - provider: Optional[Union[TProvider, TController]] = None, - ): - try: - self._injector.binder.bind(dependency, scope=singleton) - instance = self._injector.get(dependency) - self.logger.info(click.style(dependency.__name__ + " Detected ", fg="blue")) - except UnknownProvider: - raise Exception(f"Unknown provider {provider}") - return instance - - def add_module(self, metaclass) -> dict: - """ - Add a module to the container. + def add_module(self, module_class: Type) -> dict: + """Compile and register a module and all its imports recursively.""" + compiled = self._module_compiler.compile(module_class) + token = compiled.token - Args: - metaclass: The metaclass of the module to be added. + if token in self._modules: + return {"module_ref": self._modules[token], "inserted": False} - Returns: - dict: A dictionary containing the module reference and a - boolean flag indicating if it was newly inserted. - """ - module_factory = self._module_compiler.compile(metaclass) - token = module_factory.token - if self._modules.has(token): - return {"module_ref": self.modules.get(token), "inserted": False} - return {"module_ref": self.register_module(module_factory), "inserted": True} - - def register_module(self, module_factory) -> Module: - """ - Register a module in the container. + # Register imported modules first (depth-first) + for imported in compiled.imports: + self.add_module(imported) - This method creates a module reference from the provided module factory, registers - the module within the container, adds metadata, imports, providers, and controllers - associated with the module, and logs the detection of the module. + module_ref = ModuleRef(token=token, metatype=module_class, compiled=compiled) + self._modules[token] = module_ref + self._all_descriptors.extend(compiled.provider_descriptors) + self._controller_classes.extend(compiled.controllers) - Args: - module_factory: Either a CompiledModule or legacy ModuleFactory containing module info. - - Returns: - Module: The module reference that has been registered in the container. + self._logger.info(f"Module registered: {module_class.__name__}") + return {"module_ref": module_ref, "inserted": True} + def build(self) -> None: + """ + Validate the dependency graph and build the injector. + Must be called once after all add_module() calls, before any get() calls. """ - if isinstance(module_factory, CompiledModule): - metatype = module_factory.metatype - # Build legacy metadata dict from CompiledModule fields for backward compat - dynamic_metadata = { - "imports": module_factory.imports, - "providers": [ - desc.use_class for desc in module_factory.provider_descriptors - if desc.use_class is not None - ], - "controllers": module_factory.controllers, - "exports": module_factory.exports, - } - else: - metatype = module_factory.type - dynamic_metadata = module_factory.dynamic_metadata - - module_ref = Module(metatype, self) - module_ref.token = module_factory.token - self._modules[module_factory.token] = module_ref - - self.add_metadata(module_factory.token, dynamic_metadata) - self.add_import(module_factory.token) - self.add_providers( - self._get_providers(module_factory.token), module_factory.token - ) - self.add_controllers( - self._get_controllers(module_factory.token), module_factory.token - ) - - self.logger.info( - click.style(metatype.__name__ + " Detected ", fg="green") - ) - - return module_ref - - def add_metadata(self, token: str, module_metadata) -> None: - """Add metadata for a module.""" - if module_metadata: - self._modules_metadata[token] = module_metadata - - def add_import(self, token: str): - """Add imports for a module.""" - if not self.modules.has(token): - return - module_metadata = self._modules_metadata.get(token) - module_ref: Module = self.modules.get(token) - imports_mod: List[Any] = module_metadata.get("imports") - self.add_modules(imports_mod) - module_ref.add_imports(imports_mod) - - def add_modules(self, modules: List[Any]) -> None: - """Add multiple modules to the container.""" - if modules: - for module in modules: - self.add_module(module) - - def add_providers(self, providers: List[Any], module_token: str) -> None: - """Add multiple providers to a module.""" - for provider in providers: - self.add_provider(module_token, provider) - - def add_provider(self, token: str, provider): - """Add a provider to a module.""" - module_ref: Module = self.modules[token] - if not provider: - raise CircularDependencyException(module_ref.metatype) - - if not module_ref: - raise UnknownModuleException() - - if not hasattr(provider, INJECTABLE_TOKEN): - error_message = f""" - {click.style(provider.__name__, fg='red')} is not injectable. - To make {provider.__name__} injectable, apply the {click.style("@Injectable decorator", fg='green')} - to the class definition, or remove {click.style(provider.__name__, fg='red')} from the provider array - of the Module class. Please check your code and ensure that the decorator is correctly applied to the - class. - """ - raise NoneInjectableException(error_message) - - for dependency_name, dependency_instance in getattr( - provider, DEPENDENCIES - ).items(): + self._validate_dependency_graph() + + # Controller classes need singleton bindings too so the injector can resolve them + all_descriptors = self._all_descriptors + self._make_controller_descriptors() + self._injector = build_injector(all_descriptors) + self._logger.info("Container built successfully") + + def get(self, token: Union[Type, InjectionToken, str]) -> Any: + """Retrieve a fully-wired instance from the container.""" + if self._injector is None: + raise RuntimeError( + "Container not built. Call container.build() before resolving providers." + ) + return self._injector.get(_to_key(token)) + + def get_controller_instance(self, controller_class: Type) -> Any: + """Get a controller instance with all its service dependencies injected.""" + return self.get(controller_class) + + def clear(self) -> None: + """Reset container state. Useful in tests.""" + self._injector = None + self._modules.clear() + self._all_descriptors.clear() + self._controller_classes.clear() + + # ── Internal ─────────────────────────────────────────────────────────────── + + def _make_controller_descriptors(self) -> List[ProviderDescriptor]: + from nest.common.provider import Scope + return [ + ProviderDescriptor(provide=cls, use_class=cls, scope=Scope.SINGLETON) + for cls in self._controller_classes + ] + + def _validate_dependency_graph(self) -> None: + """Build a DAG from all class providers and raise CircularDependencyException on cycles.""" + import sys + + graph = DependencyGraph() + + # Build a name→class lookup from all registered providers so forward refs can be resolved + provider_classes = { + desc.use_class.__name__: desc.use_class + for desc in self._all_descriptors + if desc.use_class is not None + } + + for desc in self._all_descriptors: + if desc.use_class is None: + continue + target = desc.use_class + graph.add_node(target) + try: + sig = inspect.signature(target.__init__) + except (ValueError, TypeError): + continue + + # Resolve type hints, handling string forward references try: - instance = self.get_instance(dependency_instance, provider) - setattr(provider, dependency_name, instance) - except Exception as e: - self.logger.error(e) - raise e - - module_ref.add_provider(provider) - - def _get_providers(self, token: str) -> List[Any]: - """Get providers from the module metadata.""" - return self.modules_metadata[token]["providers"] - - def add_controllers(self, controllers: List[Any], module_token: str) -> None: - """Add multiple controllers to a module.""" - for controller in controllers: - self._add_controller(module_token, controller) - - def _add_controller(self, token: str, controller: TController) -> None: - """Add a controller to a module.""" - if not self.modules.has(token): - raise UnknownModuleException() - module_ref: Module = self.modules[token] - module_ref.add_controller(controller) - if hasattr(controller, DEPENDENCIES): - for provider_name, provider_type in getattr( - controller, DEPENDENCIES - ).items(): - instance = self.get_instance(provider_type, controller) - setattr(controller, provider_name, instance) - - def _get_controllers(self, token: str) -> List[Any]: - """Get controllers from the module metadata.""" - return self.modules_metadata[token]["controllers"] - - def clear(self): - """Clear all modules from the container.""" - self.modules.clear() - - # UNUSED: This function is currently not used but retained for potential future use. - def add_related_module(self, related_module, token: str) -> None: - if not self.modules.has(token): - return - module_ref = self.modules.get(token) - compile_related_module = self.module_compiler.compile(related_module) - related = self.modules.get(compile_related_module.token) - module_ref.add_import(related) - - # UNUSED: This function is currently not used but retained for potential future use. - # It retrieves a module from the container by its key. - def get_module_by_key(self, module_key: str) -> Module: - return self._modules[module_key] + hints = {} + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + ann = param.annotation + if ann is param.empty: + continue + if isinstance(ann, str): + # Try to resolve the string annotation against known providers + resolved = provider_classes.get(ann) + if resolved is not None: + hints[param_name] = resolved + elif isinstance(ann, type): + hints[param_name] = ann + except Exception: + continue + + for dep_type in hints.values(): + graph.add_dependency(target, dep_type) + + cycles = graph.detect_cycles() + if cycles: + chain = " → ".join( + getattr(n, "__name__", repr(n)) for n in cycles[0] + ) + raise CircularDependencyException( + f"Circular dependency detected: {chain}" + ) diff --git a/tests/test_core/test_pynest_container.py b/tests/test_core/test_pynest_container.py index 9796e25..78102ea 100644 --- a/tests/test_core/test_pynest_container.py +++ b/tests/test_core/test_pynest_container.py @@ -1,28 +1,152 @@ import pytest +from nest.core.pynest_container import PyNestContainer +from nest.core.decorators.module import Module +from nest.core.decorators.injectable import Injectable -from nest.core import Module, PyNestContainer -from tests.test_core import test_module +@Injectable +class RepoService: + def find(self): + return ["a", "b"] -@pytest.fixture(scope="module") -def container(): - # Since PyNestContainer is a singleton, we clear its state before each test to ensure test isolation - PyNestContainer()._instance = None - return PyNestContainer() +@Injectable +class AppService: + def __init__(self, repo: RepoService): + self.repo = repo -def test_singleton_pattern(container): - second_container = PyNestContainer() - assert ( - container is second_container - ), "PyNestContainer should implement the singleton pattern" + def items(self): + return self.repo.find() -def test_add_module(container, test_module): - result = container.add_module(test_module) - assert result["inserted"] is True, "Module should be added successfully" - module_ref = result["module_ref"] - assert module_ref is not None, "Module reference should not be None" - assert container.modules.has( - module_ref.token - ), "Module should be added to the container" +@Module(providers=[RepoService, AppService]) +class AppModule: + pass + + +@Module(providers=[RepoService]) +class RepoModule: + pass + + +@Module(providers=[AppService], imports=[RepoModule]) +class ServiceModule: + pass + + +# ── Container is NOT a singleton ────────────────────────────────────────────── + +def test_two_containers_are_independent(): + c1 = PyNestContainer() + c2 = PyNestContainer() + assert c1 is not c2 + + +# ── add_module + build + get ────────────────────────────────────────────────── + +def test_get_provider_after_build(): + container = PyNestContainer() + container.add_module(AppModule) + container.build() + svc = container.get(AppService) + assert isinstance(svc, AppService) + + +def test_get_returns_singleton_by_default(): + container = PyNestContainer() + container.add_module(AppModule) + container.build() + a = container.get(AppService) + b = container.get(AppService) + assert a is b + + +def test_dependency_is_injected_into_instance(): + container = PyNestContainer() + container.add_module(AppModule) + container.build() + svc = container.get(AppService) + # Must be an instance attribute, NOT a class attribute + assert "repo" in svc.__dict__ + assert isinstance(svc.repo, RepoService) + + +def test_provider_not_class_attribute_after_build(): + container = PyNestContainer() + container.add_module(AppModule) + container.build() + # The class itself must not be mutated — dep is on the instance only + assert "repo" not in AppService.__dict__ + + +def test_add_module_inserts_true_for_new_module(): + container = PyNestContainer() + result = container.add_module(AppModule) + assert result["inserted"] is True + + +def test_add_module_inserts_false_for_duplicate(): + container = PyNestContainer() + container.add_module(AppModule) + result = container.add_module(AppModule) + assert result["inserted"] is False + + +def test_imported_module_providers_are_resolvable(): + container = PyNestContainer() + container.add_module(ServiceModule) + container.build() + svc = container.get(AppService) + assert isinstance(svc.repo, RepoService) + + +# ── Cycle detection ─────────────────────────────────────────────────────────── + +def test_circular_dependency_raises_on_build(): + from nest.common.exceptions import CircularDependencyException + + @Injectable + class Alpha: + def __init__(self, b: "Beta"): + self.b = b + + @Injectable + class Beta: + def __init__(self, a: Alpha): + self.a = a + + @Module(providers=[Alpha, Beta]) + class CycleModule: + pass + + container = PyNestContainer() + container.add_module(CycleModule) + with pytest.raises(CircularDependencyException): + container.build() + + +# ── useValue provider ───────────────────────────────────────────────────────── + +def test_use_value_provider(): + from nest.common.provider import InjectionToken + + DB_URL = InjectionToken("DB_URL") + + @Module(providers=[{"provide": DB_URL, "useValue": "postgres://localhost/test"}]) + class ValueModule: + pass + + container = PyNestContainer() + container.add_module(ValueModule) + container.build() + value = container.get(DB_URL) + assert value == "postgres://localhost/test" + + +# ── get before build raises ─────────────────────────────────────────────────── + +def test_get_before_build_raises(): + container = PyNestContainer() + container.add_module(AppModule) + with pytest.raises(RuntimeError, match="build()"): + container.get(AppService) From 11e0a283399dad368f4d9149ed34dff92d70f5c2 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 7 May 2026 21:41:51 +0300 Subject: [PATCH 06/17] =?UTF-8?q?feat(di):=20rewrite=20@Injectable=20?= =?UTF-8?q?=E2=80=94=20proper=20@inject,=20Scope=20support,=20no=20class?= =?UTF-8?q?=20mutation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- nest/core/decorators/injectable.py | 64 +++++++-------- .../test_decorators/test_injectable.py | 81 +++++++++++++++++++ 2 files changed, 111 insertions(+), 34 deletions(-) create mode 100644 tests/test_core/test_decorators/test_injectable.py diff --git a/nest/core/decorators/injectable.py b/nest/core/decorators/injectable.py index 2e00aea..6e03456 100644 --- a/nest/core/decorators/injectable.py +++ b/nest/core/decorators/injectable.py @@ -1,51 +1,47 @@ -from typing import Callable, Optional, Type +from __future__ import annotations + +from typing import Callable, Optional, Type, Union from injector import inject -from nest.common.constants import DEPENDENCIES, INJECTABLE_NAME, INJECTABLE_TOKEN -from nest.core.decorators.utils import parse_dependencies +from nest.common.constants import INJECTABLE_TOKEN +from nest.common.provider import Scope -def Injectable(target_class: Optional[Type] = None, *args, **kwargs) -> Callable: +def Injectable( + target_class: Optional[Type] = None, + *, + scope: Scope = Scope.SINGLETON, +) -> Union[Type, Callable]: """ - Decorator to mark a class as injectable and handle its dependencies. + Mark a class as injectable so the PyNest container can wire its dependencies. - Args: - target_class (Type, optional): The class to be decorated. + Usage: + @Injectable + class MyService: ... - Returns: - Callable: The decorator function. + @Injectable(scope=Scope.TRANSIENT) + class MyService: ... """ - def decorator(decorated_class: Type) -> Type: - """ - Inner decorator function to process the class. - - Args: - decorated_class (Type): The class to be processed. - - Returns: - Type: The processed class with dependencies injected. - """ - - if "__init__" not in decorated_class.__dict__: - - def init_method(self, *args, **kwargs): - pass - - decorated_class.__init__ = init_method - - dependencies = parse_dependencies(decorated_class) - - setattr(decorated_class, DEPENDENCIES, dependencies) - setattr(decorated_class, INJECTABLE_TOKEN, True) - setattr(decorated_class, INJECTABLE_NAME, decorated_class.__name__) + def decorator(cls: Type) -> Type: + # Apply injector's @inject so constructor params are resolved by type annotation. + # Only apply when the class defines its own __init__ with annotated parameters, + # to avoid failures on Python 3.14+ for classes with no custom __init__. + own_init = cls.__dict__.get("__init__") + if own_init is not None and getattr(own_init, "__annotations__", None): + inject(cls) - inject(decorated_class) + # Store metadata flags — never set dependency values as class attributes + setattr(cls, INJECTABLE_TOKEN, True) + setattr(cls, "__injectable_scope__", scope) + setattr(cls, "__injectable_name__", cls.__name__) - return decorated_class + return cls if target_class is not None: + # Called as @Injectable (no parentheses) return decorator(target_class) + # Called as @Injectable(scope=...) with arguments return decorator diff --git a/tests/test_core/test_decorators/test_injectable.py b/tests/test_core/test_decorators/test_injectable.py new file mode 100644 index 0000000..1509981 --- /dev/null +++ b/tests/test_core/test_decorators/test_injectable.py @@ -0,0 +1,81 @@ +import pytest +from injector import Injector +from nest.core.decorators.injectable import Injectable +from nest.common.provider import Scope + + +def test_injectable_marks_class(): + @Injectable + class MyService: + pass + + assert hasattr(MyService, "__injectable__") + assert MyService.__injectable__ is True + + +def test_injectable_does_not_mutate_class_with_dep_attributes(): + @Injectable + class Dep: + pass + + @Injectable + class MyService: + def __init__(self, dep: Dep): + self.dep = dep + + # The class must NOT have 'dep' as a class-level attribute + assert "dep" not in MyService.__dict__ + + +def test_injectable_preserves_init(): + @Injectable + class MyService: + def __init__(self, x: int = 5): + self.x = x + + svc = MyService() + assert svc.x == 5 + + +def test_injectable_default_scope_is_singleton(): + @Injectable + class MyService: + pass + + assert MyService.__injectable_scope__ == Scope.SINGLETON + + +def test_injectable_with_transient_scope(): + @Injectable(scope=Scope.TRANSIENT) + class MyService: + pass + + assert MyService.__injectable_scope__ == Scope.TRANSIENT + + +def test_injectable_with_request_scope(): + @Injectable(scope=Scope.REQUEST) + class MyService: + pass + + assert MyService.__injectable_scope__ == Scope.REQUEST + + +def test_injector_resolves_injectable_class(): + @Injectable + class Repo: + def items(self): + return [1, 2, 3] + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + injector = Injector() + svc = injector.get(Service) + assert isinstance(svc, Service) + assert isinstance(svc.repo, Repo) + # dep must be on the INSTANCE, not on the class + assert "repo" in svc.__dict__ + assert "repo" not in Service.__dict__ From 6000f47f1c3f5fec09c496c95d49d49820bb8399 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 7 May 2026 21:45:09 +0300 Subject: [PATCH 07/17] =?UTF-8?q?feat(di):=20rewrite=20@Controller=20?= =?UTF-8?q?=E2=80=94=20metadata-only,=20no=20=5F=5Finit=5F=5F=20deletion,?= =?UTF-8?q?=20no=20class=20mutation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- nest/core/decorators/controller.py | 165 +++--------------- .../test_decorators/test_controller.py | 127 ++++++++++---- tests/test_core/test_decorators/test_guard.py | 52 +++--- 3 files changed, 145 insertions(+), 199 deletions(-) diff --git a/nest/core/decorators/controller.py b/nest/core/decorators/controller.py index b8590b1..1d0135e 100644 --- a/nest/core/decorators/controller.py +++ b/nest/core/decorators/controller.py @@ -1,162 +1,53 @@ -from typing import Optional, Type, List +from __future__ import annotations -from fastapi.routing import APIRouter -from fastapi import Depends +from typing import List, Optional, Type -from nest.core.decorators.class_based_view import class_based_view as ClassBasedView -from nest.core.decorators.http_method import HTTPMethod -from nest.core.decorators.utils import get_instance_variables, parse_dependencies -from nest.core.decorators.guards import BaseGuard +from injector import inject as injector_inject def Controller(prefix: Optional[str] = None, tag: Optional[str] = None): """ - Decorator that turns a class into a controller, allowing you to define - routes using FastAPI decorators. + Marks a class as a PyNest controller. - Args: - prefix (str, optional): The prefix to use for all routes. - tag (str, optional): The tag to use for OpenAPI documentation. + Stores route prefix and tag as class metadata only. Does NOT wrap the class, + delete __init__, or set class-level dependency attributes. + The injector resolves controller instances; RoutesResolver registers bound methods. - Returns: - class: The decorated class. + Args: + prefix: URL prefix for all routes in this controller (e.g. "/users") + tag: OpenAPI tag for Swagger docs """ - # Default route_prefix to tag_name if route_prefix is not provided - route_prefix = process_prefix(prefix, tag) - - def wrapper(cls: Type) -> Type[ClassBasedView]: - router = APIRouter(tags=[tag] if tag else None) - - # Process class dependencies - process_dependencies(cls) + def wrapper(cls: Type) -> Type: + route_prefix = _process_prefix(prefix, tag) - # Set instance variables - set_instance_variables(cls) + cls.__is_controller__ = True + cls.__route_prefix__ = route_prefix + cls.__controller_tag__ = tag - # Ensure the class has an __init__ method - ensure_init_method(cls) + # Mark constructor for injector auto-wiring (same guard as @Injectable) + own_init = cls.__dict__.get("__init__") + if own_init is not None and getattr(own_init, "__annotations__", None): + injector_inject(cls) - # Add routes to the router - add_routes(cls, router, route_prefix) - - # Add get_router method to the class - cls.get_router = classmethod(lambda cls: router) - - return ClassBasedView(router=router, cls=cls) + return cls return wrapper -def process_prefix(route_prefix: Optional[str], tag_name: Optional[str]) -> str: - """Process and format the prefix.""" +def _process_prefix(route_prefix: Optional[str], tag_name: Optional[str]) -> Optional[str]: + if route_prefix is None and tag_name is None: + return None if route_prefix is None: - if tag_name is None: - return None - else: - route_prefix = tag_name - + route_prefix = tag_name if not route_prefix.startswith("/"): route_prefix = "/" + route_prefix - - if route_prefix.endswith("/"): + if route_prefix.endswith("/") and route_prefix != "/": route_prefix = route_prefix.rstrip("/") - return route_prefix -def process_dependencies(cls: Type) -> None: - """Parse and set dependencies for the class.""" - dependencies = parse_dependencies(cls) - setattr(cls, "__dependencies__", dependencies) - - -def set_instance_variables(cls: Type) -> None: - """Set instance variables for the class.""" - non_dependency_vars = get_instance_variables(cls) - for key, value in non_dependency_vars.items(): - setattr(cls, key, value) - - -def ensure_init_method(cls: Type) -> None: - """Ensure the class has an __init__ method.""" - if not hasattr(cls, "__init__"): - raise AttributeError("Class must have an __init__ method") - try: - delattr(cls, "__init__") - except AttributeError: - pass - - -def add_routes(cls: Type, router: APIRouter, route_prefix: str) -> None: - """Add routes from class methods to the router.""" - for method_name, method_function in cls.__dict__.items(): - if callable(method_function) and hasattr(method_function, "__http_method__"): - validate_method_decorator(method_function, method_name) - configure_method_route(method_function, route_prefix) - add_route_to_router(router, method_function, cls) - - -def validate_method_decorator(method_function: callable, method_name: str) -> None: - """Validate that the method has a proper HTTP method decorator.""" - if ( - not hasattr(method_function, "__route_path__") - or not method_function.__route_path__ - ): - raise AssertionError(f"Missing path for method {method_name}") - - if not isinstance(method_function.__http_method__, HTTPMethod): - raise AssertionError(f"Invalid method {method_function.__http_method__}") - - -def configure_method_route(method_function: callable, route_prefix: str) -> None: - """Configure the route for the method.""" - if not method_function.__route_path__.startswith("/"): - method_function.__route_path__ = "/" + method_function.__route_path__ - - method_function.__route_path__ = ( - route_prefix + method_function.__route_path__ - if route_prefix - else method_function.__route_path__ - ) - - # remove trailing "/" fro __route_path__ - # it converts "/api/users/" to "/api/users" - if ( - method_function.__route_path__ != "/" - and method_function.__route_path__.endswith("/") - ): - method_function.__route_path__ = method_function.__route_path__.rstrip("/") - - -def _collect_guards(cls: Type, method: callable) -> List[BaseGuard]: - guards: List[BaseGuard] = [] - for guard in getattr(cls, "__guards__", []): - guards.append(guard) - for guard in getattr(method, "__guards__", []): - guards.append(guard) +def _collect_guards(cls: Type, method) -> List: + guards = list(getattr(cls, "__guards__", [])) + guards.extend(getattr(method, "__guards__", [])) return guards - - -def add_route_to_router( - router: APIRouter, method_function: callable, cls: Type -) -> None: - """Add the configured route to the router.""" - route_kwargs = { - "path": method_function.__route_path__, - "endpoint": method_function, - "methods": [method_function.__http_method__.value], - **method_function.__kwargs__, - } - - if hasattr(method_function, "status_code"): - route_kwargs["status_code"] = method_function.status_code - - guards = _collect_guards(cls, method_function) - if guards: - dependencies = route_kwargs.get("dependencies", []) - for guard in guards: - dependencies.append(guard.as_dependency()) - route_kwargs["dependencies"] = dependencies - - router.add_api_route(**route_kwargs) diff --git a/tests/test_core/test_decorators/test_controller.py b/tests/test_core/test_decorators/test_controller.py index 0985014..b3d92a6 100644 --- a/tests/test_core/test_decorators/test_controller.py +++ b/tests/test_core/test_decorators/test_controller.py @@ -1,51 +1,102 @@ import pytest +from nest.core.decorators.controller import Controller +from nest.core.decorators.injectable import Injectable +from nest.core.decorators.http_method import Get, Post -from nest.core import Controller, Delete, Get, Injectable, Patch, Post, Put +@Injectable +class TestService: + def hello(self): + return "hello" -@Controller(prefix="api/v1/user", tag="test") -class TestController: - def __init__(self): ... - @Get("/get_all_users") - def get_endpoint(self): - return {"message": "GET endpoint"} +def test_controller_marks_class(): + @Controller("/items") + class ItemController: + def __init__(self, svc: TestService): + self.svc = svc - @Post("/create_user") - def post_endpoint(self): - return {"message": "POST endpoint"} + assert ItemController.__is_controller__ is True - @Delete("/delete_user") - def delete_endpoint(self): - return {"message": "DELETE endpoint"} - @Put("/update_user") - def put_endpoint(self): - return {"message": "PUT endpoint"} +def test_controller_stores_prefix(): + @Controller("/items") + class ItemController: + pass - @Patch("/patch_user") - def patch_endpoint(self): - return {"message": "PATCH endpoint"} + assert ItemController.__route_prefix__ == "/items" -@pytest.fixture -def test_controller(): - return TestController() +def test_controller_stores_tag(): + @Controller("/items", tag="items") + class ItemController: + pass + assert ItemController.__controller_tag__ == "items" -@pytest.mark.parametrize( - "function, endpoint, expected_message", - [ - ("get_endpoint", "get_all_users", "GET endpoint"), - ("post_endpoint", "create_user", "POST endpoint"), - ("delete_endpoint", "delete_user", "DELETE endpoint"), - ("put_endpoint", "update_user", "PUT endpoint"), - ("patch_endpoint", "patch_user", "PATCH endpoint"), - ], -) -def test_endpoints(test_controller, function, endpoint, expected_message): - attribute = getattr(test_controller, function) - assert attribute.__route_path__ == "/api/v1/user/" + endpoint - assert attribute.__kwargs__ == {} - assert attribute.__http_method__.value == function.split("_")[0].upper() - assert attribute() == {"message": f"{function.split('_')[0].upper()} endpoint"} + +def test_controller_adds_leading_slash(): + @Controller("users") + class UserController: + pass + + assert UserController.__route_prefix__ == "/users" + + +def test_controller_strips_trailing_slash(): + @Controller("/users/") + class UserController: + pass + + assert UserController.__route_prefix__ == "/users" + + +def test_controller_does_not_delete_init(): + @Controller("/items") + class ItemController: + def __init__(self, svc: TestService): + self.svc = svc + + # __init__ must still exist on the class — the injector needs it + assert callable(getattr(ItemController, "__init__", None)) + + +def test_controller_does_not_set_class_attributes_for_deps(): + @Controller("/items") + class ItemController: + def __init__(self, svc: TestService): + self.svc = svc + + # Dep must NOT be a class attribute — injector handles instance creation + assert "svc" not in ItemController.__dict__ + + +def test_controller_returns_original_class(): + @Controller("/items") + class ItemController: + pass + + assert isinstance(ItemController, type) + + +def test_controller_http_methods_retain_metadata(): + @Controller("/items") + class ItemController: + @Get("/") + def list_items(self): + return [] + + @Post("/") + def create_item(self): + return {} + + assert hasattr(ItemController.list_items, "__http_method__") + assert hasattr(ItemController.create_item, "__http_method__") + + +def test_controller_with_no_prefix(): + @Controller() + class RootController: + pass + + assert RootController.__route_prefix__ is None diff --git a/tests/test_core/test_decorators/test_guard.py b/tests/test_core/test_decorators/test_guard.py index ccda088..4fdd0f8 100644 --- a/tests/test_core/test_decorators/test_guard.py +++ b/tests/test_core/test_decorators/test_guard.py @@ -25,7 +25,7 @@ def can_activate(self, request: Request, credentials) -> bool: class JWTGuard(BaseGuard): security_scheme = HTTPBearer() - + def can_activate(self, request: Request, credentials=None) -> bool: if credentials and credentials.scheme == "Bearer": return self.validate_jwt(credentials.credentials) @@ -45,27 +45,35 @@ def test_use_guards_sets_attribute(): assert SimpleGuard in GuardController.root.__guards__ -def test_guard_added_to_route_dependencies(): - router = GuardController.get_router() - route = router.routes[0] - deps = route.dependencies - assert len(deps) == 1 - assert callable(deps[0].dependency) +def test_guard_metadata_stored_on_method(): + """Guards are stored as metadata on route methods for later route registration.""" + @Controller("/items") + class ItemController: + @Get("/") + @UseGuards(SimpleGuard) + def list_items(self): + return [] + assert hasattr(ItemController.list_items, "__guards__") + assert SimpleGuard in ItemController.list_items.__guards__ -def _has_security_requirements(dependant): - """Recursively check if a dependant or its dependencies have security requirements.""" - if dependant.security_requirements: - return True - - for dep in dependant.dependencies: - if _has_security_requirements(dep): - return True - - return False + +def test_guard_as_dependency_callable(): + """as_dependency() must produce a valid FastAPI Depends object.""" + dep = SimpleGuard.as_dependency() + # FastAPI Depends wraps a callable in a Depends object + assert callable(dep.dependency) + + +def test_bearer_guard_has_security_scheme(): + """Guards with security_scheme are recognized as having OpenAPI security.""" + assert BearerGuard.security_scheme is not None + dep = BearerGuard.as_dependency() + assert callable(dep.dependency) -def test_openapi_security_requirement(): +def test_controller_preserves_guard_metadata(): + """@Controller must not strip guard metadata from route methods.""" @Controller("/bearer") class BearerController: @Get("/") @@ -73,9 +81,5 @@ class BearerController: def root(self): return {"ok": True} - router = BearerController.get_router() - route = router.routes[0] - - # Check if security requirements exist anywhere in the dependency tree - assert _has_security_requirements(route.dependant), \ - "Security requirements should be present in the dependency tree for OpenAPI integration" + assert hasattr(BearerController.root, "__guards__") + assert BearerGuard in BearerController.root.__guards__ From 89cc71d0c6129602601d3c37c370e2dde0ab9f68 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 7 May 2026 22:03:28 +0300 Subject: [PATCH 08/17] =?UTF-8?q?feat(di):=20rewrite=20RoutesResolver=20?= =?UTF-8?q?=E2=80=94=20instance-based=20routing=20with=20bound=20methods?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- nest/common/route_resolver.py | 89 ++++++++++++++++++++++-- pyproject.toml | 4 ++ tests/test_common/test_route_resolver.py | 81 +++++++++++++++++++++ 3 files changed, 167 insertions(+), 7 deletions(-) diff --git a/nest/common/route_resolver.py b/nest/common/route_resolver.py index 9d545c4..07a602e 100644 --- a/nest/common/route_resolver.py +++ b/nest/common/route_resolver.py @@ -1,16 +1,91 @@ +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING + from fastapi import APIRouter, FastAPI +if TYPE_CHECKING: + from nest.core.pynest_container import PyNestContainer + class RoutesResolver: - def __init__(self, container, app_ref: FastAPI): + """ + Walks the module graph, resolves controller instances from the container, + and registers their bound methods as FastAPI route endpoints. + """ + + def __init__(self, container: "PyNestContainer", app_ref: FastAPI) -> None: self.container = container self.app_ref = app_ref - def register_routes(self): - for module in self.container.modules.values(): - for controller in module.controllers.values(): - self.register_route(controller) + def register_routes(self) -> None: + seen: set = set() + for module_ref in self.container.modules.values(): + for controller_class in module_ref.compiled.controllers: + if controller_class in seen: + continue + seen.add(controller_class) + self._register_controller(controller_class) + + def _register_controller(self, controller_class: type) -> None: + instance = self.container.get_controller_instance(controller_class) + tag = getattr(controller_class, "__controller_tag__", None) + prefix = getattr(controller_class, "__route_prefix__", None) or "" + + router = APIRouter(tags=[tag] if tag else None) + + for method_name, unbound in inspect.getmembers(controller_class, predicate=callable): + if not hasattr(unbound, "__http_method__"): + continue + bound = getattr(instance, method_name) + self._add_route(router, bound, unbound, controller_class, prefix) - def register_route(self, controller): - router: APIRouter = controller.get_router() self.app_ref.include_router(router) + + def _add_route( + self, + router: APIRouter, + bound_method, + original_method, + cls: type, + prefix: str, + ) -> None: + from nest.core.decorators.controller import _collect_guards + from nest.core.decorators.http_method import HTTPMethod + + path = getattr(original_method, "__route_path__", "/") + http_method = getattr(original_method, "__http_method__", None) + extra_kwargs = getattr(original_method, "__kwargs__", {}) + + if not isinstance(http_method, HTTPMethod): + return + + full_path = _join_paths(prefix, path) + + route_kwargs = { + "path": full_path, + "endpoint": bound_method, + "methods": [http_method.value], + **extra_kwargs, + } + + if hasattr(original_method, "status_code"): + route_kwargs["status_code"] = original_method.status_code + + guards = _collect_guards(cls, original_method) + if guards: + route_kwargs["dependencies"] = [g.as_dependency() for g in guards] + + router.add_api_route(**route_kwargs) + + +def _join_paths(prefix: str, path: str) -> str: + prefix = prefix or "" + path = path or "/" + if not path.startswith("/"): + path = "/" + path + combined = prefix.rstrip("/") + path + if combined.endswith("/") and combined != "/": + combined = combined.rstrip("/") + return combined or "/" diff --git a/pyproject.toml b/pyproject.toml index 4faa078..06051f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,10 @@ mkdocs-material = "^9.5.43" mkdocstrings-python = "^1.12.2" + +[tool.poetry.group.dev.dependencies] +httpx = "^0.28.1" + [tool.black] force-exclude = ''' /( diff --git a/tests/test_common/test_route_resolver.py b/tests/test_common/test_route_resolver.py index e69de29..1129523 100644 --- a/tests/test_common/test_route_resolver.py +++ b/tests/test_common/test_route_resolver.py @@ -0,0 +1,81 @@ +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from nest.common.route_resolver import RoutesResolver +from nest.core.pynest_container import PyNestContainer +from nest.core.decorators.module import Module +from nest.core.decorators.injectable import Injectable +from nest.core.decorators.controller import Controller +from nest.core.decorators.http_method import Get, Post + + +@Injectable +class GreetService: + def greet(self, name: str) -> str: + return f"Hello, {name}!" + + +@Controller("/greet", tag="greet") +class GreetController: + def __init__(self, svc: GreetService): + self.svc = svc + + @Get("/") + def index(self): + return {"message": "ok"} + + @Get("/{name}") + def greet(self, name: str): + return {"message": self.svc.greet(name)} + + +@Module(providers=[GreetService], controllers=[GreetController]) +class GreetModule: + pass + + +@pytest.fixture +def app_and_container(): + container = PyNestContainer() + container.add_module(GreetModule) + container.build() + app = FastAPI() + resolver = RoutesResolver(container, app) + resolver.register_routes() + return app, container + + +def test_routes_registered(app_and_container): + app, _ = app_and_container + paths = {route.path for route in app.routes} + assert any("/greet" in p for p in paths) + + +def test_get_index_returns_200(app_and_container): + app, _ = app_and_container + client = TestClient(app) + response = client.get("/greet/") + assert response.status_code == 200 + assert response.json() == {"message": "ok"} + + +def test_get_with_path_param(app_and_container): + app, _ = app_and_container + client = TestClient(app) + response = client.get("/greet/World") + assert response.status_code == 200 + assert response.json()["message"] == "Hello, World!" + + +def test_service_is_instance_attribute_not_class(app_and_container): + app, container = app_and_container + instance = container.get_controller_instance(GreetController) + assert "svc" in instance.__dict__ + assert "svc" not in GreetController.__dict__ + + +def test_controller_instance_is_singleton(app_and_container): + app, container = app_and_container + a = container.get_controller_instance(GreetController) + b = container.get_controller_instance(GreetController) + assert a is b From 27b596fc71c732abb07d2105d96bcb09cfa7feea Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Fri, 8 May 2026 09:32:09 +0300 Subject: [PATCH 09/17] =?UTF-8?q?feat(di):=20update=20PyNestFactory=20to?= =?UTF-8?q?=20call=20container.build()=20=E2=80=94=20wires=20new=20engine?= =?UTF-8?q?=20end-to-end?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - PyNestFactory.create() now calls container.build() before creating PyNestApp - PyNestApp rewritten: no longer inherits PyNestApplicationContext, no select_context_module/register_routes methods; RoutesResolver called inline in __init__ - test_pynest_factory.py replaced with 6 focused TDD tests covering e2e routes, isolation, and DI correctness - test_pynest_application.py updated to match new PyNestApp API Co-Authored-By: Claude Sonnet 4.6 --- nest/core/pynest_application.py | 60 ++++------------ nest/core/pynest_factory.py | 30 ++++---- tests/test_core/test_pynest_application.py | 69 ++++++++++++++----- tests/test_core/test_pynest_factory.py | 79 +++++++++++++++++----- 4 files changed, 139 insertions(+), 99 deletions(-) diff --git a/nest/core/pynest_application.py b/nest/core/pynest_application.py index 22e2669..7bcd27d 100644 --- a/nest/core/pynest_application.py +++ b/nest/core/pynest_application.py @@ -1,64 +1,32 @@ +from __future__ import annotations + from typing import Any from fastapi import FastAPI from nest.common.route_resolver import RoutesResolver -from nest.core.pynest_app_context import PyNestApplicationContext from nest.core.pynest_container import PyNestContainer -class PyNestApp(PyNestApplicationContext): +class PyNestApp: """ - PyNestApp is the main application class for the PyNest framework, - managing the container and HTTP server. + Main PyNest application. Wraps a built container and a FastAPI HTTP server. """ - _is_listening = False - - @property - def is_listening(self) -> bool: - return self._is_listening - - def __init__(self, container: PyNestContainer, http_server: FastAPI): - """ - Initialize the PyNestApp with the given container and HTTP server. - - Args: - container (PyNestContainer): The PyNestContainer container instance. - http_server (FastAPI): The FastAPI server instance. - """ + def __init__(self, container: PyNestContainer, http_server: FastAPI) -> None: self.container = container self.http_server = http_server - super().__init__(self.container) - self.routes_resolver = RoutesResolver(self.container, self.http_server) - self.select_context_module() - self.register_routes() - - def use(self, middleware: type, **options: Any) -> "PyNestApp": - """ - Add middleware to the FastAPI server. - - Args: - middleware (type): The middleware class. - **options (Any): Additional options for the middleware. - - Returns: - PyNestApp: The current instance of PyNestApp, allowing method chaining. - """ - self.http_server.add_middleware(middleware, **options) - return self + routes_resolver = RoutesResolver(self.container, self.http_server) + routes_resolver.register_routes() def get_server(self) -> FastAPI: - """ - Get the FastAPI server instance. + return self.http_server - Returns: - FastAPI: The FastAPI server instance. - """ + def get_http_server(self) -> FastAPI: + """Alias for get_server() — kept for backward compatibility.""" return self.http_server - def register_routes(self): - """ - Register the routes using the RoutesResolver. - """ - self.routes_resolver.register_routes() + def use(self, middleware: type, **options: Any) -> "PyNestApp": + """Add ASGI middleware to the FastAPI server.""" + self.http_server.add_middleware(middleware, **options) + return self diff --git a/nest/core/pynest_factory.py b/nest/core/pynest_factory.py index 2d988ae..9058d74 100644 --- a/nest/core/pynest_factory.py +++ b/nest/core/pynest_factory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Type, TypeVar @@ -16,34 +18,26 @@ def create(self, main_module: Type[ModuleType], **kwargs): class PyNestFactory(AbstractPyNestFactory): - """Factory class for creating PyNest applications.""" + """Factory that creates a fully-wired PyNest application from a root module.""" @staticmethod def create(main_module: Type[ModuleType], **kwargs) -> PyNestApp: """ - Create a PyNest application with the specified main module class. - - Args: - main_module (ModuleType): The main module for the PyNest application. - **kwargs: Additional keyword arguments for the FastAPI server. + Build and return a PyNestApp. - Returns: - PyNestApp: The created PyNest application. + 1. Creates a fresh container (NOT a singleton) + 2. Adds the root module (recursively registers all imported modules) + 3. Validates the dependency graph and builds the injector + 4. Creates the FastAPI HTTP server + 5. Registers all routes via RoutesResolver """ container = PyNestContainer() container.add_module(main_module) - http_server = PyNestFactory._create_server(**kwargs) + container.build() + + http_server = FastAPI(**kwargs) return PyNestApp(container, http_server) @staticmethod def _create_server(**kwargs) -> FastAPI: - """ - Create a FastAPI server. - - Args: - **kwargs: Additional keyword arguments for the FastAPI server. - - Returns: - FastAPI: The created FastAPI server. - """ return FastAPI(**kwargs) diff --git a/tests/test_core/test_pynest_application.py b/tests/test_core/test_pynest_application.py index df364b7..d01c16f 100644 --- a/tests/test_core/test_pynest_application.py +++ b/tests/test_core/test_pynest_application.py @@ -1,28 +1,63 @@ import pytest +from fastapi import FastAPI -from nest.core import PyNestApp -from tests.test_core import test_container, test_resolver -from tests.test_core.test_pynest_factory import test_server +from nest.core import PyNestFactory +from nest.core.pynest_application import PyNestApp +from nest.core import Module, Injectable, Controller, Get + + +@Injectable +class AppService: + def greet(self) -> str: + return "hi" + + +@Controller("/app", tag="app") +class AppController: + def __init__(self, svc: AppService): + self.svc = svc + + @Get("/") + def index(self): + return {"msg": self.svc.greet()} + + +@Module(controllers=[AppController], providers=[AppService]) +class AppModule: + pass @pytest.fixture -def pynest_app(test_container, test_server): - return PyNestApp(container=test_container, http_server=test_server) +def pynest_app(): + return PyNestFactory.create(AppModule) + + +def test_get_server_returns_fastapi(pynest_app): + assert isinstance(pynest_app.get_server(), FastAPI) + + +def test_get_http_server_alias(pynest_app): + assert pynest_app.get_http_server() is pynest_app.get_server() + + +def test_container_is_built(pynest_app): + # If the container is built, we can resolve a provider without RuntimeError + svc = pynest_app.container.get(AppService) + assert isinstance(svc, AppService) -def test_is_listening_property(pynest_app): - assert not pynest_app.is_listening - pynest_app._is_listening = ( - True # Directly modify the protected attribute for testing - ) - assert pynest_app.is_listening +def test_routes_are_registered(pynest_app): + # At least one route should be registered on the FastAPI app + assert len(pynest_app.get_server().routes) > 0 -def test_get_server_returns_http_server(pynest_app, test_server): - assert pynest_app.get_server() == test_server +def test_use_adds_middleware(pynest_app): + from starlette.middleware.base import BaseHTTPMiddleware + class DummyMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) -def test_register_routes_calls_register_routes_on_resolver(pynest_app, test_resolver): - pynest_app.routes_resolver = test_resolver - pynest_app.register_routes() - assert pynest_app.get_server().routes + result = pynest_app.use(DummyMiddleware) + # use() returns self for chaining + assert result is pynest_app diff --git a/tests/test_core/test_pynest_factory.py b/tests/test_core/test_pynest_factory.py index 51ecb0f..f39ab84 100644 --- a/tests/test_core/test_pynest_factory.py +++ b/tests/test_core/test_pynest_factory.py @@ -1,29 +1,72 @@ import pytest from fastapi import FastAPI +from fastapi.testclient import TestClient -from nest.core import PyNestFactory # Replace 'your_module' with the actual module name -from tests.test_core import test_container, test_module, test_server +from nest.core import Module, Injectable, Controller, Get, PyNestFactory +from nest.core.pynest_application import PyNestApp -def test_create_server(test_server): - assert isinstance(test_server, FastAPI) - assert test_server.title == "Test Server" - assert test_server.description == "This is a test server" - assert test_server.version == "1.0.0" - assert test_server.debug is True +@Injectable +class MessageService: + def get_message(self) -> str: + return "Hello, World!" -def test_e2e(test_module): +@Controller("/test", tag="test") +class TestController: + def __init__(self, svc: MessageService): + self.svc = svc + + @Get("/") + def index(self): + return {"message": self.svc.get_message()} + + +@Module(controllers=[TestController], providers=[MessageService]) +class TestModule: + pass + + +def test_create_returns_pynest_app(): + app = PyNestFactory.create(TestModule) + assert isinstance(app, PyNestApp) + + +def test_create_produces_fastapi_server(): + app = PyNestFactory.create(TestModule) + assert isinstance(app.get_server(), FastAPI) + + +def test_create_server_kwargs_forwarded(): app = PyNestFactory.create( - test_module, + TestModule, title="Test Server", - description="This is a test server", + description="A test", version="1.0.0", - debug=True, ) - http_server = app.get_server() - assert isinstance(http_server, FastAPI) - assert http_server.title == "Test Server" - assert http_server.description == "This is a test server" - assert http_server.version == "1.0.0" - assert http_server.debug is True + server = app.get_server() + assert server.title == "Test Server" + assert server.version == "1.0.0" + + +def test_e2e_route_responds(): + app = PyNestFactory.create(TestModule) + client = TestClient(app.get_server()) + response = client.get("/test/") + assert response.status_code == 200 + assert response.json() == {"message": "Hello, World!"} + + +def test_two_apps_are_independent(): + app1 = PyNestFactory.create(TestModule) + app2 = PyNestFactory.create(TestModule) + svc1 = app1.container.get(MessageService) + svc2 = app2.container.get(MessageService) + assert svc1 is not svc2 + + +def test_service_dep_is_instance_attribute_not_class(): + app = PyNestFactory.create(TestModule) + ctrl = app.container.get_controller_instance(TestController) + assert "svc" in ctrl.__dict__ + assert "svc" not in TestController.__dict__ From b641783ddf65ff9d44f23f2ce5cefc8ad52f426d Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Fri, 8 May 2026 09:35:47 +0300 Subject: [PATCH 10/17] chore(di): remove dead code (parse_dependencies, ClassBasedView), export InjectionToken + Scope - Drop parse_dependencies, get_instance_variables, get_non_dependencies_params from utils.py - Remove dead imports from cli_decorators.py (only parse_params remains) - Delete class_based_view.py (replaced by instance-based routing) - Richer docstrings on all three exception classes in exceptions.py - Export InjectionToken and Scope from nest.core.__init__ Co-Authored-By: Claude Sonnet 4.6 --- nest/common/exceptions.py | 9 +- nest/core/__init__.py | 1 + nest/core/decorators/class_based_view.py | 113 --------------------- nest/core/decorators/cli/cli_decorators.py | 16 +-- nest/core/decorators/utils.py | 74 +------------- 5 files changed, 10 insertions(+), 203 deletions(-) delete mode 100644 nest/core/decorators/class_based_view.py diff --git a/nest/common/exceptions.py b/nest/common/exceptions.py index 3d60fcc..54210f0 100644 --- a/nest/common/exceptions.py +++ b/nest/common/exceptions.py @@ -1,12 +1,17 @@ class CircularDependencyException(Exception): - def __init__(self, message="Circular dependency detected"): + """Raised when a circular dependency is detected in the provider graph at build time.""" + + def __init__(self, message: str = "Circular dependency detected"): super().__init__(message) class UnknownModuleException(Exception): + """Raised when a module cannot be found in the container.""" pass class NoneInjectableException(Exception): - def __init__(self, message="None Injectable Classe Detected"): + """Raised when a class without @Injectable is listed as a provider.""" + + def __init__(self, message: str = "Non-injectable class detected"): super().__init__(message) diff --git a/nest/core/__init__.py b/nest/core/__init__.py index e2dce6e..f34b874 100644 --- a/nest/core/__init__.py +++ b/nest/core/__init__.py @@ -1,5 +1,6 @@ from fastapi import Depends +from nest.common.provider import InjectionToken, Scope from nest.core.decorators import ( Controller, Delete, diff --git a/nest/core/decorators/class_based_view.py b/nest/core/decorators/class_based_view.py deleted file mode 100644 index 18cf0b8..0000000 --- a/nest/core/decorators/class_based_view.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Credit: FastAPI-Utils -Source: https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py -""" - -import inspect -from typing import ( - Any, - Callable, - ClassVar, - List, - Type, - TypeVar, - Union, - get_origin, - get_type_hints, -) - -from fastapi import APIRouter, Depends -from starlette.routing import Route, WebSocketRoute - -T = TypeVar("T") -K = TypeVar("K", bound=Callable[..., Any]) - -CBV_CLASS_KEY = "__cbv_class__" - - -def class_based_view(router: APIRouter, cls: Type[T]) -> Type[T]: - """ - Replaces any methods of the provided class `cls` that are endpoints of routes in `router` with updated - function calls that will properly inject an instance of `cls`. - """ - _init_cbv(cls) - cbv_router = APIRouter() - function_members = inspect.getmembers(cls, inspect.isfunction) - functions_set = set(func for _, func in function_members) - cbv_routes = [ - route - for route in router.routes - if isinstance(route, (Route, WebSocketRoute)) - and route.endpoint in functions_set - ] - for route in cbv_routes: - router.routes.remove(route) - _update_cbv_route_endpoint_signature(cls, route) - cbv_router.routes.append(route) - router.include_router(cbv_router) - return cls - - -def _init_cbv(cls: Type[Any]) -> None: - """ - Idempotently modifies the provided `cls`, performing the following modifications: - * The `__init__` function is updated to set any class-annotated dependencies as instance attributes - * The `__signature__` attribute is updated to indicate to FastAPI what arguments should be passed to the initializer - """ - if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover - return # Already initialized - old_init: Callable[..., Any] = cls.__init__ - old_signature = inspect.signature(old_init) - old_parameters = list(old_signature.parameters.values())[ - 1: - ] # drop `self` parameter - new_parameters = [ - x - for x in old_parameters - if x.kind - not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - ] - dependency_names: List[str] = [] - for name, hint in get_type_hints(cls).items(): - if get_origin(hint) is ClassVar: - continue - parameter_kwargs = {"default": getattr(cls, name, Ellipsis)} - dependency_names.append(name) - new_parameters.append( - inspect.Parameter( - name=name, - kind=inspect.Parameter.KEYWORD_ONLY, - annotation=hint, - **parameter_kwargs, - ) - ) - new_signature = old_signature.replace(parameters=new_parameters) - - def new_init(self: Any, *args: Any, **kwargs: Any) -> None: - for dep_name in dependency_names: - dep_value = kwargs.pop(dep_name) - setattr(self, dep_name, dep_value) - old_init(self, *args, **kwargs) - - setattr(cls, "__signature__", new_signature) - setattr(cls, "__init__", new_init) - setattr(cls, CBV_CLASS_KEY, True) - - -def _update_cbv_route_endpoint_signature( - cls: Type[Any], route: Union[Route, WebSocketRoute] -) -> None: - """ - Fixes the endpoint signature for a cbv route to ensure FastAPI performs dependency injection properly. - """ - old_endpoint = route.endpoint - old_signature = inspect.signature(old_endpoint) - old_parameters: List[inspect.Parameter] = list(old_signature.parameters.values()) - old_first_parameter = old_parameters[0] - new_first_parameter = old_first_parameter.replace(default=Depends(cls)) - new_parameters = [new_first_parameter] + [ - parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) - for parameter in old_parameters[1:] - ] - new_signature = old_signature.replace(parameters=new_parameters) - setattr(route.endpoint, "__signature__", new_signature) \ No newline at end of file diff --git a/nest/core/decorators/cli/cli_decorators.py b/nest/core/decorators/cli/cli_decorators.py index 3d82011..97c6f54 100644 --- a/nest/core/decorators/cli/cli_decorators.py +++ b/nest/core/decorators/cli/cli_decorators.py @@ -2,25 +2,11 @@ import click -from nest.core import Controller -from nest.core.decorators.utils import ( - get_instance_variables, - parse_dependencies, - parse_params, -) +from nest.core.decorators.utils import parse_params def CliController(name: str, **kwargs): def decorator(cls): - dependencies = parse_dependencies(cls) - setattr(cls, "__dependencies__", dependencies) - non_dep = get_instance_variables(cls) - for key, value in non_dep.items(): - setattr(cls, key, value) - try: - delattr(cls, "__init__") - except AttributeError: - raise AttributeError("Class must have an __init__ method") cli_group = click.Group(name, **kwargs) setattr(cls, "_cli_group", cli_group) diff --git a/nest/core/decorators/utils.py b/nest/core/decorators/utils.py index 69fda4b..823d9a3 100644 --- a/nest/core/decorators/utils.py +++ b/nest/core/decorators/utils.py @@ -1,83 +1,11 @@ -import ast import inspect from typing import Callable, List import click -from nest.common.constants import INJECTABLE_TOKEN - - -def get_instance_variables(cls): - """ - Retrieves instance variables assigned in the __init__ method of a class, - excluding those that are injected dependencies. - - Args: - cls (type): The class to inspect. - - Returns: - dict: A dictionary with variable names as keys and their assigned values. - """ - try: - source = inspect.getsource(cls.__init__).strip() - tree = ast.parse(source) - - # Getting the parameter names to exclude dependencies - dependencies = set( - param.name - for param in inspect.signature(cls.__init__).parameters.values() - if param.annotation != param.empty - and getattr(param.annotation, "__injectable__", False) - ) - - instance_vars = {} - for node in ast.walk(tree): - if isinstance(node, ast.Assign): - for target in node.targets: - if ( - isinstance(target, ast.Attribute) - and isinstance(target.value, ast.Name) - and target.value.id == "self" - ): - # Exclude dependencies - if target.attr not in dependencies: - # Here you can either store the source code of the value or - # evaluate it in the class' context, depending on your needs - instance_vars[target.attr] = ast.get_source_segment( - source, node.value - ) - return instance_vars - except Exception as e: - return {} - - -def get_non_dependencies_params(cls): - source = inspect.getsource(cls.__init__).strip() - tree = ast.parse(source) - non_dependencies = {} - for node in ast.walk(tree): - if isinstance(node, ast.Attribute): - non_dependencies[node.attr] = node.value.id - return non_dependencies - - -def parse_dependencies(cls): - signature = inspect.signature(cls.__init__) - dependecies = {} - for param in signature.parameters.values(): - try: - if ( - param.annotation != param.empty - and hasattr(param.annotation, "__dict__") - and INJECTABLE_TOKEN in param.annotation.__dict__ - ): - dependecies[param.name] = param.annotation - except Exception as e: - raise e - return dependecies - def parse_params(func: Callable) -> List[click.Option]: + """Used by the CLI layer — parses Click options from function annotations.""" signature = inspect.signature(func) params = [] for param in signature.parameters.values(): From 38b4af0eedae9b85d32cd26414637702407064d3 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Fri, 8 May 2026 09:44:42 +0300 Subject: [PATCH 11/17] fix(cli): update CLIAppFactory to use module.compiled.controllers The new ModuleRef stores controllers as compiled.controllers (list), not module.controllers (dict). Fixes integration test boot failure. Co-Authored-By: Claude Sonnet 4.6 --- nest/core/cli_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nest/core/cli_factory.py b/nest/core/cli_factory.py index 99284be..2085b00 100644 --- a/nest/core/cli_factory.py +++ b/nest/core/cli_factory.py @@ -16,7 +16,7 @@ def create(self, app_module: ModuleType, **kwargs): cli_app = click.Group("main") for module in container.modules.values(): - for controller in module.controllers.values(): + for controller in module.compiled.controllers: for command in controller._cli_group.commands.values(): original_callback = command.callback if asyncio.iscoroutinefunction(original_callback): From 9a8620a8b23fb8dd05d0da1e245fb9cb46c5b9ca Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Fri, 8 May 2026 09:47:25 +0300 Subject: [PATCH 12/17] fix(cli): properly resolve CLI controller instances via DI CLIAppFactory now calls container.build() then manually resolves each CLI controller's constructor deps from the injector, instead of passing the class as 'self' (which relied on the old class-mutation DI). Co-Authored-By: Claude Sonnet 4.6 --- nest/core/cli_factory.py | 45 ++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/nest/core/cli_factory.py b/nest/core/cli_factory.py index 2085b00..65978c5 100644 --- a/nest/core/cli_factory.py +++ b/nest/core/cli_factory.py @@ -1,4 +1,6 @@ import asyncio +import inspect +from functools import partial import click @@ -13,17 +15,48 @@ def __init__(self): def create(self, app_module: ModuleType, **kwargs): container = PyNestContainer() container.add_module(app_module) + container.build() cli_app = click.Group("main") for module in container.modules.values(): - for controller in module.compiled.controllers: - for command in controller._cli_group.commands.values(): - original_callback = command.callback - if asyncio.iscoroutinefunction(original_callback): - command.callback = self._run_async(original_callback) - cli_app.add_command(controller._cli_group) + for controller_class in module.compiled.controllers: + if not hasattr(controller_class, "_cli_group"): + continue + instance = self._resolve_instance(controller_class, container) + cli_group = self._build_cli_group(controller_class, instance) + cli_app.add_command(cli_group) return cli_app + @staticmethod + def _resolve_instance(controller_class: type, container: PyNestContainer): + """Instantiate a CLI controller by resolving its __init__ deps from the container.""" + sig = inspect.signature(controller_class.__init__) + params = list(sig.parameters.values())[1:] # drop self + kwargs = {} + for param in params: + if param.annotation is not inspect.Parameter.empty: + kwargs[param.name] = container.get(param.annotation) + return controller_class(**kwargs) + + def _build_cli_group(self, controller_class: type, instance) -> click.Group: + original_group = controller_class._cli_group + new_group = click.Group(original_group.name) + for cmd_name, cmd in original_group.commands.items(): + # cmd.callback was set to partial(unbound_method, controller_class) + # Re-bind to the resolved instance so __init__ injected attrs are accessible. + unbound = cmd.callback.func if hasattr(cmd.callback, "func") else cmd.callback + bound_callback = partial(unbound, instance) + if asyncio.iscoroutinefunction(unbound): + bound_callback = self._run_async(bound_callback) + new_cmd = click.Command( + name=cmd_name, + callback=bound_callback, + params=list(cmd.params), + help=cmd.help, + ) + new_group.add_command(new_cmd) + return new_group + @staticmethod def _run_async(coro): def wrapper(*args, **kwargs): From 76d5c602620007722e31e9f9bdfa57dd9eff3052 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Fri, 8 May 2026 10:43:21 +0300 Subject: [PATCH 13/17] fix(cli): auto-register module in AppModule after 'generate module' pynest generate module -n now automatically adds the import and registers the new module in src/app_module.py, matching the behavior of 'generate resource' and NestJS CLI convention. Also fixes generate_empty_module_file to scaffold with proper `@Module(imports=[], controllers=[], providers=[])` instead of `@Module()`. Co-Authored-By: Claude Sonnet 4.6 --- nest/cli/src/generate/generate_service.py | 6 ++++++ nest/cli/templates/base_template.py | 13 ++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/nest/cli/src/generate/generate_service.py b/nest/cli/src/generate/generate_service.py index 432a73f..57d56b3 100644 --- a/nest/cli/src/generate/generate_service.py +++ b/nest/cli/src/generate/generate_service.py @@ -117,6 +117,12 @@ def generate_module(self, name: str, path: str = None): path = Path.cwd() / "src" with open(f"{path}/{name}_module.py", "w") as f: f.write(template.generate_empty_module_file()) + app_module_path = Path(path) / "app_module.py" + if app_module_path.exists(): + template.append_module_to_app( + path_to_app_py=str(app_module_path), + module_import_path=f"src.{name}_module", + ) def generate_app(self, app_name: str, db_type: str, is_async: bool, is_cli: bool): """ diff --git a/nest/cli/templates/base_template.py b/nest/cli/templates/base_template.py index 36e412f..bc442ed 100644 --- a/nest/cli/templates/base_template.py +++ b/nest/cli/templates/base_template.py @@ -303,10 +303,12 @@ def append_import( return tree - def append_module_to_app(self, path_to_app_py: str): + def append_module_to_app(self, path_to_app_py: str, module_import_path: str = None): + if module_import_path is None: + module_import_path = f"src.{self.module_name}.{self.module_name}_module" tree = self.append_import( file_path=path_to_app_py, - module_path=f"src.{self.module_name}.{self.module_name}_module", + module_path=module_import_path, class_name=self.class_name, import_exception="from nest.core import App", ) @@ -383,7 +385,8 @@ class {self.capitalized_module_name}Service: def generate_empty_module_file(self) -> str: return f"""from nest.core import Module -@Module() + +@Module(imports=[], controllers=[], providers=[]) class {self.capitalized_module_name}Module: - ... - """ + pass +""" From 02bd18fb4a3d23def27b62404e5e4af52baa27b4 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Fri, 8 May 2026 11:15:35 +0300 Subject: [PATCH 14/17] fix(cli): use find_target_folder to locate src/ in generate module Hardcoded Path.cwd() / 'src' doubled the path when running from inside src/. Now uses the same find_target_folder() logic as generate resource, which walks up/down the directory tree to find src/ regardless of cwd. Co-Authored-By: Claude Sonnet 4.6 --- nest/cli/src/generate/generate_service.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/nest/cli/src/generate/generate_service.py b/nest/cli/src/generate/generate_service.py index 57d56b3..a3026af 100644 --- a/nest/cli/src/generate/generate_service.py +++ b/nest/cli/src/generate/generate_service.py @@ -114,10 +114,15 @@ def generate_module(self, name: str, path: str = None): """ template = self.get_template(name) if path is None: - path = Path.cwd() / "src" - with open(f"{path}/{name}_module.py", "w") as f: + src_path = template.find_target_folder(Path.cwd(), "src") + if src_path is None: + raise Exception("src folder not found") + path = Path(src_path) + else: + path = Path(path) + with open(path / f"{name}_module.py", "w") as f: f.write(template.generate_empty_module_file()) - app_module_path = Path(path) / "app_module.py" + app_module_path = path / "app_module.py" if app_module_path.exists(): template.append_module_to_app( path_to_app_py=str(app_module_path), From c9b3a45f41fda357092d69b31dbe7159ceb84582 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Fri, 8 May 2026 11:23:43 +0300 Subject: [PATCH 15/17] fix(cli): fix Swagger grouping and generate module UX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two issues: 1. AppController scaffold was missing tag="app" — its routes appeared in Swagger's unnamed "default" bucket instead of an "app" section. Fixed by adding tag="app" to app_controller_file() template. 2. 'generate module' created a silent empty skeleton — no output, no hint, leaving the user wondering why nothing appeared in /docs. Now prints CREATE/UPDATE messages and a hint pointing to 'generate resource' for a full CRUD scaffold. Co-Authored-By: Claude Sonnet 4.6 --- nest/cli/src/generate/generate_service.py | 20 +++++++++++++++++++- nest/cli/templates/base_template.py | 2 +- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/nest/cli/src/generate/generate_service.py b/nest/cli/src/generate/generate_service.py index a3026af..5dd2377 100644 --- a/nest/cli/src/generate/generate_service.py +++ b/nest/cli/src/generate/generate_service.py @@ -120,14 +120,32 @@ def generate_module(self, name: str, path: str = None): path = Path(src_path) else: path = Path(path) - with open(path / f"{name}_module.py", "w") as f: + module_file = path / f"{name}_module.py" + with open(module_file, "w") as f: f.write(template.generate_empty_module_file()) + click.echo( + click.style(f"CREATE src/{name}_module.py", fg="green") + ) app_module_path = path / "app_module.py" if app_module_path.exists(): template.append_module_to_app( path_to_app_py=str(app_module_path), module_import_path=f"src.{name}_module", ) + click.echo( + click.style( + f"UPDATE src/app_module.py (registered {template.class_name})", + fg="yellow", + ) + ) + click.echo( + click.style( + f"\nHint: {template.class_name} is an empty skeleton. " + f"Add controllers/providers manually, or use " + f"'pynest generate resource -n {name}' for a full CRUD scaffold.", + fg="cyan", + ) + ) def generate_app(self, app_name: str, db_type: str, is_async: bool, is_cli: bool): """ diff --git a/nest/cli/templates/base_template.py b/nest/cli/templates/base_template.py index bc442ed..dbdc9ae 100644 --- a/nest/cli/templates/base_template.py +++ b/nest/cli/templates/base_template.py @@ -126,7 +126,7 @@ def app_controller_file(): from .app_service import AppService -@Controller("/") +@Controller("/", tag="app") class AppController: def __init__(self, service: AppService): From 941e3d44e4d213ebfa65a7222a49710acae9f4e5 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Sat, 9 May 2026 23:00:43 +0300 Subject: [PATCH 16/17] fix(orm): fix three session/exception bugs in the sync ORM layer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. db_request_handler: change 'return HTTPException(...)' to 'raise HTTPException(...)'. Returning an exception object lets it propagate silently as a truthy value, poisoning any multi-hop service call chain with confusing TypeErrors instead of HTTP 500s. Also removed session lifecycle from the decorator — session management belongs in each service method, not in a cross-cutting decorator. 2. OrmProvider.get_db(): the try/finally block was closing the session immediately after returning it, so the caller always received an already-closed session. Removed the finally; added get_session() context manager (rollback on exception, always close) as the canonical way to obtain a per-call session. 3. Service template: replaced 'self.session = self.config.get_db()' in __init__ (one shared session for the entire singleton lifetime) with 'with self.config.get_session() as session:' inside each method, giving every request its own isolated session and preventing concurrent-request data corruption. Co-Authored-By: Claude Sonnet 4.6 --- nest/cli/templates/orm_template.py | 19 ++++---- nest/core/database/orm_provider.py | 27 +++++++++--- nest/core/decorators/database.py | 70 ++++++++---------------------- 3 files changed, 50 insertions(+), 66 deletions(-) diff --git a/nest/cli/templates/orm_template.py b/nest/cli/templates/orm_template.py index 67f8795..c85ce29 100644 --- a/nest/cli/templates/orm_template.py +++ b/nest/cli/templates/orm_template.py @@ -103,20 +103,21 @@ class {self.capitalized_module_name}Service: def __init__(self): self.config = config - self.session = self.config.get_db() - + @db_request_handler def add_{self.module_name}(self, {self.module_name}: {self.capitalized_module_name}): - new_{self.module_name} = {self.capitalized_module_name}Entity( - **{self.module_name}.dict() - ) - self.session.add(new_{self.module_name}) - self.session.commit() - return new_{self.module_name}.id + with self.config.get_session() as session: + new_{self.module_name} = {self.capitalized_module_name}Entity( + **{self.module_name}.dict() + ) + session.add(new_{self.module_name}) + session.commit() + return new_{self.module_name}.id @db_request_handler def get_{self.module_name}(self): - return self.session.query({self.capitalized_module_name}Entity).all() + with self.config.get_session() as session: + return session.query({self.capitalized_module_name}Entity).all() """ diff --git a/nest/core/database/orm_provider.py b/nest/core/database/orm_provider.py index 21b0df3..4e3244d 100644 --- a/nest/core/database/orm_provider.py +++ b/nest/core/database/orm_provider.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from contextlib import asynccontextmanager -from typing import Any, Dict +from contextlib import asynccontextmanager, contextmanager +from typing import Any, Dict, Generator from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine @@ -82,11 +82,28 @@ def drop_all(self): self.Base.metadata.drop_all(bind=self.engine) def get_db(self) -> Session: + """Return a new session. Caller is responsible for closing it.""" + return self.session() + + @contextmanager + def get_session(self) -> Generator[Session, None, None]: + """Context manager that provides a fresh session per call. + + Rolls back on exception, always closes on exit. Use this in service + methods instead of storing a session on the instance. + + Example:: + + with self.config.get_session() as session: + session.add(entity) + session.commit() + """ db = self.session() try: - return db - except Exception as e: - raise e + yield db + except Exception: + db.rollback() + raise finally: db.close() diff --git a/nest/core/decorators/database.py b/nest/core/decorators/database.py index 970ce51..5f7ea2c 100644 --- a/nest/core/decorators/database.py +++ b/nest/core/decorators/database.py @@ -2,7 +2,6 @@ import time from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -10,75 +9,42 @@ def db_request_handler(func): """ - Decorator that handles database requests, including error handling and session management. - - Args: - func (function): The function to be decorated. - - Returns: - function: The decorated function. + Decorator that wraps ORM service methods with timing, logging, and HTTP error + conversion. Session lifecycle (open / commit / rollback / close) is the + responsibility of each service method — use config.get_session() there. """ def wrapper(self, *args, **kwargs): try: - s = time.time() + start = time.time() result = func(self, *args, **kwargs) - p_time = time.time() - s - logging.info(f"request finished after {p_time}") - if hasattr(self, "session"): - # Check if self is an instance of OrmService - self.session.close() + logger.info(f"db request finished in {time.time() - start:.3f}s") return result + except HTTPException: + raise # already an HTTP error — let FastAPI handle it except Exception as e: - logging.error(e) - if hasattr(self, "session"): - # Check if self is an instance of OrmService - self.session.rollback() - self.session.close() - return HTTPException(status_code=500, detail=str(e)) + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) return wrapper def async_db_request_handler(func): """ - Asynchronous decorator that handles database requests, including error handling, - session management, and logging for async functions. - - Args: - func (function): The async function to be decorated. - - Returns: - function: The decorated async function. + Async version of db_request_handler. Session lifecycle is the caller's + responsibility (pass session via Depends or use config.get_session()). """ async def wrapper(*args, **kwargs): try: - start_time = time.time() - result = await func(*args, **kwargs) # Awaiting the async function - process_time = time.time() - start_time - logger.info(f"Async request finished after {process_time} seconds") + start = time.time() + result = await func(*args, **kwargs) + logger.info(f"async db request finished in {time.time() - start:.3f}s") return result + except HTTPException: + raise except Exception as e: - self = args[0] if args else None - session = getattr(self, "session", None) - # If not found, check in function arguments - if session: - session_type = "class" - else: - session = [arg for arg in args if isinstance(arg, AsyncSession)][0] - if session: - session_type = "function" - else: - raise ValueError("AsyncSession not provided to the function") - - logger.error(f"Error in async request: {e}") - # Rollback if session is in a transaction - if session and session_type == "function" and session.in_transaction(): - await session.rollback() - elif session and session_type == "class": - async with session() as session: - await session.rollback() - raise Exception(f"Error in async request: {e}") + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) return wrapper From 38d6e8f178eea30ff85ee58ac2be07b4d2b00ae7 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Sun, 10 May 2026 00:15:48 +0300 Subject: [PATCH 17/17] feat(di): enforce NestJS-style module encapsulation at build time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously every provider from every module landed in one flat injector pool — a service in module A could inject a service from module B even when A never imported B and B never exported the service. Encapsulation was a documentation-only contract, not a check. Now PyNestContainer.build() runs a validation pass after cycle detection: - Each module gets a 'visible' set: own providers + own controllers + transitively-resolved exports of its imports (re-exports supported) + every provider from any @Module(is_global=True) - Every consumer's __init__ annotations are walked; class-typed deps that are registered providers but not in the consumer's visible set raise ProviderNotExportedException listing every violation with a concrete fix ('add imports=[X] to A and exports=[Y] to X, or move Y into A'). Updated test_imported_module_providers_are_resolvable to declare the export it was implicitly relying on. Added 8 new tests covering: same-module deps, proper import+export, global modules, module re-exports, missing import, missing export, controller violations, and unrelated sibling modules. Co-Authored-By: Claude Sonnet 4.6 --- nest/common/exceptions.py | 8 + nest/core/encapsulation.py | 150 +++++++++++++++ nest/core/pynest_container.py | 2 + tests/test_core/test_encapsulation.py | 226 +++++++++++++++++++++++ tests/test_core/test_pynest_container.py | 2 +- 5 files changed, 387 insertions(+), 1 deletion(-) create mode 100644 nest/core/encapsulation.py create mode 100644 tests/test_core/test_encapsulation.py diff --git a/nest/common/exceptions.py b/nest/common/exceptions.py index 1f4bfbd..e8e72d8 100644 --- a/nest/common/exceptions.py +++ b/nest/common/exceptions.py @@ -24,6 +24,14 @@ def __init__(self, message: str = "Non-injectable class detected"): super().__init__(message) +class ProviderNotExportedException(Exception): + """Raised when a class depends on a provider that lives in another module + which either isn't imported, or doesn't export that provider.""" + + def __init__(self, message: str = "Provider not visible across module boundary"): + super().__init__(message) + + class HttpException(Exception): def __init__(self, message: str = "Internal Server Error", status_code: int = 500): self.message = message diff --git a/nest/core/encapsulation.py b/nest/core/encapsulation.py new file mode 100644 index 0000000..55a2666 --- /dev/null +++ b/nest/core/encapsulation.py @@ -0,0 +1,150 @@ +"""NestJS-style module encapsulation validation. + +Enforces that a class can only inject providers visible to its owning module: +its own providers, providers exported by imported modules (transitively through +re-exports), and providers from globally-scoped modules. +""" +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Dict, List, Set + +from nest.common.exceptions import ProviderNotExportedException + +if TYPE_CHECKING: + from nest.core.pynest_container import ModuleRef + + +_PRIMITIVE_TYPES = { + int, str, float, bool, bytes, bytearray, complex, + list, dict, tuple, set, frozenset, type(None), +} + + +def validate_module_encapsulation(modules: Dict[str, "ModuleRef"]) -> None: + """Validate that every cross-module dependency goes through proper imports/exports. + + Raises ProviderNotExportedException listing every violation with a suggested fix. + """ + if not modules: + return + + # Map every provider token (and controller class) to the module that owns it. + provider_owner: Dict[Any, "ModuleRef"] = {} + for mref in modules.values(): + for desc in mref.compiled.provider_descriptors: + provider_owner[desc.provide] = mref + for ctrl in mref.compiled.controllers: + provider_owner[ctrl] = mref + + # Class → its ModuleRef (so we can look up imported modules by their metatype). + metatype_to_ref: Dict[type, "ModuleRef"] = { + mref.metatype: mref for mref in modules.values() + } + + def resolved_exports(mref: "ModuleRef", _seen: Set[str] = None) -> Set[Any]: + """Return the set of provider tokens that `mref` exposes to importers, + following module re-exports recursively.""" + if _seen is None: + _seen = set() + if mref.token in _seen: + return set() + _seen = _seen | {mref.token} + result: Set[Any] = set() + for exp in mref.compiled.exports: + # If the export is a Module class, re-export everything it exports. + if isinstance(exp, type) and getattr(exp, "__is_module__", False): + child = metatype_to_ref.get(exp) + if child is not None: + result |= resolved_exports(child, _seen) + else: + result.add(exp) + return result + + # Anything provided by an @Module(is_global=True) module is visible everywhere. + global_providers: Set[Any] = set() + for mref in modules.values(): + if getattr(mref.metatype, "__is_global__", False): + for desc in mref.compiled.provider_descriptors: + global_providers.add(desc.provide) + + # Per-module visibility set: own providers + imported exports + globals. + visible: Dict[str, Set[Any]] = {} + for mref in modules.values(): + s: Set[Any] = set() + for desc in mref.compiled.provider_descriptors: + s.add(desc.provide) + for ctrl in mref.compiled.controllers: + s.add(ctrl) + for imp in mref.compiled.imports: + child = metatype_to_ref.get(imp) + if child is not None: + s |= resolved_exports(child) + s |= global_providers + visible[mref.token] = s + + # Walk every consumer's __init__ signature and check each annotated dependency. + errors: List[str] = [] + for mref in modules.values(): + consumers: List[type] = list(mref.compiled.controllers) + for desc in mref.compiled.provider_descriptors: + if desc.use_class is not None: + consumers.append(desc.use_class) + + for consumer_cls in consumers: + try: + sig = inspect.signature(consumer_cls.__init__) + except (ValueError, TypeError): + continue + + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + ann = param.annotation + if ann is inspect.Parameter.empty: + continue + + # Resolve string forward refs against known providers (best effort). + if isinstance(ann, str): + matches = [ + p for p in provider_owner + if isinstance(p, type) and p.__name__ == ann + ] + if len(matches) != 1: + continue + ann = matches[0] + + if ann in _PRIMITIVE_TYPES: + continue + + # Only check tokens that are actually registered somewhere — unknown + # types are someone else's problem (FastAPI body params, externals, etc). + if ann not in provider_owner: + continue + + if ann not in visible[mref.token]: + owner = provider_owner[ann] + errors.append(_format_violation(consumer_cls, mref, ann, owner)) + + if errors: + raise ProviderNotExportedException( + "Module encapsulation violation(s) detected:\n\n" + + "\n\n".join(errors) + ) + + +def _format_violation(consumer_cls: type, consumer_module, dep, owner_module) -> str: + consumer_name = consumer_cls.__name__ + consumer_mod = consumer_module.metatype.__name__ + dep_name = getattr(dep, "__name__", str(dep)) + owner_mod = owner_module.metatype.__name__ + return ( + f" ✗ {consumer_name} (in module {consumer_mod}) depends on {dep_name},\n" + f" but {dep_name} is provided by {owner_mod}.\n" + f"\n" + f" To fix, do BOTH:\n" + f" 1. add 'imports=[{owner_mod}]' to {consumer_mod}\n" + f" 2. add 'exports=[{dep_name}]' to {owner_mod}\n" + f"\n" + f" Or move {dep_name} into {consumer_mod} if it doesn't belong in {owner_mod}." + ) diff --git a/nest/core/pynest_container.py b/nest/core/pynest_container.py index fd20df8..a7d09dc 100644 --- a/nest/core/pynest_container.py +++ b/nest/core/pynest_container.py @@ -8,6 +8,7 @@ from nest.common.module import CompiledModule, ModuleCompiler, ModuleTokenFactory from nest.common.provider import InjectionToken, ProviderDescriptor from nest.core.dependency_graph import DependencyGraph +from nest.core.encapsulation import validate_module_encapsulation from nest.core.injector_module import build_injector, _to_key @@ -79,6 +80,7 @@ def build(self) -> None: Must be called once after all add_module() calls, before any get() calls. """ self._validate_dependency_graph() + validate_module_encapsulation(self._modules) # Controller classes need singleton bindings too so the injector can resolve them all_descriptors = self._all_descriptors + self._make_controller_descriptors() diff --git a/tests/test_core/test_encapsulation.py b/tests/test_core/test_encapsulation.py new file mode 100644 index 0000000..375aeba --- /dev/null +++ b/tests/test_core/test_encapsulation.py @@ -0,0 +1,226 @@ +"""Tests for NestJS-style module encapsulation enforcement.""" +import pytest + +from nest.common.exceptions import ProviderNotExportedException +from nest.core.decorators.injectable import Injectable +from nest.core.decorators.module import Module +from nest.core.pynest_container import PyNestContainer + + +# ── Legal scenarios (should build cleanly) ────────────────────────────────── + + +def test_same_module_dependency_is_legal(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo, Service]) + class M: + pass + + container = PyNestContainer() + container.add_module(M) + container.build() # no raise + + +def test_imported_and_exported_dependency_is_legal(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo], exports=[Repo]) + class RepoMod: + pass + + @Module(providers=[Service], imports=[RepoMod]) + class ServiceMod: + pass + + container = PyNestContainer() + container.add_module(ServiceMod) + container.build() # no raise + + +def test_global_module_provider_is_visible_everywhere(): + @Injectable + class Logger: + pass + + @Injectable + class Service: + def __init__(self, logger: Logger): + self.logger = logger + + # is_global=True → Logger visible to every module without explicit import + @Module(providers=[Logger], is_global=True) + class LoggingMod: + pass + + @Module(providers=[Service]) + class ServiceMod: + pass + + @Module(imports=[LoggingMod, ServiceMod]) + class AppMod: + pass + + container = PyNestContainer() + container.add_module(AppMod) + container.build() # no raise + + +def test_re_exported_module_chains_through(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo], exports=[Repo]) + class RepoMod: + pass + + # CoreMod imports RepoMod and re-exports the whole module + @Module(imports=[RepoMod], exports=[RepoMod]) + class CoreMod: + pass + + # ServiceMod only imports CoreMod, not RepoMod — must still see Repo + @Module(providers=[Service], imports=[CoreMod]) + class ServiceMod: + pass + + container = PyNestContainer() + container.add_module(ServiceMod) + container.build() # no raise + + +# ── Illegal scenarios (should raise) ──────────────────────────────────────── + + +def test_missing_import_raises(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo], exports=[Repo]) + class RepoMod: + pass + + # ServiceMod uses Repo but does NOT import RepoMod + @Module(providers=[Service]) + class ServiceMod: + pass + + @Module(imports=[ServiceMod, RepoMod]) + class AppMod: + pass + + container = PyNestContainer() + container.add_module(AppMod) + with pytest.raises(ProviderNotExportedException) as exc: + container.build() + assert "Service" in str(exc.value) + assert "Repo" in str(exc.value) + assert "RepoMod" in str(exc.value) + + +def test_imported_but_not_exported_raises(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + # RepoMod imports Repo as provider but does NOT export it + @Module(providers=[Repo]) + class RepoMod: + pass + + @Module(providers=[Service], imports=[RepoMod]) + class ServiceMod: + pass + + container = PyNestContainer() + container.add_module(ServiceMod) + with pytest.raises(ProviderNotExportedException) as exc: + container.build() + msg = str(exc.value) + assert "exports=[Repo]" in msg # actionable suggestion in the error + + +def test_controller_cross_module_violation_raises(): + @Injectable + class Repo: + pass + + from nest.core.decorators.controller import Controller + + @Controller("/x") + class Ctrl: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo]) # not exported + class RepoMod: + pass + + @Module(controllers=[Ctrl], imports=[RepoMod]) + class WebMod: + pass + + container = PyNestContainer() + container.add_module(WebMod) + with pytest.raises(ProviderNotExportedException): + container.build() + + +def test_unrelated_modules_do_not_share_providers(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + # Two siblings with no import relation between them + @Module(providers=[Repo]) + class RepoMod: + pass + + @Module(providers=[Service]) + class ServiceMod: + pass + + @Module(imports=[RepoMod, ServiceMod]) + class AppMod: + pass + + container = PyNestContainer() + container.add_module(AppMod) + with pytest.raises(ProviderNotExportedException): + container.build() diff --git a/tests/test_core/test_pynest_container.py b/tests/test_core/test_pynest_container.py index 78102ea..73f276f 100644 --- a/tests/test_core/test_pynest_container.py +++ b/tests/test_core/test_pynest_container.py @@ -24,7 +24,7 @@ class AppModule: pass -@Module(providers=[RepoService]) +@Module(providers=[RepoService], exports=[RepoService]) class RepoModule: pass