Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/streams/tls.py: 40%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

144 statements  

1from __future__ import annotations 

2 

3import logging 

4import re 

5import ssl 

6import sys 

7from collections.abc import Callable, Mapping 

8from dataclasses import dataclass 

9from functools import wraps 

10from typing import Any, Tuple, TypeVar 

11 

12from .. import ( 

13 BrokenResourceError, 

14 EndOfStream, 

15 aclose_forcefully, 

16 get_cancelled_exc_class, 

17) 

18from .._core._typedattr import TypedAttributeSet, typed_attribute 

19from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup 

20 

21if sys.version_info >= (3, 11): 

22 from typing import TypeVarTuple, Unpack 

23else: 

24 from typing_extensions import TypeVarTuple, Unpack 

25 

26T_Retval = TypeVar("T_Retval") 

27PosArgsT = TypeVarTuple("PosArgsT") 

28_PCTRTT = Tuple[Tuple[str, str], ...] 

29_PCTRTTT = Tuple[_PCTRTT, ...] 

30 

31 

32class TLSAttribute(TypedAttributeSet): 

33 """Contains Transport Layer Security related attributes.""" 

34 

35 #: the selected ALPN protocol 

36 alpn_protocol: str | None = typed_attribute() 

37 #: the channel binding for type ``tls-unique`` 

38 channel_binding_tls_unique: bytes = typed_attribute() 

39 #: the selected cipher 

40 cipher: tuple[str, str, int] = typed_attribute() 

41 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` 

42 # for more information) 

43 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute() 

44 #: the peer certificate in binary form 

45 peer_certificate_binary: bytes | None = typed_attribute() 

46 #: ``True`` if this is the server side of the connection 

47 server_side: bool = typed_attribute() 

48 #: ciphers shared by the client during the TLS handshake (``None`` if this is the 

49 #: client side) 

50 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() 

51 #: the :class:`~ssl.SSLObject` used for encryption 

52 ssl_object: ssl.SSLObject = typed_attribute() 

53 #: ``True`` if this stream does (and expects) a closing TLS handshake when the 

54 #: stream is being closed 

55 standard_compatible: bool = typed_attribute() 

56 #: the TLS protocol version (e.g. ``TLSv1.2``) 

57 tls_version: str = typed_attribute() 

58 

59 

60@dataclass(eq=False) 

61class TLSStream(ByteStream): 

62 """ 

63 A stream wrapper that encrypts all sent data and decrypts received data. 

64 

65 This class has no public initializer; use :meth:`wrap` instead. 

66 All extra attributes from :class:`~TLSAttribute` are supported. 

67 

68 :var AnyByteStream transport_stream: the wrapped stream 

69 

70 """ 

71 

72 transport_stream: AnyByteStream 

73 standard_compatible: bool 

74 _ssl_object: ssl.SSLObject 

75 _read_bio: ssl.MemoryBIO 

76 _write_bio: ssl.MemoryBIO 

77 

78 @classmethod 

79 async def wrap( 

80 cls, 

81 transport_stream: AnyByteStream, 

82 *, 

83 server_side: bool | None = None, 

84 hostname: str | None = None, 

85 ssl_context: ssl.SSLContext | None = None, 

86 standard_compatible: bool = True, 

87 ) -> TLSStream: 

88 """ 

89 Wrap an existing stream with Transport Layer Security. 

90 

91 This performs a TLS handshake with the peer. 

92 

93 :param transport_stream: a bytes-transporting stream to wrap 

94 :param server_side: ``True`` if this is the server side of the connection, 

95 ``False`` if this is the client side (if omitted, will be set to ``False`` 

96 if ``hostname`` has been provided, ``False`` otherwise). Used only to create 

97 a default context when an explicit context has not been provided. 

98 :param hostname: host name of the peer (if host name checking is desired) 

99 :param ssl_context: the SSLContext object to use (if not provided, a secure 

100 default will be created) 

101 :param standard_compatible: if ``False``, skip the closing handshake when 

102 closing the connection, and don't raise an exception if the peer does the 

103 same 

104 :raises ~ssl.SSLError: if the TLS handshake fails 

105 

106 """ 

107 if server_side is None: 

108 server_side = not hostname 

109 

110 if not ssl_context: 

111 purpose = ( 

112 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH 

113 ) 

114 ssl_context = ssl.create_default_context(purpose) 

115 

116 # Re-enable detection of unexpected EOFs if it was disabled by Python 

117 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): 

118 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF 

119 

120 bio_in = ssl.MemoryBIO() 

121 bio_out = ssl.MemoryBIO() 

122 ssl_object = ssl_context.wrap_bio( 

123 bio_in, bio_out, server_side=server_side, server_hostname=hostname 

124 ) 

125 wrapper = cls( 

126 transport_stream=transport_stream, 

127 standard_compatible=standard_compatible, 

128 _ssl_object=ssl_object, 

129 _read_bio=bio_in, 

130 _write_bio=bio_out, 

131 ) 

132 await wrapper._call_sslobject_method(ssl_object.do_handshake) 

133 return wrapper 

134 

135 async def _call_sslobject_method( 

136 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] 

137 ) -> T_Retval: 

138 while True: 

139 try: 

140 result = func(*args) 

141 except ssl.SSLWantReadError: 

142 try: 

143 # Flush any pending writes first 

144 if self._write_bio.pending: 

145 await self.transport_stream.send(self._write_bio.read()) 

146 

147 data = await self.transport_stream.receive() 

148 except EndOfStream: 

149 self._read_bio.write_eof() 

150 except OSError as exc: 

151 self._read_bio.write_eof() 

152 self._write_bio.write_eof() 

153 raise BrokenResourceError from exc 

154 else: 

155 self._read_bio.write(data) 

156 except ssl.SSLWantWriteError: 

157 await self.transport_stream.send(self._write_bio.read()) 

158 except ssl.SSLSyscallError as exc: 

159 self._read_bio.write_eof() 

160 self._write_bio.write_eof() 

161 raise BrokenResourceError from exc 

162 except ssl.SSLError as exc: 

163 self._read_bio.write_eof() 

164 self._write_bio.write_eof() 

165 if ( 

166 isinstance(exc, ssl.SSLEOFError) 

167 or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror 

168 ): 

169 if self.standard_compatible: 

170 raise BrokenResourceError from exc 

171 else: 

172 raise EndOfStream from None 

173 

174 raise 

175 else: 

176 # Flush any pending writes first 

177 if self._write_bio.pending: 

178 await self.transport_stream.send(self._write_bio.read()) 

179 

180 return result 

181 

182 async def unwrap(self) -> tuple[AnyByteStream, bytes]: 

183 """ 

184 Does the TLS closing handshake. 

185 

186 :return: a tuple of (wrapped byte stream, bytes left in the read buffer) 

187 

188 """ 

189 await self._call_sslobject_method(self._ssl_object.unwrap) 

190 self._read_bio.write_eof() 

191 self._write_bio.write_eof() 

192 return self.transport_stream, self._read_bio.read() 

193 

194 async def aclose(self) -> None: 

195 if self.standard_compatible: 

196 try: 

197 await self.unwrap() 

198 except BaseException: 

199 await aclose_forcefully(self.transport_stream) 

200 raise 

201 

202 await self.transport_stream.aclose() 

203 

204 async def receive(self, max_bytes: int = 65536) -> bytes: 

205 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) 

206 if not data: 

207 raise EndOfStream 

208 

209 return data 

210 

211 async def send(self, item: bytes) -> None: 

212 await self._call_sslobject_method(self._ssl_object.write, item) 

213 

214 async def send_eof(self) -> None: 

215 tls_version = self.extra(TLSAttribute.tls_version) 

216 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) 

217 if match: 

218 major, minor = int(match.group(1)), int(match.group(2) or 0) 

219 if (major, minor) < (1, 3): 

220 raise NotImplementedError( 

221 f"send_eof() requires at least TLSv1.3; current " 

222 f"session uses {tls_version}" 

223 ) 

224 

225 raise NotImplementedError( 

226 "send_eof() has not yet been implemented for TLS streams" 

227 ) 

228 

229 @property 

230 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

231 return { 

232 **self.transport_stream.extra_attributes, 

233 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, 

234 TLSAttribute.channel_binding_tls_unique: ( 

235 self._ssl_object.get_channel_binding 

236 ), 

237 TLSAttribute.cipher: self._ssl_object.cipher, 

238 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), 

239 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( 

240 True 

241 ), 

242 TLSAttribute.server_side: lambda: self._ssl_object.server_side, 

243 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() 

244 if self._ssl_object.server_side 

245 else None, 

246 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

247 TLSAttribute.ssl_object: lambda: self._ssl_object, 

248 TLSAttribute.tls_version: self._ssl_object.version, 

249 } 

250 

251 

252@dataclass(eq=False) 

253class TLSListener(Listener[TLSStream]): 

254 """ 

255 A convenience listener that wraps another listener and auto-negotiates a TLS session 

256 on every accepted connection. 

257 

258 If the TLS handshake times out or raises an exception, 

259 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is 

260 deemed necessary. 

261 

262 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. 

263 

264 :param Listener listener: the listener to wrap 

265 :param ssl_context: the SSL context object 

266 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` 

267 :param handshake_timeout: time limit for the TLS handshake 

268 (passed to :func:`~anyio.fail_after`) 

269 """ 

270 

271 listener: Listener[Any] 

272 ssl_context: ssl.SSLContext 

273 standard_compatible: bool = True 

274 handshake_timeout: float = 30 

275 

276 @staticmethod 

277 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: 

278 """ 

279 Handle an exception raised during the TLS handshake. 

280 

281 This method does 3 things: 

282 

283 #. Forcefully closes the original stream 

284 #. Logs the exception (unless it was a cancellation exception) using the 

285 ``anyio.streams.tls`` logger 

286 #. Reraises the exception if it was a base exception or a cancellation exception 

287 

288 :param exc: the exception 

289 :param stream: the original stream 

290 

291 """ 

292 await aclose_forcefully(stream) 

293 

294 # Log all except cancellation exceptions 

295 if not isinstance(exc, get_cancelled_exc_class()): 

296 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using 

297 # any asyncio implementation, so we explicitly pass the exception to log 

298 # (https://github.com/python/cpython/issues/108668). Trio does not have this 

299 # issue because it works around the CPython bug. 

300 logging.getLogger(__name__).exception( 

301 "Error during TLS handshake", exc_info=exc 

302 ) 

303 

304 # Only reraise base exceptions and cancellation exceptions 

305 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): 

306 raise 

307 

308 async def serve( 

309 self, 

310 handler: Callable[[TLSStream], Any], 

311 task_group: TaskGroup | None = None, 

312 ) -> None: 

313 @wraps(handler) 

314 async def handler_wrapper(stream: AnyByteStream) -> None: 

315 from .. import fail_after 

316 

317 try: 

318 with fail_after(self.handshake_timeout): 

319 wrapped_stream = await TLSStream.wrap( 

320 stream, 

321 ssl_context=self.ssl_context, 

322 standard_compatible=self.standard_compatible, 

323 ) 

324 except BaseException as exc: 

325 await self.handle_handshake_error(exc, stream) 

326 else: 

327 await handler(wrapped_stream) 

328 

329 await self.listener.serve(handler_wrapper, task_group) 

330 

331 async def aclose(self) -> None: 

332 await self.listener.aclose() 

333 

334 @property 

335 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

336 return { 

337 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

338 }