Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/anyio/streams/tls.py: 40%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

171 statements  

1from __future__ import annotations 

2 

3__all__ = ( 

4 "TLSAttribute", 

5 "TLSConnectable", 

6 "TLSListener", 

7 "TLSStream", 

8) 

9 

10import logging 

11import re 

12import ssl 

13import sys 

14from collections.abc import Callable, Mapping 

15from dataclasses import dataclass 

16from functools import wraps 

17from ssl import SSLContext 

18from typing import Any, TypeVar 

19 

20from .. import ( 

21 BrokenResourceError, 

22 EndOfStream, 

23 aclose_forcefully, 

24 get_cancelled_exc_class, 

25 to_thread, 

26) 

27from .._core._typedattr import TypedAttributeSet, typed_attribute 

28from ..abc import ( 

29 AnyByteStream, 

30 AnyByteStreamConnectable, 

31 ByteStream, 

32 ByteStreamConnectable, 

33 Listener, 

34 TaskGroup, 

35) 

36 

37if sys.version_info >= (3, 10): 

38 from typing import TypeAlias 

39else: 

40 from typing_extensions import TypeAlias 

41 

42if sys.version_info >= (3, 11): 

43 from typing import TypeVarTuple, Unpack 

44else: 

45 from typing_extensions import TypeVarTuple, Unpack 

46 

47if sys.version_info >= (3, 12): 

48 from typing import override 

49else: 

50 from typing_extensions import override 

51 

52T_Retval = TypeVar("T_Retval") 

53PosArgsT = TypeVarTuple("PosArgsT") 

54_PCTRTT: TypeAlias = tuple[tuple[str, str], ...] 

55_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...] 

56 

57 

58class TLSAttribute(TypedAttributeSet): 

59 """Contains Transport Layer Security related attributes.""" 

60 

61 #: the selected ALPN protocol 

62 alpn_protocol: str | None = typed_attribute() 

63 #: the channel binding for type ``tls-unique`` 

64 channel_binding_tls_unique: bytes = typed_attribute() 

65 #: the selected cipher 

66 cipher: tuple[str, str, int] = typed_attribute() 

67 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` 

68 # for more information) 

69 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute() 

70 #: the peer certificate in binary form 

71 peer_certificate_binary: bytes | None = typed_attribute() 

72 #: ``True`` if this is the server side of the connection 

73 server_side: bool = typed_attribute() 

74 #: ciphers shared by the client during the TLS handshake (``None`` if this is the 

75 #: client side) 

76 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() 

77 #: the :class:`~ssl.SSLObject` used for encryption 

78 ssl_object: ssl.SSLObject = typed_attribute() 

79 #: ``True`` if this stream does (and expects) a closing TLS handshake when the 

80 #: stream is being closed 

81 standard_compatible: bool = typed_attribute() 

82 #: the TLS protocol version (e.g. ``TLSv1.2``) 

83 tls_version: str = typed_attribute() 

84 

85 

86@dataclass(eq=False) 

87class TLSStream(ByteStream): 

88 """ 

89 A stream wrapper that encrypts all sent data and decrypts received data. 

90 

91 This class has no public initializer; use :meth:`wrap` instead. 

92 All extra attributes from :class:`~TLSAttribute` are supported. 

93 

94 :var AnyByteStream transport_stream: the wrapped stream 

95 

96 """ 

97 

98 transport_stream: AnyByteStream 

99 standard_compatible: bool 

100 _ssl_object: ssl.SSLObject 

101 _read_bio: ssl.MemoryBIO 

102 _write_bio: ssl.MemoryBIO 

103 

104 @classmethod 

105 async def wrap( 

106 cls, 

107 transport_stream: AnyByteStream, 

108 *, 

109 server_side: bool | None = None, 

110 hostname: str | None = None, 

111 ssl_context: ssl.SSLContext | None = None, 

112 standard_compatible: bool = True, 

113 ) -> TLSStream: 

114 """ 

115 Wrap an existing stream with Transport Layer Security. 

116 

117 This performs a TLS handshake with the peer. 

118 

119 :param transport_stream: a bytes-transporting stream to wrap 

120 :param server_side: ``True`` if this is the server side of the connection, 

121 ``False`` if this is the client side (if omitted, will be set to ``False`` 

122 if ``hostname`` has been provided, ``False`` otherwise). Used only to create 

123 a default context when an explicit context has not been provided. 

124 :param hostname: host name of the peer (if host name checking is desired) 

125 :param ssl_context: the SSLContext object to use (if not provided, a secure 

126 default will be created) 

127 :param standard_compatible: if ``False``, skip the closing handshake when 

128 closing the connection, and don't raise an exception if the peer does the 

129 same 

130 :raises ~ssl.SSLError: if the TLS handshake fails 

131 

132 """ 

133 if server_side is None: 

134 server_side = not hostname 

135 

136 if not ssl_context: 

137 purpose = ( 

138 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH 

139 ) 

140 ssl_context = ssl.create_default_context(purpose) 

141 

142 # Re-enable detection of unexpected EOFs if it was disabled by Python 

143 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): 

144 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF 

145 

146 bio_in = ssl.MemoryBIO() 

147 bio_out = ssl.MemoryBIO() 

148 

149 # External SSLContext implementations may do blocking I/O in wrap_bio(), 

150 # but the standard library implementation won't 

151 if type(ssl_context) is ssl.SSLContext: 

152 ssl_object = ssl_context.wrap_bio( 

153 bio_in, bio_out, server_side=server_side, server_hostname=hostname 

154 ) 

155 else: 

156 ssl_object = await to_thread.run_sync( 

157 ssl_context.wrap_bio, 

158 bio_in, 

159 bio_out, 

160 server_side, 

161 hostname, 

162 None, 

163 ) 

164 

165 wrapper = cls( 

166 transport_stream=transport_stream, 

167 standard_compatible=standard_compatible, 

168 _ssl_object=ssl_object, 

169 _read_bio=bio_in, 

170 _write_bio=bio_out, 

171 ) 

172 await wrapper._call_sslobject_method(ssl_object.do_handshake) 

173 return wrapper 

174 

175 async def _call_sslobject_method( 

176 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] 

177 ) -> T_Retval: 

178 while True: 

179 try: 

180 result = func(*args) 

181 except ssl.SSLWantReadError: 

182 try: 

183 # Flush any pending writes first 

184 if self._write_bio.pending: 

185 await self.transport_stream.send(self._write_bio.read()) 

186 

187 data = await self.transport_stream.receive() 

188 except EndOfStream: 

189 self._read_bio.write_eof() 

190 except OSError as exc: 

191 self._read_bio.write_eof() 

192 self._write_bio.write_eof() 

193 raise BrokenResourceError from exc 

194 else: 

195 self._read_bio.write(data) 

196 except ssl.SSLWantWriteError: 

197 await self.transport_stream.send(self._write_bio.read()) 

198 except ssl.SSLSyscallError as exc: 

199 self._read_bio.write_eof() 

200 self._write_bio.write_eof() 

201 raise BrokenResourceError from exc 

202 except ssl.SSLError as exc: 

203 self._read_bio.write_eof() 

204 self._write_bio.write_eof() 

205 if isinstance(exc, ssl.SSLEOFError) or ( 

206 exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror 

207 ): 

208 if self.standard_compatible: 

209 raise BrokenResourceError from exc 

210 else: 

211 raise EndOfStream from None 

212 

213 raise 

214 else: 

215 # Flush any pending writes first 

216 if self._write_bio.pending: 

217 await self.transport_stream.send(self._write_bio.read()) 

218 

219 return result 

220 

221 async def unwrap(self) -> tuple[AnyByteStream, bytes]: 

222 """ 

223 Does the TLS closing handshake. 

224 

225 :return: a tuple of (wrapped byte stream, bytes left in the read buffer) 

226 

227 """ 

228 await self._call_sslobject_method(self._ssl_object.unwrap) 

229 self._read_bio.write_eof() 

230 self._write_bio.write_eof() 

231 return self.transport_stream, self._read_bio.read() 

232 

233 async def aclose(self) -> None: 

234 if self.standard_compatible: 

235 try: 

236 await self.unwrap() 

237 except BaseException: 

238 await aclose_forcefully(self.transport_stream) 

239 raise 

240 

241 await self.transport_stream.aclose() 

242 

243 async def receive(self, max_bytes: int = 65536) -> bytes: 

244 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) 

245 if not data: 

246 raise EndOfStream 

247 

248 return data 

249 

250 async def send(self, item: bytes) -> None: 

251 await self._call_sslobject_method(self._ssl_object.write, item) 

252 

253 async def send_eof(self) -> None: 

254 tls_version = self.extra(TLSAttribute.tls_version) 

255 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) 

256 if match: 

257 major, minor = int(match.group(1)), int(match.group(2) or 0) 

258 if (major, minor) < (1, 3): 

259 raise NotImplementedError( 

260 f"send_eof() requires at least TLSv1.3; current " 

261 f"session uses {tls_version}" 

262 ) 

263 

264 raise NotImplementedError( 

265 "send_eof() has not yet been implemented for TLS streams" 

266 ) 

267 

268 @property 

269 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

270 return { 

271 **self.transport_stream.extra_attributes, 

272 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, 

273 TLSAttribute.channel_binding_tls_unique: ( 

274 self._ssl_object.get_channel_binding 

275 ), 

276 TLSAttribute.cipher: self._ssl_object.cipher, 

277 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), 

278 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( 

279 True 

280 ), 

281 TLSAttribute.server_side: lambda: self._ssl_object.server_side, 

282 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() 

283 if self._ssl_object.server_side 

284 else None, 

285 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

286 TLSAttribute.ssl_object: lambda: self._ssl_object, 

287 TLSAttribute.tls_version: self._ssl_object.version, 

288 } 

289 

290 

291@dataclass(eq=False) 

292class TLSListener(Listener[TLSStream]): 

293 """ 

294 A convenience listener that wraps another listener and auto-negotiates a TLS session 

295 on every accepted connection. 

296 

297 If the TLS handshake times out or raises an exception, 

298 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is 

299 deemed necessary. 

300 

301 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. 

302 

303 :param Listener listener: the listener to wrap 

304 :param ssl_context: the SSL context object 

305 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` 

306 :param handshake_timeout: time limit for the TLS handshake 

307 (passed to :func:`~anyio.fail_after`) 

308 """ 

309 

310 listener: Listener[Any] 

311 ssl_context: ssl.SSLContext 

312 standard_compatible: bool = True 

313 handshake_timeout: float = 30 

314 

315 @staticmethod 

316 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: 

317 """ 

318 Handle an exception raised during the TLS handshake. 

319 

320 This method does 3 things: 

321 

322 #. Forcefully closes the original stream 

323 #. Logs the exception (unless it was a cancellation exception) using the 

324 ``anyio.streams.tls`` logger 

325 #. Reraises the exception if it was a base exception or a cancellation exception 

326 

327 :param exc: the exception 

328 :param stream: the original stream 

329 

330 """ 

331 await aclose_forcefully(stream) 

332 

333 # Log all except cancellation exceptions 

334 if not isinstance(exc, get_cancelled_exc_class()): 

335 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using 

336 # any asyncio implementation, so we explicitly pass the exception to log 

337 # (https://github.com/python/cpython/issues/108668). Trio does not have this 

338 # issue because it works around the CPython bug. 

339 logging.getLogger(__name__).exception( 

340 "Error during TLS handshake", exc_info=exc 

341 ) 

342 

343 # Only reraise base exceptions and cancellation exceptions 

344 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): 

345 raise 

346 

347 async def serve( 

348 self, 

349 handler: Callable[[TLSStream], Any], 

350 task_group: TaskGroup | None = None, 

351 ) -> None: 

352 @wraps(handler) 

353 async def handler_wrapper(stream: AnyByteStream) -> None: 

354 from .. import fail_after 

355 

356 try: 

357 with fail_after(self.handshake_timeout): 

358 wrapped_stream = await TLSStream.wrap( 

359 stream, 

360 ssl_context=self.ssl_context, 

361 standard_compatible=self.standard_compatible, 

362 ) 

363 except BaseException as exc: 

364 await self.handle_handshake_error(exc, stream) 

365 else: 

366 await handler(wrapped_stream) 

367 

368 await self.listener.serve(handler_wrapper, task_group) 

369 

370 async def aclose(self) -> None: 

371 await self.listener.aclose() 

372 

373 @property 

374 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

375 return { 

376 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

377 } 

378 

379 

380class TLSConnectable(ByteStreamConnectable): 

381 """ 

382 Wraps another connectable and does TLS negotiation after a successful connection. 

383 

384 :param connectable: the connectable to wrap 

385 :param hostname: host name of the server (if host name checking is desired) 

386 :param ssl_context: the SSLContext object to use (if not provided, a secure default 

387 will be created) 

388 :param standard_compatible: if ``False``, skip the closing handshake when closing 

389 the connection, and don't raise an exception if the server does the same 

390 """ 

391 

392 def __init__( 

393 self, 

394 connectable: AnyByteStreamConnectable, 

395 *, 

396 hostname: str | None = None, 

397 ssl_context: ssl.SSLContext | None = None, 

398 standard_compatible: bool = True, 

399 ) -> None: 

400 self.connectable = connectable 

401 self.ssl_context: SSLContext = ssl_context or ssl.create_default_context( 

402 ssl.Purpose.SERVER_AUTH 

403 ) 

404 if not isinstance(self.ssl_context, ssl.SSLContext): 

405 raise TypeError( 

406 "ssl_context must be an instance of ssl.SSLContext, not " 

407 f"{type(self.ssl_context).__name__}" 

408 ) 

409 self.hostname = hostname 

410 self.standard_compatible = standard_compatible 

411 

412 @override 

413 async def connect(self) -> TLSStream: 

414 stream = await self.connectable.connect() 

415 try: 

416 return await TLSStream.wrap( 

417 stream, 

418 hostname=self.hostname, 

419 ssl_context=self.ssl_context, 

420 standard_compatible=self.standard_compatible, 

421 ) 

422 except BaseException: 

423 await aclose_forcefully(stream) 

424 raise