1from __future__ import annotations
2
3import errno
4import os
5import socket
6import ssl
7import stat
8import sys
9from collections.abc import Awaitable
10from dataclasses import dataclass
11from ipaddress import IPv4Address, IPv6Address, ip_address
12from os import PathLike, chmod
13from socket import AddressFamily, SocketKind
14from typing import TYPE_CHECKING, Any, Literal, cast, overload
15
16from .. import ConnectionFailed, to_thread
17from ..abc import (
18 ByteStreamConnectable,
19 ConnectedUDPSocket,
20 ConnectedUNIXDatagramSocket,
21 IPAddressType,
22 IPSockAddrType,
23 SocketListener,
24 SocketStream,
25 UDPSocket,
26 UNIXDatagramSocket,
27 UNIXSocketStream,
28)
29from ..streams.stapled import MultiListener
30from ..streams.tls import TLSConnectable, TLSStream
31from ._eventloop import get_async_backend
32from ._resources import aclose_forcefully
33from ._synchronization import Event
34from ._tasks import create_task_group, move_on_after
35
36if TYPE_CHECKING:
37 from _typeshed import FileDescriptorLike
38else:
39 FileDescriptorLike = object
40
41if sys.version_info < (3, 11):
42 from exceptiongroup import ExceptionGroup
43
44if sys.version_info >= (3, 12):
45 from typing import override
46else:
47 from typing_extensions import override
48
49if sys.version_info < (3, 13):
50 from typing_extensions import deprecated
51else:
52 from warnings import deprecated
53
54IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515
55
56AnyIPAddressFamily = Literal[
57 AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
58]
59IPAddressFamily = Literal[AddressFamily.AF_INET, AddressFamily.AF_INET6]
60
61
62# tls_hostname given
63@overload
64async def connect_tcp(
65 remote_host: IPAddressType,
66 remote_port: int,
67 *,
68 local_host: IPAddressType | None = ...,
69 ssl_context: ssl.SSLContext | None = ...,
70 tls_standard_compatible: bool = ...,
71 tls_hostname: str,
72 happy_eyeballs_delay: float = ...,
73) -> TLSStream: ...
74
75
76# ssl_context given
77@overload
78async def connect_tcp(
79 remote_host: IPAddressType,
80 remote_port: int,
81 *,
82 local_host: IPAddressType | None = ...,
83 ssl_context: ssl.SSLContext,
84 tls_standard_compatible: bool = ...,
85 tls_hostname: str | None = ...,
86 happy_eyeballs_delay: float = ...,
87) -> TLSStream: ...
88
89
90# tls=True
91@overload
92async def connect_tcp(
93 remote_host: IPAddressType,
94 remote_port: int,
95 *,
96 local_host: IPAddressType | None = ...,
97 tls: Literal[True],
98 ssl_context: ssl.SSLContext | None = ...,
99 tls_standard_compatible: bool = ...,
100 tls_hostname: str | None = ...,
101 happy_eyeballs_delay: float = ...,
102) -> TLSStream: ...
103
104
105# tls=False
106@overload
107async def connect_tcp(
108 remote_host: IPAddressType,
109 remote_port: int,
110 *,
111 local_host: IPAddressType | None = ...,
112 tls: Literal[False],
113 ssl_context: ssl.SSLContext | None = ...,
114 tls_standard_compatible: bool = ...,
115 tls_hostname: str | None = ...,
116 happy_eyeballs_delay: float = ...,
117) -> SocketStream: ...
118
119
120# No TLS arguments
121@overload
122async def connect_tcp(
123 remote_host: IPAddressType,
124 remote_port: int,
125 *,
126 local_host: IPAddressType | None = ...,
127 happy_eyeballs_delay: float = ...,
128) -> SocketStream: ...
129
130
131async def connect_tcp(
132 remote_host: IPAddressType,
133 remote_port: int,
134 *,
135 local_host: IPAddressType | None = None,
136 tls: bool = False,
137 ssl_context: ssl.SSLContext | None = None,
138 tls_standard_compatible: bool = True,
139 tls_hostname: str | None = None,
140 happy_eyeballs_delay: float = 0.25,
141) -> SocketStream | TLSStream:
142 """
143 Connect to a host using the TCP protocol.
144
145 This function implements the stateless version of the Happy Eyeballs algorithm (RFC
146 6555). If ``remote_host`` is a host name that resolves to multiple IP addresses,
147 each one is tried until one connection attempt succeeds. If the first attempt does
148 not connected within 250 milliseconds, a second attempt is started using the next
149 address in the list, and so on. On IPv6 enabled systems, an IPv6 address (if
150 available) is tried first.
151
152 When the connection has been established, a TLS handshake will be done if either
153 ``ssl_context`` or ``tls_hostname`` is not ``None``, or if ``tls`` is ``True``.
154
155 :param remote_host: the IP address or host name to connect to
156 :param remote_port: port on the target host to connect to
157 :param local_host: the interface address or name to bind the socket to before
158 connecting
159 :param tls: ``True`` to do a TLS handshake with the connected stream and return a
160 :class:`~anyio.streams.tls.TLSStream` instead
161 :param ssl_context: the SSL context object to use (if omitted, a default context is
162 created)
163 :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake
164 before closing the stream and requires that the server does this as well.
165 Otherwise, :exc:`~ssl.SSLEOFError` may be raised during reads from the stream.
166 Some protocols, such as HTTP, require this option to be ``False``.
167 See :meth:`~ssl.SSLContext.wrap_socket` for details.
168 :param tls_hostname: host name to check the server certificate against (defaults to
169 the value of ``remote_host``)
170 :param happy_eyeballs_delay: delay (in seconds) before starting the next connection
171 attempt
172 :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream
173 :raises ConnectionFailed: if the connection fails
174
175 """
176 # Placed here due to https://github.com/python/mypy/issues/7057
177 connected_stream: SocketStream | None = None
178
179 async def try_connect(remote_host: str, event: Event) -> None:
180 nonlocal connected_stream
181 try:
182 stream = await asynclib.connect_tcp(remote_host, remote_port, local_address)
183 except OSError as exc:
184 oserrors.append(exc)
185 return
186 else:
187 if connected_stream is None:
188 connected_stream = stream
189 tg.cancel_scope.cancel()
190 else:
191 await stream.aclose()
192 finally:
193 event.set()
194
195 asynclib = get_async_backend()
196 local_address: IPSockAddrType | None = None
197 family = socket.AF_UNSPEC
198 if local_host:
199 gai_res = await getaddrinfo(str(local_host), None)
200 family, *_, local_address = gai_res[0]
201
202 target_host = str(remote_host)
203 try:
204 addr_obj = ip_address(remote_host)
205 except ValueError:
206 addr_obj = None
207
208 if addr_obj is not None:
209 if isinstance(addr_obj, IPv6Address):
210 target_addrs = [(socket.AF_INET6, addr_obj.compressed)]
211 else:
212 target_addrs = [(socket.AF_INET, addr_obj.compressed)]
213 else:
214 # getaddrinfo() will raise an exception if name resolution fails
215 gai_res = await getaddrinfo(
216 target_host, remote_port, family=family, type=socket.SOCK_STREAM
217 )
218
219 # Organize the list so that the first address is an IPv6 address (if available)
220 # and the second one is an IPv4 addresses. The rest can be in whatever order.
221 v6_found = v4_found = False
222 target_addrs = []
223 for af, *_, sa in gai_res:
224 if af == socket.AF_INET6 and not v6_found:
225 v6_found = True
226 target_addrs.insert(0, (af, sa[0]))
227 elif af == socket.AF_INET and not v4_found and v6_found:
228 v4_found = True
229 target_addrs.insert(1, (af, sa[0]))
230 else:
231 target_addrs.append((af, sa[0]))
232
233 oserrors: list[OSError] = []
234 try:
235 async with create_task_group() as tg:
236 for _af, addr in target_addrs:
237 event = Event()
238 tg.start_soon(try_connect, addr, event)
239 with move_on_after(happy_eyeballs_delay):
240 await event.wait()
241
242 if connected_stream is None:
243 cause = (
244 oserrors[0]
245 if len(oserrors) == 1
246 else ExceptionGroup("multiple connection attempts failed", oserrors)
247 )
248 raise OSError("All connection attempts failed") from cause
249 finally:
250 oserrors.clear()
251
252 if tls or tls_hostname or ssl_context:
253 try:
254 return await TLSStream.wrap(
255 connected_stream,
256 server_side=False,
257 hostname=tls_hostname or str(remote_host),
258 ssl_context=ssl_context,
259 standard_compatible=tls_standard_compatible,
260 )
261 except BaseException:
262 await aclose_forcefully(connected_stream)
263 raise
264
265 return connected_stream
266
267
268async def connect_unix(path: str | bytes | PathLike[Any]) -> UNIXSocketStream:
269 """
270 Connect to the given UNIX socket.
271
272 Not available on Windows.
273
274 :param path: path to the socket
275 :return: a socket stream object
276 :raises ConnectionFailed: if the connection fails
277
278 """
279 path = os.fspath(path)
280 return await get_async_backend().connect_unix(path)
281
282
283async def create_tcp_listener(
284 *,
285 local_host: IPAddressType | None = None,
286 local_port: int = 0,
287 family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC,
288 backlog: int = 65536,
289 reuse_port: bool = False,
290) -> MultiListener[SocketStream]:
291 """
292 Create a TCP socket listener.
293
294 :param local_port: port number to listen on
295 :param local_host: IP address of the interface to listen on. If omitted, listen on
296 all IPv4 and IPv6 interfaces. To listen on all interfaces on a specific address
297 family, use ``0.0.0.0`` for IPv4 or ``::`` for IPv6.
298 :param family: address family (used if ``local_host`` was omitted)
299 :param backlog: maximum number of queued incoming connections (up to a maximum of
300 2**16, or 65536)
301 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
302 address/port (not supported on Windows)
303 :return: a multi-listener object containing one or more socket listeners
304 :raises OSError: if there's an error creating a socket, or binding to one or more
305 interfaces failed
306
307 """
308 asynclib = get_async_backend()
309 backlog = min(backlog, 65536)
310 local_host = str(local_host) if local_host is not None else None
311
312 def setup_raw_socket(
313 fam: AddressFamily,
314 bind_addr: tuple[str, int] | tuple[str, int, int, int],
315 *,
316 v6only: bool = True,
317 ) -> socket.socket:
318 sock = socket.socket(fam)
319 try:
320 sock.setblocking(False)
321
322 if fam == AddressFamily.AF_INET6:
323 sock.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, v6only)
324
325 # For Windows, enable exclusive address use. For others, enable address
326 # reuse.
327 if sys.platform == "win32":
328 sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
329 else:
330 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
331
332 if reuse_port:
333 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
334
335 # Workaround for #554
336 if fam == socket.AF_INET6 and "%" in bind_addr[0]:
337 addr, scope_id = bind_addr[0].split("%", 1)
338 bind_addr = (addr, bind_addr[1], 0, int(scope_id))
339
340 sock.bind(bind_addr)
341 sock.listen(backlog)
342 except BaseException:
343 sock.close()
344 raise
345
346 return sock
347
348 # We passing type=0 on non-Windows platforms as a workaround for a uvloop bug
349 # where we don't get the correct scope ID for IPv6 link-local addresses when passing
350 # type=socket.SOCK_STREAM to getaddrinfo():
351 # https://github.com/MagicStack/uvloop/issues/539
352 gai_res = await getaddrinfo(
353 local_host,
354 local_port,
355 family=family,
356 type=socket.SOCK_STREAM if sys.platform == "win32" else 0,
357 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
358 )
359
360 # The set comprehension is here to work around a glibc bug:
361 # https://sourceware.org/bugzilla/show_bug.cgi?id=14969
362 sockaddrs = sorted({res for res in gai_res if res[1] == SocketKind.SOCK_STREAM})
363
364 # Special case for dual-stack binding on the "any" interface
365 if (
366 local_host is None
367 and family == AddressFamily.AF_UNSPEC
368 and socket.has_dualstack_ipv6()
369 and any(fam == AddressFamily.AF_INET6 for fam, *_ in gai_res)
370 ):
371 raw_socket = setup_raw_socket(
372 AddressFamily.AF_INET6, ("::", local_port), v6only=False
373 )
374 listener = asynclib.create_tcp_listener(raw_socket)
375 return MultiListener([listener])
376
377 errors: list[OSError] = []
378 try:
379 for _ in range(len(sockaddrs)):
380 listeners: list[SocketListener] = []
381 bound_ephemeral_port = local_port
382 try:
383 for fam, *_, sockaddr in sockaddrs:
384 sockaddr = sockaddr[0], bound_ephemeral_port, *sockaddr[2:]
385 raw_socket = setup_raw_socket(fam, sockaddr)
386
387 # Store the assigned port if an ephemeral port was requested, so
388 # we'll bind to the same port on all interfaces
389 if local_port == 0 and len(gai_res) > 1:
390 bound_ephemeral_port = raw_socket.getsockname()[1]
391
392 listeners.append(asynclib.create_tcp_listener(raw_socket))
393 except BaseException as exc:
394 for listener in listeners:
395 await listener.aclose()
396
397 # If an ephemeral port was requested but binding the assigned port
398 # failed for another interface, rotate the address list and try again
399 if (
400 isinstance(exc, OSError)
401 and exc.errno == errno.EADDRINUSE
402 and local_port == 0
403 and bound_ephemeral_port
404 ):
405 errors.append(exc)
406 sockaddrs.append(sockaddrs.pop(0))
407 continue
408
409 raise
410
411 return MultiListener(listeners)
412
413 raise OSError(
414 f"Could not create {len(sockaddrs)} listeners with a consistent port"
415 ) from ExceptionGroup("Several bind attempts failed", errors)
416 finally:
417 del errors # Prevent reference cycles
418
419
420async def create_unix_listener(
421 path: str | bytes | PathLike[Any],
422 *,
423 mode: int | None = None,
424 backlog: int = 65536,
425) -> SocketListener:
426 """
427 Create a UNIX socket listener.
428
429 Not available on Windows.
430
431 :param path: path of the socket
432 :param mode: permissions to set on the socket
433 :param backlog: maximum number of queued incoming connections (up to a maximum of
434 2**16, or 65536)
435 :return: a listener object
436
437 .. versionchanged:: 3.0
438 If a socket already exists on the file system in the given path, it will be
439 removed first.
440
441 """
442 backlog = min(backlog, 65536)
443 raw_socket = await setup_unix_local_socket(path, mode, socket.SOCK_STREAM)
444 try:
445 raw_socket.listen(backlog)
446 return get_async_backend().create_unix_listener(raw_socket)
447 except BaseException:
448 raw_socket.close()
449 raise
450
451
452async def create_udp_socket(
453 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
454 *,
455 local_host: IPAddressType | None = None,
456 local_port: int = 0,
457 reuse_port: bool = False,
458) -> UDPSocket:
459 """
460 Create a UDP socket.
461
462 If ``port`` has been given, the socket will be bound to this port on the local
463 machine, making this socket suitable for providing UDP based services.
464
465 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically
466 determined from ``local_host`` if omitted
467 :param local_host: IP address or host name of the local interface to bind to
468 :param local_port: local port to bind to
469 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
470 address/port (not supported on Windows)
471 :return: a UDP socket
472
473 """
474 if family is AddressFamily.AF_UNSPEC and not local_host:
475 raise ValueError('Either "family" or "local_host" must be given')
476
477 if local_host:
478 gai_res = await getaddrinfo(
479 str(local_host),
480 local_port,
481 family=family,
482 type=socket.SOCK_DGRAM,
483 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
484 )
485 family = cast(AnyIPAddressFamily, gai_res[0][0])
486 local_address = gai_res[0][-1]
487 elif family is AddressFamily.AF_INET6:
488 local_address = ("::", 0)
489 else:
490 local_address = ("0.0.0.0", 0)
491
492 sock = await get_async_backend().create_udp_socket(
493 family, local_address, None, reuse_port
494 )
495 return cast(UDPSocket, sock)
496
497
498async def create_connected_udp_socket(
499 remote_host: IPAddressType,
500 remote_port: int,
501 *,
502 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
503 local_host: IPAddressType | None = None,
504 local_port: int = 0,
505 reuse_port: bool = False,
506) -> ConnectedUDPSocket:
507 """
508 Create a connected UDP socket.
509
510 Connected UDP sockets can only communicate with the specified remote host/port, an
511 any packets sent from other sources are dropped.
512
513 :param remote_host: remote host to set as the default target
514 :param remote_port: port on the remote host to set as the default target
515 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically
516 determined from ``local_host`` or ``remote_host`` if omitted
517 :param local_host: IP address or host name of the local interface to bind to
518 :param local_port: local port to bind to
519 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
520 address/port (not supported on Windows)
521 :return: a connected UDP socket
522
523 """
524 local_address = None
525 if local_host:
526 gai_res = await getaddrinfo(
527 str(local_host),
528 local_port,
529 family=family,
530 type=socket.SOCK_DGRAM,
531 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
532 )
533 family = cast(AnyIPAddressFamily, gai_res[0][0])
534 local_address = gai_res[0][-1]
535
536 gai_res = await getaddrinfo(
537 str(remote_host), remote_port, family=family, type=socket.SOCK_DGRAM
538 )
539 family = cast(AnyIPAddressFamily, gai_res[0][0])
540 remote_address = gai_res[0][-1]
541
542 sock = await get_async_backend().create_udp_socket(
543 family, local_address, remote_address, reuse_port
544 )
545 return cast(ConnectedUDPSocket, sock)
546
547
548async def create_unix_datagram_socket(
549 *,
550 local_path: None | str | bytes | PathLike[Any] = None,
551 local_mode: int | None = None,
552) -> UNIXDatagramSocket:
553 """
554 Create a UNIX datagram socket.
555
556 Not available on Windows.
557
558 If ``local_path`` has been given, the socket will be bound to this path, making this
559 socket suitable for receiving datagrams from other processes. Other processes can
560 send datagrams to this socket only if ``local_path`` is set.
561
562 If a socket already exists on the file system in the ``local_path``, it will be
563 removed first.
564
565 :param local_path: the path on which to bind to
566 :param local_mode: permissions to set on the local socket
567 :return: a UNIX datagram socket
568
569 """
570 raw_socket = await setup_unix_local_socket(
571 local_path, local_mode, socket.SOCK_DGRAM
572 )
573 return await get_async_backend().create_unix_datagram_socket(raw_socket, None)
574
575
576async def create_connected_unix_datagram_socket(
577 remote_path: str | bytes | PathLike[Any],
578 *,
579 local_path: None | str | bytes | PathLike[Any] = None,
580 local_mode: int | None = None,
581) -> ConnectedUNIXDatagramSocket:
582 """
583 Create a connected UNIX datagram socket.
584
585 Connected datagram sockets can only communicate with the specified remote path.
586
587 If ``local_path`` has been given, the socket will be bound to this path, making
588 this socket suitable for receiving datagrams from other processes. Other processes
589 can send datagrams to this socket only if ``local_path`` is set.
590
591 If a socket already exists on the file system in the ``local_path``, it will be
592 removed first.
593
594 :param remote_path: the path to set as the default target
595 :param local_path: the path on which to bind to
596 :param local_mode: permissions to set on the local socket
597 :return: a connected UNIX datagram socket
598
599 """
600 remote_path = os.fspath(remote_path)
601 raw_socket = await setup_unix_local_socket(
602 local_path, local_mode, socket.SOCK_DGRAM
603 )
604 return await get_async_backend().create_unix_datagram_socket(
605 raw_socket, remote_path
606 )
607
608
609async def getaddrinfo(
610 host: bytes | str | None,
611 port: str | int | None,
612 *,
613 family: int | AddressFamily = 0,
614 type: int | SocketKind = 0,
615 proto: int = 0,
616 flags: int = 0,
617) -> list[tuple[AddressFamily, SocketKind, int, str, tuple[str, int]]]:
618 """
619 Look up a numeric IP address given a host name.
620
621 Internationalized domain names are translated according to the (non-transitional)
622 IDNA 2008 standard.
623
624 .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of
625 (host, port), unlike what :func:`socket.getaddrinfo` does.
626
627 :param host: host name
628 :param port: port number
629 :param family: socket family (`'AF_INET``, ...)
630 :param type: socket type (``SOCK_STREAM``, ...)
631 :param proto: protocol number
632 :param flags: flags to pass to upstream ``getaddrinfo()``
633 :return: list of tuples containing (family, type, proto, canonname, sockaddr)
634
635 .. seealso:: :func:`socket.getaddrinfo`
636
637 """
638 # Handle unicode hostnames
639 if isinstance(host, str):
640 try:
641 encoded_host: bytes | None = host.encode("ascii")
642 except UnicodeEncodeError:
643 import idna
644
645 encoded_host = idna.encode(host, uts46=True)
646 else:
647 encoded_host = host
648
649 gai_res = await get_async_backend().getaddrinfo(
650 encoded_host, port, family=family, type=type, proto=proto, flags=flags
651 )
652 return [
653 (family, type, proto, canonname, convert_ipv6_sockaddr(sockaddr))
654 for family, type, proto, canonname, sockaddr in gai_res
655 # filter out IPv6 results when IPv6 is disabled
656 if not isinstance(sockaddr[0], int)
657 ]
658
659
660def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str, str]]:
661 """
662 Look up the host name of an IP address.
663
664 :param sockaddr: socket address (e.g. (ipaddress, port) for IPv4)
665 :param flags: flags to pass to upstream ``getnameinfo()``
666 :return: a tuple of (host name, service name)
667
668 .. seealso:: :func:`socket.getnameinfo`
669
670 """
671 return get_async_backend().getnameinfo(sockaddr, flags)
672
673
674@deprecated("This function is deprecated; use `wait_readable` instead")
675def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
676 """
677 .. deprecated:: 4.7.0
678 Use :func:`wait_readable` instead.
679
680 Wait until the given socket has data to be read.
681
682 .. warning:: Only use this on raw sockets that have not been wrapped by any higher
683 level constructs like socket streams!
684
685 :param sock: a socket object
686 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
687 socket to become readable
688 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
689 to become readable
690
691 """
692 return get_async_backend().wait_readable(sock.fileno())
693
694
695@deprecated("This function is deprecated; use `wait_writable` instead")
696def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
697 """
698 .. deprecated:: 4.7.0
699 Use :func:`wait_writable` instead.
700
701 Wait until the given socket can be written to.
702
703 This does **NOT** work on Windows when using the asyncio backend with a proactor
704 event loop (default on py3.8+).
705
706 .. warning:: Only use this on raw sockets that have not been wrapped by any higher
707 level constructs like socket streams!
708
709 :param sock: a socket object
710 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
711 socket to become writable
712 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
713 to become writable
714
715 """
716 return get_async_backend().wait_writable(sock.fileno())
717
718
719def wait_readable(obj: FileDescriptorLike) -> Awaitable[None]:
720 """
721 Wait until the given object has data to be read.
722
723 On Unix systems, ``obj`` must either be an integer file descriptor, or else an
724 object with a ``.fileno()`` method which returns an integer file descriptor. Any
725 kind of file descriptor can be passed, though the exact semantics will depend on
726 your kernel. For example, this probably won't do anything useful for on-disk files.
727
728 On Windows systems, ``obj`` must either be an integer ``SOCKET`` handle, or else an
729 object with a ``.fileno()`` method which returns an integer ``SOCKET`` handle. File
730 descriptors aren't supported, and neither are handles that refer to anything besides
731 a ``SOCKET``.
732
733 On backends where this functionality is not natively provided (asyncio
734 ``ProactorEventLoop`` on Windows), it is provided using a separate selector thread
735 which is set to shut down when the interpreter shuts down.
736
737 .. warning:: Don't use this on raw sockets that have been wrapped by any higher
738 level constructs like socket streams!
739
740 :param obj: an object with a ``.fileno()`` method or an integer handle
741 :raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
742 object to become readable
743 :raises ~anyio.BusyResourceError: if another task is already waiting for the object
744 to become readable
745
746 """
747 return get_async_backend().wait_readable(obj)
748
749
750def wait_writable(obj: FileDescriptorLike) -> Awaitable[None]:
751 """
752 Wait until the given object can be written to.
753
754 :param obj: an object with a ``.fileno()`` method or an integer handle
755 :raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
756 object to become writable
757 :raises ~anyio.BusyResourceError: if another task is already waiting for the object
758 to become writable
759
760 .. seealso:: See the documentation of :func:`wait_readable` for the definition of
761 ``obj`` and notes on backend compatibility.
762
763 .. warning:: Don't use this on raw sockets that have been wrapped by any higher
764 level constructs like socket streams!
765
766 """
767 return get_async_backend().wait_writable(obj)
768
769
770def notify_closing(obj: FileDescriptorLike) -> None:
771 """
772 Call this before closing a file descriptor (on Unix) or socket (on
773 Windows). This will cause any `wait_readable` or `wait_writable`
774 calls on the given object to immediately wake up and raise
775 `~anyio.ClosedResourceError`.
776
777 This doesn't actually close the object – you still have to do that
778 yourself afterwards. Also, you want to be careful to make sure no
779 new tasks start waiting on the object in between when you call this
780 and when it's actually closed. So to close something properly, you
781 usually want to do these steps in order:
782
783 1. Explicitly mark the object as closed, so that any new attempts
784 to use it will abort before they start.
785 2. Call `notify_closing` to wake up any already-existing users.
786 3. Actually close the object.
787
788 It's also possible to do them in a different order if that's more
789 convenient, *but only if* you make sure not to have any checkpoints in
790 between the steps. This way they all happen in a single atomic
791 step, so other tasks won't be able to tell what order they happened
792 in anyway.
793
794 :param obj: an object with a ``.fileno()`` method or an integer handle
795
796 """
797 get_async_backend().notify_closing(obj)
798
799
800#
801# Private API
802#
803
804
805def convert_ipv6_sockaddr(
806 sockaddr: tuple[str, int, int, int] | tuple[str, int],
807) -> tuple[str, int]:
808 """
809 Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format.
810
811 If the scope ID is nonzero, it is added to the address, separated with ``%``.
812 Otherwise the flow id and scope id are simply cut off from the tuple.
813 Any other kinds of socket addresses are returned as-is.
814
815 :param sockaddr: the result of :meth:`~socket.socket.getsockname`
816 :return: the converted socket address
817
818 """
819 # This is more complicated than it should be because of MyPy
820 if isinstance(sockaddr, tuple) and len(sockaddr) == 4:
821 host, port, flowinfo, scope_id = sockaddr
822 if scope_id:
823 # PyPy (as of v7.3.11) leaves the interface name in the result, so
824 # we discard it and only get the scope ID from the end
825 # (https://foss.heptapod.net/pypy/pypy/-/issues/3938)
826 host = host.split("%")[0]
827
828 # Add scope_id to the address
829 return f"{host}%{scope_id}", port
830 else:
831 return host, port
832 else:
833 return sockaddr
834
835
836async def setup_unix_local_socket(
837 path: None | str | bytes | PathLike[Any],
838 mode: int | None,
839 socktype: int,
840) -> socket.socket:
841 """
842 Create a UNIX local socket object, deleting the socket at the given path if it
843 exists.
844
845 Not available on Windows.
846
847 :param path: path of the socket
848 :param mode: permissions to set on the socket
849 :param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM
850
851 """
852 path_str: str | None
853 if path is not None:
854 path_str = os.fsdecode(path)
855
856 # Linux abstract namespace sockets aren't backed by a concrete file so skip stat call
857 if not path_str.startswith("\0"):
858 # Copied from pathlib...
859 try:
860 stat_result = os.stat(path)
861 except OSError as e:
862 if e.errno not in (
863 errno.ENOENT,
864 errno.ENOTDIR,
865 errno.EBADF,
866 errno.ELOOP,
867 ):
868 raise
869 else:
870 if stat.S_ISSOCK(stat_result.st_mode):
871 os.unlink(path)
872 else:
873 path_str = None
874
875 raw_socket = socket.socket(socket.AF_UNIX, socktype)
876 raw_socket.setblocking(False)
877
878 if path_str is not None:
879 try:
880 await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True)
881 if mode is not None:
882 await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True)
883 except BaseException:
884 raw_socket.close()
885 raise
886
887 return raw_socket
888
889
890@dataclass
891class TCPConnectable(ByteStreamConnectable):
892 """
893 Connects to a TCP server at the given host and port.
894
895 :param host: host name or IP address of the server
896 :param port: TCP port number of the server
897 """
898
899 host: str | IPv4Address | IPv6Address
900 port: int
901
902 def __post_init__(self) -> None:
903 if self.port < 1 or self.port > 65535:
904 raise ValueError("TCP port number out of range")
905
906 @override
907 async def connect(self) -> SocketStream:
908 try:
909 return await connect_tcp(self.host, self.port)
910 except OSError as exc:
911 raise ConnectionFailed(
912 f"error connecting to {self.host}:{self.port}: {exc}"
913 ) from exc
914
915
916@dataclass
917class UNIXConnectable(ByteStreamConnectable):
918 """
919 Connects to a UNIX domain socket at the given path.
920
921 :param path: the file system path of the socket
922 """
923
924 path: str | bytes | PathLike[str] | PathLike[bytes]
925
926 @override
927 async def connect(self) -> UNIXSocketStream:
928 try:
929 return await connect_unix(self.path)
930 except OSError as exc:
931 raise ConnectionFailed(f"error connecting to {self.path!r}: {exc}") from exc
932
933
934def as_connectable(
935 remote: ByteStreamConnectable
936 | tuple[str | IPv4Address | IPv6Address, int]
937 | str
938 | bytes
939 | PathLike[str],
940 /,
941 *,
942 tls: bool = False,
943 ssl_context: ssl.SSLContext | None = None,
944 tls_hostname: str | None = None,
945 tls_standard_compatible: bool = True,
946) -> ByteStreamConnectable:
947 """
948 Return a byte stream connectable from the given object.
949
950 If a bytestream connectable is given, it is returned unchanged.
951 If a tuple of (host, port) is given, a TCP connectable is returned.
952 If a string or bytes path is given, a UNIX connectable is returned.
953
954 If ``tls=True``, the connectable will be wrapped in a
955 :class:`~.streams.tls.TLSConnectable`.
956
957 :param remote: a connectable, a tuple of (host, port) or a path to a UNIX socket
958 :param tls: if ``True``, wrap the plaintext connectable in a
959 :class:`~.streams.tls.TLSConnectable`, using the provided TLS settings)
960 :param ssl_context: if ``tls=True``, the SSLContext object to use (if not provided,
961 a secure default will be created)
962 :param tls_hostname: if ``tls=True``, host name of the server to use for checking
963 the server certificate (defaults to the host portion of the address for TCP
964 connectables)
965 :param tls_standard_compatible: if ``False`` and ``tls=True``, makes the TLS stream
966 skip the closing handshake when closing the connection, so it won't raise an
967 exception if the server does the same
968
969 """
970 connectable: TCPConnectable | UNIXConnectable | TLSConnectable
971 if isinstance(remote, ByteStreamConnectable):
972 return remote
973 elif isinstance(remote, tuple) and len(remote) == 2:
974 connectable = TCPConnectable(*remote)
975 elif isinstance(remote, (str, bytes, PathLike)):
976 connectable = UNIXConnectable(remote)
977 else:
978 raise TypeError(f"cannot convert {remote!r} to a connectable")
979
980 if tls:
981 if not tls_hostname and isinstance(connectable, TCPConnectable):
982 tls_hostname = str(connectable.host)
983
984 connectable = TLSConnectable(
985 connectable,
986 ssl_context=ssl_context,
987 hostname=tls_hostname,
988 standard_compatible=tls_standard_compatible,
989 )
990
991 return connectable