1"""Base implementation.""" 
    2 
    3import asyncio 
    4import collections 
    5import contextlib 
    6import functools 
    7import itertools 
    8import socket 
    9from typing import List, Optional, Sequence, Set, Union 
    10 
    11from . import _staggered 
    12from .types import AddrInfoType, SocketFactoryType 
    13 
    14 
    15async def start_connection( 
    16    addr_infos: Sequence[AddrInfoType], 
    17    *, 
    18    local_addr_infos: Optional[Sequence[AddrInfoType]] = None, 
    19    happy_eyeballs_delay: Optional[float] = None, 
    20    interleave: Optional[int] = None, 
    21    loop: Optional[asyncio.AbstractEventLoop] = None, 
    22    socket_factory: Optional[SocketFactoryType] = None, 
    23) -> socket.socket: 
    24    """ 
    25    Connect to a TCP server. 
    26 
    27    Create a socket connection to a specified destination.  The 
    28    destination is specified as a list of AddrInfoType tuples as 
    29    returned from getaddrinfo(). 
    30 
    31    The arguments are, in order: 
    32 
    33    * ``family``: the address family, e.g. ``socket.AF_INET`` or 
    34        ``socket.AF_INET6``. 
    35    * ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or 
    36        ``socket.SOCK_DGRAM``. 
    37    * ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or 
    38        ``socket.IPPROTO_UDP``. 
    39    * ``canonname``: the canonical name of the address, e.g. 
    40        ``"www.python.org"``. 
    41    * ``sockaddr``: the socket address 
    42 
    43    This method is a coroutine which will try to establish the connection 
    44    in the background. When successful, the coroutine returns a 
    45    socket. 
    46 
    47    The expected use case is to use this method in conjunction with 
    48    loop.create_connection() to establish a connection to a server:: 
    49 
    50            socket = await start_connection(addr_infos) 
    51            transport, protocol = await loop.create_connection( 
    52                MyProtocol, sock=socket, ...) 
    53    """ 
    54    if not (current_loop := loop): 
    55        current_loop = asyncio.get_running_loop() 
    56 
    57    single_addr_info = len(addr_infos) == 1 
    58 
    59    if happy_eyeballs_delay is not None and interleave is None: 
    60        # If using happy eyeballs, default to interleave addresses by family 
    61        interleave = 1 
    62 
    63    if interleave and not single_addr_info: 
    64        addr_infos = _interleave_addrinfos(addr_infos, interleave) 
    65 
    66    sock: Optional[socket.socket] = None 
    67    # uvloop can raise RuntimeError instead of OSError 
    68    exceptions: List[List[Union[OSError, RuntimeError]]] = [] 
    69    if happy_eyeballs_delay is None or single_addr_info: 
    70        # not using happy eyeballs 
    71        for addrinfo in addr_infos: 
    72            try: 
    73                sock = await _connect_sock( 
    74                    current_loop, 
    75                    exceptions, 
    76                    addrinfo, 
    77                    local_addr_infos, 
    78                    None, 
    79                    socket_factory, 
    80                ) 
    81                break 
    82            except (RuntimeError, OSError): 
    83                continue 
    84    else:  # using happy eyeballs 
    85        open_sockets: Set[socket.socket] = set() 
    86        try: 
    87            sock, _, _ = await _staggered.staggered_race( 
    88                ( 
    89                    functools.partial( 
    90                        _connect_sock, 
    91                        current_loop, 
    92                        exceptions, 
    93                        addrinfo, 
    94                        local_addr_infos, 
    95                        open_sockets, 
    96                        socket_factory, 
    97                    ) 
    98                    for addrinfo in addr_infos 
    99                ), 
    100                happy_eyeballs_delay, 
    101            ) 
    102        finally: 
    103            # If we have a winner, staggered_race will 
    104            # cancel the other tasks, however there is a 
    105            # small race window where any of the other tasks 
    106            # can be done before they are cancelled which 
    107            # will leave the socket open. To avoid this problem 
    108            # we pass a set to _connect_sock to keep track of 
    109            # the open sockets and close them here if there 
    110            # are any "runner up" sockets. 
    111            for s in open_sockets: 
    112                if s is not sock: 
    113                    with contextlib.suppress(OSError): 
    114                        s.close() 
    115            open_sockets = None  # type: ignore[assignment] 
    116 
    117    if sock is None: 
    118        all_exceptions = [exc for sub in exceptions for exc in sub] 
    119        try: 
    120            first_exception = all_exceptions[0] 
    121            if len(all_exceptions) == 1: 
    122                raise first_exception 
    123            else: 
    124                # If they all have the same str(), raise one. 
    125                model = str(first_exception) 
    126                if all(str(exc) == model for exc in all_exceptions): 
    127                    raise first_exception 
    128                # Raise a combined exception so the user can see all 
    129                # the various error messages. 
    130                msg = "Multiple exceptions: {}".format( 
    131                    ", ".join(str(exc) for exc in all_exceptions) 
    132                ) 
    133                # If the errno is the same for all exceptions, raise 
    134                # an OSError with that errno. 
    135                if isinstance(first_exception, OSError): 
    136                    first_errno = first_exception.errno 
    137                    if all( 
    138                        isinstance(exc, OSError) and exc.errno == first_errno 
    139                        for exc in all_exceptions 
    140                    ): 
    141                        raise OSError(first_errno, msg) 
    142                elif isinstance(first_exception, RuntimeError) and all( 
    143                    isinstance(exc, RuntimeError) for exc in all_exceptions 
    144                ): 
    145                    raise RuntimeError(msg) 
    146                # We have a mix of OSError and RuntimeError 
    147                # so we have to pick which one to raise. 
    148                # and we raise OSError for compatibility 
    149                raise OSError(msg) 
    150        finally: 
    151            all_exceptions = None  # type: ignore[assignment] 
    152            exceptions = None  # type: ignore[assignment] 
    153 
    154    return sock 
    155 
    156 
    157async def _connect_sock( 
    158    loop: asyncio.AbstractEventLoop, 
    159    exceptions: List[List[Union[OSError, RuntimeError]]], 
    160    addr_info: AddrInfoType, 
    161    local_addr_infos: Optional[Sequence[AddrInfoType]] = None, 
    162    open_sockets: Optional[Set[socket.socket]] = None, 
    163    socket_factory: Optional[SocketFactoryType] = None, 
    164) -> socket.socket: 
    165    """ 
    166    Create, bind and connect one socket. 
    167 
    168    If open_sockets is passed, add the socket to the set of open sockets. 
    169    Any failure caught here will remove the socket from the set and close it. 
    170 
    171    Callers can use this set to close any sockets that are not the winner 
    172    of all staggered tasks in the result there are runner up sockets aka 
    173    multiple winners. 
    174    """ 
    175    my_exceptions: List[Union[OSError, RuntimeError]] = [] 
    176    exceptions.append(my_exceptions) 
    177    family, type_, proto, _, address = addr_info 
    178    sock = None 
    179    try: 
    180        if socket_factory is not None: 
    181            sock = socket_factory(addr_info) 
    182        else: 
    183            sock = socket.socket(family=family, type=type_, proto=proto) 
    184        if open_sockets is not None: 
    185            open_sockets.add(sock) 
    186        sock.setblocking(False) 
    187        if local_addr_infos is not None: 
    188            for lfamily, _, _, _, laddr in local_addr_infos: 
    189                # skip local addresses of different family 
    190                if lfamily != family: 
    191                    continue 
    192                try: 
    193                    sock.bind(laddr) 
    194                    break 
    195                except OSError as exc: 
    196                    msg = ( 
    197                        f"error while attempting to bind on " 
    198                        f"address {laddr!r}: " 
    199                        f"{(exc.strerror or '').lower()}" 
    200                    ) 
    201                    exc = OSError(exc.errno, msg) 
    202                    my_exceptions.append(exc) 
    203            else:  # all bind attempts failed 
    204                if my_exceptions: 
    205                    raise my_exceptions.pop() 
    206                else: 
    207                    raise OSError(f"no matching local address with {family=} found") 
    208        await loop.sock_connect(sock, address) 
    209        return sock 
    210    except (RuntimeError, OSError) as exc: 
    211        my_exceptions.append(exc) 
    212        if sock is not None: 
    213            if open_sockets is not None: 
    214                open_sockets.remove(sock) 
    215            try: 
    216                sock.close() 
    217            except OSError as e: 
    218                my_exceptions.append(e) 
    219                raise 
    220        raise 
    221    except: 
    222        if sock is not None: 
    223            if open_sockets is not None: 
    224                open_sockets.remove(sock) 
    225            try: 
    226                sock.close() 
    227            except OSError as e: 
    228                my_exceptions.append(e) 
    229                raise 
    230        raise 
    231    finally: 
    232        exceptions = my_exceptions = None  # type: ignore[assignment] 
    233 
    234 
    235def _interleave_addrinfos( 
    236    addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1 
    237) -> List[AddrInfoType]: 
    238    """Interleave list of addrinfo tuples by family.""" 
    239    # Group addresses by family 
    240    addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = ( 
    241        collections.OrderedDict() 
    242    ) 
    243    for addr in addrinfos: 
    244        family = addr[0] 
    245        if family not in addrinfos_by_family: 
    246            addrinfos_by_family[family] = [] 
    247        addrinfos_by_family[family].append(addr) 
    248    addrinfos_lists = list(addrinfos_by_family.values()) 
    249 
    250    reordered: List[AddrInfoType] = [] 
    251    if first_address_family_count > 1: 
    252        reordered.extend(addrinfos_lists[0][: first_address_family_count - 1]) 
    253        del addrinfos_lists[0][: first_address_family_count - 1] 
    254    reordered.extend( 
    255        a 
    256        for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) 
    257        if a is not None 
    258    ) 
    259    return reordered