Skip to content

Commit d578ae7

Browse files
committed
Enable TLS interception support even when proxy pool plugin is enabled
1 parent 7026c13 commit d578ae7

File tree

4 files changed

+95
-21
lines changed

4 files changed

+95
-21
lines changed

proxy/core/connection/server.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ def __init__(self, host: str, port: int) -> None:
2525
self._conn: Optional[TcpOrTlsSocket] = None
2626
self.addr: HostPort = (host, port)
2727
self.closed = True
28+
self._proxy = False
29+
30+
def is_secure(self) -> bool:
31+
return isinstance(self._conn, ssl.SSLSocket)
32+
33+
def mark_as_proxy(self) -> None:
34+
self._proxy = True
35+
36+
@property
37+
def is_proxy(self) -> bool:
38+
return self._proxy
2839

2940
@property
3041
def connection(self) -> TcpOrTlsSocket:

proxy/http/proxy/plugin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ...core.event import EventQueue
1818
from ..descriptors import DescriptorsHandlerMixin
1919
from ...common.utils import tls_interception_enabled
20-
20+
from ...core.connection import TcpServerConnection
2121

2222
if TYPE_CHECKING: # pragma: no cover
2323
from ...common.types import HostPort
@@ -69,6 +69,12 @@ def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['Ho
6969
"""
7070
return None, None
7171

72+
def upstream_connection(
73+
self,
74+
request: HttpParser,
75+
) -> Optional[TcpServerConnection]:
76+
return None
77+
7278
# No longer abstract since 2.4.0
7379
#
7480
# @abstractmethod

proxy/http/proxy/server.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import logging
2222
import threading
2323
import subprocess
24-
from typing import Any, Dict, List, Union, Optional, cast
24+
from typing import Any, Dict, List, Union, Optional
2525

2626
from .plugin import HttpProxyBasePlugin
2727
from ..parser import HttpParser, httpParserTypes, httpParserStates
@@ -487,6 +487,14 @@ def on_request_complete(self) -> Union[socket.socket, bool]:
487487
# Connect to upstream
488488
if do_connect:
489489
self.connect_upstream()
490+
else:
491+
# If a plugin asked us not to connect to upstream
492+
# check if any plugin is managing an upstream connection.
493+
for plugin in self.plugins.values():
494+
up = plugin.upstream_connection(self.request)
495+
if up is not None:
496+
self.upstream = up
497+
break
490498

491499
# Invoke plugin.handle_client_request
492500
for plugin in self.plugins.values():
@@ -756,54 +764,97 @@ def intercept(self) -> Union[socket.socket, bool]:
756764
return self.client.connection
757765

758766
def wrap_server(self) -> bool:
759-
assert self.upstream is not None
760-
assert isinstance(self.upstream.connection, socket.socket)
767+
assert self.upstream is not None and self.request.host
768+
return self._wrap_server(
769+
self.upstream,
770+
host=self.request.host,
771+
ca_file=self.flags.ca_file,
772+
)
773+
774+
@staticmethod
775+
def _wrap_server(
776+
upstream: TcpServerConnection,
777+
host: bytes,
778+
ca_file: Optional[str] = None,
779+
) -> bool:
780+
assert isinstance(upstream.connection, socket.socket)
761781
do_close = False
782+
if upstream.is_proxy:
783+
# Don't wrap upstream if its part of proxy chain
784+
return do_close
762785
try:
763-
self.upstream.wrap(
764-
text_(self.request.host),
765-
self.flags.ca_file,
786+
upstream.wrap(
787+
text_(host),
788+
ca_file,
766789
as_non_blocking=True,
767790
)
768791
except ssl.SSLCertVerificationError: # Server raised certificate verification error
769792
# When --disable-interception-on-ssl-cert-verification-error flag is on,
770793
# we will cache such upstream hosts and avoid intercepting them for future
771794
# requests.
772795
logger.warning(
773-
'ssl.SSLCertVerificationError: ' +
774-
'Server raised cert verification error for upstream: {0}'.format(
775-
self.upstream.addr[0],
796+
"ssl.SSLCertVerificationError: "
797+
+ "Server raised cert verification error for upstream: {0}".format(
798+
upstream.addr[0],
776799
),
777800
)
778801
do_close = True
779802
except ssl.SSLError as e:
780803
if e.reason == 'SSLV3_ALERT_HANDSHAKE_FAILURE':
781804
logger.warning(
782-
'{0}: '.format(e.reason) +
783-
'Server raised handshake alert failure for upstream: {0}'.format(
784-
self.upstream.addr[0],
805+
"{0}: ".format(e.reason)
806+
+ "Server raised handshake alert failure for upstream: {0}".format(
807+
upstream.addr[0],
785808
),
786809
)
787810
else:
788811
logger.exception(
789-
'SSLError when wrapping client for upstream: {0}'.format(
790-
self.upstream.addr[0],
791-
), exc_info=e,
812+
"SSLError when wrapping client for upstream: {0}".format(
813+
upstream.addr[0],
814+
),
815+
exc_info=e,
792816
)
793817
do_close = True
794818
if not do_close:
795-
assert isinstance(self.upstream.connection, ssl.SSLSocket)
819+
assert isinstance(upstream.connection, ssl.SSLSocket)
796820
return do_close
797821

798822
def wrap_client(self) -> bool:
799823
assert self.upstream is not None and self.flags.ca_signing_key_file is not None
800-
assert isinstance(self.upstream.connection, ssl.SSLSocket)
824+
certificate: Optional[Dict[str, Any]] = None
825+
if isinstance(self.upstream.connection, ssl.SSLSocket):
826+
certificate = self.upstream.connection.getpeercert()
827+
else:
828+
assert self.upstream.is_proxy and self.request.host and self.request.port
829+
if self.flags.enable_conn_pool:
830+
assert self.upstream_conn_pool
831+
with self.lock:
832+
_, upstream = self.upstream_conn_pool.acquire(
833+
(text_(self.request.host), self.request.port),
834+
)
835+
else:
836+
_, upstream = True, TcpServerConnection(
837+
text_(self.request.host),
838+
self.request.port,
839+
)
840+
# Connect with overridden upstream IP and source address
841+
# if any of the plugin returned a non-null value.
842+
upstream.connect()
843+
upstream.connection.setblocking(False)
844+
do_close = self._wrap_server(
845+
upstream,
846+
host=self.request.host,
847+
ca_file=self.flags.ca_file,
848+
)
849+
if do_close:
850+
return do_close
851+
assert isinstance(upstream.connection, ssl.SSLSocket)
852+
certificate = upstream.connection.getpeercert()
853+
assert certificate
801854
do_close = False
802855
try:
803856
# TODO: Perform async certificate generation
804-
generated_cert = self.generate_upstream_certificate(
805-
cast(Dict[str, Any], self.upstream.connection.getpeercert()),
806-
)
857+
generated_cert = self.generate_upstream_certificate(certificate)
807858
self.client.wrap(self.flags.ca_signing_key_file, generated_cert)
808859
except subprocess.TimeoutExpired as e: # Popen communicate timeout
809860
logger.exception(

proxy/plugin/proxy_pool.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import ipaddress
1515
from typing import Any, Dict, List, Optional
1616

17+
from proxy.core.connection import TcpServerConnection
18+
1719
from ..http import Url, httpHeaders, httpMethods
1820
from ..core.base import TcpUpstreamConnectionHandler
1921
from ..http.proxy import HttpProxyBasePlugin
@@ -78,6 +80,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
7880
def handle_upstream_data(self, raw: memoryview) -> None:
7981
self.client.queue(raw)
8082

83+
def upstream_connection(self, request: HttpParser) -> Optional[TcpServerConnection]:
84+
return self.upstream
85+
8186
def before_upstream_connection(
8287
self, request: HttpParser,
8388
) -> Optional[HttpParser]:
@@ -107,6 +112,7 @@ def before_upstream_connection(
107112
logger.debug('Using endpoint: {0}:{1}'.format(*endpoint_tuple))
108113
self.initialize_upstream(*endpoint_tuple)
109114
assert self.upstream
115+
self.upstream.mark_as_proxy()
110116
try:
111117
self.upstream.connect()
112118
except TimeoutError:

0 commit comments

Comments
 (0)