1"""Base implementation."""
2
3import asyncio
4import collections
5import contextlib
6import functools
7import itertools
8import socket
9from collections.abc import Sequence
10
11from . import _staggered
12from .types import AddrInfoType, SocketFactoryType
13
14
15async def start_connection(
16 addr_infos: Sequence[AddrInfoType],
17 *,
18 local_addr_infos: Sequence[AddrInfoType] | None = None,
19 happy_eyeballs_delay: float | None = None,
20 interleave: int | None = None,
21 loop: asyncio.AbstractEventLoop | None = None,
22 socket_factory: SocketFactoryType | None = None,
23) -> socket.socket:
24 """
25 Connect to a TCP server.
26
27 Create a socket connection to a specified destination. The
28 destination is specified as a list of AddrInfoType tuples as
29 returned from getaddrinfo().
30
31 The arguments are, in order:
32
33 * ``family``: the address family, e.g. ``socket.AF_INET`` or
34 ``socket.AF_INET6``.
35 * ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or
36 ``socket.SOCK_DGRAM``.
37 * ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or
38 ``socket.IPPROTO_UDP``.
39 * ``canonname``: the canonical name of the address, e.g.
40 ``"www.python.org"``.
41 * ``sockaddr``: the socket address
42
43 This method is a coroutine which will try to establish the connection
44 in the background. When successful, the coroutine returns a
45 socket.
46
47 The expected use case is to use this method in conjunction with
48 loop.create_connection() to establish a connection to a server::
49
50 socket = await start_connection(addr_infos)
51 transport, protocol = await loop.create_connection(
52 MyProtocol, sock=socket, ...)
53 """
54 if not addr_infos:
55 raise ValueError("addr_infos must not be empty")
56
57 current_loop = loop or asyncio.get_running_loop()
58
59 single_addr_info = len(addr_infos) == 1
60
61 if happy_eyeballs_delay is not None and interleave is None:
62 # If using happy eyeballs, default to interleave addresses by family
63 interleave = 1
64
65 if interleave and not single_addr_info:
66 addr_infos = _interleave_addrinfos(addr_infos, interleave)
67
68 sock: socket.socket | None = None
69 # uvloop can raise RuntimeError instead of OSError
70 exceptions: list[list[OSError | RuntimeError]] = []
71 if happy_eyeballs_delay is None or single_addr_info:
72 # not using happy eyeballs
73 for addrinfo in addr_infos:
74 try:
75 sock = await _connect_sock(
76 current_loop,
77 exceptions,
78 addrinfo,
79 local_addr_infos,
80 None,
81 socket_factory,
82 )
83 break
84 except (RuntimeError, OSError):
85 continue
86 else: # using happy eyeballs
87 open_sockets: set[socket.socket] = set()
88 try:
89 sock, _, _ = await _staggered.staggered_race(
90 (
91 functools.partial(
92 _connect_sock,
93 current_loop,
94 exceptions,
95 addrinfo,
96 local_addr_infos,
97 open_sockets,
98 socket_factory,
99 )
100 for addrinfo in addr_infos
101 ),
102 happy_eyeballs_delay,
103 )
104 finally:
105 # If we have a winner, staggered_race will
106 # cancel the other tasks, however there is a
107 # small race window where any of the other tasks
108 # can be done before they are cancelled which
109 # will leave the socket open. To avoid this problem
110 # we pass a set to _connect_sock to keep track of
111 # the open sockets and close them here if there
112 # are any "runner up" sockets.
113 for s in open_sockets:
114 if s is not sock:
115 with contextlib.suppress(OSError):
116 s.close()
117 open_sockets = None # type: ignore[assignment]
118
119 if sock is None:
120 all_exceptions = [exc for sub in exceptions for exc in sub]
121 try:
122 first_exception = all_exceptions[0]
123 if len(all_exceptions) == 1:
124 raise first_exception
125 else:
126 # If they all have the same str(), raise one.
127 model = str(first_exception)
128 if all(str(exc) == model for exc in all_exceptions):
129 raise first_exception
130 # Raise a combined exception so the user can see all
131 # the various error messages.
132 msg = "Multiple exceptions: {}".format(
133 ", ".join(str(exc) for exc in all_exceptions)
134 )
135 # If the errno is the same for all exceptions, raise
136 # an OSError with that errno.
137 if isinstance(first_exception, OSError):
138 first_errno = first_exception.errno
139 if all(
140 isinstance(exc, OSError) and exc.errno == first_errno
141 for exc in all_exceptions
142 ):
143 raise OSError(first_errno, msg)
144 elif isinstance(first_exception, RuntimeError) and all(
145 isinstance(exc, RuntimeError) for exc in all_exceptions
146 ):
147 raise RuntimeError(msg)
148 # We have a mix of OSError and RuntimeError
149 # so we have to pick which one to raise.
150 # and we raise OSError for compatibility
151 raise OSError(msg)
152 finally:
153 all_exceptions = None # type: ignore[assignment]
154 exceptions = None # type: ignore[assignment]
155
156 return sock
157
158
159async def _connect_sock(
160 loop: asyncio.AbstractEventLoop,
161 exceptions: list[list[OSError | RuntimeError]],
162 addr_info: AddrInfoType,
163 local_addr_infos: Sequence[AddrInfoType] | None = None,
164 open_sockets: set[socket.socket] | None = None,
165 socket_factory: SocketFactoryType | None = None,
166) -> socket.socket:
167 """
168 Create, bind and connect one socket.
169
170 If open_sockets is passed, add the socket to the set of open sockets.
171 Any failure caught here will remove the socket from the set and close it.
172
173 Callers can use this set to close any sockets that are not the winner
174 of all staggered tasks in the result there are runner up sockets aka
175 multiple winners.
176 """
177 my_exceptions: list[OSError | RuntimeError] = []
178 exceptions.append(my_exceptions)
179 family, type_, proto, _, address = addr_info
180 sock = None
181 try:
182 if socket_factory is not None:
183 sock = socket_factory(addr_info)
184 else:
185 sock = socket.socket(family=family, type=type_, proto=proto)
186 if open_sockets is not None:
187 open_sockets.add(sock)
188 sock.setblocking(False)
189 if local_addr_infos is not None:
190 for lfamily, _, _, _, laddr in local_addr_infos:
191 # skip local addresses of different family
192 if lfamily != family:
193 continue
194 try:
195 sock.bind(laddr)
196 break
197 except OSError as exc:
198 msg = (
199 f"error while attempting to bind on "
200 f"address {laddr!r}: "
201 f"{(exc.strerror or '').lower()}"
202 )
203 exc = OSError(exc.errno, msg)
204 my_exceptions.append(exc)
205 else: # all bind attempts failed
206 if my_exceptions:
207 raise my_exceptions.pop()
208 else:
209 raise OSError(f"no matching local address with {family=} found")
210 await loop.sock_connect(sock, address)
211 return sock
212 except (RuntimeError, OSError) as exc:
213 my_exceptions.append(exc)
214 if sock is not None:
215 if open_sockets is not None:
216 open_sockets.remove(sock)
217 try:
218 sock.close()
219 except OSError as e:
220 my_exceptions.append(e)
221 raise
222 raise
223 except:
224 if sock is not None:
225 if open_sockets is not None:
226 open_sockets.remove(sock)
227 try:
228 sock.close()
229 except OSError as e:
230 my_exceptions.append(e)
231 raise
232 raise
233 finally:
234 exceptions = my_exceptions = None # type: ignore[assignment]
235
236
237def _interleave_addrinfos(
238 addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
239) -> list[AddrInfoType]:
240 """Interleave list of addrinfo tuples by family."""
241 # Group addresses by family
242 addrinfos_by_family: collections.OrderedDict[int, list[AddrInfoType]] = (
243 collections.OrderedDict()
244 )
245 for addr in addrinfos:
246 family = addr[0]
247 if family not in addrinfos_by_family:
248 addrinfos_by_family[family] = []
249 addrinfos_by_family[family].append(addr)
250 addrinfos_lists = list(addrinfos_by_family.values())
251
252 reordered: list[AddrInfoType] = []
253 if first_address_family_count > 1:
254 reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
255 del addrinfos_lists[0][: first_address_family_count - 1]
256 reordered.extend(
257 a
258 for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists))
259 if a is not None
260 )
261 return reordered