1from __future__ import annotations
2
3import ssl
4import typing
5
6import trio
7
8from .._exceptions import (
9 ConnectError,
10 ConnectTimeout,
11 ExceptionMapping,
12 ReadError,
13 ReadTimeout,
14 WriteError,
15 WriteTimeout,
16 map_exceptions,
17)
18from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
19
20
21class TrioStream(AsyncNetworkStream):
22 def __init__(self, stream: trio.abc.Stream) -> None:
23 self._stream = stream
24
25 async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
26 timeout_or_inf = float("inf") if timeout is None else timeout
27 exc_map: ExceptionMapping = {
28 trio.TooSlowError: ReadTimeout,
29 trio.BrokenResourceError: ReadError,
30 trio.ClosedResourceError: ReadError,
31 }
32 with map_exceptions(exc_map):
33 with trio.fail_after(timeout_or_inf):
34 data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
35 return data
36
37 async def write(self, buffer: bytes, timeout: float | None = None) -> None:
38 if not buffer:
39 return
40
41 timeout_or_inf = float("inf") if timeout is None else timeout
42 exc_map: ExceptionMapping = {
43 trio.TooSlowError: WriteTimeout,
44 trio.BrokenResourceError: WriteError,
45 trio.ClosedResourceError: WriteError,
46 }
47 with map_exceptions(exc_map):
48 with trio.fail_after(timeout_or_inf):
49 await self._stream.send_all(data=buffer)
50
51 async def aclose(self) -> None:
52 await self._stream.aclose()
53
54 async def start_tls(
55 self,
56 ssl_context: ssl.SSLContext,
57 server_hostname: str | None = None,
58 timeout: float | None = None,
59 ) -> AsyncNetworkStream:
60 timeout_or_inf = float("inf") if timeout is None else timeout
61 exc_map: ExceptionMapping = {
62 trio.TooSlowError: ConnectTimeout,
63 trio.BrokenResourceError: ConnectError,
64 }
65 ssl_stream = trio.SSLStream(
66 self._stream,
67 ssl_context=ssl_context,
68 server_hostname=server_hostname,
69 https_compatible=True,
70 server_side=False,
71 )
72 with map_exceptions(exc_map):
73 try:
74 with trio.fail_after(timeout_or_inf):
75 await ssl_stream.do_handshake()
76 except Exception as exc: # pragma: nocover
77 await self.aclose()
78 raise exc
79 return TrioStream(ssl_stream)
80
81 def get_extra_info(self, info: str) -> typing.Any:
82 if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
83 # Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__.
84 # Tracked at https://github.com/python-trio/trio/issues/542
85 return self._stream._ssl_object # type: ignore[attr-defined]
86 if info == "client_addr":
87 return self._get_socket_stream().socket.getsockname()
88 if info == "server_addr":
89 return self._get_socket_stream().socket.getpeername()
90 if info == "socket":
91 stream = self._stream
92 while isinstance(stream, trio.SSLStream):
93 stream = stream.transport_stream
94 assert isinstance(stream, trio.SocketStream)
95 return stream.socket
96 if info == "is_readable":
97 socket = self.get_extra_info("socket")
98 return socket.is_readable()
99 return None
100
101 def _get_socket_stream(self) -> trio.SocketStream:
102 stream = self._stream
103 while isinstance(stream, trio.SSLStream):
104 stream = stream.transport_stream
105 assert isinstance(stream, trio.SocketStream)
106 return stream
107
108
109class TrioBackend(AsyncNetworkBackend):
110 async def connect_tcp(
111 self,
112 host: str,
113 port: int,
114 timeout: float | None = None,
115 local_address: str | None = None,
116 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
117 ) -> AsyncNetworkStream:
118 # By default for TCP sockets, trio enables TCP_NODELAY.
119 # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
120 if socket_options is None:
121 socket_options = [] # pragma: no cover
122 timeout_or_inf = float("inf") if timeout is None else timeout
123 exc_map: ExceptionMapping = {
124 trio.TooSlowError: ConnectTimeout,
125 trio.BrokenResourceError: ConnectError,
126 OSError: ConnectError,
127 }
128 with map_exceptions(exc_map):
129 with trio.fail_after(timeout_or_inf):
130 stream: trio.abc.Stream = await trio.open_tcp_stream(
131 host=host, port=port, local_address=local_address
132 )
133 for option in socket_options:
134 stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
135 return TrioStream(stream)
136
137 async def connect_unix_socket(
138 self,
139 path: str,
140 timeout: float | None = None,
141 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
142 ) -> AsyncNetworkStream: # pragma: nocover
143 if socket_options is None:
144 socket_options = []
145 timeout_or_inf = float("inf") if timeout is None else timeout
146 exc_map: ExceptionMapping = {
147 trio.TooSlowError: ConnectTimeout,
148 trio.BrokenResourceError: ConnectError,
149 OSError: ConnectError,
150 }
151 with map_exceptions(exc_map):
152 with trio.fail_after(timeout_or_inf):
153 stream: trio.abc.Stream = await trio.open_unix_socket(path)
154 for option in socket_options:
155 stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
156 return TrioStream(stream)
157
158 async def sleep(self, seconds: float) -> None:
159 await trio.sleep(seconds) # pragma: nocover