diff --git a/architecture/container-lifecycle.md b/architecture/container-lifecycle.md index 7bc8d8c..999c723 100644 --- a/architecture/container-lifecycle.md +++ b/architecture/container-lifecycle.md @@ -14,17 +14,19 @@ entry point an application calls at startup. It does three things: 2. Registers the two *context providers* (`fastapi_request_provider`, `fastapi_websocket_provider`) on the container's `providers_registry`, so the live `Request` / `WebSocket` can be resolved. -3. Chains an internal lifespan manager onto the app's existing - `lifespan_context` via `fastapi.routing._merge_lifespan_context`, preserving - any lifespan the app already had. +3. Composes the container's open/close onto the app's existing + `lifespan_context` via `_compose_lifespan`, preserving any lifespan the app + already had (its startup/shutdown still run and its yielded state passes + through). The composition is our own — no dependency on FastAPI internals. It returns the same container for convenience. The application owns container construction (groups, overrides); `setup_di` only wires it in. ## Lifespan — open/close across cycles -The chained `_lifespan_manager` runs `async with fetch_di_container(app):` — the -root container's `__aenter__` opens it on startup and `__aexit__` closes it on +The composed lifespan keeps the original as the outer context and opens the +container inside it with `async with fetch_di_container(app):` — the root +container's `__aenter__` opens it on startup and `__aexit__` closes it on shutdown. Using `async with` (rather than a one-shot open) means a **second lifespan cycle against the same container reopens it** instead of raising `ContainerClosedError`. This is what lets an app be started, stopped, and diff --git a/modern_di_fastapi/main.py b/modern_di_fastapi/main.py index 0f3d0b4..3c04c5b 100644 --- a/modern_di_fastapi/main.py +++ b/modern_di_fastapi/main.py @@ -3,9 +3,9 @@ import typing import fastapi -from fastapi.routing import _merge_lifespan_context from modern_di import Container, Scope, providers from starlette.requests import HTTPConnection +from starlette.types import Lifespan T_co = typing.TypeVar("T_co", covariant=True) @@ -25,23 +25,29 @@ def fetch_di_container(app_: fastapi.FastAPI) -> Container: return typing.cast(Container, app_.state.di_container) -@contextlib.asynccontextmanager -async def _lifespan_manager(app_: fastapi.FastAPI) -> typing.AsyncIterator[None]: - # ``async with`` reopens the root container on each startup (``__aenter__``) - # and closes it on shutdown, so a second lifespan cycle against the same - # container works instead of raising ContainerClosedError. - async with fetch_di_container(app_): - yield +def _compose_lifespan(original: Lifespan[fastapi.FastAPI]) -> Lifespan[fastapi.FastAPI]: + """Wrap ``original`` so the root container opens/closes around it. + + The original lifespan stays the outer context and its yielded state passes + straight through; the container is opened inside it. ``async with`` reopens the + container on each startup and closes it on shutdown, so a second lifespan cycle + against the same container works instead of raising ``ContainerClosedError``. + """ + + @contextlib.asynccontextmanager + async def composed(app_: fastapi.FastAPI) -> typing.AsyncIterator[typing.Mapping[str, typing.Any] | None]: + async with original(app_) as state, fetch_di_container(app_): + yield state + + # ``Lifespan`` is a union of CM[None] | CM[Mapping]; it can't express our + # CM[Mapping | None], though that is exactly what a lifespan may yield. + return typing.cast(Lifespan[fastapi.FastAPI], composed) def setup_di(app: fastapi.FastAPI, container: Container) -> Container: app.state.di_container = container container.providers_registry.add_providers(*_CONNECTION_PROVIDERS) - old_lifespan_manager = app.router.lifespan_context - app.router.lifespan_context = _merge_lifespan_context( - old_lifespan_manager, - _lifespan_manager, - ) + app.router.lifespan_context = _compose_lifespan(app.router.lifespan_context) return container diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index 4080470..d24f1a8 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -1,9 +1,12 @@ +import contextlib import typing import fastapi +import modern_di from starlette import status from starlette.testclient import TestClient +import modern_di_fastapi from modern_di_fastapi import FromDI, fetch_di_container from tests.dependencies import Dependencies, SimpleCreator @@ -23,3 +26,31 @@ async def read_root(instance: typing.Annotated[SimpleCreator, FromDI(Dependencie # Second cycle must reopen the same container instead of raising ContainerClosedError. with TestClient(app=app) as client: assert client.get("/").status_code == status.HTTP_200_OK + + +def test_setup_di_composes_with_existing_lifespan() -> None: + events: list[str] = [] + + @contextlib.asynccontextmanager + async def user_lifespan(app_: fastapi.FastAPI) -> typing.AsyncIterator[dict[str, str]]: + assert isinstance(app_, fastapi.FastAPI) + events.append("startup") + yield {"marker": "from-user-lifespan"} + events.append("shutdown") + + app = fastapi.FastAPI(lifespan=user_lifespan) + container = modern_di.Container(groups=[Dependencies]) + modern_di_fastapi.setup_di(app, container) + + @app.get("/") + async def read_marker(request: fastapi.Request) -> str: + return typing.cast(str, request.state.marker) + + with TestClient(app=app) as client: + # original lifespan started; our container opened; its yielded state passes through + assert events == ["startup"] + assert not container.closed + assert client.get("/").json() == "from-user-lifespan" + # shutdown ran the original lifespan and closed our container + assert events == ["startup", "shutdown"] + assert container.closed