Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/anyio/streams/tls.py: 40%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3__all__ = (
4 "TLSAttribute",
5 "TLSConnectable",
6 "TLSListener",
7 "TLSStream",
8)
10import logging
11import re
12import ssl
13import sys
14from collections.abc import Callable, Mapping
15from dataclasses import dataclass
16from functools import wraps
17from ssl import SSLContext
18from typing import Any, TypeVar
20from .. import (
21 BrokenResourceError,
22 EndOfStream,
23 aclose_forcefully,
24 get_cancelled_exc_class,
25 to_thread,
26)
27from .._core._typedattr import TypedAttributeSet, typed_attribute
28from ..abc import (
29 AnyByteStream,
30 AnyByteStreamConnectable,
31 ByteStream,
32 ByteStreamConnectable,
33 Listener,
34 TaskGroup,
35)
37if sys.version_info >= (3, 10):
38 from typing import TypeAlias
39else:
40 from typing_extensions import TypeAlias
42if sys.version_info >= (3, 11):
43 from typing import TypeVarTuple, Unpack
44else:
45 from typing_extensions import TypeVarTuple, Unpack
47if sys.version_info >= (3, 12):
48 from typing import override
49else:
50 from typing_extensions import override
52T_Retval = TypeVar("T_Retval")
53PosArgsT = TypeVarTuple("PosArgsT")
54_PCTRTT: TypeAlias = tuple[tuple[str, str], ...]
55_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
58class TLSAttribute(TypedAttributeSet):
59 """Contains Transport Layer Security related attributes."""
61 #: the selected ALPN protocol
62 alpn_protocol: str | None = typed_attribute()
63 #: the channel binding for type ``tls-unique``
64 channel_binding_tls_unique: bytes = typed_attribute()
65 #: the selected cipher
66 cipher: tuple[str, str, int] = typed_attribute()
67 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
68 # for more information)
69 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
70 #: the peer certificate in binary form
71 peer_certificate_binary: bytes | None = typed_attribute()
72 #: ``True`` if this is the server side of the connection
73 server_side: bool = typed_attribute()
74 #: ciphers shared by the client during the TLS handshake (``None`` if this is the
75 #: client side)
76 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
77 #: the :class:`~ssl.SSLObject` used for encryption
78 ssl_object: ssl.SSLObject = typed_attribute()
79 #: ``True`` if this stream does (and expects) a closing TLS handshake when the
80 #: stream is being closed
81 standard_compatible: bool = typed_attribute()
82 #: the TLS protocol version (e.g. ``TLSv1.2``)
83 tls_version: str = typed_attribute()
86@dataclass(eq=False)
87class TLSStream(ByteStream):
88 """
89 A stream wrapper that encrypts all sent data and decrypts received data.
91 This class has no public initializer; use :meth:`wrap` instead.
92 All extra attributes from :class:`~TLSAttribute` are supported.
94 :var AnyByteStream transport_stream: the wrapped stream
96 """
98 transport_stream: AnyByteStream
99 standard_compatible: bool
100 _ssl_object: ssl.SSLObject
101 _read_bio: ssl.MemoryBIO
102 _write_bio: ssl.MemoryBIO
104 @classmethod
105 async def wrap(
106 cls,
107 transport_stream: AnyByteStream,
108 *,
109 server_side: bool | None = None,
110 hostname: str | None = None,
111 ssl_context: ssl.SSLContext | None = None,
112 standard_compatible: bool = True,
113 ) -> TLSStream:
114 """
115 Wrap an existing stream with Transport Layer Security.
117 This performs a TLS handshake with the peer.
119 :param transport_stream: a bytes-transporting stream to wrap
120 :param server_side: ``True`` if this is the server side of the connection,
121 ``False`` if this is the client side (if omitted, will be set to ``False``
122 if ``hostname`` has been provided, ``False`` otherwise). Used only to create
123 a default context when an explicit context has not been provided.
124 :param hostname: host name of the peer (if host name checking is desired)
125 :param ssl_context: the SSLContext object to use (if not provided, a secure
126 default will be created)
127 :param standard_compatible: if ``False``, skip the closing handshake when
128 closing the connection, and don't raise an exception if the peer does the
129 same
130 :raises ~ssl.SSLError: if the TLS handshake fails
132 """
133 if server_side is None:
134 server_side = not hostname
136 if not ssl_context:
137 purpose = (
138 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
139 )
140 ssl_context = ssl.create_default_context(purpose)
142 # Re-enable detection of unexpected EOFs if it was disabled by Python
143 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
144 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
146 bio_in = ssl.MemoryBIO()
147 bio_out = ssl.MemoryBIO()
149 # External SSLContext implementations may do blocking I/O in wrap_bio(),
150 # but the standard library implementation won't
151 if type(ssl_context) is ssl.SSLContext:
152 ssl_object = ssl_context.wrap_bio(
153 bio_in, bio_out, server_side=server_side, server_hostname=hostname
154 )
155 else:
156 ssl_object = await to_thread.run_sync(
157 ssl_context.wrap_bio,
158 bio_in,
159 bio_out,
160 server_side,
161 hostname,
162 None,
163 )
165 wrapper = cls(
166 transport_stream=transport_stream,
167 standard_compatible=standard_compatible,
168 _ssl_object=ssl_object,
169 _read_bio=bio_in,
170 _write_bio=bio_out,
171 )
172 await wrapper._call_sslobject_method(ssl_object.do_handshake)
173 return wrapper
175 async def _call_sslobject_method(
176 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
177 ) -> T_Retval:
178 while True:
179 try:
180 result = func(*args)
181 except ssl.SSLWantReadError:
182 try:
183 # Flush any pending writes first
184 if self._write_bio.pending:
185 await self.transport_stream.send(self._write_bio.read())
187 data = await self.transport_stream.receive()
188 except EndOfStream:
189 self._read_bio.write_eof()
190 except OSError as exc:
191 self._read_bio.write_eof()
192 self._write_bio.write_eof()
193 raise BrokenResourceError from exc
194 else:
195 self._read_bio.write(data)
196 except ssl.SSLWantWriteError:
197 await self.transport_stream.send(self._write_bio.read())
198 except ssl.SSLSyscallError as exc:
199 self._read_bio.write_eof()
200 self._write_bio.write_eof()
201 raise BrokenResourceError from exc
202 except ssl.SSLError as exc:
203 self._read_bio.write_eof()
204 self._write_bio.write_eof()
205 if isinstance(exc, ssl.SSLEOFError) or (
206 exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
207 ):
208 if self.standard_compatible:
209 raise BrokenResourceError from exc
210 else:
211 raise EndOfStream from None
213 raise
214 else:
215 # Flush any pending writes first
216 if self._write_bio.pending:
217 await self.transport_stream.send(self._write_bio.read())
219 return result
221 async def unwrap(self) -> tuple[AnyByteStream, bytes]:
222 """
223 Does the TLS closing handshake.
225 :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
227 """
228 await self._call_sslobject_method(self._ssl_object.unwrap)
229 self._read_bio.write_eof()
230 self._write_bio.write_eof()
231 return self.transport_stream, self._read_bio.read()
233 async def aclose(self) -> None:
234 if self.standard_compatible:
235 try:
236 await self.unwrap()
237 except BaseException:
238 await aclose_forcefully(self.transport_stream)
239 raise
241 await self.transport_stream.aclose()
243 async def receive(self, max_bytes: int = 65536) -> bytes:
244 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
245 if not data:
246 raise EndOfStream
248 return data
250 async def send(self, item: bytes) -> None:
251 await self._call_sslobject_method(self._ssl_object.write, item)
253 async def send_eof(self) -> None:
254 tls_version = self.extra(TLSAttribute.tls_version)
255 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
256 if match:
257 major, minor = int(match.group(1)), int(match.group(2) or 0)
258 if (major, minor) < (1, 3):
259 raise NotImplementedError(
260 f"send_eof() requires at least TLSv1.3; current "
261 f"session uses {tls_version}"
262 )
264 raise NotImplementedError(
265 "send_eof() has not yet been implemented for TLS streams"
266 )
268 @property
269 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
270 return {
271 **self.transport_stream.extra_attributes,
272 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
273 TLSAttribute.channel_binding_tls_unique: (
274 self._ssl_object.get_channel_binding
275 ),
276 TLSAttribute.cipher: self._ssl_object.cipher,
277 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
278 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
279 True
280 ),
281 TLSAttribute.server_side: lambda: self._ssl_object.server_side,
282 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
283 if self._ssl_object.server_side
284 else None,
285 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
286 TLSAttribute.ssl_object: lambda: self._ssl_object,
287 TLSAttribute.tls_version: self._ssl_object.version,
288 }
291@dataclass(eq=False)
292class TLSListener(Listener[TLSStream]):
293 """
294 A convenience listener that wraps another listener and auto-negotiates a TLS session
295 on every accepted connection.
297 If the TLS handshake times out or raises an exception,
298 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
299 deemed necessary.
301 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
303 :param Listener listener: the listener to wrap
304 :param ssl_context: the SSL context object
305 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
306 :param handshake_timeout: time limit for the TLS handshake
307 (passed to :func:`~anyio.fail_after`)
308 """
310 listener: Listener[Any]
311 ssl_context: ssl.SSLContext
312 standard_compatible: bool = True
313 handshake_timeout: float = 30
315 @staticmethod
316 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
317 """
318 Handle an exception raised during the TLS handshake.
320 This method does 3 things:
322 #. Forcefully closes the original stream
323 #. Logs the exception (unless it was a cancellation exception) using the
324 ``anyio.streams.tls`` logger
325 #. Reraises the exception if it was a base exception or a cancellation exception
327 :param exc: the exception
328 :param stream: the original stream
330 """
331 await aclose_forcefully(stream)
333 # Log all except cancellation exceptions
334 if not isinstance(exc, get_cancelled_exc_class()):
335 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
336 # any asyncio implementation, so we explicitly pass the exception to log
337 # (https://github.com/python/cpython/issues/108668). Trio does not have this
338 # issue because it works around the CPython bug.
339 logging.getLogger(__name__).exception(
340 "Error during TLS handshake", exc_info=exc
341 )
343 # Only reraise base exceptions and cancellation exceptions
344 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
345 raise
347 async def serve(
348 self,
349 handler: Callable[[TLSStream], Any],
350 task_group: TaskGroup | None = None,
351 ) -> None:
352 @wraps(handler)
353 async def handler_wrapper(stream: AnyByteStream) -> None:
354 from .. import fail_after
356 try:
357 with fail_after(self.handshake_timeout):
358 wrapped_stream = await TLSStream.wrap(
359 stream,
360 ssl_context=self.ssl_context,
361 standard_compatible=self.standard_compatible,
362 )
363 except BaseException as exc:
364 await self.handle_handshake_error(exc, stream)
365 else:
366 await handler(wrapped_stream)
368 await self.listener.serve(handler_wrapper, task_group)
370 async def aclose(self) -> None:
371 await self.listener.aclose()
373 @property
374 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
375 return {
376 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
377 }
380class TLSConnectable(ByteStreamConnectable):
381 """
382 Wraps another connectable and does TLS negotiation after a successful connection.
384 :param connectable: the connectable to wrap
385 :param hostname: host name of the server (if host name checking is desired)
386 :param ssl_context: the SSLContext object to use (if not provided, a secure default
387 will be created)
388 :param standard_compatible: if ``False``, skip the closing handshake when closing
389 the connection, and don't raise an exception if the server does the same
390 """
392 def __init__(
393 self,
394 connectable: AnyByteStreamConnectable,
395 *,
396 hostname: str | None = None,
397 ssl_context: ssl.SSLContext | None = None,
398 standard_compatible: bool = True,
399 ) -> None:
400 self.connectable = connectable
401 self.ssl_context: SSLContext = ssl_context or ssl.create_default_context(
402 ssl.Purpose.SERVER_AUTH
403 )
404 if not isinstance(self.ssl_context, ssl.SSLContext):
405 raise TypeError(
406 "ssl_context must be an instance of ssl.SSLContext, not "
407 f"{type(self.ssl_context).__name__}"
408 )
409 self.hostname = hostname
410 self.standard_compatible = standard_compatible
412 @override
413 async def connect(self) -> TLSStream:
414 stream = await self.connectable.connect()
415 try:
416 return await TLSStream.wrap(
417 stream,
418 hostname=self.hostname,
419 ssl_context=self.ssl_context,
420 standard_compatible=self.standard_compatible,
421 )
422 except BaseException:
423 await aclose_forcefully(stream)
424 raise