diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index 31233049f7..c7bd4faa05 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -25,6 +25,17 @@ def __init__(self, host: str, port: int) -> None: self._conn: Optional[TcpOrTlsSocket] = None self.addr: HostPort = (host, port) self.closed = True + self._proxy = False + + def is_secure(self) -> bool: + return isinstance(self._conn, ssl.SSLSocket) + + def mark_as_proxy(self) -> None: + self._proxy = True + + @property + def is_proxy(self) -> bool: + return self._proxy @property def connection(self) -> TcpOrTlsSocket: diff --git a/proxy/http/proxy/plugin.py b/proxy/http/proxy/plugin.py index a9c10e88f3..c1dfd34fa4 100644 --- a/proxy/http/proxy/plugin.py +++ b/proxy/http/proxy/plugin.py @@ -17,6 +17,7 @@ from ...core.event import EventQueue from ..descriptors import DescriptorsHandlerMixin from ...common.utils import tls_interception_enabled +from ...core.connection import TcpServerConnection if TYPE_CHECKING: # pragma: no cover @@ -69,6 +70,12 @@ def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['Ho """ return None, None + def upstream_connection( + self, + request: HttpParser, + ) -> Optional[TcpServerConnection]: + return None + # No longer abstract since 2.4.0 # # @abstractmethod diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 70d3369ec4..fcaf3e7306 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -21,7 +21,7 @@ import logging import threading import subprocess -from typing import Any, Dict, List, Union, Optional, cast +from typing import Any, Dict, List, Union, Optional from .plugin import HttpProxyBasePlugin from ..parser import HttpParser, httpParserTypes, httpParserStates @@ -487,6 +487,14 @@ def on_request_complete(self) -> Union[socket.socket, bool]: # Connect to upstream if do_connect: self.connect_upstream() + else: + # If a plugin asked us not to connect to upstream + # check if any plugin is managing an upstream connection. + for plugin in self.plugins.values(): + up = plugin.upstream_connection(self.request) + if up is not None: + self.upstream = up + break # Invoke plugin.handle_client_request for plugin in self.plugins.values(): @@ -756,13 +764,28 @@ def intercept(self) -> Union[socket.socket, bool]: return self.client.connection def wrap_server(self) -> bool: - assert self.upstream is not None - assert isinstance(self.upstream.connection, socket.socket) + assert self.upstream is not None and self.request.host + return self._wrap_server( + self.upstream, + host=self.request.host, + ca_file=self.flags.ca_file, + ) + + @staticmethod + def _wrap_server( + upstream: TcpServerConnection, + host: bytes, + ca_file: Optional[str] = None, + ) -> bool: + assert isinstance(upstream.connection, socket.socket) do_close = False + if upstream.is_proxy: + # Don't wrap upstream if its part of proxy chain + return do_close try: - self.upstream.wrap( - text_(self.request.host), - self.flags.ca_file, + upstream.wrap( + text_(host), + ca_file, as_non_blocking=True, ) except ssl.SSLCertVerificationError: # Server raised certificate verification error @@ -770,40 +793,68 @@ def wrap_server(self) -> bool: # we will cache such upstream hosts and avoid intercepting them for future # requests. logger.warning( - 'ssl.SSLCertVerificationError: ' + - 'Server raised cert verification error for upstream: {0}'.format( - self.upstream.addr[0], + 'ssl.SSLCertVerificationError: ' + + 'Server raised cert verification error for upstream: {0}'.format( + upstream.addr[0], ), ) do_close = True except ssl.SSLError as e: if e.reason == 'SSLV3_ALERT_HANDSHAKE_FAILURE': logger.warning( - '{0}: '.format(e.reason) + - 'Server raised handshake alert failure for upstream: {0}'.format( - self.upstream.addr[0], + '{0}: '.format(e.reason) + + 'Server raised handshake alert failure for upstream: {0}'.format( + upstream.addr[0], ), ) else: logger.exception( 'SSLError when wrapping client for upstream: {0}'.format( - self.upstream.addr[0], - ), exc_info=e, + upstream.addr[0], + ), + exc_info=e, ) do_close = True if not do_close: - assert isinstance(self.upstream.connection, ssl.SSLSocket) + assert isinstance(upstream.connection, ssl.SSLSocket) return do_close def wrap_client(self) -> bool: assert self.upstream is not None and self.flags.ca_signing_key_file is not None - assert isinstance(self.upstream.connection, ssl.SSLSocket) + certificate: Optional[Dict[str, Any]] = None + if isinstance(self.upstream.connection, ssl.SSLSocket): + certificate = self.upstream.connection.getpeercert() + else: + assert self.upstream.is_proxy and self.request.host and self.request.port + if self.flags.enable_conn_pool: + assert self.upstream_conn_pool + with self.lock: + _, upstream = self.upstream_conn_pool.acquire( + (text_(self.request.host), self.request.port), + ) + else: + _, upstream = True, TcpServerConnection( + text_(self.request.host), + self.request.port, + ) + # Connect with overridden upstream IP and source address + # if any of the plugin returned a non-null value. + upstream.connect() + upstream.connection.setblocking(False) + do_close = self._wrap_server( + upstream, + host=self.request.host, + ca_file=self.flags.ca_file, + ) + if do_close: + return do_close + assert isinstance(upstream.connection, ssl.SSLSocket) + certificate = upstream.connection.getpeercert() + assert certificate do_close = False try: # TODO: Perform async certificate generation - generated_cert = self.generate_upstream_certificate( - cast(Dict[str, Any], self.upstream.connection.getpeercert()), - ) + generated_cert = self.generate_upstream_certificate(certificate) self.client.wrap(self.flags.ca_signing_key_file, generated_cert) except subprocess.TimeoutExpired as e: # Popen communicate timeout logger.exception( diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index c244adeda1..eecac398a2 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -14,6 +14,7 @@ import ipaddress from typing import Any, Dict, List, Optional +from proxy.core.connection import TcpServerConnection from ..http import Url, httpHeaders, httpMethods from ..core.base import TcpUpstreamConnectionHandler from ..http.proxy import HttpProxyBasePlugin @@ -78,6 +79,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def handle_upstream_data(self, raw: memoryview) -> None: self.client.queue(raw) + def upstream_connection(self, request: HttpParser) -> Optional[TcpServerConnection]: + return self.upstream + def before_upstream_connection( self, request: HttpParser, ) -> Optional[HttpParser]: @@ -107,6 +111,7 @@ def before_upstream_connection( logger.debug('Using endpoint: {0}:{1}'.format(*endpoint_tuple)) self.initialize_upstream(*endpoint_tuple) assert self.upstream + self.upstream.mark_as_proxy() try: self.upstream.connect() except TimeoutError: