diff --git a/nest/cli/src/generate/generate_service.py b/nest/cli/src/generate/generate_service.py index 432a73f..5dd2377 100644 --- a/nest/cli/src/generate/generate_service.py +++ b/nest/cli/src/generate/generate_service.py @@ -114,9 +114,38 @@ def generate_module(self, name: str, path: str = None): """ template = self.get_template(name) if path is None: - path = Path.cwd() / "src" - with open(f"{path}/{name}_module.py", "w") as f: + src_path = template.find_target_folder(Path.cwd(), "src") + if src_path is None: + raise Exception("src folder not found") + path = Path(src_path) + else: + path = Path(path) + module_file = path / f"{name}_module.py" + with open(module_file, "w") as f: f.write(template.generate_empty_module_file()) + click.echo( + click.style(f"CREATE src/{name}_module.py", fg="green") + ) + app_module_path = path / "app_module.py" + if app_module_path.exists(): + template.append_module_to_app( + path_to_app_py=str(app_module_path), + module_import_path=f"src.{name}_module", + ) + click.echo( + click.style( + f"UPDATE src/app_module.py (registered {template.class_name})", + fg="yellow", + ) + ) + click.echo( + click.style( + f"\nHint: {template.class_name} is an empty skeleton. " + f"Add controllers/providers manually, or use " + f"'pynest generate resource -n {name}' for a full CRUD scaffold.", + fg="cyan", + ) + ) def generate_app(self, app_name: str, db_type: str, is_async: bool, is_cli: bool): """ diff --git a/nest/cli/templates/base_template.py b/nest/cli/templates/base_template.py index 36e412f..dbdc9ae 100644 --- a/nest/cli/templates/base_template.py +++ b/nest/cli/templates/base_template.py @@ -126,7 +126,7 @@ def app_controller_file(): from .app_service import AppService -@Controller("/") +@Controller("/", tag="app") class AppController: def __init__(self, service: AppService): @@ -303,10 +303,12 @@ def append_import( return tree - def append_module_to_app(self, path_to_app_py: str): + def append_module_to_app(self, path_to_app_py: str, module_import_path: str = None): + if module_import_path is None: + module_import_path = f"src.{self.module_name}.{self.module_name}_module" tree = self.append_import( file_path=path_to_app_py, - module_path=f"src.{self.module_name}.{self.module_name}_module", + module_path=module_import_path, class_name=self.class_name, import_exception="from nest.core import App", ) @@ -383,7 +385,8 @@ class {self.capitalized_module_name}Service: def generate_empty_module_file(self) -> str: return f"""from nest.core import Module -@Module() + +@Module(imports=[], controllers=[], providers=[]) class {self.capitalized_module_name}Module: - ... - """ + pass +""" diff --git a/nest/cli/templates/orm_template.py b/nest/cli/templates/orm_template.py index 67f8795..c85ce29 100644 --- a/nest/cli/templates/orm_template.py +++ b/nest/cli/templates/orm_template.py @@ -103,20 +103,21 @@ class {self.capitalized_module_name}Service: def __init__(self): self.config = config - self.session = self.config.get_db() - + @db_request_handler def add_{self.module_name}(self, {self.module_name}: {self.capitalized_module_name}): - new_{self.module_name} = {self.capitalized_module_name}Entity( - **{self.module_name}.dict() - ) - self.session.add(new_{self.module_name}) - self.session.commit() - return new_{self.module_name}.id + with self.config.get_session() as session: + new_{self.module_name} = {self.capitalized_module_name}Entity( + **{self.module_name}.dict() + ) + session.add(new_{self.module_name}) + session.commit() + return new_{self.module_name}.id @db_request_handler def get_{self.module_name}(self): - return self.session.query({self.capitalized_module_name}Entity).all() + with self.config.get_session() as session: + return session.query({self.capitalized_module_name}Entity).all() """ diff --git a/nest/common/exceptions.py b/nest/common/exceptions.py index c78e24d..e8e72d8 100644 --- a/nest/common/exceptions.py +++ b/nest/common/exceptions.py @@ -6,16 +6,29 @@ class CircularDependencyException(Exception): - def __init__(self, message="Circular dependency detected"): + """Raised when a circular dependency is detected in the provider graph at build time.""" + + def __init__(self, message: str = "Circular dependency detected"): super().__init__(message) class UnknownModuleException(Exception): + """Raised when a module cannot be found in the container.""" pass class NoneInjectableException(Exception): - def __init__(self, message="None Injectable Classe Detected"): + """Raised when a class without @Injectable is listed as a provider.""" + + def __init__(self, message: str = "Non-injectable class detected"): + super().__init__(message) + + +class ProviderNotExportedException(Exception): + """Raised when a class depends on a provider that lives in another module + which either isn't imported, or doesn't export that provider.""" + + def __init__(self, message: str = "Provider not visible across module boundary"): super().__init__(message) diff --git a/nest/common/module.py b/nest/common/module.py index a877eab..6b0fc86 100644 --- a/nest/common/module.py +++ b/nest/common/module.py @@ -1,11 +1,22 @@ +from __future__ import annotations + import hashlib import random import string import uuid +from dataclasses import dataclass, field from typing import Any, List, Type from uuid import uuid4 -from nest.core import Module +@dataclass +class CompiledModule: + """The result of compiling a @Module-decorated class. Immutable snapshot used by the container.""" + token: str + metatype: Type + imports: List[Type] = field(default_factory=list) + controllers: List[Type] = field(default_factory=list) + exports: List[Any] = field(default_factory=list) + provider_descriptors: List[Any] = field(default_factory=list) class ModulesContainer(dict): @@ -187,18 +198,34 @@ class ModuleCompiler: def __init__(self, module_token_factory: ModuleTokenFactory = ModuleTokenFactory()): self.module_token_factory = module_token_factory - def compile(self, metatype: Type[Any]): - metadata = self.extract_metadata(metatype) - module_type = metadata["type"] - dynamic_metadata = metadata["dynamic_metadata"] - token = self.module_token_factory.create(module_type, dynamic_metadata) - return ModuleFactory(module_type, token, dynamic_metadata) + def compile(self, metatype: Type[Any]) -> CompiledModule: + from nest.common.provider import normalize_provider # local import avoids circular + + if not self.has_module_metadata(metatype): + raise Exception(f"{metatype.__name__} has no metadata found") + + raw_providers = getattr(metatype, "providers", []) or [] + controllers = getattr(metatype, "controllers", []) or [] + imports = getattr(metatype, "imports", []) or [] + exports = getattr(metatype, "exports", []) or [] + + provider_descriptors = [normalize_provider(p) for p in raw_providers] + token = self.module_token_factory.create(metatype) + + return CompiledModule( + token=token, + metatype=metatype, + imports=list(imports), + controllers=list(controllers), + exports=list(exports), + provider_descriptors=provider_descriptors, + ) def extract_metadata(self, metatype) -> dict: + # Kept for backward compat with PyNestApplicationContext.select() metadata = {"type": metatype, "dynamic_metadata": {}} - if not self.has_module_metadata(metatype): - raise Exception(f"{metatype.__name__} as no metadata found") + raise Exception(f"{metatype.__name__} has no metadata found") for props in ["imports", "providers", "controllers", "exports"]: metadata["dynamic_metadata"][props] = getattr(metatype, props, []) return metadata diff --git a/nest/common/provider.py b/nest/common/provider.py new file mode 100644 index 0000000..ed6d698 --- /dev/null +++ b/nest/common/provider.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, List, Optional, Type, Union + + +class Scope(str, Enum): + SINGLETON = "singleton" + TRANSIENT = "transient" + REQUEST = "request" + + +class InjectionToken: + """Named token for injecting non-class values (strings, primitives, configs).""" + + def __init__(self, name: str, description: str = "") -> None: + self.name = name + self.description = description + + def __repr__(self) -> str: + return f"InjectionToken({self.name!r})" + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + return isinstance(other, InjectionToken) and self.name == other.name + + +@dataclass +class ProviderDescriptor: + """Normalized provider definition. Exactly one of use_class/use_value/use_factory/use_existing must be set.""" + + provide: Union[Type, InjectionToken, str] + use_class: Optional[Type] = None + use_value: Any = None + use_factory: Optional[Callable] = None + use_existing: Optional[Union[Type, InjectionToken]] = None + scope: Scope = Scope.SINGLETON + inject: List[Any] = field(default_factory=list) + + +def normalize_provider( + provider: Union[Type, dict, ProviderDescriptor], +) -> ProviderDescriptor: + """Convert any provider form (class, dict, or ProviderDescriptor) to a ProviderDescriptor.""" + if isinstance(provider, ProviderDescriptor): + return provider + + if isinstance(provider, dict): + provide = provider["provide"] + scope = provider.get("scope", Scope.SINGLETON) + if "useValue" in provider: + return ProviderDescriptor( + provide=provide, use_value=provider["useValue"], scope=scope + ) + if "useClass" in provider: + return ProviderDescriptor( + provide=provide, use_class=provider["useClass"], scope=scope + ) + if "useFactory" in provider: + return ProviderDescriptor( + provide=provide, + use_factory=provider["useFactory"], + inject=provider.get("inject", []), + scope=scope, + ) + if "useExisting" in provider: + return ProviderDescriptor( + provide=provide, use_existing=provider["useExisting"], scope=scope + ) + raise ValueError( + f"Invalid provider descriptor: {provider!r}. " + "Must contain one of: useValue, useClass, useFactory, useExisting" + ) + + if not callable(provider): + raise ValueError( + f"Provider must be a class, dict, or ProviderDescriptor, got {provider!r}" + ) + return ProviderDescriptor(provide=provider, use_class=provider, scope=Scope.SINGLETON) diff --git a/nest/common/route_resolver.py b/nest/common/route_resolver.py index 9d545c4..c3dc509 100644 --- a/nest/common/route_resolver.py +++ b/nest/common/route_resolver.py @@ -1,16 +1,143 @@ -from fastapi import APIRouter, FastAPI +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING + +from fastapi import APIRouter, FastAPI, Request + +if TYPE_CHECKING: + from nest.core.pynest_container import PyNestContainer class RoutesResolver: - def __init__(self, container, app_ref: FastAPI): + """ + Walks the module graph, resolves controller instances from the container, + and registers their bound methods as FastAPI route endpoints. + """ + + def __init__(self, container: "PyNestContainer", app_ref: FastAPI) -> None: self.container = container self.app_ref = app_ref - def register_routes(self): - for module in self.container.modules.values(): - for controller in module.controllers.values(): - self.register_route(controller) + def register_routes(self) -> None: + seen: set = set() + for module_ref in self.container.modules.values(): + for controller_class in module_ref.compiled.controllers: + if controller_class in seen: + continue + seen.add(controller_class) + self._register_controller(controller_class) + + def _register_controller(self, controller_class: type) -> None: + instance = self.container.get_controller_instance(controller_class) + tag = getattr(controller_class, "__controller_tag__", None) + prefix = getattr(controller_class, "__route_prefix__", None) or "" + + router = APIRouter(tags=[tag] if tag else None) + + for method_name, unbound in inspect.getmembers(controller_class, predicate=callable): + if not hasattr(unbound, "__http_method__"): + continue + bound = getattr(instance, method_name) + self._add_route(router, bound, unbound, controller_class, prefix) - def register_route(self, controller): - router: APIRouter = controller.get_router() self.app_ref.include_router(router) + + def _add_route( + self, + router: APIRouter, + bound_method, + original_method, + cls: type, + prefix: str, + ) -> None: + from nest.core.decorators.controller import _collect_guards + from nest.core.decorators.http_method import HTTPMethod + + path = getattr(original_method, "__route_path__", "/") + http_method = getattr(original_method, "__http_method__", None) + extra_kwargs = getattr(original_method, "__kwargs__", {}) + + if not isinstance(http_method, HTTPMethod): + return + + full_path = _join_paths(prefix, path) + + route_kwargs = { + "path": full_path, + "endpoint": bound_method, + "methods": [http_method.value], + **extra_kwargs, + } + + if hasattr(original_method, "status_code"): + route_kwargs["status_code"] = original_method.status_code + + guards = _collect_guards(cls, original_method) + if guards: + route_kwargs["dependencies"] = [g.as_dependency() for g in guards] + + route_filters = list(getattr(original_method, "__filters__", [])) + controller_filters = list(getattr(cls, "__filters__", [])) + if route_filters or controller_filters: + route_kwargs["endpoint"] = _wrap_with_filters( + bound_method, route_filters + controller_filters + ) + + router.add_api_route(**route_kwargs) + + +def _wrap_with_filters(endpoint, filters) -> callable: + """Wrap a bound-method endpoint with exception filter logic.""" + from nest.common.exceptions import ArgumentsHost + + original_sig = inspect.signature(endpoint) + existing_params = list(original_sig.parameters.values()) + has_request = any(p.name == "request" for p in existing_params) + + if not has_request: + request_param = inspect.Parameter( + "request", + inspect.Parameter.KEYWORD_ONLY, + annotation=Request, + ) + wrapper_sig = original_sig.replace(parameters=existing_params + [request_param]) + else: + wrapper_sig = original_sig + + orig_param_names = {p.name for p in existing_params} + + async def filter_wrapper(*args, **kwargs): + request = kwargs.get("request") + call_kwargs = {k: v for k, v in kwargs.items() if k in orig_param_names} + try: + result = endpoint(*args, **call_kwargs) + if inspect.isawaitable(result): + result = await result + return result + except Exception as exc: + host = ArgumentsHost(request=request) + for raw_filter in filters: + f = raw_filter() if isinstance(raw_filter, type) else raw_filter + caught = getattr(f, "__caught_exceptions__", ()) + if not caught or isinstance(exc, caught): + result = f.catch(exc, host) + if inspect.isawaitable(result): + return await result + return result + raise + + filter_wrapper.__name__ = getattr(endpoint, "__name__", "filter_wrapper") + filter_wrapper.__signature__ = wrapper_sig + return filter_wrapper + + +def _join_paths(prefix: str, path: str) -> str: + prefix = prefix or "" + path = path or "/" + if not path.startswith("/"): + path = "/" + path + combined = prefix.rstrip("/") + path + if combined.endswith("/") and combined != "/": + combined = combined.rstrip("/") + return combined or "/" diff --git a/nest/core/__init__.py b/nest/core/__init__.py index 8e1b620..5690b23 100644 --- a/nest/core/__init__.py +++ b/nest/core/__init__.py @@ -1,5 +1,6 @@ from fastapi import Depends +from nest.common.provider import InjectionToken, Scope from nest.core.decorators import ( Catch, Controller, diff --git a/nest/core/cli_factory.py b/nest/core/cli_factory.py index 99284be..65978c5 100644 --- a/nest/core/cli_factory.py +++ b/nest/core/cli_factory.py @@ -1,4 +1,6 @@ import asyncio +import inspect +from functools import partial import click @@ -13,17 +15,48 @@ def __init__(self): def create(self, app_module: ModuleType, **kwargs): container = PyNestContainer() container.add_module(app_module) + container.build() cli_app = click.Group("main") for module in container.modules.values(): - for controller in module.controllers.values(): - for command in controller._cli_group.commands.values(): - original_callback = command.callback - if asyncio.iscoroutinefunction(original_callback): - command.callback = self._run_async(original_callback) - cli_app.add_command(controller._cli_group) + for controller_class in module.compiled.controllers: + if not hasattr(controller_class, "_cli_group"): + continue + instance = self._resolve_instance(controller_class, container) + cli_group = self._build_cli_group(controller_class, instance) + cli_app.add_command(cli_group) return cli_app + @staticmethod + def _resolve_instance(controller_class: type, container: PyNestContainer): + """Instantiate a CLI controller by resolving its __init__ deps from the container.""" + sig = inspect.signature(controller_class.__init__) + params = list(sig.parameters.values())[1:] # drop self + kwargs = {} + for param in params: + if param.annotation is not inspect.Parameter.empty: + kwargs[param.name] = container.get(param.annotation) + return controller_class(**kwargs) + + def _build_cli_group(self, controller_class: type, instance) -> click.Group: + original_group = controller_class._cli_group + new_group = click.Group(original_group.name) + for cmd_name, cmd in original_group.commands.items(): + # cmd.callback was set to partial(unbound_method, controller_class) + # Re-bind to the resolved instance so __init__ injected attrs are accessible. + unbound = cmd.callback.func if hasattr(cmd.callback, "func") else cmd.callback + bound_callback = partial(unbound, instance) + if asyncio.iscoroutinefunction(unbound): + bound_callback = self._run_async(bound_callback) + new_cmd = click.Command( + name=cmd_name, + callback=bound_callback, + params=list(cmd.params), + help=cmd.help, + ) + new_group.add_command(new_cmd) + return new_group + @staticmethod def _run_async(coro): def wrapper(*args, **kwargs): diff --git a/nest/core/database/orm_provider.py b/nest/core/database/orm_provider.py index 21b0df3..4e3244d 100644 --- a/nest/core/database/orm_provider.py +++ b/nest/core/database/orm_provider.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from contextlib import asynccontextmanager -from typing import Any, Dict +from contextlib import asynccontextmanager, contextmanager +from typing import Any, Dict, Generator from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine @@ -82,11 +82,28 @@ def drop_all(self): self.Base.metadata.drop_all(bind=self.engine) def get_db(self) -> Session: + """Return a new session. Caller is responsible for closing it.""" + return self.session() + + @contextmanager + def get_session(self) -> Generator[Session, None, None]: + """Context manager that provides a fresh session per call. + + Rolls back on exception, always closes on exit. Use this in service + methods instead of storing a session on the instance. + + Example:: + + with self.config.get_session() as session: + session.add(entity) + session.commit() + """ db = self.session() try: - return db - except Exception as e: - raise e + yield db + except Exception: + db.rollback() + raise finally: db.close() diff --git a/nest/core/decorators/class_based_view.py b/nest/core/decorators/class_based_view.py deleted file mode 100644 index 725fa90..0000000 --- a/nest/core/decorators/class_based_view.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Credit: FastAPI-Utils -Source: https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py -""" - -import inspect -from typing import ( - Any, - Callable, - ClassVar, - List, - Type, - TypeVar, - Union, - get_origin, - get_type_hints, -) - -from fastapi import APIRouter, Depends, Request -from starlette.routing import Route, WebSocketRoute - -T = TypeVar("T") -K = TypeVar("K", bound=Callable[..., Any]) - -CBV_CLASS_KEY = "__cbv_class__" - - -def class_based_view(router: APIRouter, cls: Type[T]) -> Type[T]: - """ - Replaces any methods of the provided class `cls` that are endpoints of routes in `router` with updated - function calls that will properly inject an instance of `cls`. - """ - _init_cbv(cls) - cbv_router = APIRouter() - function_members = inspect.getmembers(cls, inspect.isfunction) - functions_set = set(func for _, func in function_members) - cbv_routes = [ - route - for route in router.routes - if isinstance(route, (Route, WebSocketRoute)) - and route.endpoint in functions_set - ] - for route in cbv_routes: - router.routes.remove(route) - _update_cbv_route_endpoint_signature(cls, route) - _wrap_route_with_filters(cls, route) - cbv_router.routes.append(route) - router.include_router(cbv_router) - return cls - - -def _init_cbv(cls: Type[Any]) -> None: - """ - Idempotently modifies the provided `cls`, performing the following modifications: - * The `__init__` function is updated to set any class-annotated dependencies as instance attributes - * The `__signature__` attribute is updated to indicate to FastAPI what arguments should be passed to the initializer - """ - if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover - return # Already initialized - old_init: Callable[..., Any] = cls.__init__ - old_signature = inspect.signature(old_init) - old_parameters = list(old_signature.parameters.values())[ - 1: - ] # drop `self` parameter - new_parameters = [ - x - for x in old_parameters - if x.kind - not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - ] - dependency_names: List[str] = [] - for name, hint in get_type_hints(cls).items(): - if get_origin(hint) is ClassVar: - continue - parameter_kwargs = {"default": getattr(cls, name, Ellipsis)} - dependency_names.append(name) - new_parameters.append( - inspect.Parameter( - name=name, - kind=inspect.Parameter.KEYWORD_ONLY, - annotation=hint, - **parameter_kwargs, - ) - ) - new_signature = old_signature.replace(parameters=new_parameters) - - def new_init(self: Any, *args: Any, **kwargs: Any) -> None: - for dep_name in dependency_names: - dep_value = kwargs.pop(dep_name) - setattr(self, dep_name, dep_value) - old_init(self, *args, **kwargs) - - setattr(cls, "__signature__", new_signature) - setattr(cls, "__init__", new_init) - setattr(cls, CBV_CLASS_KEY, True) - - -def _update_cbv_route_endpoint_signature( - cls: Type[Any], route: Union[Route, WebSocketRoute] -) -> None: - """ - Fixes the endpoint signature for a cbv route to ensure FastAPI performs dependency injection properly. - """ - old_endpoint = route.endpoint - old_signature = inspect.signature(old_endpoint) - old_parameters: List[inspect.Parameter] = list(old_signature.parameters.values()) - old_first_parameter = old_parameters[0] - new_first_parameter = old_first_parameter.replace(default=Depends(cls)) - new_parameters = [new_first_parameter] + [ - parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) - for parameter in old_parameters[1:] - ] - new_signature = old_signature.replace(parameters=new_parameters) - setattr(route.endpoint, "__signature__", new_signature) - - -def _wrap_route_with_filters(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None: - """Wrap a route endpoint with controller/route-level exception filter logic. - - Called after _update_cbv_route_endpoint_signature so the wrapper inherits - the correct CBV __signature__ (with self as Depends(cls)). - """ - from nest.common.exceptions import ArgumentsHost - - route_filters = list(getattr(route.endpoint, "__filters__", [])) - controller_filters = list(getattr(cls, "__filters__", [])) - if not route_filters and not controller_filters: - return - - original_endpoint = route.endpoint - cbv_signature = getattr(original_endpoint, "__signature__", inspect.signature(original_endpoint)) - - # Inject `request: Request` into the wrapper signature so FastAPI provides it. - existing_params = list(cbv_signature.parameters.values()) - has_request = any(p.name == "request" for p in existing_params) - if not has_request: - request_param = inspect.Parameter( - "request", - inspect.Parameter.KEYWORD_ONLY, - annotation=Request, - ) - wrapper_signature = cbv_signature.replace(parameters=existing_params + [request_param]) - else: - wrapper_signature = cbv_signature - - orig_param_names = {p.name for p in existing_params} - - async def filter_wrapper(*args, **kwargs): - request = kwargs.get("request") - call_kwargs = {k: v for k, v in kwargs.items() if k in orig_param_names} - try: - result = original_endpoint(*args, **call_kwargs) - if inspect.isawaitable(result): - result = await result - return result - except Exception as exc: - host = ArgumentsHost(request=request) - for raw_filter in route_filters + controller_filters: - f = raw_filter() if isinstance(raw_filter, type) else raw_filter - caught = getattr(f, "__caught_exceptions__", ()) - if not caught or isinstance(exc, caught): - result = f.catch(exc, host) - if inspect.isawaitable(result): - return await result - return result - raise - - filter_wrapper.__name__ = getattr(original_endpoint, "__name__", "filter_wrapper") - filter_wrapper.__signature__ = wrapper_signature - route.endpoint = filter_wrapper diff --git a/nest/core/decorators/cli/cli_decorators.py b/nest/core/decorators/cli/cli_decorators.py index 3d82011..97c6f54 100644 --- a/nest/core/decorators/cli/cli_decorators.py +++ b/nest/core/decorators/cli/cli_decorators.py @@ -2,25 +2,11 @@ import click -from nest.core import Controller -from nest.core.decorators.utils import ( - get_instance_variables, - parse_dependencies, - parse_params, -) +from nest.core.decorators.utils import parse_params def CliController(name: str, **kwargs): def decorator(cls): - dependencies = parse_dependencies(cls) - setattr(cls, "__dependencies__", dependencies) - non_dep = get_instance_variables(cls) - for key, value in non_dep.items(): - setattr(cls, key, value) - try: - delattr(cls, "__init__") - except AttributeError: - raise AttributeError("Class must have an __init__ method") cli_group = click.Group(name, **kwargs) setattr(cls, "_cli_group", cli_group) diff --git a/nest/core/decorators/controller.py b/nest/core/decorators/controller.py index b8590b1..1d0135e 100644 --- a/nest/core/decorators/controller.py +++ b/nest/core/decorators/controller.py @@ -1,162 +1,53 @@ -from typing import Optional, Type, List +from __future__ import annotations -from fastapi.routing import APIRouter -from fastapi import Depends +from typing import List, Optional, Type -from nest.core.decorators.class_based_view import class_based_view as ClassBasedView -from nest.core.decorators.http_method import HTTPMethod -from nest.core.decorators.utils import get_instance_variables, parse_dependencies -from nest.core.decorators.guards import BaseGuard +from injector import inject as injector_inject def Controller(prefix: Optional[str] = None, tag: Optional[str] = None): """ - Decorator that turns a class into a controller, allowing you to define - routes using FastAPI decorators. + Marks a class as a PyNest controller. - Args: - prefix (str, optional): The prefix to use for all routes. - tag (str, optional): The tag to use for OpenAPI documentation. + Stores route prefix and tag as class metadata only. Does NOT wrap the class, + delete __init__, or set class-level dependency attributes. + The injector resolves controller instances; RoutesResolver registers bound methods. - Returns: - class: The decorated class. + Args: + prefix: URL prefix for all routes in this controller (e.g. "/users") + tag: OpenAPI tag for Swagger docs """ - # Default route_prefix to tag_name if route_prefix is not provided - route_prefix = process_prefix(prefix, tag) - - def wrapper(cls: Type) -> Type[ClassBasedView]: - router = APIRouter(tags=[tag] if tag else None) - - # Process class dependencies - process_dependencies(cls) + def wrapper(cls: Type) -> Type: + route_prefix = _process_prefix(prefix, tag) - # Set instance variables - set_instance_variables(cls) + cls.__is_controller__ = True + cls.__route_prefix__ = route_prefix + cls.__controller_tag__ = tag - # Ensure the class has an __init__ method - ensure_init_method(cls) + # Mark constructor for injector auto-wiring (same guard as @Injectable) + own_init = cls.__dict__.get("__init__") + if own_init is not None and getattr(own_init, "__annotations__", None): + injector_inject(cls) - # Add routes to the router - add_routes(cls, router, route_prefix) - - # Add get_router method to the class - cls.get_router = classmethod(lambda cls: router) - - return ClassBasedView(router=router, cls=cls) + return cls return wrapper -def process_prefix(route_prefix: Optional[str], tag_name: Optional[str]) -> str: - """Process and format the prefix.""" +def _process_prefix(route_prefix: Optional[str], tag_name: Optional[str]) -> Optional[str]: + if route_prefix is None and tag_name is None: + return None if route_prefix is None: - if tag_name is None: - return None - else: - route_prefix = tag_name - + route_prefix = tag_name if not route_prefix.startswith("/"): route_prefix = "/" + route_prefix - - if route_prefix.endswith("/"): + if route_prefix.endswith("/") and route_prefix != "/": route_prefix = route_prefix.rstrip("/") - return route_prefix -def process_dependencies(cls: Type) -> None: - """Parse and set dependencies for the class.""" - dependencies = parse_dependencies(cls) - setattr(cls, "__dependencies__", dependencies) - - -def set_instance_variables(cls: Type) -> None: - """Set instance variables for the class.""" - non_dependency_vars = get_instance_variables(cls) - for key, value in non_dependency_vars.items(): - setattr(cls, key, value) - - -def ensure_init_method(cls: Type) -> None: - """Ensure the class has an __init__ method.""" - if not hasattr(cls, "__init__"): - raise AttributeError("Class must have an __init__ method") - try: - delattr(cls, "__init__") - except AttributeError: - pass - - -def add_routes(cls: Type, router: APIRouter, route_prefix: str) -> None: - """Add routes from class methods to the router.""" - for method_name, method_function in cls.__dict__.items(): - if callable(method_function) and hasattr(method_function, "__http_method__"): - validate_method_decorator(method_function, method_name) - configure_method_route(method_function, route_prefix) - add_route_to_router(router, method_function, cls) - - -def validate_method_decorator(method_function: callable, method_name: str) -> None: - """Validate that the method has a proper HTTP method decorator.""" - if ( - not hasattr(method_function, "__route_path__") - or not method_function.__route_path__ - ): - raise AssertionError(f"Missing path for method {method_name}") - - if not isinstance(method_function.__http_method__, HTTPMethod): - raise AssertionError(f"Invalid method {method_function.__http_method__}") - - -def configure_method_route(method_function: callable, route_prefix: str) -> None: - """Configure the route for the method.""" - if not method_function.__route_path__.startswith("/"): - method_function.__route_path__ = "/" + method_function.__route_path__ - - method_function.__route_path__ = ( - route_prefix + method_function.__route_path__ - if route_prefix - else method_function.__route_path__ - ) - - # remove trailing "/" fro __route_path__ - # it converts "/api/users/" to "/api/users" - if ( - method_function.__route_path__ != "/" - and method_function.__route_path__.endswith("/") - ): - method_function.__route_path__ = method_function.__route_path__.rstrip("/") - - -def _collect_guards(cls: Type, method: callable) -> List[BaseGuard]: - guards: List[BaseGuard] = [] - for guard in getattr(cls, "__guards__", []): - guards.append(guard) - for guard in getattr(method, "__guards__", []): - guards.append(guard) +def _collect_guards(cls: Type, method) -> List: + guards = list(getattr(cls, "__guards__", [])) + guards.extend(getattr(method, "__guards__", [])) return guards - - -def add_route_to_router( - router: APIRouter, method_function: callable, cls: Type -) -> None: - """Add the configured route to the router.""" - route_kwargs = { - "path": method_function.__route_path__, - "endpoint": method_function, - "methods": [method_function.__http_method__.value], - **method_function.__kwargs__, - } - - if hasattr(method_function, "status_code"): - route_kwargs["status_code"] = method_function.status_code - - guards = _collect_guards(cls, method_function) - if guards: - dependencies = route_kwargs.get("dependencies", []) - for guard in guards: - dependencies.append(guard.as_dependency()) - route_kwargs["dependencies"] = dependencies - - router.add_api_route(**route_kwargs) diff --git a/nest/core/decorators/database.py b/nest/core/decorators/database.py index 970ce51..5f7ea2c 100644 --- a/nest/core/decorators/database.py +++ b/nest/core/decorators/database.py @@ -2,7 +2,6 @@ import time from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -10,75 +9,42 @@ def db_request_handler(func): """ - Decorator that handles database requests, including error handling and session management. - - Args: - func (function): The function to be decorated. - - Returns: - function: The decorated function. + Decorator that wraps ORM service methods with timing, logging, and HTTP error + conversion. Session lifecycle (open / commit / rollback / close) is the + responsibility of each service method — use config.get_session() there. """ def wrapper(self, *args, **kwargs): try: - s = time.time() + start = time.time() result = func(self, *args, **kwargs) - p_time = time.time() - s - logging.info(f"request finished after {p_time}") - if hasattr(self, "session"): - # Check if self is an instance of OrmService - self.session.close() + logger.info(f"db request finished in {time.time() - start:.3f}s") return result + except HTTPException: + raise # already an HTTP error — let FastAPI handle it except Exception as e: - logging.error(e) - if hasattr(self, "session"): - # Check if self is an instance of OrmService - self.session.rollback() - self.session.close() - return HTTPException(status_code=500, detail=str(e)) + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) return wrapper def async_db_request_handler(func): """ - Asynchronous decorator that handles database requests, including error handling, - session management, and logging for async functions. - - Args: - func (function): The async function to be decorated. - - Returns: - function: The decorated async function. + Async version of db_request_handler. Session lifecycle is the caller's + responsibility (pass session via Depends or use config.get_session()). """ async def wrapper(*args, **kwargs): try: - start_time = time.time() - result = await func(*args, **kwargs) # Awaiting the async function - process_time = time.time() - start_time - logger.info(f"Async request finished after {process_time} seconds") + start = time.time() + result = await func(*args, **kwargs) + logger.info(f"async db request finished in {time.time() - start:.3f}s") return result + except HTTPException: + raise except Exception as e: - self = args[0] if args else None - session = getattr(self, "session", None) - # If not found, check in function arguments - if session: - session_type = "class" - else: - session = [arg for arg in args if isinstance(arg, AsyncSession)][0] - if session: - session_type = "function" - else: - raise ValueError("AsyncSession not provided to the function") - - logger.error(f"Error in async request: {e}") - # Rollback if session is in a transaction - if session and session_type == "function" and session.in_transaction(): - await session.rollback() - elif session and session_type == "class": - async with session() as session: - await session.rollback() - raise Exception(f"Error in async request: {e}") + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) return wrapper diff --git a/nest/core/decorators/injectable.py b/nest/core/decorators/injectable.py index 2e00aea..6e03456 100644 --- a/nest/core/decorators/injectable.py +++ b/nest/core/decorators/injectable.py @@ -1,51 +1,47 @@ -from typing import Callable, Optional, Type +from __future__ import annotations + +from typing import Callable, Optional, Type, Union from injector import inject -from nest.common.constants import DEPENDENCIES, INJECTABLE_NAME, INJECTABLE_TOKEN -from nest.core.decorators.utils import parse_dependencies +from nest.common.constants import INJECTABLE_TOKEN +from nest.common.provider import Scope -def Injectable(target_class: Optional[Type] = None, *args, **kwargs) -> Callable: +def Injectable( + target_class: Optional[Type] = None, + *, + scope: Scope = Scope.SINGLETON, +) -> Union[Type, Callable]: """ - Decorator to mark a class as injectable and handle its dependencies. + Mark a class as injectable so the PyNest container can wire its dependencies. - Args: - target_class (Type, optional): The class to be decorated. + Usage: + @Injectable + class MyService: ... - Returns: - Callable: The decorator function. + @Injectable(scope=Scope.TRANSIENT) + class MyService: ... """ - def decorator(decorated_class: Type) -> Type: - """ - Inner decorator function to process the class. - - Args: - decorated_class (Type): The class to be processed. - - Returns: - Type: The processed class with dependencies injected. - """ - - if "__init__" not in decorated_class.__dict__: - - def init_method(self, *args, **kwargs): - pass - - decorated_class.__init__ = init_method - - dependencies = parse_dependencies(decorated_class) - - setattr(decorated_class, DEPENDENCIES, dependencies) - setattr(decorated_class, INJECTABLE_TOKEN, True) - setattr(decorated_class, INJECTABLE_NAME, decorated_class.__name__) + def decorator(cls: Type) -> Type: + # Apply injector's @inject so constructor params are resolved by type annotation. + # Only apply when the class defines its own __init__ with annotated parameters, + # to avoid failures on Python 3.14+ for classes with no custom __init__. + own_init = cls.__dict__.get("__init__") + if own_init is not None and getattr(own_init, "__annotations__", None): + inject(cls) - inject(decorated_class) + # Store metadata flags — never set dependency values as class attributes + setattr(cls, INJECTABLE_TOKEN, True) + setattr(cls, "__injectable_scope__", scope) + setattr(cls, "__injectable_name__", cls.__name__) - return decorated_class + return cls if target_class is not None: + # Called as @Injectable (no parentheses) return decorator(target_class) + # Called as @Injectable(scope=...) with arguments return decorator diff --git a/nest/core/decorators/utils.py b/nest/core/decorators/utils.py index 69fda4b..823d9a3 100644 --- a/nest/core/decorators/utils.py +++ b/nest/core/decorators/utils.py @@ -1,83 +1,11 @@ -import ast import inspect from typing import Callable, List import click -from nest.common.constants import INJECTABLE_TOKEN - - -def get_instance_variables(cls): - """ - Retrieves instance variables assigned in the __init__ method of a class, - excluding those that are injected dependencies. - - Args: - cls (type): The class to inspect. - - Returns: - dict: A dictionary with variable names as keys and their assigned values. - """ - try: - source = inspect.getsource(cls.__init__).strip() - tree = ast.parse(source) - - # Getting the parameter names to exclude dependencies - dependencies = set( - param.name - for param in inspect.signature(cls.__init__).parameters.values() - if param.annotation != param.empty - and getattr(param.annotation, "__injectable__", False) - ) - - instance_vars = {} - for node in ast.walk(tree): - if isinstance(node, ast.Assign): - for target in node.targets: - if ( - isinstance(target, ast.Attribute) - and isinstance(target.value, ast.Name) - and target.value.id == "self" - ): - # Exclude dependencies - if target.attr not in dependencies: - # Here you can either store the source code of the value or - # evaluate it in the class' context, depending on your needs - instance_vars[target.attr] = ast.get_source_segment( - source, node.value - ) - return instance_vars - except Exception as e: - return {} - - -def get_non_dependencies_params(cls): - source = inspect.getsource(cls.__init__).strip() - tree = ast.parse(source) - non_dependencies = {} - for node in ast.walk(tree): - if isinstance(node, ast.Attribute): - non_dependencies[node.attr] = node.value.id - return non_dependencies - - -def parse_dependencies(cls): - signature = inspect.signature(cls.__init__) - dependecies = {} - for param in signature.parameters.values(): - try: - if ( - param.annotation != param.empty - and hasattr(param.annotation, "__dict__") - and INJECTABLE_TOKEN in param.annotation.__dict__ - ): - dependecies[param.name] = param.annotation - except Exception as e: - raise e - return dependecies - def parse_params(func: Callable) -> List[click.Option]: + """Used by the CLI layer — parses Click options from function annotations.""" signature = inspect.signature(func) params = [] for param in signature.parameters.values(): diff --git a/nest/core/dependency_graph.py b/nest/core/dependency_graph.py new file mode 100644 index 0000000..d053fb4 --- /dev/null +++ b/nest/core/dependency_graph.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any, List, Set + + +class CycleError(Exception): + """Raised when a circular dependency is detected in the provider graph.""" + pass + + +class DependencyGraph: + """ + Directed graph where an edge A → B means 'A depends on B'. + Used to detect circular dependencies and determine initialization order. + """ + + def __init__(self) -> None: + self._edges: dict[Any, Set[Any]] = {} + + @property + def nodes(self) -> Set[Any]: + return set(self._edges.keys()) + + def add_node(self, node: Any) -> None: + if node not in self._edges: + self._edges[node] = set() + + def add_dependency(self, dependent: Any, dependency: Any) -> None: + """Record that `dependent` requires `dependency` to be initialized first.""" + self.add_node(dependent) + self.add_node(dependency) + self._edges[dependent].add(dependency) + + def detect_cycles(self) -> List[List[Any]]: + """Return a list of cycles (each cycle is a list of nodes). Empty list = no cycles.""" + visited: Set[Any] = set() + path: List[Any] = [] + path_set: Set[Any] = set() + cycles: List[List[Any]] = [] + + def dfs(node: Any) -> None: + visited.add(node) + path.append(node) + path_set.add(node) + for dep in self._edges.get(node, set()): + if dep not in visited: + dfs(dep) + elif dep in path_set: + idx = path.index(dep) + cycles.append(list(path[idx:]) + [dep]) + path.pop() + path_set.discard(node) + + for node in list(self._edges.keys()): + if node not in visited: + dfs(node) + return cycles + + def topological_sort(self) -> List[Any]: + """Return nodes in initialization order: dependencies come before dependents.""" + visited: Set[Any] = set() + result: List[Any] = [] + + def visit(node: Any) -> None: + if node in visited: + return + visited.add(node) + for dep in self._edges.get(node, set()): + visit(dep) + result.append(node) + + for node in list(self._edges.keys()): + visit(node) + return result + + def validate(self) -> None: + """Raise CycleError if any circular dependencies exist.""" + cycles = self.detect_cycles() + if cycles: + chain = " → ".join( + getattr(n, "__name__", repr(n)) for n in cycles[0] + ) + raise CycleError(f"Circular dependency detected: {chain}") diff --git a/nest/core/encapsulation.py b/nest/core/encapsulation.py new file mode 100644 index 0000000..55a2666 --- /dev/null +++ b/nest/core/encapsulation.py @@ -0,0 +1,150 @@ +"""NestJS-style module encapsulation validation. + +Enforces that a class can only inject providers visible to its owning module: +its own providers, providers exported by imported modules (transitively through +re-exports), and providers from globally-scoped modules. +""" +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Dict, List, Set + +from nest.common.exceptions import ProviderNotExportedException + +if TYPE_CHECKING: + from nest.core.pynest_container import ModuleRef + + +_PRIMITIVE_TYPES = { + int, str, float, bool, bytes, bytearray, complex, + list, dict, tuple, set, frozenset, type(None), +} + + +def validate_module_encapsulation(modules: Dict[str, "ModuleRef"]) -> None: + """Validate that every cross-module dependency goes through proper imports/exports. + + Raises ProviderNotExportedException listing every violation with a suggested fix. + """ + if not modules: + return + + # Map every provider token (and controller class) to the module that owns it. + provider_owner: Dict[Any, "ModuleRef"] = {} + for mref in modules.values(): + for desc in mref.compiled.provider_descriptors: + provider_owner[desc.provide] = mref + for ctrl in mref.compiled.controllers: + provider_owner[ctrl] = mref + + # Class → its ModuleRef (so we can look up imported modules by their metatype). + metatype_to_ref: Dict[type, "ModuleRef"] = { + mref.metatype: mref for mref in modules.values() + } + + def resolved_exports(mref: "ModuleRef", _seen: Set[str] = None) -> Set[Any]: + """Return the set of provider tokens that `mref` exposes to importers, + following module re-exports recursively.""" + if _seen is None: + _seen = set() + if mref.token in _seen: + return set() + _seen = _seen | {mref.token} + result: Set[Any] = set() + for exp in mref.compiled.exports: + # If the export is a Module class, re-export everything it exports. + if isinstance(exp, type) and getattr(exp, "__is_module__", False): + child = metatype_to_ref.get(exp) + if child is not None: + result |= resolved_exports(child, _seen) + else: + result.add(exp) + return result + + # Anything provided by an @Module(is_global=True) module is visible everywhere. + global_providers: Set[Any] = set() + for mref in modules.values(): + if getattr(mref.metatype, "__is_global__", False): + for desc in mref.compiled.provider_descriptors: + global_providers.add(desc.provide) + + # Per-module visibility set: own providers + imported exports + globals. + visible: Dict[str, Set[Any]] = {} + for mref in modules.values(): + s: Set[Any] = set() + for desc in mref.compiled.provider_descriptors: + s.add(desc.provide) + for ctrl in mref.compiled.controllers: + s.add(ctrl) + for imp in mref.compiled.imports: + child = metatype_to_ref.get(imp) + if child is not None: + s |= resolved_exports(child) + s |= global_providers + visible[mref.token] = s + + # Walk every consumer's __init__ signature and check each annotated dependency. + errors: List[str] = [] + for mref in modules.values(): + consumers: List[type] = list(mref.compiled.controllers) + for desc in mref.compiled.provider_descriptors: + if desc.use_class is not None: + consumers.append(desc.use_class) + + for consumer_cls in consumers: + try: + sig = inspect.signature(consumer_cls.__init__) + except (ValueError, TypeError): + continue + + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + ann = param.annotation + if ann is inspect.Parameter.empty: + continue + + # Resolve string forward refs against known providers (best effort). + if isinstance(ann, str): + matches = [ + p for p in provider_owner + if isinstance(p, type) and p.__name__ == ann + ] + if len(matches) != 1: + continue + ann = matches[0] + + if ann in _PRIMITIVE_TYPES: + continue + + # Only check tokens that are actually registered somewhere — unknown + # types are someone else's problem (FastAPI body params, externals, etc). + if ann not in provider_owner: + continue + + if ann not in visible[mref.token]: + owner = provider_owner[ann] + errors.append(_format_violation(consumer_cls, mref, ann, owner)) + + if errors: + raise ProviderNotExportedException( + "Module encapsulation violation(s) detected:\n\n" + + "\n\n".join(errors) + ) + + +def _format_violation(consumer_cls: type, consumer_module, dep, owner_module) -> str: + consumer_name = consumer_cls.__name__ + consumer_mod = consumer_module.metatype.__name__ + dep_name = getattr(dep, "__name__", str(dep)) + owner_mod = owner_module.metatype.__name__ + return ( + f" ✗ {consumer_name} (in module {consumer_mod}) depends on {dep_name},\n" + f" but {dep_name} is provided by {owner_mod}.\n" + f"\n" + f" To fix, do BOTH:\n" + f" 1. add 'imports=[{owner_mod}]' to {consumer_mod}\n" + f" 2. add 'exports=[{dep_name}]' to {owner_mod}\n" + f"\n" + f" Or move {dep_name} into {consumer_mod} if it doesn't belong in {owner_mod}." + ) diff --git a/nest/core/injector_module.py b/nest/core/injector_module.py new file mode 100644 index 0000000..bf32b20 --- /dev/null +++ b/nest/core/injector_module.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any, List + +from injector import Injector, Module as InjectorModule, noscope, singleton + +from nest.common.provider import InjectionToken, ProviderDescriptor, Scope + + +def _injector_scope(scope: Scope): + if scope == Scope.SINGLETON: + return singleton + return noscope # TRANSIENT and REQUEST both use noscope at this layer + + +def _to_key(token: Any) -> Any: + """Convert an InjectionToken or string to an injector-compatible key. + + Both InjectionToken and plain strings are usable directly as injector + binding keys (the injector library accepts any hashable as a key). + """ + return token + + +class PyNestInjectorModule(InjectorModule): + """Translates a list of ProviderDescriptors into injector bindings. + + use_class and use_value providers are bound eagerly inside configure(). + use_factory and use_existing providers are deferred to build_injector() + so that their dependencies (or aliased singletons) are already resolved. + """ + + def __init__(self, descriptors: List[ProviderDescriptor]) -> None: + self._descriptors = [ + d for d in descriptors + if d.use_factory is None and d.use_existing is None + ] + + def configure(self, binder) -> None: + from injector import InstanceProvider + + for desc in self._descriptors: + scope = _injector_scope(desc.scope) + key = _to_key(desc.provide) + + if desc.use_value is not None: + binder.bind(key, to=InstanceProvider(desc.use_value)) + elif desc.use_class is not None: + binder.bind(key, to=desc.use_class, scope=scope) + + +def build_injector(descriptors: List[ProviderDescriptor]) -> Injector: + """ + Build and return a configured Injector from a list of ProviderDescriptors. + + use_factory and use_existing providers are resolved post-build so that + their dependencies and aliased singletons are already in the injector. + """ + from injector import InstanceProvider + + injector = Injector([PyNestInjectorModule(descriptors)]) + + for desc in descriptors: + key = _to_key(desc.provide) + + if desc.use_factory is not None: + deps = [injector.get(_to_key(t)) for t in desc.inject] + instance = desc.use_factory(*deps) + injector.binder.bind(key, to=InstanceProvider(instance)) + + elif desc.use_existing is not None: + # Resolve the aliased key to get the same (singleton) instance + existing_instance = injector.get(_to_key(desc.use_existing)) + injector.binder.bind(key, to=InstanceProvider(existing_instance)) + + return injector diff --git a/nest/core/pynest_application.py b/nest/core/pynest_application.py index 1533ae6..707ec88 100644 --- a/nest/core/pynest_application.py +++ b/nest/core/pynest_application.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect from typing import Any @@ -5,71 +7,37 @@ from fastapi.responses import JSONResponse from nest.common.route_resolver import RoutesResolver -from nest.core.pynest_app_context import PyNestApplicationContext from nest.core.pynest_container import PyNestContainer -class PyNestApp(PyNestApplicationContext): +class PyNestApp: """ - PyNestApp is the main application class for the PyNest framework, - managing the container and HTTP server. + Main PyNest application. Wraps a built container and a FastAPI HTTP server. """ - _is_listening = False - - @property - def is_listening(self) -> bool: - return self._is_listening - - def __init__(self, container: PyNestContainer, http_server: FastAPI): - """ - Initialize the PyNestApp with the given container and HTTP server. - - Args: - container (PyNestContainer): The PyNestContainer container instance. - http_server (FastAPI): The FastAPI server instance. - """ + def __init__(self, container: PyNestContainer, http_server: FastAPI) -> None: self.container = container self.http_server = http_server - super().__init__(self.container) - self.routes_resolver = RoutesResolver(self.container, self.http_server) - self.select_context_module() - self.register_routes() - - def use(self, middleware: type, **options: Any) -> "PyNestApp": - """ - Add middleware to the FastAPI server. - - Args: - middleware (type): The middleware class. - **options (Any): Additional options for the middleware. - - Returns: - PyNestApp: The current instance of PyNestApp, allowing method chaining. - """ - self.http_server.add_middleware(middleware, **options) - return self + routes_resolver = RoutesResolver(self.container, self.http_server) + routes_resolver.register_routes() def get_server(self) -> FastAPI: - """ - Get the FastAPI server instance. + return self.http_server - Returns: - FastAPI: The FastAPI server instance. - """ + def get_http_server(self) -> FastAPI: + """Alias for get_server() — kept for backward compatibility.""" return self.http_server - def register_routes(self): - """ - Register the routes using the RoutesResolver. - """ - self.routes_resolver.register_routes() + def use(self, middleware: type, **options: Any) -> "PyNestApp": + """Add ASGI middleware to the FastAPI server.""" + self.http_server.add_middleware(middleware, **options) + return self def use_global_filters(self, *filters) -> "PyNestApp": """Register one or more exception filters that apply to every route. - Filters are tried in the order provided. Each filter must be an - *instance* of an ExceptionFilter subclass decorated with @Catch. + Filters are tried in the order provided. Each filter must be an + instance of an ExceptionFilter subclass decorated with @Catch. Args: *filters: ExceptionFilter instances to register globally. diff --git a/nest/core/pynest_container.py b/nest/core/pynest_container.py index 9e524f8..a7d09dc 100644 --- a/nest/core/pynest_container.py +++ b/nest/core/pynest_container.py @@ -1,232 +1,170 @@ +from __future__ import annotations + +import inspect import logging -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union + +from nest.common.exceptions import CircularDependencyException +from nest.common.module import CompiledModule, ModuleCompiler, ModuleTokenFactory +from nest.common.provider import InjectionToken, ProviderDescriptor +from nest.core.dependency_graph import DependencyGraph +from nest.core.encapsulation import validate_module_encapsulation +from nest.core.injector_module import build_injector, _to_key -import click -from injector import Injector, UnknownProvider, singleton -from nest.common.constants import DEPENDENCIES, INJECTABLE_TOKEN -from nest.common.exceptions import ( - CircularDependencyException, - NoneInjectableException, - UnknownModuleException, -) -from nest.common.module import ( - Module, - ModuleCompiler, - ModuleFactory, - ModulesContainer, - ModuleTokenFactory, -) +class ModuleRef: + """Internal container representation of a registered module.""" -TController = type("TController", (), {}) -TProvider = type("TProvider", (), {}) + def __init__(self, token: str, metatype: Type, compiled: CompiledModule) -> None: + self.token = token + self.metatype = metatype + self.compiled = compiled + + @property + def name(self) -> str: + return self.metatype.__name__ class PyNestContainer: """ - A singleton container class for managing modules, providers, and dependencies - in a PyNest application. + IoC container managing the module graph, provider bindings, and instance lifecycle. + NOT a singleton — one fresh instance per application, created by PyNestFactory. """ - _instance = None - _dependencies = None - - def __new__(cls): - """Create a singleton instance of PyNestContainer.""" - if cls._instance is None: - cls._instance = super(PyNestContainer, cls).__new__(cls) - return cls._instance - - def __init__(self): - """Initialize the PyNestContainer.""" - if not hasattr(self, "_initialized"): # Prevent reinitialization - self.logger = logging.getLogger("pynest") - self._injector = Injector() - self._global_modules = set() - self._modules = ModulesContainer() - self._module_token_factory = ModuleTokenFactory() - self._module_compiler = ModuleCompiler(self._module_token_factory) - self._modules_metadata = {} - self._initialized = True + def __init__(self) -> None: + self._logger = logging.getLogger("pynest.container") + self._injector = None + self._modules: Dict[str, ModuleRef] = {} + self._all_descriptors: List[ProviderDescriptor] = [] + self._controller_classes: List[Type] = [] + self._module_token_factory = ModuleTokenFactory() + self._module_compiler = ModuleCompiler(self._module_token_factory) + + # ── Public API ───────────────────────────────────────────────────────────── @property - def modules(self): + def modules(self) -> Dict[str, ModuleRef]: return self._modules @property def module_token_factory(self): return self._module_token_factory - @property - def modules_metadata(self): - return self._modules_metadata - @property def module_compiler(self): return self._module_compiler - def get_instance( - self, - dependency: TProvider, - provider: Optional[Union[TProvider, TController]] = None, - ): - try: - self._injector.binder.bind(dependency, scope=singleton) - instance = self._injector.get(dependency) - self.logger.info(click.style(dependency.__name__ + " Detected ", fg="blue")) - except UnknownProvider: - raise Exception(f"Unknown provider {provider}") - return instance - - def add_module(self, metaclass) -> dict: - """ - Add a module to the container. + def add_module(self, module_class: Type) -> dict: + """Compile and register a module and all its imports recursively.""" + compiled = self._module_compiler.compile(module_class) + token = compiled.token - Args: - metaclass: The metaclass of the module to be added. + if token in self._modules: + return {"module_ref": self._modules[token], "inserted": False} - Returns: - dict: A dictionary containing the module reference and a - boolean flag indicating if it was newly inserted. - """ - module_factory = self._module_compiler.compile(metaclass) - token = module_factory.token - if self._modules.has(token): - return {"module_ref": self.modules.get(token), "inserted": False} - return {"module_ref": self.register_module(module_factory), "inserted": True} - - def register_module(self, module_factory: ModuleFactory) -> Module: - """ - Register a module in the container. - - This method creates a module reference from the provided module factory, registers - the module within the container, adds metadata, imports, providers, and controllers - associated with the module, and logs the detection of the module. + # Register imported modules first (depth-first) + for imported in compiled.imports: + self.add_module(imported) - Args: - module_factory (ModuleFactory): The factory object that contains the type and metadata - for creating the module. + module_ref = ModuleRef(token=token, metatype=module_class, compiled=compiled) + self._modules[token] = module_ref + self._all_descriptors.extend(compiled.provider_descriptors) + self._controller_classes.extend(compiled.controllers) - Returns: - Module: The module reference that has been registered in the container. + self._logger.info(f"Module registered: {module_class.__name__}") + return {"module_ref": module_ref, "inserted": True} + def build(self) -> None: """ - module_ref = Module(module_factory.type, self) - module_ref.token = module_factory.token - self._modules[module_factory.token] = module_ref - - self.add_metadata(module_factory.token, module_factory.dynamic_metadata) - self.add_import(module_factory.token) - self.add_providers( - self._get_providers(module_factory.token), module_factory.token - ) - self.add_controllers( - self._get_controllers(module_factory.token), module_factory.token - ) - - self.logger.info( - click.style(module_factory.type.__name__ + " Detected ", fg="green") - ) - - return module_ref - - def add_metadata(self, token: str, module_metadata) -> None: - """Add metadata for a module.""" - if module_metadata: - self._modules_metadata[token] = module_metadata - - def add_import(self, token: str): - """Add imports for a module.""" - if not self.modules.has(token): - return - module_metadata = self._modules_metadata.get(token) - module_ref: Module = self.modules.get(token) - imports_mod: List[Any] = module_metadata.get("imports") - self.add_modules(imports_mod) - module_ref.add_imports(imports_mod) - - def add_modules(self, modules: List[Any]) -> None: - """Add multiple modules to the container.""" - if modules: - for module in modules: - self.add_module(module) - - def add_providers(self, providers: List[Any], module_token: str) -> None: - """Add multiple providers to a module.""" - for provider in providers: - self.add_provider(module_token, provider) - - def add_provider(self, token: str, provider): - """Add a provider to a module.""" - module_ref: Module = self.modules[token] - if not provider: - raise CircularDependencyException(module_ref.metatype) - - if not module_ref: - raise UnknownModuleException() - - if not hasattr(provider, INJECTABLE_TOKEN): - error_message = f""" - {click.style(provider.__name__, fg='red')} is not injectable. - To make {provider.__name__} injectable, apply the {click.style("@Injectable decorator", fg='green')} - to the class definition, or remove {click.style(provider.__name__, fg='red')} from the provider array - of the Module class. Please check your code and ensure that the decorator is correctly applied to the - class. - """ - raise NoneInjectableException(error_message) - - for dependency_name, dependency_instance in getattr( - provider, DEPENDENCIES - ).items(): + Validate the dependency graph and build the injector. + Must be called once after all add_module() calls, before any get() calls. + """ + self._validate_dependency_graph() + validate_module_encapsulation(self._modules) + + # Controller classes need singleton bindings too so the injector can resolve them + all_descriptors = self._all_descriptors + self._make_controller_descriptors() + self._injector = build_injector(all_descriptors) + self._logger.info("Container built successfully") + + def get(self, token: Union[Type, InjectionToken, str]) -> Any: + """Retrieve a fully-wired instance from the container.""" + if self._injector is None: + raise RuntimeError( + "Container not built. Call container.build() before resolving providers." + ) + return self._injector.get(_to_key(token)) + + def get_controller_instance(self, controller_class: Type) -> Any: + """Get a controller instance with all its service dependencies injected.""" + return self.get(controller_class) + + def clear(self) -> None: + """Reset container state. Useful in tests.""" + self._injector = None + self._modules.clear() + self._all_descriptors.clear() + self._controller_classes.clear() + + # ── Internal ─────────────────────────────────────────────────────────────── + + def _make_controller_descriptors(self) -> List[ProviderDescriptor]: + from nest.common.provider import Scope + return [ + ProviderDescriptor(provide=cls, use_class=cls, scope=Scope.SINGLETON) + for cls in self._controller_classes + ] + + def _validate_dependency_graph(self) -> None: + """Build a DAG from all class providers and raise CircularDependencyException on cycles.""" + import sys + + graph = DependencyGraph() + + # Build a name→class lookup from all registered providers so forward refs can be resolved + provider_classes = { + desc.use_class.__name__: desc.use_class + for desc in self._all_descriptors + if desc.use_class is not None + } + + for desc in self._all_descriptors: + if desc.use_class is None: + continue + target = desc.use_class + graph.add_node(target) + try: + sig = inspect.signature(target.__init__) + except (ValueError, TypeError): + continue + + # Resolve type hints, handling string forward references try: - instance = self.get_instance(dependency_instance, provider) - setattr(provider, dependency_name, instance) - except Exception as e: - self.logger.error(e) - raise e - - module_ref.add_provider(provider) - - def _get_providers(self, token: str) -> List[Any]: - """Get providers from the module metadata.""" - return self.modules_metadata[token]["providers"] - - def add_controllers(self, controllers: List[Any], module_token: str) -> None: - """Add multiple controllers to a module.""" - for controller in controllers: - self._add_controller(module_token, controller) - - def _add_controller(self, token: str, controller: TController) -> None: - """Add a controller to a module.""" - if not self.modules.has(token): - raise UnknownModuleException() - module_ref: Module = self.modules[token] - module_ref.add_controller(controller) - if hasattr(controller, DEPENDENCIES): - for provider_name, provider_type in getattr( - controller, DEPENDENCIES - ).items(): - instance = self.get_instance(provider_type, controller) - setattr(controller, provider_name, instance) - - def _get_controllers(self, token: str) -> List[Any]: - """Get controllers from the module metadata.""" - return self.modules_metadata[token]["controllers"] - - def clear(self): - """Clear all modules from the container.""" - self.modules.clear() - - # UNUSED: This function is currently not used but retained for potential future use. - def add_related_module(self, related_module, token: str) -> None: - if not self.modules.has(token): - return - module_ref = self.modules.get(token) - compile_related_module = self.module_compiler.compile(related_module) - related = self.modules.get(compile_related_module.token) - module_ref.add_import(related) - - # UNUSED: This function is currently not used but retained for potential future use. - # It retrieves a module from the container by its key. - def get_module_by_key(self, module_key: str) -> Module: - return self._modules[module_key] + hints = {} + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + ann = param.annotation + if ann is param.empty: + continue + if isinstance(ann, str): + # Try to resolve the string annotation against known providers + resolved = provider_classes.get(ann) + if resolved is not None: + hints[param_name] = resolved + elif isinstance(ann, type): + hints[param_name] = ann + except Exception: + continue + + for dep_type in hints.values(): + graph.add_dependency(target, dep_type) + + cycles = graph.detect_cycles() + if cycles: + chain = " → ".join( + getattr(n, "__name__", repr(n)) for n in cycles[0] + ) + raise CircularDependencyException( + f"Circular dependency detected: {chain}" + ) diff --git a/nest/core/pynest_factory.py b/nest/core/pynest_factory.py index 2d988ae..9058d74 100644 --- a/nest/core/pynest_factory.py +++ b/nest/core/pynest_factory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Type, TypeVar @@ -16,34 +18,26 @@ def create(self, main_module: Type[ModuleType], **kwargs): class PyNestFactory(AbstractPyNestFactory): - """Factory class for creating PyNest applications.""" + """Factory that creates a fully-wired PyNest application from a root module.""" @staticmethod def create(main_module: Type[ModuleType], **kwargs) -> PyNestApp: """ - Create a PyNest application with the specified main module class. - - Args: - main_module (ModuleType): The main module for the PyNest application. - **kwargs: Additional keyword arguments for the FastAPI server. + Build and return a PyNestApp. - Returns: - PyNestApp: The created PyNest application. + 1. Creates a fresh container (NOT a singleton) + 2. Adds the root module (recursively registers all imported modules) + 3. Validates the dependency graph and builds the injector + 4. Creates the FastAPI HTTP server + 5. Registers all routes via RoutesResolver """ container = PyNestContainer() container.add_module(main_module) - http_server = PyNestFactory._create_server(**kwargs) + container.build() + + http_server = FastAPI(**kwargs) return PyNestApp(container, http_server) @staticmethod def _create_server(**kwargs) -> FastAPI: - """ - Create a FastAPI server. - - Args: - **kwargs: Additional keyword arguments for the FastAPI server. - - Returns: - FastAPI: The created FastAPI server. - """ return FastAPI(**kwargs) diff --git a/pyproject.toml b/pyproject.toml index fd90a77..423a345 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,10 @@ include = [ "pyproject.toml", ] + +[tool.poetry.group.dev.dependencies] +httpx = "^0.28.1" + [tool.black] force-exclude = ''' /( diff --git a/tests/test_common/test_module_compiler.py b/tests/test_common/test_module_compiler.py new file mode 100644 index 0000000..23b2799 --- /dev/null +++ b/tests/test_common/test_module_compiler.py @@ -0,0 +1,100 @@ +import pytest +from nest.common.module import ModuleCompiler, ModuleTokenFactory, CompiledModule +from nest.common.provider import ProviderDescriptor, Scope +from nest.core.decorators.module import Module +from nest.core.decorators.injectable import Injectable + + +@Injectable +class ServiceA: + pass + + +@Injectable +class ServiceB: + def __init__(self, a: ServiceA): + self.a = a + + +@Module(providers=[ServiceA]) +class ModuleA: + pass + + +@Module(providers=[ServiceB], imports=[ModuleA], exports=[ServiceB]) +class ModuleB: + pass + + +@Module(providers=[{"provide": "DB_URL", "useValue": "postgres://localhost/test"}]) +class ConfigModule: + pass + + +def test_compiled_module_has_token(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleA) + assert isinstance(result, CompiledModule) + assert result.token is not None + assert len(result.token) > 0 + + +def test_compiled_module_has_metatype(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleA) + assert result.metatype is ModuleA + + +def test_compiled_module_providers_are_normalized(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleA) + assert len(result.provider_descriptors) == 1 + desc = result.provider_descriptors[0] + assert isinstance(desc, ProviderDescriptor) + assert desc.use_class is ServiceA + + +def test_compiled_module_dict_provider_is_normalized(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ConfigModule) + assert len(result.provider_descriptors) == 1 + desc = result.provider_descriptors[0] + assert desc.use_value == "postgres://localhost/test" + assert desc.use_class is None + + +def test_compiled_module_imports(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleB) + assert ModuleA in result.imports + + +def test_compiled_module_exports(): + compiler = ModuleCompiler(ModuleTokenFactory()) + result = compiler.compile(ModuleB) + assert ServiceB in result.exports + + +def test_same_module_gets_same_token(): + factory = ModuleTokenFactory() + compiler = ModuleCompiler(factory) + token1 = compiler.compile(ModuleA).token + token2 = compiler.compile(ModuleA).token + assert token1 == token2 + + +def test_different_modules_get_different_tokens(): + factory = ModuleTokenFactory() + compiler = ModuleCompiler(factory) + token_a = compiler.compile(ModuleA).token + token_b = compiler.compile(ModuleB).token + assert token_a != token_b + + +def test_module_without_decorator_raises(): + class Bare: + pass + + compiler = ModuleCompiler(ModuleTokenFactory()) + with pytest.raises(Exception, match="no metadata"): + compiler.compile(Bare) diff --git a/tests/test_common/test_provider.py b/tests/test_common/test_provider.py new file mode 100644 index 0000000..0a24d7d --- /dev/null +++ b/tests/test_common/test_provider.py @@ -0,0 +1,119 @@ +import pytest +from nest.common.provider import ( + InjectionToken, + Scope, + ProviderDescriptor, + normalize_provider, +) + + +class SomeService: + pass + + +class OtherService: + pass + + +# --- InjectionToken --- + +def test_injection_token_has_name(): + token = InjectionToken("DB_URL") + assert token.name == "DB_URL" + + +def test_injection_token_repr(): + token = InjectionToken("MY_TOKEN") + assert "MY_TOKEN" in repr(token) + + +def test_injection_tokens_with_same_name_are_equal(): + assert InjectionToken("A") == InjectionToken("A") + + +def test_injection_tokens_with_different_names_are_not_equal(): + assert InjectionToken("A") != InjectionToken("B") + + +def test_injection_token_is_hashable(): + token = InjectionToken("X") + d = {token: "value"} + assert d[token] == "value" + + +# --- Scope --- + +def test_scope_values_exist(): + assert Scope.SINGLETON + assert Scope.TRANSIENT + assert Scope.REQUEST + + +# --- normalize_provider: class form --- + +def test_normalize_class_provider_sets_use_class(): + desc = normalize_provider(SomeService) + assert desc.provide is SomeService + assert desc.use_class is SomeService + assert desc.scope == Scope.SINGLETON + + +# --- normalize_provider: useValue dict --- + +def test_normalize_use_value(): + desc = normalize_provider({"provide": "DB_URL", "useValue": "postgres://localhost/db"}) + assert desc.provide == "DB_URL" + assert desc.use_value == "postgres://localhost/db" + assert desc.use_class is None + + +# --- normalize_provider: useClass dict --- + +def test_normalize_use_class(): + desc = normalize_provider({"provide": SomeService, "useClass": OtherService}) + assert desc.provide is SomeService + assert desc.use_class is OtherService + + +# --- normalize_provider: useFactory dict --- + +def test_normalize_use_factory(): + factory = lambda: SomeService() + desc = normalize_provider({ + "provide": SomeService, + "useFactory": factory, + "inject": [OtherService], + }) + assert desc.provide is SomeService + assert desc.use_factory is factory + assert desc.inject == [OtherService] + + +# --- normalize_provider: useExisting dict --- + +def test_normalize_use_existing(): + desc = normalize_provider({"provide": SomeService, "useExisting": OtherService}) + assert desc.provide is SomeService + assert desc.use_existing is OtherService + + +# --- normalize_provider: scope override --- + +def test_normalize_scope_override(): + desc = normalize_provider({"provide": SomeService, "useClass": SomeService, "scope": Scope.TRANSIENT}) + assert desc.scope == Scope.TRANSIENT + + +# --- normalize_provider: invalid dict --- + +def test_normalize_invalid_dict_raises(): + with pytest.raises(ValueError, match="Invalid provider descriptor"): + normalize_provider({"provide": SomeService}) + + +# --- normalize_provider: already a ProviderDescriptor --- + +def test_normalize_passthrough_descriptor(): + desc = ProviderDescriptor(provide=SomeService, use_class=SomeService) + result = normalize_provider(desc) + assert result is desc diff --git a/tests/test_common/test_route_resolver.py b/tests/test_common/test_route_resolver.py index e69de29..1129523 100644 --- a/tests/test_common/test_route_resolver.py +++ b/tests/test_common/test_route_resolver.py @@ -0,0 +1,81 @@ +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from nest.common.route_resolver import RoutesResolver +from nest.core.pynest_container import PyNestContainer +from nest.core.decorators.module import Module +from nest.core.decorators.injectable import Injectable +from nest.core.decorators.controller import Controller +from nest.core.decorators.http_method import Get, Post + + +@Injectable +class GreetService: + def greet(self, name: str) -> str: + return f"Hello, {name}!" + + +@Controller("/greet", tag="greet") +class GreetController: + def __init__(self, svc: GreetService): + self.svc = svc + + @Get("/") + def index(self): + return {"message": "ok"} + + @Get("/{name}") + def greet(self, name: str): + return {"message": self.svc.greet(name)} + + +@Module(providers=[GreetService], controllers=[GreetController]) +class GreetModule: + pass + + +@pytest.fixture +def app_and_container(): + container = PyNestContainer() + container.add_module(GreetModule) + container.build() + app = FastAPI() + resolver = RoutesResolver(container, app) + resolver.register_routes() + return app, container + + +def test_routes_registered(app_and_container): + app, _ = app_and_container + paths = {route.path for route in app.routes} + assert any("/greet" in p for p in paths) + + +def test_get_index_returns_200(app_and_container): + app, _ = app_and_container + client = TestClient(app) + response = client.get("/greet/") + assert response.status_code == 200 + assert response.json() == {"message": "ok"} + + +def test_get_with_path_param(app_and_container): + app, _ = app_and_container + client = TestClient(app) + response = client.get("/greet/World") + assert response.status_code == 200 + assert response.json()["message"] == "Hello, World!" + + +def test_service_is_instance_attribute_not_class(app_and_container): + app, container = app_and_container + instance = container.get_controller_instance(GreetController) + assert "svc" in instance.__dict__ + assert "svc" not in GreetController.__dict__ + + +def test_controller_instance_is_singleton(app_and_container): + app, container = app_and_container + a = container.get_controller_instance(GreetController) + b = container.get_controller_instance(GreetController) + assert a is b diff --git a/tests/test_core/test_decorators/test_controller.py b/tests/test_core/test_decorators/test_controller.py index 0985014..b3d92a6 100644 --- a/tests/test_core/test_decorators/test_controller.py +++ b/tests/test_core/test_decorators/test_controller.py @@ -1,51 +1,102 @@ import pytest +from nest.core.decorators.controller import Controller +from nest.core.decorators.injectable import Injectable +from nest.core.decorators.http_method import Get, Post -from nest.core import Controller, Delete, Get, Injectable, Patch, Post, Put +@Injectable +class TestService: + def hello(self): + return "hello" -@Controller(prefix="api/v1/user", tag="test") -class TestController: - def __init__(self): ... - @Get("/get_all_users") - def get_endpoint(self): - return {"message": "GET endpoint"} +def test_controller_marks_class(): + @Controller("/items") + class ItemController: + def __init__(self, svc: TestService): + self.svc = svc - @Post("/create_user") - def post_endpoint(self): - return {"message": "POST endpoint"} + assert ItemController.__is_controller__ is True - @Delete("/delete_user") - def delete_endpoint(self): - return {"message": "DELETE endpoint"} - @Put("/update_user") - def put_endpoint(self): - return {"message": "PUT endpoint"} +def test_controller_stores_prefix(): + @Controller("/items") + class ItemController: + pass - @Patch("/patch_user") - def patch_endpoint(self): - return {"message": "PATCH endpoint"} + assert ItemController.__route_prefix__ == "/items" -@pytest.fixture -def test_controller(): - return TestController() +def test_controller_stores_tag(): + @Controller("/items", tag="items") + class ItemController: + pass + assert ItemController.__controller_tag__ == "items" -@pytest.mark.parametrize( - "function, endpoint, expected_message", - [ - ("get_endpoint", "get_all_users", "GET endpoint"), - ("post_endpoint", "create_user", "POST endpoint"), - ("delete_endpoint", "delete_user", "DELETE endpoint"), - ("put_endpoint", "update_user", "PUT endpoint"), - ("patch_endpoint", "patch_user", "PATCH endpoint"), - ], -) -def test_endpoints(test_controller, function, endpoint, expected_message): - attribute = getattr(test_controller, function) - assert attribute.__route_path__ == "/api/v1/user/" + endpoint - assert attribute.__kwargs__ == {} - assert attribute.__http_method__.value == function.split("_")[0].upper() - assert attribute() == {"message": f"{function.split('_')[0].upper()} endpoint"} + +def test_controller_adds_leading_slash(): + @Controller("users") + class UserController: + pass + + assert UserController.__route_prefix__ == "/users" + + +def test_controller_strips_trailing_slash(): + @Controller("/users/") + class UserController: + pass + + assert UserController.__route_prefix__ == "/users" + + +def test_controller_does_not_delete_init(): + @Controller("/items") + class ItemController: + def __init__(self, svc: TestService): + self.svc = svc + + # __init__ must still exist on the class — the injector needs it + assert callable(getattr(ItemController, "__init__", None)) + + +def test_controller_does_not_set_class_attributes_for_deps(): + @Controller("/items") + class ItemController: + def __init__(self, svc: TestService): + self.svc = svc + + # Dep must NOT be a class attribute — injector handles instance creation + assert "svc" not in ItemController.__dict__ + + +def test_controller_returns_original_class(): + @Controller("/items") + class ItemController: + pass + + assert isinstance(ItemController, type) + + +def test_controller_http_methods_retain_metadata(): + @Controller("/items") + class ItemController: + @Get("/") + def list_items(self): + return [] + + @Post("/") + def create_item(self): + return {} + + assert hasattr(ItemController.list_items, "__http_method__") + assert hasattr(ItemController.create_item, "__http_method__") + + +def test_controller_with_no_prefix(): + @Controller() + class RootController: + pass + + assert RootController.__route_prefix__ is None diff --git a/tests/test_core/test_decorators/test_guard.py b/tests/test_core/test_decorators/test_guard.py index ccda088..4fdd0f8 100644 --- a/tests/test_core/test_decorators/test_guard.py +++ b/tests/test_core/test_decorators/test_guard.py @@ -25,7 +25,7 @@ def can_activate(self, request: Request, credentials) -> bool: class JWTGuard(BaseGuard): security_scheme = HTTPBearer() - + def can_activate(self, request: Request, credentials=None) -> bool: if credentials and credentials.scheme == "Bearer": return self.validate_jwt(credentials.credentials) @@ -45,27 +45,35 @@ def test_use_guards_sets_attribute(): assert SimpleGuard in GuardController.root.__guards__ -def test_guard_added_to_route_dependencies(): - router = GuardController.get_router() - route = router.routes[0] - deps = route.dependencies - assert len(deps) == 1 - assert callable(deps[0].dependency) +def test_guard_metadata_stored_on_method(): + """Guards are stored as metadata on route methods for later route registration.""" + @Controller("/items") + class ItemController: + @Get("/") + @UseGuards(SimpleGuard) + def list_items(self): + return [] + assert hasattr(ItemController.list_items, "__guards__") + assert SimpleGuard in ItemController.list_items.__guards__ -def _has_security_requirements(dependant): - """Recursively check if a dependant or its dependencies have security requirements.""" - if dependant.security_requirements: - return True - - for dep in dependant.dependencies: - if _has_security_requirements(dep): - return True - - return False + +def test_guard_as_dependency_callable(): + """as_dependency() must produce a valid FastAPI Depends object.""" + dep = SimpleGuard.as_dependency() + # FastAPI Depends wraps a callable in a Depends object + assert callable(dep.dependency) + + +def test_bearer_guard_has_security_scheme(): + """Guards with security_scheme are recognized as having OpenAPI security.""" + assert BearerGuard.security_scheme is not None + dep = BearerGuard.as_dependency() + assert callable(dep.dependency) -def test_openapi_security_requirement(): +def test_controller_preserves_guard_metadata(): + """@Controller must not strip guard metadata from route methods.""" @Controller("/bearer") class BearerController: @Get("/") @@ -73,9 +81,5 @@ class BearerController: def root(self): return {"ok": True} - router = BearerController.get_router() - route = router.routes[0] - - # Check if security requirements exist anywhere in the dependency tree - assert _has_security_requirements(route.dependant), \ - "Security requirements should be present in the dependency tree for OpenAPI integration" + assert hasattr(BearerController.root, "__guards__") + assert BearerGuard in BearerController.root.__guards__ diff --git a/tests/test_core/test_decorators/test_injectable.py b/tests/test_core/test_decorators/test_injectable.py new file mode 100644 index 0000000..1509981 --- /dev/null +++ b/tests/test_core/test_decorators/test_injectable.py @@ -0,0 +1,81 @@ +import pytest +from injector import Injector +from nest.core.decorators.injectable import Injectable +from nest.common.provider import Scope + + +def test_injectable_marks_class(): + @Injectable + class MyService: + pass + + assert hasattr(MyService, "__injectable__") + assert MyService.__injectable__ is True + + +def test_injectable_does_not_mutate_class_with_dep_attributes(): + @Injectable + class Dep: + pass + + @Injectable + class MyService: + def __init__(self, dep: Dep): + self.dep = dep + + # The class must NOT have 'dep' as a class-level attribute + assert "dep" not in MyService.__dict__ + + +def test_injectable_preserves_init(): + @Injectable + class MyService: + def __init__(self, x: int = 5): + self.x = x + + svc = MyService() + assert svc.x == 5 + + +def test_injectable_default_scope_is_singleton(): + @Injectable + class MyService: + pass + + assert MyService.__injectable_scope__ == Scope.SINGLETON + + +def test_injectable_with_transient_scope(): + @Injectable(scope=Scope.TRANSIENT) + class MyService: + pass + + assert MyService.__injectable_scope__ == Scope.TRANSIENT + + +def test_injectable_with_request_scope(): + @Injectable(scope=Scope.REQUEST) + class MyService: + pass + + assert MyService.__injectable_scope__ == Scope.REQUEST + + +def test_injector_resolves_injectable_class(): + @Injectable + class Repo: + def items(self): + return [1, 2, 3] + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + injector = Injector() + svc = injector.get(Service) + assert isinstance(svc, Service) + assert isinstance(svc.repo, Repo) + # dep must be on the INSTANCE, not on the class + assert "repo" in svc.__dict__ + assert "repo" not in Service.__dict__ diff --git a/tests/test_core/test_dependency_graph.py b/tests/test_core/test_dependency_graph.py new file mode 100644 index 0000000..89e27b7 --- /dev/null +++ b/tests/test_core/test_dependency_graph.py @@ -0,0 +1,112 @@ +import pytest +from nest.core.dependency_graph import DependencyGraph, CycleError + + +class A: pass +class B: pass +class C: pass +class D: pass + + +def test_add_node_is_idempotent(): + g = DependencyGraph() + g.add_node(A) + g.add_node(A) + assert A in g.nodes + + +def test_add_dependency_adds_both_nodes(): + g = DependencyGraph() + g.add_dependency(A, B) # A depends on B + assert A in g.nodes + assert B in g.nodes + + +def test_no_cycles_in_linear_chain(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, C) + assert g.detect_cycles() == [] + + +def test_direct_cycle_detected(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, A) + cycles = g.detect_cycles() + assert len(cycles) == 1 + cycle = cycles[0] + assert A in cycle + assert B in cycle + + +def test_indirect_cycle_detected(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, C) + g.add_dependency(C, A) + cycles = g.detect_cycles() + assert len(cycles) == 1 + assert A in cycles[0] + assert B in cycles[0] + assert C in cycles[0] + + +def test_no_cycle_in_diamond(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(A, C) + g.add_dependency(B, D) + g.add_dependency(C, D) + assert g.detect_cycles() == [] + + +def test_topological_sort_single_node(): + g = DependencyGraph() + g.add_node(A) + result = g.topological_sort() + assert result == [A] + + +def test_topological_sort_dependency_comes_first(): + g = DependencyGraph() + g.add_dependency(A, B) # A depends on B → B must come first + result = g.topological_sort() + assert result.index(B) < result.index(A) + + +def test_topological_sort_chain(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, C) + result = g.topological_sort() + assert result.index(C) < result.index(B) < result.index(A) + + +def test_topological_sort_diamond(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(A, C) + g.add_dependency(B, D) + g.add_dependency(C, D) + result = g.topological_sort() + assert result.index(D) < result.index(B) + assert result.index(D) < result.index(C) + assert result.index(B) < result.index(A) + assert result.index(C) < result.index(A) + + +def test_validate_raises_cycle_error_with_chain(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, A) + with pytest.raises(CycleError) as exc_info: + g.validate() + assert "A" in str(exc_info.value) or "B" in str(exc_info.value) + + +def test_validate_passes_for_valid_graph(): + g = DependencyGraph() + g.add_dependency(A, B) + g.add_dependency(B, C) + g.validate() # must not raise diff --git a/tests/test_core/test_encapsulation.py b/tests/test_core/test_encapsulation.py new file mode 100644 index 0000000..375aeba --- /dev/null +++ b/tests/test_core/test_encapsulation.py @@ -0,0 +1,226 @@ +"""Tests for NestJS-style module encapsulation enforcement.""" +import pytest + +from nest.common.exceptions import ProviderNotExportedException +from nest.core.decorators.injectable import Injectable +from nest.core.decorators.module import Module +from nest.core.pynest_container import PyNestContainer + + +# ── Legal scenarios (should build cleanly) ────────────────────────────────── + + +def test_same_module_dependency_is_legal(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo, Service]) + class M: + pass + + container = PyNestContainer() + container.add_module(M) + container.build() # no raise + + +def test_imported_and_exported_dependency_is_legal(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo], exports=[Repo]) + class RepoMod: + pass + + @Module(providers=[Service], imports=[RepoMod]) + class ServiceMod: + pass + + container = PyNestContainer() + container.add_module(ServiceMod) + container.build() # no raise + + +def test_global_module_provider_is_visible_everywhere(): + @Injectable + class Logger: + pass + + @Injectable + class Service: + def __init__(self, logger: Logger): + self.logger = logger + + # is_global=True → Logger visible to every module without explicit import + @Module(providers=[Logger], is_global=True) + class LoggingMod: + pass + + @Module(providers=[Service]) + class ServiceMod: + pass + + @Module(imports=[LoggingMod, ServiceMod]) + class AppMod: + pass + + container = PyNestContainer() + container.add_module(AppMod) + container.build() # no raise + + +def test_re_exported_module_chains_through(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo], exports=[Repo]) + class RepoMod: + pass + + # CoreMod imports RepoMod and re-exports the whole module + @Module(imports=[RepoMod], exports=[RepoMod]) + class CoreMod: + pass + + # ServiceMod only imports CoreMod, not RepoMod — must still see Repo + @Module(providers=[Service], imports=[CoreMod]) + class ServiceMod: + pass + + container = PyNestContainer() + container.add_module(ServiceMod) + container.build() # no raise + + +# ── Illegal scenarios (should raise) ──────────────────────────────────────── + + +def test_missing_import_raises(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo], exports=[Repo]) + class RepoMod: + pass + + # ServiceMod uses Repo but does NOT import RepoMod + @Module(providers=[Service]) + class ServiceMod: + pass + + @Module(imports=[ServiceMod, RepoMod]) + class AppMod: + pass + + container = PyNestContainer() + container.add_module(AppMod) + with pytest.raises(ProviderNotExportedException) as exc: + container.build() + assert "Service" in str(exc.value) + assert "Repo" in str(exc.value) + assert "RepoMod" in str(exc.value) + + +def test_imported_but_not_exported_raises(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + # RepoMod imports Repo as provider but does NOT export it + @Module(providers=[Repo]) + class RepoMod: + pass + + @Module(providers=[Service], imports=[RepoMod]) + class ServiceMod: + pass + + container = PyNestContainer() + container.add_module(ServiceMod) + with pytest.raises(ProviderNotExportedException) as exc: + container.build() + msg = str(exc.value) + assert "exports=[Repo]" in msg # actionable suggestion in the error + + +def test_controller_cross_module_violation_raises(): + @Injectable + class Repo: + pass + + from nest.core.decorators.controller import Controller + + @Controller("/x") + class Ctrl: + def __init__(self, repo: Repo): + self.repo = repo + + @Module(providers=[Repo]) # not exported + class RepoMod: + pass + + @Module(controllers=[Ctrl], imports=[RepoMod]) + class WebMod: + pass + + container = PyNestContainer() + container.add_module(WebMod) + with pytest.raises(ProviderNotExportedException): + container.build() + + +def test_unrelated_modules_do_not_share_providers(): + @Injectable + class Repo: + pass + + @Injectable + class Service: + def __init__(self, repo: Repo): + self.repo = repo + + # Two siblings with no import relation between them + @Module(providers=[Repo]) + class RepoMod: + pass + + @Module(providers=[Service]) + class ServiceMod: + pass + + @Module(imports=[RepoMod, ServiceMod]) + class AppMod: + pass + + container = PyNestContainer() + container.add_module(AppMod) + with pytest.raises(ProviderNotExportedException): + container.build() diff --git a/tests/test_core/test_injector_module.py b/tests/test_core/test_injector_module.py new file mode 100644 index 0000000..e48c743 --- /dev/null +++ b/tests/test_core/test_injector_module.py @@ -0,0 +1,117 @@ +import pytest +from injector import Injector +from nest.core.injector_module import build_injector +from nest.common.provider import ProviderDescriptor, Scope, InjectionToken +from nest.core.decorators.injectable import Injectable + + +@Injectable +class Repo: + def find_all(self): + return ["item1"] + + +@Injectable +class Service: + def __init__(self, repo: Repo): + self.repo = repo + + def get_items(self): + return self.repo.find_all() + + +def test_build_injector_resolves_class_provider(): + descriptors = [ + ProviderDescriptor(provide=Repo, use_class=Repo), + ProviderDescriptor(provide=Service, use_class=Service), + ] + injector = build_injector(descriptors) + service = injector.get(Service) + assert isinstance(service, Service) + assert isinstance(service.repo, Repo) + + +def test_build_injector_singleton_scope_returns_same_instance(): + descriptors = [ProviderDescriptor(provide=Repo, use_class=Repo, scope=Scope.SINGLETON)] + injector = build_injector(descriptors) + a = injector.get(Repo) + b = injector.get(Repo) + assert a is b + + +def test_build_injector_transient_scope_returns_new_instance(): + descriptors = [ProviderDescriptor(provide=Repo, use_class=Repo, scope=Scope.TRANSIENT)] + injector = build_injector(descriptors) + a = injector.get(Repo) + b = injector.get(Repo) + assert a is not b + + +def test_build_injector_use_value_injection_token(): + token = InjectionToken("DB_URL") + descriptors = [ProviderDescriptor(provide=token, use_value="postgres://localhost/test")] + injector = build_injector(descriptors) + value = injector.get(token) + assert value == "postgres://localhost/test" + + +def test_build_injector_use_value_string_key(): + descriptors = [ProviderDescriptor(provide="DB_URL", use_value="postgres://localhost/test")] + injector = build_injector(descriptors) + value = injector.get("DB_URL") + assert value == "postgres://localhost/test" + + +def test_build_injector_use_factory_no_deps(): + call_count = [0] + + def factory(): + call_count[0] += 1 + return Repo() + + descriptors = [ProviderDescriptor(provide=Repo, use_factory=factory, inject=[])] + injector = build_injector(descriptors) + result = injector.get(Repo) + assert isinstance(result, Repo) + assert call_count[0] == 1 + + +def test_build_injector_use_factory_with_deps(): + descriptors = [ + ProviderDescriptor(provide=Repo, use_class=Repo), + ProviderDescriptor( + provide=Service, + use_factory=lambda repo: Service(repo), + inject=[Repo], + ), + ] + injector = build_injector(descriptors) + service = injector.get(Service) + assert isinstance(service.repo, Repo) + + +def test_build_injector_use_existing_aliases(): + descriptors = [ + ProviderDescriptor(provide=Repo, use_class=Repo), + ProviderDescriptor(provide=Service, use_class=Service), + ProviderDescriptor(provide="IService", use_existing=Service), + ] + injector = build_injector(descriptors) + service1 = injector.get(Service) + service2 = injector.get("IService") + assert service1 is service2 + + +def test_use_class_different_from_provide(): + class MockRepo(Repo): + def find_all(self): + return [] + + descriptors = [ + ProviderDescriptor(provide=Repo, use_class=MockRepo), + ProviderDescriptor(provide=Service, use_class=Service), + ] + injector = build_injector(descriptors) + service = injector.get(Service) + assert isinstance(service.repo, MockRepo) + assert service.get_items() == [] diff --git a/tests/test_core/test_pynest_application.py b/tests/test_core/test_pynest_application.py index df364b7..d01c16f 100644 --- a/tests/test_core/test_pynest_application.py +++ b/tests/test_core/test_pynest_application.py @@ -1,28 +1,63 @@ import pytest +from fastapi import FastAPI -from nest.core import PyNestApp -from tests.test_core import test_container, test_resolver -from tests.test_core.test_pynest_factory import test_server +from nest.core import PyNestFactory +from nest.core.pynest_application import PyNestApp +from nest.core import Module, Injectable, Controller, Get + + +@Injectable +class AppService: + def greet(self) -> str: + return "hi" + + +@Controller("/app", tag="app") +class AppController: + def __init__(self, svc: AppService): + self.svc = svc + + @Get("/") + def index(self): + return {"msg": self.svc.greet()} + + +@Module(controllers=[AppController], providers=[AppService]) +class AppModule: + pass @pytest.fixture -def pynest_app(test_container, test_server): - return PyNestApp(container=test_container, http_server=test_server) +def pynest_app(): + return PyNestFactory.create(AppModule) + + +def test_get_server_returns_fastapi(pynest_app): + assert isinstance(pynest_app.get_server(), FastAPI) + + +def test_get_http_server_alias(pynest_app): + assert pynest_app.get_http_server() is pynest_app.get_server() + + +def test_container_is_built(pynest_app): + # If the container is built, we can resolve a provider without RuntimeError + svc = pynest_app.container.get(AppService) + assert isinstance(svc, AppService) -def test_is_listening_property(pynest_app): - assert not pynest_app.is_listening - pynest_app._is_listening = ( - True # Directly modify the protected attribute for testing - ) - assert pynest_app.is_listening +def test_routes_are_registered(pynest_app): + # At least one route should be registered on the FastAPI app + assert len(pynest_app.get_server().routes) > 0 -def test_get_server_returns_http_server(pynest_app, test_server): - assert pynest_app.get_server() == test_server +def test_use_adds_middleware(pynest_app): + from starlette.middleware.base import BaseHTTPMiddleware + class DummyMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) -def test_register_routes_calls_register_routes_on_resolver(pynest_app, test_resolver): - pynest_app.routes_resolver = test_resolver - pynest_app.register_routes() - assert pynest_app.get_server().routes + result = pynest_app.use(DummyMiddleware) + # use() returns self for chaining + assert result is pynest_app diff --git a/tests/test_core/test_pynest_container.py b/tests/test_core/test_pynest_container.py index 9796e25..73f276f 100644 --- a/tests/test_core/test_pynest_container.py +++ b/tests/test_core/test_pynest_container.py @@ -1,28 +1,152 @@ import pytest +from nest.core.pynest_container import PyNestContainer +from nest.core.decorators.module import Module +from nest.core.decorators.injectable import Injectable -from nest.core import Module, PyNestContainer -from tests.test_core import test_module +@Injectable +class RepoService: + def find(self): + return ["a", "b"] -@pytest.fixture(scope="module") -def container(): - # Since PyNestContainer is a singleton, we clear its state before each test to ensure test isolation - PyNestContainer()._instance = None - return PyNestContainer() +@Injectable +class AppService: + def __init__(self, repo: RepoService): + self.repo = repo -def test_singleton_pattern(container): - second_container = PyNestContainer() - assert ( - container is second_container - ), "PyNestContainer should implement the singleton pattern" + def items(self): + return self.repo.find() -def test_add_module(container, test_module): - result = container.add_module(test_module) - assert result["inserted"] is True, "Module should be added successfully" - module_ref = result["module_ref"] - assert module_ref is not None, "Module reference should not be None" - assert container.modules.has( - module_ref.token - ), "Module should be added to the container" +@Module(providers=[RepoService, AppService]) +class AppModule: + pass + + +@Module(providers=[RepoService], exports=[RepoService]) +class RepoModule: + pass + + +@Module(providers=[AppService], imports=[RepoModule]) +class ServiceModule: + pass + + +# ── Container is NOT a singleton ────────────────────────────────────────────── + +def test_two_containers_are_independent(): + c1 = PyNestContainer() + c2 = PyNestContainer() + assert c1 is not c2 + + +# ── add_module + build + get ────────────────────────────────────────────────── + +def test_get_provider_after_build(): + container = PyNestContainer() + container.add_module(AppModule) + container.build() + svc = container.get(AppService) + assert isinstance(svc, AppService) + + +def test_get_returns_singleton_by_default(): + container = PyNestContainer() + container.add_module(AppModule) + container.build() + a = container.get(AppService) + b = container.get(AppService) + assert a is b + + +def test_dependency_is_injected_into_instance(): + container = PyNestContainer() + container.add_module(AppModule) + container.build() + svc = container.get(AppService) + # Must be an instance attribute, NOT a class attribute + assert "repo" in svc.__dict__ + assert isinstance(svc.repo, RepoService) + + +def test_provider_not_class_attribute_after_build(): + container = PyNestContainer() + container.add_module(AppModule) + container.build() + # The class itself must not be mutated — dep is on the instance only + assert "repo" not in AppService.__dict__ + + +def test_add_module_inserts_true_for_new_module(): + container = PyNestContainer() + result = container.add_module(AppModule) + assert result["inserted"] is True + + +def test_add_module_inserts_false_for_duplicate(): + container = PyNestContainer() + container.add_module(AppModule) + result = container.add_module(AppModule) + assert result["inserted"] is False + + +def test_imported_module_providers_are_resolvable(): + container = PyNestContainer() + container.add_module(ServiceModule) + container.build() + svc = container.get(AppService) + assert isinstance(svc.repo, RepoService) + + +# ── Cycle detection ─────────────────────────────────────────────────────────── + +def test_circular_dependency_raises_on_build(): + from nest.common.exceptions import CircularDependencyException + + @Injectable + class Alpha: + def __init__(self, b: "Beta"): + self.b = b + + @Injectable + class Beta: + def __init__(self, a: Alpha): + self.a = a + + @Module(providers=[Alpha, Beta]) + class CycleModule: + pass + + container = PyNestContainer() + container.add_module(CycleModule) + with pytest.raises(CircularDependencyException): + container.build() + + +# ── useValue provider ───────────────────────────────────────────────────────── + +def test_use_value_provider(): + from nest.common.provider import InjectionToken + + DB_URL = InjectionToken("DB_URL") + + @Module(providers=[{"provide": DB_URL, "useValue": "postgres://localhost/test"}]) + class ValueModule: + pass + + container = PyNestContainer() + container.add_module(ValueModule) + container.build() + value = container.get(DB_URL) + assert value == "postgres://localhost/test" + + +# ── get before build raises ─────────────────────────────────────────────────── + +def test_get_before_build_raises(): + container = PyNestContainer() + container.add_module(AppModule) + with pytest.raises(RuntimeError, match="build()"): + container.get(AppService) diff --git a/tests/test_core/test_pynest_factory.py b/tests/test_core/test_pynest_factory.py index 51ecb0f..f39ab84 100644 --- a/tests/test_core/test_pynest_factory.py +++ b/tests/test_core/test_pynest_factory.py @@ -1,29 +1,72 @@ import pytest from fastapi import FastAPI +from fastapi.testclient import TestClient -from nest.core import PyNestFactory # Replace 'your_module' with the actual module name -from tests.test_core import test_container, test_module, test_server +from nest.core import Module, Injectable, Controller, Get, PyNestFactory +from nest.core.pynest_application import PyNestApp -def test_create_server(test_server): - assert isinstance(test_server, FastAPI) - assert test_server.title == "Test Server" - assert test_server.description == "This is a test server" - assert test_server.version == "1.0.0" - assert test_server.debug is True +@Injectable +class MessageService: + def get_message(self) -> str: + return "Hello, World!" -def test_e2e(test_module): +@Controller("/test", tag="test") +class TestController: + def __init__(self, svc: MessageService): + self.svc = svc + + @Get("/") + def index(self): + return {"message": self.svc.get_message()} + + +@Module(controllers=[TestController], providers=[MessageService]) +class TestModule: + pass + + +def test_create_returns_pynest_app(): + app = PyNestFactory.create(TestModule) + assert isinstance(app, PyNestApp) + + +def test_create_produces_fastapi_server(): + app = PyNestFactory.create(TestModule) + assert isinstance(app.get_server(), FastAPI) + + +def test_create_server_kwargs_forwarded(): app = PyNestFactory.create( - test_module, + TestModule, title="Test Server", - description="This is a test server", + description="A test", version="1.0.0", - debug=True, ) - http_server = app.get_server() - assert isinstance(http_server, FastAPI) - assert http_server.title == "Test Server" - assert http_server.description == "This is a test server" - assert http_server.version == "1.0.0" - assert http_server.debug is True + server = app.get_server() + assert server.title == "Test Server" + assert server.version == "1.0.0" + + +def test_e2e_route_responds(): + app = PyNestFactory.create(TestModule) + client = TestClient(app.get_server()) + response = client.get("/test/") + assert response.status_code == 200 + assert response.json() == {"message": "Hello, World!"} + + +def test_two_apps_are_independent(): + app1 = PyNestFactory.create(TestModule) + app2 = PyNestFactory.create(TestModule) + svc1 = app1.container.get(MessageService) + svc2 = app2.container.get(MessageService) + assert svc1 is not svc2 + + +def test_service_dep_is_instance_attribute_not_class(): + app = PyNestFactory.create(TestModule) + ctrl = app.container.get_controller_instance(TestController) + assert "svc" in ctrl.__dict__ + assert "svc" not in TestController.__dict__