Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohappyeyeballs/impl.py: 11%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

119 statements  

1"""Base implementation.""" 

2 

3import asyncio 

4import collections 

5import contextlib 

6import functools 

7import itertools 

8import socket 

9from typing import List, Optional, Sequence, Set, Union 

10 

11from . import _staggered 

12from .types import AddrInfoType, SocketFactoryType 

13 

14 

15async def start_connection( 

16 addr_infos: Sequence[AddrInfoType], 

17 *, 

18 local_addr_infos: Optional[Sequence[AddrInfoType]] = None, 

19 happy_eyeballs_delay: Optional[float] = None, 

20 interleave: Optional[int] = None, 

21 loop: Optional[asyncio.AbstractEventLoop] = None, 

22 socket_factory: Optional[SocketFactoryType] = None, 

23) -> socket.socket: 

24 """ 

25 Connect to a TCP server. 

26 

27 Create a socket connection to a specified destination. The 

28 destination is specified as a list of AddrInfoType tuples as 

29 returned from getaddrinfo(). 

30 

31 The arguments are, in order: 

32 

33 * ``family``: the address family, e.g. ``socket.AF_INET`` or 

34 ``socket.AF_INET6``. 

35 * ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or 

36 ``socket.SOCK_DGRAM``. 

37 * ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or 

38 ``socket.IPPROTO_UDP``. 

39 * ``canonname``: the canonical name of the address, e.g. 

40 ``"www.python.org"``. 

41 * ``sockaddr``: the socket address 

42 

43 This method is a coroutine which will try to establish the connection 

44 in the background. When successful, the coroutine returns a 

45 socket. 

46 

47 The expected use case is to use this method in conjunction with 

48 loop.create_connection() to establish a connection to a server:: 

49 

50 socket = await start_connection(addr_infos) 

51 transport, protocol = await loop.create_connection( 

52 MyProtocol, sock=socket, ...) 

53 """ 

54 if not (current_loop := loop): 

55 current_loop = asyncio.get_running_loop() 

56 

57 single_addr_info = len(addr_infos) == 1 

58 

59 if happy_eyeballs_delay is not None and interleave is None: 

60 # If using happy eyeballs, default to interleave addresses by family 

61 interleave = 1 

62 

63 if interleave and not single_addr_info: 

64 addr_infos = _interleave_addrinfos(addr_infos, interleave) 

65 

66 sock: Optional[socket.socket] = None 

67 # uvloop can raise RuntimeError instead of OSError 

68 exceptions: List[List[Union[OSError, RuntimeError]]] = [] 

69 if happy_eyeballs_delay is None or single_addr_info: 

70 # not using happy eyeballs 

71 for addrinfo in addr_infos: 

72 try: 

73 sock = await _connect_sock( 

74 current_loop, 

75 exceptions, 

76 addrinfo, 

77 local_addr_infos, 

78 None, 

79 socket_factory, 

80 ) 

81 break 

82 except (RuntimeError, OSError): 

83 continue 

84 else: # using happy eyeballs 

85 open_sockets: Set[socket.socket] = set() 

86 try: 

87 sock, _, _ = await _staggered.staggered_race( 

88 ( 

89 functools.partial( 

90 _connect_sock, 

91 current_loop, 

92 exceptions, 

93 addrinfo, 

94 local_addr_infos, 

95 open_sockets, 

96 socket_factory, 

97 ) 

98 for addrinfo in addr_infos 

99 ), 

100 happy_eyeballs_delay, 

101 ) 

102 finally: 

103 # If we have a winner, staggered_race will 

104 # cancel the other tasks, however there is a 

105 # small race window where any of the other tasks 

106 # can be done before they are cancelled which 

107 # will leave the socket open. To avoid this problem 

108 # we pass a set to _connect_sock to keep track of 

109 # the open sockets and close them here if there 

110 # are any "runner up" sockets. 

111 for s in open_sockets: 

112 if s is not sock: 

113 with contextlib.suppress(OSError): 

114 s.close() 

115 open_sockets = None # type: ignore[assignment] 

116 

117 if sock is None: 

118 all_exceptions = [exc for sub in exceptions for exc in sub] 

119 try: 

120 first_exception = all_exceptions[0] 

121 if len(all_exceptions) == 1: 

122 raise first_exception 

123 else: 

124 # If they all have the same str(), raise one. 

125 model = str(first_exception) 

126 if all(str(exc) == model for exc in all_exceptions): 

127 raise first_exception 

128 # Raise a combined exception so the user can see all 

129 # the various error messages. 

130 msg = "Multiple exceptions: {}".format( 

131 ", ".join(str(exc) for exc in all_exceptions) 

132 ) 

133 # If the errno is the same for all exceptions, raise 

134 # an OSError with that errno. 

135 if isinstance(first_exception, OSError): 

136 first_errno = first_exception.errno 

137 if all( 

138 isinstance(exc, OSError) and exc.errno == first_errno 

139 for exc in all_exceptions 

140 ): 

141 raise OSError(first_errno, msg) 

142 elif isinstance(first_exception, RuntimeError) and all( 

143 isinstance(exc, RuntimeError) for exc in all_exceptions 

144 ): 

145 raise RuntimeError(msg) 

146 # We have a mix of OSError and RuntimeError 

147 # so we have to pick which one to raise. 

148 # and we raise OSError for compatibility 

149 raise OSError(msg) 

150 finally: 

151 all_exceptions = None # type: ignore[assignment] 

152 exceptions = None # type: ignore[assignment] 

153 

154 return sock 

155 

156 

157async def _connect_sock( 

158 loop: asyncio.AbstractEventLoop, 

159 exceptions: List[List[Union[OSError, RuntimeError]]], 

160 addr_info: AddrInfoType, 

161 local_addr_infos: Optional[Sequence[AddrInfoType]] = None, 

162 open_sockets: Optional[Set[socket.socket]] = None, 

163 socket_factory: Optional[SocketFactoryType] = None, 

164) -> socket.socket: 

165 """ 

166 Create, bind and connect one socket. 

167 

168 If open_sockets is passed, add the socket to the set of open sockets. 

169 Any failure caught here will remove the socket from the set and close it. 

170 

171 Callers can use this set to close any sockets that are not the winner 

172 of all staggered tasks in the result there are runner up sockets aka 

173 multiple winners. 

174 """ 

175 my_exceptions: List[Union[OSError, RuntimeError]] = [] 

176 exceptions.append(my_exceptions) 

177 family, type_, proto, _, address = addr_info 

178 sock = None 

179 try: 

180 if socket_factory is not None: 

181 sock = socket_factory(addr_info) 

182 else: 

183 sock = socket.socket(family=family, type=type_, proto=proto) 

184 if open_sockets is not None: 

185 open_sockets.add(sock) 

186 sock.setblocking(False) 

187 if local_addr_infos is not None: 

188 for lfamily, _, _, _, laddr in local_addr_infos: 

189 # skip local addresses of different family 

190 if lfamily != family: 

191 continue 

192 try: 

193 sock.bind(laddr) 

194 break 

195 except OSError as exc: 

196 msg = ( 

197 f"error while attempting to bind on " 

198 f"address {laddr!r}: " 

199 f"{(exc.strerror or '').lower()}" 

200 ) 

201 exc = OSError(exc.errno, msg) 

202 my_exceptions.append(exc) 

203 else: # all bind attempts failed 

204 if my_exceptions: 

205 raise my_exceptions.pop() 

206 else: 

207 raise OSError(f"no matching local address with {family=} found") 

208 await loop.sock_connect(sock, address) 

209 return sock 

210 except (RuntimeError, OSError) as exc: 

211 my_exceptions.append(exc) 

212 if sock is not None: 

213 if open_sockets is not None: 

214 open_sockets.remove(sock) 

215 try: 

216 sock.close() 

217 except OSError as e: 

218 my_exceptions.append(e) 

219 raise 

220 raise 

221 except: 

222 if sock is not None: 

223 if open_sockets is not None: 

224 open_sockets.remove(sock) 

225 try: 

226 sock.close() 

227 except OSError as e: 

228 my_exceptions.append(e) 

229 raise 

230 raise 

231 finally: 

232 exceptions = my_exceptions = None # type: ignore[assignment] 

233 

234 

235def _interleave_addrinfos( 

236 addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1 

237) -> List[AddrInfoType]: 

238 """Interleave list of addrinfo tuples by family.""" 

239 # Group addresses by family 

240 addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = ( 

241 collections.OrderedDict() 

242 ) 

243 for addr in addrinfos: 

244 family = addr[0] 

245 if family not in addrinfos_by_family: 

246 addrinfos_by_family[family] = [] 

247 addrinfos_by_family[family].append(addr) 

248 addrinfos_lists = list(addrinfos_by_family.values()) 

249 

250 reordered: List[AddrInfoType] = [] 

251 if first_address_family_count > 1: 

252 reordered.extend(addrinfos_lists[0][: first_address_family_count - 1]) 

253 del addrinfos_lists[0][: first_address_family_count - 1] 

254 reordered.extend( 

255 a 

256 for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) 

257 if a is not None 

258 ) 

259 return reordered