Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/anyio/streams/tls.py: 39%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

170 statements  

1from __future__ import annotations 

2 

3import logging 

4import re 

5import ssl 

6import sys 

7from collections.abc import Callable, Mapping 

8from dataclasses import dataclass 

9from functools import wraps 

10from ssl import SSLContext 

11from typing import Any, TypeVar 

12 

13from .. import ( 

14 BrokenResourceError, 

15 EndOfStream, 

16 aclose_forcefully, 

17 get_cancelled_exc_class, 

18 to_thread, 

19) 

20from .._core._typedattr import TypedAttributeSet, typed_attribute 

21from ..abc import ( 

22 AnyByteStream, 

23 AnyByteStreamConnectable, 

24 ByteStream, 

25 ByteStreamConnectable, 

26 Listener, 

27 TaskGroup, 

28) 

29 

30if sys.version_info >= (3, 10): 

31 from typing import TypeAlias 

32else: 

33 from typing_extensions import TypeAlias 

34 

35if sys.version_info >= (3, 11): 

36 from typing import TypeVarTuple, Unpack 

37else: 

38 from typing_extensions import TypeVarTuple, Unpack 

39 

40if sys.version_info >= (3, 12): 

41 from typing import override 

42else: 

43 from typing_extensions import override 

44 

45T_Retval = TypeVar("T_Retval") 

46PosArgsT = TypeVarTuple("PosArgsT") 

47_PCTRTT: TypeAlias = tuple[tuple[str, str], ...] 

48_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...] 

49 

50 

51class TLSAttribute(TypedAttributeSet): 

52 """Contains Transport Layer Security related attributes.""" 

53 

54 #: the selected ALPN protocol 

55 alpn_protocol: str | None = typed_attribute() 

56 #: the channel binding for type ``tls-unique`` 

57 channel_binding_tls_unique: bytes = typed_attribute() 

58 #: the selected cipher 

59 cipher: tuple[str, str, int] = typed_attribute() 

60 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` 

61 # for more information) 

62 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute() 

63 #: the peer certificate in binary form 

64 peer_certificate_binary: bytes | None = typed_attribute() 

65 #: ``True`` if this is the server side of the connection 

66 server_side: bool = typed_attribute() 

67 #: ciphers shared by the client during the TLS handshake (``None`` if this is the 

68 #: client side) 

69 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() 

70 #: the :class:`~ssl.SSLObject` used for encryption 

71 ssl_object: ssl.SSLObject = typed_attribute() 

72 #: ``True`` if this stream does (and expects) a closing TLS handshake when the 

73 #: stream is being closed 

74 standard_compatible: bool = typed_attribute() 

75 #: the TLS protocol version (e.g. ``TLSv1.2``) 

76 tls_version: str = typed_attribute() 

77 

78 

79@dataclass(eq=False) 

80class TLSStream(ByteStream): 

81 """ 

82 A stream wrapper that encrypts all sent data and decrypts received data. 

83 

84 This class has no public initializer; use :meth:`wrap` instead. 

85 All extra attributes from :class:`~TLSAttribute` are supported. 

86 

87 :var AnyByteStream transport_stream: the wrapped stream 

88 

89 """ 

90 

91 transport_stream: AnyByteStream 

92 standard_compatible: bool 

93 _ssl_object: ssl.SSLObject 

94 _read_bio: ssl.MemoryBIO 

95 _write_bio: ssl.MemoryBIO 

96 

97 @classmethod 

98 async def wrap( 

99 cls, 

100 transport_stream: AnyByteStream, 

101 *, 

102 server_side: bool | None = None, 

103 hostname: str | None = None, 

104 ssl_context: ssl.SSLContext | None = None, 

105 standard_compatible: bool = True, 

106 ) -> TLSStream: 

107 """ 

108 Wrap an existing stream with Transport Layer Security. 

109 

110 This performs a TLS handshake with the peer. 

111 

112 :param transport_stream: a bytes-transporting stream to wrap 

113 :param server_side: ``True`` if this is the server side of the connection, 

114 ``False`` if this is the client side (if omitted, will be set to ``False`` 

115 if ``hostname`` has been provided, ``False`` otherwise). Used only to create 

116 a default context when an explicit context has not been provided. 

117 :param hostname: host name of the peer (if host name checking is desired) 

118 :param ssl_context: the SSLContext object to use (if not provided, a secure 

119 default will be created) 

120 :param standard_compatible: if ``False``, skip the closing handshake when 

121 closing the connection, and don't raise an exception if the peer does the 

122 same 

123 :raises ~ssl.SSLError: if the TLS handshake fails 

124 

125 """ 

126 if server_side is None: 

127 server_side = not hostname 

128 

129 if not ssl_context: 

130 purpose = ( 

131 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH 

132 ) 

133 ssl_context = ssl.create_default_context(purpose) 

134 

135 # Re-enable detection of unexpected EOFs if it was disabled by Python 

136 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): 

137 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF 

138 

139 bio_in = ssl.MemoryBIO() 

140 bio_out = ssl.MemoryBIO() 

141 

142 # External SSLContext implementations may do blocking I/O in wrap_bio(), 

143 # but the standard library implementation won't 

144 if type(ssl_context) is ssl.SSLContext: 

145 ssl_object = ssl_context.wrap_bio( 

146 bio_in, bio_out, server_side=server_side, server_hostname=hostname 

147 ) 

148 else: 

149 ssl_object = await to_thread.run_sync( 

150 ssl_context.wrap_bio, 

151 bio_in, 

152 bio_out, 

153 server_side, 

154 hostname, 

155 None, 

156 ) 

157 

158 wrapper = cls( 

159 transport_stream=transport_stream, 

160 standard_compatible=standard_compatible, 

161 _ssl_object=ssl_object, 

162 _read_bio=bio_in, 

163 _write_bio=bio_out, 

164 ) 

165 await wrapper._call_sslobject_method(ssl_object.do_handshake) 

166 return wrapper 

167 

168 async def _call_sslobject_method( 

169 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] 

170 ) -> T_Retval: 

171 while True: 

172 try: 

173 result = func(*args) 

174 except ssl.SSLWantReadError: 

175 try: 

176 # Flush any pending writes first 

177 if self._write_bio.pending: 

178 await self.transport_stream.send(self._write_bio.read()) 

179 

180 data = await self.transport_stream.receive() 

181 except EndOfStream: 

182 self._read_bio.write_eof() 

183 except OSError as exc: 

184 self._read_bio.write_eof() 

185 self._write_bio.write_eof() 

186 raise BrokenResourceError from exc 

187 else: 

188 self._read_bio.write(data) 

189 except ssl.SSLWantWriteError: 

190 await self.transport_stream.send(self._write_bio.read()) 

191 except ssl.SSLSyscallError as exc: 

192 self._read_bio.write_eof() 

193 self._write_bio.write_eof() 

194 raise BrokenResourceError from exc 

195 except ssl.SSLError as exc: 

196 self._read_bio.write_eof() 

197 self._write_bio.write_eof() 

198 if isinstance(exc, ssl.SSLEOFError) or ( 

199 exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror 

200 ): 

201 if self.standard_compatible: 

202 raise BrokenResourceError from exc 

203 else: 

204 raise EndOfStream from None 

205 

206 raise 

207 else: 

208 # Flush any pending writes first 

209 if self._write_bio.pending: 

210 await self.transport_stream.send(self._write_bio.read()) 

211 

212 return result 

213 

214 async def unwrap(self) -> tuple[AnyByteStream, bytes]: 

215 """ 

216 Does the TLS closing handshake. 

217 

218 :return: a tuple of (wrapped byte stream, bytes left in the read buffer) 

219 

220 """ 

221 await self._call_sslobject_method(self._ssl_object.unwrap) 

222 self._read_bio.write_eof() 

223 self._write_bio.write_eof() 

224 return self.transport_stream, self._read_bio.read() 

225 

226 async def aclose(self) -> None: 

227 if self.standard_compatible: 

228 try: 

229 await self.unwrap() 

230 except BaseException: 

231 await aclose_forcefully(self.transport_stream) 

232 raise 

233 

234 await self.transport_stream.aclose() 

235 

236 async def receive(self, max_bytes: int = 65536) -> bytes: 

237 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) 

238 if not data: 

239 raise EndOfStream 

240 

241 return data 

242 

243 async def send(self, item: bytes) -> None: 

244 await self._call_sslobject_method(self._ssl_object.write, item) 

245 

246 async def send_eof(self) -> None: 

247 tls_version = self.extra(TLSAttribute.tls_version) 

248 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) 

249 if match: 

250 major, minor = int(match.group(1)), int(match.group(2) or 0) 

251 if (major, minor) < (1, 3): 

252 raise NotImplementedError( 

253 f"send_eof() requires at least TLSv1.3; current " 

254 f"session uses {tls_version}" 

255 ) 

256 

257 raise NotImplementedError( 

258 "send_eof() has not yet been implemented for TLS streams" 

259 ) 

260 

261 @property 

262 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

263 return { 

264 **self.transport_stream.extra_attributes, 

265 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, 

266 TLSAttribute.channel_binding_tls_unique: ( 

267 self._ssl_object.get_channel_binding 

268 ), 

269 TLSAttribute.cipher: self._ssl_object.cipher, 

270 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), 

271 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( 

272 True 

273 ), 

274 TLSAttribute.server_side: lambda: self._ssl_object.server_side, 

275 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() 

276 if self._ssl_object.server_side 

277 else None, 

278 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

279 TLSAttribute.ssl_object: lambda: self._ssl_object, 

280 TLSAttribute.tls_version: self._ssl_object.version, 

281 } 

282 

283 

284@dataclass(eq=False) 

285class TLSListener(Listener[TLSStream]): 

286 """ 

287 A convenience listener that wraps another listener and auto-negotiates a TLS session 

288 on every accepted connection. 

289 

290 If the TLS handshake times out or raises an exception, 

291 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is 

292 deemed necessary. 

293 

294 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. 

295 

296 :param Listener listener: the listener to wrap 

297 :param ssl_context: the SSL context object 

298 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` 

299 :param handshake_timeout: time limit for the TLS handshake 

300 (passed to :func:`~anyio.fail_after`) 

301 """ 

302 

303 listener: Listener[Any] 

304 ssl_context: ssl.SSLContext 

305 standard_compatible: bool = True 

306 handshake_timeout: float = 30 

307 

308 @staticmethod 

309 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: 

310 """ 

311 Handle an exception raised during the TLS handshake. 

312 

313 This method does 3 things: 

314 

315 #. Forcefully closes the original stream 

316 #. Logs the exception (unless it was a cancellation exception) using the 

317 ``anyio.streams.tls`` logger 

318 #. Reraises the exception if it was a base exception or a cancellation exception 

319 

320 :param exc: the exception 

321 :param stream: the original stream 

322 

323 """ 

324 await aclose_forcefully(stream) 

325 

326 # Log all except cancellation exceptions 

327 if not isinstance(exc, get_cancelled_exc_class()): 

328 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using 

329 # any asyncio implementation, so we explicitly pass the exception to log 

330 # (https://github.com/python/cpython/issues/108668). Trio does not have this 

331 # issue because it works around the CPython bug. 

332 logging.getLogger(__name__).exception( 

333 "Error during TLS handshake", exc_info=exc 

334 ) 

335 

336 # Only reraise base exceptions and cancellation exceptions 

337 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): 

338 raise 

339 

340 async def serve( 

341 self, 

342 handler: Callable[[TLSStream], Any], 

343 task_group: TaskGroup | None = None, 

344 ) -> None: 

345 @wraps(handler) 

346 async def handler_wrapper(stream: AnyByteStream) -> None: 

347 from .. import fail_after 

348 

349 try: 

350 with fail_after(self.handshake_timeout): 

351 wrapped_stream = await TLSStream.wrap( 

352 stream, 

353 ssl_context=self.ssl_context, 

354 standard_compatible=self.standard_compatible, 

355 ) 

356 except BaseException as exc: 

357 await self.handle_handshake_error(exc, stream) 

358 else: 

359 await handler(wrapped_stream) 

360 

361 await self.listener.serve(handler_wrapper, task_group) 

362 

363 async def aclose(self) -> None: 

364 await self.listener.aclose() 

365 

366 @property 

367 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

368 return { 

369 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

370 } 

371 

372 

373class TLSConnectable(ByteStreamConnectable): 

374 """ 

375 Wraps another connectable and does TLS negotiation after a successful connection. 

376 

377 :param connectable: the connectable to wrap 

378 :param hostname: host name of the server (if host name checking is desired) 

379 :param ssl_context: the SSLContext object to use (if not provided, a secure default 

380 will be created) 

381 :param standard_compatible: if ``False``, skip the closing handshake when closing 

382 the connection, and don't raise an exception if the server does the same 

383 """ 

384 

385 def __init__( 

386 self, 

387 connectable: AnyByteStreamConnectable, 

388 *, 

389 hostname: str | None = None, 

390 ssl_context: ssl.SSLContext | None = None, 

391 standard_compatible: bool = True, 

392 ) -> None: 

393 self.connectable = connectable 

394 self.ssl_context: SSLContext = ssl_context or ssl.create_default_context( 

395 ssl.Purpose.SERVER_AUTH 

396 ) 

397 if not isinstance(self.ssl_context, ssl.SSLContext): 

398 raise TypeError( 

399 "ssl_context must be an instance of ssl.SSLContext, not " 

400 f"{type(self.ssl_context).__name__}" 

401 ) 

402 self.hostname = hostname 

403 self.standard_compatible = standard_compatible 

404 

405 @override 

406 async def connect(self) -> TLSStream: 

407 stream = await self.connectable.connect() 

408 try: 

409 return await TLSStream.wrap( 

410 stream, 

411 hostname=self.hostname, 

412 ssl_context=self.ssl_context, 

413 standard_compatible=self.standard_compatible, 

414 ) 

415 except BaseException: 

416 await aclose_forcefully(stream) 

417 raise