diff --git a/modern_di_faststream/main.py b/modern_di_faststream/main.py index 72c4f82..5fa3a2e 100644 --- a/modern_di_faststream/main.py +++ b/modern_di_faststream/main.py @@ -15,6 +15,12 @@ faststream_message_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=faststream.StreamMessage) +# Keys under which the containers live in FastStream's ``ContextRepo``. Each is +# written in one place and read in another; naming them keeps writer and reader +# in provable agreement instead of relying on two matching string literals. +_ROOT_CONTAINER_KEY = "di_container" +_REQUEST_CONTAINER_KEY = "request_container" + class _DIMiddlewareFactory: __slots__ = ("di_container",) @@ -41,7 +47,7 @@ async def consume_scope( scope=modern_di.Scope.REQUEST, context={faststream.StreamMessage: msg} ) try: - with self.context.scope("request_container", request_container): + with self.context.scope(_REQUEST_CONTAINER_KEY, request_container): return typing.cast( typing.AsyncIterator[DecodedMessage], await call_next(msg), @@ -51,7 +57,7 @@ async def consume_scope( def fetch_di_container(app_: faststream.FastStream | AsgiFastStream) -> Container: - return typing.cast(Container, app_.context.get("di_container")) + return typing.cast(Container, app_.context.get(_ROOT_CONTAINER_KEY)) def setup_di( @@ -63,7 +69,7 @@ def setup_di( raise RuntimeError(msg) container.providers_registry.add_providers(faststream_message_provider) - app.context.set_global("di_container", container) + app.context.set_global(_ROOT_CONTAINER_KEY, container) # FastStream's lifecycle is callback-based, so the root container can't be # wrapped in ``async with``. Reopen it on startup (before the broker consumes) # to pair with the shutdown close, so a broker restart works instead of @@ -80,7 +86,7 @@ class Dependency(typing.Generic[T_co]): dependency: providers.AbstractProvider[T_co] | type[T_co] async def __call__(self, context: faststream.ContextRepo) -> T_co: - request_container: Container = context.get("request_container") + request_container: Container = context.get(_REQUEST_CONTAINER_KEY) if isinstance(self.dependency, providers.AbstractProvider): return request_container.resolve_provider(self.dependency) return request_container.resolve(dependency_type=self.dependency)