Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/urllib3/util/ssltransport.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

159 statements  

1from __future__ import annotations 

2 

3import io 

4import socket 

5import ssl 

6import typing 

7 

8from ..exceptions import ProxySchemeUnsupported 

9 

10if typing.TYPE_CHECKING: 

11 from typing_extensions import Self 

12 

13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT 

14 

15 

16_WriteBuffer = typing.Union[bytearray, memoryview] 

17_ReturnValue = typing.TypeVar("_ReturnValue") 

18 

19SSL_BLOCKSIZE = 16384 

20 

21 

22class SSLTransport: 

23 """ 

24 The SSLTransport wraps an existing socket and establishes an SSL connection. 

25 

26 Contrary to Python's implementation of SSLSocket, it allows you to chain 

27 multiple TLS connections together. It's particularly useful if you need to 

28 implement TLS within TLS. 

29 

30 The class supports most of the socket API operations. 

31 """ 

32 

33 @staticmethod 

34 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: 

35 """ 

36 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used 

37 for TLS in TLS. 

38 

39 The only requirement is that the ssl_context provides the 'wrap_bio' 

40 methods. 

41 """ 

42 

43 if not hasattr(ssl_context, "wrap_bio"): 

44 raise ProxySchemeUnsupported( 

45 "TLS in TLS requires SSLContext.wrap_bio() which isn't " 

46 "available on non-native SSLContext" 

47 ) 

48 

49 def __init__( 

50 self, 

51 socket: socket.socket, 

52 ssl_context: ssl.SSLContext, 

53 server_hostname: str | None = None, 

54 suppress_ragged_eofs: bool = True, 

55 ) -> None: 

56 """ 

57 Create an SSLTransport around socket using the provided ssl_context. 

58 """ 

59 self.incoming = ssl.MemoryBIO() 

60 self.outgoing = ssl.MemoryBIO() 

61 

62 self.suppress_ragged_eofs = suppress_ragged_eofs 

63 self.socket = socket 

64 

65 self.sslobj = ssl_context.wrap_bio( 

66 self.incoming, self.outgoing, server_hostname=server_hostname 

67 ) 

68 

69 # Perform initial handshake. 

70 self._ssl_io_loop(self.sslobj.do_handshake) 

71 

72 def __enter__(self) -> Self: 

73 return self 

74 

75 def __exit__(self, *_: typing.Any) -> None: 

76 self.close() 

77 

78 def fileno(self) -> int: 

79 return self.socket.fileno() 

80 

81 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: 

82 return self._wrap_ssl_read(len, buffer) 

83 

84 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: 

85 if flags != 0: 

86 raise ValueError("non-zero flags not allowed in calls to recv") 

87 return self._wrap_ssl_read(buflen) 

88 

89 def recv_into( 

90 self, 

91 buffer: _WriteBuffer, 

92 nbytes: int | None = None, 

93 flags: int = 0, 

94 ) -> None | int | bytes: 

95 if flags != 0: 

96 raise ValueError("non-zero flags not allowed in calls to recv_into") 

97 if nbytes is None: 

98 nbytes = len(buffer) 

99 return self.read(nbytes, buffer) 

100 

101 def sendall(self, data: bytes, flags: int = 0) -> None: 

102 if flags != 0: 

103 raise ValueError("non-zero flags not allowed in calls to sendall") 

104 count = 0 

105 with memoryview(data) as view, view.cast("B") as byte_view: 

106 amount = len(byte_view) 

107 while count < amount: 

108 v = self.send(byte_view[count:]) 

109 count += v 

110 

111 def send(self, data: bytes, flags: int = 0) -> int: 

112 if flags != 0: 

113 raise ValueError("non-zero flags not allowed in calls to send") 

114 return self._ssl_io_loop(self.sslobj.write, data) 

115 

116 def makefile( 

117 self, 

118 mode: str, 

119 buffering: int | None = None, 

120 *, 

121 encoding: str | None = None, 

122 errors: str | None = None, 

123 newline: str | None = None, 

124 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: 

125 """ 

126 Python's httpclient uses makefile and buffered io when reading HTTP 

127 messages and we need to support it. 

128 

129 This is unfortunately a copy and paste of socket.py makefile with small 

130 changes to point to the socket directly. 

131 """ 

132 if not set(mode) <= {"r", "w", "b"}: 

133 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") 

134 

135 writing = "w" in mode 

136 reading = "r" in mode or not writing 

137 assert reading or writing 

138 binary = "b" in mode 

139 rawmode = "" 

140 if reading: 

141 rawmode += "r" 

142 if writing: 

143 rawmode += "w" 

144 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] 

145 self.socket._io_refs += 1 # type: ignore[attr-defined] 

146 if buffering is None: 

147 buffering = -1 

148 if buffering < 0: 

149 buffering = io.DEFAULT_BUFFER_SIZE 

150 if buffering == 0: 

151 if not binary: 

152 raise ValueError("unbuffered streams must be binary") 

153 return raw 

154 buffer: typing.BinaryIO 

155 if reading and writing: 

156 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] 

157 elif reading: 

158 buffer = io.BufferedReader(raw, buffering) 

159 else: 

160 assert writing 

161 buffer = io.BufferedWriter(raw, buffering) 

162 if binary: 

163 return buffer 

164 text = io.TextIOWrapper(buffer, encoding, errors, newline) 

165 text.mode = mode # type: ignore[misc] 

166 return text 

167 

168 def unwrap(self) -> None: 

169 self._ssl_io_loop(self.sslobj.unwrap) 

170 

171 def close(self) -> None: 

172 self.socket.close() 

173 

174 @typing.overload 

175 def getpeercert( 

176 self, binary_form: typing.Literal[False] = ... 

177 ) -> _TYPE_PEER_CERT_RET_DICT | None: 

178 ... 

179 

180 @typing.overload 

181 def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: 

182 ... 

183 

184 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: 

185 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] 

186 

187 def version(self) -> str | None: 

188 return self.sslobj.version() 

189 

190 def cipher(self) -> tuple[str, str, int] | None: 

191 return self.sslobj.cipher() 

192 

193 def selected_alpn_protocol(self) -> str | None: 

194 return self.sslobj.selected_alpn_protocol() 

195 

196 def selected_npn_protocol(self) -> str | None: 

197 return self.sslobj.selected_npn_protocol() 

198 

199 def shared_ciphers(self) -> list[tuple[str, str, int]] | None: 

200 return self.sslobj.shared_ciphers() 

201 

202 def compression(self) -> str | None: 

203 return self.sslobj.compression() 

204 

205 def settimeout(self, value: float | None) -> None: 

206 self.socket.settimeout(value) 

207 

208 def gettimeout(self) -> float | None: 

209 return self.socket.gettimeout() 

210 

211 def _decref_socketios(self) -> None: 

212 self.socket._decref_socketios() # type: ignore[attr-defined] 

213 

214 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: 

215 try: 

216 return self._ssl_io_loop(self.sslobj.read, len, buffer) 

217 except ssl.SSLError as e: 

218 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: 

219 return 0 # eof, return 0. 

220 else: 

221 raise 

222 

223 # func is sslobj.do_handshake or sslobj.unwrap 

224 @typing.overload 

225 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: 

226 ... 

227 

228 # func is sslobj.write, arg1 is data 

229 @typing.overload 

230 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: 

231 ... 

232 

233 # func is sslobj.read, arg1 is len, arg2 is buffer 

234 @typing.overload 

235 def _ssl_io_loop( 

236 self, 

237 func: typing.Callable[[int, bytearray | None], bytes], 

238 arg1: int, 

239 arg2: bytearray | None, 

240 ) -> bytes: 

241 ... 

242 

243 def _ssl_io_loop( 

244 self, 

245 func: typing.Callable[..., _ReturnValue], 

246 arg1: None | bytes | int = None, 

247 arg2: bytearray | None = None, 

248 ) -> _ReturnValue: 

249 """Performs an I/O loop between incoming/outgoing and the socket.""" 

250 should_loop = True 

251 ret = None 

252 

253 while should_loop: 

254 errno = None 

255 try: 

256 if arg1 is None and arg2 is None: 

257 ret = func() 

258 elif arg2 is None: 

259 ret = func(arg1) 

260 else: 

261 ret = func(arg1, arg2) 

262 except ssl.SSLError as e: 

263 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): 

264 # WANT_READ, and WANT_WRITE are expected, others are not. 

265 raise e 

266 errno = e.errno 

267 

268 buf = self.outgoing.read() 

269 self.socket.sendall(buf) 

270 

271 if errno is None: 

272 should_loop = False 

273 elif errno == ssl.SSL_ERROR_WANT_READ: 

274 buf = self.socket.recv(SSL_BLOCKSIZE) 

275 if buf: 

276 self.incoming.write(buf) 

277 else: 

278 self.incoming.write_eof() 

279 return typing.cast(_ReturnValue, ret)