1from __future__ import annotations 
    2 
    3import functools 
    4import socket 
    5import ssl 
    6import sys 
    7import typing 
    8 
    9from .._exceptions import ( 
    10    ConnectError, 
    11    ConnectTimeout, 
    12    ExceptionMapping, 
    13    ReadError, 
    14    ReadTimeout, 
    15    WriteError, 
    16    WriteTimeout, 
    17    map_exceptions, 
    18) 
    19from .._utils import is_socket_readable 
    20from .base import SOCKET_OPTION, NetworkBackend, NetworkStream 
    21 
    22 
    23class TLSinTLSStream(NetworkStream):  # pragma: no cover 
    24    """ 
    25    Because the standard `SSLContext.wrap_socket` method does 
    26    not work for `SSLSocket` objects, we need this class 
    27    to implement TLS stream using an underlying `SSLObject` 
    28    instance in order to support TLS on top of TLS. 
    29    """ 
    30 
    31    # Defined in RFC 8449 
    32    TLS_RECORD_SIZE = 16384 
    33 
    34    def __init__( 
    35        self, 
    36        sock: socket.socket, 
    37        ssl_context: ssl.SSLContext, 
    38        server_hostname: str | None = None, 
    39        timeout: float | None = None, 
    40    ): 
    41        self._sock = sock 
    42        self._incoming = ssl.MemoryBIO() 
    43        self._outgoing = ssl.MemoryBIO() 
    44 
    45        self.ssl_obj = ssl_context.wrap_bio( 
    46            incoming=self._incoming, 
    47            outgoing=self._outgoing, 
    48            server_hostname=server_hostname, 
    49        ) 
    50 
    51        self._sock.settimeout(timeout) 
    52        self._perform_io(self.ssl_obj.do_handshake) 
    53 
    54    def _perform_io( 
    55        self, 
    56        func: typing.Callable[..., typing.Any], 
    57    ) -> typing.Any: 
    58        ret = None 
    59 
    60        while True: 
    61            errno = None 
    62            try: 
    63                ret = func() 
    64            except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: 
    65                errno = e.errno 
    66 
    67            self._sock.sendall(self._outgoing.read()) 
    68 
    69            if errno == ssl.SSL_ERROR_WANT_READ: 
    70                buf = self._sock.recv(self.TLS_RECORD_SIZE) 
    71 
    72                if buf: 
    73                    self._incoming.write(buf) 
    74                else: 
    75                    self._incoming.write_eof() 
    76            if errno is None: 
    77                return ret 
    78 
    79    def read(self, max_bytes: int, timeout: float | None = None) -> bytes: 
    80        exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} 
    81        with map_exceptions(exc_map): 
    82            self._sock.settimeout(timeout) 
    83            return typing.cast( 
    84                bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes)) 
    85            ) 
    86 
    87    def write(self, buffer: bytes, timeout: float | None = None) -> None: 
    88        exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} 
    89        with map_exceptions(exc_map): 
    90            self._sock.settimeout(timeout) 
    91            while buffer: 
    92                nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer)) 
    93                buffer = buffer[nsent:] 
    94 
    95    def close(self) -> None: 
    96        self._sock.close() 
    97 
    98    def start_tls( 
    99        self, 
    100        ssl_context: ssl.SSLContext, 
    101        server_hostname: str | None = None, 
    102        timeout: float | None = None, 
    103    ) -> NetworkStream: 
    104        raise NotImplementedError() 
    105 
    106    def get_extra_info(self, info: str) -> typing.Any: 
    107        if info == "ssl_object": 
    108            return self.ssl_obj 
    109        if info == "client_addr": 
    110            return self._sock.getsockname() 
    111        if info == "server_addr": 
    112            return self._sock.getpeername() 
    113        if info == "socket": 
    114            return self._sock 
    115        if info == "is_readable": 
    116            return is_socket_readable(self._sock) 
    117        return None 
    118 
    119 
    120class SyncStream(NetworkStream): 
    121    def __init__(self, sock: socket.socket) -> None: 
    122        self._sock = sock 
    123 
    124    def read(self, max_bytes: int, timeout: float | None = None) -> bytes: 
    125        exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} 
    126        with map_exceptions(exc_map): 
    127            self._sock.settimeout(timeout) 
    128            return self._sock.recv(max_bytes) 
    129 
    130    def write(self, buffer: bytes, timeout: float | None = None) -> None: 
    131        if not buffer: 
    132            return 
    133 
    134        exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} 
    135        with map_exceptions(exc_map): 
    136            while buffer: 
    137                self._sock.settimeout(timeout) 
    138                n = self._sock.send(buffer) 
    139                buffer = buffer[n:] 
    140 
    141    def close(self) -> None: 
    142        self._sock.close() 
    143 
    144    def start_tls( 
    145        self, 
    146        ssl_context: ssl.SSLContext, 
    147        server_hostname: str | None = None, 
    148        timeout: float | None = None, 
    149    ) -> NetworkStream: 
    150        exc_map: ExceptionMapping = { 
    151            socket.timeout: ConnectTimeout, 
    152            OSError: ConnectError, 
    153        } 
    154        with map_exceptions(exc_map): 
    155            try: 
    156                if isinstance(self._sock, ssl.SSLSocket):  # pragma: no cover 
    157                    # If the underlying socket has already been upgraded 
    158                    # to the TLS layer (i.e. is an instance of SSLSocket), 
    159                    # we need some additional smarts to support TLS-in-TLS. 
    160                    return TLSinTLSStream( 
    161                        self._sock, ssl_context, server_hostname, timeout 
    162                    ) 
    163                else: 
    164                    self._sock.settimeout(timeout) 
    165                    sock = ssl_context.wrap_socket( 
    166                        self._sock, server_hostname=server_hostname 
    167                    ) 
    168            except Exception as exc:  # pragma: nocover 
    169                self.close() 
    170                raise exc 
    171        return SyncStream(sock) 
    172 
    173    def get_extra_info(self, info: str) -> typing.Any: 
    174        if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket): 
    175            return self._sock._sslobj  # type: ignore 
    176        if info == "client_addr": 
    177            return self._sock.getsockname() 
    178        if info == "server_addr": 
    179            return self._sock.getpeername() 
    180        if info == "socket": 
    181            return self._sock 
    182        if info == "is_readable": 
    183            return is_socket_readable(self._sock) 
    184        return None 
    185 
    186 
    187class SyncBackend(NetworkBackend): 
    188    def connect_tcp( 
    189        self, 
    190        host: str, 
    191        port: int, 
    192        timeout: float | None = None, 
    193        local_address: str | None = None, 
    194        socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 
    195    ) -> NetworkStream: 
    196        # Note that we automatically include `TCP_NODELAY` 
    197        # in addition to any other custom socket options. 
    198        if socket_options is None: 
    199            socket_options = []  # pragma: no cover 
    200        address = (host, port) 
    201        source_address = None if local_address is None else (local_address, 0) 
    202        exc_map: ExceptionMapping = { 
    203            socket.timeout: ConnectTimeout, 
    204            OSError: ConnectError, 
    205        } 
    206 
    207        with map_exceptions(exc_map): 
    208            sock = socket.create_connection( 
    209                address, 
    210                timeout, 
    211                source_address=source_address, 
    212            ) 
    213            for option in socket_options: 
    214                sock.setsockopt(*option)  # pragma: no cover 
    215            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 
    216        return SyncStream(sock) 
    217 
    218    def connect_unix_socket( 
    219        self, 
    220        path: str, 
    221        timeout: float | None = None, 
    222        socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 
    223    ) -> NetworkStream:  # pragma: nocover 
    224        if sys.platform == "win32": 
    225            raise RuntimeError( 
    226                "Attempted to connect to a UNIX socket on a Windows system." 
    227            ) 
    228        if socket_options is None: 
    229            socket_options = [] 
    230 
    231        exc_map: ExceptionMapping = { 
    232            socket.timeout: ConnectTimeout, 
    233            OSError: ConnectError, 
    234        } 
    235        with map_exceptions(exc_map): 
    236            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 
    237            for option in socket_options: 
    238                sock.setsockopt(*option) 
    239            sock.settimeout(timeout) 
    240            sock.connect(path) 
    241        return SyncStream(sock)