1from __future__ import annotations 
    2 
    3import logging 
    4import ssl 
    5 
    6import socksio 
    7 
    8from .._backends.auto import AutoBackend 
    9from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream 
    10from .._exceptions import ConnectionNotAvailable, ProxyError 
    11from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url 
    12from .._ssl import default_ssl_context 
    13from .._synchronization import AsyncLock 
    14from .._trace import Trace 
    15from .connection_pool import AsyncConnectionPool 
    16from .http11 import AsyncHTTP11Connection 
    17from .interfaces import AsyncConnectionInterface 
    18 
    19logger = logging.getLogger("httpcore.socks") 
    20 
    21 
    22AUTH_METHODS = { 
    23    b"\x00": "NO AUTHENTICATION REQUIRED", 
    24    b"\x01": "GSSAPI", 
    25    b"\x02": "USERNAME/PASSWORD", 
    26    b"\xff": "NO ACCEPTABLE METHODS", 
    27} 
    28 
    29REPLY_CODES = { 
    30    b"\x00": "Succeeded", 
    31    b"\x01": "General SOCKS server failure", 
    32    b"\x02": "Connection not allowed by ruleset", 
    33    b"\x03": "Network unreachable", 
    34    b"\x04": "Host unreachable", 
    35    b"\x05": "Connection refused", 
    36    b"\x06": "TTL expired", 
    37    b"\x07": "Command not supported", 
    38    b"\x08": "Address type not supported", 
    39} 
    40 
    41 
    42async def _init_socks5_connection( 
    43    stream: AsyncNetworkStream, 
    44    *, 
    45    host: bytes, 
    46    port: int, 
    47    auth: tuple[bytes, bytes] | None = None, 
    48) -> None: 
    49    conn = socksio.socks5.SOCKS5Connection() 
    50 
    51    # Auth method request 
    52    auth_method = ( 
    53        socksio.socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED 
    54        if auth is None 
    55        else socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD 
    56    ) 
    57    conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method])) 
    58    outgoing_bytes = conn.data_to_send() 
    59    await stream.write(outgoing_bytes) 
    60 
    61    # Auth method response 
    62    incoming_bytes = await stream.read(max_bytes=4096) 
    63    response = conn.receive_data(incoming_bytes) 
    64    assert isinstance(response, socksio.socks5.SOCKS5AuthReply) 
    65    if response.method != auth_method: 
    66        requested = AUTH_METHODS.get(auth_method, "UNKNOWN") 
    67        responded = AUTH_METHODS.get(response.method, "UNKNOWN") 
    68        raise ProxyError( 
    69            f"Requested {requested} from proxy server, but got {responded}." 
    70        ) 
    71 
    72    if response.method == socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD: 
    73        # Username/password request 
    74        assert auth is not None 
    75        username, password = auth 
    76        conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password)) 
    77        outgoing_bytes = conn.data_to_send() 
    78        await stream.write(outgoing_bytes) 
    79 
    80        # Username/password response 
    81        incoming_bytes = await stream.read(max_bytes=4096) 
    82        response = conn.receive_data(incoming_bytes) 
    83        assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply) 
    84        if not response.success: 
    85            raise ProxyError("Invalid username/password") 
    86 
    87    # Connect request 
    88    conn.send( 
    89        socksio.socks5.SOCKS5CommandRequest.from_address( 
    90            socksio.socks5.SOCKS5Command.CONNECT, (host, port) 
    91        ) 
    92    ) 
    93    outgoing_bytes = conn.data_to_send() 
    94    await stream.write(outgoing_bytes) 
    95 
    96    # Connect response 
    97    incoming_bytes = await stream.read(max_bytes=4096) 
    98    response = conn.receive_data(incoming_bytes) 
    99    assert isinstance(response, socksio.socks5.SOCKS5Reply) 
    100    if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED: 
    101        reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN") 
    102        raise ProxyError(f"Proxy Server could not connect: {reply_code}.") 
    103 
    104 
    105class AsyncSOCKSProxy(AsyncConnectionPool):  # pragma: nocover 
    106    """ 
    107    A connection pool that sends requests via an HTTP proxy. 
    108    """ 
    109 
    110    def __init__( 
    111        self, 
    112        proxy_url: URL | bytes | str, 
    113        proxy_auth: tuple[bytes | str, bytes | str] | None = None, 
    114        ssl_context: ssl.SSLContext | None = None, 
    115        max_connections: int | None = 10, 
    116        max_keepalive_connections: int | None = None, 
    117        keepalive_expiry: float | None = None, 
    118        http1: bool = True, 
    119        http2: bool = False, 
    120        retries: int = 0, 
    121        network_backend: AsyncNetworkBackend | None = None, 
    122    ) -> None: 
    123        """ 
    124        A connection pool for making HTTP requests. 
    125 
    126        Parameters: 
    127            proxy_url: The URL to use when connecting to the proxy server. 
    128                For example `"http://127.0.0.1:8080/"`. 
    129            ssl_context: An SSL context to use for verifying connections. 
    130                If not specified, the default `httpcore.default_ssl_context()` 
    131                will be used. 
    132            max_connections: The maximum number of concurrent HTTP connections that 
    133                the pool should allow. Any attempt to send a request on a pool that 
    134                would exceed this amount will block until a connection is available. 
    135            max_keepalive_connections: The maximum number of idle HTTP connections 
    136                that will be maintained in the pool. 
    137            keepalive_expiry: The duration in seconds that an idle HTTP connection 
    138                may be maintained for before being expired from the pool. 
    139            http1: A boolean indicating if HTTP/1.1 requests should be supported 
    140                by the connection pool. Defaults to True. 
    141            http2: A boolean indicating if HTTP/2 requests should be supported by 
    142                the connection pool. Defaults to False. 
    143            retries: The maximum number of retries when trying to establish 
    144                a connection. 
    145            local_address: Local address to connect from. Can also be used to 
    146                connect using a particular address family. Using 
    147                `local_address="0.0.0.0"` will connect using an `AF_INET` address 
    148                (IPv4), while using `local_address="::"` will connect using an 
    149                `AF_INET6` address (IPv6). 
    150            uds: Path to a Unix Domain Socket to use instead of TCP sockets. 
    151            network_backend: A backend instance to use for handling network I/O. 
    152        """ 
    153        super().__init__( 
    154            ssl_context=ssl_context, 
    155            max_connections=max_connections, 
    156            max_keepalive_connections=max_keepalive_connections, 
    157            keepalive_expiry=keepalive_expiry, 
    158            http1=http1, 
    159            http2=http2, 
    160            network_backend=network_backend, 
    161            retries=retries, 
    162        ) 
    163        self._ssl_context = ssl_context 
    164        self._proxy_url = enforce_url(proxy_url, name="proxy_url") 
    165        if proxy_auth is not None: 
    166            username, password = proxy_auth 
    167            username_bytes = enforce_bytes(username, name="proxy_auth") 
    168            password_bytes = enforce_bytes(password, name="proxy_auth") 
    169            self._proxy_auth: tuple[bytes, bytes] | None = ( 
    170                username_bytes, 
    171                password_bytes, 
    172            ) 
    173        else: 
    174            self._proxy_auth = None 
    175 
    176    def create_connection(self, origin: Origin) -> AsyncConnectionInterface: 
    177        return AsyncSocks5Connection( 
    178            proxy_origin=self._proxy_url.origin, 
    179            remote_origin=origin, 
    180            proxy_auth=self._proxy_auth, 
    181            ssl_context=self._ssl_context, 
    182            keepalive_expiry=self._keepalive_expiry, 
    183            http1=self._http1, 
    184            http2=self._http2, 
    185            network_backend=self._network_backend, 
    186        ) 
    187 
    188 
    189class AsyncSocks5Connection(AsyncConnectionInterface): 
    190    def __init__( 
    191        self, 
    192        proxy_origin: Origin, 
    193        remote_origin: Origin, 
    194        proxy_auth: tuple[bytes, bytes] | None = None, 
    195        ssl_context: ssl.SSLContext | None = None, 
    196        keepalive_expiry: float | None = None, 
    197        http1: bool = True, 
    198        http2: bool = False, 
    199        network_backend: AsyncNetworkBackend | None = None, 
    200    ) -> None: 
    201        self._proxy_origin = proxy_origin 
    202        self._remote_origin = remote_origin 
    203        self._proxy_auth = proxy_auth 
    204        self._ssl_context = ssl_context 
    205        self._keepalive_expiry = keepalive_expiry 
    206        self._http1 = http1 
    207        self._http2 = http2 
    208 
    209        self._network_backend: AsyncNetworkBackend = ( 
    210            AutoBackend() if network_backend is None else network_backend 
    211        ) 
    212        self._connect_lock = AsyncLock() 
    213        self._connection: AsyncConnectionInterface | None = None 
    214        self._connect_failed = False 
    215 
    216    async def handle_async_request(self, request: Request) -> Response: 
    217        timeouts = request.extensions.get("timeout", {}) 
    218        sni_hostname = request.extensions.get("sni_hostname", None) 
    219        timeout = timeouts.get("connect", None) 
    220 
    221        async with self._connect_lock: 
    222            if self._connection is None: 
    223                try: 
    224                    # Connect to the proxy 
    225                    kwargs = { 
    226                        "host": self._proxy_origin.host.decode("ascii"), 
    227                        "port": self._proxy_origin.port, 
    228                        "timeout": timeout, 
    229                    } 
    230                    async with Trace("connect_tcp", logger, request, kwargs) as trace: 
    231                        stream = await self._network_backend.connect_tcp(**kwargs) 
    232                        trace.return_value = stream 
    233 
    234                    # Connect to the remote host using socks5 
    235                    kwargs = { 
    236                        "stream": stream, 
    237                        "host": self._remote_origin.host.decode("ascii"), 
    238                        "port": self._remote_origin.port, 
    239                        "auth": self._proxy_auth, 
    240                    } 
    241                    async with Trace( 
    242                        "setup_socks5_connection", logger, request, kwargs 
    243                    ) as trace: 
    244                        await _init_socks5_connection(**kwargs) 
    245                        trace.return_value = stream 
    246 
    247                    # Upgrade the stream to SSL 
    248                    if self._remote_origin.scheme == b"https": 
    249                        ssl_context = ( 
    250                            default_ssl_context() 
    251                            if self._ssl_context is None 
    252                            else self._ssl_context 
    253                        ) 
    254                        alpn_protocols = ( 
    255                            ["http/1.1", "h2"] if self._http2 else ["http/1.1"] 
    256                        ) 
    257                        ssl_context.set_alpn_protocols(alpn_protocols) 
    258 
    259                        kwargs = { 
    260                            "ssl_context": ssl_context, 
    261                            "server_hostname": sni_hostname 
    262                            or self._remote_origin.host.decode("ascii"), 
    263                            "timeout": timeout, 
    264                        } 
    265                        async with Trace("start_tls", logger, request, kwargs) as trace: 
    266                            stream = await stream.start_tls(**kwargs) 
    267                            trace.return_value = stream 
    268 
    269                    # Determine if we should be using HTTP/1.1 or HTTP/2 
    270                    ssl_object = stream.get_extra_info("ssl_object") 
    271                    http2_negotiated = ( 
    272                        ssl_object is not None 
    273                        and ssl_object.selected_alpn_protocol() == "h2" 
    274                    ) 
    275 
    276                    # Create the HTTP/1.1 or HTTP/2 connection 
    277                    if http2_negotiated or ( 
    278                        self._http2 and not self._http1 
    279                    ):  # pragma: nocover 
    280                        from .http2 import AsyncHTTP2Connection 
    281 
    282                        self._connection = AsyncHTTP2Connection( 
    283                            origin=self._remote_origin, 
    284                            stream=stream, 
    285                            keepalive_expiry=self._keepalive_expiry, 
    286                        ) 
    287                    else: 
    288                        self._connection = AsyncHTTP11Connection( 
    289                            origin=self._remote_origin, 
    290                            stream=stream, 
    291                            keepalive_expiry=self._keepalive_expiry, 
    292                        ) 
    293                except Exception as exc: 
    294                    self._connect_failed = True 
    295                    raise exc 
    296            elif not self._connection.is_available():  # pragma: nocover 
    297                raise ConnectionNotAvailable() 
    298 
    299        return await self._connection.handle_async_request(request) 
    300 
    301    def can_handle_request(self, origin: Origin) -> bool: 
    302        return origin == self._remote_origin 
    303 
    304    async def aclose(self) -> None: 
    305        if self._connection is not None: 
    306            await self._connection.aclose() 
    307 
    308    def is_available(self) -> bool: 
    309        if self._connection is None:  # pragma: nocover 
    310            # If HTTP/2 support is enabled, and the resulting connection could 
    311            # end up as HTTP/2 then we should indicate the connection as being 
    312            # available to service multiple requests. 
    313            return ( 
    314                self._http2 
    315                and (self._remote_origin.scheme == b"https" or not self._http1) 
    316                and not self._connect_failed 
    317            ) 
    318        return self._connection.is_available() 
    319 
    320    def has_expired(self) -> bool: 
    321        if self._connection is None:  # pragma: nocover 
    322            return self._connect_failed 
    323        return self._connection.has_expired() 
    324 
    325    def is_idle(self) -> bool: 
    326        if self._connection is None:  # pragma: nocover 
    327            return self._connect_failed 
    328        return self._connection.is_idle() 
    329 
    330    def is_closed(self) -> bool: 
    331        if self._connection is None:  # pragma: nocover 
    332            return self._connect_failed 
    333        return self._connection.is_closed() 
    334 
    335    def info(self) -> str: 
    336        if self._connection is None:  # pragma: nocover 
    337            return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" 
    338        return self._connection.info() 
    339 
    340    def __repr__(self) -> str: 
    341        return f"<{self.__class__.__name__} [{self.info()}]>"