Skip to content

Commit c305cc5

Browse files
filipnestakFilip Nešťákdavidparsson
authored
Allow annotated in classes (#263)
* Allow annotated in classes --------- Co-authored-by: Filip Nešťák <[email protected]> Co-authored-by: David Pärsson <[email protected]>
1 parent 7b0d446 commit c305cc5

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

injector/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,11 @@ def _is_new_union_type(instance: Any) -> bool:
12081208
new_union_type = getattr(types, 'UnionType', None)
12091209
return new_union_type is not None and isinstance(instance, new_union_type)
12101210

1211+
def _is_injection_annotation(annotation: Any) -> bool:
1212+
return _is_specialization(annotation, Annotated) and (
1213+
_inject_marker in annotation.__metadata__ or _noinject_marker in annotation.__metadata__
1214+
)
1215+
12111216
spec = inspect.getfullargspec(callable)
12121217

12131218
try:
@@ -1238,7 +1243,8 @@ def _is_new_union_type(instance: Any) -> bool:
12381243
bindings.pop(spec.varkw, None)
12391244

12401245
for k, v in list(bindings.items()):
1241-
if _is_specialization(v, Annotated):
1246+
# extract metadata only from Inject and NonInject
1247+
if _is_injection_annotation(v):
12421248
v, metadata = v.__origin__, v.__metadata__
12431249
bindings[k] = v
12441250
else:

injector_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""Functional tests for the "Injector" dependency injection framework."""
1212

1313
from contextlib import contextmanager
14+
from dataclasses import dataclass
1415
from typing import Any, NewType, Optional, Union
1516
import abc
1617
import sys
@@ -1754,3 +1755,53 @@ def configure(binder):
17541755
injector = Injector([configure])
17551756
assert injector.get(foo) == 123
17561757
assert injector.get(bar) == 456
1758+
1759+
1760+
def test_annotated_integration_with_annotated():
1761+
UserID = Annotated[int, 'user_id']
1762+
1763+
@inject
1764+
class TestClass:
1765+
def __init__(self, user_id: UserID):
1766+
self.user_id = user_id
1767+
1768+
def configure(binder):
1769+
binder.bind(UserID, to=123)
1770+
1771+
injector = Injector([configure])
1772+
1773+
test_class = injector.get(TestClass)
1774+
assert test_class.user_id == 123
1775+
1776+
1777+
def test_newtype_integration_with_annotated():
1778+
UserID = NewType('UserID', int)
1779+
1780+
@inject
1781+
class TestClass:
1782+
def __init__(self, user_id: UserID):
1783+
self.user_id = user_id
1784+
1785+
def configure(binder):
1786+
binder.bind(UserID, to=123)
1787+
1788+
injector = Injector([configure])
1789+
1790+
test_class = injector.get(TestClass)
1791+
assert test_class.user_id == 123
1792+
1793+
1794+
def test_dataclass_annotated_parameter():
1795+
Foo = Annotated[int, object()]
1796+
1797+
def configure(binder):
1798+
binder.bind(Foo, to=123)
1799+
1800+
@inject
1801+
@dataclass
1802+
class MyClass:
1803+
foo: Foo
1804+
1805+
injector = Injector([configure])
1806+
instance = injector.get(MyClass)
1807+
assert instance.foo == 123

0 commit comments

Comments
 (0)