Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/streams/tls.py: 39%
139 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:38 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:38 +0000
1from __future__ import annotations
3import logging
4import re
5import ssl
6from collections.abc import Callable, Mapping
7from dataclasses import dataclass
8from functools import wraps
9from typing import Any, Tuple, TypeVar
11from .. import (
12 BrokenResourceError,
13 EndOfStream,
14 aclose_forcefully,
15 get_cancelled_exc_class,
16)
17from .._core._typedattr import TypedAttributeSet, typed_attribute
18from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
20T_Retval = TypeVar("T_Retval")
21_PCTRTT = Tuple[Tuple[str, str], ...]
22_PCTRTTT = Tuple[_PCTRTT, ...]
25class TLSAttribute(TypedAttributeSet):
26 """Contains Transport Layer Security related attributes."""
28 #: the selected ALPN protocol
29 alpn_protocol: str | None = typed_attribute()
30 #: the channel binding for type ``tls-unique``
31 channel_binding_tls_unique: bytes = typed_attribute()
32 #: the selected cipher
33 cipher: tuple[str, str, int] = typed_attribute()
34 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
35 # for more information)
36 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
37 #: the peer certificate in binary form
38 peer_certificate_binary: bytes | None = typed_attribute()
39 #: ``True`` if this is the server side of the connection
40 server_side: bool = typed_attribute()
41 #: ciphers shared by the client during the TLS handshake (``None`` if this is the
42 #: client side)
43 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
44 #: the :class:`~ssl.SSLObject` used for encryption
45 ssl_object: ssl.SSLObject = typed_attribute()
46 #: ``True`` if this stream does (and expects) a closing TLS handshake when the
47 #: stream is being closed
48 standard_compatible: bool = typed_attribute()
49 #: the TLS protocol version (e.g. ``TLSv1.2``)
50 tls_version: str = typed_attribute()
53@dataclass(eq=False)
54class TLSStream(ByteStream):
55 """
56 A stream wrapper that encrypts all sent data and decrypts received data.
58 This class has no public initializer; use :meth:`wrap` instead.
59 All extra attributes from :class:`~TLSAttribute` are supported.
61 :var AnyByteStream transport_stream: the wrapped stream
63 """
65 transport_stream: AnyByteStream
66 standard_compatible: bool
67 _ssl_object: ssl.SSLObject
68 _read_bio: ssl.MemoryBIO
69 _write_bio: ssl.MemoryBIO
71 @classmethod
72 async def wrap(
73 cls,
74 transport_stream: AnyByteStream,
75 *,
76 server_side: bool | None = None,
77 hostname: str | None = None,
78 ssl_context: ssl.SSLContext | None = None,
79 standard_compatible: bool = True,
80 ) -> TLSStream:
81 """
82 Wrap an existing stream with Transport Layer Security.
84 This performs a TLS handshake with the peer.
86 :param transport_stream: a bytes-transporting stream to wrap
87 :param server_side: ``True`` if this is the server side of the connection,
88 ``False`` if this is the client side (if omitted, will be set to ``False``
89 if ``hostname`` has been provided, ``False`` otherwise). Used only to create
90 a default context when an explicit context has not been provided.
91 :param hostname: host name of the peer (if host name checking is desired)
92 :param ssl_context: the SSLContext object to use (if not provided, a secure
93 default will be created)
94 :param standard_compatible: if ``False``, skip the closing handshake when
95 closing the connection, and don't raise an exception if the peer does the
96 same
97 :raises ~ssl.SSLError: if the TLS handshake fails
99 """
100 if server_side is None:
101 server_side = not hostname
103 if not ssl_context:
104 purpose = (
105 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
106 )
107 ssl_context = ssl.create_default_context(purpose)
109 # Re-enable detection of unexpected EOFs if it was disabled by Python
110 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
111 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
113 bio_in = ssl.MemoryBIO()
114 bio_out = ssl.MemoryBIO()
115 ssl_object = ssl_context.wrap_bio(
116 bio_in, bio_out, server_side=server_side, server_hostname=hostname
117 )
118 wrapper = cls(
119 transport_stream=transport_stream,
120 standard_compatible=standard_compatible,
121 _ssl_object=ssl_object,
122 _read_bio=bio_in,
123 _write_bio=bio_out,
124 )
125 await wrapper._call_sslobject_method(ssl_object.do_handshake)
126 return wrapper
128 async def _call_sslobject_method(
129 self, func: Callable[..., T_Retval], *args: object
130 ) -> T_Retval:
131 while True:
132 try:
133 result = func(*args)
134 except ssl.SSLWantReadError:
135 try:
136 # Flush any pending writes first
137 if self._write_bio.pending:
138 await self.transport_stream.send(self._write_bio.read())
140 data = await self.transport_stream.receive()
141 except EndOfStream:
142 self._read_bio.write_eof()
143 except OSError as exc:
144 self._read_bio.write_eof()
145 self._write_bio.write_eof()
146 raise BrokenResourceError from exc
147 else:
148 self._read_bio.write(data)
149 except ssl.SSLWantWriteError:
150 await self.transport_stream.send(self._write_bio.read())
151 except ssl.SSLSyscallError as exc:
152 self._read_bio.write_eof()
153 self._write_bio.write_eof()
154 raise BrokenResourceError from exc
155 except ssl.SSLError as exc:
156 self._read_bio.write_eof()
157 self._write_bio.write_eof()
158 if (
159 isinstance(exc, ssl.SSLEOFError)
160 or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
161 ):
162 if self.standard_compatible:
163 raise BrokenResourceError from exc
164 else:
165 raise EndOfStream from None
167 raise
168 else:
169 # Flush any pending writes first
170 if self._write_bio.pending:
171 await self.transport_stream.send(self._write_bio.read())
173 return result
175 async def unwrap(self) -> tuple[AnyByteStream, bytes]:
176 """
177 Does the TLS closing handshake.
179 :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
181 """
182 await self._call_sslobject_method(self._ssl_object.unwrap)
183 self._read_bio.write_eof()
184 self._write_bio.write_eof()
185 return self.transport_stream, self._read_bio.read()
187 async def aclose(self) -> None:
188 if self.standard_compatible:
189 try:
190 await self.unwrap()
191 except BaseException:
192 await aclose_forcefully(self.transport_stream)
193 raise
195 await self.transport_stream.aclose()
197 async def receive(self, max_bytes: int = 65536) -> bytes:
198 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
199 if not data:
200 raise EndOfStream
202 return data
204 async def send(self, item: bytes) -> None:
205 await self._call_sslobject_method(self._ssl_object.write, item)
207 async def send_eof(self) -> None:
208 tls_version = self.extra(TLSAttribute.tls_version)
209 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
210 if match:
211 major, minor = int(match.group(1)), int(match.group(2) or 0)
212 if (major, minor) < (1, 3):
213 raise NotImplementedError(
214 f"send_eof() requires at least TLSv1.3; current "
215 f"session uses {tls_version}"
216 )
218 raise NotImplementedError(
219 "send_eof() has not yet been implemented for TLS streams"
220 )
222 @property
223 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
224 return {
225 **self.transport_stream.extra_attributes,
226 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
227 TLSAttribute.channel_binding_tls_unique: (
228 self._ssl_object.get_channel_binding
229 ),
230 TLSAttribute.cipher: self._ssl_object.cipher,
231 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
232 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
233 True
234 ),
235 TLSAttribute.server_side: lambda: self._ssl_object.server_side,
236 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
237 if self._ssl_object.server_side
238 else None,
239 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
240 TLSAttribute.ssl_object: lambda: self._ssl_object,
241 TLSAttribute.tls_version: self._ssl_object.version,
242 }
245@dataclass(eq=False)
246class TLSListener(Listener[TLSStream]):
247 """
248 A convenience listener that wraps another listener and auto-negotiates a TLS session
249 on every accepted connection.
251 If the TLS handshake times out or raises an exception,
252 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
253 deemed necessary.
255 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
257 :param Listener listener: the listener to wrap
258 :param ssl_context: the SSL context object
259 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
260 :param handshake_timeout: time limit for the TLS handshake
261 (passed to :func:`~anyio.fail_after`)
262 """
264 listener: Listener[Any]
265 ssl_context: ssl.SSLContext
266 standard_compatible: bool = True
267 handshake_timeout: float = 30
269 @staticmethod
270 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
271 """
272 Handle an exception raised during the TLS handshake.
274 This method does 3 things:
276 #. Forcefully closes the original stream
277 #. Logs the exception (unless it was a cancellation exception) using the
278 ``anyio.streams.tls`` logger
279 #. Reraises the exception if it was a base exception or a cancellation exception
281 :param exc: the exception
282 :param stream: the original stream
284 """
285 await aclose_forcefully(stream)
287 # Log all except cancellation exceptions
288 if not isinstance(exc, get_cancelled_exc_class()):
289 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
290 # any asyncio implementation, so we explicitly pass the exception to log
291 # (https://github.com/python/cpython/issues/108668). Trio does not have this
292 # issue because it works around the CPython bug.
293 logging.getLogger(__name__).exception(
294 "Error during TLS handshake", exc_info=exc
295 )
297 # Only reraise base exceptions and cancellation exceptions
298 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
299 raise
301 async def serve(
302 self,
303 handler: Callable[[TLSStream], Any],
304 task_group: TaskGroup | None = None,
305 ) -> None:
306 @wraps(handler)
307 async def handler_wrapper(stream: AnyByteStream) -> None:
308 from .. import fail_after
310 try:
311 with fail_after(self.handshake_timeout):
312 wrapped_stream = await TLSStream.wrap(
313 stream,
314 ssl_context=self.ssl_context,
315 standard_compatible=self.standard_compatible,
316 )
317 except BaseException as exc:
318 await self.handle_handshake_error(exc, stream)
319 else:
320 await handler(wrapped_stream)
322 await self.listener.serve(handler_wrapper, task_group)
324 async def aclose(self) -> None:
325 await self.listener.aclose()
327 @property
328 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
329 return {
330 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
331 }