Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/anyio/streams/tls.py: 39%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3__all__ = (
4 "TLSAttribute",
5 "TLSConnectable",
6 "TLSListener",
7 "TLSStream",
8)
10import logging
11import re
12import ssl
13import sys
14from collections.abc import Callable, Mapping
15from dataclasses import dataclass
16from functools import wraps
17from ssl import SSLContext
18from typing import Any, TypeAlias, TypeVar
20from .. import (
21 BrokenResourceError,
22 EndOfStream,
23 aclose_forcefully,
24 get_cancelled_exc_class,
25 to_thread,
26)
27from .._core._typedattr import TypedAttributeSet, typed_attribute
28from ..abc import (
29 AnyByteStream,
30 AnyByteStreamConnectable,
31 ByteStream,
32 ByteStreamConnectable,
33 Listener,
34 TaskGroup,
35)
37if sys.version_info >= (3, 11):
38 from typing import TypeVarTuple, Unpack
39else:
40 from typing_extensions import TypeVarTuple, Unpack
42if sys.version_info >= (3, 12):
43 from typing import override
44else:
45 from typing_extensions import override
47T_Retval = TypeVar("T_Retval")
48PosArgsT = TypeVarTuple("PosArgsT")
49_PCTRTT: TypeAlias = tuple[tuple[str, str], ...]
50_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
53class TLSAttribute(TypedAttributeSet):
54 """Contains Transport Layer Security related attributes."""
56 #: the selected ALPN protocol
57 alpn_protocol: str | None = typed_attribute()
58 #: the channel binding for type ``tls-unique``
59 channel_binding_tls_unique: bytes = typed_attribute()
60 #: the selected cipher
61 cipher: tuple[str, str, int] = typed_attribute()
62 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
63 # for more information)
64 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
65 #: the peer certificate in binary form
66 peer_certificate_binary: bytes | None = typed_attribute()
67 #: ``True`` if this is the server side of the connection
68 server_side: bool = typed_attribute()
69 #: ciphers shared by the client during the TLS handshake (``None`` if this is the
70 #: client side)
71 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
72 #: the :class:`~ssl.SSLObject` used for encryption
73 ssl_object: ssl.SSLObject = typed_attribute()
74 #: ``True`` if this stream does (and expects) a closing TLS handshake when the
75 #: stream is being closed
76 standard_compatible: bool = typed_attribute()
77 #: the TLS protocol version (e.g. ``TLSv1.2``)
78 tls_version: str = typed_attribute()
81@dataclass(eq=False)
82class TLSStream(ByteStream):
83 """
84 A stream wrapper that encrypts all sent data and decrypts received data.
86 This class has no public initializer; use :meth:`wrap` instead.
87 All extra attributes from :class:`~TLSAttribute` are supported.
89 :var AnyByteStream transport_stream: the wrapped stream
91 """
93 transport_stream: AnyByteStream
94 standard_compatible: bool
95 _ssl_object: ssl.SSLObject
96 _read_bio: ssl.MemoryBIO
97 _write_bio: ssl.MemoryBIO
99 @classmethod
100 async def wrap(
101 cls,
102 transport_stream: AnyByteStream,
103 *,
104 server_side: bool | None = None,
105 hostname: str | None = None,
106 ssl_context: ssl.SSLContext | None = None,
107 standard_compatible: bool = True,
108 ) -> TLSStream:
109 """
110 Wrap an existing stream with Transport Layer Security.
112 This performs a TLS handshake with the peer.
114 :param transport_stream: a bytes-transporting stream to wrap
115 :param server_side: ``True`` if this is the server side of the connection,
116 ``False`` if this is the client side (if omitted, will be set to ``False``
117 if ``hostname`` has been provided, ``False`` otherwise). Used only to create
118 a default context when an explicit context has not been provided.
119 :param hostname: host name of the peer (if host name checking is desired)
120 :param ssl_context: the SSLContext object to use (if not provided, a secure
121 default will be created)
122 :param standard_compatible: if ``False``, skip the closing handshake when
123 closing the connection, and don't raise an exception if the peer does the
124 same
125 :raises ~ssl.SSLError: if the TLS handshake fails
127 """
128 if server_side is None:
129 server_side = not hostname
131 if not ssl_context:
132 purpose = (
133 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
134 )
135 ssl_context = ssl.create_default_context(purpose)
137 # Re-enable detection of unexpected EOFs if it was disabled by Python
138 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
139 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
141 bio_in = ssl.MemoryBIO()
142 bio_out = ssl.MemoryBIO()
144 # External SSLContext implementations may do blocking I/O in wrap_bio(),
145 # but the standard library implementation won't
146 if type(ssl_context) is ssl.SSLContext:
147 ssl_object = ssl_context.wrap_bio(
148 bio_in, bio_out, server_side=server_side, server_hostname=hostname
149 )
150 else:
151 ssl_object = await to_thread.run_sync(
152 ssl_context.wrap_bio,
153 bio_in,
154 bio_out,
155 server_side,
156 hostname,
157 None,
158 )
160 wrapper = cls(
161 transport_stream=transport_stream,
162 standard_compatible=standard_compatible,
163 _ssl_object=ssl_object,
164 _read_bio=bio_in,
165 _write_bio=bio_out,
166 )
167 await wrapper._call_sslobject_method(ssl_object.do_handshake)
168 return wrapper
170 async def _call_sslobject_method(
171 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
172 ) -> T_Retval:
173 while True:
174 try:
175 result = func(*args)
176 except ssl.SSLWantReadError:
177 try:
178 # Flush any pending writes first
179 if self._write_bio.pending:
180 await self.transport_stream.send(self._write_bio.read())
182 data = await self.transport_stream.receive()
183 except EndOfStream:
184 self._read_bio.write_eof()
185 except OSError as exc:
186 self._read_bio.write_eof()
187 self._write_bio.write_eof()
188 raise BrokenResourceError from exc
189 else:
190 self._read_bio.write(data)
191 except ssl.SSLWantWriteError:
192 await self.transport_stream.send(self._write_bio.read())
193 except ssl.SSLSyscallError as exc:
194 self._read_bio.write_eof()
195 self._write_bio.write_eof()
196 raise BrokenResourceError from exc
197 except ssl.SSLError as exc:
198 self._read_bio.write_eof()
199 self._write_bio.write_eof()
200 if isinstance(exc, ssl.SSLEOFError) or (
201 exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
202 ):
203 if self.standard_compatible:
204 raise BrokenResourceError from exc
205 else:
206 raise EndOfStream from None
208 raise
209 else:
210 # Flush any pending writes first
211 if self._write_bio.pending:
212 await self.transport_stream.send(self._write_bio.read())
214 return result
216 async def unwrap(self) -> tuple[AnyByteStream, bytes]:
217 """
218 Does the TLS closing handshake.
220 :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
222 """
223 await self._call_sslobject_method(self._ssl_object.unwrap)
224 self._read_bio.write_eof()
225 self._write_bio.write_eof()
226 return self.transport_stream, self._read_bio.read()
228 async def aclose(self) -> None:
229 if self.standard_compatible:
230 try:
231 await self.unwrap()
232 except BaseException:
233 await aclose_forcefully(self.transport_stream)
234 raise
236 await self.transport_stream.aclose()
238 async def receive(self, max_bytes: int = 65536) -> bytes:
239 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
240 if not data:
241 raise EndOfStream
243 return data
245 async def send(self, item: bytes) -> None:
246 await self._call_sslobject_method(self._ssl_object.write, item)
248 async def send_eof(self) -> None:
249 tls_version = self.extra(TLSAttribute.tls_version)
250 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
251 if match:
252 major, minor = int(match.group(1)), int(match.group(2) or 0)
253 if (major, minor) < (1, 3):
254 raise NotImplementedError(
255 f"send_eof() requires at least TLSv1.3; current "
256 f"session uses {tls_version}"
257 )
259 raise NotImplementedError(
260 "send_eof() has not yet been implemented for TLS streams"
261 )
263 @property
264 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
265 return {
266 **self.transport_stream.extra_attributes,
267 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
268 TLSAttribute.channel_binding_tls_unique: (
269 self._ssl_object.get_channel_binding
270 ),
271 TLSAttribute.cipher: self._ssl_object.cipher,
272 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
273 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
274 True
275 ),
276 TLSAttribute.server_side: lambda: self._ssl_object.server_side,
277 TLSAttribute.shared_ciphers: lambda: (
278 self._ssl_object.shared_ciphers()
279 if self._ssl_object.server_side
280 else None
281 ),
282 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
283 TLSAttribute.ssl_object: lambda: self._ssl_object,
284 TLSAttribute.tls_version: self._ssl_object.version,
285 }
288@dataclass(eq=False)
289class TLSListener(Listener[TLSStream]):
290 """
291 A convenience listener that wraps another listener and auto-negotiates a TLS session
292 on every accepted connection.
294 If the TLS handshake times out or raises an exception,
295 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
296 deemed necessary.
298 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
300 :param Listener listener: the listener to wrap
301 :param ssl_context: the SSL context object
302 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
303 :param handshake_timeout: time limit for the TLS handshake
304 (passed to :func:`~anyio.fail_after`)
305 """
307 listener: Listener[Any]
308 ssl_context: ssl.SSLContext
309 standard_compatible: bool = True
310 handshake_timeout: float = 30
312 @staticmethod
313 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
314 """
315 Handle an exception raised during the TLS handshake.
317 This method does 3 things:
319 #. Forcefully closes the original stream
320 #. Logs the exception (unless it was a cancellation exception) using the
321 ``anyio.streams.tls`` logger
322 #. Reraises the exception if it was a base exception or a cancellation exception
324 :param exc: the exception
325 :param stream: the original stream
327 """
328 await aclose_forcefully(stream)
330 # Log all except cancellation exceptions
331 if not isinstance(exc, get_cancelled_exc_class()):
332 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
333 # any asyncio implementation, so we explicitly pass the exception to log
334 # (https://github.com/python/cpython/issues/108668). Trio does not have this
335 # issue because it works around the CPython bug.
336 logging.getLogger(__name__).exception(
337 "Error during TLS handshake", exc_info=exc
338 )
340 # Only reraise base exceptions and cancellation exceptions
341 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
342 raise
344 async def serve(
345 self,
346 handler: Callable[[TLSStream], Any],
347 task_group: TaskGroup | None = None,
348 ) -> None:
349 @wraps(handler)
350 async def handler_wrapper(stream: AnyByteStream) -> None:
351 from .. import fail_after
353 try:
354 with fail_after(self.handshake_timeout):
355 wrapped_stream = await TLSStream.wrap(
356 stream,
357 ssl_context=self.ssl_context,
358 standard_compatible=self.standard_compatible,
359 )
360 except BaseException as exc:
361 await self.handle_handshake_error(exc, stream)
362 else:
363 await handler(wrapped_stream)
365 await self.listener.serve(handler_wrapper, task_group)
367 async def aclose(self) -> None:
368 await self.listener.aclose()
370 @property
371 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
372 return {
373 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
374 }
377class TLSConnectable(ByteStreamConnectable):
378 """
379 Wraps another connectable and does TLS negotiation after a successful connection.
381 :param connectable: the connectable to wrap
382 :param hostname: host name of the server (if host name checking is desired)
383 :param ssl_context: the SSLContext object to use (if not provided, a secure default
384 will be created)
385 :param standard_compatible: if ``False``, skip the closing handshake when closing
386 the connection, and don't raise an exception if the server does the same
387 """
389 def __init__(
390 self,
391 connectable: AnyByteStreamConnectable,
392 *,
393 hostname: str | None = None,
394 ssl_context: ssl.SSLContext | None = None,
395 standard_compatible: bool = True,
396 ) -> None:
397 self.connectable = connectable
398 self.ssl_context: SSLContext = ssl_context or ssl.create_default_context(
399 ssl.Purpose.SERVER_AUTH
400 )
401 if not isinstance(self.ssl_context, ssl.SSLContext):
402 raise TypeError(
403 "ssl_context must be an instance of ssl.SSLContext, not "
404 f"{type(self.ssl_context).__name__}"
405 )
406 self.hostname = hostname
407 self.standard_compatible = standard_compatible
409 @override
410 async def connect(self) -> TLSStream:
411 stream = await self.connectable.connect()
412 try:
413 return await TLSStream.wrap(
414 stream,
415 hostname=self.hostname,
416 ssl_context=self.ssl_context,
417 standard_compatible=self.standard_compatible,
418 )
419 except BaseException:
420 await aclose_forcefully(stream)
421 raise