Skip to content

Commit 1e8c2cc

Browse files
committed
Run mitmproxy only once per test case.
This reduces the run time of the test suite by 40%, from 6.7s to 4.1s.
1 parent 7bb226a commit 1e8c2cc

File tree

3 files changed

+205
-223
lines changed

3 files changed

+205
-223
lines changed

tests/asyncio/test_client.py

Lines changed: 65 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from websockets.extensions.permessage_deflate import PerMessageDeflate
2424

25-
from ..proxy import async_proxy
25+
from ..proxy import ProxyMixin
2626
from ..utils import (
2727
CLIENT_CONTEXT,
2828
MS,
@@ -388,7 +388,7 @@ def remove_accept_header(self, request, response):
388388

389389
async def test_timeout_during_handshake(self):
390390
"""Client times out before receiving handshake response from server."""
391-
# Replace the WebSocket server with a TCP server that does't respond.
391+
# Replace the WebSocket server with a TCP server that doesn't respond.
392392
with socket.create_server(("localhost", 0)) as sock:
393393
host, port = sock.getsockname()
394394
with self.assertRaises(TimeoutError) as raised:
@@ -508,7 +508,7 @@ async def test_reject_invalid_server_certificate(self):
508508
"""Client rejects certificate where server certificate isn't trusted."""
509509
async with serve(*args, ssl=SERVER_CONTEXT) as server:
510510
with self.assertRaises(ssl.SSLCertVerificationError) as raised:
511-
# The test certificate isn't trusted system-wide.
511+
# The test certificate is self-signed.
512512
async with connect(get_uri(server)):
513513
self.fail("did not raise")
514514
self.assertIn(
@@ -566,126 +566,104 @@ def redirect(connection, request):
566566

567567

568568
@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed")
569-
class ProxyClientTests(unittest.IsolatedAsyncioTestCase):
570-
@contextlib.asynccontextmanager
571-
async def socks_proxy(self, auth=None):
572-
if auth:
573-
proxyauth = "hello:iloveyou"
574-
proxy_uri = "http://hello:iloveyou@localhost:51080"
575-
else:
576-
proxyauth = None
577-
proxy_uri = "http://localhost:51080"
578-
async with async_proxy(
579-
mode=["socks5@51080"],
580-
proxyauth=proxyauth,
581-
) as record_flows:
582-
with patch_environ({"socks_proxy": proxy_uri}):
583-
yield record_flows
569+
class SocksProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase):
570+
proxy_mode = "socks5@51080"
584571

585572
async def test_socks_proxy(self):
586573
"""Client connects to server through a SOCKS5 proxy."""
587-
async with self.socks_proxy() as proxy:
574+
with patch_environ({"socks_proxy": "http://localhost:51080"}):
588575
async with serve(*args) as server:
589576
async with connect(get_uri(server)) as client:
590577
self.assertEqual(client.protocol.state.name, "OPEN")
591-
self.assertEqual(len(proxy.get_flows()), 1)
578+
self.assertNumFlows(1)
592579

593580
async def test_secure_socks_proxy(self):
594581
"""Client connects to server securely through a SOCKS5 proxy."""
595-
async with self.socks_proxy() as proxy:
582+
with patch_environ({"socks_proxy": "http://localhost:51080"}):
596583
async with serve(*args, ssl=SERVER_CONTEXT) as server:
597584
async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client:
598585
self.assertEqual(client.protocol.state.name, "OPEN")
599-
self.assertEqual(len(proxy.get_flows()), 1)
586+
self.assertNumFlows(1)
600587

601588
async def test_authenticated_socks_proxy(self):
602589
"""Client connects to server through an authenticated SOCKS5 proxy."""
603-
async with self.socks_proxy(auth=True) as proxy:
604-
async with serve(*args) as server:
605-
async with connect(get_uri(server)) as client:
606-
self.assertEqual(client.protocol.state.name, "OPEN")
607-
self.assertEqual(len(proxy.get_flows()), 1)
590+
try:
591+
self.proxy_options.update(proxyauth="hello:iloveyou")
592+
with patch_environ(
593+
{"socks_proxy": "http://hello:iloveyou@localhost:51080"}
594+
):
595+
async with serve(*args) as server:
596+
async with connect(get_uri(server)) as client:
597+
self.assertEqual(client.protocol.state.name, "OPEN")
598+
finally:
599+
self.proxy_options.update(proxyauth=None)
600+
self.assertNumFlows(1)
608601

609-
async def test_socks_proxy_connection_error(self):
610-
"""Client receives an error when connecting to the SOCKS5 proxy."""
602+
async def test_authenticated_socks_proxy_error(self):
603+
"""Client fails to authenticate to the SOCKS5 proxy."""
611604
from python_socks import ProxyError as SocksProxyError
612605

613-
async with self.socks_proxy(auth=True) as proxy:
614-
with self.assertRaises(ProxyError) as raised:
615-
async with connect(
616-
"ws://example.com/",
617-
proxy="socks5h://localhost:51080", # remove credentials
618-
):
619-
self.fail("did not raise")
606+
try:
607+
self.proxy_options.update(proxyauth="any")
608+
with patch_environ({"socks_proxy": "http://localhost:51080"}):
609+
with self.assertRaises(ProxyError) as raised:
610+
async with connect("ws://example.com/"):
611+
self.fail("did not raise")
612+
finally:
613+
self.proxy_options.update(proxyauth=None)
620614
self.assertEqual(
621615
str(raised.exception),
622616
"failed to connect to SOCKS proxy",
623617
)
624618
self.assertIsInstance(raised.exception.__cause__, SocksProxyError)
625-
self.assertEqual(len(proxy.get_flows()), 0)
619+
self.assertNumFlows(0)
626620

627-
async def test_socks_proxy_connection_fails(self):
621+
async def test_socks_proxy_connection_failure(self):
628622
"""Client fails to connect to the SOCKS5 proxy."""
629623
from python_socks import ProxyConnectionError as SocksProxyConnectionError
630624

631-
with self.assertRaises(OSError) as raised:
632-
async with connect(
633-
"ws://example.com/",
634-
proxy="socks5h://localhost:51080", # nothing at this address
635-
):
636-
self.fail("did not raise")
625+
with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port
626+
with self.assertRaises(OSError) as raised:
627+
async with connect("ws://example.com/"):
628+
self.fail("did not raise")
637629
# Don't test str(raised.exception) because we don't control it.
638630
self.assertIsInstance(raised.exception, SocksProxyConnectionError)
631+
self.assertNumFlows(0)
639632

640633
async def test_socks_proxy_connection_timeout(self):
641634
"""Client times out while connecting to the SOCKS5 proxy."""
642-
# Replace the proxy with a TCP server that does't respond.
635+
# Replace the proxy with a TCP server that doesn't respond.
643636
with socket.create_server(("localhost", 0)) as sock:
644637
host, port = sock.getsockname()
645-
with self.assertRaises(TimeoutError) as raised:
646-
async with connect(
647-
"ws://example.com/",
648-
proxy=f"socks5h://{host}:{port}/",
649-
open_timeout=MS,
650-
):
651-
self.fail("did not raise")
638+
with patch_environ({"socks_proxy": f"http://{host}:{port}"}):
639+
with self.assertRaises(TimeoutError) as raised:
640+
async with connect("ws://example.com/", open_timeout=MS):
641+
self.fail("did not raise")
652642
self.assertEqual(
653643
str(raised.exception),
654644
"timed out during handshake",
655645
)
646+
self.assertNumFlows(0)
656647

657-
async def test_explicit_proxy(self):
658-
"""Client connects to server through a proxy set explicitly."""
659-
async with async_proxy(mode=["socks5@51080"]) as proxy:
660-
async with serve(*args) as server:
661-
async with connect(
662-
get_uri(server),
663-
# Take this opportunity to test socks5 instead of socks5h.
664-
proxy="socks5://localhost:51080",
665-
) as client:
666-
self.assertEqual(client.protocol.state.name, "OPEN")
667-
self.assertEqual(len(proxy.get_flows()), 1)
648+
async def test_explicit_socks_proxy(self):
649+
"""Client connects to server through a SOCKS5 proxy set explicitly."""
650+
async with serve(*args) as server:
651+
async with connect(
652+
get_uri(server),
653+
# Take this opportunity to test socks5 instead of socks5h.
654+
proxy="socks5://localhost:51080",
655+
) as client:
656+
self.assertEqual(client.protocol.state.name, "OPEN")
657+
self.assertNumFlows(1)
668658

669659
async def test_ignore_proxy_with_existing_socket(self):
670660
"""Client connects using a pre-existing socket."""
671-
async with self.socks_proxy() as proxy:
661+
with patch_environ({"socks_proxy": "http://localhost:51080"}):
672662
async with serve(*args) as server:
673663
with socket.create_connection(get_host_port(server)) as sock:
674664
# Use a non-existing domain to ensure we connect to sock.
675665
async with connect("ws://invalid/", sock=sock) as client:
676666
self.assertEqual(client.protocol.state.name, "OPEN")
677-
self.assertEqual(len(proxy.get_flows()), 0)
678-
679-
async def test_unsupported_proxy(self):
680-
"""Client connects to server through an unsupported proxy."""
681-
with patch_environ({"ws_proxy": "other://localhost:51080"}):
682-
with self.assertRaises(InvalidProxy) as raised:
683-
async with connect("ws://example.com/"):
684-
self.fail("did not raise")
685-
self.assertEqual(
686-
str(raised.exception),
687-
"other://localhost:51080 isn't a valid proxy: scheme other isn't supported",
688-
)
689667

690668

691669
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets")
@@ -724,10 +702,7 @@ def redirect(connection, request):
724702
"cannot follow cross-origin redirect to ws://other/ with a Unix socket",
725703
)
726704

727-
728-
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets")
729-
class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase):
730-
async def test_connection(self):
705+
async def test_secure_connection(self):
731706
"""Client connects to server securely over a Unix socket."""
732707
with temp_unix_socket_path() as path:
733708
async with unix_serve(handler, path, ssl=SERVER_CONTEXT):
@@ -769,6 +744,16 @@ async def test_secure_uri_without_ssl(self):
769744
"ssl=None is incompatible with a wss:// URI",
770745
)
771746

747+
async def test_unsupported_proxy(self):
748+
"""Client rejects unsupported proxy."""
749+
with self.assertRaises(InvalidProxy) as raised:
750+
async with connect("ws://example.com/", proxy="other://localhost:51080"):
751+
self.fail("did not raise")
752+
self.assertEqual(
753+
str(raised.exception),
754+
"other://localhost:51080 isn't a valid proxy: scheme other isn't supported",
755+
)
756+
772757
async def test_unix_without_path_or_sock(self):
773758
"""Unix client requires path when sock isn't provided."""
774759
with self.assertRaises(ValueError) as raised:

tests/proxy.py

Lines changed: 73 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import asyncio
2-
import contextlib
32
import pathlib
43
import threading
54
import warnings
65

76

8-
warnings.filterwarnings("ignore", category=DeprecationWarning, module="mitmproxy")
9-
warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib")
10-
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1")
11-
127
try:
8+
# Ignore deprecation warnings raised by mitmproxy dependencies at import time.
9+
warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib")
10+
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1")
11+
1312
from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig
1413
from mitmproxy.master import Master
1514
from mitmproxy.options import Options
@@ -18,72 +17,87 @@
1817

1918

2019
class RecordFlows:
21-
def __init__(self):
22-
self.ready = asyncio.get_running_loop().create_future()
20+
def __init__(self, on_running):
21+
self.running = on_running
2322
self.flows = []
2423

25-
def running(self):
26-
self.ready.set_result(None)
27-
2824
def websocket_start(self, flow):
2925
self.flows.append(flow)
3026

3127
def get_flows(self):
3228
flows, self.flows[:] = self.flows[:], []
3329
return flows
3430

31+
def reset_flows(self):
32+
self.flows = []
33+
34+
35+
class ProxyMixin:
36+
"""
37+
Run mitmproxy in a background thread.
38+
39+
While it's uncommon to run two event loops in two threads, tests for the
40+
asyncio implementation rely on this class too because it starts an event
41+
loop for mitm proxy once, then a new event loop for each test.
42+
"""
43+
44+
proxy_mode = None
45+
46+
@classmethod
47+
async def run_proxy(cls):
48+
cls.proxy_loop = loop = asyncio.get_event_loop()
49+
cls.proxy_stop = stop = loop.create_future()
50+
51+
cls.proxy_options = options = Options(mode=[cls.proxy_mode])
52+
cls.proxy_master = master = Master(options)
53+
master.addons.add(
54+
core.Core(),
55+
proxyauth.ProxyAuth(),
56+
proxyserver.Proxyserver(),
57+
next_layer.NextLayer(),
58+
tlsconfig.TlsConfig(),
59+
RecordFlows(on_running=cls.proxy_ready.set),
60+
)
61+
options.update(
62+
# Use test certificate for TLS between client and proxy.
63+
certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))],
64+
# Disable TLS verification between proxy and upstream.
65+
ssl_insecure=True,
66+
)
67+
68+
task = loop.create_task(cls.proxy_master.run())
69+
await stop
3570

36-
@contextlib.asynccontextmanager
37-
async def async_proxy(mode, **config):
38-
options = Options(mode=mode)
39-
master = Master(options)
40-
record_flows = RecordFlows()
41-
master.addons.add(
42-
core.Core(),
43-
proxyauth.ProxyAuth(),
44-
proxyserver.Proxyserver(),
45-
next_layer.NextLayer(),
46-
tlsconfig.TlsConfig(),
47-
record_flows,
48-
)
49-
config.update(
50-
# Use our test certificate for TLS between client and proxy
51-
# and disable TLS verification between proxy and upstream.
52-
certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))],
53-
ssl_insecure=True,
54-
)
55-
options.update(**config)
56-
57-
asyncio.create_task(master.run())
58-
try:
59-
await record_flows.ready
60-
yield record_flows
61-
finally:
6271
for server in master.addons.get("proxyserver").servers:
6372
await server.stop()
6473
master.shutdown()
74+
await task
75+
76+
@classmethod
77+
def setUpClass(cls):
78+
super().setUpClass()
79+
80+
# Ignore deprecation warnings raised by mitmproxy at run time.
81+
warnings.filterwarnings(
82+
"ignore", category=DeprecationWarning, module="mitmproxy"
83+
)
84+
85+
cls.proxy_ready = threading.Event()
86+
cls.proxy_thread = threading.Thread(target=asyncio.run, args=(cls.run_proxy(),))
87+
cls.proxy_thread.start()
88+
cls.proxy_ready.wait()
89+
90+
def assertNumFlows(self, num_flows):
91+
record_flows = self.proxy_master.addons.get("recordflows")
92+
self.assertEqual(len(record_flows.get_flows()), num_flows)
6593

94+
def tearDown(self):
95+
record_flows = self.proxy_master.addons.get("recordflows")
96+
record_flows.reset_flows()
97+
super().tearDown()
6698

67-
@contextlib.contextmanager
68-
def sync_proxy(mode, **config):
69-
loop = None
70-
test_done = None
71-
proxy_ready = threading.Event()
72-
record_flows = None
73-
74-
async def proxy_coroutine():
75-
nonlocal loop, test_done, proxy_ready, record_flows
76-
loop = asyncio.get_running_loop()
77-
test_done = loop.create_future()
78-
async with async_proxy(mode, **config) as record_flows:
79-
proxy_ready.set()
80-
await test_done
81-
82-
proxy_thread = threading.Thread(target=asyncio.run, args=(proxy_coroutine(),))
83-
proxy_thread.start()
84-
try:
85-
proxy_ready.wait()
86-
yield record_flows
87-
finally:
88-
loop.call_soon_threadsafe(test_done.set_result, None)
89-
proxy_thread.join()
99+
@classmethod
100+
def tearDownClass(cls):
101+
cls.proxy_loop.call_soon_threadsafe(cls.proxy_stop.set_result, None)
102+
cls.proxy_thread.join()
103+
super().tearDownClass()

0 commit comments

Comments
 (0)