1#
2# Copyright 2014 Facebook
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may
5# not use this file except in compliance with the License. You may obtain
6# a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations
14# under the License.
15
16"""A non-blocking TCP connection factory.
17"""
18
19import functools
20import socket
21import numbers
22import datetime
23import ssl
24import typing
25
26from tornado.concurrent import Future, future_add_done_callback
27from tornado.ioloop import IOLoop
28from tornado.iostream import IOStream
29from tornado import gen
30from tornado.netutil import Resolver
31from tornado.gen import TimeoutError
32
33from typing import Any, Union, Dict, Tuple, List, Callable, Iterator, Optional
34
35if typing.TYPE_CHECKING:
36 from typing import Set # noqa(F401)
37
38_INITIAL_CONNECT_TIMEOUT = 0.3
39
40
41class _Connector:
42 """A stateless implementation of the "Happy Eyeballs" algorithm.
43
44 "Happy Eyeballs" is documented in RFC6555 as the recommended practice
45 for when both IPv4 and IPv6 addresses are available.
46
47 In this implementation, we partition the addresses by family, and
48 make the first connection attempt to whichever address was
49 returned first by ``getaddrinfo``. If that connection fails or
50 times out, we begin a connection in parallel to the first address
51 of the other family. If there are additional failures we retry
52 with other addresses, keeping one connection attempt per family
53 in flight at a time.
54
55 http://tools.ietf.org/html/rfc6555
56
57 """
58
59 def __init__(
60 self,
61 addrinfo: List[Tuple],
62 connect: Callable[
63 [socket.AddressFamily, Tuple], Tuple[IOStream, "Future[IOStream]"]
64 ],
65 ) -> None:
66 self.io_loop = IOLoop.current()
67 self.connect = connect
68
69 self.future = (
70 Future()
71 ) # type: Future[Tuple[socket.AddressFamily, Any, IOStream]]
72 self.timeout = None # type: Optional[object]
73 self.connect_timeout = None # type: Optional[object]
74 self.last_error = None # type: Optional[Exception]
75 self.remaining = len(addrinfo)
76 self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
77 self.streams = set() # type: Set[IOStream]
78
79 @staticmethod
80 def split(
81 addrinfo: List[Tuple],
82 ) -> Tuple[
83 List[Tuple[socket.AddressFamily, Tuple]],
84 List[Tuple[socket.AddressFamily, Tuple]],
85 ]:
86 """Partition the ``addrinfo`` list by address family.
87
88 Returns two lists. The first list contains the first entry from
89 ``addrinfo`` and all others with the same family, and the
90 second list contains all other addresses (normally one list will
91 be AF_INET and the other AF_INET6, although non-standard resolvers
92 may return additional families).
93 """
94 primary = []
95 secondary = []
96 primary_af = addrinfo[0][0]
97 for af, addr in addrinfo:
98 if af == primary_af:
99 primary.append((af, addr))
100 else:
101 secondary.append((af, addr))
102 return primary, secondary
103
104 def start(
105 self,
106 timeout: float = _INITIAL_CONNECT_TIMEOUT,
107 connect_timeout: Optional[Union[float, datetime.timedelta]] = None,
108 ) -> "Future[Tuple[socket.AddressFamily, Any, IOStream]]":
109 self.try_connect(iter(self.primary_addrs))
110 self.set_timeout(timeout)
111 if connect_timeout is not None:
112 self.set_connect_timeout(connect_timeout)
113 return self.future
114
115 def try_connect(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]]) -> None:
116 try:
117 af, addr = next(addrs)
118 except StopIteration:
119 # We've reached the end of our queue, but the other queue
120 # might still be working. Send a final error on the future
121 # only when both queues are finished.
122 if self.remaining == 0 and not self.future.done():
123 self.future.set_exception(
124 self.last_error or IOError("connection failed")
125 )
126 return
127 stream, future = self.connect(af, addr)
128 self.streams.add(stream)
129 future_add_done_callback(
130 future, functools.partial(self.on_connect_done, addrs, af, addr)
131 )
132
133 def on_connect_done(
134 self,
135 addrs: Iterator[Tuple[socket.AddressFamily, Tuple]],
136 af: socket.AddressFamily,
137 addr: Tuple,
138 future: "Future[IOStream]",
139 ) -> None:
140 self.remaining -= 1
141 try:
142 stream = future.result()
143 except Exception as e:
144 if self.future.done():
145 return
146 # Error: try again (but remember what happened so we have an
147 # error to raise in the end)
148 self.last_error = e
149 self.try_connect(addrs)
150 if self.timeout is not None:
151 # If the first attempt failed, don't wait for the
152 # timeout to try an address from the secondary queue.
153 self.io_loop.remove_timeout(self.timeout)
154 self.on_timeout()
155 return
156 self.clear_timeouts()
157 if self.future.done():
158 # This is a late arrival; just drop it.
159 stream.close()
160 else:
161 self.streams.discard(stream)
162 self.future.set_result((af, addr, stream))
163 self.close_streams()
164
165 def set_timeout(self, timeout: float) -> None:
166 self.timeout = self.io_loop.add_timeout(
167 self.io_loop.time() + timeout, self.on_timeout
168 )
169
170 def on_timeout(self) -> None:
171 self.timeout = None
172 if not self.future.done():
173 self.try_connect(iter(self.secondary_addrs))
174
175 def clear_timeout(self) -> None:
176 if self.timeout is not None:
177 self.io_loop.remove_timeout(self.timeout)
178
179 def set_connect_timeout(
180 self, connect_timeout: Union[float, datetime.timedelta]
181 ) -> None:
182 self.connect_timeout = self.io_loop.add_timeout(
183 connect_timeout, self.on_connect_timeout
184 )
185
186 def on_connect_timeout(self) -> None:
187 if not self.future.done():
188 self.future.set_exception(TimeoutError())
189 self.close_streams()
190
191 def clear_timeouts(self) -> None:
192 if self.timeout is not None:
193 self.io_loop.remove_timeout(self.timeout)
194 if self.connect_timeout is not None:
195 self.io_loop.remove_timeout(self.connect_timeout)
196
197 def close_streams(self) -> None:
198 for stream in self.streams:
199 stream.close()
200
201
202class TCPClient:
203 """A non-blocking TCP connection factory.
204
205 .. versionchanged:: 5.0
206 The ``io_loop`` argument (deprecated since version 4.1) has been removed.
207 """
208
209 def __init__(self, resolver: Optional[Resolver] = None) -> None:
210 if resolver is not None:
211 self.resolver = resolver
212 self._own_resolver = False
213 else:
214 self.resolver = Resolver()
215 self._own_resolver = True
216
217 def close(self) -> None:
218 if self._own_resolver:
219 self.resolver.close()
220
221 async def connect(
222 self,
223 host: str,
224 port: int,
225 af: socket.AddressFamily = socket.AF_UNSPEC,
226 ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None,
227 max_buffer_size: Optional[int] = None,
228 source_ip: Optional[str] = None,
229 source_port: Optional[int] = None,
230 timeout: Optional[Union[float, datetime.timedelta]] = None,
231 ) -> IOStream:
232 """Connect to the given host and port.
233
234 Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
235 ``ssl_options`` is not None).
236
237 Using the ``source_ip`` kwarg, one can specify the source
238 IP address to use when establishing the connection.
239 In case the user needs to resolve and
240 use a specific interface, it has to be handled outside
241 of Tornado as this depends very much on the platform.
242
243 Raises `TimeoutError` if the input future does not complete before
244 ``timeout``, which may be specified in any form allowed by
245 `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
246 relative to `.IOLoop.time`)
247
248 Similarly, when the user requires a certain source port, it can
249 be specified using the ``source_port`` arg.
250
251 .. versionchanged:: 4.5
252 Added the ``source_ip`` and ``source_port`` arguments.
253
254 .. versionchanged:: 5.0
255 Added the ``timeout`` argument.
256 """
257 if timeout is not None:
258 if isinstance(timeout, numbers.Real):
259 timeout = IOLoop.current().time() + timeout
260 elif isinstance(timeout, datetime.timedelta):
261 timeout = IOLoop.current().time() + timeout.total_seconds()
262 else:
263 raise TypeError("Unsupported timeout %r" % timeout)
264 if timeout is not None:
265 addrinfo = await gen.with_timeout(
266 timeout, self.resolver.resolve(host, port, af)
267 )
268 else:
269 addrinfo = await self.resolver.resolve(host, port, af)
270 connector = _Connector(
271 addrinfo,
272 functools.partial(
273 self._create_stream,
274 max_buffer_size,
275 source_ip=source_ip,
276 source_port=source_port,
277 ),
278 )
279 af, addr, stream = await connector.start(connect_timeout=timeout)
280 # TODO: For better performance we could cache the (af, addr)
281 # information here and re-use it on subsequent connections to
282 # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
283 if ssl_options is not None:
284 if timeout is not None:
285 stream = await gen.with_timeout(
286 timeout,
287 stream.start_tls(
288 False, ssl_options=ssl_options, server_hostname=host
289 ),
290 )
291 else:
292 stream = await stream.start_tls(
293 False, ssl_options=ssl_options, server_hostname=host
294 )
295 return stream
296
297 def _create_stream(
298 self,
299 max_buffer_size: int,
300 af: socket.AddressFamily,
301 addr: Tuple,
302 source_ip: Optional[str] = None,
303 source_port: Optional[int] = None,
304 ) -> Tuple[IOStream, "Future[IOStream]"]:
305 # Always connect in plaintext; we'll convert to ssl if necessary
306 # after one connection has completed.
307 source_port_bind = source_port if isinstance(source_port, int) else 0
308 source_ip_bind = source_ip
309 if source_port_bind and not source_ip:
310 # User required a specific port, but did not specify
311 # a certain source IP, will bind to the default loopback.
312 source_ip_bind = "::1" if af == socket.AF_INET6 else "127.0.0.1"
313 # Trying to use the same address family as the requested af socket:
314 # - 127.0.0.1 for IPv4
315 # - ::1 for IPv6
316 socket_obj = socket.socket(af)
317 if source_port_bind or source_ip_bind:
318 # If the user requires binding also to a specific IP/port.
319 try:
320 socket_obj.bind((source_ip_bind, source_port_bind))
321 except OSError:
322 socket_obj.close()
323 # Fail loudly if unable to use the IP/port.
324 raise
325 try:
326 stream = IOStream(socket_obj, max_buffer_size=max_buffer_size)
327 except OSError as e:
328 fu = Future() # type: Future[IOStream]
329 fu.set_exception(e)
330 return stream, fu
331 else:
332 return stream, stream.connect(addr)