1from __future__ import annotations
2
3import io
4import socket
5import ssl
6import typing
7
8from ..exceptions import ProxySchemeUnsupported
9
10if typing.TYPE_CHECKING:
11 from typing_extensions import Self
12
13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
14
15
16_WriteBuffer = typing.Union[bytearray, memoryview]
17_ReturnValue = typing.TypeVar("_ReturnValue")
18
19SSL_BLOCKSIZE = 16384
20
21
22class SSLTransport:
23 """
24 The SSLTransport wraps an existing socket and establishes an SSL connection.
25
26 Contrary to Python's implementation of SSLSocket, it allows you to chain
27 multiple TLS connections together. It's particularly useful if you need to
28 implement TLS within TLS.
29
30 The class supports most of the socket API operations.
31 """
32
33 @staticmethod
34 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
35 """
36 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
37 for TLS in TLS.
38
39 The only requirement is that the ssl_context provides the 'wrap_bio'
40 methods.
41 """
42
43 if not hasattr(ssl_context, "wrap_bio"):
44 raise ProxySchemeUnsupported(
45 "TLS in TLS requires SSLContext.wrap_bio() which isn't "
46 "available on non-native SSLContext"
47 )
48
49 def __init__(
50 self,
51 socket: socket.socket,
52 ssl_context: ssl.SSLContext,
53 server_hostname: str | None = None,
54 suppress_ragged_eofs: bool = True,
55 ) -> None:
56 """
57 Create an SSLTransport around socket using the provided ssl_context.
58 """
59 self.incoming = ssl.MemoryBIO()
60 self.outgoing = ssl.MemoryBIO()
61
62 self.suppress_ragged_eofs = suppress_ragged_eofs
63 self.socket = socket
64
65 self.sslobj = ssl_context.wrap_bio(
66 self.incoming, self.outgoing, server_hostname=server_hostname
67 )
68
69 # Perform initial handshake.
70 self._ssl_io_loop(self.sslobj.do_handshake)
71
72 def __enter__(self) -> Self:
73 return self
74
75 def __exit__(self, *_: typing.Any) -> None:
76 self.close()
77
78 def fileno(self) -> int:
79 return self.socket.fileno()
80
81 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
82 return self._wrap_ssl_read(len, buffer)
83
84 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
85 if flags != 0:
86 raise ValueError("non-zero flags not allowed in calls to recv")
87 return self._wrap_ssl_read(buflen)
88
89 def recv_into(
90 self,
91 buffer: _WriteBuffer,
92 nbytes: int | None = None,
93 flags: int = 0,
94 ) -> None | int | bytes:
95 if flags != 0:
96 raise ValueError("non-zero flags not allowed in calls to recv_into")
97 if nbytes is None:
98 nbytes = len(buffer)
99 return self.read(nbytes, buffer)
100
101 def sendall(self, data: bytes, flags: int = 0) -> None:
102 if flags != 0:
103 raise ValueError("non-zero flags not allowed in calls to sendall")
104 count = 0
105 with memoryview(data) as view, view.cast("B") as byte_view:
106 amount = len(byte_view)
107 while count < amount:
108 v = self.send(byte_view[count:])
109 count += v
110
111 def send(self, data: bytes, flags: int = 0) -> int:
112 if flags != 0:
113 raise ValueError("non-zero flags not allowed in calls to send")
114 return self._ssl_io_loop(self.sslobj.write, data)
115
116 def makefile(
117 self,
118 mode: str,
119 buffering: int | None = None,
120 *,
121 encoding: str | None = None,
122 errors: str | None = None,
123 newline: str | None = None,
124 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
125 """
126 Python's httpclient uses makefile and buffered io when reading HTTP
127 messages and we need to support it.
128
129 This is unfortunately a copy and paste of socket.py makefile with small
130 changes to point to the socket directly.
131 """
132 if not set(mode) <= {"r", "w", "b"}:
133 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
134
135 writing = "w" in mode
136 reading = "r" in mode or not writing
137 assert reading or writing
138 binary = "b" in mode
139 rawmode = ""
140 if reading:
141 rawmode += "r"
142 if writing:
143 rawmode += "w"
144 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
145 self.socket._io_refs += 1 # type: ignore[attr-defined]
146 if buffering is None:
147 buffering = -1
148 if buffering < 0:
149 buffering = io.DEFAULT_BUFFER_SIZE
150 if buffering == 0:
151 if not binary:
152 raise ValueError("unbuffered streams must be binary")
153 return raw
154 buffer: typing.BinaryIO
155 if reading and writing:
156 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
157 elif reading:
158 buffer = io.BufferedReader(raw, buffering)
159 else:
160 assert writing
161 buffer = io.BufferedWriter(raw, buffering)
162 if binary:
163 return buffer
164 text = io.TextIOWrapper(buffer, encoding, errors, newline)
165 text.mode = mode # type: ignore[misc]
166 return text
167
168 def unwrap(self) -> None:
169 self._ssl_io_loop(self.sslobj.unwrap)
170
171 def close(self) -> None:
172 self.socket.close()
173
174 @typing.overload
175 def getpeercert(
176 self, binary_form: typing.Literal[False] = ...
177 ) -> _TYPE_PEER_CERT_RET_DICT | None:
178 ...
179
180 @typing.overload
181 def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None:
182 ...
183
184 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
185 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
186
187 def version(self) -> str | None:
188 return self.sslobj.version()
189
190 def cipher(self) -> tuple[str, str, int] | None:
191 return self.sslobj.cipher()
192
193 def selected_alpn_protocol(self) -> str | None:
194 return self.sslobj.selected_alpn_protocol()
195
196 def selected_npn_protocol(self) -> str | None:
197 return self.sslobj.selected_npn_protocol()
198
199 def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
200 return self.sslobj.shared_ciphers()
201
202 def compression(self) -> str | None:
203 return self.sslobj.compression()
204
205 def settimeout(self, value: float | None) -> None:
206 self.socket.settimeout(value)
207
208 def gettimeout(self) -> float | None:
209 return self.socket.gettimeout()
210
211 def _decref_socketios(self) -> None:
212 self.socket._decref_socketios() # type: ignore[attr-defined]
213
214 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
215 try:
216 return self._ssl_io_loop(self.sslobj.read, len, buffer)
217 except ssl.SSLError as e:
218 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
219 return 0 # eof, return 0.
220 else:
221 raise
222
223 # func is sslobj.do_handshake or sslobj.unwrap
224 @typing.overload
225 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
226 ...
227
228 # func is sslobj.write, arg1 is data
229 @typing.overload
230 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
231 ...
232
233 # func is sslobj.read, arg1 is len, arg2 is buffer
234 @typing.overload
235 def _ssl_io_loop(
236 self,
237 func: typing.Callable[[int, bytearray | None], bytes],
238 arg1: int,
239 arg2: bytearray | None,
240 ) -> bytes:
241 ...
242
243 def _ssl_io_loop(
244 self,
245 func: typing.Callable[..., _ReturnValue],
246 arg1: None | bytes | int = None,
247 arg2: bytearray | None = None,
248 ) -> _ReturnValue:
249 """Performs an I/O loop between incoming/outgoing and the socket."""
250 should_loop = True
251 ret = None
252
253 while should_loop:
254 errno = None
255 try:
256 if arg1 is None and arg2 is None:
257 ret = func()
258 elif arg2 is None:
259 ret = func(arg1)
260 else:
261 ret = func(arg1, arg2)
262 except ssl.SSLError as e:
263 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
264 # WANT_READ, and WANT_WRITE are expected, others are not.
265 raise e
266 errno = e.errno
267
268 buf = self.outgoing.read()
269 self.socket.sendall(buf)
270
271 if errno is None:
272 should_loop = False
273 elif errno == ssl.SSL_ERROR_WANT_READ:
274 buf = self.socket.recv(SSL_BLOCKSIZE)
275 if buf:
276 self.incoming.write(buf)
277 else:
278 self.incoming.write_eof()
279 return typing.cast(_ReturnValue, ret)