1from __future__ import annotations
2
3import errno
4import os
5import socket
6import ssl
7import stat
8import sys
9from collections.abc import Awaitable
10from ipaddress import IPv6Address, ip_address
11from os import PathLike, chmod
12from socket import AddressFamily, SocketKind
13from typing import Any, Literal, cast, overload
14
15from .. import to_thread
16from ..abc import (
17 ConnectedUDPSocket,
18 ConnectedUNIXDatagramSocket,
19 IPAddressType,
20 IPSockAddrType,
21 SocketListener,
22 SocketStream,
23 UDPSocket,
24 UNIXDatagramSocket,
25 UNIXSocketStream,
26)
27from ..streams.stapled import MultiListener
28from ..streams.tls import TLSStream
29from ._eventloop import get_async_backend
30from ._resources import aclose_forcefully
31from ._synchronization import Event
32from ._tasks import create_task_group, move_on_after
33
34if sys.version_info < (3, 11):
35 from exceptiongroup import ExceptionGroup
36
37IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515
38
39AnyIPAddressFamily = Literal[
40 AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
41]
42IPAddressFamily = Literal[AddressFamily.AF_INET, AddressFamily.AF_INET6]
43
44
45# tls_hostname given
46@overload
47async def connect_tcp(
48 remote_host: IPAddressType,
49 remote_port: int,
50 *,
51 local_host: IPAddressType | None = ...,
52 ssl_context: ssl.SSLContext | None = ...,
53 tls_standard_compatible: bool = ...,
54 tls_hostname: str,
55 happy_eyeballs_delay: float = ...,
56) -> TLSStream: ...
57
58
59# ssl_context given
60@overload
61async def connect_tcp(
62 remote_host: IPAddressType,
63 remote_port: int,
64 *,
65 local_host: IPAddressType | None = ...,
66 ssl_context: ssl.SSLContext,
67 tls_standard_compatible: bool = ...,
68 tls_hostname: str | None = ...,
69 happy_eyeballs_delay: float = ...,
70) -> TLSStream: ...
71
72
73# tls=True
74@overload
75async def connect_tcp(
76 remote_host: IPAddressType,
77 remote_port: int,
78 *,
79 local_host: IPAddressType | None = ...,
80 tls: Literal[True],
81 ssl_context: ssl.SSLContext | None = ...,
82 tls_standard_compatible: bool = ...,
83 tls_hostname: str | None = ...,
84 happy_eyeballs_delay: float = ...,
85) -> TLSStream: ...
86
87
88# tls=False
89@overload
90async def connect_tcp(
91 remote_host: IPAddressType,
92 remote_port: int,
93 *,
94 local_host: IPAddressType | None = ...,
95 tls: Literal[False],
96 ssl_context: ssl.SSLContext | None = ...,
97 tls_standard_compatible: bool = ...,
98 tls_hostname: str | None = ...,
99 happy_eyeballs_delay: float = ...,
100) -> SocketStream: ...
101
102
103# No TLS arguments
104@overload
105async def connect_tcp(
106 remote_host: IPAddressType,
107 remote_port: int,
108 *,
109 local_host: IPAddressType | None = ...,
110 happy_eyeballs_delay: float = ...,
111) -> SocketStream: ...
112
113
114async def connect_tcp(
115 remote_host: IPAddressType,
116 remote_port: int,
117 *,
118 local_host: IPAddressType | None = None,
119 tls: bool = False,
120 ssl_context: ssl.SSLContext | None = None,
121 tls_standard_compatible: bool = True,
122 tls_hostname: str | None = None,
123 happy_eyeballs_delay: float = 0.25,
124) -> SocketStream | TLSStream:
125 """
126 Connect to a host using the TCP protocol.
127
128 This function implements the stateless version of the Happy Eyeballs algorithm (RFC
129 6555). If ``remote_host`` is a host name that resolves to multiple IP addresses,
130 each one is tried until one connection attempt succeeds. If the first attempt does
131 not connected within 250 milliseconds, a second attempt is started using the next
132 address in the list, and so on. On IPv6 enabled systems, an IPv6 address (if
133 available) is tried first.
134
135 When the connection has been established, a TLS handshake will be done if either
136 ``ssl_context`` or ``tls_hostname`` is not ``None``, or if ``tls`` is ``True``.
137
138 :param remote_host: the IP address or host name to connect to
139 :param remote_port: port on the target host to connect to
140 :param local_host: the interface address or name to bind the socket to before
141 connecting
142 :param tls: ``True`` to do a TLS handshake with the connected stream and return a
143 :class:`~anyio.streams.tls.TLSStream` instead
144 :param ssl_context: the SSL context object to use (if omitted, a default context is
145 created)
146 :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake
147 before closing the stream and requires that the server does this as well.
148 Otherwise, :exc:`~ssl.SSLEOFError` may be raised during reads from the stream.
149 Some protocols, such as HTTP, require this option to be ``False``.
150 See :meth:`~ssl.SSLContext.wrap_socket` for details.
151 :param tls_hostname: host name to check the server certificate against (defaults to
152 the value of ``remote_host``)
153 :param happy_eyeballs_delay: delay (in seconds) before starting the next connection
154 attempt
155 :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream
156 :raises OSError: if the connection attempt fails
157
158 """
159 # Placed here due to https://github.com/python/mypy/issues/7057
160 connected_stream: SocketStream | None = None
161
162 async def try_connect(remote_host: str, event: Event) -> None:
163 nonlocal connected_stream
164 try:
165 stream = await asynclib.connect_tcp(remote_host, remote_port, local_address)
166 except OSError as exc:
167 oserrors.append(exc)
168 return
169 else:
170 if connected_stream is None:
171 connected_stream = stream
172 tg.cancel_scope.cancel()
173 else:
174 await stream.aclose()
175 finally:
176 event.set()
177
178 asynclib = get_async_backend()
179 local_address: IPSockAddrType | None = None
180 family = socket.AF_UNSPEC
181 if local_host:
182 gai_res = await getaddrinfo(str(local_host), None)
183 family, *_, local_address = gai_res[0]
184
185 target_host = str(remote_host)
186 try:
187 addr_obj = ip_address(remote_host)
188 except ValueError:
189 # getaddrinfo() will raise an exception if name resolution fails
190 gai_res = await getaddrinfo(
191 target_host, remote_port, family=family, type=socket.SOCK_STREAM
192 )
193
194 # Organize the list so that the first address is an IPv6 address (if available)
195 # and the second one is an IPv4 addresses. The rest can be in whatever order.
196 v6_found = v4_found = False
197 target_addrs: list[tuple[socket.AddressFamily, str]] = []
198 for af, *rest, sa in gai_res:
199 if af == socket.AF_INET6 and not v6_found:
200 v6_found = True
201 target_addrs.insert(0, (af, sa[0]))
202 elif af == socket.AF_INET and not v4_found and v6_found:
203 v4_found = True
204 target_addrs.insert(1, (af, sa[0]))
205 else:
206 target_addrs.append((af, sa[0]))
207 else:
208 if isinstance(addr_obj, IPv6Address):
209 target_addrs = [(socket.AF_INET6, addr_obj.compressed)]
210 else:
211 target_addrs = [(socket.AF_INET, addr_obj.compressed)]
212
213 oserrors: list[OSError] = []
214 async with create_task_group() as tg:
215 for i, (af, addr) in enumerate(target_addrs):
216 event = Event()
217 tg.start_soon(try_connect, addr, event)
218 with move_on_after(happy_eyeballs_delay):
219 await event.wait()
220
221 if connected_stream is None:
222 cause = (
223 oserrors[0]
224 if len(oserrors) == 1
225 else ExceptionGroup("multiple connection attempts failed", oserrors)
226 )
227 raise OSError("All connection attempts failed") from cause
228
229 if tls or tls_hostname or ssl_context:
230 try:
231 return await TLSStream.wrap(
232 connected_stream,
233 server_side=False,
234 hostname=tls_hostname or str(remote_host),
235 ssl_context=ssl_context,
236 standard_compatible=tls_standard_compatible,
237 )
238 except BaseException:
239 await aclose_forcefully(connected_stream)
240 raise
241
242 return connected_stream
243
244
245async def connect_unix(path: str | bytes | PathLike[Any]) -> UNIXSocketStream:
246 """
247 Connect to the given UNIX socket.
248
249 Not available on Windows.
250
251 :param path: path to the socket
252 :return: a socket stream object
253
254 """
255 path = os.fspath(path)
256 return await get_async_backend().connect_unix(path)
257
258
259async def create_tcp_listener(
260 *,
261 local_host: IPAddressType | None = None,
262 local_port: int = 0,
263 family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC,
264 backlog: int = 65536,
265 reuse_port: bool = False,
266) -> MultiListener[SocketStream]:
267 """
268 Create a TCP socket listener.
269
270 :param local_port: port number to listen on
271 :param local_host: IP address of the interface to listen on. If omitted, listen on
272 all IPv4 and IPv6 interfaces. To listen on all interfaces on a specific address
273 family, use ``0.0.0.0`` for IPv4 or ``::`` for IPv6.
274 :param family: address family (used if ``local_host`` was omitted)
275 :param backlog: maximum number of queued incoming connections (up to a maximum of
276 2**16, or 65536)
277 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
278 address/port (not supported on Windows)
279 :return: a list of listener objects
280
281 """
282 asynclib = get_async_backend()
283 backlog = min(backlog, 65536)
284 local_host = str(local_host) if local_host is not None else None
285 gai_res = await getaddrinfo(
286 local_host,
287 local_port,
288 family=family,
289 type=socket.SocketKind.SOCK_STREAM if sys.platform == "win32" else 0,
290 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
291 )
292 listeners: list[SocketListener] = []
293 try:
294 # The set() is here to work around a glibc bug:
295 # https://sourceware.org/bugzilla/show_bug.cgi?id=14969
296 sockaddr: tuple[str, int] | tuple[str, int, int, int]
297 for fam, kind, *_, sockaddr in sorted(set(gai_res)):
298 # Workaround for an uvloop bug where we don't get the correct scope ID for
299 # IPv6 link-local addresses when passing type=socket.SOCK_STREAM to
300 # getaddrinfo(): https://github.com/MagicStack/uvloop/issues/539
301 if sys.platform != "win32" and kind is not SocketKind.SOCK_STREAM:
302 continue
303
304 raw_socket = socket.socket(fam)
305 raw_socket.setblocking(False)
306
307 # For Windows, enable exclusive address use. For others, enable address
308 # reuse.
309 if sys.platform == "win32":
310 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
311 else:
312 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
313
314 if reuse_port:
315 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
316
317 # If only IPv6 was requested, disable dual stack operation
318 if fam == socket.AF_INET6:
319 raw_socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
320
321 # Workaround for #554
322 if "%" in sockaddr[0]:
323 addr, scope_id = sockaddr[0].split("%", 1)
324 sockaddr = (addr, sockaddr[1], 0, int(scope_id))
325
326 raw_socket.bind(sockaddr)
327 raw_socket.listen(backlog)
328 listener = asynclib.create_tcp_listener(raw_socket)
329 listeners.append(listener)
330 except BaseException:
331 for listener in listeners:
332 await listener.aclose()
333
334 raise
335
336 return MultiListener(listeners)
337
338
339async def create_unix_listener(
340 path: str | bytes | PathLike[Any],
341 *,
342 mode: int | None = None,
343 backlog: int = 65536,
344) -> SocketListener:
345 """
346 Create a UNIX socket listener.
347
348 Not available on Windows.
349
350 :param path: path of the socket
351 :param mode: permissions to set on the socket
352 :param backlog: maximum number of queued incoming connections (up to a maximum of
353 2**16, or 65536)
354 :return: a listener object
355
356 .. versionchanged:: 3.0
357 If a socket already exists on the file system in the given path, it will be
358 removed first.
359
360 """
361 backlog = min(backlog, 65536)
362 raw_socket = await setup_unix_local_socket(path, mode, socket.SOCK_STREAM)
363 try:
364 raw_socket.listen(backlog)
365 return get_async_backend().create_unix_listener(raw_socket)
366 except BaseException:
367 raw_socket.close()
368 raise
369
370
371async def create_udp_socket(
372 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
373 *,
374 local_host: IPAddressType | None = None,
375 local_port: int = 0,
376 reuse_port: bool = False,
377) -> UDPSocket:
378 """
379 Create a UDP socket.
380
381 If ``port`` has been given, the socket will be bound to this port on the local
382 machine, making this socket suitable for providing UDP based services.
383
384 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically
385 determined from ``local_host`` if omitted
386 :param local_host: IP address or host name of the local interface to bind to
387 :param local_port: local port to bind to
388 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
389 address/port (not supported on Windows)
390 :return: a UDP socket
391
392 """
393 if family is AddressFamily.AF_UNSPEC and not local_host:
394 raise ValueError('Either "family" or "local_host" must be given')
395
396 if local_host:
397 gai_res = await getaddrinfo(
398 str(local_host),
399 local_port,
400 family=family,
401 type=socket.SOCK_DGRAM,
402 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
403 )
404 family = cast(AnyIPAddressFamily, gai_res[0][0])
405 local_address = gai_res[0][-1]
406 elif family is AddressFamily.AF_INET6:
407 local_address = ("::", 0)
408 else:
409 local_address = ("0.0.0.0", 0)
410
411 sock = await get_async_backend().create_udp_socket(
412 family, local_address, None, reuse_port
413 )
414 return cast(UDPSocket, sock)
415
416
417async def create_connected_udp_socket(
418 remote_host: IPAddressType,
419 remote_port: int,
420 *,
421 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
422 local_host: IPAddressType | None = None,
423 local_port: int = 0,
424 reuse_port: bool = False,
425) -> ConnectedUDPSocket:
426 """
427 Create a connected UDP socket.
428
429 Connected UDP sockets can only communicate with the specified remote host/port, an
430 any packets sent from other sources are dropped.
431
432 :param remote_host: remote host to set as the default target
433 :param remote_port: port on the remote host to set as the default target
434 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically
435 determined from ``local_host`` or ``remote_host`` if omitted
436 :param local_host: IP address or host name of the local interface to bind to
437 :param local_port: local port to bind to
438 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
439 address/port (not supported on Windows)
440 :return: a connected UDP socket
441
442 """
443 local_address = None
444 if local_host:
445 gai_res = await getaddrinfo(
446 str(local_host),
447 local_port,
448 family=family,
449 type=socket.SOCK_DGRAM,
450 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
451 )
452 family = cast(AnyIPAddressFamily, gai_res[0][0])
453 local_address = gai_res[0][-1]
454
455 gai_res = await getaddrinfo(
456 str(remote_host), remote_port, family=family, type=socket.SOCK_DGRAM
457 )
458 family = cast(AnyIPAddressFamily, gai_res[0][0])
459 remote_address = gai_res[0][-1]
460
461 sock = await get_async_backend().create_udp_socket(
462 family, local_address, remote_address, reuse_port
463 )
464 return cast(ConnectedUDPSocket, sock)
465
466
467async def create_unix_datagram_socket(
468 *,
469 local_path: None | str | bytes | PathLike[Any] = None,
470 local_mode: int | None = None,
471) -> UNIXDatagramSocket:
472 """
473 Create a UNIX datagram socket.
474
475 Not available on Windows.
476
477 If ``local_path`` has been given, the socket will be bound to this path, making this
478 socket suitable for receiving datagrams from other processes. Other processes can
479 send datagrams to this socket only if ``local_path`` is set.
480
481 If a socket already exists on the file system in the ``local_path``, it will be
482 removed first.
483
484 :param local_path: the path on which to bind to
485 :param local_mode: permissions to set on the local socket
486 :return: a UNIX datagram socket
487
488 """
489 raw_socket = await setup_unix_local_socket(
490 local_path, local_mode, socket.SOCK_DGRAM
491 )
492 return await get_async_backend().create_unix_datagram_socket(raw_socket, None)
493
494
495async def create_connected_unix_datagram_socket(
496 remote_path: str | bytes | PathLike[Any],
497 *,
498 local_path: None | str | bytes | PathLike[Any] = None,
499 local_mode: int | None = None,
500) -> ConnectedUNIXDatagramSocket:
501 """
502 Create a connected UNIX datagram socket.
503
504 Connected datagram sockets can only communicate with the specified remote path.
505
506 If ``local_path`` has been given, the socket will be bound to this path, making
507 this socket suitable for receiving datagrams from other processes. Other processes
508 can send datagrams to this socket only if ``local_path`` is set.
509
510 If a socket already exists on the file system in the ``local_path``, it will be
511 removed first.
512
513 :param remote_path: the path to set as the default target
514 :param local_path: the path on which to bind to
515 :param local_mode: permissions to set on the local socket
516 :return: a connected UNIX datagram socket
517
518 """
519 remote_path = os.fspath(remote_path)
520 raw_socket = await setup_unix_local_socket(
521 local_path, local_mode, socket.SOCK_DGRAM
522 )
523 return await get_async_backend().create_unix_datagram_socket(
524 raw_socket, remote_path
525 )
526
527
528async def getaddrinfo(
529 host: bytes | str | None,
530 port: str | int | None,
531 *,
532 family: int | AddressFamily = 0,
533 type: int | SocketKind = 0,
534 proto: int = 0,
535 flags: int = 0,
536) -> list[tuple[AddressFamily, SocketKind, int, str, tuple[str, int]]]:
537 """
538 Look up a numeric IP address given a host name.
539
540 Internationalized domain names are translated according to the (non-transitional)
541 IDNA 2008 standard.
542
543 .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of
544 (host, port), unlike what :func:`socket.getaddrinfo` does.
545
546 :param host: host name
547 :param port: port number
548 :param family: socket family (`'AF_INET``, ...)
549 :param type: socket type (``SOCK_STREAM``, ...)
550 :param proto: protocol number
551 :param flags: flags to pass to upstream ``getaddrinfo()``
552 :return: list of tuples containing (family, type, proto, canonname, sockaddr)
553
554 .. seealso:: :func:`socket.getaddrinfo`
555
556 """
557 # Handle unicode hostnames
558 if isinstance(host, str):
559 try:
560 encoded_host: bytes | None = host.encode("ascii")
561 except UnicodeEncodeError:
562 import idna
563
564 encoded_host = idna.encode(host, uts46=True)
565 else:
566 encoded_host = host
567
568 gai_res = await get_async_backend().getaddrinfo(
569 encoded_host, port, family=family, type=type, proto=proto, flags=flags
570 )
571 return [
572 (family, type, proto, canonname, convert_ipv6_sockaddr(sockaddr))
573 for family, type, proto, canonname, sockaddr in gai_res
574 ]
575
576
577def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str, str]]:
578 """
579 Look up the host name of an IP address.
580
581 :param sockaddr: socket address (e.g. (ipaddress, port) for IPv4)
582 :param flags: flags to pass to upstream ``getnameinfo()``
583 :return: a tuple of (host name, service name)
584
585 .. seealso:: :func:`socket.getnameinfo`
586
587 """
588 return get_async_backend().getnameinfo(sockaddr, flags)
589
590
591def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
592 """
593 Wait until the given socket has data to be read.
594
595 This does **NOT** work on Windows when using the asyncio backend with a proactor
596 event loop (default on py3.8+).
597
598 .. warning:: Only use this on raw sockets that have not been wrapped by any higher
599 level constructs like socket streams!
600
601 :param sock: a socket object
602 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
603 socket to become readable
604 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
605 to become readable
606
607 """
608 return get_async_backend().wait_socket_readable(sock)
609
610
611def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
612 """
613 Wait until the given socket can be written to.
614
615 This does **NOT** work on Windows when using the asyncio backend with a proactor
616 event loop (default on py3.8+).
617
618 .. warning:: Only use this on raw sockets that have not been wrapped by any higher
619 level constructs like socket streams!
620
621 :param sock: a socket object
622 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
623 socket to become writable
624 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
625 to become writable
626
627 """
628 return get_async_backend().wait_socket_writable(sock)
629
630
631#
632# Private API
633#
634
635
636def convert_ipv6_sockaddr(
637 sockaddr: tuple[str, int, int, int] | tuple[str, int],
638) -> tuple[str, int]:
639 """
640 Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format.
641
642 If the scope ID is nonzero, it is added to the address, separated with ``%``.
643 Otherwise the flow id and scope id are simply cut off from the tuple.
644 Any other kinds of socket addresses are returned as-is.
645
646 :param sockaddr: the result of :meth:`~socket.socket.getsockname`
647 :return: the converted socket address
648
649 """
650 # This is more complicated than it should be because of MyPy
651 if isinstance(sockaddr, tuple) and len(sockaddr) == 4:
652 host, port, flowinfo, scope_id = sockaddr
653 if scope_id:
654 # PyPy (as of v7.3.11) leaves the interface name in the result, so
655 # we discard it and only get the scope ID from the end
656 # (https://foss.heptapod.net/pypy/pypy/-/issues/3938)
657 host = host.split("%")[0]
658
659 # Add scope_id to the address
660 return f"{host}%{scope_id}", port
661 else:
662 return host, port
663 else:
664 return sockaddr
665
666
667async def setup_unix_local_socket(
668 path: None | str | bytes | PathLike[Any],
669 mode: int | None,
670 socktype: int,
671) -> socket.socket:
672 """
673 Create a UNIX local socket object, deleting the socket at the given path if it
674 exists.
675
676 Not available on Windows.
677
678 :param path: path of the socket
679 :param mode: permissions to set on the socket
680 :param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM
681
682 """
683 path_str: str | bytes | None
684 if path is not None:
685 path_str = os.fspath(path)
686
687 # Copied from pathlib...
688 try:
689 stat_result = os.stat(path)
690 except OSError as e:
691 if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP):
692 raise
693 else:
694 if stat.S_ISSOCK(stat_result.st_mode):
695 os.unlink(path)
696 else:
697 path_str = None
698
699 raw_socket = socket.socket(socket.AF_UNIX, socktype)
700 raw_socket.setblocking(False)
701
702 if path_str is not None:
703 try:
704 await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True)
705 if mode is not None:
706 await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True)
707 except BaseException:
708 raw_socket.close()
709 raise
710
711 return raw_socket