1from __future__ import annotations
2
3import errno
4import socket
5import sys
6from abc import abstractmethod
7from collections.abc import Callable, Collection, Mapping
8from contextlib import AsyncExitStack
9from io import IOBase
10from ipaddress import IPv4Address, IPv6Address
11from socket import AddressFamily
12from typing import Any, TypeVar, Union
13
14from .._core._eventloop import get_async_backend
15from .._core._typedattr import (
16 TypedAttributeProvider,
17 TypedAttributeSet,
18 typed_attribute,
19)
20from ._streams import ByteStream, Listener, UnreliableObjectStream
21from ._tasks import TaskGroup
22
23if sys.version_info >= (3, 10):
24 from typing import TypeAlias
25else:
26 from typing_extensions import TypeAlias
27
28IPAddressType: TypeAlias = Union[str, IPv4Address, IPv6Address]
29IPSockAddrType: TypeAlias = tuple[str, int]
30SockAddrType: TypeAlias = Union[IPSockAddrType, str]
31UDPPacketType: TypeAlias = tuple[bytes, IPSockAddrType]
32UNIXDatagramPacketType: TypeAlias = tuple[bytes, str]
33T_Retval = TypeVar("T_Retval")
34
35
36def _validate_socket(
37 sock_or_fd: socket.socket | int,
38 sock_type: socket.SocketKind,
39 addr_family: socket.AddressFamily = socket.AF_UNSPEC,
40 *,
41 require_connected: bool = False,
42 require_bound: bool = False,
43) -> socket.socket:
44 if isinstance(sock_or_fd, int):
45 try:
46 sock = socket.socket(fileno=sock_or_fd)
47 except OSError as exc:
48 if exc.errno == errno.ENOTSOCK:
49 raise ValueError(
50 "the file descriptor does not refer to a socket"
51 ) from exc
52 elif require_connected:
53 raise ValueError("the socket must be connected") from exc
54 elif require_bound:
55 raise ValueError("the socket must be bound to a local address") from exc
56 else:
57 raise
58 elif isinstance(sock_or_fd, socket.socket):
59 sock = sock_or_fd
60 else:
61 raise TypeError(
62 f"expected an int or socket, got {type(sock_or_fd).__qualname__} instead"
63 )
64
65 try:
66 if require_connected:
67 try:
68 sock.getpeername()
69 except OSError as exc:
70 raise ValueError("the socket must be connected") from exc
71
72 if require_bound:
73 try:
74 if sock.family in (socket.AF_INET, socket.AF_INET6):
75 bound_addr = sock.getsockname()[1]
76 else:
77 bound_addr = sock.getsockname()
78 except OSError:
79 bound_addr = None
80
81 if not bound_addr:
82 raise ValueError("the socket must be bound to a local address")
83
84 if addr_family != socket.AF_UNSPEC and sock.family != addr_family:
85 raise ValueError(
86 f"address family mismatch: expected {addr_family.name}, got "
87 f"{sock.family.name}"
88 )
89
90 if sock.type != sock_type:
91 raise ValueError(
92 f"socket type mismatch: expected {sock_type.name}, got {sock.type.name}"
93 )
94 except BaseException:
95 # Avoid ResourceWarning from the locally constructed socket object
96 if isinstance(sock_or_fd, int):
97 sock.detach()
98
99 raise
100
101 sock.setblocking(False)
102 return sock
103
104
105class SocketAttribute(TypedAttributeSet):
106 """
107 .. attribute:: family
108 :type: socket.AddressFamily
109
110 the address family of the underlying socket
111
112 .. attribute:: local_address
113 :type: tuple[str, int] | str
114
115 the local address the underlying socket is connected to
116
117 .. attribute:: local_port
118 :type: int
119
120 for IP based sockets, the local port the underlying socket is bound to
121
122 .. attribute:: raw_socket
123 :type: socket.socket
124
125 the underlying stdlib socket object
126
127 .. attribute:: remote_address
128 :type: tuple[str, int] | str
129
130 the remote address the underlying socket is connected to
131
132 .. attribute:: remote_port
133 :type: int
134
135 for IP based sockets, the remote port the underlying socket is connected to
136 """
137
138 family: AddressFamily = typed_attribute()
139 local_address: SockAddrType = typed_attribute()
140 local_port: int = typed_attribute()
141 raw_socket: socket.socket = typed_attribute()
142 remote_address: SockAddrType = typed_attribute()
143 remote_port: int = typed_attribute()
144
145
146class _SocketProvider(TypedAttributeProvider):
147 @property
148 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
149 from .._core._sockets import convert_ipv6_sockaddr as convert
150
151 attributes: dict[Any, Callable[[], Any]] = {
152 SocketAttribute.family: lambda: self._raw_socket.family,
153 SocketAttribute.local_address: lambda: convert(
154 self._raw_socket.getsockname()
155 ),
156 SocketAttribute.raw_socket: lambda: self._raw_socket,
157 }
158 try:
159 peername: tuple[str, int] | None = convert(self._raw_socket.getpeername())
160 except OSError:
161 peername = None
162
163 # Provide the remote address for connected sockets
164 if peername is not None:
165 attributes[SocketAttribute.remote_address] = lambda: peername
166
167 # Provide local and remote ports for IP based sockets
168 if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
169 attributes[SocketAttribute.local_port] = (
170 lambda: self._raw_socket.getsockname()[1]
171 )
172 if peername is not None:
173 remote_port = peername[1]
174 attributes[SocketAttribute.remote_port] = lambda: remote_port
175
176 return attributes
177
178 @property
179 @abstractmethod
180 def _raw_socket(self) -> socket.socket:
181 pass
182
183
184class SocketStream(ByteStream, _SocketProvider):
185 """
186 Transports bytes over a socket.
187
188 Supports all relevant extra attributes from :class:`~SocketAttribute`.
189 """
190
191 @classmethod
192 async def from_socket(cls, sock_or_fd: socket.socket | int) -> SocketStream:
193 """
194 Wrap an existing socket object or file descriptor as a socket stream.
195
196 The newly created socket wrapper takes ownership of the socket being passed in.
197 The existing socket must already be connected.
198
199 :param sock_or_fd: a socket object or file descriptor
200 :return: a socket stream
201
202 """
203 sock = _validate_socket(sock_or_fd, socket.SOCK_STREAM, require_connected=True)
204 return await get_async_backend().wrap_stream_socket(sock)
205
206
207class UNIXSocketStream(SocketStream):
208 @classmethod
209 async def from_socket(cls, sock_or_fd: socket.socket | int) -> UNIXSocketStream:
210 """
211 Wrap an existing socket object or file descriptor as a UNIX socket stream.
212
213 The newly created socket wrapper takes ownership of the socket being passed in.
214 The existing socket must already be connected.
215
216 :param sock_or_fd: a socket object or file descriptor
217 :return: a UNIX socket stream
218
219 """
220 sock = _validate_socket(
221 sock_or_fd, socket.SOCK_STREAM, socket.AF_UNIX, require_connected=True
222 )
223 return await get_async_backend().wrap_unix_stream_socket(sock)
224
225 @abstractmethod
226 async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
227 """
228 Send file descriptors along with a message to the peer.
229
230 :param message: a non-empty bytestring
231 :param fds: a collection of files (either numeric file descriptors or open file
232 or socket objects)
233 """
234
235 @abstractmethod
236 async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
237 """
238 Receive file descriptors along with a message from the peer.
239
240 :param msglen: length of the message to expect from the peer
241 :param maxfds: maximum number of file descriptors to expect from the peer
242 :return: a tuple of (message, file descriptors)
243 """
244
245
246class SocketListener(Listener[SocketStream], _SocketProvider):
247 """
248 Listens to incoming socket connections.
249
250 Supports all relevant extra attributes from :class:`~SocketAttribute`.
251 """
252
253 @classmethod
254 async def from_socket(
255 cls,
256 sock_or_fd: socket.socket | int,
257 ) -> SocketListener:
258 """
259 Wrap an existing socket object or file descriptor as a socket listener.
260
261 The newly created listener takes ownership of the socket being passed in.
262
263 :param sock_or_fd: a socket object or file descriptor
264 :return: a socket listener
265
266 """
267 sock = _validate_socket(sock_or_fd, socket.SOCK_STREAM, require_bound=True)
268 return await get_async_backend().wrap_listener_socket(sock)
269
270 @abstractmethod
271 async def accept(self) -> SocketStream:
272 """Accept an incoming connection."""
273
274 async def serve(
275 self,
276 handler: Callable[[SocketStream], Any],
277 task_group: TaskGroup | None = None,
278 ) -> None:
279 from .. import create_task_group
280
281 async with AsyncExitStack() as stack:
282 if task_group is None:
283 task_group = await stack.enter_async_context(create_task_group())
284
285 while True:
286 stream = await self.accept()
287 task_group.start_soon(handler, stream)
288
289
290class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
291 """
292 Represents an unconnected UDP socket.
293
294 Supports all relevant extra attributes from :class:`~SocketAttribute`.
295 """
296
297 @classmethod
298 async def from_socket(cls, sock_or_fd: socket.socket | int) -> UDPSocket:
299 """
300 Wrap an existing socket object or file descriptor as a UDP socket.
301
302 The newly created socket wrapper takes ownership of the socket being passed in.
303 The existing socket must be bound to a local address.
304
305 :param sock_or_fd: a socket object or file descriptor
306 :return: a UDP socket
307
308 """
309 sock = _validate_socket(sock_or_fd, socket.SOCK_DGRAM, require_bound=True)
310 return await get_async_backend().wrap_udp_socket(sock)
311
312 async def sendto(self, data: bytes, host: str, port: int) -> None:
313 """
314 Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).
315
316 """
317 return await self.send((data, (host, port)))
318
319
320class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
321 """
322 Represents an connected UDP socket.
323
324 Supports all relevant extra attributes from :class:`~SocketAttribute`.
325 """
326
327 @classmethod
328 async def from_socket(cls, sock_or_fd: socket.socket | int) -> ConnectedUDPSocket:
329 """
330 Wrap an existing socket object or file descriptor as a connected UDP socket.
331
332 The newly created socket wrapper takes ownership of the socket being passed in.
333 The existing socket must already be connected.
334
335 :param sock_or_fd: a socket object or file descriptor
336 :return: a connected UDP socket
337
338 """
339 sock = _validate_socket(
340 sock_or_fd,
341 socket.SOCK_DGRAM,
342 require_connected=True,
343 )
344 return await get_async_backend().wrap_connected_udp_socket(sock)
345
346
347class UNIXDatagramSocket(
348 UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider
349):
350 """
351 Represents an unconnected Unix datagram socket.
352
353 Supports all relevant extra attributes from :class:`~SocketAttribute`.
354 """
355
356 @classmethod
357 async def from_socket(
358 cls,
359 sock_or_fd: socket.socket | int,
360 ) -> UNIXDatagramSocket:
361 """
362 Wrap an existing socket object or file descriptor as a UNIX datagram
363 socket.
364
365 The newly created socket wrapper takes ownership of the socket being passed in.
366
367 :param sock_or_fd: a socket object or file descriptor
368 :return: a UNIX datagram socket
369
370 """
371 sock = _validate_socket(sock_or_fd, socket.SOCK_DGRAM, socket.AF_UNIX)
372 return await get_async_backend().wrap_unix_datagram_socket(sock)
373
374 async def sendto(self, data: bytes, path: str) -> None:
375 """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path))."""
376 return await self.send((data, path))
377
378
379class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider):
380 """
381 Represents a connected Unix datagram socket.
382
383 Supports all relevant extra attributes from :class:`~SocketAttribute`.
384 """
385
386 @classmethod
387 async def from_socket(
388 cls,
389 sock_or_fd: socket.socket | int,
390 ) -> ConnectedUNIXDatagramSocket:
391 """
392 Wrap an existing socket object or file descriptor as a connected UNIX datagram
393 socket.
394
395 The newly created socket wrapper takes ownership of the socket being passed in.
396 The existing socket must already be connected.
397
398 :param sock_or_fd: a socket object or file descriptor
399 :return: a connected UNIX datagram socket
400
401 """
402 sock = _validate_socket(
403 sock_or_fd, socket.SOCK_DGRAM, socket.AF_UNIX, require_connected=True
404 )
405 return await get_async_backend().wrap_connected_unix_datagram_socket(sock)