1import ssl
2import typing
3
4import trio
5
6from .._exceptions import (
7 ConnectError,
8 ConnectTimeout,
9 ExceptionMapping,
10 ReadError,
11 ReadTimeout,
12 WriteError,
13 WriteTimeout,
14 map_exceptions,
15)
16from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
17
18
19class TrioStream(AsyncNetworkStream):
20 def __init__(self, stream: trio.abc.Stream) -> None:
21 self._stream = stream
22
23 async def read(
24 self, max_bytes: int, timeout: typing.Optional[float] = None
25 ) -> bytes:
26 timeout_or_inf = float("inf") if timeout is None else timeout
27 exc_map: ExceptionMapping = {
28 trio.TooSlowError: ReadTimeout,
29 trio.BrokenResourceError: ReadError,
30 trio.ClosedResourceError: ReadError,
31 }
32 with map_exceptions(exc_map):
33 with trio.fail_after(timeout_or_inf):
34 data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
35 return data
36
37 async def write(
38 self, buffer: bytes, timeout: typing.Optional[float] = None
39 ) -> None:
40 if not buffer:
41 return
42
43 timeout_or_inf = float("inf") if timeout is None else timeout
44 exc_map: ExceptionMapping = {
45 trio.TooSlowError: WriteTimeout,
46 trio.BrokenResourceError: WriteError,
47 trio.ClosedResourceError: WriteError,
48 }
49 with map_exceptions(exc_map):
50 with trio.fail_after(timeout_or_inf):
51 await self._stream.send_all(data=buffer)
52
53 async def aclose(self) -> None:
54 await self._stream.aclose()
55
56 async def start_tls(
57 self,
58 ssl_context: ssl.SSLContext,
59 server_hostname: typing.Optional[str] = None,
60 timeout: typing.Optional[float] = None,
61 ) -> AsyncNetworkStream:
62 timeout_or_inf = float("inf") if timeout is None else timeout
63 exc_map: ExceptionMapping = {
64 trio.TooSlowError: ConnectTimeout,
65 trio.BrokenResourceError: ConnectError,
66 }
67 ssl_stream = trio.SSLStream(
68 self._stream,
69 ssl_context=ssl_context,
70 server_hostname=server_hostname,
71 https_compatible=True,
72 server_side=False,
73 )
74 with map_exceptions(exc_map):
75 try:
76 with trio.fail_after(timeout_or_inf):
77 await ssl_stream.do_handshake()
78 except Exception as exc: # pragma: nocover
79 await self.aclose()
80 raise exc
81 return TrioStream(ssl_stream)
82
83 def get_extra_info(self, info: str) -> typing.Any:
84 if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
85 # Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__.
86 # Tracked at https://github.com/python-trio/trio/issues/542
87 return self._stream._ssl_object # type: ignore[attr-defined]
88 if info == "client_addr":
89 return self._get_socket_stream().socket.getsockname()
90 if info == "server_addr":
91 return self._get_socket_stream().socket.getpeername()
92 if info == "socket":
93 stream = self._stream
94 while isinstance(stream, trio.SSLStream):
95 stream = stream.transport_stream
96 assert isinstance(stream, trio.SocketStream)
97 return stream.socket
98 if info == "is_readable":
99 socket = self.get_extra_info("socket")
100 return socket.is_readable()
101 return None
102
103 def _get_socket_stream(self) -> trio.SocketStream:
104 stream = self._stream
105 while isinstance(stream, trio.SSLStream):
106 stream = stream.transport_stream
107 assert isinstance(stream, trio.SocketStream)
108 return stream
109
110
111class TrioBackend(AsyncNetworkBackend):
112 async def connect_tcp(
113 self,
114 host: str,
115 port: int,
116 timeout: typing.Optional[float] = None,
117 local_address: typing.Optional[str] = None,
118 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
119 ) -> AsyncNetworkStream:
120 # By default for TCP sockets, trio enables TCP_NODELAY.
121 # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
122 if socket_options is None:
123 socket_options = [] # pragma: no cover
124 timeout_or_inf = float("inf") if timeout is None else timeout
125 exc_map: ExceptionMapping = {
126 trio.TooSlowError: ConnectTimeout,
127 trio.BrokenResourceError: ConnectError,
128 OSError: ConnectError,
129 }
130 with map_exceptions(exc_map):
131 with trio.fail_after(timeout_or_inf):
132 stream: trio.abc.Stream = await trio.open_tcp_stream(
133 host=host, port=port, local_address=local_address
134 )
135 for option in socket_options:
136 stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
137 return TrioStream(stream)
138
139 async def connect_unix_socket(
140 self,
141 path: str,
142 timeout: typing.Optional[float] = None,
143 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
144 ) -> AsyncNetworkStream: # pragma: nocover
145 if socket_options is None:
146 socket_options = []
147 timeout_or_inf = float("inf") if timeout is None else timeout
148 exc_map: ExceptionMapping = {
149 trio.TooSlowError: ConnectTimeout,
150 trio.BrokenResourceError: ConnectError,
151 OSError: ConnectError,
152 }
153 with map_exceptions(exc_map):
154 with trio.fail_after(timeout_or_inf):
155 stream: trio.abc.Stream = await trio.open_unix_socket(path)
156 for option in socket_options:
157 stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
158 return TrioStream(stream)
159
160 async def sleep(self, seconds: float) -> None:
161 await trio.sleep(seconds) # pragma: nocover