1
1
# -*- coding: utf-8 -*-
2
2
3
3
import zlib
4
- from typing import cast , Dict , Optional , Union
4
+ from typing import cast , Dict , Optional , Sequence , TYPE_CHECKING , Union
5
5
6
6
import pytest
7
7
from _pytest .monkeypatch import MonkeyPatch
8
8
9
9
import wsproto .extensions as wpext
10
10
import wsproto .frame_protocol as fp
11
11
12
- Params = Dict [str , Union [bool , Optional [int ]]]
12
+ if TYPE_CHECKING :
13
+ from mypy_extensions import TypedDict
14
+
15
+ class Params (TypedDict , total = False ):
16
+ client_no_context_takeover : bool
17
+ client_max_window_bits : Optional [int ]
18
+ server_no_context_takeover : bool
19
+ server_max_window_bits : Optional [int ]
20
+
21
+
22
+ else :
23
+ Params = dict
13
24
14
25
15
26
class TestPerMessageDeflate :
@@ -42,7 +53,7 @@ class TestPerMessageDeflate:
42
53
{"server_no_context_takeover" : True , "client_max_window_bits" : 8 },
43
54
{"client_max_window_bits" : None , "server_max_window_bits" : None },
44
55
{},
45
- ]
56
+ ] # type: Sequence[Params]
46
57
47
58
def make_offer_string (self , params : Params ) -> str :
48
59
offer = ["permessage-deflate" ]
@@ -90,7 +101,7 @@ def compare_params_to_string(
90
101
91
102
@pytest .mark .parametrize ("params" , parameter_sets )
92
103
def test_offer (self , params : Params ) -> None :
93
- ext = wpext .PerMessageDeflate (** params ) # type: ignore
104
+ ext = wpext .PerMessageDeflate (** params )
94
105
offer = ext .offer ()
95
106
offer = cast (str , offer )
96
107
@@ -101,7 +112,6 @@ def test_finalize(self, params: Params) -> None:
101
112
ext = wpext .PerMessageDeflate ()
102
113
assert not ext .enabled ()
103
114
104
- params = dict (params )
105
115
if "client_max_window_bits" in params :
106
116
if params ["client_max_window_bits" ] is None :
107
117
del params ["client_max_window_bits" ]
@@ -178,7 +188,8 @@ def test_inbound_uncompressed_control_frame(self) -> None:
178
188
result = ext .frame_inbound_header (
179
189
proto , fp .Opcode .PING , fp .RsvBits (False , False , False ), len (payload )
180
190
)
181
- assert result .rsv1 # type: ignore
191
+ assert isinstance (result , fp .RsvBits )
192
+ assert result .rsv1
182
193
183
194
data = ext .frame_inbound_payload_data (proto , payload )
184
195
assert data == payload
@@ -219,7 +230,8 @@ def test_inbound_uncompressed_data_frame(self) -> None:
219
230
result = ext .frame_inbound_header (
220
231
proto , fp .Opcode .BINARY , fp .RsvBits (False , False , False ), len (payload )
221
232
)
222
- assert result .rsv1 # type: ignore
233
+ assert isinstance (result , fp .RsvBits )
234
+ assert result .rsv1
223
235
224
236
data = ext .frame_inbound_payload_data (proto , payload )
225
237
assert data == payload
@@ -241,11 +253,14 @@ def test_client_inbound_compressed_single_data_frame(self, client: bool) -> None
241
253
fp .RsvBits (True , False , False ),
242
254
len (compressed_payload ),
243
255
)
244
- assert result .rsv1 # type: ignore
256
+ assert isinstance (result , fp .RsvBits )
257
+ assert result .rsv1
245
258
246
259
data = ext .frame_inbound_payload_data (proto , compressed_payload )
247
- data += ext .frame_inbound_complete (proto , True ) # type: ignore
248
- assert data == payload
260
+ assert isinstance (data , bytes )
261
+ data2 = ext .frame_inbound_complete (proto , True )
262
+ assert isinstance (data2 , bytes )
263
+ assert data + data2 == payload
249
264
250
265
@pytest .mark .parametrize ("client" , [True , False ])
251
266
def test_client_inbound_compressed_multiple_data_frames (self , client : bool ) -> None :
@@ -261,30 +276,28 @@ def test_client_inbound_compressed_multiple_data_frames(self, client: bool) -> N
261
276
result = ext .frame_inbound_header (
262
277
proto , fp .Opcode .BINARY , fp .RsvBits (True , False , False ), split
263
278
)
264
- assert result .rsv1 # type: ignore
265
- result = ext .frame_inbound_payload_data ( # type: ignore
266
- proto , compressed_payload [:split ]
267
- )
268
- assert not isinstance (result , fp .CloseReason )
269
- data += result # type: ignore
279
+ assert isinstance (result , fp .RsvBits )
280
+ assert result .rsv1
281
+ result2 = ext .frame_inbound_payload_data (proto , compressed_payload [:split ])
282
+ assert not isinstance (result2 , fp .CloseReason )
283
+ data += result2
270
284
assert ext .frame_inbound_complete (proto , False ) is None
271
285
272
- result = ext .frame_inbound_header (
286
+ result3 = ext .frame_inbound_header (
273
287
proto ,
274
288
fp .Opcode .CONTINUATION ,
275
289
fp .RsvBits (False , False , False ),
276
290
len (compressed_payload ) - split ,
277
291
)
278
- assert result .rsv1 # type: ignore
279
- result = ext .frame_inbound_payload_data ( # type: ignore
280
- proto , compressed_payload [split :]
281
- )
282
- assert not isinstance (result , fp .CloseReason )
283
- data += result # type: ignore
292
+ assert isinstance (result3 , fp .RsvBits )
293
+ assert result3 .rsv1
294
+ result4 = ext .frame_inbound_payload_data (proto , compressed_payload [split :])
295
+ assert not isinstance (result4 , fp .CloseReason )
296
+ data += result4
284
297
285
- result = ext .frame_inbound_complete (proto , True ) # type: ignore
286
- assert not isinstance (result , fp .CloseReason )
287
- data += result # type: ignore
298
+ result5 = ext .frame_inbound_complete (proto , True )
299
+ assert not isinstance (result5 , fp .CloseReason )
300
+ data += result5
288
301
289
302
assert data == payload
290
303
@@ -298,28 +311,27 @@ def test_client_decompress_after_uncompressible_frame(self, client: bool) -> Non
298
311
result = ext .frame_inbound_header (
299
312
proto , fp .Opcode .PING , fp .RsvBits (False , False , False ), 0
300
313
)
301
- result = ext .frame_inbound_payload_data (proto , b"" ) # type: ignore
302
- assert not isinstance (result , fp .CloseReason )
314
+ result2 = ext .frame_inbound_payload_data (proto , b"" )
315
+ assert not isinstance (result2 , fp .CloseReason )
303
316
assert ext .frame_inbound_complete (proto , True ) is None
304
317
305
318
# A compressed TEXT frame
306
319
payload = b"x" * 23
307
320
compressed_payload = b"\xaa \xa8 \xc0 \n \x00 \x00 "
308
321
309
- result = ext .frame_inbound_header (
322
+ result3 = ext .frame_inbound_header (
310
323
proto ,
311
324
fp .Opcode .TEXT ,
312
325
fp .RsvBits (True , False , False ),
313
326
len (compressed_payload ),
314
327
)
315
- assert result .rsv1 # type: ignore
316
- result = ext .frame_inbound_payload_data ( # type: ignore
317
- proto , compressed_payload
318
- )
319
- assert result == payload
328
+ assert isinstance (result3 , fp .RsvBits )
329
+ assert result3 .rsv1
330
+ result4 = ext .frame_inbound_payload_data (proto , compressed_payload )
331
+ assert result4 == payload
320
332
321
- result = ext .frame_inbound_complete (proto , True ) # type: ignore
322
- assert not isinstance (result , fp .CloseReason )
333
+ result5 = ext .frame_inbound_complete (proto , True )
334
+ assert not isinstance (result5 , fp .CloseReason )
323
335
324
336
def test_inbound_bad_zlib_payload (self ) -> None :
325
337
compressed_payload = b"x" * 23
@@ -334,11 +346,10 @@ def test_inbound_bad_zlib_payload(self) -> None:
334
346
fp .RsvBits (True , False , False ),
335
347
len (compressed_payload ),
336
348
)
337
- assert result .rsv1 # type: ignore
338
- result = ext .frame_inbound_payload_data ( # type: ignore
339
- proto , compressed_payload
340
- )
341
- assert result is fp .CloseReason .INVALID_FRAME_PAYLOAD_DATA
349
+ assert isinstance (result , fp .RsvBits )
350
+ assert result .rsv1
351
+ result2 = ext .frame_inbound_payload_data (proto , compressed_payload )
352
+ assert result2 is fp .CloseReason .INVALID_FRAME_PAYLOAD_DATA
342
353
343
354
def test_inbound_bad_zlib_decoder_end_state (self , monkeypatch : MonkeyPatch ) -> None :
344
355
compressed_payload = b"x" * 23
@@ -353,7 +364,8 @@ def test_inbound_bad_zlib_decoder_end_state(self, monkeypatch: MonkeyPatch) -> N
353
364
fp .RsvBits (True , False , False ),
354
365
len (compressed_payload ),
355
366
)
356
- assert result .rsv1 # type: ignore
367
+ assert isinstance (result , fp .RsvBits )
368
+ assert result .rsv1
357
369
358
370
class FailDecompressor :
359
371
def decompress (self , data : bytes ) -> bytes :
@@ -364,8 +376,8 @@ def flush(self) -> None:
364
376
365
377
monkeypatch .setattr (ext , "_decompressor" , FailDecompressor ())
366
378
367
- result = ext .frame_inbound_complete (proto , True ) # type: ignore
368
- assert result is fp .CloseReason .INVALID_FRAME_PAYLOAD_DATA
379
+ result2 = ext .frame_inbound_complete (proto , True )
380
+ assert result2 is fp .CloseReason .INVALID_FRAME_PAYLOAD_DATA
369
381
370
382
@pytest .mark .parametrize (
371
383
"client,no_context_takeover" ,
@@ -383,22 +395,24 @@ def test_decompressor_reset(self, client: bool, no_context_takeover: bool) -> No
383
395
result = ext .frame_inbound_header (
384
396
proto , fp .Opcode .BINARY , fp .RsvBits (True , False , False ), 0
385
397
)
386
- assert result .rsv1 # type: ignore
398
+ assert isinstance (result , fp .RsvBits )
399
+ assert result .rsv1
387
400
388
401
assert ext ._decompressor is not None
389
402
390
- result = ext .frame_inbound_complete (proto , True ) # type: ignore
391
- assert not isinstance (result , fp .CloseReason )
403
+ result2 = ext .frame_inbound_complete (proto , True )
404
+ assert not isinstance (result2 , fp .CloseReason )
392
405
393
406
if no_context_takeover :
394
407
assert ext ._decompressor is None
395
408
else :
396
409
assert ext ._decompressor is not None
397
410
398
- result = ext .frame_inbound_header (
411
+ result3 = ext .frame_inbound_header (
399
412
proto , fp .Opcode .BINARY , fp .RsvBits (True , False , False ), 0
400
413
)
401
- assert result .rsv1 # type: ignore
414
+ assert isinstance (result3 , fp .RsvBits )
415
+ assert result3 .rsv1
402
416
403
417
assert ext ._decompressor is not None
404
418
@@ -486,5 +500,5 @@ def test_compressor_reset(self, client: bool, no_context_takeover: bool) -> None
486
500
487
501
@pytest .mark .parametrize ("params" , parameter_sets )
488
502
def test_repr (self , params : Params ) -> None :
489
- ext = wpext .PerMessageDeflate (** params ) # type: ignore
503
+ ext = wpext .PerMessageDeflate (** params )
490
504
self .compare_params_to_string (params , ext , repr (ext ))
0 commit comments