1from __future__ import annotations
2
3import io
4import socket
5import ssl
6import typing
7
8from ..exceptions import ProxySchemeUnsupported
9
10if typing.TYPE_CHECKING:
11 from typing_extensions import Self
12
13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
14
15
16_WriteBuffer = typing.Union[bytearray, memoryview]
17_ReturnValue = typing.TypeVar("_ReturnValue")
18
19SSL_BLOCKSIZE = 16384
20
21
22class SSLTransport:
23 """
24 The SSLTransport wraps an existing socket and establishes an SSL connection.
25
26 Contrary to Python's implementation of SSLSocket, it allows you to chain
27 multiple TLS connections together. It's particularly useful if you need to
28 implement TLS within TLS.
29
30 The class supports most of the socket API operations.
31 """
32
33 @staticmethod
34 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
35 """
36 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
37 for TLS in TLS.
38
39 The only requirement is that the ssl_context provides the 'wrap_bio'
40 methods.
41 """
42
43 if not hasattr(ssl_context, "wrap_bio"):
44 raise ProxySchemeUnsupported(
45 "TLS in TLS requires SSLContext.wrap_bio() which isn't "
46 "available on non-native SSLContext"
47 )
48
49 def __init__(
50 self,
51 socket: socket.socket,
52 ssl_context: ssl.SSLContext,
53 server_hostname: str | None = None,
54 suppress_ragged_eofs: bool = True,
55 ) -> None:
56 """
57 Create an SSLTransport around socket using the provided ssl_context.
58 """
59 self.incoming = ssl.MemoryBIO()
60 self.outgoing = ssl.MemoryBIO()
61
62 self.suppress_ragged_eofs = suppress_ragged_eofs
63 self.socket = socket
64
65 self.sslobj = ssl_context.wrap_bio(
66 self.incoming, self.outgoing, server_hostname=server_hostname
67 )
68
69 # Perform initial handshake.
70 self._ssl_io_loop(self.sslobj.do_handshake)
71
72 def __enter__(self) -> Self:
73 return self
74
75 def __exit__(self, *_: typing.Any) -> None:
76 self.close()
77
78 def fileno(self) -> int:
79 return self.socket.fileno()
80
81 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
82 return self._wrap_ssl_read(len, buffer)
83
84 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
85 if flags != 0:
86 raise ValueError("non-zero flags not allowed in calls to recv")
87 return self._wrap_ssl_read(buflen)
88
89 def recv_into(
90 self,
91 buffer: _WriteBuffer,
92 nbytes: int | None = None,
93 flags: int = 0,
94 ) -> None | int | bytes:
95 if flags != 0:
96 raise ValueError("non-zero flags not allowed in calls to recv_into")
97 if nbytes is None:
98 nbytes = len(buffer)
99 return self.read(nbytes, buffer)
100
101 def sendall(self, data: bytes, flags: int = 0) -> None:
102 if flags != 0:
103 raise ValueError("non-zero flags not allowed in calls to sendall")
104 count = 0
105 with memoryview(data) as view, view.cast("B") as byte_view:
106 amount = len(byte_view)
107 while count < amount:
108 v = self.send(byte_view[count:])
109 count += v
110
111 def send(self, data: bytes, flags: int = 0) -> int:
112 if flags != 0:
113 raise ValueError("non-zero flags not allowed in calls to send")
114 return self._ssl_io_loop(self.sslobj.write, data)
115
116 def makefile(
117 self,
118 mode: str,
119 buffering: int | None = None,
120 *,
121 encoding: str | None = None,
122 errors: str | None = None,
123 newline: str | None = None,
124 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
125 """
126 Python's httpclient uses makefile and buffered io when reading HTTP
127 messages and we need to support it.
128
129 This is unfortunately a copy and paste of socket.py makefile with small
130 changes to point to the socket directly.
131 """
132 if not set(mode) <= {"r", "w", "b"}:
133 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
134
135 writing = "w" in mode
136 reading = "r" in mode or not writing
137 assert reading or writing
138 binary = "b" in mode
139 rawmode = ""
140 if reading:
141 rawmode += "r"
142 if writing:
143 rawmode += "w"
144 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
145 self.socket._io_refs += 1 # type: ignore[attr-defined]
146 if buffering is None:
147 buffering = -1
148 if buffering < 0:
149 buffering = io.DEFAULT_BUFFER_SIZE
150 if buffering == 0:
151 if not binary:
152 raise ValueError("unbuffered streams must be binary")
153 return raw
154 buffer: typing.BinaryIO
155 if reading and writing:
156 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
157 elif reading:
158 buffer = io.BufferedReader(raw, buffering)
159 else:
160 assert writing
161 buffer = io.BufferedWriter(raw, buffering)
162 if binary:
163 return buffer
164 text = io.TextIOWrapper(buffer, encoding, errors, newline)
165 text.mode = mode # type: ignore[misc]
166 return text
167
168 def unwrap(self) -> None:
169 self._ssl_io_loop(self.sslobj.unwrap)
170
171 def close(self) -> None:
172 self.socket.close()
173
174 @typing.overload
175 def getpeercert(
176 self, binary_form: typing.Literal[False] = ...
177 ) -> _TYPE_PEER_CERT_RET_DICT | None: ...
178
179 @typing.overload
180 def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ...
181
182 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
183 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
184
185 def version(self) -> str | None:
186 return self.sslobj.version()
187
188 def cipher(self) -> tuple[str, str, int] | None:
189 return self.sslobj.cipher()
190
191 def selected_alpn_protocol(self) -> str | None:
192 return self.sslobj.selected_alpn_protocol()
193
194 def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
195 return self.sslobj.shared_ciphers()
196
197 def compression(self) -> str | None:
198 return self.sslobj.compression()
199
200 def settimeout(self, value: float | None) -> None:
201 self.socket.settimeout(value)
202
203 def gettimeout(self) -> float | None:
204 return self.socket.gettimeout()
205
206 def _decref_socketios(self) -> None:
207 self.socket._decref_socketios() # type: ignore[attr-defined]
208
209 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
210 try:
211 return self._ssl_io_loop(self.sslobj.read, len, buffer)
212 except ssl.SSLError as e:
213 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
214 return 0 # eof, return 0.
215 else:
216 raise
217
218 # func is sslobj.do_handshake or sslobj.unwrap
219 @typing.overload
220 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ...
221
222 # func is sslobj.write, arg1 is data
223 @typing.overload
224 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ...
225
226 # func is sslobj.read, arg1 is len, arg2 is buffer
227 @typing.overload
228 def _ssl_io_loop(
229 self,
230 func: typing.Callable[[int, bytearray | None], bytes],
231 arg1: int,
232 arg2: bytearray | None,
233 ) -> bytes: ...
234
235 def _ssl_io_loop(
236 self,
237 func: typing.Callable[..., _ReturnValue],
238 arg1: None | bytes | int = None,
239 arg2: bytearray | None = None,
240 ) -> _ReturnValue:
241 """Performs an I/O loop between incoming/outgoing and the socket."""
242 should_loop = True
243 ret = None
244
245 while should_loop:
246 errno = None
247 try:
248 if arg1 is None and arg2 is None:
249 ret = func()
250 elif arg2 is None:
251 ret = func(arg1)
252 else:
253 ret = func(arg1, arg2)
254 except ssl.SSLError as e:
255 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
256 # WANT_READ, and WANT_WRITE are expected, others are not.
257 raise e
258 errno = e.errno
259
260 buf = self.outgoing.read()
261 self.socket.sendall(buf)
262
263 if errno is None:
264 should_loop = False
265 elif errno == ssl.SSL_ERROR_WANT_READ:
266 buf = self.socket.recv(SSL_BLOCKSIZE)
267 if buf:
268 self.incoming.write(buf)
269 else:
270 self.incoming.write_eof()
271 return typing.cast(_ReturnValue, ret)