diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000..f8b1e6a --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,297 @@ +# Testing + +PyNest ships a built-in testing toolkit — `nest.testing` — so you can test your +application through the real dependency-injection container instead of +hand-wiring objects or monkey-patching. It is modeled after +[`@nestjs/testing`](https://docs.nestjs.com/fundamentals/testing) and +[Suites](https://docs.nestjs.com/recipes/suites). + +```bash +pip install "pynest-api[testing]" # adds httpx for the in-process HTTP client +``` + +Everything below works with plain `pytest`; async examples use +[`pytest-asyncio`](https://pytest-asyncio.readthedocs.io/). + +## Quick start + +Given a typical module: + +```python +from nest.core import Controller, Get, Injectable, Module + +@Injectable +class UserRepository: + def find_all(self): + ... # hits a real database + +@Injectable +class UserService: + def __init__(self, repo: UserRepository): + self.repo = repo + + def list_users(self): + return self.repo.find_all() + +@Controller("/users") +class UserController: + def __init__(self, service: UserService): + self.service = service + + @Get("/") + def list_users(self): + return {"users": self.service.list_users()} + +@Module(controllers=[UserController], providers=[UserService, UserRepository]) +class UserModule: + pass +``` + +Build a real container for it in one line: + +```python +from nest.testing import PyNestTestingModule + +def test_list_users(): + module = PyNestTestingModule.create_testing_module(imports=[UserModule]).compile() + service = module.get(UserService) + assert service.list_users() == [...] +``` + +`create_testing_module()` accepts the same metadata as `@Module` — +`imports`, `controllers`, `providers`, `exports` — and returns a fluent +`TestingModuleBuilder`. `compile()` validates the dependency graph, builds the +injector, and returns a `TestingModule`. It works both synchronously and with +`await`: + +```python +module = builder.compile() # sync tests +module = await builder.compile() # async tests +``` + +## Overriding providers + +Swap any provider for a test double — the rest of the graph stays real: + +```python +class FakeUserRepository: + def find_all(self): + return ["alice"] + +def test_service_with_fake_repo(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .override_provider(UserRepository) + .use_value(FakeUserRepository()) + .compile() + ) + assert module.get(UserService).list_users() == ["alice"] +``` + +Three override strategies are available, mirroring provider definitions: + +```python +builder.override_provider(UserRepository).use_value(FakeUserRepository()) +builder.override_provider(UserRepository).use_class(InMemoryUserRepository) +builder.override_provider(ConfigService).use_factory(lambda: FakeConfig({"db": "sqlite://"})) +``` + +Overriding a token that is not registered anywhere in the module graph raises a +`ValueError` listing the known providers. + +## Auto-mocking + +Replace every class-based provider with a spec'd mock +(`unittest.mock.create_autospec`) in one call. Async methods automatically +become `AsyncMock`s: + +```python +def test_controller_with_all_services_mocked(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .use_auto_mock(exclude=[UserService]) # keep UserService real + .compile() + ) + repo = module.get(UserRepository) # this is a mock + repo.find_all.return_value = ["mocked"] + assert module.get(UserService).list_users() == ["mocked"] +``` + +Explicit `override_provider()` calls always win over auto-mocking, and +controllers are never mocked. Pass `mock_factory=` to control how mocks are +created. + +## HTTP testing without a server + +`create_http_client()` returns an `httpx.AsyncClient` wired straight into the +FastAPI app over ASGI — no sockets, no server process: + +```python +import pytest + +@pytest.mark.asyncio +async def test_list_users_http(): + module = PyNestTestingModule.create_testing_module(imports=[UserModule]).compile() + async with module.create_http_client() as client: + response = await client.get("/users") + assert response.status_code == 200 +``` + +For synchronous tests, `create_test_client()` returns FastAPI's `TestClient`: + +```python +def test_list_users_sync(): + module = PyNestTestingModule.create_testing_module(imports=[UserModule]).compile() + client = module.create_test_client() + assert client.get("/users").status_code == 200 +``` + +Need middleware or global filters? `create_nest_application(**fastapi_kwargs)` +returns the full `PyNestApp` first. + +## Overriding guards + +Bypass (or tighten) authentication in HTTP tests without touching headers: + +```python +class AlwaysPassGuard(BaseGuard): + def can_activate(self, request, credentials=None) -> bool: + return True + +@pytest.mark.asyncio +async def test_protected_route(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .override_guard(AuthGuard) + .use_class(AlwaysPassGuard) # or .use_value(AlwaysPassGuard()) + .compile() + ) + async with module.create_http_client() as client: + assert (await client.get("/users/me")).status_code == 200 +``` + +Guard overrides apply to both controller-level and route-level `@UseGuards`, +and the original guard classes are restored automatically once the test app is +built — nothing leaks between tests. + +## Lifecycle hooks + +`compile()` does **not** run lifecycle hooks, so you stay in control: + +```python +@pytest.mark.asyncio +async def test_with_lifecycle(): + module = PyNestTestingModule.create_testing_module(imports=[DbModule]).compile() + await module.init() # on_module_init / on_application_bootstrap + ... + await module.close() # shutdown hooks, in reverse module order +``` + +Or let the async context manager do it: + +```python +async with PyNestTestingModule.create_testing_module(imports=[DbModule]).compile() as module: + service = module.get(DbService) # init() already ran +# close() ran here +``` + +## pytest fixture patterns + +A reusable module-per-test fixture: + +```python +import pytest_asyncio +from nest.testing import PyNestTestingModule + +@pytest_asyncio.fixture +async def user_module(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .override_provider(UserRepository) + .use_class(InMemoryUserRepository) + .compile() + ) + await module.init() + yield module + await module.close() + + +@pytest.mark.asyncio +async def test_list_users(user_module): + assert user_module.get(UserService).list_users() == [] + + +@pytest.mark.asyncio +async def test_create_user_http(user_module): + async with user_module.create_http_client() as client: + response = await client.post("/users", json={"name": "Bob"}) + assert response.status_code == 200 +``` + +Every compiled `TestingModule` owns a fresh container, so tests are fully +isolated — no shared singletons between them. + +## TestBed: Suites-style unit tests + +For pure unit tests, skip modules entirely. `TestBed.solitary()` builds the +unit under test with **every constructor dependency auto-mocked**: + +```python +from nest.testing import TestBed + +def test_get_users_solitary(): + unit, unit_ref = TestBed.solitary(UserService).compile() + + unit_ref.get(UserRepository).find_all.return_value = ["alice"] + + assert unit.get_users() == ["alice"] + unit_ref.get(UserRepository).find_all.assert_called_once() +``` + +- `unit` is a real `UserService`; `unit_ref.get(Dep)` returns the mock that was + injected for `Dep`. +- Async dependency methods become `AsyncMock`s, so `assert_awaited_once_with` + works out of the box. + +Customize mocks inline with the fluent chain: + +```python +unit, unit_ref = ( + TestBed.solitary(UserService) + .mock(UserRepository).using(find_all=Mock(return_value=["bob"])) + .mock(CacheService).final(InMemoryCache()) # use a real object instead + .compile() +) +``` + +`TestBed.sociable()` keeps chosen collaborators real while mocking everything +deeper in the graph — ideal for testing a service together with its closest +collaborator: + +```python +def test_service_with_real_repo(): + unit, unit_ref = ( + TestBed.sociable(UserService) + .expose(UserRepository) # real UserRepository + .compile() + ) + unit_ref.get(Database).query.return_value = ["row"] # repo's own dep is mocked + assert unit.get_users() == ["row"] +``` + +## API summary + +| API | Purpose | +|---|---| +| `PyNestTestingModule.create_testing_module(...)` | Start a testing-module builder (alias: `Test`) | +| `.override_provider(token).use_value/.use_class/.use_factory` | Replace a provider | +| `.override_guard(Guard).use_value/.use_class` | Replace a guard | +| `.use_auto_mock(exclude=..., mock_factory=...)` | Mock all class providers | +| `.compile()` | Build the container (sync or `await`) | +| `module.get(token)` | Resolve a provider or controller | +| `module.init()` / `module.close()` | Run lifecycle hooks | +| `module.create_http_client()` | In-process `httpx.AsyncClient` | +| `module.create_test_client()` | Synchronous `TestClient` | +| `module.create_nest_application()` | Full `PyNestApp` | +| `TestBed.solitary(Unit)` / `TestBed.sociable(Unit).expose(...)` | Suites-style unit tests | +| `.mock(Dep).using(**attrs)` / `.mock(Dep).final(obj)` | Customize TestBed mocks | diff --git a/mkdocs.yml b/mkdocs.yml index 202ee6d..5d7060b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -60,6 +60,7 @@ nav: - Guards: guards.md - Exception Filters: exception_filters.md - WebSockets: websockets.md + - Testing: testing.md - Dependency Injection: dependency_injection.md - Deployment: - Docker: docker.md diff --git a/nest/core/pynest_container.py b/nest/core/pynest_container.py index 120dc2c..06df663 100644 --- a/nest/core/pynest_container.py +++ b/nest/core/pynest_container.py @@ -117,6 +117,30 @@ def get_controller_instance(self, controller_class: Type) -> Any: """Get a controller instance with all its service dependencies injected.""" return self.get(controller_class) + @property + def provider_descriptors(self) -> List[ProviderDescriptor]: + """Snapshot of every registered provider descriptor.""" + return list(self._all_descriptors) + + def replace_provider( + self, token: Union[Type, InjectionToken, str], descriptor: ProviderDescriptor + ) -> int: + """ + Replace every registered descriptor bound to `token` with `descriptor`. + Must be called after add_module() and before build(). + Returns the number of descriptors replaced (0 if the token is unknown). + """ + if self._injector is not None: + raise RuntimeError( + "Container already built. replace_provider() must be called before build()." + ) + count = 0 + for index, desc in enumerate(self._all_descriptors): + if desc.provide == token: + self._all_descriptors[index] = descriptor + count += 1 + return count + def clear(self) -> None: """Reset container state. Useful in tests.""" self._injector = None diff --git a/nest/testing/__init__.py b/nest/testing/__init__.py new file mode 100644 index 0000000..4fbbbd9 --- /dev/null +++ b/nest/testing/__init__.py @@ -0,0 +1,39 @@ +""" +PyNest testing utilities. + +Build real DI containers for tests without booting an HTTP server, override +providers and guards with mocks, and drive in-process HTTP requests — +inspired by @nestjs/testing and Suites. +""" + +from nest.testing.test_bed import ( + MockChain, + TestBed, + TestBedBuilder, + UnitRef, + UnitTestBed, +) +from nest.testing.testing_module import ( + GuardOverrideBy, + OverrideBy, + PyNestTestingModule, + Test, + TestingModule, + TestingModuleBuilder, + create_auto_mock, +) + +__all__ = [ + "PyNestTestingModule", + "Test", + "TestingModule", + "TestingModuleBuilder", + "OverrideBy", + "GuardOverrideBy", + "create_auto_mock", + "TestBed", + "TestBedBuilder", + "MockChain", + "UnitRef", + "UnitTestBed", +] diff --git a/nest/testing/test_bed.py b/nest/testing/test_bed.py new file mode 100644 index 0000000..7f7c901 --- /dev/null +++ b/nest/testing/test_bed.py @@ -0,0 +1,182 @@ +""" +Suites-style unit-testing harness (https://suites.dev). + +TestBed builds the unit under test with every constructor dependency replaced +by a spec'd mock — no module graph, no container, no manual wiring:: + + unit, unit_ref = TestBed.solitary(UserService).compile() + unit_ref.get(UserRepository).find_all.return_value = ["alice"] + assert unit.get_users() == ["alice"] + +Sociable mode keeps selected collaborators real while mocking the rest:: + + unit, unit_ref = TestBed.sociable(UserService).expose(UserRepository).compile() +""" + +from __future__ import annotations + +import inspect +import typing +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Set, Type +from unittest.mock import create_autospec + + +class UnitRef: + """Read-only registry of the dependencies built for the unit under test.""" + + def __init__(self, instances: Dict[Type, Any]) -> None: + self._instances = instances + + def get(self, token: Type) -> Any: + if token not in self._instances: + known = ", ".join( + sorted(getattr(t, "__name__", repr(t)) for t in self._instances) + ) + raise KeyError( + f"{getattr(token, '__name__', token)!r} was not injected into the " + f"unit under test. Known dependencies: {known or '(none)'}" + ) + return self._instances[token] + + +@dataclass +class UnitTestBed: + """Result of TestBed.compile(): the real unit and a ref to its dependencies.""" + + unit: Any + unit_ref: UnitRef + + def __iter__(self) -> Iterator[Any]: + yield self.unit + yield self.unit_ref + + +class MockChain: + """Second half of the .mock(Dep) fluent chain.""" + + def __init__(self, builder: "TestBedBuilder", dep: Type) -> None: + self._builder = builder + self._dep = dep + + def using(self, **attributes: Any) -> "TestBedBuilder": + """Configure attributes on the auto-generated mock (configure_mock style).""" + self._builder._mock_attrs.setdefault(self._dep, {}).update(attributes) + return self._builder + + def final(self, instance: Any) -> "TestBedBuilder": + """Use `instance` as the dependency instead of generating a mock.""" + self._builder._finals[self._dep] = instance + return self._builder + + +class TestBedBuilder: + def __init__(self, unit_class: Type, sociable: bool) -> None: + self._unit_class = unit_class + self._sociable = sociable + self._exposed: List[Type] = [] + self._finals: Dict[Type, Any] = {} + self._mock_attrs: Dict[Type, Dict[str, Any]] = {} + + def mock(self, dep: Type) -> MockChain: + """Start a mock-customization chain for one dependency.""" + return MockChain(self, dep) + + def expose(self, *deps: Type) -> "TestBedBuilder": + """Keep these dependencies real (sociable mode only).""" + if not self._sociable: + raise RuntimeError( + "expose() is only available on TestBed.sociable(). " + "Solitary units mock every dependency." + ) + self._exposed.extend(deps) + return self + + def compile(self) -> UnitTestBed: + registry: Dict[Type, Any] = {} + building: Set[Type] = set() + unit = self._build_real(self._unit_class, registry, building) + return UnitTestBed(unit=unit, unit_ref=UnitRef(registry)) + + # ── internal ─────────────────────────────────────────────────────────────── + + def _resolve(self, dep_type: Type, registry: Dict, building: Set) -> Any: + if dep_type in registry: + return registry[dep_type] + + if dep_type in self._finals: + instance = self._finals[dep_type] + elif dep_type in self._exposed: + instance = self._build_real(dep_type, registry, building) + else: + instance = create_autospec(dep_type, instance=True) + if dep_type in self._mock_attrs: + instance.configure_mock(**self._mock_attrs[dep_type]) + + registry[dep_type] = instance + return instance + + def _build_real(self, cls: Type, registry: Dict, building: Set) -> Any: + if cls in building: + chain = " → ".join(c.__name__ for c in building) + raise RuntimeError( + f"Circular dependency while building {cls.__name__} ({chain})" + ) + building.add(cls) + try: + kwargs = { + name: self._resolve(dep_type, registry, building) + for name, dep_type in _constructor_dependencies(cls).items() + } + return cls(**kwargs) + finally: + building.discard(cls) + + +def _constructor_dependencies(cls: Type) -> Dict[str, Type]: + """Map constructor parameter names to their annotated types.""" + try: + signature = inspect.signature(cls.__init__) + except (TypeError, ValueError): + return {} + + try: + hints = typing.get_type_hints(cls.__init__) + except Exception: + hints = getattr(cls.__init__, "__annotations__", {}) or {} + + dependencies: Dict[str, Type] = {} + for name, param in signature.parameters.items(): + if name == "self" or param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + annotation = hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty or not isinstance(annotation, type): + if param.default is inspect.Parameter.empty: + raise TypeError( + f"Cannot build {cls.__name__}: constructor parameter " + f"{name!r} has no usable type annotation and no default." + ) + continue + dependencies[name] = annotation + return dependencies + + +class TestBed: + """ + Entry point for Suites-style unit tests. + + - `TestBed.solitary(Unit)` — the unit is real, every dependency is mocked. + - `TestBed.sociable(Unit).expose(Collaborator)` — exposed collaborators stay + real (their own dependencies are mocked); everything else is mocked. + """ + + @staticmethod + def solitary(unit_class: Type) -> TestBedBuilder: + return TestBedBuilder(unit_class, sociable=False) + + @staticmethod + def sociable(unit_class: Type) -> TestBedBuilder: + return TestBedBuilder(unit_class, sociable=True) diff --git a/nest/testing/testing_module.py b/nest/testing/testing_module.py new file mode 100644 index 0000000..d0e0223 --- /dev/null +++ b/nest/testing/testing_module.py @@ -0,0 +1,392 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict, List, Optional, Type, Union +from unittest.mock import create_autospec + +from nest.common.provider import InjectionToken, ProviderDescriptor, Scope +from nest.core.decorators.guards import BaseGuard +from nest.core.pynest_container import PyNestContainer + +ProviderToken = Union[Type, InjectionToken, str] + +_MISSING = object() + + +def _token_name(token: Any) -> str: + return ( + getattr(token, "__name__", None) + or getattr(token, "name", None) + or repr(token) + ) + + +def create_auto_mock(cls: Type) -> Any: + """Create a spec'd mock instance of `cls`. Async methods become AsyncMocks.""" + return create_autospec(cls, instance=True) + + +class OverrideBy: + """Second half of the override_provider() fluent chain.""" + + def __init__(self, builder: "TestingModuleBuilder", token: ProviderToken) -> None: + self._builder = builder + self._token = token + + def use_value(self, value: Any) -> "TestingModuleBuilder": + """Replace the provider with a pre-built instance.""" + return self._builder._add_provider_override( + self._token, ProviderDescriptor(provide=self._token, use_value=value) + ) + + def use_class(self, cls: Type) -> "TestingModuleBuilder": + """Replace the provider with a different class (constructor-injected).""" + return self._builder._add_provider_override( + self._token, + ProviderDescriptor(provide=self._token, use_class=cls, scope=Scope.SINGLETON), + ) + + def use_factory( + self, factory: Callable, inject: Optional[List[Any]] = None + ) -> "TestingModuleBuilder": + """Replace the provider with the result of a factory function.""" + return self._builder._add_provider_override( + self._token, + ProviderDescriptor( + provide=self._token, use_factory=factory, inject=list(inject or []) + ), + ) + + +class GuardOverrideBy: + """Second half of the override_guard() fluent chain.""" + + def __init__(self, builder: "TestingModuleBuilder", guard: Type) -> None: + self._builder = builder + self._guard = guard + + def use_value(self, value: Any) -> "TestingModuleBuilder": + """Replace the guard with a pre-built guard instance.""" + return self._builder._add_guard_override( + self._guard, _guard_class_from_value(value) + ) + + def use_class(self, cls: Type) -> "TestingModuleBuilder": + """Replace the guard with a different guard class.""" + return self._builder._add_guard_override(self._guard, cls) + + +def _guard_class_from_value(value: Any) -> Type: + """Wrap a guard instance in a class so it fits the class-based guard pipeline.""" + + class _ValueGuard(BaseGuard): + security_scheme = getattr(value, "security_scheme", None) + + def __new__(cls): + return value + + _ValueGuard.__name__ = f"Override({type(value).__name__})" + return _ValueGuard + + +class TestingModuleBuilder: + """ + Fluent builder for a test DI container, mirroring NestJS's Test.createTestingModule. + + Example:: + + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .override_provider(UserRepository).use_value(FakeUserRepository()) + .compile() + ) + """ + + def __init__( + self, + imports: Optional[List[Type]] = None, + controllers: Optional[List[Type]] = None, + providers: Optional[List[Any]] = None, + exports: Optional[List[Any]] = None, + ) -> None: + self._imports = list(imports or []) + self._controllers = list(controllers or []) + self._providers = list(providers or []) + self._exports = list(exports or []) + self._provider_overrides: List[tuple] = [] + self._guard_overrides: Dict[Type, Type] = {} + self._auto_mock = False + self._auto_mock_exclude: List[Any] = [] + self._auto_mock_factory: Optional[Callable[[Type], Any]] = None + + # ── overrides ────────────────────────────────────────────────────────────── + + def override_provider(self, token: ProviderToken) -> OverrideBy: + """Start an override chain for a provider token.""" + return OverrideBy(self, token) + + def override_guard(self, guard: Type) -> GuardOverrideBy: + """Start an override chain for a guard class applied via @UseGuards.""" + return GuardOverrideBy(self, guard) + + def use_auto_mock( + self, + exclude: Optional[List[Any]] = None, + mock_factory: Optional[Callable[[Type], Any]] = None, + ) -> "TestingModuleBuilder": + """ + Replace every class-based provider with a spec'd mock + (unittest.mock.create_autospec; async methods become AsyncMocks). + + Explicit override_provider() calls and tokens in `exclude` keep + their real implementation. Controllers are never mocked. + """ + self._auto_mock = True + self._auto_mock_exclude = list(exclude or []) + self._auto_mock_factory = mock_factory + return self + + def _add_provider_override( + self, token: ProviderToken, descriptor: ProviderDescriptor + ) -> "TestingModuleBuilder": + self._provider_overrides.append((token, descriptor)) + return self + + def _add_guard_override(self, guard: Type, replacement: Type) -> "TestingModuleBuilder": + self._guard_overrides[guard] = replacement + return self + + # ── compilation ──────────────────────────────────────────────────────────── + + def compile(self) -> "TestingModule": + """ + Build the container and return a TestingModule. + + Works both synchronously and with `await` (TestingModule is awaitable):: + + module = builder.compile() + module = await builder.compile() + """ + container = PyNestContainer() + container.add_module(self._build_root_module()) + + for token, descriptor in self._provider_overrides: + replaced = container.replace_provider(token, descriptor) + if not replaced: + known = ", ".join( + sorted(_token_name(d.provide) for d in container.provider_descriptors) + ) + raise ValueError( + f"Cannot override provider {_token_name(token)!r}: it is not " + f"registered in the testing module. Known providers: {known}" + ) + + if self._auto_mock: + self._apply_auto_mock(container) + + container.build() + return TestingModule(container, dict(self._guard_overrides)) + + def _build_root_module(self) -> Type: + return type( + "PyNestTestingRootModule", + (), + { + "imports": list(self._imports), + "controllers": list(self._controllers), + "providers": list(self._providers), + "exports": list(self._exports), + "__is_module__": True, + "__is_global__": False, + }, + ) + + def _apply_auto_mock(self, container: PyNestContainer) -> None: + overridden = [token for token, _ in self._provider_overrides] + make_mock = self._auto_mock_factory or create_auto_mock + + for desc in container.provider_descriptors: + if desc.use_class is None: + continue + if desc.provide in overridden: + continue + if desc.provide in self._auto_mock_exclude or ( + desc.use_class in self._auto_mock_exclude + ): + continue + if hasattr(desc.use_class, "__websocket_gateway__"): + continue + container.replace_provider( + desc.provide, + ProviderDescriptor(provide=desc.provide, use_value=make_mock(desc.use_class)), + ) + + +class TestingModule: + """ + A compiled test container. Resolve providers with get(), drive lifecycle + with init()/close(), and test HTTP behavior with create_http_client(). + + Awaitable (so `await builder.compile()` works) and usable as an async + context manager (init() on enter, close() on exit). + """ + + def __init__( + self, container: PyNestContainer, guard_overrides: Dict[Type, Type] + ) -> None: + self._container = container + self._guard_overrides = guard_overrides + self._app = None + self._guard_patches: List[tuple] = [] + + @property + def container(self) -> PyNestContainer: + return self._container + + def get(self, token: ProviderToken) -> Any: + """Retrieve a fully-wired instance (provider or controller) by token.""" + return self._container.get(token) + + async def init(self) -> "TestingModule": + """Run on_module_init / on_application_bootstrap lifecycle hooks.""" + await self._container.initialize_lifecycle() + return self + + async def close(self, signal: Optional[str] = None) -> None: + """Run shutdown lifecycle hooks and undo guard patches.""" + try: + await self._container.shutdown_lifecycle(signal) + finally: + self._restore_guards() + + def __await__(self): + async def _identity() -> "TestingModule": + return self + + return _identity().__await__() + + async def __aenter__(self) -> "TestingModule": + return await self.init() + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.close() + + # ── HTTP testing ─────────────────────────────────────────────────────────── + + def create_nest_application(self, **fastapi_kwargs: Any): + """ + Create (and cache) a full PyNestApp with all routes registered. + Guard overrides are applied before route registration. + """ + if self._app is None: + from fastapi import FastAPI + + from nest.core.pynest_application import PyNestApp + + # Guards are captured into route dependencies during PyNestApp + # construction, so the __guards__ patches can be undone right after — + # no global state leaks between testing modules. + self._apply_guard_overrides() + try: + self._app = PyNestApp(self._container, FastAPI(**fastapi_kwargs)) + finally: + self._restore_guards() + return self._app + + def create_http_client(self, **client_kwargs: Any): + """ + In-process async HTTP client (httpx.AsyncClient over ASGITransport). + No network, no server process. Requires the `testing` extra: + pip install "pynest-api[testing]". + """ + httpx = _import_httpx() + app = self.create_nest_application().get_server() + client_kwargs.setdefault("base_url", "http://testserver") + return httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), **client_kwargs + ) + + def create_test_client(self, **client_kwargs: Any): + """Synchronous in-process client (fastapi.testclient.TestClient).""" + from fastapi.testclient import TestClient + + return TestClient(self.create_nest_application().get_server(), **client_kwargs) + + # ── guard patching ───────────────────────────────────────────────────────── + + def _apply_guard_overrides(self) -> None: + if not self._guard_overrides: + return + for module_ref in self._container.modules.values(): + for controller in module_ref.compiled.controllers: + self._patch_guards(controller) + for member in vars(controller).values(): + if callable(member) and hasattr(member, "__guards__"): + self._patch_guards(member) + + def _patch_guards(self, owner: Any) -> None: + guards = getattr(owner, "__guards__", None) + if not guards: + return + patched = [self._guard_overrides.get(guard, guard) for guard in guards] + if patched == list(guards): + return + original = vars(owner).get("__guards__", _MISSING) + self._guard_patches.append((owner, original)) + setattr(owner, "__guards__", patched) + + def _restore_guards(self) -> None: + for owner, original in reversed(self._guard_patches): + if original is _MISSING: + try: + delattr(owner, "__guards__") + except AttributeError: + pass + else: + setattr(owner, "__guards__", original) + self._guard_patches.clear() + + +class PyNestTestingModule: + """ + Entry point for PyNest's testing utilities, mirroring NestJS's `Test` class. + + Example:: + + from nest.testing import PyNestTestingModule + + module = ( + PyNestTestingModule.create_testing_module(imports=[AppModule]) + .override_provider(Database).use_value(FakeDatabase()) + .compile() + ) + service = module.get(UserService) + """ + + @staticmethod + def create_testing_module( + imports: Optional[List[Type]] = None, + controllers: Optional[List[Type]] = None, + providers: Optional[List[Any]] = None, + exports: Optional[List[Any]] = None, + ) -> TestingModuleBuilder: + return TestingModuleBuilder( + imports=imports, + controllers=controllers, + providers=providers, + exports=exports, + ) + + +# NestJS-style alias: Test.create_testing_module(...) +Test = PyNestTestingModule + + +def _import_httpx(): + try: + import httpx + except ImportError as exc: # pragma: no cover + raise ImportError( + "httpx is required for create_http_client(). " + 'Install it with: pip install "pynest-api[testing]"' + ) from exc + return httpx diff --git a/pyproject.toml b/pyproject.toml index af6005f..5e98e3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,10 +50,14 @@ mongo = [ test = [ "pytest>=7.0.1,<8.0.0", ] +testing = [ + "httpx>=0.27.0,<1.0.0", +] [dependency-groups] test = [ "pytest>=7.0.1,<8.0.0", + "pytest-asyncio>=0.21.1,<1.0.0", "httpx>=0.27.0,<1.0.0", "sqlalchemy>=2.0.36,<3.0.0", "motor>=3.2.0,<4.0.0", diff --git a/tests/test_testing/__init__.py b/tests/test_testing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_testing/test_http_testing.py b/tests/test_testing/test_http_testing.py new file mode 100644 index 0000000..85ffafb --- /dev/null +++ b/tests/test_testing/test_http_testing.py @@ -0,0 +1,203 @@ +import pytest +from fastapi import Request + +from nest.core import ( + BaseGuard, + Body, + Controller, + Get, + Injectable, + Module, + Post, + UseGuards, +) +from nest.testing import PyNestTestingModule + + +@Injectable +class TodoService: + def __init__(self): + self.todos = [] + + def add(self, title: str): + todo = {"id": len(self.todos) + 1, "title": title} + self.todos.append(todo) + return todo + + def list_all(self): + return self.todos + + +class SecretHeaderGuard(BaseGuard): + def can_activate(self, request: Request, credentials=None) -> bool: + return request.headers.get("X-Secret") == "letmein" + + +class AlwaysPassGuard(BaseGuard): + def can_activate(self, request: Request, credentials=None) -> bool: + return True + + +class AlwaysDenyGuard(BaseGuard): + def can_activate(self, request: Request, credentials=None) -> bool: + return False + + +@Controller("/todos") +class TodoController: + def __init__(self, service: TodoService): + self.service = service + + @Get("/") + def list_todos(self): + return {"todos": self.service.list_all()} + + @Post("/") + def create_todo(self, title: str = Body("title")): + return self.service.add(title) + + @Get("/secret") + @UseGuards(SecretHeaderGuard) + def secret(self): + return {"secret": True} + + +@Module(controllers=[TodoController], providers=[TodoService]) +class TodoModule: + pass + + +# ── async http client ────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_http_client_get(): + module = PyNestTestingModule.create_testing_module(imports=[TodoModule]).compile() + async with module.create_http_client() as client: + response = await client.get("/todos") + assert response.status_code == 200 + assert response.json() == {"todos": []} + + +@pytest.mark.asyncio +async def test_http_client_post_hits_real_service(): + module = PyNestTestingModule.create_testing_module(imports=[TodoModule]).compile() + async with module.create_http_client() as client: + response = await client.post("/todos", json={"title": "write tests"}) + assert response.status_code == 200 + assert response.json() == {"id": 1, "title": "write tests"} + assert module.get(TodoService).list_all() == [{"id": 1, "title": "write tests"}] + + +@pytest.mark.asyncio +async def test_http_client_with_overridden_provider(): + class FakeTodoService: + def list_all(self): + return [{"id": 99, "title": "faked"}] + + module = ( + PyNestTestingModule.create_testing_module(imports=[TodoModule]) + .override_provider(TodoService) + .use_value(FakeTodoService()) + .compile() + ) + async with module.create_http_client() as client: + response = await client.get("/todos") + assert response.json() == {"todos": [{"id": 99, "title": "faked"}]} + + +# ── sync test client ─────────────────────────────────────────────────────────── + + +def test_sync_test_client(): + module = PyNestTestingModule.create_testing_module(imports=[TodoModule]).compile() + client = module.create_test_client() + response = client.get("/todos") + assert response.status_code == 200 + assert response.json() == {"todos": []} + + +# ── guards ───────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_guard_blocks_without_override(): + module = PyNestTestingModule.create_testing_module(imports=[TodoModule]).compile() + async with module.create_http_client() as client: + response = await client.get("/todos/secret") + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_override_guard_use_class_bypasses_guard(): + module = ( + PyNestTestingModule.create_testing_module(imports=[TodoModule]) + .override_guard(SecretHeaderGuard) + .use_class(AlwaysPassGuard) + .compile() + ) + async with module.create_http_client() as client: + response = await client.get("/todos/secret") + assert response.status_code == 200 + assert response.json() == {"secret": True} + + +@pytest.mark.asyncio +async def test_override_guard_use_value(): + module = ( + PyNestTestingModule.create_testing_module(imports=[TodoModule]) + .override_guard(SecretHeaderGuard) + .use_value(AlwaysPassGuard()) + .compile() + ) + async with module.create_http_client() as client: + response = await client.get("/todos/secret") + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_guard_override_restored_after_close(): + module = ( + PyNestTestingModule.create_testing_module(imports=[TodoModule]) + .override_guard(SecretHeaderGuard) + .use_class(AlwaysPassGuard) + .compile() + ) + async with module.create_http_client() as client: + assert (await client.get("/todos/secret")).status_code == 200 + await module.close() + + # A fresh testing module must see the original guard again. + fresh = PyNestTestingModule.create_testing_module(imports=[TodoModule]).compile() + async with fresh.create_http_client() as client: + assert (await fresh_client_get(client)).status_code == 403 + + +async def fresh_client_get(client): + return await client.get("/todos/secret") + + +@pytest.mark.asyncio +async def test_controller_level_guard_override(): + @Controller("/locked") + @UseGuards(AlwaysDenyGuard) + class LockedController: + @Get("/") + def index(self): + return {"open": True} + + @Module(controllers=[LockedController], providers=[]) + class LockedModule: + pass + + module = ( + PyNestTestingModule.create_testing_module(imports=[LockedModule]) + .override_guard(AlwaysDenyGuard) + .use_class(AlwaysPassGuard) + .compile() + ) + async with module.create_http_client() as client: + response = await client.get("/locked") + assert response.status_code == 200 + assert response.json() == {"open": True} + await module.close() diff --git a/tests/test_testing/test_test_bed.py b/tests/test_testing/test_test_bed.py new file mode 100644 index 0000000..64ed2ac --- /dev/null +++ b/tests/test_testing/test_test_bed.py @@ -0,0 +1,158 @@ +import pytest +from unittest.mock import AsyncMock, Mock + +from nest.core import Injectable +from nest.testing import TestBed + + +@Injectable +class Database: + def query(self, sql: str): + raise RuntimeError("real database should never be hit in unit tests") + + +@Injectable +class CacheService: + def get(self, key: str): + return None + + def set(self, key: str, value): + pass + + +@Injectable +class UserRepository: + def __init__(self, db: Database): + self.db = db + + def find_all(self): + return self.db.query("select * from users") + + +@Injectable +class UserService: + def __init__(self, repo: UserRepository, cache: CacheService): + self.repo = repo + self.cache = cache + + def get_users(self): + cached = self.cache.get("users") + if cached is not None: + return cached + users = self.repo.find_all() + self.cache.set("users", users) + return users + + +@Injectable +class AsyncNotifier: + async def send(self, message: str): + raise RuntimeError("real notifier should never be hit") + + +@Injectable +class SignupService: + def __init__(self, notifier: AsyncNotifier): + self.notifier = notifier + + async def signup(self, email: str): + await self.notifier.send(f"welcome {email}") + return {"email": email} + + +# ── solitary ─────────────────────────────────────────────────────────────────── + + +def test_solitary_unit_is_real_and_deps_are_mocked(): + unit, unit_ref = TestBed.solitary(UserService).compile() + assert isinstance(unit, UserService) + + unit_ref.get(UserRepository).find_all.return_value = ["alice"] + unit_ref.get(CacheService).get.return_value = None + + assert unit.get_users() == ["alice"] + unit_ref.get(CacheService).set.assert_called_once_with("users", ["alice"]) + + +def test_unit_ref_returns_the_injected_mock(): + unit, unit_ref = TestBed.solitary(UserService).compile() + assert unit.repo is unit_ref.get(UserRepository) + assert unit.cache is unit_ref.get(CacheService) + + +def test_compile_result_supports_attribute_access(): + result = TestBed.solitary(UserService).compile() + assert isinstance(result.unit, UserService) + assert result.unit_ref.get(UserRepository) is result.unit.repo + + +@pytest.mark.asyncio +async def test_async_dependency_methods_become_async_mocks(): + unit, unit_ref = TestBed.solitary(SignupService).compile() + assert await unit.signup("a@b.com") == {"email": "a@b.com"} + unit_ref.get(AsyncNotifier).send.assert_awaited_once_with("welcome a@b.com") + + +def test_mock_using_configures_the_mock(): + unit, unit_ref = ( + TestBed.solitary(UserService) + .mock(UserRepository) + .using(find_all=Mock(return_value=["bob"])) + .mock(CacheService) + .using(get=Mock(return_value=None)) + .compile() + ) + assert unit.get_users() == ["bob"] + + +def test_mock_final_replaces_dependency_entirely(): + class InMemoryRepo: + def find_all(self): + return ["from-memory"] + + fake = InMemoryRepo() + unit, unit_ref = ( + TestBed.solitary(UserService) + .mock(UserRepository) + .final(fake) + .mock(CacheService) + .using(get=Mock(return_value=None)) + .compile() + ) + assert unit.repo is fake + assert unit_ref.get(UserRepository) is fake + assert unit.get_users() == ["from-memory"] + + +def test_unit_with_no_dependencies(): + unit, unit_ref = TestBed.solitary(Database).compile() + assert isinstance(unit, Database) + + +def test_unit_ref_unknown_dependency_raises(): + _, unit_ref = TestBed.solitary(UserService).compile() + with pytest.raises(KeyError, match="Database"): + unit_ref.get(Database) + + +# ── sociable ─────────────────────────────────────────────────────────────────── + + +def test_sociable_exposes_real_dependency(): + unit, unit_ref = ( + TestBed.sociable(UserService).expose(UserRepository).compile() + ) + # UserRepository is real; its own Database dependency is mocked. + assert isinstance(unit.repo, UserRepository) + unit_ref.get(Database).query.return_value = ["from-db"] + unit_ref.get(CacheService).get.return_value = None + + assert unit.get_users() == ["from-db"] + unit_ref.get(Database).query.assert_called_once_with("select * from users") + + +def test_sociable_exposed_dependency_accessible_via_unit_ref(): + unit, unit_ref = ( + TestBed.sociable(UserService).expose(UserRepository).compile() + ) + assert unit_ref.get(UserRepository) is unit.repo diff --git a/tests/test_testing/test_testing_module.py b/tests/test_testing/test_testing_module.py new file mode 100644 index 0000000..1cd086e --- /dev/null +++ b/tests/test_testing/test_testing_module.py @@ -0,0 +1,243 @@ +import pytest +from unittest.mock import MagicMock + +from nest.core import Controller, Get, Injectable, Module +from nest.testing import PyNestTestingModule + + +@Injectable +class UserRepository: + def find_all(self): + return ["alice", "bob"] + + +@Injectable +class UserService: + def __init__(self, repo: UserRepository): + self.repo = repo + + def list_users(self): + return self.repo.find_all() + + +@Controller("/users") +class UserController: + def __init__(self, service: UserService): + self.service = service + + @Get("/") + def list_users(self): + return {"users": self.service.list_users()} + + +@Module(controllers=[UserController], providers=[UserService, UserRepository]) +class UserModule: + pass + + +class FakeUserRepository: + def find_all(self): + return ["fake"] + + +# ── compile & get ────────────────────────────────────────────────────────────── + + +def test_compile_resolves_wired_provider(): + module = PyNestTestingModule.create_testing_module( + providers=[UserService, UserRepository] + ).compile() + service = module.get(UserService) + assert isinstance(service, UserService) + assert service.list_users() == ["alice", "bob"] + + +@pytest.mark.asyncio +async def test_compile_is_awaitable(): + module = await PyNestTestingModule.create_testing_module( + providers=[UserService, UserRepository] + ).compile() + assert isinstance(module.get(UserService), UserService) + + +def test_compile_from_imports_resolves_module_graph(): + module = PyNestTestingModule.create_testing_module(imports=[UserModule]).compile() + assert module.get(UserService).list_users() == ["alice", "bob"] + + +def test_get_controller_instance(): + module = PyNestTestingModule.create_testing_module(imports=[UserModule]).compile() + controller = module.get(UserController) + assert isinstance(controller, UserController) + assert controller.list_users() == {"users": ["alice", "bob"]} + + +def test_singleton_providers_return_same_instance(): + module = PyNestTestingModule.create_testing_module(imports=[UserModule]).compile() + assert module.get(UserService) is module.get(UserService) + + +def test_two_testing_modules_are_isolated(): + first = PyNestTestingModule.create_testing_module(imports=[UserModule]).compile() + second = PyNestTestingModule.create_testing_module(imports=[UserModule]).compile() + assert first.get(UserService) is not second.get(UserService) + + +# ── provider overrides ───────────────────────────────────────────────────────── + + +def test_override_provider_use_value(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .override_provider(UserRepository) + .use_value(FakeUserRepository()) + .compile() + ) + assert module.get(UserService).list_users() == ["fake"] + + +def test_override_provider_use_class(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .override_provider(UserRepository) + .use_class(FakeUserRepository) + .compile() + ) + assert module.get(UserService).list_users() == ["fake"] + + +def test_override_provider_use_factory(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .override_provider(UserRepository) + .use_factory(lambda: FakeUserRepository()) + .compile() + ) + assert module.get(UserService).list_users() == ["fake"] + + +def test_override_unknown_provider_raises(): + class NotRegistered: + pass + + builder = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .override_provider(NotRegistered) + .use_value(MagicMock()) + ) + with pytest.raises(ValueError, match="NotRegistered"): + builder.compile() + + +# ── auto-mock ────────────────────────────────────────────────────────────────── + + +def test_use_auto_mock_replaces_providers_with_mocks(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .use_auto_mock() + .compile() + ) + repo = module.get(UserRepository) + repo.find_all.return_value = ["mocked"] + assert repo.find_all() == ["mocked"] + + +def test_use_auto_mock_respects_exclude(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .use_auto_mock(exclude=[UserService]) + .compile() + ) + service = module.get(UserService) + assert isinstance(service, UserService) + service.repo.find_all.return_value = ["mocked"] + assert service.list_users() == ["mocked"] + + +def test_explicit_override_wins_over_auto_mock(): + module = ( + PyNestTestingModule.create_testing_module(imports=[UserModule]) + .use_auto_mock() + .override_provider(UserRepository) + .use_value(FakeUserRepository()) + .compile() + ) + assert module.get(UserRepository).find_all() == ["fake"] + + +@pytest.mark.asyncio +async def test_auto_mock_creates_async_mocks_for_async_methods(): + @Injectable + class AsyncRepository: + async def find_all(self): + return ["real"] + + @Module(providers=[AsyncRepository]) + class AsyncModule: + pass + + module = ( + PyNestTestingModule.create_testing_module(imports=[AsyncModule]) + .use_auto_mock() + .compile() + ) + repo = module.get(AsyncRepository) + repo.find_all.return_value = ["mocked"] + assert await repo.find_all() == ["mocked"] + + +# ── lifecycle ────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_init_and_close_run_lifecycle_hooks(): + events = [] + + @Injectable + class LifecycleService: + def on_module_init(self): + events.append("init") + + def on_module_destroy(self): + events.append("destroy") + + @Module(providers=[LifecycleService]) + class LifecycleModule: + pass + + module = PyNestTestingModule.create_testing_module( + imports=[LifecycleModule] + ).compile() + assert events == [] + + await module.init() + assert events == ["init"] + + await module.close() + assert events == ["init", "destroy"] + + +@pytest.mark.asyncio +async def test_async_context_manager_runs_init_and_close(): + events = [] + + @Injectable + class CtxService: + def on_module_init(self): + events.append("init") + + def on_module_destroy(self): + events.append("destroy") + + @Module(providers=[CtxService]) + class CtxModule: + pass + + async with PyNestTestingModule.create_testing_module( + imports=[CtxModule] + ).compile() as module: + assert isinstance(module.get(CtxService), CtxService) + assert events == ["init"] + + assert events == ["init", "destroy"]