Skip to content

Commit 513073b

Browse files
committed
Add option to always include port in build_host helper.
1 parent 321be89 commit 513073b

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

src/websockets/headers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
T = TypeVar("T")
3737

3838

39-
def build_host(host: str, port: int, secure: bool) -> str:
39+
def build_host(
40+
host: str,
41+
port: int,
42+
secure: bool,
43+
*,
44+
always_include_port: bool = False,
45+
) -> str:
4046
"""
4147
Build a ``Host`` header.
4248
@@ -53,7 +59,7 @@ def build_host(host: str, port: int, secure: bool) -> str:
5359
if address.version == 6:
5460
host = f"[{host}]"
5561

56-
if port != (443 if secure else 80):
62+
if always_include_port or port != (443 if secure else 80):
5763
host = f"{host}:{port}"
5864

5965
return host

tests/test_headers.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,33 @@
66

77
class HeadersTests(unittest.TestCase):
88
def test_build_host(self):
9-
for (host, port, secure), result in [
10-
(("localhost", 80, False), "localhost"),
11-
(("localhost", 8000, False), "localhost:8000"),
12-
(("localhost", 443, True), "localhost"),
13-
(("localhost", 8443, True), "localhost:8443"),
14-
(("example.com", 80, False), "example.com"),
15-
(("example.com", 8000, False), "example.com:8000"),
16-
(("example.com", 443, True), "example.com"),
17-
(("example.com", 8443, True), "example.com:8443"),
18-
(("127.0.0.1", 80, False), "127.0.0.1"),
19-
(("127.0.0.1", 8000, False), "127.0.0.1:8000"),
20-
(("127.0.0.1", 443, True), "127.0.0.1"),
21-
(("127.0.0.1", 8443, True), "127.0.0.1:8443"),
22-
(("::1", 80, False), "[::1]"),
23-
(("::1", 8000, False), "[::1]:8000"),
24-
(("::1", 443, True), "[::1]"),
25-
(("::1", 8443, True), "[::1]:8443"),
9+
for (host, port, secure), (result, result_with_port) in [
10+
(("localhost", 80, False), ("localhost", "localhost:80")),
11+
(("localhost", 8000, False), ("localhost:8000", "localhost:8000")),
12+
(("localhost", 443, True), ("localhost", "localhost:443")),
13+
(("localhost", 8443, True), ("localhost:8443", "localhost:8443")),
14+
(("example.com", 80, False), ("example.com", "example.com:80")),
15+
(("example.com", 8000, False), ("example.com:8000", "example.com:8000")),
16+
(("example.com", 443, True), ("example.com", "example.com:443")),
17+
(("example.com", 8443, True), ("example.com:8443", "example.com:8443")),
18+
(("127.0.0.1", 80, False), ("127.0.0.1", "127.0.0.1:80")),
19+
(("127.0.0.1", 8000, False), ("127.0.0.1:8000", "127.0.0.1:8000")),
20+
(("127.0.0.1", 443, True), ("127.0.0.1", "127.0.0.1:443")),
21+
(("127.0.0.1", 8443, True), ("127.0.0.1:8443", "127.0.0.1:8443")),
22+
(("::1", 80, False), ("[::1]", "[::1]:80")),
23+
(("::1", 8000, False), ("[::1]:8000", "[::1]:8000")),
24+
(("::1", 443, True), ("[::1]", "[::1]:443")),
25+
(("::1", 8443, True), ("[::1]:8443", "[::1]:8443")),
2626
]:
2727
with self.subTest(host=host, port=port, secure=secure):
28-
self.assertEqual(build_host(host, port, secure), result)
28+
self.assertEqual(
29+
build_host(host, port, secure),
30+
result,
31+
)
32+
self.assertEqual(
33+
build_host(host, port, secure, always_include_port=True),
34+
result_with_port,
35+
)
2936

3037
def test_parse_connection(self):
3138
for header, parsed in [

0 commit comments

Comments
 (0)