1from __future__ import annotations
2
3import socket
4from abc import abstractmethod
5from collections.abc import Callable, Collection, Mapping
6from contextlib import AsyncExitStack
7from io import IOBase
8from ipaddress import IPv4Address, IPv6Address
9from socket import AddressFamily
10from types import TracebackType
11from typing import Any, Tuple, TypeVar, Union
12
13from .._core._typedattr import (
14 TypedAttributeProvider,
15 TypedAttributeSet,
16 typed_attribute,
17)
18from ._streams import ByteStream, Listener, UnreliableObjectStream
19from ._tasks import TaskGroup
20
21IPAddressType = Union[str, IPv4Address, IPv6Address]
22IPSockAddrType = Tuple[str, int]
23SockAddrType = Union[IPSockAddrType, str]
24UDPPacketType = Tuple[bytes, IPSockAddrType]
25UNIXDatagramPacketType = Tuple[bytes, str]
26T_Retval = TypeVar("T_Retval")
27
28
29class _NullAsyncContextManager:
30 async def __aenter__(self) -> None:
31 pass
32
33 async def __aexit__(
34 self,
35 exc_type: type[BaseException] | None,
36 exc_val: BaseException | None,
37 exc_tb: TracebackType | None,
38 ) -> bool | None:
39 return None
40
41
42class SocketAttribute(TypedAttributeSet):
43 #: the address family of the underlying socket
44 family: AddressFamily = typed_attribute()
45 #: the local socket address of the underlying socket
46 local_address: SockAddrType = typed_attribute()
47 #: for IP addresses, the local port the underlying socket is bound to
48 local_port: int = typed_attribute()
49 #: the underlying stdlib socket object
50 raw_socket: socket.socket = typed_attribute()
51 #: the remote address the underlying socket is connected to
52 remote_address: SockAddrType = typed_attribute()
53 #: for IP addresses, the remote port the underlying socket is connected to
54 remote_port: int = typed_attribute()
55
56
57class _SocketProvider(TypedAttributeProvider):
58 @property
59 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
60 from .._core._sockets import convert_ipv6_sockaddr as convert
61
62 attributes: dict[Any, Callable[[], Any]] = {
63 SocketAttribute.family: lambda: self._raw_socket.family,
64 SocketAttribute.local_address: lambda: convert(
65 self._raw_socket.getsockname()
66 ),
67 SocketAttribute.raw_socket: lambda: self._raw_socket,
68 }
69 try:
70 peername: tuple[str, int] | None = convert(self._raw_socket.getpeername())
71 except OSError:
72 peername = None
73
74 # Provide the remote address for connected sockets
75 if peername is not None:
76 attributes[SocketAttribute.remote_address] = lambda: peername
77
78 # Provide local and remote ports for IP based sockets
79 if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
80 attributes[SocketAttribute.local_port] = (
81 lambda: self._raw_socket.getsockname()[1]
82 )
83 if peername is not None:
84 remote_port = peername[1]
85 attributes[SocketAttribute.remote_port] = lambda: remote_port
86
87 return attributes
88
89 @property
90 @abstractmethod
91 def _raw_socket(self) -> socket.socket:
92 pass
93
94
95class SocketStream(ByteStream, _SocketProvider):
96 """
97 Transports bytes over a socket.
98
99 Supports all relevant extra attributes from :class:`~SocketAttribute`.
100 """
101
102
103class UNIXSocketStream(SocketStream):
104 @abstractmethod
105 async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
106 """
107 Send file descriptors along with a message to the peer.
108
109 :param message: a non-empty bytestring
110 :param fds: a collection of files (either numeric file descriptors or open file
111 or socket objects)
112 """
113
114 @abstractmethod
115 async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
116 """
117 Receive file descriptors along with a message from the peer.
118
119 :param msglen: length of the message to expect from the peer
120 :param maxfds: maximum number of file descriptors to expect from the peer
121 :return: a tuple of (message, file descriptors)
122 """
123
124
125class SocketListener(Listener[SocketStream], _SocketProvider):
126 """
127 Listens to incoming socket connections.
128
129 Supports all relevant extra attributes from :class:`~SocketAttribute`.
130 """
131
132 @abstractmethod
133 async def accept(self) -> SocketStream:
134 """Accept an incoming connection."""
135
136 async def serve(
137 self,
138 handler: Callable[[SocketStream], Any],
139 task_group: TaskGroup | None = None,
140 ) -> None:
141 from .. import create_task_group
142
143 async with AsyncExitStack() as stack:
144 if task_group is None:
145 task_group = await stack.enter_async_context(create_task_group())
146
147 while True:
148 stream = await self.accept()
149 task_group.start_soon(handler, stream)
150
151
152class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
153 """
154 Represents an unconnected UDP socket.
155
156 Supports all relevant extra attributes from :class:`~SocketAttribute`.
157 """
158
159 async def sendto(self, data: bytes, host: str, port: int) -> None:
160 """
161 Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).
162
163 """
164 return await self.send((data, (host, port)))
165
166
167class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
168 """
169 Represents an connected UDP socket.
170
171 Supports all relevant extra attributes from :class:`~SocketAttribute`.
172 """
173
174
175class UNIXDatagramSocket(
176 UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider
177):
178 """
179 Represents an unconnected Unix datagram socket.
180
181 Supports all relevant extra attributes from :class:`~SocketAttribute`.
182 """
183
184 async def sendto(self, data: bytes, path: str) -> None:
185 """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path))."""
186 return await self.send((data, path))
187
188
189class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider):
190 """
191 Represents a connected Unix datagram socket.
192
193 Supports all relevant extra attributes from :class:`~SocketAttribute`.
194 """