Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/streams/tls.py: 39%

139 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:38 +0000

1from __future__ import annotations 

2 

3import logging 

4import re 

5import ssl 

6from collections.abc import Callable, Mapping 

7from dataclasses import dataclass 

8from functools import wraps 

9from typing import Any, Tuple, TypeVar 

10 

11from .. import ( 

12 BrokenResourceError, 

13 EndOfStream, 

14 aclose_forcefully, 

15 get_cancelled_exc_class, 

16) 

17from .._core._typedattr import TypedAttributeSet, typed_attribute 

18from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup 

19 

20T_Retval = TypeVar("T_Retval") 

21_PCTRTT = Tuple[Tuple[str, str], ...] 

22_PCTRTTT = Tuple[_PCTRTT, ...] 

23 

24 

25class TLSAttribute(TypedAttributeSet): 

26 """Contains Transport Layer Security related attributes.""" 

27 

28 #: the selected ALPN protocol 

29 alpn_protocol: str | None = typed_attribute() 

30 #: the channel binding for type ``tls-unique`` 

31 channel_binding_tls_unique: bytes = typed_attribute() 

32 #: the selected cipher 

33 cipher: tuple[str, str, int] = typed_attribute() 

34 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` 

35 # for more information) 

36 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute() 

37 #: the peer certificate in binary form 

38 peer_certificate_binary: bytes | None = typed_attribute() 

39 #: ``True`` if this is the server side of the connection 

40 server_side: bool = typed_attribute() 

41 #: ciphers shared by the client during the TLS handshake (``None`` if this is the 

42 #: client side) 

43 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() 

44 #: the :class:`~ssl.SSLObject` used for encryption 

45 ssl_object: ssl.SSLObject = typed_attribute() 

46 #: ``True`` if this stream does (and expects) a closing TLS handshake when the 

47 #: stream is being closed 

48 standard_compatible: bool = typed_attribute() 

49 #: the TLS protocol version (e.g. ``TLSv1.2``) 

50 tls_version: str = typed_attribute() 

51 

52 

53@dataclass(eq=False) 

54class TLSStream(ByteStream): 

55 """ 

56 A stream wrapper that encrypts all sent data and decrypts received data. 

57 

58 This class has no public initializer; use :meth:`wrap` instead. 

59 All extra attributes from :class:`~TLSAttribute` are supported. 

60 

61 :var AnyByteStream transport_stream: the wrapped stream 

62 

63 """ 

64 

65 transport_stream: AnyByteStream 

66 standard_compatible: bool 

67 _ssl_object: ssl.SSLObject 

68 _read_bio: ssl.MemoryBIO 

69 _write_bio: ssl.MemoryBIO 

70 

71 @classmethod 

72 async def wrap( 

73 cls, 

74 transport_stream: AnyByteStream, 

75 *, 

76 server_side: bool | None = None, 

77 hostname: str | None = None, 

78 ssl_context: ssl.SSLContext | None = None, 

79 standard_compatible: bool = True, 

80 ) -> TLSStream: 

81 """ 

82 Wrap an existing stream with Transport Layer Security. 

83 

84 This performs a TLS handshake with the peer. 

85 

86 :param transport_stream: a bytes-transporting stream to wrap 

87 :param server_side: ``True`` if this is the server side of the connection, 

88 ``False`` if this is the client side (if omitted, will be set to ``False`` 

89 if ``hostname`` has been provided, ``False`` otherwise). Used only to create 

90 a default context when an explicit context has not been provided. 

91 :param hostname: host name of the peer (if host name checking is desired) 

92 :param ssl_context: the SSLContext object to use (if not provided, a secure 

93 default will be created) 

94 :param standard_compatible: if ``False``, skip the closing handshake when 

95 closing the connection, and don't raise an exception if the peer does the 

96 same 

97 :raises ~ssl.SSLError: if the TLS handshake fails 

98 

99 """ 

100 if server_side is None: 

101 server_side = not hostname 

102 

103 if not ssl_context: 

104 purpose = ( 

105 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH 

106 ) 

107 ssl_context = ssl.create_default_context(purpose) 

108 

109 # Re-enable detection of unexpected EOFs if it was disabled by Python 

110 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): 

111 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF 

112 

113 bio_in = ssl.MemoryBIO() 

114 bio_out = ssl.MemoryBIO() 

115 ssl_object = ssl_context.wrap_bio( 

116 bio_in, bio_out, server_side=server_side, server_hostname=hostname 

117 ) 

118 wrapper = cls( 

119 transport_stream=transport_stream, 

120 standard_compatible=standard_compatible, 

121 _ssl_object=ssl_object, 

122 _read_bio=bio_in, 

123 _write_bio=bio_out, 

124 ) 

125 await wrapper._call_sslobject_method(ssl_object.do_handshake) 

126 return wrapper 

127 

128 async def _call_sslobject_method( 

129 self, func: Callable[..., T_Retval], *args: object 

130 ) -> T_Retval: 

131 while True: 

132 try: 

133 result = func(*args) 

134 except ssl.SSLWantReadError: 

135 try: 

136 # Flush any pending writes first 

137 if self._write_bio.pending: 

138 await self.transport_stream.send(self._write_bio.read()) 

139 

140 data = await self.transport_stream.receive() 

141 except EndOfStream: 

142 self._read_bio.write_eof() 

143 except OSError as exc: 

144 self._read_bio.write_eof() 

145 self._write_bio.write_eof() 

146 raise BrokenResourceError from exc 

147 else: 

148 self._read_bio.write(data) 

149 except ssl.SSLWantWriteError: 

150 await self.transport_stream.send(self._write_bio.read()) 

151 except ssl.SSLSyscallError as exc: 

152 self._read_bio.write_eof() 

153 self._write_bio.write_eof() 

154 raise BrokenResourceError from exc 

155 except ssl.SSLError as exc: 

156 self._read_bio.write_eof() 

157 self._write_bio.write_eof() 

158 if ( 

159 isinstance(exc, ssl.SSLEOFError) 

160 or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror 

161 ): 

162 if self.standard_compatible: 

163 raise BrokenResourceError from exc 

164 else: 

165 raise EndOfStream from None 

166 

167 raise 

168 else: 

169 # Flush any pending writes first 

170 if self._write_bio.pending: 

171 await self.transport_stream.send(self._write_bio.read()) 

172 

173 return result 

174 

175 async def unwrap(self) -> tuple[AnyByteStream, bytes]: 

176 """ 

177 Does the TLS closing handshake. 

178 

179 :return: a tuple of (wrapped byte stream, bytes left in the read buffer) 

180 

181 """ 

182 await self._call_sslobject_method(self._ssl_object.unwrap) 

183 self._read_bio.write_eof() 

184 self._write_bio.write_eof() 

185 return self.transport_stream, self._read_bio.read() 

186 

187 async def aclose(self) -> None: 

188 if self.standard_compatible: 

189 try: 

190 await self.unwrap() 

191 except BaseException: 

192 await aclose_forcefully(self.transport_stream) 

193 raise 

194 

195 await self.transport_stream.aclose() 

196 

197 async def receive(self, max_bytes: int = 65536) -> bytes: 

198 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) 

199 if not data: 

200 raise EndOfStream 

201 

202 return data 

203 

204 async def send(self, item: bytes) -> None: 

205 await self._call_sslobject_method(self._ssl_object.write, item) 

206 

207 async def send_eof(self) -> None: 

208 tls_version = self.extra(TLSAttribute.tls_version) 

209 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) 

210 if match: 

211 major, minor = int(match.group(1)), int(match.group(2) or 0) 

212 if (major, minor) < (1, 3): 

213 raise NotImplementedError( 

214 f"send_eof() requires at least TLSv1.3; current " 

215 f"session uses {tls_version}" 

216 ) 

217 

218 raise NotImplementedError( 

219 "send_eof() has not yet been implemented for TLS streams" 

220 ) 

221 

222 @property 

223 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

224 return { 

225 **self.transport_stream.extra_attributes, 

226 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, 

227 TLSAttribute.channel_binding_tls_unique: ( 

228 self._ssl_object.get_channel_binding 

229 ), 

230 TLSAttribute.cipher: self._ssl_object.cipher, 

231 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), 

232 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( 

233 True 

234 ), 

235 TLSAttribute.server_side: lambda: self._ssl_object.server_side, 

236 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() 

237 if self._ssl_object.server_side 

238 else None, 

239 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

240 TLSAttribute.ssl_object: lambda: self._ssl_object, 

241 TLSAttribute.tls_version: self._ssl_object.version, 

242 } 

243 

244 

245@dataclass(eq=False) 

246class TLSListener(Listener[TLSStream]): 

247 """ 

248 A convenience listener that wraps another listener and auto-negotiates a TLS session 

249 on every accepted connection. 

250 

251 If the TLS handshake times out or raises an exception, 

252 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is 

253 deemed necessary. 

254 

255 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. 

256 

257 :param Listener listener: the listener to wrap 

258 :param ssl_context: the SSL context object 

259 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` 

260 :param handshake_timeout: time limit for the TLS handshake 

261 (passed to :func:`~anyio.fail_after`) 

262 """ 

263 

264 listener: Listener[Any] 

265 ssl_context: ssl.SSLContext 

266 standard_compatible: bool = True 

267 handshake_timeout: float = 30 

268 

269 @staticmethod 

270 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: 

271 """ 

272 Handle an exception raised during the TLS handshake. 

273 

274 This method does 3 things: 

275 

276 #. Forcefully closes the original stream 

277 #. Logs the exception (unless it was a cancellation exception) using the 

278 ``anyio.streams.tls`` logger 

279 #. Reraises the exception if it was a base exception or a cancellation exception 

280 

281 :param exc: the exception 

282 :param stream: the original stream 

283 

284 """ 

285 await aclose_forcefully(stream) 

286 

287 # Log all except cancellation exceptions 

288 if not isinstance(exc, get_cancelled_exc_class()): 

289 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using 

290 # any asyncio implementation, so we explicitly pass the exception to log 

291 # (https://github.com/python/cpython/issues/108668). Trio does not have this 

292 # issue because it works around the CPython bug. 

293 logging.getLogger(__name__).exception( 

294 "Error during TLS handshake", exc_info=exc 

295 ) 

296 

297 # Only reraise base exceptions and cancellation exceptions 

298 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): 

299 raise 

300 

301 async def serve( 

302 self, 

303 handler: Callable[[TLSStream], Any], 

304 task_group: TaskGroup | None = None, 

305 ) -> None: 

306 @wraps(handler) 

307 async def handler_wrapper(stream: AnyByteStream) -> None: 

308 from .. import fail_after 

309 

310 try: 

311 with fail_after(self.handshake_timeout): 

312 wrapped_stream = await TLSStream.wrap( 

313 stream, 

314 ssl_context=self.ssl_context, 

315 standard_compatible=self.standard_compatible, 

316 ) 

317 except BaseException as exc: 

318 await self.handle_handshake_error(exc, stream) 

319 else: 

320 await handler(wrapped_stream) 

321 

322 await self.listener.serve(handler_wrapper, task_group) 

323 

324 async def aclose(self) -> None: 

325 await self.listener.aclose() 

326 

327 @property 

328 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

329 return { 

330 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

331 }