Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohappyeyeballs/impl.py: 11%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

120 statements  

1"""Base implementation.""" 

2 

3import asyncio 

4import collections 

5import contextlib 

6import functools 

7import itertools 

8import socket 

9from collections.abc import Sequence 

10 

11from . import _staggered 

12from .types import AddrInfoType, SocketFactoryType 

13 

14 

15async def start_connection( 

16 addr_infos: Sequence[AddrInfoType], 

17 *, 

18 local_addr_infos: Sequence[AddrInfoType] | None = None, 

19 happy_eyeballs_delay: float | None = None, 

20 interleave: int | None = None, 

21 loop: asyncio.AbstractEventLoop | None = None, 

22 socket_factory: SocketFactoryType | None = None, 

23) -> socket.socket: 

24 """ 

25 Connect to a TCP server. 

26 

27 Create a socket connection to a specified destination. The 

28 destination is specified as a list of AddrInfoType tuples as 

29 returned from getaddrinfo(). 

30 

31 The arguments are, in order: 

32 

33 * ``family``: the address family, e.g. ``socket.AF_INET`` or 

34 ``socket.AF_INET6``. 

35 * ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or 

36 ``socket.SOCK_DGRAM``. 

37 * ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or 

38 ``socket.IPPROTO_UDP``. 

39 * ``canonname``: the canonical name of the address, e.g. 

40 ``"www.python.org"``. 

41 * ``sockaddr``: the socket address 

42 

43 This method is a coroutine which will try to establish the connection 

44 in the background. When successful, the coroutine returns a 

45 socket. 

46 

47 The expected use case is to use this method in conjunction with 

48 loop.create_connection() to establish a connection to a server:: 

49 

50 socket = await start_connection(addr_infos) 

51 transport, protocol = await loop.create_connection( 

52 MyProtocol, sock=socket, ...) 

53 """ 

54 if not addr_infos: 

55 raise ValueError("addr_infos must not be empty") 

56 

57 current_loop = loop or asyncio.get_running_loop() 

58 

59 single_addr_info = len(addr_infos) == 1 

60 

61 if happy_eyeballs_delay is not None and interleave is None: 

62 # If using happy eyeballs, default to interleave addresses by family 

63 interleave = 1 

64 

65 if interleave and not single_addr_info: 

66 addr_infos = _interleave_addrinfos(addr_infos, interleave) 

67 

68 sock: socket.socket | None = None 

69 # uvloop can raise RuntimeError instead of OSError 

70 exceptions: list[list[OSError | RuntimeError]] = [] 

71 if happy_eyeballs_delay is None or single_addr_info: 

72 # not using happy eyeballs 

73 for addrinfo in addr_infos: 

74 try: 

75 sock = await _connect_sock( 

76 current_loop, 

77 exceptions, 

78 addrinfo, 

79 local_addr_infos, 

80 None, 

81 socket_factory, 

82 ) 

83 break 

84 except (RuntimeError, OSError): 

85 continue 

86 else: # using happy eyeballs 

87 open_sockets: set[socket.socket] = set() 

88 try: 

89 sock, _, _ = await _staggered.staggered_race( 

90 ( 

91 functools.partial( 

92 _connect_sock, 

93 current_loop, 

94 exceptions, 

95 addrinfo, 

96 local_addr_infos, 

97 open_sockets, 

98 socket_factory, 

99 ) 

100 for addrinfo in addr_infos 

101 ), 

102 happy_eyeballs_delay, 

103 ) 

104 finally: 

105 # If we have a winner, staggered_race will 

106 # cancel the other tasks, however there is a 

107 # small race window where any of the other tasks 

108 # can be done before they are cancelled which 

109 # will leave the socket open. To avoid this problem 

110 # we pass a set to _connect_sock to keep track of 

111 # the open sockets and close them here if there 

112 # are any "runner up" sockets. 

113 for s in open_sockets: 

114 if s is not sock: 

115 with contextlib.suppress(OSError): 

116 s.close() 

117 open_sockets = None # type: ignore[assignment] 

118 

119 if sock is None: 

120 all_exceptions = [exc for sub in exceptions for exc in sub] 

121 try: 

122 first_exception = all_exceptions[0] 

123 if len(all_exceptions) == 1: 

124 raise first_exception 

125 else: 

126 # If they all have the same str(), raise one. 

127 model = str(first_exception) 

128 if all(str(exc) == model for exc in all_exceptions): 

129 raise first_exception 

130 # Raise a combined exception so the user can see all 

131 # the various error messages. 

132 msg = "Multiple exceptions: {}".format( 

133 ", ".join(str(exc) for exc in all_exceptions) 

134 ) 

135 # If the errno is the same for all exceptions, raise 

136 # an OSError with that errno. 

137 if isinstance(first_exception, OSError): 

138 first_errno = first_exception.errno 

139 if all( 

140 isinstance(exc, OSError) and exc.errno == first_errno 

141 for exc in all_exceptions 

142 ): 

143 raise OSError(first_errno, msg) 

144 elif isinstance(first_exception, RuntimeError) and all( 

145 isinstance(exc, RuntimeError) for exc in all_exceptions 

146 ): 

147 raise RuntimeError(msg) 

148 # We have a mix of OSError and RuntimeError 

149 # so we have to pick which one to raise. 

150 # and we raise OSError for compatibility 

151 raise OSError(msg) 

152 finally: 

153 all_exceptions = None # type: ignore[assignment] 

154 exceptions = None # type: ignore[assignment] 

155 

156 return sock 

157 

158 

159async def _connect_sock( 

160 loop: asyncio.AbstractEventLoop, 

161 exceptions: list[list[OSError | RuntimeError]], 

162 addr_info: AddrInfoType, 

163 local_addr_infos: Sequence[AddrInfoType] | None = None, 

164 open_sockets: set[socket.socket] | None = None, 

165 socket_factory: SocketFactoryType | None = None, 

166) -> socket.socket: 

167 """ 

168 Create, bind and connect one socket. 

169 

170 If open_sockets is passed, add the socket to the set of open sockets. 

171 Any failure caught here will remove the socket from the set and close it. 

172 

173 Callers can use this set to close any sockets that are not the winner 

174 of all staggered tasks in the result there are runner up sockets aka 

175 multiple winners. 

176 """ 

177 my_exceptions: list[OSError | RuntimeError] = [] 

178 exceptions.append(my_exceptions) 

179 family, type_, proto, _, address = addr_info 

180 sock = None 

181 try: 

182 if socket_factory is not None: 

183 sock = socket_factory(addr_info) 

184 else: 

185 sock = socket.socket(family=family, type=type_, proto=proto) 

186 if open_sockets is not None: 

187 open_sockets.add(sock) 

188 sock.setblocking(False) 

189 if local_addr_infos is not None: 

190 for lfamily, _, _, _, laddr in local_addr_infos: 

191 # skip local addresses of different family 

192 if lfamily != family: 

193 continue 

194 try: 

195 sock.bind(laddr) 

196 break 

197 except OSError as exc: 

198 msg = ( 

199 f"error while attempting to bind on " 

200 f"address {laddr!r}: " 

201 f"{(exc.strerror or '').lower()}" 

202 ) 

203 exc = OSError(exc.errno, msg) 

204 my_exceptions.append(exc) 

205 else: # all bind attempts failed 

206 if my_exceptions: 

207 raise my_exceptions.pop() 

208 else: 

209 raise OSError(f"no matching local address with {family=} found") 

210 await loop.sock_connect(sock, address) 

211 return sock 

212 except (RuntimeError, OSError) as exc: 

213 my_exceptions.append(exc) 

214 if sock is not None: 

215 if open_sockets is not None: 

216 open_sockets.remove(sock) 

217 try: 

218 sock.close() 

219 except OSError as e: 

220 my_exceptions.append(e) 

221 raise 

222 raise 

223 except: 

224 if sock is not None: 

225 if open_sockets is not None: 

226 open_sockets.remove(sock) 

227 try: 

228 sock.close() 

229 except OSError as e: 

230 my_exceptions.append(e) 

231 raise 

232 raise 

233 finally: 

234 exceptions = my_exceptions = None # type: ignore[assignment] 

235 

236 

237def _interleave_addrinfos( 

238 addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1 

239) -> list[AddrInfoType]: 

240 """Interleave list of addrinfo tuples by family.""" 

241 # Group addresses by family 

242 addrinfos_by_family: collections.OrderedDict[int, list[AddrInfoType]] = ( 

243 collections.OrderedDict() 

244 ) 

245 for addr in addrinfos: 

246 family = addr[0] 

247 if family not in addrinfos_by_family: 

248 addrinfos_by_family[family] = [] 

249 addrinfos_by_family[family].append(addr) 

250 addrinfos_lists = list(addrinfos_by_family.values()) 

251 

252 reordered: list[AddrInfoType] = [] 

253 if first_address_family_count > 1: 

254 reordered.extend(addrinfos_lists[0][: first_address_family_count - 1]) 

255 del addrinfos_lists[0][: first_address_family_count - 1] 

256 reordered.extend( 

257 a 

258 for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) 

259 if a is not None 

260 ) 

261 return reordered