1from __future__ import annotations
2
3import logging
4import re
5import ssl
6import sys
7from collections.abc import Callable, Mapping
8from dataclasses import dataclass
9from functools import wraps
10from typing import Any, Tuple, TypeVar
11
12from .. import (
13 BrokenResourceError,
14 EndOfStream,
15 aclose_forcefully,
16 get_cancelled_exc_class,
17)
18from .._core._typedattr import TypedAttributeSet, typed_attribute
19from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
20
21if sys.version_info >= (3, 11):
22 from typing import TypeVarTuple, Unpack
23else:
24 from typing_extensions import TypeVarTuple, Unpack
25
26T_Retval = TypeVar("T_Retval")
27PosArgsT = TypeVarTuple("PosArgsT")
28_PCTRTT = Tuple[Tuple[str, str], ...]
29_PCTRTTT = Tuple[_PCTRTT, ...]
30
31
32class TLSAttribute(TypedAttributeSet):
33 """Contains Transport Layer Security related attributes."""
34
35 #: the selected ALPN protocol
36 alpn_protocol: str | None = typed_attribute()
37 #: the channel binding for type ``tls-unique``
38 channel_binding_tls_unique: bytes = typed_attribute()
39 #: the selected cipher
40 cipher: tuple[str, str, int] = typed_attribute()
41 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
42 # for more information)
43 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
44 #: the peer certificate in binary form
45 peer_certificate_binary: bytes | None = typed_attribute()
46 #: ``True`` if this is the server side of the connection
47 server_side: bool = typed_attribute()
48 #: ciphers shared by the client during the TLS handshake (``None`` if this is the
49 #: client side)
50 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
51 #: the :class:`~ssl.SSLObject` used for encryption
52 ssl_object: ssl.SSLObject = typed_attribute()
53 #: ``True`` if this stream does (and expects) a closing TLS handshake when the
54 #: stream is being closed
55 standard_compatible: bool = typed_attribute()
56 #: the TLS protocol version (e.g. ``TLSv1.2``)
57 tls_version: str = typed_attribute()
58
59
60@dataclass(eq=False)
61class TLSStream(ByteStream):
62 """
63 A stream wrapper that encrypts all sent data and decrypts received data.
64
65 This class has no public initializer; use :meth:`wrap` instead.
66 All extra attributes from :class:`~TLSAttribute` are supported.
67
68 :var AnyByteStream transport_stream: the wrapped stream
69
70 """
71
72 transport_stream: AnyByteStream
73 standard_compatible: bool
74 _ssl_object: ssl.SSLObject
75 _read_bio: ssl.MemoryBIO
76 _write_bio: ssl.MemoryBIO
77
78 @classmethod
79 async def wrap(
80 cls,
81 transport_stream: AnyByteStream,
82 *,
83 server_side: bool | None = None,
84 hostname: str | None = None,
85 ssl_context: ssl.SSLContext | None = None,
86 standard_compatible: bool = True,
87 ) -> TLSStream:
88 """
89 Wrap an existing stream with Transport Layer Security.
90
91 This performs a TLS handshake with the peer.
92
93 :param transport_stream: a bytes-transporting stream to wrap
94 :param server_side: ``True`` if this is the server side of the connection,
95 ``False`` if this is the client side (if omitted, will be set to ``False``
96 if ``hostname`` has been provided, ``False`` otherwise). Used only to create
97 a default context when an explicit context has not been provided.
98 :param hostname: host name of the peer (if host name checking is desired)
99 :param ssl_context: the SSLContext object to use (if not provided, a secure
100 default will be created)
101 :param standard_compatible: if ``False``, skip the closing handshake when
102 closing the connection, and don't raise an exception if the peer does the
103 same
104 :raises ~ssl.SSLError: if the TLS handshake fails
105
106 """
107 if server_side is None:
108 server_side = not hostname
109
110 if not ssl_context:
111 purpose = (
112 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
113 )
114 ssl_context = ssl.create_default_context(purpose)
115
116 # Re-enable detection of unexpected EOFs if it was disabled by Python
117 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
118 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
119
120 bio_in = ssl.MemoryBIO()
121 bio_out = ssl.MemoryBIO()
122 ssl_object = ssl_context.wrap_bio(
123 bio_in, bio_out, server_side=server_side, server_hostname=hostname
124 )
125 wrapper = cls(
126 transport_stream=transport_stream,
127 standard_compatible=standard_compatible,
128 _ssl_object=ssl_object,
129 _read_bio=bio_in,
130 _write_bio=bio_out,
131 )
132 await wrapper._call_sslobject_method(ssl_object.do_handshake)
133 return wrapper
134
135 async def _call_sslobject_method(
136 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
137 ) -> T_Retval:
138 while True:
139 try:
140 result = func(*args)
141 except ssl.SSLWantReadError:
142 try:
143 # Flush any pending writes first
144 if self._write_bio.pending:
145 await self.transport_stream.send(self._write_bio.read())
146
147 data = await self.transport_stream.receive()
148 except EndOfStream:
149 self._read_bio.write_eof()
150 except OSError as exc:
151 self._read_bio.write_eof()
152 self._write_bio.write_eof()
153 raise BrokenResourceError from exc
154 else:
155 self._read_bio.write(data)
156 except ssl.SSLWantWriteError:
157 await self.transport_stream.send(self._write_bio.read())
158 except ssl.SSLSyscallError as exc:
159 self._read_bio.write_eof()
160 self._write_bio.write_eof()
161 raise BrokenResourceError from exc
162 except ssl.SSLError as exc:
163 self._read_bio.write_eof()
164 self._write_bio.write_eof()
165 if (
166 isinstance(exc, ssl.SSLEOFError)
167 or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
168 ):
169 if self.standard_compatible:
170 raise BrokenResourceError from exc
171 else:
172 raise EndOfStream from None
173
174 raise
175 else:
176 # Flush any pending writes first
177 if self._write_bio.pending:
178 await self.transport_stream.send(self._write_bio.read())
179
180 return result
181
182 async def unwrap(self) -> tuple[AnyByteStream, bytes]:
183 """
184 Does the TLS closing handshake.
185
186 :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
187
188 """
189 await self._call_sslobject_method(self._ssl_object.unwrap)
190 self._read_bio.write_eof()
191 self._write_bio.write_eof()
192 return self.transport_stream, self._read_bio.read()
193
194 async def aclose(self) -> None:
195 if self.standard_compatible:
196 try:
197 await self.unwrap()
198 except BaseException:
199 await aclose_forcefully(self.transport_stream)
200 raise
201
202 await self.transport_stream.aclose()
203
204 async def receive(self, max_bytes: int = 65536) -> bytes:
205 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
206 if not data:
207 raise EndOfStream
208
209 return data
210
211 async def send(self, item: bytes) -> None:
212 await self._call_sslobject_method(self._ssl_object.write, item)
213
214 async def send_eof(self) -> None:
215 tls_version = self.extra(TLSAttribute.tls_version)
216 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
217 if match:
218 major, minor = int(match.group(1)), int(match.group(2) or 0)
219 if (major, minor) < (1, 3):
220 raise NotImplementedError(
221 f"send_eof() requires at least TLSv1.3; current "
222 f"session uses {tls_version}"
223 )
224
225 raise NotImplementedError(
226 "send_eof() has not yet been implemented for TLS streams"
227 )
228
229 @property
230 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
231 return {
232 **self.transport_stream.extra_attributes,
233 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
234 TLSAttribute.channel_binding_tls_unique: (
235 self._ssl_object.get_channel_binding
236 ),
237 TLSAttribute.cipher: self._ssl_object.cipher,
238 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
239 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
240 True
241 ),
242 TLSAttribute.server_side: lambda: self._ssl_object.server_side,
243 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
244 if self._ssl_object.server_side
245 else None,
246 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
247 TLSAttribute.ssl_object: lambda: self._ssl_object,
248 TLSAttribute.tls_version: self._ssl_object.version,
249 }
250
251
252@dataclass(eq=False)
253class TLSListener(Listener[TLSStream]):
254 """
255 A convenience listener that wraps another listener and auto-negotiates a TLS session
256 on every accepted connection.
257
258 If the TLS handshake times out or raises an exception,
259 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
260 deemed necessary.
261
262 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
263
264 :param Listener listener: the listener to wrap
265 :param ssl_context: the SSL context object
266 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
267 :param handshake_timeout: time limit for the TLS handshake
268 (passed to :func:`~anyio.fail_after`)
269 """
270
271 listener: Listener[Any]
272 ssl_context: ssl.SSLContext
273 standard_compatible: bool = True
274 handshake_timeout: float = 30
275
276 @staticmethod
277 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
278 """
279 Handle an exception raised during the TLS handshake.
280
281 This method does 3 things:
282
283 #. Forcefully closes the original stream
284 #. Logs the exception (unless it was a cancellation exception) using the
285 ``anyio.streams.tls`` logger
286 #. Reraises the exception if it was a base exception or a cancellation exception
287
288 :param exc: the exception
289 :param stream: the original stream
290
291 """
292 await aclose_forcefully(stream)
293
294 # Log all except cancellation exceptions
295 if not isinstance(exc, get_cancelled_exc_class()):
296 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
297 # any asyncio implementation, so we explicitly pass the exception to log
298 # (https://github.com/python/cpython/issues/108668). Trio does not have this
299 # issue because it works around the CPython bug.
300 logging.getLogger(__name__).exception(
301 "Error during TLS handshake", exc_info=exc
302 )
303
304 # Only reraise base exceptions and cancellation exceptions
305 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
306 raise
307
308 async def serve(
309 self,
310 handler: Callable[[TLSStream], Any],
311 task_group: TaskGroup | None = None,
312 ) -> None:
313 @wraps(handler)
314 async def handler_wrapper(stream: AnyByteStream) -> None:
315 from .. import fail_after
316
317 try:
318 with fail_after(self.handshake_timeout):
319 wrapped_stream = await TLSStream.wrap(
320 stream,
321 ssl_context=self.ssl_context,
322 standard_compatible=self.standard_compatible,
323 )
324 except BaseException as exc:
325 await self.handle_handshake_error(exc, stream)
326 else:
327 await handler(wrapped_stream)
328
329 await self.listener.serve(handler_wrapper, task_group)
330
331 async def aclose(self) -> None:
332 await self.listener.aclose()
333
334 @property
335 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
336 return {
337 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
338 }