Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ jobs:
- "3.12"
- "3.13"
- "3.14"
faststream-version:
- "<0.6.0"
- ">=0.6.0"
steps:
- uses: actions/checkout@v5
- uses: extractions/setup-just@v3
Expand All @@ -41,7 +44,7 @@ jobs:
cache-dependency-glob: "**/pyproject.toml"
- run: uv python install ${{ matrix.python-version }}
- run: just install
- run: just test . --cov=. --cov-report xml
- run: uv run --with "faststream${{ matrix.faststream-version }}" pytest . --cov=. --cov-report xml
- uses: codecov/[email protected]
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fastapi = [
"fastapi",
]
faststream = [
"faststream<0.6.0"
"faststream"
]

[project.urls]
Expand Down
15 changes: 13 additions & 2 deletions tests/integrations/faststream/test_faststream_di_pass_message.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import typing

from faststream import BaseMiddleware, Context, Depends
from faststream.broker.message import StreamMessage
from faststream.nats import NatsBroker, TestNatsBroker
from faststream.nats.message import NatsMessage
from packaging.version import Version

from that_depends import BaseContainer, container_context, fetch_context_item, providers
from that_depends.integrations.faststream import _FASTSTREAM_VERSION


if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover
from faststream.message import StreamMessage
else: # pragma: no cover
from faststream.broker.message import StreamMessage # type: ignore[import-not-found, no-redef]


class ContextMiddleware(BaseMiddleware):
Expand All @@ -18,7 +25,11 @@ async def consume_scope(
return await call_next(msg)


broker = NatsBroker(middlewares=(ContextMiddleware,), validate=False)
if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover
broker = NatsBroker(middlewares=(ContextMiddleware,))

else: # pragma: no cover
broker = NatsBroker(middlewares=(ContextMiddleware,), validate=False) # type: ignore[call-arg]

TEST_SUBJECT = "test"

Expand Down
176 changes: 127 additions & 49 deletions that_depends/integrations/faststream.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,136 @@
import typing
from importlib.metadata import version
from types import TracebackType
from typing import Any, Optional
from typing import Any, Final, Optional

from faststream import BaseMiddleware
from typing_extensions import override
from packaging.version import Version
from typing_extensions import deprecated, override

from that_depends import container_context
from that_depends.providers.context_resources import ContextScope, SupportsContext
from that_depends.utils import UNSET, Unset, is_set


class DIContextMiddleware(BaseMiddleware):
"""Initializes the container context for faststream brokers."""

def __init__(
self,
*context_items: SupportsContext[Any],
global_context: dict[str, Any] | Unset = UNSET,
scope: ContextScope | Unset = UNSET,
) -> None:
"""Initialize the container context middleware.

Args:
*context_items (SupportsContext[Any]): Context items to initialize.
global_context (dict[str, Any] | Unset): Global context to initialize the container.
scope (ContextScope | Unset): Context scope to initialize the container.

"""
super().__init__()
self._context: container_context | None = None
self._context_items = set(context_items)
self._global_context = global_context
self._scope = scope

@override
async def on_receive(self) -> None:
self._context = container_context(
*self._context_items,
scope=self._scope if is_set(self._scope) else None,
global_context=self._global_context if is_set(self._global_context) else None,
)
await self._context.__aenter__()

@override
async def after_processed(
self,
exc_type: type[BaseException] | None = None,
exc_val: BaseException | None = None,
exc_tb: Optional["TracebackType"] = None,
) -> bool | None:
if self._context is not None:
await self._context.__aexit__(exc_type, exc_val, exc_tb)
return None

def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> "DIContextMiddleware": # noqa: ARG002, ANN401
"""Create an instance of DIContextMiddleware."""
return DIContextMiddleware(*self._context_items, scope=self._scope, global_context=self._global_context)
_FASTSTREAM_MODULE_NAME: Final[str] = "faststream"
_FASTSTREAM_VERSION: Final[str] = version(_FASTSTREAM_MODULE_NAME)
if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover
from faststream import BaseMiddleware, ContextRepo
from faststream._internal.types import AnyMsg

class DIContextMiddleware(BaseMiddleware):
"""Initializes the container context for faststream brokers."""

def __init__(
self,
*context_items: SupportsContext[Any],
msg: AnyMsg | None = None,
context: Optional["ContextRepo"] = None,
global_context: dict[str, Any] | Unset = UNSET,
scope: ContextScope | Unset = UNSET,
) -> None:
"""Initialize the container context middleware.

Args:
*context_items (SupportsContext[Any]): Context items to initialize.
msg (Any): Message object.
context (ContextRepo): Context repository.
global_context (dict[str, Any] | Unset): Global context to initialize the container.
scope (ContextScope | Unset): Context scope to initialize the container.

"""
super().__init__(msg, context=context) # type: ignore[arg-type]
self._context: container_context | None = None
self._context_items = set(context_items)
self._global_context = global_context
self._scope = scope

@override
async def on_receive(self) -> None:
self._context = container_context(
*self._context_items,
scope=self._scope if is_set(self._scope) else None,
global_context=self._global_context if is_set(self._global_context) else None,
)
await self._context.__aenter__()

@override
async def after_processed(
self,
exc_type: type[BaseException] | None = None,
exc_val: BaseException | None = None,
exc_tb: Optional["TracebackType"] = None,
) -> bool | None:
if self._context is not None:
await self._context.__aexit__(exc_type, exc_val, exc_tb)
return None

def __call__(self, msg: Any = None, **kwargs: Any) -> "DIContextMiddleware": # noqa: ANN401
"""Create an instance of DIContextMiddleware.

Args:
msg (Any): Message object.
**kwargs: Additional keyword arguments.

Returns:
DIContextMiddleware: A new instance of DIContextMiddleware.

"""
context = kwargs.get("context")

return DIContextMiddleware(
*self._context_items,
msg=msg,
context=context,
scope=self._scope,
global_context=self._global_context,
)
else: # pragma: no cover
from faststream import BaseMiddleware

@deprecated("Will be removed with faststream v1")
class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef]
"""Initializes the container context for faststream brokers."""

def __init__(
self,
*context_items: SupportsContext[Any],
global_context: dict[str, Any] | Unset = UNSET,
scope: ContextScope | Unset = UNSET,
) -> None:
"""Initialize the container context middleware.

Args:
*context_items (SupportsContext[Any]): Context items to initialize.
global_context (dict[str, Any] | Unset): Global context to initialize the container.
scope (ContextScope | Unset): Context scope to initialize the container.

"""
super().__init__() # type: ignore[call-arg]
self._context: container_context | None = None
self._context_items = set(context_items)
self._global_context = global_context
self._scope = scope

@override
async def on_receive(self) -> None:
self._context = container_context(
*self._context_items,
scope=self._scope if is_set(self._scope) else None,
global_context=self._global_context if is_set(self._global_context) else None,
)
await self._context.__aenter__()

@override
async def after_processed(
self,
exc_type: type[BaseException] | None = None,
exc_val: BaseException | None = None,
exc_tb: Optional["TracebackType"] = None,
) -> bool | None:
if self._context is not None:
await self._context.__aexit__(exc_type, exc_val, exc_tb)
return None

def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> "DIContextMiddleware": # noqa: ARG002, ANN401
"""Create an instance of DIContextMiddleware."""
return DIContextMiddleware(*self._context_items, scope=self._scope, global_context=self._global_context)