Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/_core/_sockets.py: 24%
182 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
1import socket
2import ssl
3import sys
4from ipaddress import IPv6Address, ip_address
5from os import PathLike, chmod
6from pathlib import Path
7from socket import AddressFamily, SocketKind
8from typing import Awaitable, List, Optional, Tuple, Union, cast, overload
10from .. import to_thread
11from ..abc import (
12 ConnectedUDPSocket,
13 IPAddressType,
14 IPSockAddrType,
15 SocketListener,
16 SocketStream,
17 UDPSocket,
18 UNIXSocketStream,
19)
20from ..streams.stapled import MultiListener
21from ..streams.tls import TLSStream
22from ._eventloop import get_asynclib
23from ._resources import aclose_forcefully
24from ._synchronization import Event
25from ._tasks import create_task_group, move_on_after
27if sys.version_info >= (3, 8):
28 from typing import Literal
29else:
30 from typing_extensions import Literal
32IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515
34GetAddrInfoReturnType = List[
35 Tuple[AddressFamily, SocketKind, int, str, Tuple[str, int]]
36]
37AnyIPAddressFamily = Literal[
38 AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
39]
40IPAddressFamily = Literal[AddressFamily.AF_INET, AddressFamily.AF_INET6]
43# tls_hostname given
44@overload
45async def connect_tcp(
46 remote_host: IPAddressType,
47 remote_port: int,
48 *,
49 local_host: Optional[IPAddressType] = ...,
50 ssl_context: Optional[ssl.SSLContext] = ...,
51 tls_standard_compatible: bool = ...,
52 tls_hostname: str,
53 happy_eyeballs_delay: float = ...,
54) -> TLSStream:
55 ...
58# ssl_context given
59@overload
60async def connect_tcp(
61 remote_host: IPAddressType,
62 remote_port: int,
63 *,
64 local_host: Optional[IPAddressType] = ...,
65 ssl_context: ssl.SSLContext,
66 tls_standard_compatible: bool = ...,
67 tls_hostname: Optional[str] = ...,
68 happy_eyeballs_delay: float = ...,
69) -> TLSStream:
70 ...
73# tls=True
74@overload
75async def connect_tcp(
76 remote_host: IPAddressType,
77 remote_port: int,
78 *,
79 local_host: Optional[IPAddressType] = ...,
80 tls: Literal[True],
81 ssl_context: Optional[ssl.SSLContext] = ...,
82 tls_standard_compatible: bool = ...,
83 tls_hostname: Optional[str] = ...,
84 happy_eyeballs_delay: float = ...,
85) -> TLSStream:
86 ...
89# tls=False
90@overload
91async def connect_tcp(
92 remote_host: IPAddressType,
93 remote_port: int,
94 *,
95 local_host: Optional[IPAddressType] = ...,
96 tls: Literal[False],
97 ssl_context: Optional[ssl.SSLContext] = ...,
98 tls_standard_compatible: bool = ...,
99 tls_hostname: Optional[str] = ...,
100 happy_eyeballs_delay: float = ...,
101) -> SocketStream:
102 ...
105# No TLS arguments
106@overload
107async def connect_tcp(
108 remote_host: IPAddressType,
109 remote_port: int,
110 *,
111 local_host: Optional[IPAddressType] = ...,
112 happy_eyeballs_delay: float = ...,
113) -> SocketStream:
114 ...
117async def connect_tcp(
118 remote_host: IPAddressType,
119 remote_port: int,
120 *,
121 local_host: Optional[IPAddressType] = None,
122 tls: bool = False,
123 ssl_context: Optional[ssl.SSLContext] = None,
124 tls_standard_compatible: bool = True,
125 tls_hostname: Optional[str] = None,
126 happy_eyeballs_delay: float = 0.25,
127) -> Union[SocketStream, TLSStream]:
128 """
129 Connect to a host using the TCP protocol.
131 This function implements the stateless version of the Happy Eyeballs algorithm (RFC 6555).
132 If ``address`` is a host name that resolves to multiple IP addresses, each one is tried until
133 one connection attempt succeeds. If the first attempt does not connected within 250
134 milliseconds, a second attempt is started using the next address in the list, and so on.
135 On IPv6 enabled systems, an IPv6 address (if available) is tried first.
137 When the connection has been established, a TLS handshake will be done if either
138 ``ssl_context`` or ``tls_hostname`` is not ``None``, or if ``tls`` is ``True``.
140 :param remote_host: the IP address or host name to connect to
141 :param remote_port: port on the target host to connect to
142 :param local_host: the interface address or name to bind the socket to before connecting
143 :param tls: ``True`` to do a TLS handshake with the connected stream and return a
144 :class:`~anyio.streams.tls.TLSStream` instead
145 :param ssl_context: the SSL context object to use (if omitted, a default context is created)
146 :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake before closing
147 the stream and requires that the server does this as well. Otherwise,
148 :exc:`~ssl.SSLEOFError` may be raised during reads from the stream.
149 Some protocols, such as HTTP, require this option to be ``False``.
150 See :meth:`~ssl.SSLContext.wrap_socket` for details.
151 :param tls_hostname: host name to check the server certificate against (defaults to the value
152 of ``remote_host``)
153 :param happy_eyeballs_delay: delay (in seconds) before starting the next connection attempt
154 :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream
155 :raises OSError: if the connection attempt fails
157 """
158 # Placed here due to https://github.com/python/mypy/issues/7057
159 connected_stream: Optional[SocketStream] = None
161 async def try_connect(remote_host: str, event: Event) -> None:
162 nonlocal connected_stream
163 try:
164 stream = await asynclib.connect_tcp(remote_host, remote_port, local_address)
165 except OSError as exc:
166 oserrors.append(exc)
167 return
168 else:
169 if connected_stream is None:
170 connected_stream = stream
171 tg.cancel_scope.cancel()
172 else:
173 await stream.aclose()
174 finally:
175 event.set()
177 asynclib = get_asynclib()
178 local_address: Optional[IPSockAddrType] = None
179 family = socket.AF_UNSPEC
180 if local_host:
181 gai_res = await getaddrinfo(str(local_host), None)
182 family, *_, local_address = gai_res[0]
184 target_host = str(remote_host)
185 try:
186 addr_obj = ip_address(remote_host)
187 except ValueError:
188 # getaddrinfo() will raise an exception if name resolution fails
189 gai_res = await getaddrinfo(
190 target_host, remote_port, family=family, type=socket.SOCK_STREAM
191 )
193 # Organize the list so that the first address is an IPv6 address (if available) and the
194 # second one is an IPv4 addresses. The rest can be in whatever order.
195 v6_found = v4_found = False
196 target_addrs: List[Tuple[socket.AddressFamily, str]] = []
197 for af, *rest, sa in gai_res:
198 if af == socket.AF_INET6 and not v6_found:
199 v6_found = True
200 target_addrs.insert(0, (af, sa[0]))
201 elif af == socket.AF_INET and not v4_found and v6_found:
202 v4_found = True
203 target_addrs.insert(1, (af, sa[0]))
204 else:
205 target_addrs.append((af, sa[0]))
206 else:
207 if isinstance(addr_obj, IPv6Address):
208 target_addrs = [(socket.AF_INET6, addr_obj.compressed)]
209 else:
210 target_addrs = [(socket.AF_INET, addr_obj.compressed)]
212 oserrors: List[OSError] = []
213 async with create_task_group() as tg:
214 for i, (af, addr) in enumerate(target_addrs):
215 event = Event()
216 tg.start_soon(try_connect, addr, event)
217 with move_on_after(happy_eyeballs_delay):
218 await event.wait()
220 if connected_stream is None:
221 cause = oserrors[0] if len(oserrors) == 1 else asynclib.ExceptionGroup(oserrors)
222 raise OSError("All connection attempts failed") from cause
224 if tls or tls_hostname or ssl_context:
225 try:
226 return await TLSStream.wrap(
227 connected_stream,
228 server_side=False,
229 hostname=tls_hostname or str(remote_host),
230 ssl_context=ssl_context,
231 standard_compatible=tls_standard_compatible,
232 )
233 except BaseException:
234 await aclose_forcefully(connected_stream)
235 raise
237 return connected_stream
240async def connect_unix(path: Union[str, "PathLike[str]"]) -> UNIXSocketStream:
241 """
242 Connect to the given UNIX socket.
244 Not available on Windows.
246 :param path: path to the socket
247 :return: a socket stream object
249 """
250 path = str(Path(path))
251 return await get_asynclib().connect_unix(path)
254async def create_tcp_listener(
255 *,
256 local_host: Optional[IPAddressType] = None,
257 local_port: int = 0,
258 family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC,
259 backlog: int = 65536,
260 reuse_port: bool = False,
261) -> MultiListener[SocketStream]:
262 """
263 Create a TCP socket listener.
265 :param local_port: port number to listen on
266 :param local_host: IP address of the interface to listen on. If omitted, listen on all IPv4
267 and IPv6 interfaces. To listen on all interfaces on a specific address family, use
268 ``0.0.0.0`` for IPv4 or ``::`` for IPv6.
269 :param family: address family (used if ``interface`` was omitted)
270 :param backlog: maximum number of queued incoming connections (up to a maximum of 2**16, or
271 65536)
272 :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port
273 (not supported on Windows)
274 :return: a list of listener objects
276 """
277 asynclib = get_asynclib()
278 backlog = min(backlog, 65536)
279 local_host = str(local_host) if local_host is not None else None
280 gai_res = await getaddrinfo(
281 local_host, # type: ignore[arg-type]
282 local_port,
283 family=family,
284 type=socket.SOCK_STREAM,
285 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
286 )
287 listeners: List[SocketListener] = []
288 try:
289 # The set() is here to work around a glibc bug:
290 # https://sourceware.org/bugzilla/show_bug.cgi?id=14969
291 for fam, *_, sockaddr in sorted(set(gai_res)):
292 raw_socket = socket.socket(fam)
293 raw_socket.setblocking(False)
295 # For Windows, enable exclusive address use. For others, enable address reuse.
296 if sys.platform == "win32":
297 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
298 else:
299 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
301 if reuse_port:
302 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
304 # If only IPv6 was requested, disable dual stack operation
305 if fam == socket.AF_INET6:
306 raw_socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
308 raw_socket.bind(sockaddr)
309 raw_socket.listen(backlog)
310 listener = asynclib.TCPSocketListener(raw_socket)
311 listeners.append(listener)
312 except BaseException:
313 for listener in listeners:
314 await listener.aclose()
316 raise
318 return MultiListener(listeners)
321async def create_unix_listener(
322 path: Union[str, "PathLike[str]"],
323 *,
324 mode: Optional[int] = None,
325 backlog: int = 65536,
326) -> SocketListener:
327 """
328 Create a UNIX socket listener.
330 Not available on Windows.
332 :param path: path of the socket
333 :param mode: permissions to set on the socket
334 :param backlog: maximum number of queued incoming connections (up to a maximum of 2**16, or
335 65536)
336 :return: a listener object
338 .. versionchanged:: 3.0
339 If a socket already exists on the file system in the given path, it will be removed first.
341 """
342 path_str = str(path)
343 path = Path(path)
344 if path.is_socket():
345 path.unlink()
347 backlog = min(backlog, 65536)
348 raw_socket = socket.socket(socket.AF_UNIX)
349 raw_socket.setblocking(False)
350 try:
351 await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True)
352 if mode is not None:
353 await to_thread.run_sync(chmod, path_str, mode, cancellable=True)
355 raw_socket.listen(backlog)
356 return get_asynclib().UNIXSocketListener(raw_socket)
357 except BaseException:
358 raw_socket.close()
359 raise
362async def create_udp_socket(
363 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
364 *,
365 local_host: Optional[IPAddressType] = None,
366 local_port: int = 0,
367 reuse_port: bool = False,
368) -> UDPSocket:
369 """
370 Create a UDP socket.
372 If ``port`` has been given, the socket will be bound to this port on the local machine,
373 making this socket suitable for providing UDP based services.
375 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from
376 ``local_host`` if omitted
377 :param local_host: IP address or host name of the local interface to bind to
378 :param local_port: local port to bind to
379 :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port
380 (not supported on Windows)
381 :return: a UDP socket
383 """
384 if family is AddressFamily.AF_UNSPEC and not local_host:
385 raise ValueError('Either "family" or "local_host" must be given')
387 if local_host:
388 gai_res = await getaddrinfo(
389 str(local_host),
390 local_port,
391 family=family,
392 type=socket.SOCK_DGRAM,
393 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
394 )
395 family = cast(AnyIPAddressFamily, gai_res[0][0])
396 local_address = gai_res[0][-1]
397 elif family is AddressFamily.AF_INET6:
398 local_address = ("::", 0)
399 else:
400 local_address = ("0.0.0.0", 0)
402 return await get_asynclib().create_udp_socket(
403 family, local_address, None, reuse_port
404 )
407async def create_connected_udp_socket(
408 remote_host: IPAddressType,
409 remote_port: int,
410 *,
411 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC,
412 local_host: Optional[IPAddressType] = None,
413 local_port: int = 0,
414 reuse_port: bool = False,
415) -> ConnectedUDPSocket:
416 """
417 Create a connected UDP socket.
419 Connected UDP sockets can only communicate with the specified remote host/port, and any packets
420 sent from other sources are dropped.
422 :param remote_host: remote host to set as the default target
423 :param remote_port: port on the remote host to set as the default target
424 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from
425 ``local_host`` or ``remote_host`` if omitted
426 :param local_host: IP address or host name of the local interface to bind to
427 :param local_port: local port to bind to
428 :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port
429 (not supported on Windows)
430 :return: a connected UDP socket
432 """
433 local_address = None
434 if local_host:
435 gai_res = await getaddrinfo(
436 str(local_host),
437 local_port,
438 family=family,
439 type=socket.SOCK_DGRAM,
440 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
441 )
442 family = cast(AnyIPAddressFamily, gai_res[0][0])
443 local_address = gai_res[0][-1]
445 gai_res = await getaddrinfo(
446 str(remote_host), remote_port, family=family, type=socket.SOCK_DGRAM
447 )
448 family = cast(AnyIPAddressFamily, gai_res[0][0])
449 remote_address = gai_res[0][-1]
451 return await get_asynclib().create_udp_socket(
452 family, local_address, remote_address, reuse_port
453 )
456async def getaddrinfo(
457 host: Union[bytearray, bytes, str],
458 port: Union[str, int, None],
459 *,
460 family: Union[int, AddressFamily] = 0,
461 type: Union[int, SocketKind] = 0,
462 proto: int = 0,
463 flags: int = 0,
464) -> GetAddrInfoReturnType:
465 """
466 Look up a numeric IP address given a host name.
468 Internationalized domain names are translated according to the (non-transitional) IDNA 2008
469 standard.
471 .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of
472 (host, port), unlike what :func:`socket.getaddrinfo` does.
474 :param host: host name
475 :param port: port number
476 :param family: socket family (`'AF_INET``, ...)
477 :param type: socket type (``SOCK_STREAM``, ...)
478 :param proto: protocol number
479 :param flags: flags to pass to upstream ``getaddrinfo()``
480 :return: list of tuples containing (family, type, proto, canonname, sockaddr)
482 .. seealso:: :func:`socket.getaddrinfo`
484 """
485 # Handle unicode hostnames
486 if isinstance(host, str):
487 try:
488 encoded_host = host.encode("ascii")
489 except UnicodeEncodeError:
490 import idna
492 encoded_host = idna.encode(host, uts46=True)
493 else:
494 encoded_host = host
496 gai_res = await get_asynclib().getaddrinfo(
497 encoded_host, port, family=family, type=type, proto=proto, flags=flags
498 )
499 return [
500 (family, type, proto, canonname, convert_ipv6_sockaddr(sockaddr))
501 for family, type, proto, canonname, sockaddr in gai_res
502 ]
505def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[Tuple[str, str]]:
506 """
507 Look up the host name of an IP address.
509 :param sockaddr: socket address (e.g. (ipaddress, port) for IPv4)
510 :param flags: flags to pass to upstream ``getnameinfo()``
511 :return: a tuple of (host name, service name)
513 .. seealso:: :func:`socket.getnameinfo`
515 """
516 return get_asynclib().getnameinfo(sockaddr, flags)
519def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
520 """
521 Wait until the given socket has data to be read.
523 This does **NOT** work on Windows when using the asyncio backend with a proactor event loop
524 (default on py3.8+).
526 .. warning:: Only use this on raw sockets that have not been wrapped by any higher level
527 constructs like socket streams!
529 :param sock: a socket object
530 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
531 socket to become readable
532 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
533 to become readable
535 """
536 return get_asynclib().wait_socket_readable(sock)
539def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
540 """
541 Wait until the given socket can be written to.
543 This does **NOT** work on Windows when using the asyncio backend with a proactor event loop
544 (default on py3.8+).
546 .. warning:: Only use this on raw sockets that have not been wrapped by any higher level
547 constructs like socket streams!
549 :param sock: a socket object
550 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
551 socket to become writable
552 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket
553 to become writable
555 """
556 return get_asynclib().wait_socket_writable(sock)
559#
560# Private API
561#
564def convert_ipv6_sockaddr(
565 sockaddr: Union[Tuple[str, int, int, int], Tuple[str, int]]
566) -> Tuple[str, int]:
567 """
568 Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format.
570 If the scope ID is nonzero, it is added to the address, separated with ``%``.
571 Otherwise the flow id and scope id are simply cut off from the tuple.
572 Any other kinds of socket addresses are returned as-is.
574 :param sockaddr: the result of :meth:`~socket.socket.getsockname`
575 :return: the converted socket address
577 """
578 # This is more complicated than it should be because of MyPy
579 if isinstance(sockaddr, tuple) and len(sockaddr) == 4:
580 host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr)
581 if scope_id:
582 # Add scope_id to the address
583 return f"{host}%{scope_id}", port
584 else:
585 return host, port
586 else:
587 return cast(Tuple[str, int], sockaddr)