Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/_core/_sockets.py: 23%
189 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 07:19 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 07:19 +0000
1from __future__ import annotations
3import socket
4import ssl
5import sys
6from ipaddress import IPv6Address, ip_address
7from os import PathLike, chmod
8from pathlib import Path
9from socket import AddressFamily, SocketKind
10from typing import Awaitable, List, Tuple, cast, overload
12from .. import to_thread
13from ..abc import (
14 ConnectedUDPSocket,
15 IPAddressType,
16 IPSockAddrType,
17 SocketListener,
18 SocketStream,
19 UDPSocket,
20 UNIXSocketStream,
21)
22from ..streams.stapled import MultiListener
23from ..streams.tls import TLSStream
24from ._eventloop import get_asynclib
25from ._resources import aclose_forcefully
26from ._synchronization import Event
27from ._tasks import create_task_group, move_on_after
29if sys.version_info >= (3, 8):
30 from typing import Literal
31else:
32 from typing_extensions import Literal
34IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515
36GetAddrInfoReturnType = List[
37 Tuple[AddressFamily, SocketKind, int, str, Tuple[str, int]]
38]
39AnyIPAddressFamily = Literal[
40 AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
41]
42IPAddressFamily = Literal[AddressFamily.AF_INET, AddressFamily.AF_INET6]
45# tls_hostname given
46@overload
47async def connect_tcp(
48 remote_host: IPAddressType,
49 remote_port: int,
50 *,
51 local_host: IPAddressType | None = ...,
52 ssl_context: ssl.SSLContext | None = ...,
53 tls_standard_compatible: bool = ...,
54 tls_hostname: str,
55 happy_eyeballs_delay: float = ...,
56) -> TLSStream:
57 ...
60# ssl_context given
61@overload
62async def connect_tcp(
63 remote_host: IPAddressType,
64 remote_port: int,
65 *,
66 local_host: IPAddressType | None = ...,
67 ssl_context: ssl.SSLContext,
68 tls_standard_compatible: bool = ...,
69 tls_hostname: str | None = ...,
70 happy_eyeballs_delay: float = ...,
71) -> TLSStream:
72 ...
75# tls=True
76@overload
77async def connect_tcp(
78 remote_host: IPAddressType,
79 remote_port: int,
80 *,
81 local_host: IPAddressType | None = ...,
82 tls: Literal[True],
83 ssl_context: ssl.SSLContext | None = ...,
84 tls_standard_compatible: bool = ...,
85 tls_hostname: str | None = ...,
86 happy_eyeballs_delay: float = ...,
87) -> TLSStream:
88 ...
91# tls=False
92@overload
93async def connect_tcp(
94 remote_host: IPAddressType,
95 remote_port: int,
96 *,
97 local_host: IPAddressType | None = ...,
98 tls: Literal[False],
99 ssl_context: ssl.SSLContext | None = ...,
100 tls_standard_compatible: bool = ...,
101 tls_hostname: str | None = ...,
102 happy_eyeballs_delay: float = ...,
103) -> SocketStream:
104 ...
107# No TLS arguments
108@overload
109async def connect_tcp(
110 remote_host: IPAddressType,
111 remote_port: int,
112 *,
113 local_host: IPAddressType | None = ...,
114 happy_eyeballs_delay: float = ...,
115) -> SocketStream:
116 ...
119async def connect_tcp(
120 remote_host: IPAddressType,
121 remote_port: int,
122 *,
123 local_host: IPAddressType | None = None,
124 tls: bool = False,
125 ssl_context: ssl.SSLContext | None = None,
126 tls_standard_compatible: bool = True,
127 tls_hostname: str | None = None,
128 happy_eyeballs_delay: float = 0.25,
129) -> SocketStream | TLSStream:
130 """
131 Connect to a host using the TCP protocol.
133 This function implements the stateless version of the Happy Eyeballs algorithm (RFC
134 6555). If ``remote_host`` is a host name that resolves to multiple IP addresses,
135 each one is tried until one connection attempt succeeds. If the first attempt does
136 not connected within 250 milliseconds, a second attempt is started using the next
137 address in the list, and so on. On IPv6 enabled systems, an IPv6 address (if
138 available) is tried first.
140 When the connection has been established, a TLS handshake will be done if either
141 ``ssl_context`` or ``tls_hostname`` is not ``None``, or if ``tls`` is ``True``.
143 :param remote_host: the IP address or host name to connect to
144 :param remote_port: port on the target host to connect to
145 :param local_host: the interface address or name to bind the socket to before connecting
146 :param tls: ``True`` to do a TLS handshake with the connected stream and return a
147 :class:`~anyio.streams.tls.TLSStream` instead
148 :param ssl_context: the SSL context object to use (if omitted, a default context is created)
149 :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake before closing
150 the stream and requires that the server does this as well. Otherwise,
151 :exc:`~ssl.SSLEOFError` may be raised during reads from the stream.
152 Some protocols, such as HTTP, require this option to be ``False``.
153 See :meth:`~ssl.SSLContext.wrap_socket` for details.
154 :param tls_hostname: host name to check the server certificate against (defaults to the value
155 of ``remote_host``)
156 :param happy_eyeballs_delay: delay (in seconds) before starting the next connection attempt
157 :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream
158 :raises OSError: if the connection attempt fails
160 """
161 # Placed here due to https://github.com/python/mypy/issues/7057
162 connected_stream: SocketStream | None = None
164 async def try_connect(remote_host: str, event: Event) -> None:
165 nonlocal connected_stream
166 try:
167 stream = await asynclib.connect_tcp(remote_host, remote_port, local_address)
168 except OSError as exc:
169 oserrors.append(exc)
170 return
171 else:
172 if connected_stream is None:
173 connected_stream = stream
174 tg.cancel_scope.cancel()
175 else:
176 await stream.aclose()
177 finally:
178 event.set()
180 asynclib = get_asynclib()
181 local_address: IPSockAddrType | None = None
182 family = socket.AF_UNSPEC
183 if local_host:
184 gai_res = await getaddrinfo(str(local_host), None)
185 family, *_, local_address = gai_res[0]
187 target_host = str(remote_host)
188 try:
189 addr_obj = ip_address(remote_host)
190 except ValueError:
191 # getaddrinfo() will raise an exception if name resolution fails
192 gai_res = await getaddrinfo(
193 target_host, remote_port, family=family, type=socket.SOCK_STREAM
194 )
196 # Organize the list so that the first address is an IPv6 address (if available) and the
197 # second one is an IPv4 addresses. The rest can be in whatever order.
198 v6_found = v4_found = False
199 target_addrs: list[tuple[socket.AddressFamily, str]] = []
200 for af, *rest, sa in gai_res:
201 if af == socket.AF_INET6 and not v6_found:
202 v6_found = True
203 target_addrs.insert(0, (af, sa[0]))
204 elif af == socket.AF_INET and not v4_found and v6_found:
205 v4_found = True
206 target_addrs.insert(1, (af, sa[0]))
207 else:
208 target_addrs.append((af, sa[0]))
209 else:
210 if isinstance(addr_obj, IPv6Address):
211 target_addrs = [(socket.AF_INET6, addr_obj.compressed)]
212 else:
213 target_addrs = [(socket.AF_INET, addr_obj.compressed)]
215 oserrors: list[OSError] = []
216 async with create_task_group() as tg:
217 for i, (af, addr) in enumerate(target_addrs):
218 event = Event()
219 tg.start_soon(try_connect, addr, event)
220 with move_on_after(happy_eyeballs_delay):
221 await event.wait()
223 if connected_stream is None:
224 cause = oserrors[0] if len(oserrors) == 1 else asynclib.ExceptionGroup(oserrors)
225 raise OSError("All connection attempts failed") from cause
227 if tls or tls_hostname or ssl_context:
228 try:
229 return await TLSStream.wrap(
230 connected_stream,
231 server_side=False,
232 hostname=tls_hostname or str(remote_host),
233 ssl_context=ssl_context,
234 standard_compatible=tls_standard_compatible,
235 )
236 except BaseException:
237 await aclose_forcefully(connected_stream)
238 raise
240 return connected_stream
243async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream:
244 """
245 Connect to the given UNIX socket.
247 Not available on Windows.
249 :param path: path to the socket
250 :return: a socket stream object
252 """
253 path = str(Path(path))
254 return await get_asynclib().connect_unix(path)
257async def create_tcp_listener(
258 *,
259 local_host: IPAddressType | None = None,
260 local_port: int = 0,
261 family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC,
262 backlog: int = 65536,
263 reuse_port: bool = False,
264) -> MultiListener[SocketStream]:
265 """
266 Create a TCP socket listener.
268 :param local_port: port number to listen on
269 :param local_host: IP address of the interface to listen on. If omitted, listen on
270 all IPv4 and IPv6 interfaces. To listen on all interfaces on a specific address
271 family, use ``0.0.0.0`` for IPv4 or ``::`` for IPv6.
272 :param family: address family (used if ``local_host`` was omitted)
273 :param backlog: maximum number of queued incoming connections (up to a maximum of
274 2**16, or 65536)
275 :param reuse_port: ``True`` to allow multiple sockets to bind to the same
276 address/port (not supported on Windows)
277 :return: a list of listener objects
279 """
280 asynclib = get_asynclib()
281 backlog = min(backlog, 65536)
282 local_host = str(local_host) if local_host is not None else None
283 gai_res = await getaddrinfo(
284 local_host, # type: ignore[arg-type]
285 local_port,
286 family=family,
287 type=socket.SocketKind.SOCK_STREAM if sys.platform == "win32" else 0,
288 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
289 )
290 listeners: list[SocketListener] = []
291 try:
292 # The set() is here to work around a glibc bug:
293 # https://sourceware.org/bugzilla/show_bug.cgi?id=14969
294 sockaddr: tuple[str, int] | tuple[str, int, int, int]
295 for fam, kind, *_, sockaddr in sorted(set(gai_res)):
296 # Workaround for an uvloop bug where we don't get the correct scope ID for
297 # IPv6 link-local addresses when passing type=socket.SOCK_STREAM to
298 # getaddrinfo(): https://github.com/MagicStack/uvloop/issues/539
299 if sys.platform != "win32" and kind is not SocketKind.SOCK_STREAM:
300 continue
302 raw_socket = socket.socket(fam)
303 raw_socket.setblocking(False)
305 # For Windows, enable exclusive address use. For others, enable address reuse.
306 if sys.platform == "win32":
307 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
308 else:
309 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
311 if reuse_port:
312 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
314 # If only IPv6 was requested, disable dual stack operation
315 if fam == socket.AF_INET6:
316 raw_socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
318 # Workaround for #554
319 if "%" in sockaddr[0]:
320 addr, scope_id = sockaddr[0].split("%", 1)
321 sockaddr = (addr, sockaddr[1], 0, int(scope_id))
323 raw_socket.bind(sockaddr)
324 raw_socket.listen(backlog)
325 listener = asynclib.TCPSocketListener(raw_socket)
326 listeners.append(listener)
327 except BaseException:
328 for listener in listeners:
329 await listener.aclose()
331 raise
333 return MultiListener(listeners)
336async def create_unix_listener(
337 path: str | PathLike[str],
338 *,
339 mode: int | None = None,
340 backlog: int = 65536,
341) -> SocketListener:
342 """
343 Create a UNIX socket listener.
345 Not available on Windows.
347 :param path: path of the socket
348 :param mode: permissions to set on the socket
349 :param backlog: maximum number of queued incoming connections (up to a maximum of 2**16, or
350 65536)
351 :return: a listener object
353 .. versionchanged:: 3.0
354 If a socket already exists on the file system in the given path, it will be removed first.
356 """
357 path_str = str(path)
358 path = Path(path)
359 if path.is_socket():
360 path.unlink()
362 backlog = min(backlog, 65536)
363 raw_socket = socket.socket(socket.AF_UNIX)
364 raw_socket.setblocking(False)
365 try:
366 await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True)
367 if mode is not None:
368 await to_thread.run_sync(chmod, path_str, mode, cancellable=True)
370 raw_socket.listen(backlog)
371 return get_asynclib().UNIXSocketListener(raw_socket)
372 except BaseException:
373 raw_socket.close()
374 raise
377async def create_udp_socket(
378 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
379 *,
380 local_host: IPAddressType | None = None,
381 local_port: int = 0,
382 reuse_port: bool = False,
383) -> UDPSocket:
384 """
385 Create a UDP socket.
387 If ``local_port`` has been given, the socket will be bound to this port on the local
388 machine, making this socket suitable for providing UDP based services.
390 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from
391 ``local_host`` if omitted
392 :param local_host: IP address or host name of the local interface to bind to
393 :param local_port: local port to bind to
394 :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port
395 (not supported on Windows)
396 :return: a UDP socket
398 """
399 if family is AddressFamily.AF_UNSPEC and not local_host:
400 raise ValueError('Either "family" or "local_host" must be given')
402 if local_host:
403 gai_res = await getaddrinfo(
404 str(local_host),
405 local_port,
406 family=family,
407 type=socket.SOCK_DGRAM,
408 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
409 )
410 family = cast(AnyIPAddressFamily, gai_res[0][0])
411 local_address = gai_res[0][-1]
412 elif family is AddressFamily.AF_INET6:
413 local_address = ("::", 0)
414 else:
415 local_address = ("0.0.0.0", 0)
417 return await get_asynclib().create_udp_socket(
418 family, local_address, None, reuse_port
419 )
422async def create_connected_udp_socket(
423 remote_host: IPAddressType,
424 remote_port: int,
425 *,
426 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
427 local_host: IPAddressType | None = None,
428 local_port: int = 0,
429 reuse_port: bool = False,
430) -> ConnectedUDPSocket:
431 """
432 Create a connected UDP socket.
434 Connected UDP sockets can only communicate with the specified remote host/port, and any packets
435 sent from other sources are dropped.
437 :param remote_host: remote host to set as the default target
438 :param remote_port: port on the remote host to set as the default target
439 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from
440 ``local_host`` or ``remote_host`` if omitted
441 :param local_host: IP address or host name of the local interface to bind to
442 :param local_port: local port to bind to
443 :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port
444 (not supported on Windows)
445 :return: a connected UDP socket
447 """
448 local_address = None
449 if local_host:
450 gai_res = await getaddrinfo(
451 str(local_host),
452 local_port,
453 family=family,
454 type=socket.SOCK_DGRAM,
455 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
456 )
457 family = cast(AnyIPAddressFamily, gai_res[0][0])
458 local_address = gai_res[0][-1]
460 gai_res = await getaddrinfo(
461 str(remote_host), remote_port, family=family, type=socket.SOCK_DGRAM
462 )
463 family = cast(AnyIPAddressFamily, gai_res[0][0])
464 remote_address = gai_res[0][-1]
466 return await get_asynclib().create_udp_socket(
467 family, local_address, remote_address, reuse_port
468 )
471async def getaddrinfo(
472 host: bytearray | bytes | str,
473 port: str | int | None,
474 *,
475 family: int | AddressFamily = 0,
476 type: int | SocketKind = 0,
477 proto: int = 0,
478 flags: int = 0,
479) -> GetAddrInfoReturnType:
480 """
481 Look up a numeric IP address given a host name.
483 Internationalized domain names are translated according to the (non-transitional) IDNA 2008
484 standard.
486 .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of
487 (host, port), unlike what :func:`socket.getaddrinfo` does.
489 :param host: host name
490 :param port: port number
491 :param family: socket family (`'AF_INET``, ...)
492 :param type: socket type (``SOCK_STREAM``, ...)
493 :param proto: protocol number
494 :param flags: flags to pass to upstream ``getaddrinfo()``
495 :return: list of tuples containing (family, type, proto, canonname, sockaddr)
497 .. seealso:: :func:`socket.getaddrinfo`
499 """
500 # Handle unicode hostnames
501 if isinstance(host, str):
502 try:
503 encoded_host = host.encode("ascii")
504 except UnicodeEncodeError:
505 import idna
507 encoded_host = idna.encode(host, uts46=True)
508 else:
509 encoded_host = host
511 gai_res = await get_asynclib().getaddrinfo(
512 encoded_host, port, family=family, type=type, proto=proto, flags=flags
513 )
514 return [
515 (family, type, proto, canonname, convert_ipv6_sockaddr(sockaddr))
516 for family, type, proto, canonname, sockaddr in gai_res
517 ]
520def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str, str]]:
521 """
522 Look up the host name of an IP address.
524 :param sockaddr: socket address (e.g. (ipaddress, port) for IPv4)
525 :param flags: flags to pass to upstream ``getnameinfo()``
526 :return: a tuple of (host name, service name)
528 .. seealso:: :func:`socket.getnameinfo`
530 """
531 return get_asynclib().getnameinfo(sockaddr, flags)
534def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
535 """
536 Wait until the given socket has data to be read.
538 This does **NOT** work on Windows when using the asyncio backend with a proactor event loop
539 (default on py3.8+).
541 .. warning:: Only use this on raw sockets that have not been wrapped by any higher level
542 constructs like socket streams!
544 :param sock: a socket object
545 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
546 socket to become readable
547 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
548 to become readable
550 """
551 return get_asynclib().wait_socket_readable(sock)
554def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
555 """
556 Wait until the given socket can be written to.
558 This does **NOT** work on Windows when using the asyncio backend with a proactor event loop
559 (default on py3.8+).
561 .. warning:: Only use this on raw sockets that have not been wrapped by any higher level
562 constructs like socket streams!
564 :param sock: a socket object
565 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
566 socket to become writable
567 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
568 to become writable
570 """
571 return get_asynclib().wait_socket_writable(sock)
574#
575# Private API
576#
579def convert_ipv6_sockaddr(
580 sockaddr: tuple[str, int, int, int] | tuple[str, int]
581) -> tuple[str, int]:
582 """
583 Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format.
585 If the scope ID is nonzero, it is added to the address, separated with ``%``.
586 Otherwise the flow id and scope id are simply cut off from the tuple.
587 Any other kinds of socket addresses are returned as-is.
589 :param sockaddr: the result of :meth:`~socket.socket.getsockname`
590 :return: the converted socket address
592 """
593 # This is more complicated than it should be because of MyPy
594 if isinstance(sockaddr, tuple) and len(sockaddr) == 4:
595 host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr)
596 if scope_id:
597 # PyPy (as of v7.3.11) leaves the interface name in the result, so
598 # we discard it and only get the scope ID from the end
599 # (https://foss.heptapod.net/pypy/pypy/-/issues/3938)
600 host = host.split("%")[0]
602 # Add scope_id to the address
603 return f"{host}%{scope_id}", port
604 else:
605 return host, port
606 else:
607 return cast(Tuple[str, int], sockaddr)