1from __future__ import annotations 
    2 
    3import ssl 
    4import typing 
    5 
    6import anyio 
    7 
    8from .._exceptions import ( 
    9    ConnectError, 
    10    ConnectTimeout, 
    11    ReadError, 
    12    ReadTimeout, 
    13    WriteError, 
    14    WriteTimeout, 
    15    map_exceptions, 
    16) 
    17from .._utils import is_socket_readable 
    18from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream 
    19 
    20 
    21class AnyIOStream(AsyncNetworkStream): 
    22    def __init__(self, stream: anyio.abc.ByteStream) -> None: 
    23        self._stream = stream 
    24 
    25    async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: 
    26        exc_map = { 
    27            TimeoutError: ReadTimeout, 
    28            anyio.BrokenResourceError: ReadError, 
    29            anyio.ClosedResourceError: ReadError, 
    30            anyio.EndOfStream: ReadError, 
    31        } 
    32        with map_exceptions(exc_map): 
    33            with anyio.fail_after(timeout): 
    34                try: 
    35                    return await self._stream.receive(max_bytes=max_bytes) 
    36                except anyio.EndOfStream:  # pragma: nocover 
    37                    return b"" 
    38 
    39    async def write(self, buffer: bytes, timeout: float | None = None) -> None: 
    40        if not buffer: 
    41            return 
    42 
    43        exc_map = { 
    44            TimeoutError: WriteTimeout, 
    45            anyio.BrokenResourceError: WriteError, 
    46            anyio.ClosedResourceError: WriteError, 
    47        } 
    48        with map_exceptions(exc_map): 
    49            with anyio.fail_after(timeout): 
    50                await self._stream.send(item=buffer) 
    51 
    52    async def aclose(self) -> None: 
    53        await self._stream.aclose() 
    54 
    55    async def start_tls( 
    56        self, 
    57        ssl_context: ssl.SSLContext, 
    58        server_hostname: str | None = None, 
    59        timeout: float | None = None, 
    60    ) -> AsyncNetworkStream: 
    61        exc_map = { 
    62            TimeoutError: ConnectTimeout, 
    63            anyio.BrokenResourceError: ConnectError, 
    64            anyio.EndOfStream: ConnectError, 
    65            ssl.SSLError: ConnectError, 
    66        } 
    67        with map_exceptions(exc_map): 
    68            try: 
    69                with anyio.fail_after(timeout): 
    70                    ssl_stream = await anyio.streams.tls.TLSStream.wrap( 
    71                        self._stream, 
    72                        ssl_context=ssl_context, 
    73                        hostname=server_hostname, 
    74                        standard_compatible=False, 
    75                        server_side=False, 
    76                    ) 
    77            except Exception as exc:  # pragma: nocover 
    78                await self.aclose() 
    79                raise exc 
    80        return AnyIOStream(ssl_stream) 
    81 
    82    def get_extra_info(self, info: str) -> typing.Any: 
    83        if info == "ssl_object": 
    84            return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None) 
    85        if info == "client_addr": 
    86            return self._stream.extra(anyio.abc.SocketAttribute.local_address, None) 
    87        if info == "server_addr": 
    88            return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None) 
    89        if info == "socket": 
    90            return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) 
    91        if info == "is_readable": 
    92            sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) 
    93            return is_socket_readable(sock) 
    94        return None 
    95 
    96 
    97class AnyIOBackend(AsyncNetworkBackend): 
    98    async def connect_tcp( 
    99        self, 
    100        host: str, 
    101        port: int, 
    102        timeout: float | None = None, 
    103        local_address: str | None = None, 
    104        socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 
    105    ) -> AsyncNetworkStream:  # pragma: nocover 
    106        if socket_options is None: 
    107            socket_options = [] 
    108        exc_map = { 
    109            TimeoutError: ConnectTimeout, 
    110            OSError: ConnectError, 
    111            anyio.BrokenResourceError: ConnectError, 
    112        } 
    113        with map_exceptions(exc_map): 
    114            with anyio.fail_after(timeout): 
    115                stream: anyio.abc.ByteStream = await anyio.connect_tcp( 
    116                    remote_host=host, 
    117                    remote_port=port, 
    118                    local_host=local_address, 
    119                ) 
    120                # By default TCP sockets opened in `asyncio` include TCP_NODELAY. 
    121                for option in socket_options: 
    122                    stream._raw_socket.setsockopt(*option)  # type: ignore[attr-defined] # pragma: no cover 
    123        return AnyIOStream(stream) 
    124 
    125    async def connect_unix_socket( 
    126        self, 
    127        path: str, 
    128        timeout: float | None = None, 
    129        socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 
    130    ) -> AsyncNetworkStream:  # pragma: nocover 
    131        if socket_options is None: 
    132            socket_options = [] 
    133        exc_map = { 
    134            TimeoutError: ConnectTimeout, 
    135            OSError: ConnectError, 
    136            anyio.BrokenResourceError: ConnectError, 
    137        } 
    138        with map_exceptions(exc_map): 
    139            with anyio.fail_after(timeout): 
    140                stream: anyio.abc.ByteStream = await anyio.connect_unix(path) 
    141                for option in socket_options: 
    142                    stream._raw_socket.setsockopt(*option)  # type: ignore[attr-defined] # pragma: no cover 
    143        return AnyIOStream(stream) 
    144 
    145    async def sleep(self, seconds: float) -> None: 
    146        await anyio.sleep(seconds)  # pragma: nocover