Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions modern_di_faststream/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

faststream_message_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=faststream.StreamMessage)

# Keys under which the containers live in FastStream's ``ContextRepo``. Each is
# written in one place and read in another; naming them keeps writer and reader
# in provable agreement instead of relying on two matching string literals.
_ROOT_CONTAINER_KEY = "di_container"
_REQUEST_CONTAINER_KEY = "request_container"


class _DIMiddlewareFactory:
__slots__ = ("di_container",)
Expand All @@ -41,7 +47,7 @@ async def consume_scope(
scope=modern_di.Scope.REQUEST, context={faststream.StreamMessage: msg}
)
try:
with self.context.scope("request_container", request_container):
with self.context.scope(_REQUEST_CONTAINER_KEY, request_container):
return typing.cast(
typing.AsyncIterator[DecodedMessage],
await call_next(msg),
Expand All @@ -51,7 +57,7 @@ async def consume_scope(


def fetch_di_container(app_: faststream.FastStream | AsgiFastStream) -> Container:
return typing.cast(Container, app_.context.get("di_container"))
return typing.cast(Container, app_.context.get(_ROOT_CONTAINER_KEY))


def setup_di(
Expand All @@ -63,7 +69,7 @@ def setup_di(
raise RuntimeError(msg)

container.providers_registry.add_providers(faststream_message_provider)
app.context.set_global("di_container", container)
app.context.set_global(_ROOT_CONTAINER_KEY, container)
# FastStream's lifecycle is callback-based, so the root container can't be
# wrapped in ``async with``. Reopen it on startup (before the broker consumes)
# to pair with the shutdown close, so a broker restart works instead of
Expand All @@ -80,7 +86,7 @@ class Dependency(typing.Generic[T_co]):
dependency: providers.AbstractProvider[T_co] | type[T_co]

async def __call__(self, context: faststream.ContextRepo) -> T_co:
request_container: Container = context.get("request_container")
request_container: Container = context.get(_REQUEST_CONTAINER_KEY)
if isinstance(self.dependency, providers.AbstractProvider):
return request_container.resolve_provider(self.dependency)
return request_container.resolve(dependency_type=self.dependency)
Expand Down