|
21 | 21 | import logging
|
22 | 22 | import threading
|
23 | 23 | import subprocess
|
24 |
| -from typing import Any, Dict, List, Union, Optional, cast |
| 24 | +from typing import Any, Dict, List, Union, Optional |
25 | 25 |
|
26 | 26 | from .plugin import HttpProxyBasePlugin
|
27 | 27 | from ..parser import HttpParser, httpParserTypes, httpParserStates
|
@@ -487,6 +487,14 @@ def on_request_complete(self) -> Union[socket.socket, bool]:
|
487 | 487 | # Connect to upstream
|
488 | 488 | if do_connect:
|
489 | 489 | self.connect_upstream()
|
| 490 | + else: |
| 491 | + # If a plugin asked us not to connect to upstream |
| 492 | + # check if any plugin is managing an upstream connection. |
| 493 | + for plugin in self.plugins.values(): |
| 494 | + up = plugin.upstream_connection(self.request) |
| 495 | + if up is not None: |
| 496 | + self.upstream = up |
| 497 | + break |
490 | 498 |
|
491 | 499 | # Invoke plugin.handle_client_request
|
492 | 500 | for plugin in self.plugins.values():
|
@@ -756,54 +764,97 @@ def intercept(self) -> Union[socket.socket, bool]:
|
756 | 764 | return self.client.connection
|
757 | 765 |
|
758 | 766 | def wrap_server(self) -> bool:
|
759 |
| - assert self.upstream is not None |
760 |
| - assert isinstance(self.upstream.connection, socket.socket) |
| 767 | + assert self.upstream is not None and self.request.host |
| 768 | + return self._wrap_server( |
| 769 | + self.upstream, |
| 770 | + host=self.request.host, |
| 771 | + ca_file=self.flags.ca_file, |
| 772 | + ) |
| 773 | + |
| 774 | + @staticmethod |
| 775 | + def _wrap_server( |
| 776 | + upstream: TcpServerConnection, |
| 777 | + host: bytes, |
| 778 | + ca_file: Optional[str] = None, |
| 779 | + ) -> bool: |
| 780 | + assert isinstance(upstream.connection, socket.socket) |
761 | 781 | do_close = False
|
| 782 | + if upstream.is_proxy: |
| 783 | + # Don't wrap upstream if its part of proxy chain |
| 784 | + return do_close |
762 | 785 | try:
|
763 |
| - self.upstream.wrap( |
764 |
| - text_(self.request.host), |
765 |
| - self.flags.ca_file, |
| 786 | + upstream.wrap( |
| 787 | + text_(host), |
| 788 | + ca_file, |
766 | 789 | as_non_blocking=True,
|
767 | 790 | )
|
768 | 791 | except ssl.SSLCertVerificationError: # Server raised certificate verification error
|
769 | 792 | # When --disable-interception-on-ssl-cert-verification-error flag is on,
|
770 | 793 | # we will cache such upstream hosts and avoid intercepting them for future
|
771 | 794 | # requests.
|
772 | 795 | logger.warning(
|
773 |
| - 'ssl.SSLCertVerificationError: ' + |
774 |
| - 'Server raised cert verification error for upstream: {0}'.format( |
775 |
| - self.upstream.addr[0], |
| 796 | + "ssl.SSLCertVerificationError: " |
| 797 | + + "Server raised cert verification error for upstream: {0}".format( |
| 798 | + upstream.addr[0], |
776 | 799 | ),
|
777 | 800 | )
|
778 | 801 | do_close = True
|
779 | 802 | except ssl.SSLError as e:
|
780 | 803 | if e.reason == 'SSLV3_ALERT_HANDSHAKE_FAILURE':
|
781 | 804 | logger.warning(
|
782 |
| - '{0}: '.format(e.reason) + |
783 |
| - 'Server raised handshake alert failure for upstream: {0}'.format( |
784 |
| - self.upstream.addr[0], |
| 805 | + "{0}: ".format(e.reason) |
| 806 | + + "Server raised handshake alert failure for upstream: {0}".format( |
| 807 | + upstream.addr[0], |
785 | 808 | ),
|
786 | 809 | )
|
787 | 810 | else:
|
788 | 811 | logger.exception(
|
789 |
| - 'SSLError when wrapping client for upstream: {0}'.format( |
790 |
| - self.upstream.addr[0], |
791 |
| - ), exc_info=e, |
| 812 | + "SSLError when wrapping client for upstream: {0}".format( |
| 813 | + upstream.addr[0], |
| 814 | + ), |
| 815 | + exc_info=e, |
792 | 816 | )
|
793 | 817 | do_close = True
|
794 | 818 | if not do_close:
|
795 |
| - assert isinstance(self.upstream.connection, ssl.SSLSocket) |
| 819 | + assert isinstance(upstream.connection, ssl.SSLSocket) |
796 | 820 | return do_close
|
797 | 821 |
|
798 | 822 | def wrap_client(self) -> bool:
|
799 | 823 | assert self.upstream is not None and self.flags.ca_signing_key_file is not None
|
800 |
| - assert isinstance(self.upstream.connection, ssl.SSLSocket) |
| 824 | + certificate: Optional[Dict[str, Any]] = None |
| 825 | + if isinstance(self.upstream.connection, ssl.SSLSocket): |
| 826 | + certificate = self.upstream.connection.getpeercert() |
| 827 | + else: |
| 828 | + assert self.upstream.is_proxy and self.request.host and self.request.port |
| 829 | + if self.flags.enable_conn_pool: |
| 830 | + assert self.upstream_conn_pool |
| 831 | + with self.lock: |
| 832 | + _, upstream = self.upstream_conn_pool.acquire( |
| 833 | + (text_(self.request.host), self.request.port), |
| 834 | + ) |
| 835 | + else: |
| 836 | + _, upstream = True, TcpServerConnection( |
| 837 | + text_(self.request.host), |
| 838 | + self.request.port, |
| 839 | + ) |
| 840 | + # Connect with overridden upstream IP and source address |
| 841 | + # if any of the plugin returned a non-null value. |
| 842 | + upstream.connect() |
| 843 | + upstream.connection.setblocking(False) |
| 844 | + do_close = self._wrap_server( |
| 845 | + upstream, |
| 846 | + host=self.request.host, |
| 847 | + ca_file=self.flags.ca_file, |
| 848 | + ) |
| 849 | + if do_close: |
| 850 | + return do_close |
| 851 | + assert isinstance(upstream.connection, ssl.SSLSocket) |
| 852 | + certificate = upstream.connection.getpeercert() |
| 853 | + assert certificate |
801 | 854 | do_close = False
|
802 | 855 | try:
|
803 | 856 | # TODO: Perform async certificate generation
|
804 |
| - generated_cert = self.generate_upstream_certificate( |
805 |
| - cast(Dict[str, Any], self.upstream.connection.getpeercert()), |
806 |
| - ) |
| 857 | + generated_cert = self.generate_upstream_certificate(certificate) |
807 | 858 | self.client.wrap(self.flags.ca_signing_key_file, generated_cert)
|
808 | 859 | except subprocess.TimeoutExpired as e: # Popen communicate timeout
|
809 | 860 | logger.exception(
|
|
0 commit comments