|
1 | 1 | import typing |
| 2 | +from importlib.metadata import version |
2 | 3 | from types import TracebackType |
3 | | -from typing import Any, Optional |
| 4 | +from typing import Any, Final, Optional |
4 | 5 |
|
5 | | -from faststream import BaseMiddleware |
6 | | -from typing_extensions import override |
| 6 | +from packaging.version import Version |
| 7 | +from typing_extensions import deprecated, override |
7 | 8 |
|
8 | 9 | from that_depends import container_context |
9 | 10 | from that_depends.providers.context_resources import ContextScope, SupportsContext |
10 | 11 | from that_depends.utils import UNSET, Unset, is_set |
11 | 12 |
|
12 | 13 |
|
13 | | -class DIContextMiddleware(BaseMiddleware): |
14 | | - """Initializes the container context for faststream brokers.""" |
15 | | - |
16 | | - def __init__( |
17 | | - self, |
18 | | - *context_items: SupportsContext[Any], |
19 | | - global_context: dict[str, Any] | Unset = UNSET, |
20 | | - scope: ContextScope | Unset = UNSET, |
21 | | - ) -> None: |
22 | | - """Initialize the container context middleware. |
23 | | -
|
24 | | - Args: |
25 | | - *context_items (SupportsContext[Any]): Context items to initialize. |
26 | | - global_context (dict[str, Any] | Unset): Global context to initialize the container. |
27 | | - scope (ContextScope | Unset): Context scope to initialize the container. |
28 | | -
|
29 | | - """ |
30 | | - super().__init__() |
31 | | - self._context: container_context | None = None |
32 | | - self._context_items = set(context_items) |
33 | | - self._global_context = global_context |
34 | | - self._scope = scope |
35 | | - |
36 | | - @override |
37 | | - async def on_receive(self) -> None: |
38 | | - self._context = container_context( |
39 | | - *self._context_items, |
40 | | - scope=self._scope if is_set(self._scope) else None, |
41 | | - global_context=self._global_context if is_set(self._global_context) else None, |
42 | | - ) |
43 | | - await self._context.__aenter__() |
44 | | - |
45 | | - @override |
46 | | - async def after_processed( |
47 | | - self, |
48 | | - exc_type: type[BaseException] | None = None, |
49 | | - exc_val: BaseException | None = None, |
50 | | - exc_tb: Optional["TracebackType"] = None, |
51 | | - ) -> bool | None: |
52 | | - if self._context is not None: |
53 | | - await self._context.__aexit__(exc_type, exc_val, exc_tb) |
54 | | - return None |
55 | | - |
56 | | - def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> "DIContextMiddleware": # noqa: ARG002, ANN401 |
57 | | - """Create an instance of DIContextMiddleware.""" |
58 | | - return DIContextMiddleware(*self._context_items, scope=self._scope, global_context=self._global_context) |
| 14 | +_FASTSTREAM_MODULE_NAME: Final[str] = "faststream" |
| 15 | +_FASTSTREAM_VERSION: Final[str] = version(_FASTSTREAM_MODULE_NAME) |
| 16 | +if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover |
| 17 | + from faststream import BaseMiddleware, ContextRepo |
| 18 | + from faststream._internal.types import AnyMsg |
| 19 | + |
| 20 | + class DIContextMiddleware(BaseMiddleware): |
| 21 | + """Initializes the container context for faststream brokers.""" |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + *context_items: SupportsContext[Any], |
| 26 | + msg: AnyMsg | None = None, |
| 27 | + context: Optional["ContextRepo"] = None, |
| 28 | + global_context: dict[str, Any] | Unset = UNSET, |
| 29 | + scope: ContextScope | Unset = UNSET, |
| 30 | + ) -> None: |
| 31 | + """Initialize the container context middleware. |
| 32 | +
|
| 33 | + Args: |
| 34 | + *context_items (SupportsContext[Any]): Context items to initialize. |
| 35 | + msg (Any): Message object. |
| 36 | + context (ContextRepo): Context repository. |
| 37 | + global_context (dict[str, Any] | Unset): Global context to initialize the container. |
| 38 | + scope (ContextScope | Unset): Context scope to initialize the container. |
| 39 | +
|
| 40 | + """ |
| 41 | + super().__init__(msg, context=context) # type: ignore[arg-type] |
| 42 | + self._context: container_context | None = None |
| 43 | + self._context_items = set(context_items) |
| 44 | + self._global_context = global_context |
| 45 | + self._scope = scope |
| 46 | + |
| 47 | + @override |
| 48 | + async def on_receive(self) -> None: |
| 49 | + self._context = container_context( |
| 50 | + *self._context_items, |
| 51 | + scope=self._scope if is_set(self._scope) else None, |
| 52 | + global_context=self._global_context if is_set(self._global_context) else None, |
| 53 | + ) |
| 54 | + await self._context.__aenter__() |
| 55 | + |
| 56 | + @override |
| 57 | + async def after_processed( |
| 58 | + self, |
| 59 | + exc_type: type[BaseException] | None = None, |
| 60 | + exc_val: BaseException | None = None, |
| 61 | + exc_tb: Optional["TracebackType"] = None, |
| 62 | + ) -> bool | None: |
| 63 | + if self._context is not None: |
| 64 | + await self._context.__aexit__(exc_type, exc_val, exc_tb) |
| 65 | + return None |
| 66 | + |
| 67 | + def __call__(self, msg: Any = None, **kwargs: Any) -> "DIContextMiddleware": # noqa: ANN401 |
| 68 | + """Create an instance of DIContextMiddleware. |
| 69 | +
|
| 70 | + Args: |
| 71 | + msg (Any): Message object. |
| 72 | + **kwargs: Additional keyword arguments. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + DIContextMiddleware: A new instance of DIContextMiddleware. |
| 76 | +
|
| 77 | + """ |
| 78 | + context = kwargs.get("context") |
| 79 | + |
| 80 | + return DIContextMiddleware( |
| 81 | + *self._context_items, |
| 82 | + msg=msg, |
| 83 | + context=context, |
| 84 | + scope=self._scope, |
| 85 | + global_context=self._global_context, |
| 86 | + ) |
| 87 | +else: # pragma: no cover |
| 88 | + from faststream import BaseMiddleware |
| 89 | + |
| 90 | + @deprecated("Will be removed with faststream v1") |
| 91 | + class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef] |
| 92 | + """Initializes the container context for faststream brokers.""" |
| 93 | + |
| 94 | + def __init__( |
| 95 | + self, |
| 96 | + *context_items: SupportsContext[Any], |
| 97 | + global_context: dict[str, Any] | Unset = UNSET, |
| 98 | + scope: ContextScope | Unset = UNSET, |
| 99 | + ) -> None: |
| 100 | + """Initialize the container context middleware. |
| 101 | +
|
| 102 | + Args: |
| 103 | + *context_items (SupportsContext[Any]): Context items to initialize. |
| 104 | + global_context (dict[str, Any] | Unset): Global context to initialize the container. |
| 105 | + scope (ContextScope | Unset): Context scope to initialize the container. |
| 106 | +
|
| 107 | + """ |
| 108 | + super().__init__() # type: ignore[call-arg] |
| 109 | + self._context: container_context | None = None |
| 110 | + self._context_items = set(context_items) |
| 111 | + self._global_context = global_context |
| 112 | + self._scope = scope |
| 113 | + |
| 114 | + @override |
| 115 | + async def on_receive(self) -> None: |
| 116 | + self._context = container_context( |
| 117 | + *self._context_items, |
| 118 | + scope=self._scope if is_set(self._scope) else None, |
| 119 | + global_context=self._global_context if is_set(self._global_context) else None, |
| 120 | + ) |
| 121 | + await self._context.__aenter__() |
| 122 | + |
| 123 | + @override |
| 124 | + async def after_processed( |
| 125 | + self, |
| 126 | + exc_type: type[BaseException] | None = None, |
| 127 | + exc_val: BaseException | None = None, |
| 128 | + exc_tb: Optional["TracebackType"] = None, |
| 129 | + ) -> bool | None: |
| 130 | + if self._context is not None: |
| 131 | + await self._context.__aexit__(exc_type, exc_val, exc_tb) |
| 132 | + return None |
| 133 | + |
| 134 | + def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> "DIContextMiddleware": # noqa: ARG002, ANN401 |
| 135 | + """Create an instance of DIContextMiddleware.""" |
| 136 | + return DIContextMiddleware(*self._context_items, scope=self._scope, global_context=self._global_context) |
0 commit comments