1import ssl
2import typing
3
4import anyio
5
6from .._exceptions import (
7 ConnectError,
8 ConnectTimeout,
9 ReadError,
10 ReadTimeout,
11 WriteError,
12 WriteTimeout,
13 map_exceptions,
14)
15from .._utils import is_socket_readable
16from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
17
18
19class AnyIOStream(AsyncNetworkStream):
20 def __init__(self, stream: anyio.abc.ByteStream) -> None:
21 self._stream = stream
22
23 async def read(
24 self, max_bytes: int, timeout: typing.Optional[float] = None
25 ) -> bytes:
26 exc_map = {
27 TimeoutError: ReadTimeout,
28 anyio.BrokenResourceError: ReadError,
29 anyio.ClosedResourceError: ReadError,
30 anyio.EndOfStream: ReadError,
31 }
32 with map_exceptions(exc_map):
33 with anyio.fail_after(timeout):
34 try:
35 return await self._stream.receive(max_bytes=max_bytes)
36 except anyio.EndOfStream: # pragma: nocover
37 return b""
38
39 async def write(
40 self, buffer: bytes, timeout: typing.Optional[float] = None
41 ) -> None:
42 if not buffer:
43 return
44
45 exc_map = {
46 TimeoutError: WriteTimeout,
47 anyio.BrokenResourceError: WriteError,
48 anyio.ClosedResourceError: WriteError,
49 }
50 with map_exceptions(exc_map):
51 with anyio.fail_after(timeout):
52 await self._stream.send(item=buffer)
53
54 async def aclose(self) -> None:
55 await self._stream.aclose()
56
57 async def start_tls(
58 self,
59 ssl_context: ssl.SSLContext,
60 server_hostname: typing.Optional[str] = None,
61 timeout: typing.Optional[float] = None,
62 ) -> AsyncNetworkStream:
63 exc_map = {
64 TimeoutError: ConnectTimeout,
65 anyio.BrokenResourceError: ConnectError,
66 anyio.EndOfStream: ConnectError,
67 }
68 with map_exceptions(exc_map):
69 try:
70 with anyio.fail_after(timeout):
71 ssl_stream = await anyio.streams.tls.TLSStream.wrap(
72 self._stream,
73 ssl_context=ssl_context,
74 hostname=server_hostname,
75 standard_compatible=False,
76 server_side=False,
77 )
78 except Exception as exc: # pragma: nocover
79 await self.aclose()
80 raise exc
81 return AnyIOStream(ssl_stream)
82
83 def get_extra_info(self, info: str) -> typing.Any:
84 if info == "ssl_object":
85 return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
86 if info == "client_addr":
87 return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
88 if info == "server_addr":
89 return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
90 if info == "socket":
91 return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
92 if info == "is_readable":
93 sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
94 return is_socket_readable(sock)
95 return None
96
97
98class AnyIOBackend(AsyncNetworkBackend):
99 async def connect_tcp(
100 self,
101 host: str,
102 port: int,
103 timeout: typing.Optional[float] = None,
104 local_address: typing.Optional[str] = None,
105 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
106 ) -> AsyncNetworkStream:
107 if socket_options is None:
108 socket_options = [] # pragma: no cover
109 exc_map = {
110 TimeoutError: ConnectTimeout,
111 OSError: ConnectError,
112 anyio.BrokenResourceError: ConnectError,
113 }
114 with map_exceptions(exc_map):
115 with anyio.fail_after(timeout):
116 stream: anyio.abc.ByteStream = await anyio.connect_tcp(
117 remote_host=host,
118 remote_port=port,
119 local_host=local_address,
120 )
121 # By default TCP sockets opened in `asyncio` include TCP_NODELAY.
122 for option in socket_options:
123 stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
124 return AnyIOStream(stream)
125
126 async def connect_unix_socket(
127 self,
128 path: str,
129 timeout: typing.Optional[float] = None,
130 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
131 ) -> AsyncNetworkStream: # pragma: nocover
132 if socket_options is None:
133 socket_options = []
134 exc_map = {
135 TimeoutError: ConnectTimeout,
136 OSError: ConnectError,
137 anyio.BrokenResourceError: ConnectError,
138 }
139 with map_exceptions(exc_map):
140 with anyio.fail_after(timeout):
141 stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
142 for option in socket_options:
143 stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
144 return AnyIOStream(stream)
145
146 async def sleep(self, seconds: float) -> None:
147 await anyio.sleep(seconds) # pragma: nocover