Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/urllib3/util/ssltransport.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

152 statements  

1from __future__ import annotations 

2 

3import io 

4import socket 

5import ssl 

6import typing 

7 

8from ..exceptions import ProxySchemeUnsupported 

9 

10if typing.TYPE_CHECKING: 

11 from typing_extensions import Self 

12 

13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT 

14 

15 

16_WriteBuffer = typing.Union[bytearray, memoryview] 

17_ReturnValue = typing.TypeVar("_ReturnValue") 

18 

19SSL_BLOCKSIZE = 16384 

20 

21 

22class SSLTransport: 

23 """ 

24 The SSLTransport wraps an existing socket and establishes an SSL connection. 

25 

26 Contrary to Python's implementation of SSLSocket, it allows you to chain 

27 multiple TLS connections together. It's particularly useful if you need to 

28 implement TLS within TLS. 

29 

30 The class supports most of the socket API operations. 

31 """ 

32 

33 @staticmethod 

34 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: 

35 """ 

36 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used 

37 for TLS in TLS. 

38 

39 The only requirement is that the ssl_context provides the 'wrap_bio' 

40 methods. 

41 """ 

42 

43 if not hasattr(ssl_context, "wrap_bio"): 

44 raise ProxySchemeUnsupported( 

45 "TLS in TLS requires SSLContext.wrap_bio() which isn't " 

46 "available on non-native SSLContext" 

47 ) 

48 

49 def __init__( 

50 self, 

51 socket: socket.socket, 

52 ssl_context: ssl.SSLContext, 

53 server_hostname: str | None = None, 

54 suppress_ragged_eofs: bool = True, 

55 ) -> None: 

56 """ 

57 Create an SSLTransport around socket using the provided ssl_context. 

58 """ 

59 self.incoming = ssl.MemoryBIO() 

60 self.outgoing = ssl.MemoryBIO() 

61 

62 self.suppress_ragged_eofs = suppress_ragged_eofs 

63 self.socket = socket 

64 

65 self.sslobj = ssl_context.wrap_bio( 

66 self.incoming, self.outgoing, server_hostname=server_hostname 

67 ) 

68 

69 # Perform initial handshake. 

70 self._ssl_io_loop(self.sslobj.do_handshake) 

71 

72 def __enter__(self) -> Self: 

73 return self 

74 

75 def __exit__(self, *_: typing.Any) -> None: 

76 self.close() 

77 

78 def fileno(self) -> int: 

79 return self.socket.fileno() 

80 

81 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: 

82 return self._wrap_ssl_read(len, buffer) 

83 

84 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: 

85 if flags != 0: 

86 raise ValueError("non-zero flags not allowed in calls to recv") 

87 return self._wrap_ssl_read(buflen) 

88 

89 def recv_into( 

90 self, 

91 buffer: _WriteBuffer, 

92 nbytes: int | None = None, 

93 flags: int = 0, 

94 ) -> None | int | bytes: 

95 if flags != 0: 

96 raise ValueError("non-zero flags not allowed in calls to recv_into") 

97 if nbytes is None: 

98 nbytes = len(buffer) 

99 return self.read(nbytes, buffer) 

100 

101 def sendall(self, data: bytes, flags: int = 0) -> None: 

102 if flags != 0: 

103 raise ValueError("non-zero flags not allowed in calls to sendall") 

104 count = 0 

105 with memoryview(data) as view, view.cast("B") as byte_view: 

106 amount = len(byte_view) 

107 while count < amount: 

108 v = self.send(byte_view[count:]) 

109 count += v 

110 

111 def send(self, data: bytes, flags: int = 0) -> int: 

112 if flags != 0: 

113 raise ValueError("non-zero flags not allowed in calls to send") 

114 return self._ssl_io_loop(self.sslobj.write, data) 

115 

116 def makefile( 

117 self, 

118 mode: str, 

119 buffering: int | None = None, 

120 *, 

121 encoding: str | None = None, 

122 errors: str | None = None, 

123 newline: str | None = None, 

124 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: 

125 """ 

126 Python's httpclient uses makefile and buffered io when reading HTTP 

127 messages and we need to support it. 

128 

129 This is unfortunately a copy and paste of socket.py makefile with small 

130 changes to point to the socket directly. 

131 """ 

132 if not set(mode) <= {"r", "w", "b"}: 

133 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") 

134 

135 writing = "w" in mode 

136 reading = "r" in mode or not writing 

137 assert reading or writing 

138 binary = "b" in mode 

139 rawmode = "" 

140 if reading: 

141 rawmode += "r" 

142 if writing: 

143 rawmode += "w" 

144 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] 

145 self.socket._io_refs += 1 # type: ignore[attr-defined] 

146 if buffering is None: 

147 buffering = -1 

148 if buffering < 0: 

149 buffering = io.DEFAULT_BUFFER_SIZE 

150 if buffering == 0: 

151 if not binary: 

152 raise ValueError("unbuffered streams must be binary") 

153 return raw 

154 buffer: typing.BinaryIO 

155 if reading and writing: 

156 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] 

157 elif reading: 

158 buffer = io.BufferedReader(raw, buffering) 

159 else: 

160 assert writing 

161 buffer = io.BufferedWriter(raw, buffering) 

162 if binary: 

163 return buffer 

164 text = io.TextIOWrapper(buffer, encoding, errors, newline) 

165 text.mode = mode # type: ignore[misc] 

166 return text 

167 

168 def unwrap(self) -> None: 

169 self._ssl_io_loop(self.sslobj.unwrap) 

170 

171 def close(self) -> None: 

172 self.socket.close() 

173 

174 @typing.overload 

175 def getpeercert( 

176 self, binary_form: typing.Literal[False] = ... 

177 ) -> _TYPE_PEER_CERT_RET_DICT | None: ... 

178 

179 @typing.overload 

180 def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ... 

181 

182 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: 

183 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] 

184 

185 def version(self) -> str | None: 

186 return self.sslobj.version() 

187 

188 def cipher(self) -> tuple[str, str, int] | None: 

189 return self.sslobj.cipher() 

190 

191 def selected_alpn_protocol(self) -> str | None: 

192 return self.sslobj.selected_alpn_protocol() 

193 

194 def shared_ciphers(self) -> list[tuple[str, str, int]] | None: 

195 return self.sslobj.shared_ciphers() 

196 

197 def compression(self) -> str | None: 

198 return self.sslobj.compression() 

199 

200 def settimeout(self, value: float | None) -> None: 

201 self.socket.settimeout(value) 

202 

203 def gettimeout(self) -> float | None: 

204 return self.socket.gettimeout() 

205 

206 def _decref_socketios(self) -> None: 

207 self.socket._decref_socketios() # type: ignore[attr-defined] 

208 

209 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: 

210 try: 

211 return self._ssl_io_loop(self.sslobj.read, len, buffer) 

212 except ssl.SSLError as e: 

213 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: 

214 return 0 # eof, return 0. 

215 else: 

216 raise 

217 

218 # func is sslobj.do_handshake or sslobj.unwrap 

219 @typing.overload 

220 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ... 

221 

222 # func is sslobj.write, arg1 is data 

223 @typing.overload 

224 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ... 

225 

226 # func is sslobj.read, arg1 is len, arg2 is buffer 

227 @typing.overload 

228 def _ssl_io_loop( 

229 self, 

230 func: typing.Callable[[int, bytearray | None], bytes], 

231 arg1: int, 

232 arg2: bytearray | None, 

233 ) -> bytes: ... 

234 

235 def _ssl_io_loop( 

236 self, 

237 func: typing.Callable[..., _ReturnValue], 

238 arg1: None | bytes | int = None, 

239 arg2: bytearray | None = None, 

240 ) -> _ReturnValue: 

241 """Performs an I/O loop between incoming/outgoing and the socket.""" 

242 should_loop = True 

243 ret = None 

244 

245 while should_loop: 

246 errno = None 

247 try: 

248 if arg1 is None and arg2 is None: 

249 ret = func() 

250 elif arg2 is None: 

251 ret = func(arg1) 

252 else: 

253 ret = func(arg1, arg2) 

254 except ssl.SSLError as e: 

255 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): 

256 # WANT_READ, and WANT_WRITE are expected, others are not. 

257 raise e 

258 errno = e.errno 

259 

260 buf = self.outgoing.read() 

261 self.socket.sendall(buf) 

262 

263 if errno is None: 

264 should_loop = False 

265 elif errno == ssl.SSL_ERROR_WANT_READ: 

266 buf = self.socket.recv(SSL_BLOCKSIZE) 

267 if buf: 

268 self.incoming.write(buf) 

269 else: 

270 self.incoming.write_eof() 

271 return typing.cast(_ReturnValue, ret)