Skip to content

Commit c767c59

Browse files
bluetechpgjones
authored andcommitted
Enable mypy strict-equality check, clean up types in test_permessage_deflate.py
Enabling strict-equality triggers one warning in test_permessage_deflate.py. But in order to fix it, some larger cleanups are required. Remove the type: ignores, avoid redefinitions and fix the type errors.
1 parent adae9fc commit c767c59

File tree

2 files changed

+65
-51
lines changed

2 files changed

+65
-51
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ disallow_untyped_calls = True
4545
disallow_untyped_defs = True
4646
; implicit_reexport = False
4747
no_implicit_optional = True
48-
# strict_equality = True
48+
strict_equality = True
4949
strict_optional = False
5050
warn_redundant_casts = True
5151
# warn_return_any = True

test/test_permessage_deflate.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
# -*- coding: utf-8 -*-
22

33
import zlib
4-
from typing import cast, Dict, Optional, Union
4+
from typing import cast, Dict, Optional, Sequence, TYPE_CHECKING, Union
55

66
import pytest
77
from _pytest.monkeypatch import MonkeyPatch
88

99
import wsproto.extensions as wpext
1010
import wsproto.frame_protocol as fp
1111

12-
Params = Dict[str, Union[bool, Optional[int]]]
12+
if TYPE_CHECKING:
13+
from mypy_extensions import TypedDict
14+
15+
class Params(TypedDict, total=False):
16+
client_no_context_takeover: bool
17+
client_max_window_bits: Optional[int]
18+
server_no_context_takeover: bool
19+
server_max_window_bits: Optional[int]
20+
21+
22+
else:
23+
Params = dict
1324

1425

1526
class TestPerMessageDeflate:
@@ -42,7 +53,7 @@ class TestPerMessageDeflate:
4253
{"server_no_context_takeover": True, "client_max_window_bits": 8},
4354
{"client_max_window_bits": None, "server_max_window_bits": None},
4455
{},
45-
]
56+
] # type: Sequence[Params]
4657

4758
def make_offer_string(self, params: Params) -> str:
4859
offer = ["permessage-deflate"]
@@ -90,7 +101,7 @@ def compare_params_to_string(
90101

91102
@pytest.mark.parametrize("params", parameter_sets)
92103
def test_offer(self, params: Params) -> None:
93-
ext = wpext.PerMessageDeflate(**params) # type: ignore
104+
ext = wpext.PerMessageDeflate(**params)
94105
offer = ext.offer()
95106
offer = cast(str, offer)
96107

@@ -101,7 +112,6 @@ def test_finalize(self, params: Params) -> None:
101112
ext = wpext.PerMessageDeflate()
102113
assert not ext.enabled()
103114

104-
params = dict(params)
105115
if "client_max_window_bits" in params:
106116
if params["client_max_window_bits"] is None:
107117
del params["client_max_window_bits"]
@@ -178,7 +188,8 @@ def test_inbound_uncompressed_control_frame(self) -> None:
178188
result = ext.frame_inbound_header(
179189
proto, fp.Opcode.PING, fp.RsvBits(False, False, False), len(payload)
180190
)
181-
assert result.rsv1 # type: ignore
191+
assert isinstance(result, fp.RsvBits)
192+
assert result.rsv1
182193

183194
data = ext.frame_inbound_payload_data(proto, payload)
184195
assert data == payload
@@ -219,7 +230,8 @@ def test_inbound_uncompressed_data_frame(self) -> None:
219230
result = ext.frame_inbound_header(
220231
proto, fp.Opcode.BINARY, fp.RsvBits(False, False, False), len(payload)
221232
)
222-
assert result.rsv1 # type: ignore
233+
assert isinstance(result, fp.RsvBits)
234+
assert result.rsv1
223235

224236
data = ext.frame_inbound_payload_data(proto, payload)
225237
assert data == payload
@@ -241,11 +253,14 @@ def test_client_inbound_compressed_single_data_frame(self, client: bool) -> None
241253
fp.RsvBits(True, False, False),
242254
len(compressed_payload),
243255
)
244-
assert result.rsv1 # type: ignore
256+
assert isinstance(result, fp.RsvBits)
257+
assert result.rsv1
245258

246259
data = ext.frame_inbound_payload_data(proto, compressed_payload)
247-
data += ext.frame_inbound_complete(proto, True) # type: ignore
248-
assert data == payload
260+
assert isinstance(data, bytes)
261+
data2 = ext.frame_inbound_complete(proto, True)
262+
assert isinstance(data2, bytes)
263+
assert data + data2 == payload
249264

250265
@pytest.mark.parametrize("client", [True, False])
251266
def test_client_inbound_compressed_multiple_data_frames(self, client: bool) -> None:
@@ -261,30 +276,28 @@ def test_client_inbound_compressed_multiple_data_frames(self, client: bool) -> N
261276
result = ext.frame_inbound_header(
262277
proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), split
263278
)
264-
assert result.rsv1 # type: ignore
265-
result = ext.frame_inbound_payload_data( # type: ignore
266-
proto, compressed_payload[:split]
267-
)
268-
assert not isinstance(result, fp.CloseReason)
269-
data += result # type: ignore
279+
assert isinstance(result, fp.RsvBits)
280+
assert result.rsv1
281+
result2 = ext.frame_inbound_payload_data(proto, compressed_payload[:split])
282+
assert not isinstance(result2, fp.CloseReason)
283+
data += result2
270284
assert ext.frame_inbound_complete(proto, False) is None
271285

272-
result = ext.frame_inbound_header(
286+
result3 = ext.frame_inbound_header(
273287
proto,
274288
fp.Opcode.CONTINUATION,
275289
fp.RsvBits(False, False, False),
276290
len(compressed_payload) - split,
277291
)
278-
assert result.rsv1 # type: ignore
279-
result = ext.frame_inbound_payload_data( # type: ignore
280-
proto, compressed_payload[split:]
281-
)
282-
assert not isinstance(result, fp.CloseReason)
283-
data += result # type: ignore
292+
assert isinstance(result3, fp.RsvBits)
293+
assert result3.rsv1
294+
result4 = ext.frame_inbound_payload_data(proto, compressed_payload[split:])
295+
assert not isinstance(result4, fp.CloseReason)
296+
data += result4
284297

285-
result = ext.frame_inbound_complete(proto, True) # type: ignore
286-
assert not isinstance(result, fp.CloseReason)
287-
data += result # type: ignore
298+
result5 = ext.frame_inbound_complete(proto, True)
299+
assert not isinstance(result5, fp.CloseReason)
300+
data += result5
288301

289302
assert data == payload
290303

@@ -298,28 +311,27 @@ def test_client_decompress_after_uncompressible_frame(self, client: bool) -> Non
298311
result = ext.frame_inbound_header(
299312
proto, fp.Opcode.PING, fp.RsvBits(False, False, False), 0
300313
)
301-
result = ext.frame_inbound_payload_data(proto, b"") # type: ignore
302-
assert not isinstance(result, fp.CloseReason)
314+
result2 = ext.frame_inbound_payload_data(proto, b"")
315+
assert not isinstance(result2, fp.CloseReason)
303316
assert ext.frame_inbound_complete(proto, True) is None
304317

305318
# A compressed TEXT frame
306319
payload = b"x" * 23
307320
compressed_payload = b"\xaa\xa8\xc0\n\x00\x00"
308321

309-
result = ext.frame_inbound_header(
322+
result3 = ext.frame_inbound_header(
310323
proto,
311324
fp.Opcode.TEXT,
312325
fp.RsvBits(True, False, False),
313326
len(compressed_payload),
314327
)
315-
assert result.rsv1 # type: ignore
316-
result = ext.frame_inbound_payload_data( # type: ignore
317-
proto, compressed_payload
318-
)
319-
assert result == payload
328+
assert isinstance(result3, fp.RsvBits)
329+
assert result3.rsv1
330+
result4 = ext.frame_inbound_payload_data(proto, compressed_payload)
331+
assert result4 == payload
320332

321-
result = ext.frame_inbound_complete(proto, True) # type: ignore
322-
assert not isinstance(result, fp.CloseReason)
333+
result5 = ext.frame_inbound_complete(proto, True)
334+
assert not isinstance(result5, fp.CloseReason)
323335

324336
def test_inbound_bad_zlib_payload(self) -> None:
325337
compressed_payload = b"x" * 23
@@ -334,11 +346,10 @@ def test_inbound_bad_zlib_payload(self) -> None:
334346
fp.RsvBits(True, False, False),
335347
len(compressed_payload),
336348
)
337-
assert result.rsv1 # type: ignore
338-
result = ext.frame_inbound_payload_data( # type: ignore
339-
proto, compressed_payload
340-
)
341-
assert result is fp.CloseReason.INVALID_FRAME_PAYLOAD_DATA
349+
assert isinstance(result, fp.RsvBits)
350+
assert result.rsv1
351+
result2 = ext.frame_inbound_payload_data(proto, compressed_payload)
352+
assert result2 is fp.CloseReason.INVALID_FRAME_PAYLOAD_DATA
342353

343354
def test_inbound_bad_zlib_decoder_end_state(self, monkeypatch: MonkeyPatch) -> None:
344355
compressed_payload = b"x" * 23
@@ -353,7 +364,8 @@ def test_inbound_bad_zlib_decoder_end_state(self, monkeypatch: MonkeyPatch) -> N
353364
fp.RsvBits(True, False, False),
354365
len(compressed_payload),
355366
)
356-
assert result.rsv1 # type: ignore
367+
assert isinstance(result, fp.RsvBits)
368+
assert result.rsv1
357369

358370
class FailDecompressor:
359371
def decompress(self, data: bytes) -> bytes:
@@ -364,8 +376,8 @@ def flush(self) -> None:
364376

365377
monkeypatch.setattr(ext, "_decompressor", FailDecompressor())
366378

367-
result = ext.frame_inbound_complete(proto, True) # type: ignore
368-
assert result is fp.CloseReason.INVALID_FRAME_PAYLOAD_DATA
379+
result2 = ext.frame_inbound_complete(proto, True)
380+
assert result2 is fp.CloseReason.INVALID_FRAME_PAYLOAD_DATA
369381

370382
@pytest.mark.parametrize(
371383
"client,no_context_takeover",
@@ -383,22 +395,24 @@ def test_decompressor_reset(self, client: bool, no_context_takeover: bool) -> No
383395
result = ext.frame_inbound_header(
384396
proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), 0
385397
)
386-
assert result.rsv1 # type: ignore
398+
assert isinstance(result, fp.RsvBits)
399+
assert result.rsv1
387400

388401
assert ext._decompressor is not None
389402

390-
result = ext.frame_inbound_complete(proto, True) # type: ignore
391-
assert not isinstance(result, fp.CloseReason)
403+
result2 = ext.frame_inbound_complete(proto, True)
404+
assert not isinstance(result2, fp.CloseReason)
392405

393406
if no_context_takeover:
394407
assert ext._decompressor is None
395408
else:
396409
assert ext._decompressor is not None
397410

398-
result = ext.frame_inbound_header(
411+
result3 = ext.frame_inbound_header(
399412
proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), 0
400413
)
401-
assert result.rsv1 # type: ignore
414+
assert isinstance(result3, fp.RsvBits)
415+
assert result3.rsv1
402416

403417
assert ext._decompressor is not None
404418

@@ -486,5 +500,5 @@ def test_compressor_reset(self, client: bool, no_context_takeover: bool) -> None
486500

487501
@pytest.mark.parametrize("params", parameter_sets)
488502
def test_repr(self, params: Params) -> None:
489-
ext = wpext.PerMessageDeflate(**params) # type: ignore
503+
ext = wpext.PerMessageDeflate(**params)
490504
self.compare_params_to_string(params, ext, repr(ext))

0 commit comments

Comments
 (0)