Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/urllib3/util/ssltransport.py: 29%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

157 statements  

1from __future__ import annotations 

2 

3import io 

4import socket 

5import ssl 

6import typing 

7 

8from ..exceptions import ProxySchemeUnsupported 

9 

10if typing.TYPE_CHECKING: 

11 from typing_extensions import Self 

12 

13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT 

14 

15 

16_WriteBuffer = typing.Union[bytearray, memoryview] 

17_ReturnValue = typing.TypeVar("_ReturnValue") 

18 

19SSL_BLOCKSIZE = 16384 

20 

21 

22class SSLTransport: 

23 """ 

24 The SSLTransport wraps an existing socket and establishes an SSL connection. 

25 

26 Contrary to Python's implementation of SSLSocket, it allows you to chain 

27 multiple TLS connections together. It's particularly useful if you need to 

28 implement TLS within TLS. 

29 

30 The class supports most of the socket API operations. 

31 """ 

32 

33 @staticmethod 

34 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: 

35 """ 

36 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used 

37 for TLS in TLS. 

38 

39 The only requirement is that the ssl_context provides the 'wrap_bio' 

40 methods. 

41 """ 

42 

43 if not hasattr(ssl_context, "wrap_bio"): 

44 raise ProxySchemeUnsupported( 

45 "TLS in TLS requires SSLContext.wrap_bio() which isn't " 

46 "available on non-native SSLContext" 

47 ) 

48 

49 def __init__( 

50 self, 

51 socket: socket.socket, 

52 ssl_context: ssl.SSLContext, 

53 server_hostname: str | None = None, 

54 suppress_ragged_eofs: bool = True, 

55 ) -> None: 

56 """ 

57 Create an SSLTransport around socket using the provided ssl_context. 

58 """ 

59 self.incoming = ssl.MemoryBIO() 

60 self.outgoing = ssl.MemoryBIO() 

61 

62 self.suppress_ragged_eofs = suppress_ragged_eofs 

63 self.socket = socket 

64 

65 self.sslobj = ssl_context.wrap_bio( 

66 self.incoming, self.outgoing, server_hostname=server_hostname 

67 ) 

68 

69 # Perform initial handshake. 

70 self._ssl_io_loop(self.sslobj.do_handshake) 

71 

72 def __enter__(self) -> Self: 

73 return self 

74 

75 def __exit__(self, *_: typing.Any) -> None: 

76 self.close() 

77 

78 def fileno(self) -> int: 

79 return self.socket.fileno() 

80 

81 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: 

82 return self._wrap_ssl_read(len, buffer) 

83 

84 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: 

85 if flags != 0: 

86 raise ValueError("non-zero flags not allowed in calls to recv") 

87 return self._wrap_ssl_read(buflen) 

88 

89 def recv_into( 

90 self, 

91 buffer: _WriteBuffer, 

92 nbytes: int | None = None, 

93 flags: int = 0, 

94 ) -> None | int | bytes: 

95 if flags != 0: 

96 raise ValueError("non-zero flags not allowed in calls to recv_into") 

97 if nbytes is None: 

98 nbytes = len(buffer) 

99 return self.read(nbytes, buffer) 

100 

101 def sendall(self, data: bytes, flags: int = 0) -> None: 

102 if flags != 0: 

103 raise ValueError("non-zero flags not allowed in calls to sendall") 

104 count = 0 

105 with memoryview(data) as view, view.cast("B") as byte_view: 

106 amount = len(byte_view) 

107 while count < amount: 

108 v = self.send(byte_view[count:]) 

109 count += v 

110 

111 def send(self, data: bytes, flags: int = 0) -> int: 

112 if flags != 0: 

113 raise ValueError("non-zero flags not allowed in calls to send") 

114 return self._ssl_io_loop(self.sslobj.write, data) 

115 

116 def makefile( 

117 self, 

118 mode: str, 

119 buffering: int | None = None, 

120 *, 

121 encoding: str | None = None, 

122 errors: str | None = None, 

123 newline: str | None = None, 

124 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: 

125 """ 

126 Python's httpclient uses makefile and buffered io when reading HTTP 

127 messages and we need to support it. 

128 

129 This is unfortunately a copy and paste of socket.py makefile with small 

130 changes to point to the socket directly. 

131 """ 

132 if not set(mode) <= {"r", "w", "b"}: 

133 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") 

134 

135 writing = "w" in mode 

136 reading = "r" in mode or not writing 

137 assert reading or writing 

138 binary = "b" in mode 

139 rawmode = "" 

140 if reading: 

141 rawmode += "r" 

142 if writing: 

143 rawmode += "w" 

144 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] 

145 self.socket._io_refs += 1 # type: ignore[attr-defined] 

146 if buffering is None: 

147 buffering = -1 

148 if buffering < 0: 

149 buffering = io.DEFAULT_BUFFER_SIZE 

150 if buffering == 0: 

151 if not binary: 

152 raise ValueError("unbuffered streams must be binary") 

153 return raw 

154 buffer: typing.BinaryIO 

155 if reading and writing: 

156 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] 

157 elif reading: 

158 buffer = io.BufferedReader(raw, buffering) 

159 else: 

160 assert writing 

161 buffer = io.BufferedWriter(raw, buffering) 

162 if binary: 

163 return buffer 

164 text = io.TextIOWrapper(buffer, encoding, errors, newline) 

165 text.mode = mode # type: ignore[misc] 

166 return text 

167 

168 def unwrap(self) -> None: 

169 self._ssl_io_loop(self.sslobj.unwrap) 

170 

171 def close(self) -> None: 

172 self.socket.close() 

173 

174 @typing.overload 

175 def getpeercert( 

176 self, binary_form: typing.Literal[False] = ... 

177 ) -> _TYPE_PEER_CERT_RET_DICT | None: 

178 ... 

179 

180 @typing.overload 

181 def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: 

182 ... 

183 

184 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: 

185 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] 

186 

187 def version(self) -> str | None: 

188 return self.sslobj.version() 

189 

190 def cipher(self) -> tuple[str, str, int] | None: 

191 return self.sslobj.cipher() 

192 

193 def selected_alpn_protocol(self) -> str | None: 

194 return self.sslobj.selected_alpn_protocol() 

195 

196 def shared_ciphers(self) -> list[tuple[str, str, int]] | None: 

197 return self.sslobj.shared_ciphers() 

198 

199 def compression(self) -> str | None: 

200 return self.sslobj.compression() 

201 

202 def settimeout(self, value: float | None) -> None: 

203 self.socket.settimeout(value) 

204 

205 def gettimeout(self) -> float | None: 

206 return self.socket.gettimeout() 

207 

208 def _decref_socketios(self) -> None: 

209 self.socket._decref_socketios() # type: ignore[attr-defined] 

210 

211 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: 

212 try: 

213 return self._ssl_io_loop(self.sslobj.read, len, buffer) 

214 except ssl.SSLError as e: 

215 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: 

216 return 0 # eof, return 0. 

217 else: 

218 raise 

219 

220 # func is sslobj.do_handshake or sslobj.unwrap 

221 @typing.overload 

222 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: 

223 ... 

224 

225 # func is sslobj.write, arg1 is data 

226 @typing.overload 

227 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: 

228 ... 

229 

230 # func is sslobj.read, arg1 is len, arg2 is buffer 

231 @typing.overload 

232 def _ssl_io_loop( 

233 self, 

234 func: typing.Callable[[int, bytearray | None], bytes], 

235 arg1: int, 

236 arg2: bytearray | None, 

237 ) -> bytes: 

238 ... 

239 

240 def _ssl_io_loop( 

241 self, 

242 func: typing.Callable[..., _ReturnValue], 

243 arg1: None | bytes | int = None, 

244 arg2: bytearray | None = None, 

245 ) -> _ReturnValue: 

246 """Performs an I/O loop between incoming/outgoing and the socket.""" 

247 should_loop = True 

248 ret = None 

249 

250 while should_loop: 

251 errno = None 

252 try: 

253 if arg1 is None and arg2 is None: 

254 ret = func() 

255 elif arg2 is None: 

256 ret = func(arg1) 

257 else: 

258 ret = func(arg1, arg2) 

259 except ssl.SSLError as e: 

260 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): 

261 # WANT_READ, and WANT_WRITE are expected, others are not. 

262 raise e 

263 errno = e.errno 

264 

265 buf = self.outgoing.read() 

266 self.socket.sendall(buf) 

267 

268 if errno is None: 

269 should_loop = False 

270 elif errno == ssl.SSL_ERROR_WANT_READ: 

271 buf = self.socket.recv(SSL_BLOCKSIZE) 

272 if buf: 

273 self.incoming.write(buf) 

274 else: 

275 self.incoming.write_eof() 

276 return typing.cast(_ReturnValue, ret)