|
1 | | -"""Mock transport and network for testing.""" |
| 1 | +"""Pytest fixtures for the test suite.""" |
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import asyncio |
6 | | -import random |
7 | | -from collections.abc import Callable |
8 | 6 |
|
9 | 7 | import pytest |
10 | 8 |
|
11 | | -from pycyphal import Closable, Instant, Priority, TransportArrival |
12 | | - |
13 | | -# A small prime modulus suitable for testing. |
14 | | -DEFAULT_MODULUS = 122743 |
15 | | - |
16 | | - |
17 | | -# ===================================================================================================================== |
18 | | -# MockSubjectWriter |
19 | | -# ===================================================================================================================== |
20 | | - |
21 | | - |
22 | | -class MockSubjectWriter(Closable): |
23 | | - def __init__(self, transport: MockTransport, subject_id: int) -> None: |
24 | | - self.transport = transport |
25 | | - self.subject_id = subject_id |
26 | | - self.closed = False |
27 | | - self.send_count = 0 |
28 | | - self.fail_next = False |
29 | | - |
30 | | - async def __call__(self, deadline: Instant, priority: Priority, message: bytes | memoryview) -> None: |
31 | | - if self.closed: |
32 | | - raise RuntimeError("Writer closed") |
33 | | - if self.fail_next: |
34 | | - self.fail_next = False |
35 | | - raise RuntimeError("Simulated send failure") |
36 | | - self.send_count += 1 |
37 | | - msg_bytes = bytes(message) |
38 | | - arrival = TransportArrival( |
39 | | - timestamp=Instant.now(), |
40 | | - priority=priority, |
41 | | - remote_id=self.transport.node_id, |
42 | | - message=msg_bytes, |
43 | | - ) |
44 | | - # Deliver to all listeners on this subject across the network |
45 | | - if self.transport.network is not None: |
46 | | - self.transport.network.deliver_subject(self.subject_id, arrival, sender=self.transport) |
47 | | - else: |
48 | | - # Local-only delivery |
49 | | - for handler in self.transport._subject_handlers.get(self.subject_id, []): |
50 | | - handler(arrival) |
51 | | - |
52 | | - def close(self) -> None: |
53 | | - self.closed = True |
54 | | - |
55 | | - |
56 | | -# ===================================================================================================================== |
57 | | -# MockSubjectListener |
58 | | -# ===================================================================================================================== |
59 | | - |
60 | | - |
61 | | -class MockSubjectListener(Closable): |
62 | | - def __init__(self, transport: MockTransport, subject_id: int, handler: Callable[[TransportArrival], None]) -> None: |
63 | | - self.transport = transport |
64 | | - self.subject_id = subject_id |
65 | | - self.handler = handler |
66 | | - self.closed = False |
67 | | - |
68 | | - def close(self) -> None: |
69 | | - self.closed = True |
70 | | - handlers = self.transport._subject_handlers.get(self.subject_id, []) |
71 | | - if self.handler in handlers: |
72 | | - handlers.remove(self.handler) |
73 | | - if not handlers: |
74 | | - self.transport._subject_handlers.pop(self.subject_id, None) |
75 | | - |
76 | | - |
77 | | -# ===================================================================================================================== |
78 | | -# MockTransport |
79 | | -# ===================================================================================================================== |
80 | | - |
81 | | - |
82 | | -class MockTransport(Closable): |
83 | | - def __init__(self, node_id: int = 0, modulus: int = DEFAULT_MODULUS, network: MockNetwork | None = None) -> None: |
84 | | - self.node_id = node_id |
85 | | - self._modulus = modulus |
86 | | - self.network = network |
87 | | - self._subject_handlers: dict[int, list[Callable[[TransportArrival], None]]] = {} |
88 | | - self._unicast_handler: Callable[[TransportArrival], None] | None = None |
89 | | - self._writers: dict[int, MockSubjectWriter] = {} |
90 | | - self.unicast_log: list[tuple[int, bytes]] = [] |
91 | | - self.closed = False |
92 | | - self.fail_unicast = False |
93 | | - |
94 | | - if network is not None: |
95 | | - network.add_transport(self) |
96 | | - |
97 | | - @property |
98 | | - def subject_id_modulus(self) -> int: |
99 | | - return self._modulus |
100 | | - |
101 | | - def subject_listen(self, subject_id: int, handler: Callable[[TransportArrival], None]) -> Closable: |
102 | | - if subject_id not in self._subject_handlers: |
103 | | - self._subject_handlers[subject_id] = [] |
104 | | - self._subject_handlers[subject_id].append(handler) |
105 | | - return MockSubjectListener(self, subject_id, handler) |
106 | | - |
107 | | - def subject_advertise(self, subject_id: int) -> MockSubjectWriter: |
108 | | - writer = MockSubjectWriter(self, subject_id) |
109 | | - self._writers[subject_id] = writer |
110 | | - return writer |
111 | | - |
112 | | - def unicast_listen(self, handler: Callable[[TransportArrival], None]) -> None: |
113 | | - self._unicast_handler = handler |
114 | | - |
115 | | - async def unicast(self, deadline: Instant, priority: Priority, remote_id: int, message: bytes | memoryview) -> None: |
116 | | - if self.closed: |
117 | | - raise RuntimeError("Transport closed") |
118 | | - if self.fail_unicast: |
119 | | - raise RuntimeError("Simulated unicast failure") |
120 | | - msg_bytes = bytes(message) |
121 | | - self.unicast_log.append((remote_id, msg_bytes)) |
122 | | - arrival = TransportArrival( |
123 | | - timestamp=Instant.now(), |
124 | | - priority=priority, |
125 | | - remote_id=self.node_id, |
126 | | - message=msg_bytes, |
127 | | - ) |
128 | | - if self.network is not None: |
129 | | - self.network.deliver_unicast(remote_id, arrival) |
130 | | - else: |
131 | | - # Local unicast: deliver to own handler |
132 | | - if self._unicast_handler is not None: |
133 | | - self._unicast_handler(arrival) |
134 | | - |
135 | | - def close(self) -> None: |
136 | | - self.closed = True |
137 | | - |
138 | | - def deliver_subject(self, subject_id: int, arrival: TransportArrival) -> None: |
139 | | - """Deliver a subject message to local handlers.""" |
140 | | - for handler in self._subject_handlers.get(subject_id, []): |
141 | | - handler(arrival) |
142 | | - |
143 | | - def deliver_unicast(self, arrival: TransportArrival) -> None: |
144 | | - """Deliver a unicast message to local handler.""" |
145 | | - if self._unicast_handler is not None: |
146 | | - self._unicast_handler(arrival) |
147 | | - |
148 | | - |
149 | | -# ===================================================================================================================== |
150 | | -# MockNetwork |
151 | | -# ===================================================================================================================== |
152 | | - |
153 | | - |
154 | | -class MockNetwork: |
155 | | - """Simulates a network connecting multiple MockTransport instances.""" |
156 | | - |
157 | | - def __init__(self, *, delay: float = 0.0, drop_rate: float = 0.0) -> None: |
158 | | - self.transports: dict[int, MockTransport] = {} |
159 | | - self.delay = delay |
160 | | - self.drop_rate = drop_rate |
161 | | - self.message_log: list[tuple[str, int, bytes]] = [] |
162 | | - |
163 | | - def add_transport(self, transport: MockTransport) -> None: |
164 | | - self.transports[transport.node_id] = transport |
165 | | - |
166 | | - def deliver_subject(self, subject_id: int, arrival: TransportArrival, sender: MockTransport) -> None: |
167 | | - """Deliver subject message to all transports (including sender for loopback).""" |
168 | | - for tid, transport in self.transports.items(): |
169 | | - if random.random() < self.drop_rate: |
170 | | - continue |
171 | | - transport.deliver_subject(subject_id, arrival) |
172 | | - |
173 | | - def deliver_unicast(self, remote_id: int, arrival: TransportArrival) -> None: |
174 | | - """Deliver unicast message to specific transport.""" |
175 | | - transport = self.transports.get(remote_id) |
176 | | - if transport is not None: |
177 | | - if random.random() >= self.drop_rate: |
178 | | - transport.deliver_unicast(arrival) |
179 | | - |
180 | | - |
181 | | -# ===================================================================================================================== |
182 | | -# Fixtures |
183 | | -# ===================================================================================================================== |
| 9 | +from tests.mock_transport import MockTransport, MockNetwork |
184 | 10 |
|
185 | 11 |
|
186 | 12 | @pytest.fixture |
187 | | -def mock_network(): |
| 13 | +def mock_network() -> MockNetwork: |
188 | 14 | return MockNetwork() |
189 | 15 |
|
190 | 16 |
|
191 | 17 | @pytest.fixture |
192 | | -def mock_transport(): |
| 18 | +def mock_transport() -> MockTransport: |
193 | 19 | return MockTransport(node_id=1) |
194 | 20 |
|
195 | 21 |
|
196 | 22 | @pytest.fixture |
197 | | -def event_loop(): |
| 23 | +def event_loop(): # type: ignore[no-untyped-def] |
198 | 24 | loop = asyncio.new_event_loop() |
199 | 25 | yield loop |
200 | 26 | loop.close() |
0 commit comments