Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/streams/tls.py: 38%
138 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 07:19 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 07:19 +0000
1from __future__ import annotations
3import logging
4import re
5import ssl
6from dataclasses import dataclass
7from functools import wraps
8from typing import Any, Callable, Mapping, Tuple, TypeVar
10from .. import (
11 BrokenResourceError,
12 EndOfStream,
13 aclose_forcefully,
14 get_cancelled_exc_class,
15)
16from .._core._typedattr import TypedAttributeSet, typed_attribute
17from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
19T_Retval = TypeVar("T_Retval")
20_PCTRTT = Tuple[Tuple[str, str], ...]
21_PCTRTTT = Tuple[_PCTRTT, ...]
24class TLSAttribute(TypedAttributeSet):
25 """Contains Transport Layer Security related attributes."""
27 #: the selected ALPN protocol
28 alpn_protocol: str | None = typed_attribute()
29 #: the channel binding for type ``tls-unique``
30 channel_binding_tls_unique: bytes = typed_attribute()
31 #: the selected cipher
32 cipher: tuple[str, str, int] = typed_attribute()
33 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` for more
34 #: information)
35 peer_certificate: dict[str, str | _PCTRTTT | _PCTRTT] | None = typed_attribute()
36 #: the peer certificate in binary form
37 peer_certificate_binary: bytes | None = typed_attribute()
38 #: ``True`` if this is the server side of the connection
39 server_side: bool = typed_attribute()
40 #: ciphers shared by the client during the TLS handshake (``None`` if this is the
41 #: client side)
42 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
43 #: the :class:`~ssl.SSLObject` used for encryption
44 ssl_object: ssl.SSLObject = typed_attribute()
45 #: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being
46 #: closed
47 standard_compatible: bool = typed_attribute()
48 #: the TLS protocol version (e.g. ``TLSv1.2``)
49 tls_version: str = typed_attribute()
52@dataclass(eq=False)
53class TLSStream(ByteStream):
54 """
55 A stream wrapper that encrypts all sent data and decrypts received data.
57 This class has no public initializer; use :meth:`wrap` instead.
58 All extra attributes from :class:`~TLSAttribute` are supported.
60 :var AnyByteStream transport_stream: the wrapped stream
62 """
64 transport_stream: AnyByteStream
65 standard_compatible: bool
66 _ssl_object: ssl.SSLObject
67 _read_bio: ssl.MemoryBIO
68 _write_bio: ssl.MemoryBIO
70 @classmethod
71 async def wrap(
72 cls,
73 transport_stream: AnyByteStream,
74 *,
75 server_side: bool | None = None,
76 hostname: str | None = None,
77 ssl_context: ssl.SSLContext | None = None,
78 standard_compatible: bool = True,
79 ) -> TLSStream:
80 """
81 Wrap an existing stream with Transport Layer Security.
83 This performs a TLS handshake with the peer.
85 :param transport_stream: a bytes-transporting stream to wrap
86 :param server_side: ``True`` if this is the server side of the connection, ``False`` if
87 this is the client side (if omitted, will be set to ``False`` if ``hostname`` has been
88 provided, ``False`` otherwise). Used only to create a default context when an explicit
89 context has not been provided.
90 :param hostname: host name of the peer (if host name checking is desired)
91 :param ssl_context: the SSLContext object to use (if not provided, a secure default will be
92 created)
93 :param standard_compatible: if ``False``, skip the closing handshake when closing the
94 connection, and don't raise an exception if the peer does the same
95 :raises ~ssl.SSLError: if the TLS handshake fails
97 """
98 if server_side is None:
99 server_side = not hostname
101 if not ssl_context:
102 purpose = (
103 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
104 )
105 ssl_context = ssl.create_default_context(purpose)
107 # Re-enable detection of unexpected EOFs if it was disabled by Python
108 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
109 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
111 bio_in = ssl.MemoryBIO()
112 bio_out = ssl.MemoryBIO()
113 ssl_object = ssl_context.wrap_bio(
114 bio_in, bio_out, server_side=server_side, server_hostname=hostname
115 )
116 wrapper = cls(
117 transport_stream=transport_stream,
118 standard_compatible=standard_compatible,
119 _ssl_object=ssl_object,
120 _read_bio=bio_in,
121 _write_bio=bio_out,
122 )
123 await wrapper._call_sslobject_method(ssl_object.do_handshake)
124 return wrapper
126 async def _call_sslobject_method(
127 self, func: Callable[..., T_Retval], *args: object
128 ) -> T_Retval:
129 while True:
130 try:
131 result = func(*args)
132 except ssl.SSLWantReadError:
133 try:
134 # Flush any pending writes first
135 if self._write_bio.pending:
136 await self.transport_stream.send(self._write_bio.read())
138 data = await self.transport_stream.receive()
139 except EndOfStream:
140 self._read_bio.write_eof()
141 except OSError as exc:
142 self._read_bio.write_eof()
143 self._write_bio.write_eof()
144 raise BrokenResourceError from exc
145 else:
146 self._read_bio.write(data)
147 except ssl.SSLWantWriteError:
148 await self.transport_stream.send(self._write_bio.read())
149 except ssl.SSLSyscallError as exc:
150 self._read_bio.write_eof()
151 self._write_bio.write_eof()
152 raise BrokenResourceError from exc
153 except ssl.SSLError as exc:
154 self._read_bio.write_eof()
155 self._write_bio.write_eof()
156 if (
157 isinstance(exc, ssl.SSLEOFError)
158 or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
159 ):
160 if self.standard_compatible:
161 raise BrokenResourceError from exc
162 else:
163 raise EndOfStream from None
165 raise
166 else:
167 # Flush any pending writes first
168 if self._write_bio.pending:
169 await self.transport_stream.send(self._write_bio.read())
171 return result
173 async def unwrap(self) -> tuple[AnyByteStream, bytes]:
174 """
175 Does the TLS closing handshake.
177 :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
179 """
180 await self._call_sslobject_method(self._ssl_object.unwrap)
181 self._read_bio.write_eof()
182 self._write_bio.write_eof()
183 return self.transport_stream, self._read_bio.read()
185 async def aclose(self) -> None:
186 if self.standard_compatible:
187 try:
188 await self.unwrap()
189 except BaseException:
190 await aclose_forcefully(self.transport_stream)
191 raise
193 await self.transport_stream.aclose()
195 async def receive(self, max_bytes: int = 65536) -> bytes:
196 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
197 if not data:
198 raise EndOfStream
200 return data
202 async def send(self, item: bytes) -> None:
203 await self._call_sslobject_method(self._ssl_object.write, item)
205 async def send_eof(self) -> None:
206 tls_version = self.extra(TLSAttribute.tls_version)
207 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
208 if match:
209 major, minor = int(match.group(1)), int(match.group(2) or 0)
210 if (major, minor) < (1, 3):
211 raise NotImplementedError(
212 f"send_eof() requires at least TLSv1.3; current "
213 f"session uses {tls_version}"
214 )
216 raise NotImplementedError(
217 "send_eof() has not yet been implemented for TLS streams"
218 )
220 @property
221 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
222 return {
223 **self.transport_stream.extra_attributes,
224 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
225 TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding,
226 TLSAttribute.cipher: self._ssl_object.cipher,
227 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
228 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
229 True
230 ),
231 TLSAttribute.server_side: lambda: self._ssl_object.server_side,
232 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
233 if self._ssl_object.server_side
234 else None,
235 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
236 TLSAttribute.ssl_object: lambda: self._ssl_object,
237 TLSAttribute.tls_version: self._ssl_object.version,
238 }
241@dataclass(eq=False)
242class TLSListener(Listener[TLSStream]):
243 """
244 A convenience listener that wraps another listener and auto-negotiates a TLS session on every
245 accepted connection.
247 If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is
248 called to do whatever post-mortem processing is deemed necessary.
250 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
252 :param Listener listener: the listener to wrap
253 :param ssl_context: the SSL context object
254 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
255 :param handshake_timeout: time limit for the TLS handshake
256 (passed to :func:`~anyio.fail_after`)
257 """
259 listener: Listener[Any]
260 ssl_context: ssl.SSLContext
261 standard_compatible: bool = True
262 handshake_timeout: float = 30
264 @staticmethod
265 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
266 f"""
267 Handle an exception raised during the TLS handshake.
269 This method does 3 things:
271 #. Forcefully closes the original stream
272 #. Logs the exception (unless it was a cancellation exception) using the ``{__name__}``
273 logger
274 #. Reraises the exception if it was a base exception or a cancellation exception
276 :param exc: the exception
277 :param stream: the original stream
279 """
280 await aclose_forcefully(stream)
282 # Log all except cancellation exceptions
283 if not isinstance(exc, get_cancelled_exc_class()):
284 logging.getLogger(__name__).exception("Error during TLS handshake")
286 # Only reraise base exceptions and cancellation exceptions
287 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
288 raise
290 async def serve(
291 self,
292 handler: Callable[[TLSStream], Any],
293 task_group: TaskGroup | None = None,
294 ) -> None:
295 @wraps(handler)
296 async def handler_wrapper(stream: AnyByteStream) -> None:
297 from .. import fail_after
299 try:
300 with fail_after(self.handshake_timeout):
301 wrapped_stream = await TLSStream.wrap(
302 stream,
303 ssl_context=self.ssl_context,
304 standard_compatible=self.standard_compatible,
305 )
306 except BaseException as exc:
307 await self.handle_handshake_error(exc, stream)
308 else:
309 await handler(wrapped_stream)
311 await self.listener.serve(handler_wrapper, task_group)
313 async def aclose(self) -> None:
314 await self.listener.aclose()
316 @property
317 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
318 return {
319 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
320 }