1from __future__ import annotations 
    2 
    3import io 
    4import socket 
    5import ssl 
    6import typing 
    7 
    8from ..exceptions import ProxySchemeUnsupported 
    9 
    10if typing.TYPE_CHECKING: 
    11    from typing_extensions import Self 
    12 
    13    from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT 
    14 
    15 
    16_WriteBuffer = typing.Union[bytearray, memoryview] 
    17_ReturnValue = typing.TypeVar("_ReturnValue") 
    18 
    19SSL_BLOCKSIZE = 16384 
    20 
    21 
    22class SSLTransport: 
    23    """ 
    24    The SSLTransport wraps an existing socket and establishes an SSL connection. 
    25 
    26    Contrary to Python's implementation of SSLSocket, it allows you to chain 
    27    multiple TLS connections together. It's particularly useful if you need to 
    28    implement TLS within TLS. 
    29 
    30    The class supports most of the socket API operations. 
    31    """ 
    32 
    33    @staticmethod 
    34    def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: 
    35        """ 
    36        Raises a ProxySchemeUnsupported if the provided ssl_context can't be used 
    37        for TLS in TLS. 
    38 
    39        The only requirement is that the ssl_context provides the 'wrap_bio' 
    40        methods. 
    41        """ 
    42 
    43        if not hasattr(ssl_context, "wrap_bio"): 
    44            raise ProxySchemeUnsupported( 
    45                "TLS in TLS requires SSLContext.wrap_bio() which isn't " 
    46                "available on non-native SSLContext" 
    47            ) 
    48 
    49    def __init__( 
    50        self, 
    51        socket: socket.socket, 
    52        ssl_context: ssl.SSLContext, 
    53        server_hostname: str | None = None, 
    54        suppress_ragged_eofs: bool = True, 
    55    ) -> None: 
    56        """ 
    57        Create an SSLTransport around socket using the provided ssl_context. 
    58        """ 
    59        self.incoming = ssl.MemoryBIO() 
    60        self.outgoing = ssl.MemoryBIO() 
    61 
    62        self.suppress_ragged_eofs = suppress_ragged_eofs 
    63        self.socket = socket 
    64 
    65        self.sslobj = ssl_context.wrap_bio( 
    66            self.incoming, self.outgoing, server_hostname=server_hostname 
    67        ) 
    68 
    69        # Perform initial handshake. 
    70        self._ssl_io_loop(self.sslobj.do_handshake) 
    71 
    72    def __enter__(self) -> Self: 
    73        return self 
    74 
    75    def __exit__(self, *_: typing.Any) -> None: 
    76        self.close() 
    77 
    78    def fileno(self) -> int: 
    79        return self.socket.fileno() 
    80 
    81    def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: 
    82        return self._wrap_ssl_read(len, buffer) 
    83 
    84    def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: 
    85        if flags != 0: 
    86            raise ValueError("non-zero flags not allowed in calls to recv") 
    87        return self._wrap_ssl_read(buflen) 
    88 
    89    def recv_into( 
    90        self, 
    91        buffer: _WriteBuffer, 
    92        nbytes: int | None = None, 
    93        flags: int = 0, 
    94    ) -> None | int | bytes: 
    95        if flags != 0: 
    96            raise ValueError("non-zero flags not allowed in calls to recv_into") 
    97        if nbytes is None: 
    98            nbytes = len(buffer) 
    99        return self.read(nbytes, buffer) 
    100 
    101    def sendall(self, data: bytes, flags: int = 0) -> None: 
    102        if flags != 0: 
    103            raise ValueError("non-zero flags not allowed in calls to sendall") 
    104        count = 0 
    105        with memoryview(data) as view, view.cast("B") as byte_view: 
    106            amount = len(byte_view) 
    107            while count < amount: 
    108                v = self.send(byte_view[count:]) 
    109                count += v 
    110 
    111    def send(self, data: bytes, flags: int = 0) -> int: 
    112        if flags != 0: 
    113            raise ValueError("non-zero flags not allowed in calls to send") 
    114        return self._ssl_io_loop(self.sslobj.write, data) 
    115 
    116    def makefile( 
    117        self, 
    118        mode: str, 
    119        buffering: int | None = None, 
    120        *, 
    121        encoding: str | None = None, 
    122        errors: str | None = None, 
    123        newline: str | None = None, 
    124    ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: 
    125        """ 
    126        Python's httpclient uses makefile and buffered io when reading HTTP 
    127        messages and we need to support it. 
    128 
    129        This is unfortunately a copy and paste of socket.py makefile with small 
    130        changes to point to the socket directly. 
    131        """ 
    132        if not set(mode) <= {"r", "w", "b"}: 
    133            raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") 
    134 
    135        writing = "w" in mode 
    136        reading = "r" in mode or not writing 
    137        assert reading or writing 
    138        binary = "b" in mode 
    139        rawmode = "" 
    140        if reading: 
    141            rawmode += "r" 
    142        if writing: 
    143            rawmode += "w" 
    144        raw = socket.SocketIO(self, rawmode)  # type: ignore[arg-type] 
    145        self.socket._io_refs += 1  # type: ignore[attr-defined] 
    146        if buffering is None: 
    147            buffering = -1 
    148        if buffering < 0: 
    149            buffering = io.DEFAULT_BUFFER_SIZE 
    150        if buffering == 0: 
    151            if not binary: 
    152                raise ValueError("unbuffered streams must be binary") 
    153            return raw 
    154        buffer: typing.BinaryIO 
    155        if reading and writing: 
    156            buffer = io.BufferedRWPair(raw, raw, buffering)  # type: ignore[assignment] 
    157        elif reading: 
    158            buffer = io.BufferedReader(raw, buffering) 
    159        else: 
    160            assert writing 
    161            buffer = io.BufferedWriter(raw, buffering) 
    162        if binary: 
    163            return buffer 
    164        text = io.TextIOWrapper(buffer, encoding, errors, newline) 
    165        text.mode = mode  # type: ignore[misc] 
    166        return text 
    167 
    168    def unwrap(self) -> None: 
    169        self._ssl_io_loop(self.sslobj.unwrap) 
    170 
    171    def close(self) -> None: 
    172        self.socket.close() 
    173 
    174    @typing.overload 
    175    def getpeercert( 
    176        self, binary_form: typing.Literal[False] = ... 
    177    ) -> _TYPE_PEER_CERT_RET_DICT | None: ... 
    178 
    179    @typing.overload 
    180    def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ... 
    181 
    182    def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: 
    183        return self.sslobj.getpeercert(binary_form)  # type: ignore[return-value] 
    184 
    185    def version(self) -> str | None: 
    186        return self.sslobj.version() 
    187 
    188    def cipher(self) -> tuple[str, str, int] | None: 
    189        return self.sslobj.cipher() 
    190 
    191    def selected_alpn_protocol(self) -> str | None: 
    192        return self.sslobj.selected_alpn_protocol() 
    193 
    194    def shared_ciphers(self) -> list[tuple[str, str, int]] | None: 
    195        return self.sslobj.shared_ciphers() 
    196 
    197    def compression(self) -> str | None: 
    198        return self.sslobj.compression() 
    199 
    200    def settimeout(self, value: float | None) -> None: 
    201        self.socket.settimeout(value) 
    202 
    203    def gettimeout(self) -> float | None: 
    204        return self.socket.gettimeout() 
    205 
    206    def _decref_socketios(self) -> None: 
    207        self.socket._decref_socketios()  # type: ignore[attr-defined] 
    208 
    209    def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: 
    210        try: 
    211            return self._ssl_io_loop(self.sslobj.read, len, buffer) 
    212        except ssl.SSLError as e: 
    213            if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: 
    214                return 0  # eof, return 0. 
    215            else: 
    216                raise 
    217 
    218    # func is sslobj.do_handshake or sslobj.unwrap 
    219    @typing.overload 
    220    def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ... 
    221 
    222    # func is sslobj.write, arg1 is data 
    223    @typing.overload 
    224    def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ... 
    225 
    226    # func is sslobj.read, arg1 is len, arg2 is buffer 
    227    @typing.overload 
    228    def _ssl_io_loop( 
    229        self, 
    230        func: typing.Callable[[int, bytearray | None], bytes], 
    231        arg1: int, 
    232        arg2: bytearray | None, 
    233    ) -> bytes: ... 
    234 
    235    def _ssl_io_loop( 
    236        self, 
    237        func: typing.Callable[..., _ReturnValue], 
    238        arg1: None | bytes | int = None, 
    239        arg2: bytearray | None = None, 
    240    ) -> _ReturnValue: 
    241        """Performs an I/O loop between incoming/outgoing and the socket.""" 
    242        should_loop = True 
    243        ret = None 
    244 
    245        while should_loop: 
    246            errno = None 
    247            try: 
    248                if arg1 is None and arg2 is None: 
    249                    ret = func() 
    250                elif arg2 is None: 
    251                    ret = func(arg1) 
    252                else: 
    253                    ret = func(arg1, arg2) 
    254            except ssl.SSLError as e: 
    255                if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): 
    256                    # WANT_READ, and WANT_WRITE are expected, others are not. 
    257                    raise e 
    258                errno = e.errno 
    259 
    260            buf = self.outgoing.read() 
    261            self.socket.sendall(buf) 
    262 
    263            if errno is None: 
    264                should_loop = False 
    265            elif errno == ssl.SSL_ERROR_WANT_READ: 
    266                buf = self.socket.recv(SSL_BLOCKSIZE) 
    267                if buf: 
    268                    self.incoming.write(buf) 
    269                else: 
    270                    self.incoming.write_eof() 
    271        return typing.cast(_ReturnValue, ret)