Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/anyio/streams/tls.py: 39%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

147 statements  

1from __future__ import annotations 

2 

3import logging 

4import re 

5import ssl 

6import sys 

7from collections.abc import Callable, Mapping 

8from dataclasses import dataclass 

9from functools import wraps 

10from typing import Any, TypeVar 

11 

12from .. import ( 

13 BrokenResourceError, 

14 EndOfStream, 

15 aclose_forcefully, 

16 get_cancelled_exc_class, 

17 to_thread, 

18) 

19from .._core._typedattr import TypedAttributeSet, typed_attribute 

20from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup 

21 

22if sys.version_info >= (3, 11): 

23 from typing import TypeVarTuple, Unpack 

24else: 

25 from typing_extensions import TypeVarTuple, Unpack 

26 

27T_Retval = TypeVar("T_Retval") 

28PosArgsT = TypeVarTuple("PosArgsT") 

29_PCTRTT = tuple[tuple[str, str], ...] 

30_PCTRTTT = tuple[_PCTRTT, ...] 

31 

32 

33class TLSAttribute(TypedAttributeSet): 

34 """Contains Transport Layer Security related attributes.""" 

35 

36 #: the selected ALPN protocol 

37 alpn_protocol: str | None = typed_attribute() 

38 #: the channel binding for type ``tls-unique`` 

39 channel_binding_tls_unique: bytes = typed_attribute() 

40 #: the selected cipher 

41 cipher: tuple[str, str, int] = typed_attribute() 

42 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` 

43 # for more information) 

44 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute() 

45 #: the peer certificate in binary form 

46 peer_certificate_binary: bytes | None = typed_attribute() 

47 #: ``True`` if this is the server side of the connection 

48 server_side: bool = typed_attribute() 

49 #: ciphers shared by the client during the TLS handshake (``None`` if this is the 

50 #: client side) 

51 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() 

52 #: the :class:`~ssl.SSLObject` used for encryption 

53 ssl_object: ssl.SSLObject = typed_attribute() 

54 #: ``True`` if this stream does (and expects) a closing TLS handshake when the 

55 #: stream is being closed 

56 standard_compatible: bool = typed_attribute() 

57 #: the TLS protocol version (e.g. ``TLSv1.2``) 

58 tls_version: str = typed_attribute() 

59 

60 

61@dataclass(eq=False) 

62class TLSStream(ByteStream): 

63 """ 

64 A stream wrapper that encrypts all sent data and decrypts received data. 

65 

66 This class has no public initializer; use :meth:`wrap` instead. 

67 All extra attributes from :class:`~TLSAttribute` are supported. 

68 

69 :var AnyByteStream transport_stream: the wrapped stream 

70 

71 """ 

72 

73 transport_stream: AnyByteStream 

74 standard_compatible: bool 

75 _ssl_object: ssl.SSLObject 

76 _read_bio: ssl.MemoryBIO 

77 _write_bio: ssl.MemoryBIO 

78 

79 @classmethod 

80 async def wrap( 

81 cls, 

82 transport_stream: AnyByteStream, 

83 *, 

84 server_side: bool | None = None, 

85 hostname: str | None = None, 

86 ssl_context: ssl.SSLContext | None = None, 

87 standard_compatible: bool = True, 

88 ) -> TLSStream: 

89 """ 

90 Wrap an existing stream with Transport Layer Security. 

91 

92 This performs a TLS handshake with the peer. 

93 

94 :param transport_stream: a bytes-transporting stream to wrap 

95 :param server_side: ``True`` if this is the server side of the connection, 

96 ``False`` if this is the client side (if omitted, will be set to ``False`` 

97 if ``hostname`` has been provided, ``False`` otherwise). Used only to create 

98 a default context when an explicit context has not been provided. 

99 :param hostname: host name of the peer (if host name checking is desired) 

100 :param ssl_context: the SSLContext object to use (if not provided, a secure 

101 default will be created) 

102 :param standard_compatible: if ``False``, skip the closing handshake when 

103 closing the connection, and don't raise an exception if the peer does the 

104 same 

105 :raises ~ssl.SSLError: if the TLS handshake fails 

106 

107 """ 

108 if server_side is None: 

109 server_side = not hostname 

110 

111 if not ssl_context: 

112 purpose = ( 

113 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH 

114 ) 

115 ssl_context = ssl.create_default_context(purpose) 

116 

117 # Re-enable detection of unexpected EOFs if it was disabled by Python 

118 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): 

119 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF 

120 

121 bio_in = ssl.MemoryBIO() 

122 bio_out = ssl.MemoryBIO() 

123 

124 # External SSLContext implementations may do blocking I/O in wrap_bio(), 

125 # but the standard library implementation won't 

126 if type(ssl_context) is ssl.SSLContext: 

127 ssl_object = ssl_context.wrap_bio( 

128 bio_in, bio_out, server_side=server_side, server_hostname=hostname 

129 ) 

130 else: 

131 ssl_object = await to_thread.run_sync( 

132 ssl_context.wrap_bio, 

133 bio_in, 

134 bio_out, 

135 server_side, 

136 hostname, 

137 None, 

138 ) 

139 

140 wrapper = cls( 

141 transport_stream=transport_stream, 

142 standard_compatible=standard_compatible, 

143 _ssl_object=ssl_object, 

144 _read_bio=bio_in, 

145 _write_bio=bio_out, 

146 ) 

147 await wrapper._call_sslobject_method(ssl_object.do_handshake) 

148 return wrapper 

149 

150 async def _call_sslobject_method( 

151 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] 

152 ) -> T_Retval: 

153 while True: 

154 try: 

155 result = func(*args) 

156 except ssl.SSLWantReadError: 

157 try: 

158 # Flush any pending writes first 

159 if self._write_bio.pending: 

160 await self.transport_stream.send(self._write_bio.read()) 

161 

162 data = await self.transport_stream.receive() 

163 except EndOfStream: 

164 self._read_bio.write_eof() 

165 except OSError as exc: 

166 self._read_bio.write_eof() 

167 self._write_bio.write_eof() 

168 raise BrokenResourceError from exc 

169 else: 

170 self._read_bio.write(data) 

171 except ssl.SSLWantWriteError: 

172 await self.transport_stream.send(self._write_bio.read()) 

173 except ssl.SSLSyscallError as exc: 

174 self._read_bio.write_eof() 

175 self._write_bio.write_eof() 

176 raise BrokenResourceError from exc 

177 except ssl.SSLError as exc: 

178 self._read_bio.write_eof() 

179 self._write_bio.write_eof() 

180 if isinstance(exc, ssl.SSLEOFError) or ( 

181 exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror 

182 ): 

183 if self.standard_compatible: 

184 raise BrokenResourceError from exc 

185 else: 

186 raise EndOfStream from None 

187 

188 raise 

189 else: 

190 # Flush any pending writes first 

191 if self._write_bio.pending: 

192 await self.transport_stream.send(self._write_bio.read()) 

193 

194 return result 

195 

196 async def unwrap(self) -> tuple[AnyByteStream, bytes]: 

197 """ 

198 Does the TLS closing handshake. 

199 

200 :return: a tuple of (wrapped byte stream, bytes left in the read buffer) 

201 

202 """ 

203 await self._call_sslobject_method(self._ssl_object.unwrap) 

204 self._read_bio.write_eof() 

205 self._write_bio.write_eof() 

206 return self.transport_stream, self._read_bio.read() 

207 

208 async def aclose(self) -> None: 

209 if self.standard_compatible: 

210 try: 

211 await self.unwrap() 

212 except BaseException: 

213 await aclose_forcefully(self.transport_stream) 

214 raise 

215 

216 await self.transport_stream.aclose() 

217 

218 async def receive(self, max_bytes: int = 65536) -> bytes: 

219 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) 

220 if not data: 

221 raise EndOfStream 

222 

223 return data 

224 

225 async def send(self, item: bytes) -> None: 

226 await self._call_sslobject_method(self._ssl_object.write, item) 

227 

228 async def send_eof(self) -> None: 

229 tls_version = self.extra(TLSAttribute.tls_version) 

230 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) 

231 if match: 

232 major, minor = int(match.group(1)), int(match.group(2) or 0) 

233 if (major, minor) < (1, 3): 

234 raise NotImplementedError( 

235 f"send_eof() requires at least TLSv1.3; current " 

236 f"session uses {tls_version}" 

237 ) 

238 

239 raise NotImplementedError( 

240 "send_eof() has not yet been implemented for TLS streams" 

241 ) 

242 

243 @property 

244 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

245 return { 

246 **self.transport_stream.extra_attributes, 

247 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, 

248 TLSAttribute.channel_binding_tls_unique: ( 

249 self._ssl_object.get_channel_binding 

250 ), 

251 TLSAttribute.cipher: self._ssl_object.cipher, 

252 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), 

253 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( 

254 True 

255 ), 

256 TLSAttribute.server_side: lambda: self._ssl_object.server_side, 

257 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() 

258 if self._ssl_object.server_side 

259 else None, 

260 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

261 TLSAttribute.ssl_object: lambda: self._ssl_object, 

262 TLSAttribute.tls_version: self._ssl_object.version, 

263 } 

264 

265 

266@dataclass(eq=False) 

267class TLSListener(Listener[TLSStream]): 

268 """ 

269 A convenience listener that wraps another listener and auto-negotiates a TLS session 

270 on every accepted connection. 

271 

272 If the TLS handshake times out or raises an exception, 

273 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is 

274 deemed necessary. 

275 

276 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. 

277 

278 :param Listener listener: the listener to wrap 

279 :param ssl_context: the SSL context object 

280 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` 

281 :param handshake_timeout: time limit for the TLS handshake 

282 (passed to :func:`~anyio.fail_after`) 

283 """ 

284 

285 listener: Listener[Any] 

286 ssl_context: ssl.SSLContext 

287 standard_compatible: bool = True 

288 handshake_timeout: float = 30 

289 

290 @staticmethod 

291 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: 

292 """ 

293 Handle an exception raised during the TLS handshake. 

294 

295 This method does 3 things: 

296 

297 #. Forcefully closes the original stream 

298 #. Logs the exception (unless it was a cancellation exception) using the 

299 ``anyio.streams.tls`` logger 

300 #. Reraises the exception if it was a base exception or a cancellation exception 

301 

302 :param exc: the exception 

303 :param stream: the original stream 

304 

305 """ 

306 await aclose_forcefully(stream) 

307 

308 # Log all except cancellation exceptions 

309 if not isinstance(exc, get_cancelled_exc_class()): 

310 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using 

311 # any asyncio implementation, so we explicitly pass the exception to log 

312 # (https://github.com/python/cpython/issues/108668). Trio does not have this 

313 # issue because it works around the CPython bug. 

314 logging.getLogger(__name__).exception( 

315 "Error during TLS handshake", exc_info=exc 

316 ) 

317 

318 # Only reraise base exceptions and cancellation exceptions 

319 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): 

320 raise 

321 

322 async def serve( 

323 self, 

324 handler: Callable[[TLSStream], Any], 

325 task_group: TaskGroup | None = None, 

326 ) -> None: 

327 @wraps(handler) 

328 async def handler_wrapper(stream: AnyByteStream) -> None: 

329 from .. import fail_after 

330 

331 try: 

332 with fail_after(self.handshake_timeout): 

333 wrapped_stream = await TLSStream.wrap( 

334 stream, 

335 ssl_context=self.ssl_context, 

336 standard_compatible=self.standard_compatible, 

337 ) 

338 except BaseException as exc: 

339 await self.handle_handshake_error(exc, stream) 

340 else: 

341 await handler(wrapped_stream) 

342 

343 await self.listener.serve(handler_wrapper, task_group) 

344 

345 async def aclose(self) -> None: 

346 await self.listener.aclose() 

347 

348 @property 

349 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

350 return { 

351 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

352 }