Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/anyio/streams/tls.py: 39%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

168 statements  

1from __future__ import annotations 

2 

3__all__ = ( 

4 "TLSAttribute", 

5 "TLSConnectable", 

6 "TLSListener", 

7 "TLSStream", 

8) 

9 

10import logging 

11import re 

12import ssl 

13import sys 

14from collections.abc import Callable, Mapping 

15from dataclasses import dataclass 

16from functools import wraps 

17from ssl import SSLContext 

18from typing import Any, TypeAlias, TypeVar 

19 

20from .. import ( 

21 BrokenResourceError, 

22 EndOfStream, 

23 aclose_forcefully, 

24 get_cancelled_exc_class, 

25 to_thread, 

26) 

27from .._core._typedattr import TypedAttributeSet, typed_attribute 

28from ..abc import ( 

29 AnyByteStream, 

30 AnyByteStreamConnectable, 

31 ByteStream, 

32 ByteStreamConnectable, 

33 Listener, 

34 TaskGroup, 

35) 

36 

37if sys.version_info >= (3, 11): 

38 from typing import TypeVarTuple, Unpack 

39else: 

40 from typing_extensions import TypeVarTuple, Unpack 

41 

42if sys.version_info >= (3, 12): 

43 from typing import override 

44else: 

45 from typing_extensions import override 

46 

47T_Retval = TypeVar("T_Retval") 

48PosArgsT = TypeVarTuple("PosArgsT") 

49_PCTRTT: TypeAlias = tuple[tuple[str, str], ...] 

50_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...] 

51 

52 

53class TLSAttribute(TypedAttributeSet): 

54 """Contains Transport Layer Security related attributes.""" 

55 

56 #: the selected ALPN protocol 

57 alpn_protocol: str | None = typed_attribute() 

58 #: the channel binding for type ``tls-unique`` 

59 channel_binding_tls_unique: bytes = typed_attribute() 

60 #: the selected cipher 

61 cipher: tuple[str, str, int] = typed_attribute() 

62 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` 

63 # for more information) 

64 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute() 

65 #: the peer certificate in binary form 

66 peer_certificate_binary: bytes | None = typed_attribute() 

67 #: ``True`` if this is the server side of the connection 

68 server_side: bool = typed_attribute() 

69 #: ciphers shared by the client during the TLS handshake (``None`` if this is the 

70 #: client side) 

71 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() 

72 #: the :class:`~ssl.SSLObject` used for encryption 

73 ssl_object: ssl.SSLObject = typed_attribute() 

74 #: ``True`` if this stream does (and expects) a closing TLS handshake when the 

75 #: stream is being closed 

76 standard_compatible: bool = typed_attribute() 

77 #: the TLS protocol version (e.g. ``TLSv1.2``) 

78 tls_version: str = typed_attribute() 

79 

80 

81@dataclass(eq=False) 

82class TLSStream(ByteStream): 

83 """ 

84 A stream wrapper that encrypts all sent data and decrypts received data. 

85 

86 This class has no public initializer; use :meth:`wrap` instead. 

87 All extra attributes from :class:`~TLSAttribute` are supported. 

88 

89 :var AnyByteStream transport_stream: the wrapped stream 

90 

91 """ 

92 

93 transport_stream: AnyByteStream 

94 standard_compatible: bool 

95 _ssl_object: ssl.SSLObject 

96 _read_bio: ssl.MemoryBIO 

97 _write_bio: ssl.MemoryBIO 

98 

99 @classmethod 

100 async def wrap( 

101 cls, 

102 transport_stream: AnyByteStream, 

103 *, 

104 server_side: bool | None = None, 

105 hostname: str | None = None, 

106 ssl_context: ssl.SSLContext | None = None, 

107 standard_compatible: bool = True, 

108 ) -> TLSStream: 

109 """ 

110 Wrap an existing stream with Transport Layer Security. 

111 

112 This performs a TLS handshake with the peer. 

113 

114 :param transport_stream: a bytes-transporting stream to wrap 

115 :param server_side: ``True`` if this is the server side of the connection, 

116 ``False`` if this is the client side (if omitted, will be set to ``False`` 

117 if ``hostname`` has been provided, ``False`` otherwise). Used only to create 

118 a default context when an explicit context has not been provided. 

119 :param hostname: host name of the peer (if host name checking is desired) 

120 :param ssl_context: the SSLContext object to use (if not provided, a secure 

121 default will be created) 

122 :param standard_compatible: if ``False``, skip the closing handshake when 

123 closing the connection, and don't raise an exception if the peer does the 

124 same 

125 :raises ~ssl.SSLError: if the TLS handshake fails 

126 

127 """ 

128 if server_side is None: 

129 server_side = not hostname 

130 

131 if not ssl_context: 

132 purpose = ( 

133 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH 

134 ) 

135 ssl_context = ssl.create_default_context(purpose) 

136 

137 # Re-enable detection of unexpected EOFs if it was disabled by Python 

138 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): 

139 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF 

140 

141 bio_in = ssl.MemoryBIO() 

142 bio_out = ssl.MemoryBIO() 

143 

144 # External SSLContext implementations may do blocking I/O in wrap_bio(), 

145 # but the standard library implementation won't 

146 if type(ssl_context) is ssl.SSLContext: 

147 ssl_object = ssl_context.wrap_bio( 

148 bio_in, bio_out, server_side=server_side, server_hostname=hostname 

149 ) 

150 else: 

151 ssl_object = await to_thread.run_sync( 

152 ssl_context.wrap_bio, 

153 bio_in, 

154 bio_out, 

155 server_side, 

156 hostname, 

157 None, 

158 ) 

159 

160 wrapper = cls( 

161 transport_stream=transport_stream, 

162 standard_compatible=standard_compatible, 

163 _ssl_object=ssl_object, 

164 _read_bio=bio_in, 

165 _write_bio=bio_out, 

166 ) 

167 await wrapper._call_sslobject_method(ssl_object.do_handshake) 

168 return wrapper 

169 

170 async def _call_sslobject_method( 

171 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] 

172 ) -> T_Retval: 

173 while True: 

174 try: 

175 result = func(*args) 

176 except ssl.SSLWantReadError: 

177 try: 

178 # Flush any pending writes first 

179 if self._write_bio.pending: 

180 await self.transport_stream.send(self._write_bio.read()) 

181 

182 data = await self.transport_stream.receive() 

183 except EndOfStream: 

184 self._read_bio.write_eof() 

185 except OSError as exc: 

186 self._read_bio.write_eof() 

187 self._write_bio.write_eof() 

188 raise BrokenResourceError from exc 

189 else: 

190 self._read_bio.write(data) 

191 except ssl.SSLWantWriteError: 

192 await self.transport_stream.send(self._write_bio.read()) 

193 except ssl.SSLSyscallError as exc: 

194 self._read_bio.write_eof() 

195 self._write_bio.write_eof() 

196 raise BrokenResourceError from exc 

197 except ssl.SSLError as exc: 

198 self._read_bio.write_eof() 

199 self._write_bio.write_eof() 

200 if isinstance(exc, ssl.SSLEOFError) or ( 

201 exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror 

202 ): 

203 if self.standard_compatible: 

204 raise BrokenResourceError from exc 

205 else: 

206 raise EndOfStream from None 

207 

208 raise 

209 else: 

210 # Flush any pending writes first 

211 if self._write_bio.pending: 

212 await self.transport_stream.send(self._write_bio.read()) 

213 

214 return result 

215 

216 async def unwrap(self) -> tuple[AnyByteStream, bytes]: 

217 """ 

218 Does the TLS closing handshake. 

219 

220 :return: a tuple of (wrapped byte stream, bytes left in the read buffer) 

221 

222 """ 

223 await self._call_sslobject_method(self._ssl_object.unwrap) 

224 self._read_bio.write_eof() 

225 self._write_bio.write_eof() 

226 return self.transport_stream, self._read_bio.read() 

227 

228 async def aclose(self) -> None: 

229 if self.standard_compatible: 

230 try: 

231 await self.unwrap() 

232 except BaseException: 

233 await aclose_forcefully(self.transport_stream) 

234 raise 

235 

236 await self.transport_stream.aclose() 

237 

238 async def receive(self, max_bytes: int = 65536) -> bytes: 

239 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) 

240 if not data: 

241 raise EndOfStream 

242 

243 return data 

244 

245 async def send(self, item: bytes) -> None: 

246 await self._call_sslobject_method(self._ssl_object.write, item) 

247 

248 async def send_eof(self) -> None: 

249 tls_version = self.extra(TLSAttribute.tls_version) 

250 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) 

251 if match: 

252 major, minor = int(match.group(1)), int(match.group(2) or 0) 

253 if (major, minor) < (1, 3): 

254 raise NotImplementedError( 

255 f"send_eof() requires at least TLSv1.3; current " 

256 f"session uses {tls_version}" 

257 ) 

258 

259 raise NotImplementedError( 

260 "send_eof() has not yet been implemented for TLS streams" 

261 ) 

262 

263 @property 

264 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

265 return { 

266 **self.transport_stream.extra_attributes, 

267 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, 

268 TLSAttribute.channel_binding_tls_unique: ( 

269 self._ssl_object.get_channel_binding 

270 ), 

271 TLSAttribute.cipher: self._ssl_object.cipher, 

272 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), 

273 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( 

274 True 

275 ), 

276 TLSAttribute.server_side: lambda: self._ssl_object.server_side, 

277 TLSAttribute.shared_ciphers: lambda: ( 

278 self._ssl_object.shared_ciphers() 

279 if self._ssl_object.server_side 

280 else None 

281 ), 

282 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

283 TLSAttribute.ssl_object: lambda: self._ssl_object, 

284 TLSAttribute.tls_version: self._ssl_object.version, 

285 } 

286 

287 

288@dataclass(eq=False) 

289class TLSListener(Listener[TLSStream]): 

290 """ 

291 A convenience listener that wraps another listener and auto-negotiates a TLS session 

292 on every accepted connection. 

293 

294 If the TLS handshake times out or raises an exception, 

295 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is 

296 deemed necessary. 

297 

298 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. 

299 

300 :param Listener listener: the listener to wrap 

301 :param ssl_context: the SSL context object 

302 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` 

303 :param handshake_timeout: time limit for the TLS handshake 

304 (passed to :func:`~anyio.fail_after`) 

305 """ 

306 

307 listener: Listener[Any] 

308 ssl_context: ssl.SSLContext 

309 standard_compatible: bool = True 

310 handshake_timeout: float = 30 

311 

312 @staticmethod 

313 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: 

314 """ 

315 Handle an exception raised during the TLS handshake. 

316 

317 This method does 3 things: 

318 

319 #. Forcefully closes the original stream 

320 #. Logs the exception (unless it was a cancellation exception) using the 

321 ``anyio.streams.tls`` logger 

322 #. Reraises the exception if it was a base exception or a cancellation exception 

323 

324 :param exc: the exception 

325 :param stream: the original stream 

326 

327 """ 

328 await aclose_forcefully(stream) 

329 

330 # Log all except cancellation exceptions 

331 if not isinstance(exc, get_cancelled_exc_class()): 

332 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using 

333 # any asyncio implementation, so we explicitly pass the exception to log 

334 # (https://github.com/python/cpython/issues/108668). Trio does not have this 

335 # issue because it works around the CPython bug. 

336 logging.getLogger(__name__).exception( 

337 "Error during TLS handshake", exc_info=exc 

338 ) 

339 

340 # Only reraise base exceptions and cancellation exceptions 

341 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): 

342 raise 

343 

344 async def serve( 

345 self, 

346 handler: Callable[[TLSStream], Any], 

347 task_group: TaskGroup | None = None, 

348 ) -> None: 

349 @wraps(handler) 

350 async def handler_wrapper(stream: AnyByteStream) -> None: 

351 from .. import fail_after 

352 

353 try: 

354 with fail_after(self.handshake_timeout): 

355 wrapped_stream = await TLSStream.wrap( 

356 stream, 

357 ssl_context=self.ssl_context, 

358 standard_compatible=self.standard_compatible, 

359 ) 

360 except BaseException as exc: 

361 await self.handle_handshake_error(exc, stream) 

362 else: 

363 await handler(wrapped_stream) 

364 

365 await self.listener.serve(handler_wrapper, task_group) 

366 

367 async def aclose(self) -> None: 

368 await self.listener.aclose() 

369 

370 @property 

371 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

372 return { 

373 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

374 } 

375 

376 

377class TLSConnectable(ByteStreamConnectable): 

378 """ 

379 Wraps another connectable and does TLS negotiation after a successful connection. 

380 

381 :param connectable: the connectable to wrap 

382 :param hostname: host name of the server (if host name checking is desired) 

383 :param ssl_context: the SSLContext object to use (if not provided, a secure default 

384 will be created) 

385 :param standard_compatible: if ``False``, skip the closing handshake when closing 

386 the connection, and don't raise an exception if the server does the same 

387 """ 

388 

389 def __init__( 

390 self, 

391 connectable: AnyByteStreamConnectable, 

392 *, 

393 hostname: str | None = None, 

394 ssl_context: ssl.SSLContext | None = None, 

395 standard_compatible: bool = True, 

396 ) -> None: 

397 self.connectable = connectable 

398 self.ssl_context: SSLContext = ssl_context or ssl.create_default_context( 

399 ssl.Purpose.SERVER_AUTH 

400 ) 

401 if not isinstance(self.ssl_context, ssl.SSLContext): 

402 raise TypeError( 

403 "ssl_context must be an instance of ssl.SSLContext, not " 

404 f"{type(self.ssl_context).__name__}" 

405 ) 

406 self.hostname = hostname 

407 self.standard_compatible = standard_compatible 

408 

409 @override 

410 async def connect(self) -> TLSStream: 

411 stream = await self.connectable.connect() 

412 try: 

413 return await TLSStream.wrap( 

414 stream, 

415 hostname=self.hostname, 

416 ssl_context=self.ssl_context, 

417 standard_compatible=self.standard_compatible, 

418 ) 

419 except BaseException: 

420 await aclose_forcefully(stream) 

421 raise