Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/streams/tls.py: 38%
137 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
1import logging
2import re
3import ssl
4from dataclasses import dataclass
5from functools import wraps
6from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
8from .. import (
9 BrokenResourceError,
10 EndOfStream,
11 aclose_forcefully,
12 get_cancelled_exc_class,
13)
14from .._core._typedattr import TypedAttributeSet, typed_attribute
15from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
17T_Retval = TypeVar("T_Retval")
18_PCTRTT = Tuple[Tuple[str, str], ...]
19_PCTRTTT = Tuple[_PCTRTT, ...]
22class TLSAttribute(TypedAttributeSet):
23 """Contains Transport Layer Security related attributes."""
25 #: the selected ALPN protocol
26 alpn_protocol: Optional[str] = typed_attribute()
27 #: the channel binding for type ``tls-unique``
28 channel_binding_tls_unique: bytes = typed_attribute()
29 #: the selected cipher
30 cipher: Tuple[str, str, int] = typed_attribute()
31 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` for more
32 #: information)
33 peer_certificate: Optional[
34 Dict[str, Union[str, _PCTRTTT, _PCTRTT]]
35 ] = typed_attribute()
36 #: the peer certificate in binary form
37 peer_certificate_binary: Optional[bytes] = typed_attribute()
38 #: ``True`` if this is the server side of the connection
39 server_side: bool = typed_attribute()
40 #: ciphers shared between both ends of the TLS connection
41 shared_ciphers: List[Tuple[str, str, int]] = typed_attribute()
42 #: the :class:`~ssl.SSLObject` used for encryption
43 ssl_object: ssl.SSLObject = typed_attribute()
44 #: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being
45 #: closed
46 standard_compatible: bool = typed_attribute()
47 #: the TLS protocol version (e.g. ``TLSv1.2``)
48 tls_version: str = typed_attribute()
51@dataclass(eq=False)
52class TLSStream(ByteStream):
53 """
54 A stream wrapper that encrypts all sent data and decrypts received data.
56 This class has no public initializer; use :meth:`wrap` instead.
57 All extra attributes from :class:`~TLSAttribute` are supported.
59 :var AnyByteStream transport_stream: the wrapped stream
61 """
63 transport_stream: AnyByteStream
64 standard_compatible: bool
65 _ssl_object: ssl.SSLObject
66 _read_bio: ssl.MemoryBIO
67 _write_bio: ssl.MemoryBIO
69 @classmethod
70 async def wrap(
71 cls,
72 transport_stream: AnyByteStream,
73 *,
74 server_side: Optional[bool] = None,
75 hostname: Optional[str] = None,
76 ssl_context: Optional[ssl.SSLContext] = None,
77 standard_compatible: bool = True,
78 ) -> "TLSStream":
79 """
80 Wrap an existing stream with Transport Layer Security.
82 This performs a TLS handshake with the peer.
84 :param transport_stream: a bytes-transporting stream to wrap
85 :param server_side: ``True`` if this is the server side of the connection, ``False`` if
86 this is the client side (if omitted, will be set to ``False`` if ``hostname`` has been
87 provided, ``False`` otherwise). Used only to create a default context when an explicit
88 context has not been provided.
89 :param hostname: host name of the peer (if host name checking is desired)
90 :param ssl_context: the SSLContext object to use (if not provided, a secure default will be
91 created)
92 :param standard_compatible: if ``False``, skip the closing handshake when closing the
93 connection, and don't raise an exception if the peer does the same
94 :raises ~ssl.SSLError: if the TLS handshake fails
96 """
97 if server_side is None:
98 server_side = not hostname
100 if not ssl_context:
101 purpose = (
102 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
103 )
104 ssl_context = ssl.create_default_context(purpose)
106 # Re-enable detection of unexpected EOFs if it was disabled by Python
107 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
108 ssl_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF # type: ignore[attr-defined]
110 bio_in = ssl.MemoryBIO()
111 bio_out = ssl.MemoryBIO()
112 ssl_object = ssl_context.wrap_bio(
113 bio_in, bio_out, server_side=server_side, server_hostname=hostname
114 )
115 wrapper = cls(
116 transport_stream=transport_stream,
117 standard_compatible=standard_compatible,
118 _ssl_object=ssl_object,
119 _read_bio=bio_in,
120 _write_bio=bio_out,
121 )
122 await wrapper._call_sslobject_method(ssl_object.do_handshake)
123 return wrapper
125 async def _call_sslobject_method(
126 self, func: Callable[..., T_Retval], *args: object
127 ) -> T_Retval:
128 while True:
129 try:
130 result = func(*args)
131 except ssl.SSLWantReadError:
132 try:
133 # Flush any pending writes first
134 if self._write_bio.pending:
135 await self.transport_stream.send(self._write_bio.read())
137 data = await self.transport_stream.receive()
138 except EndOfStream:
139 self._read_bio.write_eof()
140 except OSError as exc:
141 self._read_bio.write_eof()
142 self._write_bio.write_eof()
143 raise BrokenResourceError from exc
144 else:
145 self._read_bio.write(data)
146 except ssl.SSLWantWriteError:
147 await self.transport_stream.send(self._write_bio.read())
148 except ssl.SSLSyscallError as exc:
149 self._read_bio.write_eof()
150 self._write_bio.write_eof()
151 raise BrokenResourceError from exc
152 except ssl.SSLError as exc:
153 self._read_bio.write_eof()
154 self._write_bio.write_eof()
155 if (
156 isinstance(exc, ssl.SSLEOFError)
157 or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
158 ):
159 if self.standard_compatible:
160 raise BrokenResourceError from exc
161 else:
162 raise EndOfStream from None
164 raise
165 else:
166 # Flush any pending writes first
167 if self._write_bio.pending:
168 await self.transport_stream.send(self._write_bio.read())
170 return result
172 async def unwrap(self) -> Tuple[AnyByteStream, bytes]:
173 """
174 Does the TLS closing handshake.
176 :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
178 """
179 await self._call_sslobject_method(self._ssl_object.unwrap)
180 self._read_bio.write_eof()
181 self._write_bio.write_eof()
182 return self.transport_stream, self._read_bio.read()
184 async def aclose(self) -> None:
185 if self.standard_compatible:
186 try:
187 await self.unwrap()
188 except BaseException:
189 await aclose_forcefully(self.transport_stream)
190 raise
192 await self.transport_stream.aclose()
194 async def receive(self, max_bytes: int = 65536) -> bytes:
195 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
196 if not data:
197 raise EndOfStream
199 return data
201 async def send(self, item: bytes) -> None:
202 await self._call_sslobject_method(self._ssl_object.write, item)
204 async def send_eof(self) -> None:
205 tls_version = self.extra(TLSAttribute.tls_version)
206 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
207 if match:
208 major, minor = int(match.group(1)), int(match.group(2) or 0)
209 if (major, minor) < (1, 3):
210 raise NotImplementedError(
211 f"send_eof() requires at least TLSv1.3; current "
212 f"session uses {tls_version}"
213 )
215 raise NotImplementedError(
216 "send_eof() has not yet been implemented for TLS streams"
217 )
219 @property
220 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
221 return {
222 **self.transport_stream.extra_attributes,
223 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
224 TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding,
225 TLSAttribute.cipher: self._ssl_object.cipher,
226 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
227 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
228 True
229 ),
230 TLSAttribute.server_side: lambda: self._ssl_object.server_side,
231 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers(),
232 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
233 TLSAttribute.ssl_object: lambda: self._ssl_object,
234 TLSAttribute.tls_version: self._ssl_object.version,
235 }
238@dataclass(eq=False)
239class TLSListener(Listener[TLSStream]):
240 """
241 A convenience listener that wraps another listener and auto-negotiates a TLS session on every
242 accepted connection.
244 If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is
245 called to do whatever post-mortem processing is deemed necessary.
247 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
249 :param Listener listener: the listener to wrap
250 :param ssl_context: the SSL context object
251 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
252 :param handshake_timeout: time limit for the TLS handshake
253 (passed to :func:`~anyio.fail_after`)
254 """
256 listener: Listener[Any]
257 ssl_context: ssl.SSLContext
258 standard_compatible: bool = True
259 handshake_timeout: float = 30
261 @staticmethod
262 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
263 f"""
264 Handle an exception raised during the TLS handshake.
266 This method does 3 things:
268 #. Forcefully closes the original stream
269 #. Logs the exception (unless it was a cancellation exception) using the ``{__name__}``
270 logger
271 #. Reraises the exception if it was a base exception or a cancellation exception
273 :param exc: the exception
274 :param stream: the original stream
276 """
277 await aclose_forcefully(stream)
279 # Log all except cancellation exceptions
280 if not isinstance(exc, get_cancelled_exc_class()):
281 logging.getLogger(__name__).exception("Error during TLS handshake")
283 # Only reraise base exceptions and cancellation exceptions
284 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
285 raise
287 async def serve(
288 self,
289 handler: Callable[[TLSStream], Any],
290 task_group: Optional[TaskGroup] = None,
291 ) -> None:
292 @wraps(handler)
293 async def handler_wrapper(stream: AnyByteStream) -> None:
294 from .. import fail_after
296 try:
297 with fail_after(self.handshake_timeout):
298 wrapped_stream = await TLSStream.wrap(
299 stream,
300 ssl_context=self.ssl_context,
301 standard_compatible=self.standard_compatible,
302 )
303 except BaseException as exc:
304 await self.handle_handshake_error(exc, stream)
305 else:
306 await handler(wrapped_stream)
308 await self.listener.serve(handler_wrapper, task_group)
310 async def aclose(self) -> None:
311 await self.listener.aclose()
313 @property
314 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
315 return {
316 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
317 }