2
2
#
3
3
# SPDX-License-Identifier: MIT
4
4
5
+ import asyncio
5
6
import json
6
7
import os
7
8
import warnings
8
- from typing import Any , Optional , cast
9
+ from typing import Any , Callable , Optional , cast
9
10
10
11
import websockets
11
12
12
13
from .__version__ import get_version
13
14
from .constants import (
14
15
DEFAULT_WS_URL ,
16
+ MSG_TYPE_EVENT ,
15
17
MSG_TYPE_HELLO ,
16
18
MSG_TYPE_RESPONSE ,
17
19
PROTOCOL_VERSION ,
18
20
)
19
21
from .exceptions import ProtocolError , ServerError , WokwiError
20
- from .protocol_types import HelloMessage , IncomingMessage , ResponseMessage
22
+ from .protocol_types import EventMessage , HelloMessage , IncomingMessage , ResponseMessage
21
23
22
24
TRANSPORT_DEFAULT_WS_URL = os .getenv ("WOKWI_CLI_SERVER" , DEFAULT_WS_URL )
23
25
@@ -28,6 +30,10 @@ def __init__(self, token: str, url: str = TRANSPORT_DEFAULT_WS_URL):
28
30
self ._url = url
29
31
self ._next_id = 1
30
32
self ._ws : Optional [websockets .WebSocketClientProtocol ] = None
33
+ self ._event_listeners : dict [str , list [Callable [[EventMessage ], Any ]]] = {}
34
+ self ._response_futures : dict [str , asyncio .Future [ResponseMessage ]] = {}
35
+ self ._recv_task : Optional [asyncio .Task [None ]] = None
36
+ self ._closed = False
31
37
32
38
async def connect (self ) -> dict [str , Any ]:
33
39
self ._ws = await websockets .connect (
@@ -41,28 +47,85 @@ async def connect(self) -> dict[str, Any]:
41
47
if hello ["type" ] != MSG_TYPE_HELLO or hello .get ("protocolVersion" ) != PROTOCOL_VERSION :
42
48
raise ProtocolError (f"Unsupported protocol handshake: { hello } " )
43
49
hello_msg = cast (HelloMessage , hello )
50
+ self ._closed = False
51
+ # Start background message processor
52
+ self ._recv_task = asyncio .create_task (self ._background_recv ())
44
53
return {"version" : hello_msg ["appVersion" ]}
45
54
46
55
async def close (self ) -> None :
56
+ self ._closed = True
57
+ if self ._recv_task :
58
+ self ._recv_task .cancel ()
59
+ try :
60
+ await self ._recv_task
61
+ except asyncio .CancelledError :
62
+ pass
47
63
if self ._ws :
48
64
await self ._ws .close ()
49
65
66
+ def add_event_listener (self , event_type : str , listener : Callable [[EventMessage ], Any ]) -> None :
67
+ """Register a listener for a specific event type."""
68
+ if event_type not in self ._event_listeners :
69
+ self ._event_listeners [event_type ] = []
70
+ self ._event_listeners [event_type ].append (listener )
71
+
72
+ def remove_event_listener (
73
+ self , event_type : str , listener : Callable [[EventMessage ], Any ]
74
+ ) -> None :
75
+ """Remove a previously registered listener for a specific event type."""
76
+ if event_type in self ._event_listeners :
77
+ self ._event_listeners [event_type ] = [
78
+ registered_listener
79
+ for registered_listener in self ._event_listeners [event_type ]
80
+ if registered_listener != listener
81
+ ]
82
+ if not self ._event_listeners [event_type ]:
83
+ del self ._event_listeners [event_type ]
84
+
85
+ async def _dispatch_event (self , event_msg : EventMessage ) -> None :
86
+ listeners = self ._event_listeners .get (event_msg ["event" ], [])
87
+ for listener in listeners :
88
+ result = listener (event_msg )
89
+ if hasattr (result , "__await__" ):
90
+ await result
91
+
50
92
async def request (self , command : str , params : dict [str , Any ]) -> ResponseMessage :
51
93
msg_id = str (self ._next_id )
52
94
self ._next_id += 1
53
95
if self ._ws is None :
54
96
raise WokwiError ("Not connected" )
97
+ loop = asyncio .get_running_loop ()
98
+ future : asyncio .Future [ResponseMessage ] = loop .create_future ()
99
+ self ._response_futures [msg_id ] = future
55
100
await self ._ws .send (
56
101
json .dumps ({"type" : "command" , "command" : command , "params" : params , "id" : msg_id })
57
102
)
58
- while True :
59
- msg : IncomingMessage = await self ._recv ()
60
- if msg ["type" ] == MSG_TYPE_RESPONSE and msg .get ("id" ) == msg_id :
61
- resp_msg = cast (ResponseMessage , msg )
62
- if resp_msg .get ("error" ):
63
- result = resp_msg ["result" ]
64
- raise ServerError (result ["message" ])
65
- return resp_msg
103
+ try :
104
+ resp_msg_resp = await future
105
+ if resp_msg_resp .get ("error" ):
106
+ result = resp_msg_resp ["result" ]
107
+ raise ServerError (result ["message" ])
108
+ return resp_msg_resp
109
+ finally :
110
+ del self ._response_futures [msg_id ]
111
+
112
+ async def _background_recv (self ) -> None :
113
+ try :
114
+ while not self ._closed and self ._ws is not None :
115
+ msg : IncomingMessage = await self ._recv ()
116
+ if msg ["type" ] == MSG_TYPE_EVENT :
117
+ resp_msg_event = cast (EventMessage , msg )
118
+ await self ._dispatch_event (resp_msg_event )
119
+ elif msg ["type" ] == MSG_TYPE_RESPONSE :
120
+ resp_msg_resp = cast (ResponseMessage , msg )
121
+ future = self ._response_futures .get (resp_msg_resp ["id" ])
122
+ if future is None or future .done ():
123
+ continue
124
+ future .set_result (resp_msg_resp )
125
+ except (websockets .ConnectionClosed , asyncio .CancelledError ):
126
+ pass
127
+ except Exception as e :
128
+ warnings .warn (f"Background recv error: { e } " , RuntimeWarning )
66
129
67
130
async def _recv (self ) -> IncomingMessage :
68
131
if self ._ws is None :
@@ -87,6 +150,3 @@ async def _recv(self) -> IncomingMessage:
87
150
)
88
151
raise WokwiError (f"Server error { result ['code' ]} : { result ['message' ]} " )
89
152
return cast (IncomingMessage , message )
90
-
91
- async def recv (self ) -> IncomingMessage :
92
- return await self ._recv ()
0 commit comments