Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/urllib3/util/ssltransport.py: 30%

160 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:40 +0000

1from __future__ import annotations 

2 

3import io 

4import socket 

5import ssl 

6import typing 

7 

8from ..exceptions import ProxySchemeUnsupported 

9 

10if typing.TYPE_CHECKING: 

11 from typing import Literal 

12 

13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT 

14 

15 

16_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport") 

17_WriteBuffer = typing.Union[bytearray, memoryview] 

18_ReturnValue = typing.TypeVar("_ReturnValue") 

19 

20SSL_BLOCKSIZE = 16384 

21 

22 

23class SSLTransport: 

24 """ 

25 The SSLTransport wraps an existing socket and establishes an SSL connection. 

26 

27 Contrary to Python's implementation of SSLSocket, it allows you to chain 

28 multiple TLS connections together. It's particularly useful if you need to 

29 implement TLS within TLS. 

30 

31 The class supports most of the socket API operations. 

32 """ 

33 

34 @staticmethod 

35 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: 

36 """ 

37 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used 

38 for TLS in TLS. 

39 

40 The only requirement is that the ssl_context provides the 'wrap_bio' 

41 methods. 

42 """ 

43 

44 if not hasattr(ssl_context, "wrap_bio"): 

45 raise ProxySchemeUnsupported( 

46 "TLS in TLS requires SSLContext.wrap_bio() which isn't " 

47 "available on non-native SSLContext" 

48 ) 

49 

50 def __init__( 

51 self, 

52 socket: socket.socket, 

53 ssl_context: ssl.SSLContext, 

54 server_hostname: str | None = None, 

55 suppress_ragged_eofs: bool = True, 

56 ) -> None: 

57 """ 

58 Create an SSLTransport around socket using the provided ssl_context. 

59 """ 

60 self.incoming = ssl.MemoryBIO() 

61 self.outgoing = ssl.MemoryBIO() 

62 

63 self.suppress_ragged_eofs = suppress_ragged_eofs 

64 self.socket = socket 

65 

66 self.sslobj = ssl_context.wrap_bio( 

67 self.incoming, self.outgoing, server_hostname=server_hostname 

68 ) 

69 

70 # Perform initial handshake. 

71 self._ssl_io_loop(self.sslobj.do_handshake) 

72 

73 def __enter__(self: _SelfT) -> _SelfT: 

74 return self 

75 

76 def __exit__(self, *_: typing.Any) -> None: 

77 self.close() 

78 

79 def fileno(self) -> int: 

80 return self.socket.fileno() 

81 

82 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: 

83 return self._wrap_ssl_read(len, buffer) 

84 

85 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: 

86 if flags != 0: 

87 raise ValueError("non-zero flags not allowed in calls to recv") 

88 return self._wrap_ssl_read(buflen) 

89 

90 def recv_into( 

91 self, 

92 buffer: _WriteBuffer, 

93 nbytes: int | None = None, 

94 flags: int = 0, 

95 ) -> None | int | bytes: 

96 if flags != 0: 

97 raise ValueError("non-zero flags not allowed in calls to recv_into") 

98 if nbytes is None: 

99 nbytes = len(buffer) 

100 return self.read(nbytes, buffer) 

101 

102 def sendall(self, data: bytes, flags: int = 0) -> None: 

103 if flags != 0: 

104 raise ValueError("non-zero flags not allowed in calls to sendall") 

105 count = 0 

106 with memoryview(data) as view, view.cast("B") as byte_view: 

107 amount = len(byte_view) 

108 while count < amount: 

109 v = self.send(byte_view[count:]) 

110 count += v 

111 

112 def send(self, data: bytes, flags: int = 0) -> int: 

113 if flags != 0: 

114 raise ValueError("non-zero flags not allowed in calls to send") 

115 return self._ssl_io_loop(self.sslobj.write, data) 

116 

117 def makefile( 

118 self, 

119 mode: str, 

120 buffering: int | None = None, 

121 *, 

122 encoding: str | None = None, 

123 errors: str | None = None, 

124 newline: str | None = None, 

125 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: 

126 """ 

127 Python's httpclient uses makefile and buffered io when reading HTTP 

128 messages and we need to support it. 

129 

130 This is unfortunately a copy and paste of socket.py makefile with small 

131 changes to point to the socket directly. 

132 """ 

133 if not set(mode) <= {"r", "w", "b"}: 

134 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") 

135 

136 writing = "w" in mode 

137 reading = "r" in mode or not writing 

138 assert reading or writing 

139 binary = "b" in mode 

140 rawmode = "" 

141 if reading: 

142 rawmode += "r" 

143 if writing: 

144 rawmode += "w" 

145 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] 

146 self.socket._io_refs += 1 # type: ignore[attr-defined] 

147 if buffering is None: 

148 buffering = -1 

149 if buffering < 0: 

150 buffering = io.DEFAULT_BUFFER_SIZE 

151 if buffering == 0: 

152 if not binary: 

153 raise ValueError("unbuffered streams must be binary") 

154 return raw 

155 buffer: typing.BinaryIO 

156 if reading and writing: 

157 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] 

158 elif reading: 

159 buffer = io.BufferedReader(raw, buffering) 

160 else: 

161 assert writing 

162 buffer = io.BufferedWriter(raw, buffering) 

163 if binary: 

164 return buffer 

165 text = io.TextIOWrapper(buffer, encoding, errors, newline) 

166 text.mode = mode # type: ignore[misc] 

167 return text 

168 

169 def unwrap(self) -> None: 

170 self._ssl_io_loop(self.sslobj.unwrap) 

171 

172 def close(self) -> None: 

173 self.socket.close() 

174 

175 @typing.overload 

176 def getpeercert( 

177 self, binary_form: Literal[False] = ... 

178 ) -> _TYPE_PEER_CERT_RET_DICT | None: 

179 ... 

180 

181 @typing.overload 

182 def getpeercert(self, binary_form: Literal[True]) -> bytes | None: 

183 ... 

184 

185 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: 

186 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] 

187 

188 def version(self) -> str | None: 

189 return self.sslobj.version() 

190 

191 def cipher(self) -> tuple[str, str, int] | None: 

192 return self.sslobj.cipher() 

193 

194 def selected_alpn_protocol(self) -> str | None: 

195 return self.sslobj.selected_alpn_protocol() 

196 

197 def selected_npn_protocol(self) -> str | None: 

198 return self.sslobj.selected_npn_protocol() 

199 

200 def shared_ciphers(self) -> list[tuple[str, str, int]] | None: 

201 return self.sslobj.shared_ciphers() 

202 

203 def compression(self) -> str | None: 

204 return self.sslobj.compression() 

205 

206 def settimeout(self, value: float | None) -> None: 

207 self.socket.settimeout(value) 

208 

209 def gettimeout(self) -> float | None: 

210 return self.socket.gettimeout() 

211 

212 def _decref_socketios(self) -> None: 

213 self.socket._decref_socketios() # type: ignore[attr-defined] 

214 

215 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: 

216 try: 

217 return self._ssl_io_loop(self.sslobj.read, len, buffer) 

218 except ssl.SSLError as e: 

219 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: 

220 return 0 # eof, return 0. 

221 else: 

222 raise 

223 

224 # func is sslobj.do_handshake or sslobj.unwrap 

225 @typing.overload 

226 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: 

227 ... 

228 

229 # func is sslobj.write, arg1 is data 

230 @typing.overload 

231 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: 

232 ... 

233 

234 # func is sslobj.read, arg1 is len, arg2 is buffer 

235 @typing.overload 

236 def _ssl_io_loop( 

237 self, 

238 func: typing.Callable[[int, bytearray | None], bytes], 

239 arg1: int, 

240 arg2: bytearray | None, 

241 ) -> bytes: 

242 ... 

243 

244 def _ssl_io_loop( 

245 self, 

246 func: typing.Callable[..., _ReturnValue], 

247 arg1: None | bytes | int = None, 

248 arg2: bytearray | None = None, 

249 ) -> _ReturnValue: 

250 """Performs an I/O loop between incoming/outgoing and the socket.""" 

251 should_loop = True 

252 ret = None 

253 

254 while should_loop: 

255 errno = None 

256 try: 

257 if arg1 is None and arg2 is None: 

258 ret = func() 

259 elif arg2 is None: 

260 ret = func(arg1) 

261 else: 

262 ret = func(arg1, arg2) 

263 except ssl.SSLError as e: 

264 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): 

265 # WANT_READ, and WANT_WRITE are expected, others are not. 

266 raise e 

267 errno = e.errno 

268 

269 buf = self.outgoing.read() 

270 self.socket.sendall(buf) 

271 

272 if errno is None: 

273 should_loop = False 

274 elif errno == ssl.SSL_ERROR_WANT_READ: 

275 buf = self.socket.recv(SSL_BLOCKSIZE) 

276 if buf: 

277 self.incoming.write(buf) 

278 else: 

279 self.incoming.write_eof() 

280 return typing.cast(_ReturnValue, ret)