1from __future__ import annotations 
    2 
    3import io 
    4import socket 
    5import ssl 
    6import typing 
    7 
    8from ..exceptions import ProxySchemeUnsupported 
    9 
    10if typing.TYPE_CHECKING: 
    11    from typing import Literal 
    12 
    13    from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT 
    14 
    15 
    16_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport") 
    17_WriteBuffer = typing.Union[bytearray, memoryview] 
    18_ReturnValue = typing.TypeVar("_ReturnValue") 
    19 
    20SSL_BLOCKSIZE = 16384 
    21 
    22 
    23class SSLTransport: 
    24    """ 
    25    The SSLTransport wraps an existing socket and establishes an SSL connection. 
    26 
    27    Contrary to Python's implementation of SSLSocket, it allows you to chain 
    28    multiple TLS connections together. It's particularly useful if you need to 
    29    implement TLS within TLS. 
    30 
    31    The class supports most of the socket API operations. 
    32    """ 
    33 
    34    @staticmethod 
    35    def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: 
    36        """ 
    37        Raises a ProxySchemeUnsupported if the provided ssl_context can't be used 
    38        for TLS in TLS. 
    39 
    40        The only requirement is that the ssl_context provides the 'wrap_bio' 
    41        methods. 
    42        """ 
    43 
    44        if not hasattr(ssl_context, "wrap_bio"): 
    45            raise ProxySchemeUnsupported( 
    46                "TLS in TLS requires SSLContext.wrap_bio() which isn't " 
    47                "available on non-native SSLContext" 
    48            ) 
    49 
    50    def __init__( 
    51        self, 
    52        socket: socket.socket, 
    53        ssl_context: ssl.SSLContext, 
    54        server_hostname: str | None = None, 
    55        suppress_ragged_eofs: bool = True, 
    56    ) -> None: 
    57        """ 
    58        Create an SSLTransport around socket using the provided ssl_context. 
    59        """ 
    60        self.incoming = ssl.MemoryBIO() 
    61        self.outgoing = ssl.MemoryBIO() 
    62 
    63        self.suppress_ragged_eofs = suppress_ragged_eofs 
    64        self.socket = socket 
    65 
    66        self.sslobj = ssl_context.wrap_bio( 
    67            self.incoming, self.outgoing, server_hostname=server_hostname 
    68        ) 
    69 
    70        # Perform initial handshake. 
    71        self._ssl_io_loop(self.sslobj.do_handshake) 
    72 
    73    def __enter__(self: _SelfT) -> _SelfT: 
    74        return self 
    75 
    76    def __exit__(self, *_: typing.Any) -> None: 
    77        self.close() 
    78 
    79    def fileno(self) -> int: 
    80        return self.socket.fileno() 
    81 
    82    def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: 
    83        return self._wrap_ssl_read(len, buffer) 
    84 
    85    def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: 
    86        if flags != 0: 
    87            raise ValueError("non-zero flags not allowed in calls to recv") 
    88        return self._wrap_ssl_read(buflen) 
    89 
    90    def recv_into( 
    91        self, 
    92        buffer: _WriteBuffer, 
    93        nbytes: int | None = None, 
    94        flags: int = 0, 
    95    ) -> None | int | bytes: 
    96        if flags != 0: 
    97            raise ValueError("non-zero flags not allowed in calls to recv_into") 
    98        if nbytes is None: 
    99            nbytes = len(buffer) 
    100        return self.read(nbytes, buffer) 
    101 
    102    def sendall(self, data: bytes, flags: int = 0) -> None: 
    103        if flags != 0: 
    104            raise ValueError("non-zero flags not allowed in calls to sendall") 
    105        count = 0 
    106        with memoryview(data) as view, view.cast("B") as byte_view: 
    107            amount = len(byte_view) 
    108            while count < amount: 
    109                v = self.send(byte_view[count:]) 
    110                count += v 
    111 
    112    def send(self, data: bytes, flags: int = 0) -> int: 
    113        if flags != 0: 
    114            raise ValueError("non-zero flags not allowed in calls to send") 
    115        return self._ssl_io_loop(self.sslobj.write, data) 
    116 
    117    def makefile( 
    118        self, 
    119        mode: str, 
    120        buffering: int | None = None, 
    121        *, 
    122        encoding: str | None = None, 
    123        errors: str | None = None, 
    124        newline: str | None = None, 
    125    ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: 
    126        """ 
    127        Python's httpclient uses makefile and buffered io when reading HTTP 
    128        messages and we need to support it. 
    129 
    130        This is unfortunately a copy and paste of socket.py makefile with small 
    131        changes to point to the socket directly. 
    132        """ 
    133        if not set(mode) <= {"r", "w", "b"}: 
    134            raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") 
    135 
    136        writing = "w" in mode 
    137        reading = "r" in mode or not writing 
    138        assert reading or writing 
    139        binary = "b" in mode 
    140        rawmode = "" 
    141        if reading: 
    142            rawmode += "r" 
    143        if writing: 
    144            rawmode += "w" 
    145        raw = socket.SocketIO(self, rawmode)  # type: ignore[arg-type] 
    146        self.socket._io_refs += 1  # type: ignore[attr-defined] 
    147        if buffering is None: 
    148            buffering = -1 
    149        if buffering < 0: 
    150            buffering = io.DEFAULT_BUFFER_SIZE 
    151        if buffering == 0: 
    152            if not binary: 
    153                raise ValueError("unbuffered streams must be binary") 
    154            return raw 
    155        buffer: typing.BinaryIO 
    156        if reading and writing: 
    157            buffer = io.BufferedRWPair(raw, raw, buffering)  # type: ignore[assignment] 
    158        elif reading: 
    159            buffer = io.BufferedReader(raw, buffering) 
    160        else: 
    161            assert writing 
    162            buffer = io.BufferedWriter(raw, buffering) 
    163        if binary: 
    164            return buffer 
    165        text = io.TextIOWrapper(buffer, encoding, errors, newline) 
    166        text.mode = mode  # type: ignore[misc] 
    167        return text 
    168 
    169    def unwrap(self) -> None: 
    170        self._ssl_io_loop(self.sslobj.unwrap) 
    171 
    172    def close(self) -> None: 
    173        self.socket.close() 
    174 
    175    @typing.overload 
    176    def getpeercert( 
    177        self, binary_form: Literal[False] = ... 
    178    ) -> _TYPE_PEER_CERT_RET_DICT | None: 
    179        ... 
    180 
    181    @typing.overload 
    182    def getpeercert(self, binary_form: Literal[True]) -> bytes | None: 
    183        ... 
    184 
    185    def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: 
    186        return self.sslobj.getpeercert(binary_form)  # type: ignore[return-value] 
    187 
    188    def version(self) -> str | None: 
    189        return self.sslobj.version() 
    190 
    191    def cipher(self) -> tuple[str, str, int] | None: 
    192        return self.sslobj.cipher() 
    193 
    194    def selected_alpn_protocol(self) -> str | None: 
    195        return self.sslobj.selected_alpn_protocol() 
    196 
    197    def selected_npn_protocol(self) -> str | None: 
    198        return self.sslobj.selected_npn_protocol() 
    199 
    200    def shared_ciphers(self) -> list[tuple[str, str, int]] | None: 
    201        return self.sslobj.shared_ciphers() 
    202 
    203    def compression(self) -> str | None: 
    204        return self.sslobj.compression() 
    205 
    206    def settimeout(self, value: float | None) -> None: 
    207        self.socket.settimeout(value) 
    208 
    209    def gettimeout(self) -> float | None: 
    210        return self.socket.gettimeout() 
    211 
    212    def _decref_socketios(self) -> None: 
    213        self.socket._decref_socketios()  # type: ignore[attr-defined] 
    214 
    215    def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: 
    216        try: 
    217            return self._ssl_io_loop(self.sslobj.read, len, buffer) 
    218        except ssl.SSLError as e: 
    219            if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: 
    220                return 0  # eof, return 0. 
    221            else: 
    222                raise 
    223 
    224    # func is sslobj.do_handshake or sslobj.unwrap 
    225    @typing.overload 
    226    def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: 
    227        ... 
    228 
    229    # func is sslobj.write, arg1 is data 
    230    @typing.overload 
    231    def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: 
    232        ... 
    233 
    234    # func is sslobj.read, arg1 is len, arg2 is buffer 
    235    @typing.overload 
    236    def _ssl_io_loop( 
    237        self, 
    238        func: typing.Callable[[int, bytearray | None], bytes], 
    239        arg1: int, 
    240        arg2: bytearray | None, 
    241    ) -> bytes: 
    242        ... 
    243 
    244    def _ssl_io_loop( 
    245        self, 
    246        func: typing.Callable[..., _ReturnValue], 
    247        arg1: None | bytes | int = None, 
    248        arg2: bytearray | None = None, 
    249    ) -> _ReturnValue: 
    250        """Performs an I/O loop between incoming/outgoing and the socket.""" 
    251        should_loop = True 
    252        ret = None 
    253 
    254        while should_loop: 
    255            errno = None 
    256            try: 
    257                if arg1 is None and arg2 is None: 
    258                    ret = func() 
    259                elif arg2 is None: 
    260                    ret = func(arg1) 
    261                else: 
    262                    ret = func(arg1, arg2) 
    263            except ssl.SSLError as e: 
    264                if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): 
    265                    # WANT_READ, and WANT_WRITE are expected, others are not. 
    266                    raise e 
    267                errno = e.errno 
    268 
    269            buf = self.outgoing.read() 
    270            self.socket.sendall(buf) 
    271 
    272            if errno is None: 
    273                should_loop = False 
    274            elif errno == ssl.SSL_ERROR_WANT_READ: 
    275                buf = self.socket.recv(SSL_BLOCKSIZE) 
    276                if buf: 
    277                    self.incoming.write(buf) 
    278                else: 
    279                    self.incoming.write_eof() 
    280        return typing.cast(_ReturnValue, ret)