Skip to content

Commit bac6353

Browse files
committed
Add trio message assembler.
1 parent 25c5c07 commit bac6353

File tree

4 files changed

+1008
-38
lines changed

4 files changed

+1008
-38
lines changed

src/websockets/asyncio/messages.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ class Assembler:
8181
8282
"""
8383

84-
# coverage reports incorrectly: "line NN didn't jump to the function exit"
85-
def __init__( # pragma: no cover
84+
def __init__(
8685
self,
8786
high: int | None = None,
8887
low: int | None = None,

src/websockets/trio/messages.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
from __future__ import annotations
2+
3+
import codecs
4+
import math
5+
from collections.abc import AsyncIterator
6+
from typing import Any, Callable, Literal, TypeVar, overload
7+
8+
import trio
9+
10+
from ..exceptions import ConcurrencyError
11+
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
12+
from ..typing import Data
13+
14+
15+
__all__ = ["Assembler"]
16+
17+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
18+
19+
T = TypeVar("T")
20+
21+
22+
class Assembler:
23+
"""
24+
Assemble messages from frames.
25+
26+
:class:`Assembler` expects only data frames. The stream of frames must
27+
respect the protocol; if it doesn't, the behavior is undefined.
28+
29+
Args:
30+
pause: Called when the buffer of frames goes above the high water mark;
31+
should pause reading from the network.
32+
resume: Called when the buffer of frames goes below the low water mark;
33+
should resume reading from the network.
34+
35+
"""
36+
37+
def __init__(
38+
self,
39+
high: int | None = None,
40+
low: int | None = None,
41+
pause: Callable[[], Any] = lambda: None,
42+
resume: Callable[[], Any] = lambda: None,
43+
) -> None:
44+
# Queue of incoming frames.
45+
self.send_frames: trio.MemorySendChannel[Frame]
46+
self.recv_frames: trio.MemoryReceiveChannel[Frame]
47+
self.send_frames, self.recv_frames = trio.open_memory_channel(math.inf)
48+
49+
# We cannot put a hard limit on the size of the queue because a single
50+
# call to Protocol.data_received() could produce thousands of frames,
51+
# which must be buffered. Instead, we pause reading when the buffer goes
52+
# above the high limit and we resume when it goes under the low limit.
53+
if high is not None and low is None:
54+
low = high // 4
55+
if high is None and low is not None:
56+
high = low * 4
57+
if high is not None and low is not None:
58+
if low < 0:
59+
raise ValueError("low must be positive or equal to zero")
60+
if high < low:
61+
raise ValueError("high must be greater than or equal to low")
62+
self.high, self.low = high, low
63+
self.pause = pause
64+
self.resume = resume
65+
self.paused = False
66+
67+
# This flag prevents concurrent calls to get() by user code.
68+
self.get_in_progress = False
69+
70+
# This flag marks the end of the connection.
71+
self.closed = False
72+
73+
@overload
74+
async def get(self, decode: Literal[True]) -> str: ...
75+
76+
@overload
77+
async def get(self, decode: Literal[False]) -> bytes: ...
78+
79+
@overload
80+
async def get(self, decode: bool | None = None) -> Data: ...
81+
82+
async def get(self, decode: bool | None = None) -> Data:
83+
"""
84+
Read the next message.
85+
86+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
87+
88+
If the message is fragmented, :meth:`get` waits until the last frame is
89+
received, then it reassembles the message and returns it. To receive
90+
messages frame by frame, use :meth:`get_iter` instead.
91+
92+
Args:
93+
decode: :obj:`False` disables UTF-8 decoding of text frames and
94+
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
95+
binary frames and returns :class:`str`.
96+
97+
Raises:
98+
trio.EndOfChannel: If the stream of frames has ended.
99+
UnicodeDecodeError: If a text frame contains invalid UTF-8.
100+
ConcurrencyError: If two coroutines run :meth:`get` or
101+
:meth:`get_iter` concurrently.
102+
103+
"""
104+
if self.get_in_progress:
105+
raise ConcurrencyError("get() or get_iter() is already running")
106+
self.get_in_progress = True
107+
108+
# Locking with get_in_progress prevents concurrent execution
109+
# until get() fetches a complete message or is canceled.
110+
111+
try:
112+
# First frame
113+
frame = await self.recv_frames.receive()
114+
self.maybe_resume()
115+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
116+
if decode is None:
117+
decode = frame.opcode is OP_TEXT
118+
frames = [frame]
119+
120+
# Following frames, for fragmented messages
121+
while not frame.fin:
122+
try:
123+
frame = await self.recv_frames.receive()
124+
except trio.Cancelled:
125+
# Put frames already received back into the queue
126+
# so that future calls to get() can return them.
127+
assert not self.send_frames._state.receive_tasks, (
128+
"no task should be waiting on receive()"
129+
)
130+
assert not self.send_frames._state.data, "queue should be empty"
131+
for frame in frames:
132+
self.send_frames.send_nowait(frame)
133+
raise
134+
self.maybe_resume()
135+
assert frame.opcode is OP_CONT
136+
frames.append(frame)
137+
138+
finally:
139+
self.get_in_progress = False
140+
141+
data = b"".join(frame.data for frame in frames)
142+
if decode:
143+
return data.decode()
144+
else:
145+
return data
146+
147+
@overload
148+
def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ...
149+
150+
@overload
151+
def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...
152+
153+
@overload
154+
def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
155+
156+
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
157+
"""
158+
Stream the next message.
159+
160+
Iterating the return value of :meth:`get_iter` asynchronously yields a
161+
:class:`str` or :class:`bytes` for each frame in the message.
162+
163+
The iterator must be fully consumed before calling :meth:`get_iter` or
164+
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
165+
166+
This method only makes sense for fragmented messages. If messages aren't
167+
fragmented, use :meth:`get` instead.
168+
169+
Args:
170+
decode: :obj:`False` disables UTF-8 decoding of text frames and
171+
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
172+
binary frames and returns :class:`str`.
173+
174+
Raises:
175+
trio.EndOfChannel: If the stream of frames has ended.
176+
UnicodeDecodeError: If a text frame contains invalid UTF-8.
177+
ConcurrencyError: If two coroutines run :meth:`get` or
178+
:meth:`get_iter` concurrently.
179+
180+
"""
181+
if self.get_in_progress:
182+
raise ConcurrencyError("get() or get_iter() is already running")
183+
self.get_in_progress = True
184+
185+
# Locking with get_in_progress prevents concurrent execution
186+
# until get_iter() fetches a complete message or is canceled.
187+
188+
# If get_iter() raises an exception e.g. in decoder.decode(),
189+
# get_in_progress remains set and the connection becomes unusable.
190+
191+
# First frame
192+
try:
193+
frame = await self.recv_frames.receive()
194+
except trio.Cancelled:
195+
self.get_in_progress = False
196+
raise
197+
self.maybe_resume()
198+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
199+
if decode is None:
200+
decode = frame.opcode is OP_TEXT
201+
if decode:
202+
decoder = UTF8Decoder()
203+
yield decoder.decode(frame.data, frame.fin)
204+
else:
205+
yield frame.data
206+
207+
# Following frames, for fragmented messages
208+
while not frame.fin:
209+
# We cannot handle trio.Cancelled because we don't buffer
210+
# previous fragments — we're streaming them. Canceling get_iter()
211+
# here will leave the assembler in a stuck state. Future calls to
212+
# get() or get_iter() will raise ConcurrencyError.
213+
frame = await self.recv_frames.receive()
214+
self.maybe_resume()
215+
assert frame.opcode is OP_CONT
216+
if decode:
217+
yield decoder.decode(frame.data, frame.fin)
218+
else:
219+
yield frame.data
220+
221+
self.get_in_progress = False
222+
223+
def put(self, frame: Frame) -> None:
224+
"""
225+
Add ``frame`` to the next message.
226+
227+
Raises:
228+
trio.EndOfChannel: If the stream of frames has ended.
229+
230+
"""
231+
if self.closed:
232+
raise trio.EndOfChannel("stream of frames ended")
233+
234+
self.send_frames.send_nowait(frame)
235+
self.maybe_pause()
236+
237+
def maybe_pause(self) -> None:
238+
"""Pause the writer if queue is above the high water mark."""
239+
# Skip if flow control is disabled
240+
if self.high is None:
241+
return
242+
243+
# Bypass the statistics() method for performance reasons.
244+
# Check for "> high" to support high = 0
245+
if len(self.send_frames._state.data) > self.high and not self.paused:
246+
self.paused = True
247+
self.pause()
248+
249+
def maybe_resume(self) -> None:
250+
"""Resume the writer if queue is below the low water mark."""
251+
# Skip if flow control is disabled
252+
if self.low is None:
253+
return
254+
255+
# Bypass the statistics() method for performance reasons.
256+
# Check for "<= low" to support low = 0
257+
if len(self.send_frames._state.data) <= self.low and self.paused:
258+
self.paused = False
259+
self.resume()
260+
261+
def close(self) -> None:
262+
"""
263+
End the stream of frames.
264+
265+
Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
266+
or :meth:`put` is safe. They will raise :exc:`trio.EndOfChannel`.
267+
268+
"""
269+
if self.closed:
270+
return
271+
272+
self.closed = True
273+
274+
# Unblock get() or get_iter().
275+
self.send_frames.close()

0 commit comments

Comments
 (0)