Skip to content
This repository was archived by the owner on Jun 21, 2025. It is now read-only.

Commit c213746

Browse files
feat: NestedAsyncIO contextmanager
This comes from erdewit/nest_asyncio#88 implemented by @CharlieJiangXXX before NestedAsyncIO was archived. It should help with some issues where the patched eventloop causes problems with some other libraries, like Discord.py, that do some "odd" things with eventloops
1 parent 89e1421 commit c213746

File tree

3 files changed

+414
-4
lines changed

3 files changed

+414
-4
lines changed

pydantic_aioredis/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
from typing import Tuple
1111
from typing import Union
1212

13-
import nest_asyncio
1413
from pydantic_aioredis.abstract import _AbstractModel
15-
from pydantic_aioredis.utils import bytes_to_string
14+
from pydantic_aioredis.utils import bytes_to_string, NestedAsyncIO
1615

1716

1817
class Model(_AbstractModel):
@@ -69,8 +68,8 @@ def __save_from_sync(self):
6968
# Use nest_asyncio so we can call the async save
7069
except RuntimeError:
7170
io_loop = asyncio.new_event_loop()
72-
nest_asyncio.apply()
73-
io_loop.run_until_complete(self.save())
71+
with NestedAsyncIO():
72+
io_loop.run_until_complete(self.save())
7473

7574
@asynccontextmanager
7675
async def update(self):

pydantic_aioredis/utils.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,313 @@
11
"""Module containing common utilities"""
22

3+
import asyncio
4+
import asyncio.events as events
5+
import os
6+
import sys
7+
import threading
8+
from contextlib import contextmanager, suppress
9+
from heapq import heappop
10+
311

412
def bytes_to_string(data: bytes):
513
"""Converts data to string"""
614
return str(data, "utf-8") if isinstance(data, bytes) else data
15+
16+
17+
class NestedAsyncIO:
18+
"""Patch asyncio to allow nested event loops."""
19+
20+
__slots__ = [
21+
"_loop",
22+
"orig_run",
23+
"orig_tasks",
24+
"orig_futures",
25+
"orig_loop_attrs",
26+
"policy_get_loop",
27+
"orig_get_loops",
28+
"orig_tc",
29+
"patched",
30+
]
31+
_instance = None
32+
_initialized = False
33+
34+
def __new__(cls, *args, **kwargs):
35+
if not cls._instance:
36+
cls._instance = super().__new__(cls)
37+
return cls._instance
38+
39+
def __init__(self, loop=None):
40+
if not self._initialized:
41+
self._loop = loop
42+
self.orig_run = None
43+
self.orig_tasks = []
44+
self.orig_futures = []
45+
self.orig_loop_attrs = {}
46+
self.policy_get_loop = None
47+
self.orig_get_loops = {}
48+
self.orig_tc = None
49+
self.patched = False
50+
self.__class__._initialized = True
51+
52+
def __enter__(self):
53+
self.apply(self._loop)
54+
return self
55+
56+
def __exit__(self, exc_type, exc_val, exc_tb):
57+
self.revert()
58+
59+
def apply(self, loop=None):
60+
"""Patch asyncio to make its event loop reentrant."""
61+
if not self.patched:
62+
self.patch_asyncio()
63+
self.patch_policy()
64+
self.patch_tornado()
65+
66+
loop = loop or asyncio.get_event_loop()
67+
self.patch_loop(loop)
68+
self.patched = True
69+
70+
def revert(self):
71+
if self.patched:
72+
for loop in self.orig_loop_attrs:
73+
self.unpatch_loop(loop)
74+
self.unpatch_tornado()
75+
self.unpatch_policy()
76+
self.unpatch_asyncio()
77+
self.patched = False
78+
79+
def patch_asyncio(self):
80+
"""Patch asyncio module to use pure Python tasks and futures."""
81+
82+
def run(main, *, debug=False):
83+
loop = asyncio.get_event_loop()
84+
loop.set_debug(debug)
85+
task = asyncio.ensure_future(main)
86+
try:
87+
return loop.run_until_complete(task)
88+
finally:
89+
if not task.done():
90+
task.cancel()
91+
with suppress(asyncio.CancelledError):
92+
loop.run_until_complete(task)
93+
94+
def _get_event_loop(stacklevel=3):
95+
return events._get_running_loop() or events.get_event_loop_policy().get_event_loop()
96+
97+
# Use module level _current_tasks, all_tasks and patch run method.
98+
if getattr(asyncio, "_nest_patched", False):
99+
return
100+
self.orig_tasks = [asyncio.Task, asyncio.tasks._CTask, asyncio.tasks.Task]
101+
asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = asyncio.tasks._PyTask
102+
self.orig_futures = [asyncio.Future, asyncio.futures._CFuture, asyncio.futures.Future]
103+
asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = asyncio.futures._PyFuture
104+
if sys.version_info < (3, 7, 0):
105+
asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks
106+
asyncio.all_tasks = asyncio.tasks.Task.all_tasks
107+
elif sys.version_info >= (3, 9, 0):
108+
self.orig_get_loops = {
109+
"events__get_event_loop": events._get_event_loop,
110+
"events_get_event_loop": events.get_event_loop,
111+
"asyncio_get_event_loop": asyncio.get_event_loop,
112+
}
113+
events._get_event_loop = events.get_event_loop = asyncio.get_event_loop = _get_event_loop
114+
self.orig_run = asyncio.run
115+
asyncio.run = run
116+
asyncio._nest_patched = True
117+
118+
def unpatch_asyncio(self):
119+
if self.orig_run:
120+
asyncio.run = self.orig_run
121+
asyncio._nest_patched = False
122+
(asyncio.Task, asyncio.tasks._CTask, asyncio.tasks.Task) = self.orig_tasks
123+
(asyncio.Future, asyncio.futures._CFuture, asyncio.futures.Future) = self.orig_futures
124+
if sys.version_info >= (3, 9, 0):
125+
for key, value in self.orig_get_loops.items():
126+
setattr(asyncio if key.startswith("asyncio") else events, key.split("_")[-1], value)
127+
128+
def patch_policy(self):
129+
"""Patch the policy to always return a patched loop."""
130+
131+
def get_event_loop(this):
132+
if this._local._loop is None:
133+
loop = this.new_event_loop()
134+
self.patch_loop(loop)
135+
this.set_event_loop(loop)
136+
return this._local._loop
137+
138+
cls = events.get_event_loop_policy().__class__
139+
self.policy_get_loop = cls.get_event_loop
140+
cls.get_event_loop = get_event_loop
141+
142+
def unpatch_policy(self):
143+
cls = events.get_event_loop_policy().__class__
144+
orig = self.policy_get_loop
145+
if orig:
146+
cls.get_event_loop = orig
147+
148+
def patch_loop(self, loop):
149+
"""Patch loop to make it reentrant."""
150+
151+
def run_forever(this):
152+
with manage_run(this), manage_asyncgens(this):
153+
while True:
154+
this._run_once()
155+
if this._stopping:
156+
break
157+
this._stopping = False
158+
159+
def run_until_complete(this, future):
160+
with manage_run(this):
161+
f = asyncio.ensure_future(future, loop=this)
162+
if f is not future:
163+
f._log_destroy_pending = False
164+
while not f.done():
165+
this._run_once()
166+
if this._stopping:
167+
break
168+
if not f.done():
169+
raise RuntimeError("Event loop stopped before Future completed.")
170+
return f.result()
171+
172+
def _run_once(this):
173+
"""
174+
Simplified re-implementation of asyncio's _run_once that
175+
runs handles as they become ready.
176+
"""
177+
ready = this._ready
178+
scheduled = this._scheduled
179+
while scheduled and scheduled[0]._cancelled:
180+
heappop(scheduled)
181+
182+
timeout = (
183+
0
184+
if ready or this._stopping
185+
else min(max(scheduled[0]._when - this.time(), 0), 86400)
186+
if scheduled
187+
else None
188+
)
189+
event_list = this._selector.select(timeout)
190+
this._process_events(event_list)
191+
192+
end_time = this.time() + this._clock_resolution
193+
while scheduled and scheduled[0]._when < end_time:
194+
handle = heappop(scheduled)
195+
ready.append(handle)
196+
197+
for _ in range(len(ready)):
198+
if not ready:
199+
break
200+
handle = ready.popleft()
201+
if not handle._cancelled:
202+
# preempt the current task so that that checks in
203+
# Task.__step do not raise
204+
curr_task = curr_tasks.pop(this, None)
205+
206+
try:
207+
handle._run()
208+
finally:
209+
# restore the current task
210+
if curr_task is not None:
211+
curr_tasks[this] = curr_task
212+
213+
handle = None
214+
215+
@contextmanager
216+
def manage_run(this):
217+
"""Set up the loop for running."""
218+
this._check_closed()
219+
old_thread_id = this._thread_id
220+
old_running_loop = events._get_running_loop()
221+
try:
222+
this._thread_id = threading.get_ident()
223+
events._set_running_loop(this)
224+
this._num_runs_pending += 1
225+
if this._is_proactorloop:
226+
if this._self_reading_future is None:
227+
this.call_soon(this._loop_self_reading)
228+
yield
229+
finally:
230+
this._thread_id = old_thread_id
231+
events._set_running_loop(old_running_loop)
232+
this._num_runs_pending -= 1
233+
if this._is_proactorloop:
234+
if this._num_runs_pending == 0 and this._self_reading_future is not None:
235+
ov = this._self_reading_future._ov
236+
this._self_reading_future.cancel()
237+
if ov is not None:
238+
this._proactor._unregister(ov)
239+
this._self_reading_future = None
240+
241+
@contextmanager
242+
def manage_asyncgens(this):
243+
if not hasattr(sys, "get_asyncgen_hooks"):
244+
# Python version is too old.
245+
return
246+
old_agen_hooks = sys.get_asyncgen_hooks()
247+
try:
248+
this._set_coroutine_origin_tracking(this._debug)
249+
if this._asyncgens is not None:
250+
sys.set_asyncgen_hooks(
251+
firstiter=this._asyncgen_firstiter_hook, finalizer=this._asyncgen_finalizer_hook
252+
)
253+
yield
254+
finally:
255+
this._set_coroutine_origin_tracking(False)
256+
if this._asyncgens is not None:
257+
sys.set_asyncgen_hooks(*old_agen_hooks)
258+
259+
def _check_running(this):
260+
"""Do not throw exception if loop is already running."""
261+
pass
262+
263+
if getattr(loop, "_nest_patched", False):
264+
return
265+
if not isinstance(loop, asyncio.BaseEventLoop):
266+
raise ValueError("Can't patch loop of type %s" % type(loop))
267+
cls = loop.__class__
268+
self.orig_loop_attrs[cls] = {}
269+
self.orig_loop_attrs[cls]["run_forever"] = cls.run_forever
270+
cls.run_forever = run_forever
271+
self.orig_loop_attrs[cls]["run_until_complete"] = cls.run_until_complete
272+
cls.run_until_complete = run_until_complete
273+
self.orig_loop_attrs[cls]["_run_once"] = cls._run_once
274+
cls._run_once = _run_once
275+
self.orig_loop_attrs[cls]["_check_running"] = cls._check_running
276+
cls._check_running = _check_running
277+
self.orig_loop_attrs[cls]["_check_runnung"] = cls._check_running
278+
cls._check_runnung = _check_running # typo in Python 3.7 source
279+
cls._num_runs_pending = 1 if loop.is_running() else 0
280+
cls._is_proactorloop = os.name == "nt" and issubclass(cls, asyncio.ProactorEventLoop)
281+
if sys.version_info < (3, 7, 0):
282+
cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper
283+
curr_tasks = asyncio.tasks._current_tasks if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks
284+
cls._nest_patched = True
285+
286+
def unpatch_loop(self, loop):
287+
loop._nest_patched = False
288+
if self.orig_loop_attrs[loop]:
289+
for key, value in self.orig_loop_attrs[loop].items():
290+
setattr(loop, key, value)
291+
292+
for attr in ["_num_runs_pending", "_is_proactorloop"]:
293+
if hasattr(loop, attr):
294+
delattr(loop, attr)
295+
296+
def patch_tornado(self):
297+
"""
298+
If tornado is imported before nest_asyncio, make tornado aware of
299+
the pure-Python asyncio Future.
300+
"""
301+
if "tornado" in sys.modules:
302+
import tornado.concurrent as tc # type: ignore
303+
304+
self.orig_tc = tc.Future
305+
tc.Future = asyncio.Future
306+
if asyncio.Future not in tc.FUTURES:
307+
tc.FUTURES += (asyncio.Future,)
308+
309+
def unpatch_tornado(self):
310+
if self.orig_tc:
311+
import tornado.concurrent as tc # noqa
312+
313+
tc.Future = self.orig_tc

0 commit comments

Comments
 (0)