1from __future__ import annotations
2
3import ssl
4import typing
5
6import anyio
7
8from .._exceptions import (
9 ConnectError,
10 ConnectTimeout,
11 ReadError,
12 ReadTimeout,
13 WriteError,
14 WriteTimeout,
15 map_exceptions,
16)
17from .._utils import is_socket_readable
18from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
19
20
21class AnyIOStream(AsyncNetworkStream):
22 def __init__(self, stream: anyio.abc.ByteStream) -> None:
23 self._stream = stream
24
25 async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
26 exc_map = {
27 TimeoutError: ReadTimeout,
28 anyio.BrokenResourceError: ReadError,
29 anyio.ClosedResourceError: ReadError,
30 anyio.EndOfStream: ReadError,
31 }
32 with map_exceptions(exc_map):
33 with anyio.fail_after(timeout):
34 try:
35 return await self._stream.receive(max_bytes=max_bytes)
36 except anyio.EndOfStream: # pragma: nocover
37 return b""
38
39 async def write(self, buffer: bytes, timeout: float | None = None) -> None:
40 if not buffer:
41 return
42
43 exc_map = {
44 TimeoutError: WriteTimeout,
45 anyio.BrokenResourceError: WriteError,
46 anyio.ClosedResourceError: WriteError,
47 }
48 with map_exceptions(exc_map):
49 with anyio.fail_after(timeout):
50 await self._stream.send(item=buffer)
51
52 async def aclose(self) -> None:
53 await self._stream.aclose()
54
55 async def start_tls(
56 self,
57 ssl_context: ssl.SSLContext,
58 server_hostname: str | None = None,
59 timeout: float | None = None,
60 ) -> AsyncNetworkStream:
61 exc_map = {
62 TimeoutError: ConnectTimeout,
63 anyio.BrokenResourceError: ConnectError,
64 anyio.EndOfStream: ConnectError,
65 ssl.SSLError: ConnectError,
66 }
67 with map_exceptions(exc_map):
68 try:
69 with anyio.fail_after(timeout):
70 ssl_stream = await anyio.streams.tls.TLSStream.wrap(
71 self._stream,
72 ssl_context=ssl_context,
73 hostname=server_hostname,
74 standard_compatible=False,
75 server_side=False,
76 )
77 except Exception as exc: # pragma: nocover
78 await self.aclose()
79 raise exc
80 return AnyIOStream(ssl_stream)
81
82 def get_extra_info(self, info: str) -> typing.Any:
83 if info == "ssl_object":
84 return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
85 if info == "client_addr":
86 return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
87 if info == "server_addr":
88 return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
89 if info == "socket":
90 return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
91 if info == "is_readable":
92 sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
93 return is_socket_readable(sock)
94 return None
95
96
97class AnyIOBackend(AsyncNetworkBackend):
98 async def connect_tcp(
99 self,
100 host: str,
101 port: int,
102 timeout: float | None = None,
103 local_address: str | None = None,
104 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
105 ) -> AsyncNetworkStream: # pragma: nocover
106 if socket_options is None:
107 socket_options = []
108 exc_map = {
109 TimeoutError: ConnectTimeout,
110 OSError: ConnectError,
111 anyio.BrokenResourceError: ConnectError,
112 }
113 with map_exceptions(exc_map):
114 with anyio.fail_after(timeout):
115 stream: anyio.abc.ByteStream = await anyio.connect_tcp(
116 remote_host=host,
117 remote_port=port,
118 local_host=local_address,
119 )
120 # By default TCP sockets opened in `asyncio` include TCP_NODELAY.
121 for option in socket_options:
122 stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
123 return AnyIOStream(stream)
124
125 async def connect_unix_socket(
126 self,
127 path: str,
128 timeout: float | None = None,
129 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
130 ) -> AsyncNetworkStream: # pragma: nocover
131 if socket_options is None:
132 socket_options = []
133 exc_map = {
134 TimeoutError: ConnectTimeout,
135 OSError: ConnectError,
136 anyio.BrokenResourceError: ConnectError,
137 }
138 with map_exceptions(exc_map):
139 with anyio.fail_after(timeout):
140 stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
141 for option in socket_options:
142 stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
143 return AnyIOStream(stream)
144
145 async def sleep(self, seconds: float) -> None:
146 await anyio.sleep(seconds) # pragma: nocover