1from __future__ import annotations
2
3import logging
4import re
5import ssl
6import sys
7from collections.abc import Callable, Mapping
8from dataclasses import dataclass
9from functools import wraps
10from ssl import SSLContext
11from typing import Any, TypeVar
12
13from .. import (
14 BrokenResourceError,
15 EndOfStream,
16 aclose_forcefully,
17 get_cancelled_exc_class,
18 to_thread,
19)
20from .._core._typedattr import TypedAttributeSet, typed_attribute
21from ..abc import (
22 AnyByteStream,
23 AnyByteStreamConnectable,
24 ByteStream,
25 ByteStreamConnectable,
26 Listener,
27 TaskGroup,
28)
29
30if sys.version_info >= (3, 10):
31 from typing import TypeAlias
32else:
33 from typing_extensions import TypeAlias
34
35if sys.version_info >= (3, 11):
36 from typing import TypeVarTuple, Unpack
37else:
38 from typing_extensions import TypeVarTuple, Unpack
39
40if sys.version_info >= (3, 12):
41 from typing import override
42else:
43 from typing_extensions import override
44
45T_Retval = TypeVar("T_Retval")
46PosArgsT = TypeVarTuple("PosArgsT")
47_PCTRTT: TypeAlias = tuple[tuple[str, str], ...]
48_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
49
50
51class TLSAttribute(TypedAttributeSet):
52 """Contains Transport Layer Security related attributes."""
53
54 #: the selected ALPN protocol
55 alpn_protocol: str | None = typed_attribute()
56 #: the channel binding for type ``tls-unique``
57 channel_binding_tls_unique: bytes = typed_attribute()
58 #: the selected cipher
59 cipher: tuple[str, str, int] = typed_attribute()
60 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
61 # for more information)
62 peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
63 #: the peer certificate in binary form
64 peer_certificate_binary: bytes | None = typed_attribute()
65 #: ``True`` if this is the server side of the connection
66 server_side: bool = typed_attribute()
67 #: ciphers shared by the client during the TLS handshake (``None`` if this is the
68 #: client side)
69 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
70 #: the :class:`~ssl.SSLObject` used for encryption
71 ssl_object: ssl.SSLObject = typed_attribute()
72 #: ``True`` if this stream does (and expects) a closing TLS handshake when the
73 #: stream is being closed
74 standard_compatible: bool = typed_attribute()
75 #: the TLS protocol version (e.g. ``TLSv1.2``)
76 tls_version: str = typed_attribute()
77
78
79@dataclass(eq=False)
80class TLSStream(ByteStream):
81 """
82 A stream wrapper that encrypts all sent data and decrypts received data.
83
84 This class has no public initializer; use :meth:`wrap` instead.
85 All extra attributes from :class:`~TLSAttribute` are supported.
86
87 :var AnyByteStream transport_stream: the wrapped stream
88
89 """
90
91 transport_stream: AnyByteStream
92 standard_compatible: bool
93 _ssl_object: ssl.SSLObject
94 _read_bio: ssl.MemoryBIO
95 _write_bio: ssl.MemoryBIO
96
97 @classmethod
98 async def wrap(
99 cls,
100 transport_stream: AnyByteStream,
101 *,
102 server_side: bool | None = None,
103 hostname: str | None = None,
104 ssl_context: ssl.SSLContext | None = None,
105 standard_compatible: bool = True,
106 ) -> TLSStream:
107 """
108 Wrap an existing stream with Transport Layer Security.
109
110 This performs a TLS handshake with the peer.
111
112 :param transport_stream: a bytes-transporting stream to wrap
113 :param server_side: ``True`` if this is the server side of the connection,
114 ``False`` if this is the client side (if omitted, will be set to ``False``
115 if ``hostname`` has been provided, ``False`` otherwise). Used only to create
116 a default context when an explicit context has not been provided.
117 :param hostname: host name of the peer (if host name checking is desired)
118 :param ssl_context: the SSLContext object to use (if not provided, a secure
119 default will be created)
120 :param standard_compatible: if ``False``, skip the closing handshake when
121 closing the connection, and don't raise an exception if the peer does the
122 same
123 :raises ~ssl.SSLError: if the TLS handshake fails
124
125 """
126 if server_side is None:
127 server_side = not hostname
128
129 if not ssl_context:
130 purpose = (
131 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
132 )
133 ssl_context = ssl.create_default_context(purpose)
134
135 # Re-enable detection of unexpected EOFs if it was disabled by Python
136 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
137 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
138
139 bio_in = ssl.MemoryBIO()
140 bio_out = ssl.MemoryBIO()
141
142 # External SSLContext implementations may do blocking I/O in wrap_bio(),
143 # but the standard library implementation won't
144 if type(ssl_context) is ssl.SSLContext:
145 ssl_object = ssl_context.wrap_bio(
146 bio_in, bio_out, server_side=server_side, server_hostname=hostname
147 )
148 else:
149 ssl_object = await to_thread.run_sync(
150 ssl_context.wrap_bio,
151 bio_in,
152 bio_out,
153 server_side,
154 hostname,
155 None,
156 )
157
158 wrapper = cls(
159 transport_stream=transport_stream,
160 standard_compatible=standard_compatible,
161 _ssl_object=ssl_object,
162 _read_bio=bio_in,
163 _write_bio=bio_out,
164 )
165 await wrapper._call_sslobject_method(ssl_object.do_handshake)
166 return wrapper
167
168 async def _call_sslobject_method(
169 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
170 ) -> T_Retval:
171 while True:
172 try:
173 result = func(*args)
174 except ssl.SSLWantReadError:
175 try:
176 # Flush any pending writes first
177 if self._write_bio.pending:
178 await self.transport_stream.send(self._write_bio.read())
179
180 data = await self.transport_stream.receive()
181 except EndOfStream:
182 self._read_bio.write_eof()
183 except OSError as exc:
184 self._read_bio.write_eof()
185 self._write_bio.write_eof()
186 raise BrokenResourceError from exc
187 else:
188 self._read_bio.write(data)
189 except ssl.SSLWantWriteError:
190 await self.transport_stream.send(self._write_bio.read())
191 except ssl.SSLSyscallError as exc:
192 self._read_bio.write_eof()
193 self._write_bio.write_eof()
194 raise BrokenResourceError from exc
195 except ssl.SSLError as exc:
196 self._read_bio.write_eof()
197 self._write_bio.write_eof()
198 if isinstance(exc, ssl.SSLEOFError) or (
199 exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
200 ):
201 if self.standard_compatible:
202 raise BrokenResourceError from exc
203 else:
204 raise EndOfStream from None
205
206 raise
207 else:
208 # Flush any pending writes first
209 if self._write_bio.pending:
210 await self.transport_stream.send(self._write_bio.read())
211
212 return result
213
214 async def unwrap(self) -> tuple[AnyByteStream, bytes]:
215 """
216 Does the TLS closing handshake.
217
218 :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
219
220 """
221 await self._call_sslobject_method(self._ssl_object.unwrap)
222 self._read_bio.write_eof()
223 self._write_bio.write_eof()
224 return self.transport_stream, self._read_bio.read()
225
226 async def aclose(self) -> None:
227 if self.standard_compatible:
228 try:
229 await self.unwrap()
230 except BaseException:
231 await aclose_forcefully(self.transport_stream)
232 raise
233
234 await self.transport_stream.aclose()
235
236 async def receive(self, max_bytes: int = 65536) -> bytes:
237 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
238 if not data:
239 raise EndOfStream
240
241 return data
242
243 async def send(self, item: bytes) -> None:
244 await self._call_sslobject_method(self._ssl_object.write, item)
245
246 async def send_eof(self) -> None:
247 tls_version = self.extra(TLSAttribute.tls_version)
248 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
249 if match:
250 major, minor = int(match.group(1)), int(match.group(2) or 0)
251 if (major, minor) < (1, 3):
252 raise NotImplementedError(
253 f"send_eof() requires at least TLSv1.3; current "
254 f"session uses {tls_version}"
255 )
256
257 raise NotImplementedError(
258 "send_eof() has not yet been implemented for TLS streams"
259 )
260
261 @property
262 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
263 return {
264 **self.transport_stream.extra_attributes,
265 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
266 TLSAttribute.channel_binding_tls_unique: (
267 self._ssl_object.get_channel_binding
268 ),
269 TLSAttribute.cipher: self._ssl_object.cipher,
270 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
271 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
272 True
273 ),
274 TLSAttribute.server_side: lambda: self._ssl_object.server_side,
275 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
276 if self._ssl_object.server_side
277 else None,
278 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
279 TLSAttribute.ssl_object: lambda: self._ssl_object,
280 TLSAttribute.tls_version: self._ssl_object.version,
281 }
282
283
284@dataclass(eq=False)
285class TLSListener(Listener[TLSStream]):
286 """
287 A convenience listener that wraps another listener and auto-negotiates a TLS session
288 on every accepted connection.
289
290 If the TLS handshake times out or raises an exception,
291 :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
292 deemed necessary.
293
294 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
295
296 :param Listener listener: the listener to wrap
297 :param ssl_context: the SSL context object
298 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
299 :param handshake_timeout: time limit for the TLS handshake
300 (passed to :func:`~anyio.fail_after`)
301 """
302
303 listener: Listener[Any]
304 ssl_context: ssl.SSLContext
305 standard_compatible: bool = True
306 handshake_timeout: float = 30
307
308 @staticmethod
309 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
310 """
311 Handle an exception raised during the TLS handshake.
312
313 This method does 3 things:
314
315 #. Forcefully closes the original stream
316 #. Logs the exception (unless it was a cancellation exception) using the
317 ``anyio.streams.tls`` logger
318 #. Reraises the exception if it was a base exception or a cancellation exception
319
320 :param exc: the exception
321 :param stream: the original stream
322
323 """
324 await aclose_forcefully(stream)
325
326 # Log all except cancellation exceptions
327 if not isinstance(exc, get_cancelled_exc_class()):
328 # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
329 # any asyncio implementation, so we explicitly pass the exception to log
330 # (https://github.com/python/cpython/issues/108668). Trio does not have this
331 # issue because it works around the CPython bug.
332 logging.getLogger(__name__).exception(
333 "Error during TLS handshake", exc_info=exc
334 )
335
336 # Only reraise base exceptions and cancellation exceptions
337 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
338 raise
339
340 async def serve(
341 self,
342 handler: Callable[[TLSStream], Any],
343 task_group: TaskGroup | None = None,
344 ) -> None:
345 @wraps(handler)
346 async def handler_wrapper(stream: AnyByteStream) -> None:
347 from .. import fail_after
348
349 try:
350 with fail_after(self.handshake_timeout):
351 wrapped_stream = await TLSStream.wrap(
352 stream,
353 ssl_context=self.ssl_context,
354 standard_compatible=self.standard_compatible,
355 )
356 except BaseException as exc:
357 await self.handle_handshake_error(exc, stream)
358 else:
359 await handler(wrapped_stream)
360
361 await self.listener.serve(handler_wrapper, task_group)
362
363 async def aclose(self) -> None:
364 await self.listener.aclose()
365
366 @property
367 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
368 return {
369 TLSAttribute.standard_compatible: lambda: self.standard_compatible,
370 }
371
372
373class TLSConnectable(ByteStreamConnectable):
374 """
375 Wraps another connectable and does TLS negotiation after a successful connection.
376
377 :param connectable: the connectable to wrap
378 :param hostname: host name of the server (if host name checking is desired)
379 :param ssl_context: the SSLContext object to use (if not provided, a secure default
380 will be created)
381 :param standard_compatible: if ``False``, skip the closing handshake when closing
382 the connection, and don't raise an exception if the server does the same
383 """
384
385 def __init__(
386 self,
387 connectable: AnyByteStreamConnectable,
388 *,
389 hostname: str | None = None,
390 ssl_context: ssl.SSLContext | None = None,
391 standard_compatible: bool = True,
392 ) -> None:
393 self.connectable = connectable
394 self.ssl_context: SSLContext = ssl_context or ssl.create_default_context(
395 ssl.Purpose.SERVER_AUTH
396 )
397 if not isinstance(self.ssl_context, ssl.SSLContext):
398 raise TypeError(
399 "ssl_context must be an instance of ssl.SSLContext, not "
400 f"{type(self.ssl_context).__name__}"
401 )
402 self.hostname = hostname
403 self.standard_compatible = standard_compatible
404
405 @override
406 async def connect(self) -> TLSStream:
407 stream = await self.connectable.connect()
408 try:
409 return await TLSStream.wrap(
410 stream,
411 hostname=self.hostname,
412 ssl_context=self.ssl_context,
413 standard_compatible=self.standard_compatible,
414 )
415 except BaseException:
416 await aclose_forcefully(stream)
417 raise