Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/urllib3/util/ssltransport.py: 30%
160 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-20 06:09 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-20 06:09 +0000
1from __future__ import annotations
3import io
4import socket
5import ssl
6import typing
8from ..exceptions import ProxySchemeUnsupported
10if typing.TYPE_CHECKING:
11 from typing import Literal
13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
16_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
17_WriteBuffer = typing.Union[bytearray, memoryview]
18_ReturnValue = typing.TypeVar("_ReturnValue")
20SSL_BLOCKSIZE = 16384
23class SSLTransport:
24 """
25 The SSLTransport wraps an existing socket and establishes an SSL connection.
27 Contrary to Python's implementation of SSLSocket, it allows you to chain
28 multiple TLS connections together. It's particularly useful if you need to
29 implement TLS within TLS.
31 The class supports most of the socket API operations.
32 """
34 @staticmethod
35 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
36 """
37 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
38 for TLS in TLS.
40 The only requirement is that the ssl_context provides the 'wrap_bio'
41 methods.
42 """
44 if not hasattr(ssl_context, "wrap_bio"):
45 raise ProxySchemeUnsupported(
46 "TLS in TLS requires SSLContext.wrap_bio() which isn't "
47 "available on non-native SSLContext"
48 )
50 def __init__(
51 self,
52 socket: socket.socket,
53 ssl_context: ssl.SSLContext,
54 server_hostname: str | None = None,
55 suppress_ragged_eofs: bool = True,
56 ) -> None:
57 """
58 Create an SSLTransport around socket using the provided ssl_context.
59 """
60 self.incoming = ssl.MemoryBIO()
61 self.outgoing = ssl.MemoryBIO()
63 self.suppress_ragged_eofs = suppress_ragged_eofs
64 self.socket = socket
66 self.sslobj = ssl_context.wrap_bio(
67 self.incoming, self.outgoing, server_hostname=server_hostname
68 )
70 # Perform initial handshake.
71 self._ssl_io_loop(self.sslobj.do_handshake)
73 def __enter__(self: _SelfT) -> _SelfT:
74 return self
76 def __exit__(self, *_: typing.Any) -> None:
77 self.close()
79 def fileno(self) -> int:
80 return self.socket.fileno()
82 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
83 return self._wrap_ssl_read(len, buffer)
85 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
86 if flags != 0:
87 raise ValueError("non-zero flags not allowed in calls to recv")
88 return self._wrap_ssl_read(buflen)
90 def recv_into(
91 self,
92 buffer: _WriteBuffer,
93 nbytes: int | None = None,
94 flags: int = 0,
95 ) -> None | int | bytes:
96 if flags != 0:
97 raise ValueError("non-zero flags not allowed in calls to recv_into")
98 if nbytes is None:
99 nbytes = len(buffer)
100 return self.read(nbytes, buffer)
102 def sendall(self, data: bytes, flags: int = 0) -> None:
103 if flags != 0:
104 raise ValueError("non-zero flags not allowed in calls to sendall")
105 count = 0
106 with memoryview(data) as view, view.cast("B") as byte_view:
107 amount = len(byte_view)
108 while count < amount:
109 v = self.send(byte_view[count:])
110 count += v
112 def send(self, data: bytes, flags: int = 0) -> int:
113 if flags != 0:
114 raise ValueError("non-zero flags not allowed in calls to send")
115 return self._ssl_io_loop(self.sslobj.write, data)
117 def makefile(
118 self,
119 mode: str,
120 buffering: int | None = None,
121 *,
122 encoding: str | None = None,
123 errors: str | None = None,
124 newline: str | None = None,
125 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
126 """
127 Python's httpclient uses makefile and buffered io when reading HTTP
128 messages and we need to support it.
130 This is unfortunately a copy and paste of socket.py makefile with small
131 changes to point to the socket directly.
132 """
133 if not set(mode) <= {"r", "w", "b"}:
134 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
136 writing = "w" in mode
137 reading = "r" in mode or not writing
138 assert reading or writing
139 binary = "b" in mode
140 rawmode = ""
141 if reading:
142 rawmode += "r"
143 if writing:
144 rawmode += "w"
145 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
146 self.socket._io_refs += 1 # type: ignore[attr-defined]
147 if buffering is None:
148 buffering = -1
149 if buffering < 0:
150 buffering = io.DEFAULT_BUFFER_SIZE
151 if buffering == 0:
152 if not binary:
153 raise ValueError("unbuffered streams must be binary")
154 return raw
155 buffer: typing.BinaryIO
156 if reading and writing:
157 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
158 elif reading:
159 buffer = io.BufferedReader(raw, buffering)
160 else:
161 assert writing
162 buffer = io.BufferedWriter(raw, buffering)
163 if binary:
164 return buffer
165 text = io.TextIOWrapper(buffer, encoding, errors, newline)
166 text.mode = mode # type: ignore[misc]
167 return text
169 def unwrap(self) -> None:
170 self._ssl_io_loop(self.sslobj.unwrap)
172 def close(self) -> None:
173 self.socket.close()
175 @typing.overload
176 def getpeercert(
177 self, binary_form: Literal[False] = ...
178 ) -> _TYPE_PEER_CERT_RET_DICT | None:
179 ...
181 @typing.overload
182 def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
183 ...
185 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
186 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
188 def version(self) -> str | None:
189 return self.sslobj.version()
191 def cipher(self) -> tuple[str, str, int] | None:
192 return self.sslobj.cipher()
194 def selected_alpn_protocol(self) -> str | None:
195 return self.sslobj.selected_alpn_protocol()
197 def selected_npn_protocol(self) -> str | None:
198 return self.sslobj.selected_npn_protocol()
200 def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
201 return self.sslobj.shared_ciphers()
203 def compression(self) -> str | None:
204 return self.sslobj.compression()
206 def settimeout(self, value: float | None) -> None:
207 self.socket.settimeout(value)
209 def gettimeout(self) -> float | None:
210 return self.socket.gettimeout()
212 def _decref_socketios(self) -> None:
213 self.socket._decref_socketios() # type: ignore[attr-defined]
215 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
216 try:
217 return self._ssl_io_loop(self.sslobj.read, len, buffer)
218 except ssl.SSLError as e:
219 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
220 return 0 # eof, return 0.
221 else:
222 raise
224 # func is sslobj.do_handshake or sslobj.unwrap
225 @typing.overload
226 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
227 ...
229 # func is sslobj.write, arg1 is data
230 @typing.overload
231 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
232 ...
234 # func is sslobj.read, arg1 is len, arg2 is buffer
235 @typing.overload
236 def _ssl_io_loop(
237 self,
238 func: typing.Callable[[int, bytearray | None], bytes],
239 arg1: int,
240 arg2: bytearray | None,
241 ) -> bytes:
242 ...
244 def _ssl_io_loop(
245 self,
246 func: typing.Callable[..., _ReturnValue],
247 arg1: None | bytes | int = None,
248 arg2: bytearray | None = None,
249 ) -> _ReturnValue:
250 """Performs an I/O loop between incoming/outgoing and the socket."""
251 should_loop = True
252 ret = None
254 while should_loop:
255 errno = None
256 try:
257 if arg1 is None and arg2 is None:
258 ret = func()
259 elif arg2 is None:
260 ret = func(arg1)
261 else:
262 ret = func(arg1, arg2)
263 except ssl.SSLError as e:
264 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
265 # WANT_READ, and WANT_WRITE are expected, others are not.
266 raise e
267 errno = e.errno
269 buf = self.outgoing.read()
270 self.socket.sendall(buf)
272 if errno is None:
273 should_loop = False
274 elif errno == ssl.SSL_ERROR_WANT_READ:
275 buf = self.socket.recv(SSL_BLOCKSIZE)
276 if buf:
277 self.incoming.write(buf)
278 else:
279 self.incoming.write_eof()
280 return typing.cast(_ReturnValue, ret)