1from __future__ import annotations
2
3import logging
4import re
5import ssl
6import sys
7from collections.abc import Callable, Mapping
8from dataclasses import dataclass
9from functools import wraps
10from typing import Any, TypeVar
11
12from .. import (
13 BrokenResourceError,
14 EndOfStream,
15 aclose_forcefully,
16 get_cancelled_exc_class,
17 to_thread,
18)
19from .._core._typedattr import TypedAttributeSet, typed_attribute
20from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
21
22if sys.version_info >= (3, 11):
23 from typing import TypeVarTuple, Unpack
24else:
25 from typing_extensions import TypeVarTuple, Unpack
26
27T_Retval = TypeVar("T_Retval")
28PosArgsT = TypeVarTuple("PosArgsT")
29_PCTRTT = tuple[tuple[str, str], ...]
30_PCTRTTT = tuple[_PCTRTT, ...]
31
32
33class TLSAttribute(TypedAttributeSet):
34 """Contains Transport Layer Security related attributes."""
35
36 #: the selected ALPN protocol
37 alpn_protocol: str | None = typed_attribute()
38 #: the channel binding for type ``tls-unique``
39 channel_binding_tls_unique: bytes = typed_attribute()
40 #: the selected cipher
41 cipher: tuple[str, str, int] = typed_attribute()
42 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
43 # for more information)
44 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
45 #: the peer certificate in binary form
46 peer_certificate_binary: bytes | None = typed_attribute()
47 #: ``True`` if this is the server side of the connection
48 server_side: bool = typed_attribute()
49 #: ciphers shared by the client during the TLS handshake (``None`` if this is the
50 #: client side)
51 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
52 #: the :class:`~ssl.SSLObject` used for encryption
53 ssl_object: ssl.SSLObject = typed_attribute()
54 #: ``True`` if this stream does (and expects) a closing TLS handshake when the
55 #: stream is being closed
56 standard_compatible: bool = typed_attribute()
57 #: the TLS protocol version (e.g. ``TLSv1.2``)
58 tls_version: str = typed_attribute()
59
60
61@dataclass(eq=False)
62class TLSStream(ByteStream):
63 """
64 A stream wrapper that encrypts all sent data and decrypts received data.
65
66 This class has no public initializer; use :meth:`wrap` instead.
67 All extra attributes from :class:`~TLSAttribute` are supported.
68
69 :var AnyByteStream transport_stream: the wrapped stream
70
71 """
72
73 transport_stream: AnyByteStream
74 standard_compatible: bool
75 _ssl_object: ssl.SSLObject
76 _read_bio: ssl.MemoryBIO
77 _write_bio: ssl.MemoryBIO
78
79 @classmethod
80 async def wrap(
81 cls,
82 transport_stream: AnyByteStream,
83 *,
84 server_side: bool | None = None,
85 hostname: str | None = None,
86 ssl_context: ssl.SSLContext | None = None,
87 standard_compatible: bool = True,
88 ) -> TLSStream:
89 """
90 Wrap an existing stream with Transport Layer Security.
91
92 This performs a TLS handshake with the peer.
93
94 :param transport_stream: a bytes-transporting stream to wrap
95 :param server_side: ``True`` if this is the server side of the connection,
96 ``False`` if this is the client side (if omitted, will be set to ``False``
97 if ``hostname`` has been provided, ``False`` otherwise). Used only to create
98 a default context when an explicit context has not been provided.
99 :param hostname: host name of the peer (if host name checking is desired)
100 :param ssl_context: the SSLContext object to use (if not provided, a secure
101 default will be created)
102 :param standard_compatible: if ``False``, skip the closing handshake when
103 closing the connection, and don't raise an exception if the peer does the
104 same
105 :raises ~ssl.SSLError: if the TLS handshake fails
106
107 """
108 if server_side is None:
109 server_side = not hostname
110
111 if not ssl_context:
112 purpose = (
113 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
114 )
115 ssl_context = ssl.create_default_context(purpose)
116
117 # Re-enable detection of unexpected EOFs if it was disabled by Python
118 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
119 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
120
121 bio_in = ssl.MemoryBIO()
122 bio_out = ssl.MemoryBIO()
123
124 # External SSLContext implementations may do blocking I/O in wrap_bio(),
125 # but the standard library implementation won't
126 if type(ssl_context) is ssl.SSLContext:
127 ssl_object = ssl_context.wrap_bio(
128 bio_in, bio_out, server_side=server_side, server_hostname=hostname
129 )
130 else:
131 ssl_object = await to_thread.run_sync(
132 ssl_context.wrap_bio,
133 bio_in,
134 bio_out,
135 server_side,
136 hostname,
137 None,
138 )
139
140 wrapper = cls(
141 transport_stream=transport_stream,
142 standard_compatible=standard_compatible,
143 _ssl_object=ssl_object,
144 _read_bio=bio_in,
145 _write_bio=bio_out,
146 )
147 await wrapper._call_sslobject_method(ssl_object.do_handshake)
148 return wrapper
149
150 async def _call_sslobject_method(
151 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
152 ) -> T_Retval:
153 while True:
154 try:
155 result = func(*args)
156 except ssl.SSLWantReadError:
157 try:
158 # Flush any pending writes first
159 if self._write_bio.pending:
160 await self.transport_stream.send(self._write_bio.read())
161
162 data = await self.transport_stream.receive()
163 except EndOfStream:
164 self._read_bio.write_eof()
165 except OSError as exc:
166 self._read_bio.write_eof()
167 self._write_bio.write_eof()
168 raise BrokenResourceError from exc
169 else:
170 self._read_bio.write(data)
171 except ssl.SSLWantWriteError:
172 await self.transport_stream.send(self._write_bio.read())
173 except ssl.SSLSyscallError as exc:
174 self._read_bio.write_eof()
175 self._write_bio.write_eof()
176 raise BrokenResourceError from exc
177 except ssl.SSLError as exc:
178 self._read_bio.write_eof()
179 self._write_bio.write_eof()
180 if isinstance(exc, ssl.SSLEOFError) or (
181 exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
182 ):
183 if self.standard_compatible:
184 raise BrokenResourceError from exc
185 else:
186 raise EndOfStream from None
187
188 raise
189 else:
190 # Flush any pending writes first
191 if self._write_bio.pending:
192 await self.transport_stream.send(self._write_bio.read())
193
194 return result
195
196 async def unwrap(self) -> tuple[AnyByteStream, bytes]:
197 """
198 Does the TLS closing handshake.
199
200 :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
201
202 """
203 await self._call_sslobject_method(self._ssl_object.unwrap)
204 self._read_bio.write_eof()
205 self._write_bio.write_eof()
206 return self.transport_stream, self._read_bio.read()
207
208 async def aclose(self) -> None:
209 if self.standard_compatible:
210 try:
211 await self.unwrap()
212 except BaseException:
213 await aclose_forcefully(self.transport_stream)
214 raise
215
216 await self.transport_stream.aclose()
217
218 async def receive(self, max_bytes: int = 65536) -> bytes:
219 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
220 if not data:
221 raise EndOfStream
222
223 return data
224
225 async def send(self, item: bytes) -> None:
226 await self._call_sslobject_method(self._ssl_object.write, item)
227
228 async def send_eof(self) -> None:
229 tls_version = self.extra(TLSAttribute.tls_version)
230 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
231 if match:
232 major, minor = int(match.group(1)), int(match.group(2) or 0)
233 if (major, minor) < (1, 3):
234 raise NotImplementedError(
235 f"send_eof() requires at least TLSv1.3; current "
236 f"session uses {tls_version}"
237 )
238
239 raise NotImplementedError(
240 "send_eof() has not yet been implemented for TLS streams"
241 )
242
243 @property
244 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
245 return {
246 **self.transport_stream.extra_attributes,
247 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
248 TLSAttribute.channel_binding_tls_unique: (
249 self._ssl_object.get_channel_binding
250 ),
251 TLSAttribute.cipher: self._ssl_object.cipher,
252 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
253 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
254 True
255 ),
256 TLSAttribute.server_side: lambda: self._ssl_object.server_side,
257 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
258 if self._ssl_object.server_side
259 else None,
260 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
261 TLSAttribute.ssl_object: lambda: self._ssl_object,
262 TLSAttribute.tls_version: self._ssl_object.version,
263 }
264
265
266@dataclass(eq=False)
267class TLSListener(Listener[TLSStream]):
268 """
269 A convenience listener that wraps another listener and auto-negotiates a TLS session
270 on every accepted connection.
271
272 If the TLS handshake times out or raises an exception,
273 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
274 deemed necessary.
275
276 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
277
278 :param Listener listener: the listener to wrap
279 :param ssl_context: the SSL context object
280 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
281 :param handshake_timeout: time limit for the TLS handshake
282 (passed to :func:`~anyio.fail_after`)
283 """
284
285 listener: Listener[Any]
286 ssl_context: ssl.SSLContext
287 standard_compatible: bool = True
288 handshake_timeout: float = 30
289
290 @staticmethod
291 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
292 """
293 Handle an exception raised during the TLS handshake.
294
295 This method does 3 things:
296
297 #. Forcefully closes the original stream
298 #. Logs the exception (unless it was a cancellation exception) using the
299 ``anyio.streams.tls`` logger
300 #. Reraises the exception if it was a base exception or a cancellation exception
301
302 :param exc: the exception
303 :param stream: the original stream
304
305 """
306 await aclose_forcefully(stream)
307
308 # Log all except cancellation exceptions
309 if not isinstance(exc, get_cancelled_exc_class()):
310 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
311 # any asyncio implementation, so we explicitly pass the exception to log
312 # (https://github.com/python/cpython/issues/108668). Trio does not have this
313 # issue because it works around the CPython bug.
314 logging.getLogger(__name__).exception(
315 "Error during TLS handshake", exc_info=exc
316 )
317
318 # Only reraise base exceptions and cancellation exceptions
319 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
320 raise
321
322 async def serve(
323 self,
324 handler: Callable[[TLSStream], Any],
325 task_group: TaskGroup | None = None,
326 ) -> None:
327 @wraps(handler)
328 async def handler_wrapper(stream: AnyByteStream) -> None:
329 from .. import fail_after
330
331 try:
332 with fail_after(self.handshake_timeout):
333 wrapped_stream = await TLSStream.wrap(
334 stream,
335 ssl_context=self.ssl_context,
336 standard_compatible=self.standard_compatible,
337 )
338 except BaseException as exc:
339 await self.handle_handshake_error(exc, stream)
340 else:
341 await handler(wrapped_stream)
342
343 await self.listener.serve(handler_wrapper, task_group)
344
345 async def aclose(self) -> None:
346 await self.listener.aclose()
347
348 @property
349 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
350 return {
351 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
352 }