Skip to content

Add support for async generator injections #900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions src/dependency_injector/_cwiring.pyi
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar
from typing import Any, Dict

from .providers import Provider

T = TypeVar("T")
class DependencyResolver:
def __init__(
self,
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> None: ...
def __enter__(self) -> Dict[str, Any]: ...
def __exit__(self, *exc_info: Any) -> None: ...
async def __aenter__(self) -> Dict[str, Any]: ...
async def __aexit__(self, *exc_info: Any) -> None: ...

def _sync_inject(
fn: Callable[..., T],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> T: ...
async def _async_inject(
fn: Callable[..., Awaitable[T]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
injections: Dict[str, Provider[Any]],
closings: Dict[str, Provider[Any]],
/,
) -> T: ...
def _isawaitable(instance: Any) -> bool: ...
157 changes: 92 additions & 65 deletions src/dependency_injector/_cwiring.pyx
Original file line number Diff line number Diff line change
@@ -1,83 +1,110 @@
"""Wiring optimizations module."""

import asyncio
import collections.abc
import inspect
import types
from asyncio import gather
from collections.abc import Awaitable
from inspect import CO_ITERABLE_COROUTINE
from types import CoroutineType, GeneratorType

from .providers cimport Provider, Resource, NULL_AWAITABLE
from .wiring import _Marker

from .providers cimport Provider, Resource
cimport cython


def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
cdef object result
@cython.internal
@cython.no_gc
cdef class KWPair:
cdef str name
cdef object value

def __cinit__(self, str name, object value, /):
self.name = name
self.value = value


cdef inline bint _is_injectable(dict kwargs, str name):
return name not in kwargs or isinstance(kwargs[name], _Marker)


cdef class DependencyResolver:
cdef dict kwargs
cdef dict to_inject
cdef object arg_key
cdef Provider provider
cdef dict injections
cdef dict closings

to_inject = kwargs.copy()
for arg_key, provider in injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
to_inject[arg_key] = provider()
def __init__(self, dict kwargs, dict injections, dict closings, /):
self.kwargs = kwargs
self.to_inject = kwargs.copy()
self.injections = injections
self.closings = closings

result = fn(*args, **to_inject)
async def _await_injection(self, kw_pair: KWPair, /) -> None:
self.to_inject[kw_pair.name] = await kw_pair.value

if closings:
for arg_key, provider in closings.items():
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, Resource):
continue
provider.shutdown()
cdef object _await_injections(self, to_await: list):
return gather(*map(self._await_injection, to_await))

return result
cdef void _handle_injections_sync(self):
cdef Provider provider

for name, provider in self.injections.items():
if _is_injectable(self.kwargs, name):
self.to_inject[name] = provider()

async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
cdef object result
cdef dict to_inject
cdef list to_inject_await = []
cdef list to_close_await = []
cdef object arg_key
cdef Provider provider

to_inject = kwargs.copy()
for arg_key, provider in injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
provide = provider()
if provider.is_async_mode_enabled():
to_inject_await.append((arg_key, provide))
elif _isawaitable(provide):
to_inject_await.append((arg_key, provide))
else:
to_inject[arg_key] = provide

if to_inject_await:
async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await))
for provide, (injection, _) in zip(async_to_inject, to_inject_await):
to_inject[injection] = provide

result = await fn(*args, **to_inject)

if closings:
for arg_key, provider in closings.items():
if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, Resource):
continue
shutdown = provider.shutdown()
if _isawaitable(shutdown):
to_close_await.append(shutdown)

await asyncio.gather(*to_close_await)

return result
cdef list _handle_injections_async(self):
cdef list to_await = []
cdef Provider provider

for name, provider in self.injections.items():
if _is_injectable(self.kwargs, name):
provide = provider()

if provider.is_async_mode_enabled() or _isawaitable(provide):
to_await.append(KWPair(name, provide))
else:
self.to_inject[name] = provide

return to_await

cdef void _handle_closings_sync(self):
cdef Provider provider

for name, provider in self.closings.items():
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
provider.shutdown()

cdef list _handle_closings_async(self):
cdef list to_await = []
cdef Provider provider

for name, provider in self.closings.items():
if _is_injectable(self.kwargs, name) and isinstance(provider, Resource):
if _isawaitable(shutdown := provider.shutdown()):
to_await.append(shutdown)

return to_await

def __enter__(self):
self._handle_injections_sync()
return self.to_inject

def __exit__(self, *_):
self._handle_closings_sync()

async def __aenter__(self):
if to_await := self._handle_injections_async():
await self._await_injections(to_await)
return self.to_inject

def __aexit__(self, *_):
if to_await := self._handle_closings_async():
return gather(*to_await)
return NULL_AWAITABLE


cdef bint _isawaitable(object instance):
"""Return true if object can be passed to an ``await`` expression."""
return (isinstance(instance, types.CoroutineType) or
isinstance(instance, types.GeneratorType) and
bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or
isinstance(instance, collections.abc.Awaitable))
return (isinstance(instance, CoroutineType) or
isinstance(instance, GeneratorType) and
bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or
isinstance(instance, Awaitable))
44 changes: 26 additions & 18 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterable,
Expand Down Expand Up @@ -720,6 +721,8 @@ def _get_patched(

if inspect.iscoroutinefunction(fn):
patched = _get_async_patched(fn, patched_object)
elif inspect.isasyncgenfunction(fn):
patched = _get_async_gen_patched(fn, patched_object)
else:
patched = _get_sync_patched(fn, patched_object)

Expand Down Expand Up @@ -1035,36 +1038,41 @@ def is_loader_installed() -> bool:
_loader = AutoLoader()

# Optimizations
from ._cwiring import _async_inject # noqa
from ._cwiring import _sync_inject # noqa
from ._cwiring import DependencyResolver # noqa: E402


# Wiring uses the following Python wrapper because there is
# no possibility to compile a first-type citizen coroutine in Cython.
def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
async def _patched(*args, **kwargs):
return await _async_inject(
fn,
args,
kwargs,
patched.injections,
patched.closing,
)
async def _patched(*args: Any, **raw_kwargs: Any) -> Any:
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)

async with resolver as kwargs:
return await fn(*args, **kwargs)

return cast(F, _patched)


def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]:
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)

async with resolver as kwargs:
async for obj in fn(*args, **kwargs):
yield obj

return cast(F, _patched)


def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
def _patched(*args, **kwargs):
return _sync_inject(
fn,
args,
kwargs,
patched.injections,
patched.closing,
)
def _patched(*args: Any, **raw_kwargs: Any) -> Any:
resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing)

with resolver as kwargs:
return fn(*args, **kwargs)

return cast(F, _patched)

Expand Down
13 changes: 12 additions & 1 deletion tests/unit/samples/wiring/asyncinjections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio

from typing_extensions import Annotated

from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing
from dependency_injector.wiring import Closing, Provide, inject


class TestResource:
Expand Down Expand Up @@ -42,6 +44,15 @@ async def async_injection(
return resource1, resource2


@inject
async def async_generator_injection(
resource1: object = Provide[Container.resource1],
resource2: object = Closing[Provide[Container.resource2]],
):
yield resource1
yield resource2


@inject
async def async_injection_with_closing(
resource1: object = Closing[Provide[Container.resource1]],
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/wiring/provider_ids/test_async_injections_py36.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ async def test_async_injections():
assert asyncinjections.resource2.shutdown_counter == 0


@mark.asyncio
async def test_async_generator_injections() -> None:
resources = []

async for resource in asyncinjections.async_generator_injection():
resources.append(resource)

assert len(resources) == 2
assert resources[0] is asyncinjections.resource1
assert asyncinjections.resource1.init_counter == 1
assert asyncinjections.resource1.shutdown_counter == 0

assert resources[1] is asyncinjections.resource2
assert asyncinjections.resource2.init_counter == 1
assert asyncinjections.resource2.shutdown_counter == 1


@mark.asyncio
async def test_async_injections_with_closing():
resource1, resource2 = await asyncinjections.async_injection_with_closing()
Expand Down