1from __future__ import annotations
2
3import errno
4import os
5import socket
6import ssl
7import stat
8import sys
9from collections.abc import Awaitable
10from ipaddress import IPv6Address, ip_address
11from os import PathLike, chmod
12from socket import AddressFamily, SocketKind
13from typing import Any, Literal, cast, overload
14
15from .. import to_thread
16from ..abc import (
17 ConnectedUDPSocket,
18 ConnectedUNIXDatagramSocket,
19 IPAddressType,
20 IPSockAddrType,
21 SocketListener,
22 SocketStream,
23 UDPSocket,
24 UNIXDatagramSocket,
25 UNIXSocketStream,
26)
27from ..streams.stapled import MultiListener
28from ..streams.tls import TLSStream
29from ._eventloop import get_async_backend
30from ._resources import aclose_forcefully
31from ._synchronization import Event
32from ._tasks import create_task_group, move_on_after
33
34if sys.version_info < (3, 11):
35 from exceptiongroup import ExceptionGroup
36
37IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515
38
39AnyIPAddressFamily = Literal[
40 AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
41]
42IPAddressFamily = Literal[AddressFamily.AF_INET, AddressFamily.AF_INET6]
43
44
45# tls_hostname given
46@overload
47async def connect_tcp(
48 remote_host: IPAddressType,
49 remote_port: int,
50 *,
51 local_host: IPAddressType | None = ...,
52 ssl_context: ssl.SSLContext | None = ...,
53 tls_standard_compatible: bool = ...,
54 tls_hostname: str,
55 happy_eyeballs_delay: float = ...,
56) -> TLSStream:
57 ...
58
59
60# ssl_context given
61@overload
62async def connect_tcp(
63 remote_host: IPAddressType,
64 remote_port: int,
65 *,
66 local_host: IPAddressType | None = ...,
67 ssl_context: ssl.SSLContext,
68 tls_standard_compatible: bool = ...,
69 tls_hostname: str | None = ...,
70 happy_eyeballs_delay: float = ...,
71) -> TLSStream:
72 ...
73
74
75# tls=True
76@overload
77async def connect_tcp(
78 remote_host: IPAddressType,
79 remote_port: int,
80 *,
81 local_host: IPAddressType | None = ...,
82 tls: Literal[True],
83 ssl_context: ssl.SSLContext | None = ...,
84 tls_standard_compatible: bool = ...,
85 tls_hostname: str | None = ...,
86 happy_eyeballs_delay: float = ...,
87) -> TLSStream:
88 ...
89
90
91# tls=False
92@overload
93async def connect_tcp(
94 remote_host: IPAddressType,
95 remote_port: int,
96 *,
97 local_host: IPAddressType | None = ...,
98 tls: Literal[False],
99 ssl_context: ssl.SSLContext | None = ...,
100 tls_standard_compatible: bool = ...,
101 tls_hostname: str | None = ...,
102 happy_eyeballs_delay: float = ...,
103) -> SocketStream:
104 ...
105
106
107# No TLS arguments
108@overload
109async def connect_tcp(
110 remote_host: IPAddressType,
111 remote_port: int,
112 *,
113 local_host: IPAddressType | None = ...,
114 happy_eyeballs_delay: float = ...,
115) -> SocketStream:
116 ...
117
118
119async def connect_tcp(
120 remote_host: IPAddressType,
121 remote_port: int,
122 *,
123 local_host: IPAddressType | None = None,
124 tls: bool = False,
125 ssl_context: ssl.SSLContext | None = None,
126 tls_standard_compatible: bool = True,
127 tls_hostname: str | None = None,
128 happy_eyeballs_delay: float = 0.25,
129) -> SocketStream | TLSStream:
130 """
131 Connect to a host using the TCP protocol.
132
133 This function implements the stateless version of the Happy Eyeballs algorithm (RFC
134 6555). If ``remote_host`` is a host name that resolves to multiple IP addresses,
135 each one is tried until one connection attempt succeeds. If the first attempt does
136 not connected within 250 milliseconds, a second attempt is started using the next
137 address in the list, and so on. On IPv6 enabled systems, an IPv6 address (if
138 available) is tried first.
139
140 When the connection has been established, a TLS handshake will be done if either
141 ``ssl_context`` or ``tls_hostname`` is not ``None``, or if ``tls`` is ``True``.
142
143 :param remote_host: the IP address or host name to connect to
144 :param remote_port: port on the target host to connect to
145 :param local_host: the interface address or name to bind the socket to before
146 connecting
147 :param tls: ``True`` to do a TLS handshake with the connected stream and return a
148 :class:`~anyio.streams.tls.TLSStream` instead
149 :param ssl_context: the SSL context object to use (if omitted, a default context is
150 created)
151 :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake
152 before closing the stream and requires that the server does this as well.
153 Otherwise, :exc:`~ssl.SSLEOFError` may be raised during reads from the stream.
154 Some protocols, such as HTTP, require this option to be ``False``.
155 See :meth:`~ssl.SSLContext.wrap_socket` for details.
156 :param tls_hostname: host name to check the server certificate against (defaults to
157 the value of ``remote_host``)
158 :param happy_eyeballs_delay: delay (in seconds) before starting the next connection
159 attempt
160 :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream
161 :raises OSError: if the connection attempt fails
162
163 """
164 # Placed here due to https://github.com/python/mypy/issues/7057
165 connected_stream: SocketStream | None = None
166
167 async def try_connect(remote_host: str, event: Event) -> None:
168 nonlocal connected_stream
169 try:
170 stream = await asynclib.connect_tcp(remote_host, remote_port, local_address)
171 except OSError as exc:
172 oserrors.append(exc)
173 return
174 else:
175 if connected_stream is None:
176 connected_stream = stream
177 tg.cancel_scope.cancel()
178 else:
179 await stream.aclose()
180 finally:
181 event.set()
182
183 asynclib = get_async_backend()
184 local_address: IPSockAddrType | None = None
185 family = socket.AF_UNSPEC
186 if local_host:
187 gai_res = await getaddrinfo(str(local_host), None)
188 family, *_, local_address = gai_res[0]
189
190 target_host = str(remote_host)
191 try:
192 addr_obj = ip_address(remote_host)
193 except ValueError:
194 # getaddrinfo() will raise an exception if name resolution fails
195 gai_res = await getaddrinfo(
196 target_host, remote_port, family=family, type=socket.SOCK_STREAM
197 )
198
199 # Organize the list so that the first address is an IPv6 address (if available)
200 # and the second one is an IPv4 addresses. The rest can be in whatever order.
201 v6_found = v4_found = False
202 target_addrs: list[tuple[socket.AddressFamily, str]] = []
203 for af, *rest, sa in gai_res:
204 if af == socket.AF_INET6 and not v6_found:
205 v6_found = True
206 target_addrs.insert(0, (af, sa[0]))
207 elif af == socket.AF_INET and not v4_found and v6_found:
208 v4_found = True
209 target_addrs.insert(1, (af, sa[0]))
210 else:
211 target_addrs.append((af, sa[0]))
212 else:
213 if isinstance(addr_obj, IPv6Address):
214 target_addrs = [(socket.AF_INET6, addr_obj.compressed)]
215 else:
216 target_addrs = [(socket.AF_INET, addr_obj.compressed)]
217
218 oserrors: list[OSError] = []
219 async with create_task_group() as tg:
220 for i, (af, addr) in enumerate(target_addrs):
221 event = Event()
222 tg.start_soon(try_connect, addr, event)
223 with move_on_after(happy_eyeballs_delay):
224 await event.wait()
225
226 if connected_stream is None:
227 cause = (
228 oserrors[0]
229 if len(oserrors) == 1
230 else ExceptionGroup("multiple connection attempts failed", oserrors)
231 )
232 raise OSError("All connection attempts failed") from cause
233
234 if tls or tls_hostname or ssl_context:
235 try:
236 return await TLSStream.wrap(
237 connected_stream,
238 server_side=False,
239 hostname=tls_hostname or str(remote_host),
240 ssl_context=ssl_context,
241 standard_compatible=tls_standard_compatible,
242 )
243 except BaseException:
244 await aclose_forcefully(connected_stream)
245 raise
246
247 return connected_stream
248
249
250async def connect_unix(path: str | bytes | PathLike[Any]) -> UNIXSocketStream:
251 """
252 Connect to the given UNIX socket.
253
254 Not available on Windows.
255
256 :param path: path to the socket
257 :return: a socket stream object
258
259 """
260 path = os.fspath(path)
261 return await get_async_backend().connect_unix(path)
262
263
264async def create_tcp_listener(
265 *,
266 local_host: IPAddressType | None = None,
267 local_port: int = 0,
268 family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC,
269 backlog: int = 65536,
270 reuse_port: bool = False,
271) -> MultiListener[SocketStream]:
272 """
273 Create a TCP socket listener.
274
275 :param local_port: port number to listen on
276 :param local_host: IP address of the interface to listen on. If omitted, listen on
277 all IPv4 and IPv6 interfaces. To listen on all interfaces on a specific address
278 family, use ``0.0.0.0`` for IPv4 or ``::`` for IPv6.
279 :param family: address family (used if ``local_host`` was omitted)
280 :param backlog: maximum number of queued incoming connections (up to a maximum of
281 2**16, or 65536)
282 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
283 address/port (not supported on Windows)
284 :return: a list of listener objects
285
286 """
287 asynclib = get_async_backend()
288 backlog = min(backlog, 65536)
289 local_host = str(local_host) if local_host is not None else None
290 gai_res = await getaddrinfo(
291 local_host,
292 local_port,
293 family=family,
294 type=socket.SocketKind.SOCK_STREAM if sys.platform == "win32" else 0,
295 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
296 )
297 listeners: list[SocketListener] = []
298 try:
299 # The set() is here to work around a glibc bug:
300 # https://sourceware.org/bugzilla/show_bug.cgi?id=14969
301 sockaddr: tuple[str, int] | tuple[str, int, int, int]
302 for fam, kind, *_, sockaddr in sorted(set(gai_res)):
303 # Workaround for an uvloop bug where we don't get the correct scope ID for
304 # IPv6 link-local addresses when passing type=socket.SOCK_STREAM to
305 # getaddrinfo(): https://github.com/MagicStack/uvloop/issues/539
306 if sys.platform != "win32" and kind is not SocketKind.SOCK_STREAM:
307 continue
308
309 raw_socket = socket.socket(fam)
310 raw_socket.setblocking(False)
311
312 # For Windows, enable exclusive address use. For others, enable address
313 # reuse.
314 if sys.platform == "win32":
315 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
316 else:
317 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
318
319 if reuse_port:
320 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
321
322 # If only IPv6 was requested, disable dual stack operation
323 if fam == socket.AF_INET6:
324 raw_socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
325
326 # Workaround for #554
327 if "%" in sockaddr[0]:
328 addr, scope_id = sockaddr[0].split("%", 1)
329 sockaddr = (addr, sockaddr[1], 0, int(scope_id))
330
331 raw_socket.bind(sockaddr)
332 raw_socket.listen(backlog)
333 listener = asynclib.create_tcp_listener(raw_socket)
334 listeners.append(listener)
335 except BaseException:
336 for listener in listeners:
337 await listener.aclose()
338
339 raise
340
341 return MultiListener(listeners)
342
343
344async def create_unix_listener(
345 path: str | bytes | PathLike[Any],
346 *,
347 mode: int | None = None,
348 backlog: int = 65536,
349) -> SocketListener:
350 """
351 Create a UNIX socket listener.
352
353 Not available on Windows.
354
355 :param path: path of the socket
356 :param mode: permissions to set on the socket
357 :param backlog: maximum number of queued incoming connections (up to a maximum of
358 2**16, or 65536)
359 :return: a listener object
360
361 .. versionchanged:: 3.0
362 If a socket already exists on the file system in the given path, it will be
363 removed first.
364
365 """
366 backlog = min(backlog, 65536)
367 raw_socket = await setup_unix_local_socket(path, mode, socket.SOCK_STREAM)
368 try:
369 raw_socket.listen(backlog)
370 return get_async_backend().create_unix_listener(raw_socket)
371 except BaseException:
372 raw_socket.close()
373 raise
374
375
376async def create_udp_socket(
377 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
378 *,
379 local_host: IPAddressType | None = None,
380 local_port: int = 0,
381 reuse_port: bool = False,
382) -> UDPSocket:
383 """
384 Create a UDP socket.
385
386 If ``port`` has been given, the socket will be bound to this port on the local
387 machine, making this socket suitable for providing UDP based services.
388
389 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically
390 determined from ``local_host`` if omitted
391 :param local_host: IP address or host name of the local interface to bind to
392 :param local_port: local port to bind to
393 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
394 address/port (not supported on Windows)
395 :return: a UDP socket
396
397 """
398 if family is AddressFamily.AF_UNSPEC and not local_host:
399 raise ValueError('Either "family" or "local_host" must be given')
400
401 if local_host:
402 gai_res = await getaddrinfo(
403 str(local_host),
404 local_port,
405 family=family,
406 type=socket.SOCK_DGRAM,
407 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
408 )
409 family = cast(AnyIPAddressFamily, gai_res[0][0])
410 local_address = gai_res[0][-1]
411 elif family is AddressFamily.AF_INET6:
412 local_address = ("::", 0)
413 else:
414 local_address = ("0.0.0.0", 0)
415
416 sock = await get_async_backend().create_udp_socket(
417 family, local_address, None, reuse_port
418 )
419 return cast(UDPSocket, sock)
420
421
422async def create_connected_udp_socket(
423 remote_host: IPAddressType,
424 remote_port: int,
425 *,
426 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
427 local_host: IPAddressType | None = None,
428 local_port: int = 0,
429 reuse_port: bool = False,
430) -> ConnectedUDPSocket:
431 """
432 Create a connected UDP socket.
433
434 Connected UDP sockets can only communicate with the specified remote host/port, an
435 any packets sent from other sources are dropped.
436
437 :param remote_host: remote host to set as the default target
438 :param remote_port: port on the remote host to set as the default target
439 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically
440 determined from ``local_host`` or ``remote_host`` if omitted
441 :param local_host: IP address or host name of the local interface to bind to
442 :param local_port: local port to bind to
443 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
444 address/port (not supported on Windows)
445 :return: a connected UDP socket
446
447 """
448 local_address = None
449 if local_host:
450 gai_res = await getaddrinfo(
451 str(local_host),
452 local_port,
453 family=family,
454 type=socket.SOCK_DGRAM,
455 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
456 )
457 family = cast(AnyIPAddressFamily, gai_res[0][0])
458 local_address = gai_res[0][-1]
459
460 gai_res = await getaddrinfo(
461 str(remote_host), remote_port, family=family, type=socket.SOCK_DGRAM
462 )
463 family = cast(AnyIPAddressFamily, gai_res[0][0])
464 remote_address = gai_res[0][-1]
465
466 sock = await get_async_backend().create_udp_socket(
467 family, local_address, remote_address, reuse_port
468 )
469 return cast(ConnectedUDPSocket, sock)
470
471
472async def create_unix_datagram_socket(
473 *,
474 local_path: None | str | bytes | PathLike[Any] = None,
475 local_mode: int | None = None,
476) -> UNIXDatagramSocket:
477 """
478 Create a UNIX datagram socket.
479
480 Not available on Windows.
481
482 If ``local_path`` has been given, the socket will be bound to this path, making this
483 socket suitable for receiving datagrams from other processes. Other processes can
484 send datagrams to this socket only if ``local_path`` is set.
485
486 If a socket already exists on the file system in the ``local_path``, it will be
487 removed first.
488
489 :param local_path: the path on which to bind to
490 :param local_mode: permissions to set on the local socket
491 :return: a UNIX datagram socket
492
493 """
494 raw_socket = await setup_unix_local_socket(
495 local_path, local_mode, socket.SOCK_DGRAM
496 )
497 return await get_async_backend().create_unix_datagram_socket(raw_socket, None)
498
499
500async def create_connected_unix_datagram_socket(
501 remote_path: str | bytes | PathLike[Any],
502 *,
503 local_path: None | str | bytes | PathLike[Any] = None,
504 local_mode: int | None = None,
505) -> ConnectedUNIXDatagramSocket:
506 """
507 Create a connected UNIX datagram socket.
508
509 Connected datagram sockets can only communicate with the specified remote path.
510
511 If ``local_path`` has been given, the socket will be bound to this path, making
512 this socket suitable for receiving datagrams from other processes. Other processes
513 can send datagrams to this socket only if ``local_path`` is set.
514
515 If a socket already exists on the file system in the ``local_path``, it will be
516 removed first.
517
518 :param remote_path: the path to set as the default target
519 :param local_path: the path on which to bind to
520 :param local_mode: permissions to set on the local socket
521 :return: a connected UNIX datagram socket
522
523 """
524 remote_path = os.fspath(remote_path)
525 raw_socket = await setup_unix_local_socket(
526 local_path, local_mode, socket.SOCK_DGRAM
527 )
528 return await get_async_backend().create_unix_datagram_socket(
529 raw_socket, remote_path
530 )
531
532
533async def getaddrinfo(
534 host: bytes | str | None,
535 port: str | int | None,
536 *,
537 family: int | AddressFamily = 0,
538 type: int | SocketKind = 0,
539 proto: int = 0,
540 flags: int = 0,
541) -> list[tuple[AddressFamily, SocketKind, int, str, tuple[str, int]]]:
542 """
543 Look up a numeric IP address given a host name.
544
545 Internationalized domain names are translated according to the (non-transitional)
546 IDNA 2008 standard.
547
548 .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of
549 (host, port), unlike what :func:`socket.getaddrinfo` does.
550
551 :param host: host name
552 :param port: port number
553 :param family: socket family (`'AF_INET``, ...)
554 :param type: socket type (``SOCK_STREAM``, ...)
555 :param proto: protocol number
556 :param flags: flags to pass to upstream ``getaddrinfo()``
557 :return: list of tuples containing (family, type, proto, canonname, sockaddr)
558
559 .. seealso:: :func:`socket.getaddrinfo`
560
561 """
562 # Handle unicode hostnames
563 if isinstance(host, str):
564 try:
565 encoded_host: bytes | None = host.encode("ascii")
566 except UnicodeEncodeError:
567 import idna
568
569 encoded_host = idna.encode(host, uts46=True)
570 else:
571 encoded_host = host
572
573 gai_res = await get_async_backend().getaddrinfo(
574 encoded_host, port, family=family, type=type, proto=proto, flags=flags
575 )
576 return [
577 (family, type, proto, canonname, convert_ipv6_sockaddr(sockaddr))
578 for family, type, proto, canonname, sockaddr in gai_res
579 ]
580
581
582def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str, str]]:
583 """
584 Look up the host name of an IP address.
585
586 :param sockaddr: socket address (e.g. (ipaddress, port) for IPv4)
587 :param flags: flags to pass to upstream ``getnameinfo()``
588 :return: a tuple of (host name, service name)
589
590 .. seealso:: :func:`socket.getnameinfo`
591
592 """
593 return get_async_backend().getnameinfo(sockaddr, flags)
594
595
596def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
597 """
598 Wait until the given socket has data to be read.
599
600 This does **NOT** work on Windows when using the asyncio backend with a proactor
601 event loop (default on py3.8+).
602
603 .. warning:: Only use this on raw sockets that have not been wrapped by any higher
604 level constructs like socket streams!
605
606 :param sock: a socket object
607 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
608 socket to become readable
609 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
610 to become readable
611
612 """
613 return get_async_backend().wait_socket_readable(sock)
614
615
616def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
617 """
618 Wait until the given socket can be written to.
619
620 This does **NOT** work on Windows when using the asyncio backend with a proactor
621 event loop (default on py3.8+).
622
623 .. warning:: Only use this on raw sockets that have not been wrapped by any higher
624 level constructs like socket streams!
625
626 :param sock: a socket object
627 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
628 socket to become writable
629 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
630 to become writable
631
632 """
633 return get_async_backend().wait_socket_writable(sock)
634
635
636#
637# Private API
638#
639
640
641def convert_ipv6_sockaddr(
642 sockaddr: tuple[str, int, int, int] | tuple[str, int],
643) -> tuple[str, int]:
644 """
645 Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format.
646
647 If the scope ID is nonzero, it is added to the address, separated with ``%``.
648 Otherwise the flow id and scope id are simply cut off from the tuple.
649 Any other kinds of socket addresses are returned as-is.
650
651 :param sockaddr: the result of :meth:`~socket.socket.getsockname`
652 :return: the converted socket address
653
654 """
655 # This is more complicated than it should be because of MyPy
656 if isinstance(sockaddr, tuple) and len(sockaddr) == 4:
657 host, port, flowinfo, scope_id = sockaddr
658 if scope_id:
659 # PyPy (as of v7.3.11) leaves the interface name in the result, so
660 # we discard it and only get the scope ID from the end
661 # (https://foss.heptapod.net/pypy/pypy/-/issues/3938)
662 host = host.split("%")[0]
663
664 # Add scope_id to the address
665 return f"{host}%{scope_id}", port
666 else:
667 return host, port
668 else:
669 return sockaddr
670
671
672async def setup_unix_local_socket(
673 path: None | str | bytes | PathLike[Any],
674 mode: int | None,
675 socktype: int,
676) -> socket.socket:
677 """
678 Create a UNIX local socket object, deleting the socket at the given path if it
679 exists.
680
681 Not available on Windows.
682
683 :param path: path of the socket
684 :param mode: permissions to set on the socket
685 :param socktype: socket.SOCK_STREAM or socket.SOCK_DGRAM
686
687 """
688 path_str: str | bytes | None
689 if path is not None:
690 path_str = os.fspath(path)
691
692 # Copied from pathlib...
693 try:
694 stat_result = os.stat(path)
695 except OSError as e:
696 if e.errno not in (errno.ENOENT, errno.ENOTDIR, errno.EBADF, errno.ELOOP):
697 raise
698 else:
699 if stat.S_ISSOCK(stat_result.st_mode):
700 os.unlink(path)
701 else:
702 path_str = None
703
704 raw_socket = socket.socket(socket.AF_UNIX, socktype)
705 raw_socket.setblocking(False)
706
707 if path_str is not None:
708 try:
709 await to_thread.run_sync(raw_socket.bind, path_str, abandon_on_cancel=True)
710 if mode is not None:
711 await to_thread.run_sync(chmod, path_str, mode, abandon_on_cancel=True)
712 except BaseException:
713 raw_socket.close()
714 raise
715
716 return raw_socket