Skip to content

Commit bf1df27

Browse files
committed
Added SharedExecutor.
1 parent de9aa4c commit bf1df27

File tree

3 files changed

+87
-16
lines changed

3 files changed

+87
-16
lines changed

README.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
### `django_threaded_sync_to_async`
1+
## `django_threaded_sync_to_async`
22

3-
Under executor context, replaces `sync_to_async` calls to `sync_to_async(thread_sensitive=None, executor=...)`, effectively allowing Django to make calls to database concurrently:
3+
FIXME add description
4+
5+
### `Executor`
6+
7+
Under executor context, `Executor` replaces `sync_to_async` calls to `sync_to_async(thread_sensitive=None, executor=...)`, effectively allowing Django to make calls to database concurrently:
48

59
```python3
610
async with django_threaded_sync_to_async.Executor(thread_name_prefix="thread", max_workers=3) as executor:
@@ -10,3 +14,17 @@ async with django_threaded_sync_to_async.Executor(thread_name_prefix="thread", m
1014
d = asgiref.sync.sync_to_async(long_call)(4)
1115
await asyncio.gather(a, b, c, d)
1216
```
17+
18+
### `SharedExecutor`
19+
20+
Maintains global dictionary of executors (`concurrent.futures.ThreadPoolExecutor`) accessed by name and allows to limit utilization of executor for a single context.
21+
22+
```python3
23+
@django_threaded_sync_to_async.SharedExecutor("common", max_workers=3, max_tasks=2)
24+
def operations():
25+
a = asgiref.sync.sync_to_async(long_call)(1)
26+
b = asgiref.sync.sync_to_async(long_call)(2)
27+
c = asgiref.sync.sync_to_async(long_call)(3)
28+
d = asgiref.sync.sync_to_async(long_call)(4)
29+
await asyncio.gather(a, b, c, d)
30+
```

django_threaded_sync_to_async/__init__.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,37 +68,66 @@ def one_time_patch(obj, attr, value):
6868
delattr(obj, f"__{attr}__count__")
6969

7070

71+
_current_executor = contextvars.ContextVar("current_executor", default=None)
72+
_max_tasks_semaphore = contextvars.ContextVar("max_tasks_semaphore", default=None)
73+
74+
7175
async def _sync_to_async_call(self, orig, *args, **kwargs):
72-
if (executor := _get_current_executor()) is not None:
76+
if (executor := _current_executor.get()) is not None:
7377
self = asgiref.sync.SyncToAsync(self.func, thread_sensitive=False, executor=executor)
7478

7579
else:
7680
"""
7781
The task is called outside of executor's scope (or in different context).
7882
"""
7983

80-
return await orig(self, *args, **kwargs)
84+
if _max_tasks_semaphore.get() is not None:
85+
with _max_tasks_semaphore.get():
86+
return await orig(self, *args, **kwargs)
8187

82-
83-
_current_executor = contextvars.ContextVar("current_executor", default=None)
88+
else:
89+
return await orig(self, *args, **kwargs)
8490

8591

8692
@contextlib.contextmanager
87-
def _set_current_executor(value):
88-
token = _current_executor.set(value)
93+
def _set_context_variable(variable, value):
94+
token = variable.set(value)
8995
yield
90-
_current_executor.reset(token)
96+
variable.reset(token)
9197

9298

93-
def _get_current_executor():
94-
return _current_executor.get()
99+
@contextlib.contextmanager
100+
def _use_executor(executor):
101+
with _set_context_variable(_current_executor, executor):
102+
# Can be replaced by a single call to `setattr(asgiref.sync.SyncToAsync, "__call__", new_call)`
103+
# if we don't care about restoring everything back.
104+
new_call = functools.partialmethod(_sync_to_async_call, asgiref.sync.SyncToAsync.__call__)
105+
with one_time_patch(asgiref.sync.SyncToAsync, "__call__", new_call):
106+
yield executor
95107

96108

97109
@contextlib.asynccontextmanager
98110
async def Executor(*args, **kwargs):
99111
with concurrent.futures.ThreadPoolExecutor(*args, **kwargs) as executor:
100-
with _set_current_executor(executor):
101-
# It can be replaced by a single call to `setattr(obj, attr, value)` if we don't care about restoring everything back.
102-
new_call = functools.partialmethod(_sync_to_async_call, asgiref.sync.SyncToAsync.__call__)
103-
with one_time_patch(asgiref.sync.SyncToAsync, "__call__", new_call):
104-
yield executor
112+
with _use_executor(executor):
113+
yield executor
114+
115+
116+
_shared_executors = {}
117+
_shared_executors_lock = threading.Lock()
118+
119+
120+
@contextlib.asynccontextmanager
121+
async def SharedExecutor(name, *args, max_tasks=None, **kwargs):
122+
with _shared_executors_lock:
123+
if name in _shared_executors:
124+
executor = _shared_executors[name]
125+
if "max_workers" in kwargs:
126+
executor._max_workers = max(kwargs["max_workers"], executor._max_workers)
127+
else:
128+
kwargs.setdefault("thread_name_prefix", name)
129+
executor = _shared_executors[name] = concurrent.futures.ThreadPoolExecutor(*args, **kwargs)
130+
131+
with _set_context_variable(_max_tasks_semaphore, max_tasks and threading.Semaphore(max_tasks)):
132+
with _use_executor(executor):
133+
yield executor

tests/test_shared_executor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import unittest
2+
3+
import asgiref.sync
4+
5+
import django_threaded_sync_to_async
6+
7+
8+
class TestSharedExecutor(unittest.IsolatedAsyncioTestCase):
9+
async def testSimple(self):
10+
async with django_threaded_sync_to_async.SharedExecutor("common") as executor:
11+
pass
12+
13+
with self.subTest(same_name=True):
14+
async with django_threaded_sync_to_async.SharedExecutor("common") as another_executor:
15+
self.assertIs(executor, another_executor)
16+
17+
with self.subTest(same_name=False):
18+
async with django_threaded_sync_to_async.SharedExecutor("specific") as specific_executor:
19+
self.assertIsNot(executor, specific_executor)
20+
21+
async def testMaxTasks(self):
22+
async with django_threaded_sync_to_async.SharedExecutor("common", max_tasks=2, max_workers=3) as executor:
23+
# TODO Make proper test on `max_tasks`.
24+
self.assertEqual(await asgiref.sync.sync_to_async(lambda: 42)(), 42)

0 commit comments

Comments
 (0)