Skip to content

Commit a621204

Browse files
committed
refactor: use a background task to process incoming responses
This makes sure that we keep getting events from the simulator in the background, even when there's on outstanding request.
1 parent 4aeea23 commit a621204

File tree

3 files changed

+135
-21
lines changed

3 files changed

+135
-21
lines changed

src/wokwi_client/event_queue.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-FileCopyrightText: 2025-present CodeMagic LTD
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
"""
6+
Queue-based event subscription helper for Transport.
7+
8+
Usage:
9+
with EventQueue(transport, "serial-monitor:data") as queue:
10+
# Use the queue to get events, calling get() or get_nowait()
11+
event = await queue.get()
12+
# do something with the event
13+
...
14+
"""
15+
16+
import asyncio
17+
18+
from .protocol_types import EventMessage
19+
from .transport import Transport
20+
21+
22+
class EventQueue:
23+
"""A queue for events from a specific event type."""
24+
25+
def __init__(self, transport: Transport, event_type: str) -> None:
26+
self._queue: asyncio.Queue[EventMessage] = asyncio.Queue()
27+
self._transport = transport
28+
self._event_type = event_type
29+
30+
def listener(event: EventMessage) -> None:
31+
self._queue.put_nowait(event)
32+
33+
self._listener = listener
34+
self._transport.add_event_listener(self._event_type, self._listener)
35+
36+
def close(self) -> None:
37+
"""Close the queue. This is useful when you want to stop listening for events."""
38+
self._transport.remove_event_listener(self._event_type, self._listener)
39+
40+
async def get(self) -> EventMessage:
41+
"""Get an event from the queue. Blocks until an event is available."""
42+
return await self._queue.get()
43+
44+
def get_nowait(self) -> EventMessage:
45+
"""Get an event from the queue. Raises QueueEmpty if no event is available."""
46+
return self._queue.get_nowait()
47+
48+
def flush(self) -> None:
49+
"""Flush the queue. This is useful when you want to wait for all events to be processed."""
50+
while not self._queue.empty():
51+
self._queue.get_nowait()
52+
53+
def __enter__(self) -> "EventQueue":
54+
return self
55+
56+
def __exit__(self, exc_type: type, exc_value: Exception, traceback: object) -> None:
57+
self.close()

src/wokwi_client/serial.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33
# SPDX-License-Identifier: MIT
44

55
from collections.abc import AsyncGenerator
6-
from typing import cast
76

8-
from .protocol_types import EventMessage
7+
from .event_queue import EventQueue
98
from .transport import Transport
109

1110

1211
async def monitor_lines(transport: Transport) -> AsyncGenerator[bytes, None]:
1312
await transport.request("serial-monitor:listen", {})
14-
while True:
15-
msg = await transport.recv()
16-
if msg["type"] == "event":
17-
event_msg = cast(EventMessage, msg)
18-
if event_msg["event"] == "serial-monitor:data":
19-
yield bytes(event_msg["payload"]["bytes"])
13+
with EventQueue(transport, "serial-monitor:data") as queue:
14+
while True:
15+
event_msg = await queue.get()
16+
yield bytes(event_msg["payload"]["bytes"])

src/wokwi_client/transport.py

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,24 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5+
import asyncio
56
import json
67
import os
78
import warnings
8-
from typing import Any, Optional, cast
9+
from typing import Any, Callable, Optional, cast
910

1011
import websockets
1112

1213
from .__version__ import get_version
1314
from .constants import (
1415
DEFAULT_WS_URL,
16+
MSG_TYPE_EVENT,
1517
MSG_TYPE_HELLO,
1618
MSG_TYPE_RESPONSE,
1719
PROTOCOL_VERSION,
1820
)
1921
from .exceptions import ProtocolError, ServerError, WokwiError
20-
from .protocol_types import HelloMessage, IncomingMessage, ResponseMessage
22+
from .protocol_types import EventMessage, HelloMessage, IncomingMessage, ResponseMessage
2123

2224
TRANSPORT_DEFAULT_WS_URL = os.getenv("WOKWI_CLI_SERVER", DEFAULT_WS_URL)
2325

@@ -28,6 +30,10 @@ def __init__(self, token: str, url: str = TRANSPORT_DEFAULT_WS_URL):
2830
self._url = url
2931
self._next_id = 1
3032
self._ws: Optional[websockets.WebSocketClientProtocol] = None
33+
self._event_listeners: dict[str, list[Callable[[EventMessage], Any]]] = {}
34+
self._response_futures: dict[str, asyncio.Future[ResponseMessage]] = {}
35+
self._recv_task: Optional[asyncio.Task[None]] = None
36+
self._closed = False
3137

3238
async def connect(self) -> dict[str, Any]:
3339
self._ws = await websockets.connect(
@@ -41,28 +47,85 @@ async def connect(self) -> dict[str, Any]:
4147
if hello["type"] != MSG_TYPE_HELLO or hello.get("protocolVersion") != PROTOCOL_VERSION:
4248
raise ProtocolError(f"Unsupported protocol handshake: {hello}")
4349
hello_msg = cast(HelloMessage, hello)
50+
self._closed = False
51+
# Start background message processor
52+
self._recv_task = asyncio.create_task(self._background_recv())
4453
return {"version": hello_msg["appVersion"]}
4554

4655
async def close(self) -> None:
56+
self._closed = True
57+
if self._recv_task:
58+
self._recv_task.cancel()
59+
try:
60+
await self._recv_task
61+
except asyncio.CancelledError:
62+
pass
4763
if self._ws:
4864
await self._ws.close()
4965

66+
def add_event_listener(self, event_type: str, listener: Callable[[EventMessage], Any]) -> None:
67+
"""Register a listener for a specific event type."""
68+
if event_type not in self._event_listeners:
69+
self._event_listeners[event_type] = []
70+
self._event_listeners[event_type].append(listener)
71+
72+
def remove_event_listener(
73+
self, event_type: str, listener: Callable[[EventMessage], Any]
74+
) -> None:
75+
"""Remove a previously registered listener for a specific event type."""
76+
if event_type in self._event_listeners:
77+
self._event_listeners[event_type] = [
78+
registered_listener
79+
for registered_listener in self._event_listeners[event_type]
80+
if registered_listener != listener
81+
]
82+
if not self._event_listeners[event_type]:
83+
del self._event_listeners[event_type]
84+
85+
async def _dispatch_event(self, event_msg: EventMessage) -> None:
86+
listeners = self._event_listeners.get(event_msg["event"], [])
87+
for listener in listeners:
88+
result = listener(event_msg)
89+
if hasattr(result, "__await__"):
90+
await result
91+
5092
async def request(self, command: str, params: dict[str, Any]) -> ResponseMessage:
5193
msg_id = str(self._next_id)
5294
self._next_id += 1
5395
if self._ws is None:
5496
raise WokwiError("Not connected")
97+
loop = asyncio.get_running_loop()
98+
future: asyncio.Future[ResponseMessage] = loop.create_future()
99+
self._response_futures[msg_id] = future
55100
await self._ws.send(
56101
json.dumps({"type": "command", "command": command, "params": params, "id": msg_id})
57102
)
58-
while True:
59-
msg: IncomingMessage = await self._recv()
60-
if msg["type"] == MSG_TYPE_RESPONSE and msg.get("id") == msg_id:
61-
resp_msg = cast(ResponseMessage, msg)
62-
if resp_msg.get("error"):
63-
result = resp_msg["result"]
64-
raise ServerError(result["message"])
65-
return resp_msg
103+
try:
104+
resp_msg_resp = await future
105+
if resp_msg_resp.get("error"):
106+
result = resp_msg_resp["result"]
107+
raise ServerError(result["message"])
108+
return resp_msg_resp
109+
finally:
110+
del self._response_futures[msg_id]
111+
112+
async def _background_recv(self) -> None:
113+
try:
114+
while not self._closed and self._ws is not None:
115+
msg: IncomingMessage = await self._recv()
116+
if msg["type"] == MSG_TYPE_EVENT:
117+
resp_msg_event = cast(EventMessage, msg)
118+
await self._dispatch_event(resp_msg_event)
119+
elif msg["type"] == MSG_TYPE_RESPONSE:
120+
resp_msg_resp = cast(ResponseMessage, msg)
121+
future = self._response_futures.get(resp_msg_resp["id"])
122+
if future is None or future.done():
123+
continue
124+
future.set_result(resp_msg_resp)
125+
except (websockets.ConnectionClosed, asyncio.CancelledError):
126+
pass
127+
except Exception as e:
128+
warnings.warn(f"Background recv error: {e}", RuntimeWarning)
66129

67130
async def _recv(self) -> IncomingMessage:
68131
if self._ws is None:
@@ -87,6 +150,3 @@ async def _recv(self) -> IncomingMessage:
87150
)
88151
raise WokwiError(f"Server error {result['code']}: {result['message']}")
89152
return cast(IncomingMessage, message)
90-
91-
async def recv(self) -> IncomingMessage:
92-
return await self._recv()

0 commit comments

Comments
 (0)