1import ssl
2import typing
3
4import anyio
5
6from .._exceptions import (
7 ConnectError,
8 ConnectTimeout,
9 ReadError,
10 ReadTimeout,
11 WriteError,
12 WriteTimeout,
13 map_exceptions,
14)
15from .._utils import is_socket_readable
16from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
17
18
19class AnyIOStream(AsyncNetworkStream):
20 def __init__(self, stream: anyio.abc.ByteStream) -> None:
21 self._stream = stream
22
23 async def read(
24 self, max_bytes: int, timeout: typing.Optional[float] = None
25 ) -> bytes:
26 exc_map = {
27 TimeoutError: ReadTimeout,
28 anyio.BrokenResourceError: ReadError,
29 anyio.ClosedResourceError: ReadError,
30 anyio.EndOfStream: ReadError,
31 }
32 with map_exceptions(exc_map):
33 with anyio.fail_after(timeout):
34 try:
35 return await self._stream.receive(max_bytes=max_bytes)
36 except anyio.EndOfStream: # pragma: nocover
37 return b""
38
39 async def write(
40 self, buffer: bytes, timeout: typing.Optional[float] = None
41 ) -> None:
42 if not buffer:
43 return
44
45 exc_map = {
46 TimeoutError: WriteTimeout,
47 anyio.BrokenResourceError: WriteError,
48 anyio.ClosedResourceError: WriteError,
49 }
50 with map_exceptions(exc_map):
51 with anyio.fail_after(timeout):
52 await self._stream.send(item=buffer)
53
54 async def aclose(self) -> None:
55 await self._stream.aclose()
56
57 async def start_tls(
58 self,
59 ssl_context: ssl.SSLContext,
60 server_hostname: typing.Optional[str] = None,
61 timeout: typing.Optional[float] = None,
62 ) -> AsyncNetworkStream:
63 exc_map = {
64 TimeoutError: ConnectTimeout,
65 anyio.BrokenResourceError: ConnectError,
66 anyio.EndOfStream: ConnectError,
67 ssl.SSLError: ConnectError,
68 }
69 with map_exceptions(exc_map):
70 try:
71 with anyio.fail_after(timeout):
72 ssl_stream = await anyio.streams.tls.TLSStream.wrap(
73 self._stream,
74 ssl_context=ssl_context,
75 hostname=server_hostname,
76 standard_compatible=False,
77 server_side=False,
78 )
79 except Exception as exc: # pragma: nocover
80 await self.aclose()
81 raise exc
82 return AnyIOStream(ssl_stream)
83
84 def get_extra_info(self, info: str) -> typing.Any:
85 if info == "ssl_object":
86 return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
87 if info == "client_addr":
88 return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
89 if info == "server_addr":
90 return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
91 if info == "socket":
92 return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
93 if info == "is_readable":
94 sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
95 return is_socket_readable(sock)
96 return None
97
98
99class AnyIOBackend(AsyncNetworkBackend):
100 async def connect_tcp(
101 self,
102 host: str,
103 port: int,
104 timeout: typing.Optional[float] = None,
105 local_address: typing.Optional[str] = None,
106 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
107 ) -> AsyncNetworkStream: # pragma: nocover
108 if socket_options is None:
109 socket_options = []
110 exc_map = {
111 TimeoutError: ConnectTimeout,
112 OSError: ConnectError,
113 anyio.BrokenResourceError: ConnectError,
114 }
115 with map_exceptions(exc_map):
116 with anyio.fail_after(timeout):
117 stream: anyio.abc.ByteStream = await anyio.connect_tcp(
118 remote_host=host,
119 remote_port=port,
120 local_host=local_address,
121 )
122 # By default TCP sockets opened in `asyncio` include TCP_NODELAY.
123 for option in socket_options:
124 stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
125 return AnyIOStream(stream)
126
127 async def connect_unix_socket(
128 self,
129 path: str,
130 timeout: typing.Optional[float] = None,
131 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
132 ) -> AsyncNetworkStream: # pragma: nocover
133 if socket_options is None:
134 socket_options = []
135 exc_map = {
136 TimeoutError: ConnectTimeout,
137 OSError: ConnectError,
138 anyio.BrokenResourceError: ConnectError,
139 }
140 with map_exceptions(exc_map):
141 with anyio.fail_after(timeout):
142 stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
143 for option in socket_options:
144 stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
145 return AnyIOStream(stream)
146
147 async def sleep(self, seconds: float) -> None:
148 await anyio.sleep(seconds) # pragma: nocover