Skip to content

Commit 5a6ae2e

Browse files
committed
Replaced threading.Semaphore with asyncio.Semaphore, improved tests and cleaned everything up.
1 parent 02d29ea commit 5a6ae2e

File tree

8 files changed

+164
-97
lines changed

8 files changed

+164
-97
lines changed
Lines changed: 11 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import concurrent.futures
23
import contextlib
34
import contextvars
@@ -6,70 +7,13 @@
67

78
import asgiref.sync
89

9-
10-
_reentrant_patch_lock = threading.Lock()
11-
12-
13-
@contextlib.contextmanager
14-
def reentrant_patch(obj, attr, value):
15-
"""
16-
Makes time-aware patch on the attribute of the object without locking like in `unittest.mock.patch`, the context
17-
will leak system-wide.
18-
However, if no `await` happens after obtaining the context, and no threads are getting the same attribute,
19-
it guarantees that the attribute will have the desired value.
20-
Effectively guarantees to restore original value after all contexts are destroyed.
21-
No protection from interleaving foreign code doing same.
22-
"""
23-
24-
with _reentrant_patch_lock:
25-
contexts = getattr(obj, f"__{attr}__contexts__", {})
26-
if not contexts:
27-
contexts[1] = getattr(obj, attr)
28-
setattr(obj, f"__{attr}__contexts__", contexts)
29-
context_id = len(contexts) + 1
30-
contexts[context_id] = value
31-
setattr(obj, attr, value)
32-
33-
yield
34-
35-
with _reentrant_patch_lock:
36-
last_context_id = next(reversed(contexts))
37-
del contexts[context_id]
38-
if last_context_id == context_id:
39-
setattr(obj, attr, next(reversed(contexts.values())))
40-
if len(contexts) == 1:
41-
delattr(obj, f"__{attr}__contexts__")
42-
43-
44-
_one_time_patch_lock = threading.Lock()
45-
46-
47-
@contextlib.contextmanager
48-
def one_time_patch(obj, attr, value):
49-
"""
50-
More lightweight implementation, only sets the attribute once — in outer context.
51-
Effectively guarantees to restore original value after all contexts are destroyed.
52-
"""
53-
54-
with _one_time_patch_lock:
55-
if not hasattr(obj, f"__{attr}__value__"):
56-
setattr(obj, f"__{attr}__value__", getattr(obj, attr))
57-
setattr(obj, attr, value)
58-
setattr(obj, f"__{attr}__count__", getattr(obj, f"__{attr}__count__", 0) + 1)
59-
60-
yield
61-
62-
with _one_time_patch_lock:
63-
count = getattr(obj, f"__{attr}__count__") - 1
64-
setattr(obj, f"__{attr}__count__", count)
65-
if not count:
66-
setattr(obj, attr, getattr(obj, f"__{attr}__value__"))
67-
delattr(obj, f"__{attr}__value__")
68-
delattr(obj, f"__{attr}__count__")
10+
import django_threaded_sync_to_async.patch
6911

7012

7113
_current_executor = contextvars.ContextVar("current_executor", default=None)
7214
_max_tasks_semaphore = contextvars.ContextVar("max_tasks_semaphore", default=None)
15+
_shared_executors = {}
16+
_shared_executors_lock = threading.Lock()
7317

7418

7519
async def _sync_to_async_call(self, orig, *args, **kwargs):
@@ -78,14 +22,10 @@ async def _sync_to_async_call(self, orig, *args, **kwargs):
7822

7923
else:
8024
"""
81-
The task is called outside of executor's scope (or in different context).
25+
The task hit the call outside of executor's scope (or in different context).
8226
"""
8327

84-
if _max_tasks_semaphore.get() is not None:
85-
with _max_tasks_semaphore.get():
86-
return await orig(self, *args, **kwargs)
87-
88-
else:
28+
async with _max_tasks_semaphore.get() or contextlib.nullcontext(): # Python 3.10+.
8929
return await orig(self, *args, **kwargs)
9030

9131

@@ -98,11 +38,10 @@ def _set_context_variable(variable, value):
9838

9939
@contextlib.contextmanager
10040
def _use_executor(executor):
101-
with _set_context_variable(_current_executor, executor):
102-
# Can be replaced by a single call to `setattr(asgiref.sync.SyncToAsync, "__call__", new_call)`
103-
# if we don't care about restoring everything back.
104-
new_call = functools.partialmethod(_sync_to_async_call, asgiref.sync.SyncToAsync.__call__)
105-
with one_time_patch(asgiref.sync.SyncToAsync, "__call__", new_call):
41+
# `patch.one_time()` can be replaced with `patch.permanent()` if we don't care about restoring everything back.
42+
new_call = functools.partialmethod(_sync_to_async_call, asgiref.sync.SyncToAsync.__call__)
43+
with django_threaded_sync_to_async.patch.one_time(asgiref.sync.SyncToAsync, "__call__", new_call):
44+
with _set_context_variable(_current_executor, executor):
10645
yield executor
10746

10847

@@ -113,10 +52,6 @@ async def Executor(*args, **kwargs):
11352
yield executor
11453

11554

116-
_shared_executors = {}
117-
_shared_executors_lock = threading.Lock()
118-
119-
12055
@contextlib.asynccontextmanager
12156
async def SharedExecutor(name, *args, max_tasks=None, **kwargs):
12257
with _shared_executors_lock:
@@ -128,6 +63,6 @@ async def SharedExecutor(name, *args, max_tasks=None, **kwargs):
12863
kwargs.setdefault("thread_name_prefix", name)
12964
executor = _shared_executors[name] = concurrent.futures.ThreadPoolExecutor(*args, **kwargs)
13065

131-
with _set_context_variable(_max_tasks_semaphore, max_tasks and threading.Semaphore(max_tasks)):
66+
with _set_context_variable(_max_tasks_semaphore, max_tasks and asyncio.Semaphore(max_tasks)):
13267
with _use_executor(executor):
13368
yield executor
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import contextlib
2+
import threading
3+
4+
5+
_reentrant_lock = threading.Lock()
6+
_one_time_lock = threading.Lock()
7+
8+
9+
@contextlib.contextmanager
10+
def reentrant(obj, attr, value):
11+
"""
12+
Makes time-aware patch on the attribute of the object without locking like in `unittest.mock.patch`, the context
13+
will leak system-wide.
14+
However, if no `await` happens after obtaining the context, and no threads are getting the same attribute,
15+
it guarantees that the attribute will have the desired value.
16+
Effectively guarantees to restore original value after all contexts are destroyed.
17+
No protection from interleaving foreign code doing same.
18+
"""
19+
20+
with _reentrant_lock:
21+
contexts = getattr(obj, f"__{attr}__contexts__", {})
22+
if not contexts:
23+
contexts[1] = getattr(obj, attr)
24+
setattr(obj, f"__{attr}__contexts__", contexts)
25+
context_id = len(contexts) + 1
26+
contexts[context_id] = value
27+
setattr(obj, attr, value)
28+
29+
yield
30+
31+
with _reentrant_lock:
32+
last_context_id = next(reversed(contexts))
33+
del contexts[context_id]
34+
if last_context_id == context_id:
35+
setattr(obj, attr, next(reversed(contexts.values())))
36+
if len(contexts) == 1:
37+
delattr(obj, f"__{attr}__contexts__")
38+
39+
40+
@contextlib.contextmanager
41+
def one_time(obj, attr, value):
42+
"""
43+
More lightweight implementation, only sets the attribute once — in outer context.
44+
Effectively guarantees to restore original value after all contexts are destroyed.
45+
"""
46+
47+
with _one_time_lock:
48+
if not hasattr(obj, f"__{attr}__value__"):
49+
setattr(obj, f"__{attr}__value__", getattr(obj, attr))
50+
setattr(obj, attr, value)
51+
setattr(obj, f"__{attr}__count__", getattr(obj, f"__{attr}__count__", 0) + 1)
52+
53+
yield
54+
55+
with _one_time_lock:
56+
count = getattr(obj, f"__{attr}__count__") - 1
57+
setattr(obj, f"__{attr}__count__", count)
58+
if not count:
59+
setattr(obj, attr, getattr(obj, f"__{attr}__value__"))
60+
delattr(obj, f"__{attr}__value__")
61+
delattr(obj, f"__{attr}__count__")
62+
63+
64+
@contextlib.contextmanager
65+
def permanent(obj, attr, value):
66+
"""
67+
Most lightweight implementation.
68+
"""
69+
70+
setattr(obj, attr, value)
71+
72+
yield

example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ async def test():
3838
await four_calls()
3939

4040

41-
# N.B. `contextlib.asynccontextmanager` only works as decorator since Python 3.10.
4241
@django_threaded_sync_to_async.Executor(thread_name_prefix="thread", max_workers=3)
4342
async def test2():
4443
await four_calls()

setup.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@
2121
"License :: OSI Approved :: MIT License",
2222
"Operating System :: OS Independent",
2323
"Programming Language :: Python :: 3 :: Only",
24-
"Programming Language :: Python :: 3.6",
25-
"Programming Language :: Python :: 3.7",
26-
"Programming Language :: Python :: 3.8",
27-
"Programming Language :: Python :: 3.9",
2824
"Programming Language :: Python :: 3.10",
2925
"Programming Language :: Python :: 3.11",
3026
"Programming Language :: Python :: 3.12",

tests/test_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ async def empty(**kwargs):
3939
with self.subTest(parallel=parallel, decorated=decorated):
4040
cv = threading.Condition()
4141
threads = set()
42+
4243
async with context(max_workers=workers):
4344
tt = [asyncio.create_task(function(cv, threads)) for _ in range(workers)]
4445
try:
@@ -52,3 +53,5 @@ async def empty(**kwargs):
5253
else:
5354
if not parallel:
5455
self.assertEqual("No", "exception")
56+
57+
self.assertEqual(len(threads), workers if parallel else 1)

tests/test_one_time_patch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22

3-
import django_threaded_sync_to_async
3+
import django_threaded_sync_to_async.patch
44

55

66
class TestOneTimePatch(unittest.TestCase):
@@ -13,10 +13,10 @@ class Dummy:
1313

1414
with self.subTest(inside=[]):
1515
self.assertEqual(o.x, 1)
16-
with django_threaded_sync_to_async.one_time_patch(o, "x", 2):
16+
with django_threaded_sync_to_async.patch.one_time(o, "x", 2):
1717
with self.subTest(inside=["c1"]):
1818
self.assertEqual(o.x, 2)
19-
with django_threaded_sync_to_async.one_time_patch(o, "x", 3):
19+
with django_threaded_sync_to_async.patch.one_time(o, "x", 3):
2020
with self.subTest(inside=["c1", "c2"]):
2121
self.assertEqual(o.x, 2)
2222
with self.subTest(inside=["c1"]):
@@ -34,8 +34,8 @@ class Dummy:
3434
o = Dummy()
3535
fields = dir(o)
3636

37-
c1 = django_threaded_sync_to_async.one_time_patch(o, "x", 2)
38-
c2 = django_threaded_sync_to_async.one_time_patch(o, "x", 3)
37+
c1 = django_threaded_sync_to_async.patch.one_time(o, "x", 2)
38+
c2 = django_threaded_sync_to_async.patch.one_time(o, "x", 3)
3939

4040
with self.subTest(inside=[]):
4141
self.assertEqual(o.x, 1)

tests/test_reentrant_patch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22

3-
import django_threaded_sync_to_async
3+
import django_threaded_sync_to_async.patch
44

55

66
class TestReentrantPatch(unittest.TestCase):
@@ -13,10 +13,10 @@ class Dummy:
1313

1414
with self.subTest(inside=[]):
1515
self.assertEqual(o.x, 1)
16-
with django_threaded_sync_to_async.reentrant_patch(o, "x", 2):
16+
with django_threaded_sync_to_async.patch.reentrant(o, "x", 2):
1717
with self.subTest(inside=["c1"]):
1818
self.assertEqual(o.x, 2)
19-
with django_threaded_sync_to_async.reentrant_patch(o, "x", 3):
19+
with django_threaded_sync_to_async.patch.reentrant(o, "x", 3):
2020
with self.subTest(inside=["c1", "c2"]):
2121
self.assertEqual(o.x, 3)
2222
with self.subTest(inside=["c1"]):
@@ -34,8 +34,8 @@ class Dummy:
3434
o = Dummy()
3535
fields = dir(o)
3636

37-
c1 = django_threaded_sync_to_async.reentrant_patch(o, "x", 2)
38-
c2 = django_threaded_sync_to_async.reentrant_patch(o, "x", 3)
37+
c1 = django_threaded_sync_to_async.patch.reentrant(o, "x", 2)
38+
c2 = django_threaded_sync_to_async.patch.reentrant(o, "x", 3)
3939

4040
with self.subTest(inside=[]):
4141
self.assertEqual(o.x, 1)

tests/test_shared_executor.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import asyncio
2+
import contextlib
3+
import functools
4+
import threading
15
import unittest
26

37
import asgiref.sync
@@ -7,18 +11,76 @@
711

812
class TestSharedExecutor(unittest.IsolatedAsyncioTestCase):
913
async def testSimple(self):
10-
async with django_threaded_sync_to_async.SharedExecutor("common") as executor:
14+
async with django_threaded_sync_to_async.SharedExecutor("simple_common") as executor:
1115
pass
1216

1317
with self.subTest(same_name=True):
14-
async with django_threaded_sync_to_async.SharedExecutor("common") as another_executor:
18+
async with django_threaded_sync_to_async.SharedExecutor("simple_common") as another_executor:
1519
self.assertIs(executor, another_executor)
1620

1721
with self.subTest(same_name=False):
18-
async with django_threaded_sync_to_async.SharedExecutor("specific") as specific_executor:
22+
async with django_threaded_sync_to_async.SharedExecutor("simple_specific") as specific_executor:
1923
self.assertIsNot(executor, specific_executor)
2024

2125
async def testMaxTasks(self):
22-
async with django_threaded_sync_to_async.SharedExecutor("common", max_tasks=2, max_workers=3) as executor:
23-
# TODO Make proper test on `max_tasks`.
24-
self.assertEqual(await asgiref.sync.sync_to_async(lambda: 42)(), 42)
26+
workers = 10
27+
timeout = 0.05
28+
29+
def long_call(cv, threads, another_threads):
30+
threads.add(threading.current_thread().name)
31+
32+
def notify_cv(predicate):
33+
with cv:
34+
if predicate():
35+
cv.notify_all()
36+
37+
with cv:
38+
notify = threading.Thread(target=notify_cv, args=(lambda: len(threads) == workers,))
39+
notify.start()
40+
cv.wait_for(lambda: len(threads) == workers, timeout)
41+
42+
result = len(threads)
43+
another_threads.add(threading.current_thread().name)
44+
45+
with cv:
46+
notify = threading.Thread(target=notify_cv, args=(lambda: len(another_threads) == workers,))
47+
notify.start()
48+
cv.wait_for(lambda: len(another_threads) == workers, timeout)
49+
50+
threads.discard(threading.current_thread().name)
51+
another_threads.discard(threading.current_thread().name)
52+
return result
53+
54+
@asgiref.sync.sync_to_async
55+
def decorated_long_call(*args):
56+
return long_call(*args)
57+
58+
@contextlib.asynccontextmanager
59+
async def empty(name, **kwargs):
60+
yield
61+
62+
for parallel, context in ((False, empty), (True, django_threaded_sync_to_async.SharedExecutor)):
63+
for decorated, function in ((False, asgiref.sync.sync_to_async(long_call)), (True, decorated_long_call)):
64+
for tasks in (workers, workers - 1):
65+
# One or two passes are allowed — `tasks` must not be less than `workers/2`.
66+
with self.subTest(parallel=parallel, decorated=decorated, tasks=tasks):
67+
cv = threading.Condition()
68+
threads = set()
69+
another_threads = set()
70+
71+
async with context(
72+
f"max_tasks_{parallel}_{decorated}_{tasks}", max_workers=workers, max_tasks=tasks
73+
):
74+
tt = [asyncio.create_task(function(cv, threads, another_threads)) for _ in range(workers)]
75+
try:
76+
total_timeout = timeout * (2.5 if tasks == workers else 4.5)
77+
for c in asyncio.as_completed(tt, timeout=total_timeout):
78+
self.assertIn(await c, (tasks, workers - tasks) if parallel else (1,))
79+
except asyncio.TimeoutError:
80+
for t in tt:
81+
t.cancel()
82+
if parallel:
83+
self.assertEqual("Exception", "occurred")
84+
else:
85+
if not parallel:
86+
self.assertEqual("No", "exception")

0 commit comments

Comments
 (0)