1from __future__ import annotations 
    2 
    3import base64 
    4import logging 
    5import ssl 
    6import typing 
    7 
    8from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend 
    9from .._exceptions import ProxyError 
    10from .._models import ( 
    11    URL, 
    12    Origin, 
    13    Request, 
    14    Response, 
    15    enforce_bytes, 
    16    enforce_headers, 
    17    enforce_url, 
    18) 
    19from .._ssl import default_ssl_context 
    20from .._synchronization import AsyncLock 
    21from .._trace import Trace 
    22from .connection import AsyncHTTPConnection 
    23from .connection_pool import AsyncConnectionPool 
    24from .http11 import AsyncHTTP11Connection 
    25from .interfaces import AsyncConnectionInterface 
    26 
    27ByteOrStr = typing.Union[bytes, str] 
    28HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]] 
    29HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr] 
    30 
    31 
    32logger = logging.getLogger("httpcore.proxy") 
    33 
    34 
    35def merge_headers( 
    36    default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, 
    37    override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, 
    38) -> list[tuple[bytes, bytes]]: 
    39    """ 
    40    Append default_headers and override_headers, de-duplicating if a key exists 
    41    in both cases. 
    42    """ 
    43    default_headers = [] if default_headers is None else list(default_headers) 
    44    override_headers = [] if override_headers is None else list(override_headers) 
    45    has_override = set(key.lower() for key, value in override_headers) 
    46    default_headers = [ 
    47        (key, value) 
    48        for key, value in default_headers 
    49        if key.lower() not in has_override 
    50    ] 
    51    return default_headers + override_headers 
    52 
    53 
    54class AsyncHTTPProxy(AsyncConnectionPool):  # pragma: nocover 
    55    """ 
    56    A connection pool that sends requests via an HTTP proxy. 
    57    """ 
    58 
    59    def __init__( 
    60        self, 
    61        proxy_url: URL | bytes | str, 
    62        proxy_auth: tuple[bytes | str, bytes | str] | None = None, 
    63        proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None, 
    64        ssl_context: ssl.SSLContext | None = None, 
    65        proxy_ssl_context: ssl.SSLContext | None = None, 
    66        max_connections: int | None = 10, 
    67        max_keepalive_connections: int | None = None, 
    68        keepalive_expiry: float | None = None, 
    69        http1: bool = True, 
    70        http2: bool = False, 
    71        retries: int = 0, 
    72        local_address: str | None = None, 
    73        uds: str | None = None, 
    74        network_backend: AsyncNetworkBackend | None = None, 
    75        socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 
    76    ) -> None: 
    77        """ 
    78        A connection pool for making HTTP requests. 
    79 
    80        Parameters: 
    81            proxy_url: The URL to use when connecting to the proxy server. 
    82                For example `"http://127.0.0.1:8080/"`. 
    83            proxy_auth: Any proxy authentication as a two-tuple of 
    84                (username, password). May be either bytes or ascii-only str. 
    85            proxy_headers: Any HTTP headers to use for the proxy requests. 
    86                For example `{"Proxy-Authorization": "Basic <username>:<password>"}`. 
    87            ssl_context: An SSL context to use for verifying connections. 
    88                If not specified, the default `httpcore.default_ssl_context()` 
    89                will be used. 
    90            proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin. 
    91            max_connections: The maximum number of concurrent HTTP connections that 
    92                the pool should allow. Any attempt to send a request on a pool that 
    93                would exceed this amount will block until a connection is available. 
    94            max_keepalive_connections: The maximum number of idle HTTP connections 
    95                that will be maintained in the pool. 
    96            keepalive_expiry: The duration in seconds that an idle HTTP connection 
    97                may be maintained for before being expired from the pool. 
    98            http1: A boolean indicating if HTTP/1.1 requests should be supported 
    99                by the connection pool. Defaults to True. 
    100            http2: A boolean indicating if HTTP/2 requests should be supported by 
    101                the connection pool. Defaults to False. 
    102            retries: The maximum number of retries when trying to establish 
    103                a connection. 
    104            local_address: Local address to connect from. Can also be used to 
    105                connect using a particular address family. Using 
    106                `local_address="0.0.0.0"` will connect using an `AF_INET` address 
    107                (IPv4), while using `local_address="::"` will connect using an 
    108                `AF_INET6` address (IPv6). 
    109            uds: Path to a Unix Domain Socket to use instead of TCP sockets. 
    110            network_backend: A backend instance to use for handling network I/O. 
    111        """ 
    112        super().__init__( 
    113            ssl_context=ssl_context, 
    114            max_connections=max_connections, 
    115            max_keepalive_connections=max_keepalive_connections, 
    116            keepalive_expiry=keepalive_expiry, 
    117            http1=http1, 
    118            http2=http2, 
    119            network_backend=network_backend, 
    120            retries=retries, 
    121            local_address=local_address, 
    122            uds=uds, 
    123            socket_options=socket_options, 
    124        ) 
    125 
    126        self._proxy_url = enforce_url(proxy_url, name="proxy_url") 
    127        if ( 
    128            self._proxy_url.scheme == b"http" and proxy_ssl_context is not None 
    129        ):  # pragma: no cover 
    130            raise RuntimeError( 
    131                "The `proxy_ssl_context` argument is not allowed for the http scheme" 
    132            ) 
    133 
    134        self._ssl_context = ssl_context 
    135        self._proxy_ssl_context = proxy_ssl_context 
    136        self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 
    137        if proxy_auth is not None: 
    138            username = enforce_bytes(proxy_auth[0], name="proxy_auth") 
    139            password = enforce_bytes(proxy_auth[1], name="proxy_auth") 
    140            userpass = username + b":" + password 
    141            authorization = b"Basic " + base64.b64encode(userpass) 
    142            self._proxy_headers = [ 
    143                (b"Proxy-Authorization", authorization) 
    144            ] + self._proxy_headers 
    145 
    146    def create_connection(self, origin: Origin) -> AsyncConnectionInterface: 
    147        if origin.scheme == b"http": 
    148            return AsyncForwardHTTPConnection( 
    149                proxy_origin=self._proxy_url.origin, 
    150                proxy_headers=self._proxy_headers, 
    151                remote_origin=origin, 
    152                keepalive_expiry=self._keepalive_expiry, 
    153                network_backend=self._network_backend, 
    154                proxy_ssl_context=self._proxy_ssl_context, 
    155            ) 
    156        return AsyncTunnelHTTPConnection( 
    157            proxy_origin=self._proxy_url.origin, 
    158            proxy_headers=self._proxy_headers, 
    159            remote_origin=origin, 
    160            ssl_context=self._ssl_context, 
    161            proxy_ssl_context=self._proxy_ssl_context, 
    162            keepalive_expiry=self._keepalive_expiry, 
    163            http1=self._http1, 
    164            http2=self._http2, 
    165            network_backend=self._network_backend, 
    166        ) 
    167 
    168 
    169class AsyncForwardHTTPConnection(AsyncConnectionInterface): 
    170    def __init__( 
    171        self, 
    172        proxy_origin: Origin, 
    173        remote_origin: Origin, 
    174        proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None, 
    175        keepalive_expiry: float | None = None, 
    176        network_backend: AsyncNetworkBackend | None = None, 
    177        socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 
    178        proxy_ssl_context: ssl.SSLContext | None = None, 
    179    ) -> None: 
    180        self._connection = AsyncHTTPConnection( 
    181            origin=proxy_origin, 
    182            keepalive_expiry=keepalive_expiry, 
    183            network_backend=network_backend, 
    184            socket_options=socket_options, 
    185            ssl_context=proxy_ssl_context, 
    186        ) 
    187        self._proxy_origin = proxy_origin 
    188        self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 
    189        self._remote_origin = remote_origin 
    190 
    191    async def handle_async_request(self, request: Request) -> Response: 
    192        headers = merge_headers(self._proxy_headers, request.headers) 
    193        url = URL( 
    194            scheme=self._proxy_origin.scheme, 
    195            host=self._proxy_origin.host, 
    196            port=self._proxy_origin.port, 
    197            target=bytes(request.url), 
    198        ) 
    199        proxy_request = Request( 
    200            method=request.method, 
    201            url=url, 
    202            headers=headers, 
    203            content=request.stream, 
    204            extensions=request.extensions, 
    205        ) 
    206        return await self._connection.handle_async_request(proxy_request) 
    207 
    208    def can_handle_request(self, origin: Origin) -> bool: 
    209        return origin == self._remote_origin 
    210 
    211    async def aclose(self) -> None: 
    212        await self._connection.aclose() 
    213 
    214    def info(self) -> str: 
    215        return self._connection.info() 
    216 
    217    def is_available(self) -> bool: 
    218        return self._connection.is_available() 
    219 
    220    def has_expired(self) -> bool: 
    221        return self._connection.has_expired() 
    222 
    223    def is_idle(self) -> bool: 
    224        return self._connection.is_idle() 
    225 
    226    def is_closed(self) -> bool: 
    227        return self._connection.is_closed() 
    228 
    229    def __repr__(self) -> str: 
    230        return f"<{self.__class__.__name__} [{self.info()}]>" 
    231 
    232 
    233class AsyncTunnelHTTPConnection(AsyncConnectionInterface): 
    234    def __init__( 
    235        self, 
    236        proxy_origin: Origin, 
    237        remote_origin: Origin, 
    238        ssl_context: ssl.SSLContext | None = None, 
    239        proxy_ssl_context: ssl.SSLContext | None = None, 
    240        proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None, 
    241        keepalive_expiry: float | None = None, 
    242        http1: bool = True, 
    243        http2: bool = False, 
    244        network_backend: AsyncNetworkBackend | None = None, 
    245        socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 
    246    ) -> None: 
    247        self._connection: AsyncConnectionInterface = AsyncHTTPConnection( 
    248            origin=proxy_origin, 
    249            keepalive_expiry=keepalive_expiry, 
    250            network_backend=network_backend, 
    251            socket_options=socket_options, 
    252            ssl_context=proxy_ssl_context, 
    253        ) 
    254        self._proxy_origin = proxy_origin 
    255        self._remote_origin = remote_origin 
    256        self._ssl_context = ssl_context 
    257        self._proxy_ssl_context = proxy_ssl_context 
    258        self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 
    259        self._keepalive_expiry = keepalive_expiry 
    260        self._http1 = http1 
    261        self._http2 = http2 
    262        self._connect_lock = AsyncLock() 
    263        self._connected = False 
    264 
    265    async def handle_async_request(self, request: Request) -> Response: 
    266        timeouts = request.extensions.get("timeout", {}) 
    267        timeout = timeouts.get("connect", None) 
    268 
    269        async with self._connect_lock: 
    270            if not self._connected: 
    271                target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) 
    272 
    273                connect_url = URL( 
    274                    scheme=self._proxy_origin.scheme, 
    275                    host=self._proxy_origin.host, 
    276                    port=self._proxy_origin.port, 
    277                    target=target, 
    278                ) 
    279                connect_headers = merge_headers( 
    280                    [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers 
    281                ) 
    282                connect_request = Request( 
    283                    method=b"CONNECT", 
    284                    url=connect_url, 
    285                    headers=connect_headers, 
    286                    extensions=request.extensions, 
    287                ) 
    288                connect_response = await self._connection.handle_async_request( 
    289                    connect_request 
    290                ) 
    291 
    292                if connect_response.status < 200 or connect_response.status > 299: 
    293                    reason_bytes = connect_response.extensions.get("reason_phrase", b"") 
    294                    reason_str = reason_bytes.decode("ascii", errors="ignore") 
    295                    msg = "%d %s" % (connect_response.status, reason_str) 
    296                    await self._connection.aclose() 
    297                    raise ProxyError(msg) 
    298 
    299                stream = connect_response.extensions["network_stream"] 
    300 
    301                # Upgrade the stream to SSL 
    302                ssl_context = ( 
    303                    default_ssl_context() 
    304                    if self._ssl_context is None 
    305                    else self._ssl_context 
    306                ) 
    307                alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] 
    308                ssl_context.set_alpn_protocols(alpn_protocols) 
    309 
    310                kwargs = { 
    311                    "ssl_context": ssl_context, 
    312                    "server_hostname": self._remote_origin.host.decode("ascii"), 
    313                    "timeout": timeout, 
    314                } 
    315                async with Trace("start_tls", logger, request, kwargs) as trace: 
    316                    stream = await stream.start_tls(**kwargs) 
    317                    trace.return_value = stream 
    318 
    319                # Determine if we should be using HTTP/1.1 or HTTP/2 
    320                ssl_object = stream.get_extra_info("ssl_object") 
    321                http2_negotiated = ( 
    322                    ssl_object is not None 
    323                    and ssl_object.selected_alpn_protocol() == "h2" 
    324                ) 
    325 
    326                # Create the HTTP/1.1 or HTTP/2 connection 
    327                if http2_negotiated or (self._http2 and not self._http1): 
    328                    from .http2 import AsyncHTTP2Connection 
    329 
    330                    self._connection = AsyncHTTP2Connection( 
    331                        origin=self._remote_origin, 
    332                        stream=stream, 
    333                        keepalive_expiry=self._keepalive_expiry, 
    334                    ) 
    335                else: 
    336                    self._connection = AsyncHTTP11Connection( 
    337                        origin=self._remote_origin, 
    338                        stream=stream, 
    339                        keepalive_expiry=self._keepalive_expiry, 
    340                    ) 
    341 
    342                self._connected = True 
    343        return await self._connection.handle_async_request(request) 
    344 
    345    def can_handle_request(self, origin: Origin) -> bool: 
    346        return origin == self._remote_origin 
    347 
    348    async def aclose(self) -> None: 
    349        await self._connection.aclose() 
    350 
    351    def info(self) -> str: 
    352        return self._connection.info() 
    353 
    354    def is_available(self) -> bool: 
    355        return self._connection.is_available() 
    356 
    357    def has_expired(self) -> bool: 
    358        return self._connection.has_expired() 
    359 
    360    def is_idle(self) -> bool: 
    361        return self._connection.is_idle() 
    362 
    363    def is_closed(self) -> bool: 
    364        return self._connection.is_closed() 
    365 
    366    def __repr__(self) -> str: 
    367        return f"<{self.__class__.__name__} [{self.info()}]>"