1import asyncio
2import functools
3import random
4import sys
5import traceback
6import warnings
7from collections import defaultdict, deque
8from contextlib import suppress
9from http import HTTPStatus
10from http.cookies import SimpleCookie
11from itertools import cycle, islice
12from time import monotonic
13from types import TracebackType
14from typing import (
15 TYPE_CHECKING,
16 Any,
17 Awaitable,
18 Callable,
19 DefaultDict,
20 Dict,
21 Iterator,
22 List,
23 Literal,
24 Optional,
25 Set,
26 Tuple,
27 Type,
28 Union,
29 cast,
30)
31
32import attr
33
34from . import hdrs, helpers
35from .abc import AbstractResolver
36from .client_exceptions import (
37 ClientConnectionError,
38 ClientConnectorCertificateError,
39 ClientConnectorError,
40 ClientConnectorSSLError,
41 ClientHttpProxyError,
42 ClientProxyConnectionError,
43 ServerFingerprintMismatch,
44 UnixClientConnectorError,
45 cert_errors,
46 ssl_errors,
47)
48from .client_proto import ResponseHandler
49from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
50from .helpers import ceil_timeout, get_running_loop, is_ip_address, noop, sentinel
51from .locks import EventResultOrError
52from .resolver import DefaultResolver
53
54try:
55 import ssl
56
57 SSLContext = ssl.SSLContext
58except ImportError: # pragma: no cover
59 ssl = None # type: ignore[assignment]
60 SSLContext = object # type: ignore[misc,assignment]
61
62
63__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
64
65
66if TYPE_CHECKING:
67 from .client import ClientTimeout
68 from .client_reqrep import ConnectionKey
69 from .tracing import Trace
70
71
72class _DeprecationWaiter:
73 __slots__ = ("_awaitable", "_awaited")
74
75 def __init__(self, awaitable: Awaitable[Any]) -> None:
76 self._awaitable = awaitable
77 self._awaited = False
78
79 def __await__(self) -> Any:
80 self._awaited = True
81 return self._awaitable.__await__()
82
83 def __del__(self) -> None:
84 if not self._awaited:
85 warnings.warn(
86 "Connector.close() is a coroutine, "
87 "please use await connector.close()",
88 DeprecationWarning,
89 )
90
91
92class Connection:
93
94 _source_traceback = None
95 _transport = None
96
97 def __init__(
98 self,
99 connector: "BaseConnector",
100 key: "ConnectionKey",
101 protocol: ResponseHandler,
102 loop: asyncio.AbstractEventLoop,
103 ) -> None:
104 self._key = key
105 self._connector = connector
106 self._loop = loop
107 self._protocol: Optional[ResponseHandler] = protocol
108 self._callbacks: List[Callable[[], None]] = []
109
110 if loop.get_debug():
111 self._source_traceback = traceback.extract_stack(sys._getframe(1))
112
113 def __repr__(self) -> str:
114 return f"Connection<{self._key}>"
115
116 def __del__(self, _warnings: Any = warnings) -> None:
117 if self._protocol is not None:
118 kwargs = {"source": self}
119 _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs)
120 if self._loop.is_closed():
121 return
122
123 self._connector._release(self._key, self._protocol, should_close=True)
124
125 context = {"client_connection": self, "message": "Unclosed connection"}
126 if self._source_traceback is not None:
127 context["source_traceback"] = self._source_traceback
128 self._loop.call_exception_handler(context)
129
130 def __bool__(self) -> Literal[True]:
131 """Force subclasses to not be falsy, to make checks simpler."""
132 return True
133
134 @property
135 def loop(self) -> asyncio.AbstractEventLoop:
136 warnings.warn(
137 "connector.loop property is deprecated", DeprecationWarning, stacklevel=2
138 )
139 return self._loop
140
141 @property
142 def transport(self) -> Optional[asyncio.Transport]:
143 if self._protocol is None:
144 return None
145 return self._protocol.transport
146
147 @property
148 def protocol(self) -> Optional[ResponseHandler]:
149 return self._protocol
150
151 def add_callback(self, callback: Callable[[], None]) -> None:
152 if callback is not None:
153 self._callbacks.append(callback)
154
155 def _notify_release(self) -> None:
156 callbacks, self._callbacks = self._callbacks[:], []
157
158 for cb in callbacks:
159 with suppress(Exception):
160 cb()
161
162 def close(self) -> None:
163 self._notify_release()
164
165 if self._protocol is not None:
166 self._connector._release(self._key, self._protocol, should_close=True)
167 self._protocol = None
168
169 def release(self) -> None:
170 self._notify_release()
171
172 if self._protocol is not None:
173 self._connector._release(
174 self._key, self._protocol, should_close=self._protocol.should_close
175 )
176 self._protocol = None
177
178 @property
179 def closed(self) -> bool:
180 return self._protocol is None or not self._protocol.is_connected()
181
182
183class _TransportPlaceholder:
184 """placeholder for BaseConnector.connect function"""
185
186 def close(self) -> None:
187 pass
188
189
190class BaseConnector:
191 """Base connector class.
192
193 keepalive_timeout - (optional) Keep-alive timeout.
194 force_close - Set to True to force close and do reconnect
195 after each request (and between redirects).
196 limit - The total number of simultaneous connections.
197 limit_per_host - Number of simultaneous connections to one host.
198 enable_cleanup_closed - Enables clean-up closed ssl transports.
199 Disabled by default.
200 timeout_ceil_threshold - Trigger ceiling of timeout values when
201 it's above timeout_ceil_threshold.
202 loop - Optional event loop.
203 """
204
205 _closed = True # prevent AttributeError in __del__ if ctor was failed
206 _source_traceback = None
207
208 # abort transport after 2 seconds (cleanup broken connections)
209 _cleanup_closed_period = 2.0
210
211 def __init__(
212 self,
213 *,
214 keepalive_timeout: Union[object, None, float] = sentinel,
215 force_close: bool = False,
216 limit: int = 100,
217 limit_per_host: int = 0,
218 enable_cleanup_closed: bool = False,
219 loop: Optional[asyncio.AbstractEventLoop] = None,
220 timeout_ceil_threshold: float = 5,
221 ) -> None:
222
223 if force_close:
224 if keepalive_timeout is not None and keepalive_timeout is not sentinel:
225 raise ValueError(
226 "keepalive_timeout cannot " "be set if force_close is True"
227 )
228 else:
229 if keepalive_timeout is sentinel:
230 keepalive_timeout = 15.0
231
232 loop = get_running_loop(loop)
233 self._timeout_ceil_threshold = timeout_ceil_threshold
234
235 self._closed = False
236 if loop.get_debug():
237 self._source_traceback = traceback.extract_stack(sys._getframe(1))
238
239 self._conns: Dict[ConnectionKey, List[Tuple[ResponseHandler, float]]] = {}
240 self._limit = limit
241 self._limit_per_host = limit_per_host
242 self._acquired: Set[ResponseHandler] = set()
243 self._acquired_per_host: DefaultDict[
244 ConnectionKey, Set[ResponseHandler]
245 ] = defaultdict(set)
246 self._keepalive_timeout = cast(float, keepalive_timeout)
247 self._force_close = force_close
248
249 # {host_key: FIFO list of waiters}
250 self._waiters = defaultdict(deque) # type: ignore[var-annotated]
251
252 self._loop = loop
253 self._factory = functools.partial(ResponseHandler, loop=loop)
254
255 self.cookies = SimpleCookie()
256
257 # start keep-alive connection cleanup task
258 self._cleanup_handle: Optional[asyncio.TimerHandle] = None
259
260 # start cleanup closed transports task
261 self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None
262 self._cleanup_closed_disabled = not enable_cleanup_closed
263 self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = []
264 self._cleanup_closed()
265
266 def __del__(self, _warnings: Any = warnings) -> None:
267 if self._closed:
268 return
269 if not self._conns:
270 return
271
272 conns = [repr(c) for c in self._conns.values()]
273
274 self._close()
275
276 kwargs = {"source": self}
277 _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs)
278 context = {
279 "connector": self,
280 "connections": conns,
281 "message": "Unclosed connector",
282 }
283 if self._source_traceback is not None:
284 context["source_traceback"] = self._source_traceback
285 self._loop.call_exception_handler(context)
286
287 def __enter__(self) -> "BaseConnector":
288 warnings.warn(
289 '"with Connector():" is deprecated, '
290 'use "async with Connector():" instead',
291 DeprecationWarning,
292 )
293 return self
294
295 def __exit__(self, *exc: Any) -> None:
296 self._close()
297
298 async def __aenter__(self) -> "BaseConnector":
299 return self
300
301 async def __aexit__(
302 self,
303 exc_type: Optional[Type[BaseException]] = None,
304 exc_value: Optional[BaseException] = None,
305 exc_traceback: Optional[TracebackType] = None,
306 ) -> None:
307 await self.close()
308
309 @property
310 def force_close(self) -> bool:
311 """Ultimately close connection on releasing if True."""
312 return self._force_close
313
314 @property
315 def limit(self) -> int:
316 """The total number for simultaneous connections.
317
318 If limit is 0 the connector has no limit.
319 The default limit size is 100.
320 """
321 return self._limit
322
323 @property
324 def limit_per_host(self) -> int:
325 """The limit for simultaneous connections to the same endpoint.
326
327 Endpoints are the same if they are have equal
328 (host, port, is_ssl) triple.
329 """
330 return self._limit_per_host
331
332 def _cleanup(self) -> None:
333 """Cleanup unused transports."""
334 if self._cleanup_handle:
335 self._cleanup_handle.cancel()
336 # _cleanup_handle should be unset, otherwise _release() will not
337 # recreate it ever!
338 self._cleanup_handle = None
339
340 now = self._loop.time()
341 timeout = self._keepalive_timeout
342
343 if self._conns:
344 connections = {}
345 deadline = now - timeout
346 for key, conns in self._conns.items():
347 alive = []
348 for proto, use_time in conns:
349 if proto.is_connected():
350 if use_time - deadline < 0:
351 transport = proto.transport
352 proto.close()
353 if key.is_ssl and not self._cleanup_closed_disabled:
354 self._cleanup_closed_transports.append(transport)
355 else:
356 alive.append((proto, use_time))
357 else:
358 transport = proto.transport
359 proto.close()
360 if key.is_ssl and not self._cleanup_closed_disabled:
361 self._cleanup_closed_transports.append(transport)
362
363 if alive:
364 connections[key] = alive
365
366 self._conns = connections
367
368 if self._conns:
369 self._cleanup_handle = helpers.weakref_handle(
370 self,
371 "_cleanup",
372 timeout,
373 self._loop,
374 timeout_ceil_threshold=self._timeout_ceil_threshold,
375 )
376
377 def _drop_acquired_per_host(
378 self, key: "ConnectionKey", val: ResponseHandler
379 ) -> None:
380 acquired_per_host = self._acquired_per_host
381 if key not in acquired_per_host:
382 return
383 conns = acquired_per_host[key]
384 conns.remove(val)
385 if not conns:
386 del self._acquired_per_host[key]
387
388 def _cleanup_closed(self) -> None:
389 """Double confirmation for transport close.
390
391 Some broken ssl servers may leave socket open without proper close.
392 """
393 if self._cleanup_closed_handle:
394 self._cleanup_closed_handle.cancel()
395
396 for transport in self._cleanup_closed_transports:
397 if transport is not None:
398 transport.abort()
399
400 self._cleanup_closed_transports = []
401
402 if not self._cleanup_closed_disabled:
403 self._cleanup_closed_handle = helpers.weakref_handle(
404 self,
405 "_cleanup_closed",
406 self._cleanup_closed_period,
407 self._loop,
408 timeout_ceil_threshold=self._timeout_ceil_threshold,
409 )
410
411 def close(self) -> Awaitable[None]:
412 """Close all opened transports."""
413 self._close()
414 return _DeprecationWaiter(noop())
415
416 def _close(self) -> None:
417 if self._closed:
418 return
419
420 self._closed = True
421
422 try:
423 if self._loop.is_closed():
424 return
425
426 # cancel cleanup task
427 if self._cleanup_handle:
428 self._cleanup_handle.cancel()
429
430 # cancel cleanup close task
431 if self._cleanup_closed_handle:
432 self._cleanup_closed_handle.cancel()
433
434 for data in self._conns.values():
435 for proto, t0 in data:
436 proto.close()
437
438 for proto in self._acquired:
439 proto.close()
440
441 for transport in self._cleanup_closed_transports:
442 if transport is not None:
443 transport.abort()
444
445 finally:
446 self._conns.clear()
447 self._acquired.clear()
448 self._waiters.clear()
449 self._cleanup_handle = None
450 self._cleanup_closed_transports.clear()
451 self._cleanup_closed_handle = None
452
453 @property
454 def closed(self) -> bool:
455 """Is connector closed.
456
457 A readonly property.
458 """
459 return self._closed
460
461 def _available_connections(self, key: "ConnectionKey") -> int:
462 """
463 Return number of available connections.
464
465 The limit, limit_per_host and the connection key are taken into account.
466
467 If it returns less than 1 means that there are no connections
468 available.
469 """
470 if self._limit:
471 # total calc available connections
472 available = self._limit - len(self._acquired)
473
474 # check limit per host
475 if (
476 self._limit_per_host
477 and available > 0
478 and key in self._acquired_per_host
479 ):
480 acquired = self._acquired_per_host.get(key)
481 assert acquired is not None
482 available = self._limit_per_host - len(acquired)
483
484 elif self._limit_per_host and key in self._acquired_per_host:
485 # check limit per host
486 acquired = self._acquired_per_host.get(key)
487 assert acquired is not None
488 available = self._limit_per_host - len(acquired)
489 else:
490 available = 1
491
492 return available
493
494 async def connect(
495 self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
496 ) -> Connection:
497 """Get from pool or create new connection."""
498 key = req.connection_key
499 available = self._available_connections(key)
500
501 # Wait if there are no available connections or if there are/were
502 # waiters (i.e. don't steal connection from a waiter about to wake up)
503 if available <= 0 or key in self._waiters:
504 fut = self._loop.create_future()
505
506 # This connection will now count towards the limit.
507 self._waiters[key].append(fut)
508
509 if traces:
510 for trace in traces:
511 await trace.send_connection_queued_start()
512
513 try:
514 await fut
515 except BaseException as e:
516 if key in self._waiters:
517 # remove a waiter even if it was cancelled, normally it's
518 # removed when it's notified
519 try:
520 self._waiters[key].remove(fut)
521 except ValueError: # fut may no longer be in list
522 pass
523
524 raise e
525 finally:
526 if key in self._waiters and not self._waiters[key]:
527 del self._waiters[key]
528
529 if traces:
530 for trace in traces:
531 await trace.send_connection_queued_end()
532
533 proto = self._get(key)
534 if proto is None:
535 placeholder = cast(ResponseHandler, _TransportPlaceholder())
536 self._acquired.add(placeholder)
537 self._acquired_per_host[key].add(placeholder)
538
539 if traces:
540 for trace in traces:
541 await trace.send_connection_create_start()
542
543 try:
544 proto = await self._create_connection(req, traces, timeout)
545 if self._closed:
546 proto.close()
547 raise ClientConnectionError("Connector is closed.")
548 except BaseException:
549 if not self._closed:
550 self._acquired.remove(placeholder)
551 self._drop_acquired_per_host(key, placeholder)
552 self._release_waiter()
553 raise
554 else:
555 if not self._closed:
556 self._acquired.remove(placeholder)
557 self._drop_acquired_per_host(key, placeholder)
558
559 if traces:
560 for trace in traces:
561 await trace.send_connection_create_end()
562 else:
563 if traces:
564 # Acquire the connection to prevent race conditions with limits
565 placeholder = cast(ResponseHandler, _TransportPlaceholder())
566 self._acquired.add(placeholder)
567 self._acquired_per_host[key].add(placeholder)
568 for trace in traces:
569 await trace.send_connection_reuseconn()
570 self._acquired.remove(placeholder)
571 self._drop_acquired_per_host(key, placeholder)
572
573 self._acquired.add(proto)
574 self._acquired_per_host[key].add(proto)
575 return Connection(self, key, proto, self._loop)
576
577 def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]:
578 try:
579 conns = self._conns[key]
580 except KeyError:
581 return None
582
583 t1 = self._loop.time()
584 while conns:
585 proto, t0 = conns.pop()
586 if proto.is_connected():
587 if t1 - t0 > self._keepalive_timeout:
588 transport = proto.transport
589 proto.close()
590 # only for SSL transports
591 if key.is_ssl and not self._cleanup_closed_disabled:
592 self._cleanup_closed_transports.append(transport)
593 else:
594 if not conns:
595 # The very last connection was reclaimed: drop the key
596 del self._conns[key]
597 return proto
598 else:
599 transport = proto.transport
600 proto.close()
601 if key.is_ssl and not self._cleanup_closed_disabled:
602 self._cleanup_closed_transports.append(transport)
603
604 # No more connections: drop the key
605 del self._conns[key]
606 return None
607
608 def _release_waiter(self) -> None:
609 """
610 Iterates over all waiters until one to be released is found.
611
612 The one to be released is not finished and
613 belongs to a host that has available connections.
614 """
615 if not self._waiters:
616 return
617
618 # Having the dict keys ordered this avoids to iterate
619 # at the same order at each call.
620 queues = list(self._waiters.keys())
621 random.shuffle(queues)
622
623 for key in queues:
624 if self._available_connections(key) < 1:
625 continue
626
627 waiters = self._waiters[key]
628 while waiters:
629 waiter = waiters.popleft()
630 if not waiter.done():
631 waiter.set_result(None)
632 return
633
634 def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None:
635 if self._closed:
636 # acquired connection is already released on connector closing
637 return
638
639 try:
640 self._acquired.remove(proto)
641 self._drop_acquired_per_host(key, proto)
642 except KeyError: # pragma: no cover
643 # this may be result of undetermenistic order of objects
644 # finalization due garbage collection.
645 pass
646 else:
647 self._release_waiter()
648
649 def _release(
650 self,
651 key: "ConnectionKey",
652 protocol: ResponseHandler,
653 *,
654 should_close: bool = False,
655 ) -> None:
656 if self._closed:
657 # acquired connection is already released on connector closing
658 return
659
660 self._release_acquired(key, protocol)
661
662 if self._force_close:
663 should_close = True
664
665 if should_close or protocol.should_close:
666 transport = protocol.transport
667 protocol.close()
668
669 if key.is_ssl and not self._cleanup_closed_disabled:
670 self._cleanup_closed_transports.append(transport)
671 else:
672 conns = self._conns.get(key)
673 if conns is None:
674 conns = self._conns[key] = []
675 conns.append((protocol, self._loop.time()))
676
677 if self._cleanup_handle is None:
678 self._cleanup_handle = helpers.weakref_handle(
679 self,
680 "_cleanup",
681 self._keepalive_timeout,
682 self._loop,
683 timeout_ceil_threshold=self._timeout_ceil_threshold,
684 )
685
686 async def _create_connection(
687 self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
688 ) -> ResponseHandler:
689 raise NotImplementedError()
690
691
692class _DNSCacheTable:
693 def __init__(self, ttl: Optional[float] = None) -> None:
694 self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] = {}
695 self._timestamps: Dict[Tuple[str, int], float] = {}
696 self._ttl = ttl
697
698 def __contains__(self, host: object) -> bool:
699 return host in self._addrs_rr
700
701 def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None:
702 self._addrs_rr[key] = (cycle(addrs), len(addrs))
703
704 if self._ttl is not None:
705 self._timestamps[key] = monotonic()
706
707 def remove(self, key: Tuple[str, int]) -> None:
708 self._addrs_rr.pop(key, None)
709
710 if self._ttl is not None:
711 self._timestamps.pop(key, None)
712
713 def clear(self) -> None:
714 self._addrs_rr.clear()
715 self._timestamps.clear()
716
717 def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]:
718 loop, length = self._addrs_rr[key]
719 addrs = list(islice(loop, length))
720 # Consume one more element to shift internal state of `cycle`
721 next(loop)
722 return addrs
723
724 def expired(self, key: Tuple[str, int]) -> bool:
725 if self._ttl is None:
726 return False
727
728 return self._timestamps[key] + self._ttl < monotonic()
729
730
731class TCPConnector(BaseConnector):
732 """TCP connector.
733
734 verify_ssl - Set to True to check ssl certifications.
735 fingerprint - Pass the binary sha256
736 digest of the expected certificate in DER format to verify
737 that the certificate the server presents matches. See also
738 https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning
739 resolver - Enable DNS lookups and use this
740 resolver
741 use_dns_cache - Use memory cache for DNS lookups.
742 ttl_dns_cache - Max seconds having cached a DNS entry, None forever.
743 family - socket address family
744 local_addr - local tuple of (host, port) to bind socket to
745
746 keepalive_timeout - (optional) Keep-alive timeout.
747 force_close - Set to True to force close and do reconnect
748 after each request (and between redirects).
749 limit - The total number of simultaneous connections.
750 limit_per_host - Number of simultaneous connections to one host.
751 enable_cleanup_closed - Enables clean-up closed ssl transports.
752 Disabled by default.
753 loop - Optional event loop.
754 """
755
756 def __init__(
757 self,
758 *,
759 verify_ssl: bool = True,
760 fingerprint: Optional[bytes] = None,
761 use_dns_cache: bool = True,
762 ttl_dns_cache: Optional[int] = 10,
763 family: int = 0,
764 ssl_context: Optional[SSLContext] = None,
765 ssl: Union[bool, Fingerprint, SSLContext] = True,
766 local_addr: Optional[Tuple[str, int]] = None,
767 resolver: Optional[AbstractResolver] = None,
768 keepalive_timeout: Union[None, float, object] = sentinel,
769 force_close: bool = False,
770 limit: int = 100,
771 limit_per_host: int = 0,
772 enable_cleanup_closed: bool = False,
773 loop: Optional[asyncio.AbstractEventLoop] = None,
774 timeout_ceil_threshold: float = 5,
775 ):
776 super().__init__(
777 keepalive_timeout=keepalive_timeout,
778 force_close=force_close,
779 limit=limit,
780 limit_per_host=limit_per_host,
781 enable_cleanup_closed=enable_cleanup_closed,
782 loop=loop,
783 timeout_ceil_threshold=timeout_ceil_threshold,
784 )
785
786 self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
787 if resolver is None:
788 resolver = DefaultResolver(loop=self._loop)
789 self._resolver = resolver
790
791 self._use_dns_cache = use_dns_cache
792 self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
793 self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {}
794 self._family = family
795 self._local_addr = local_addr
796
797 def close(self) -> Awaitable[None]:
798 """Close all ongoing DNS calls."""
799 for ev in self._throttle_dns_events.values():
800 ev.cancel()
801
802 return super().close()
803
804 @property
805 def family(self) -> int:
806 """Socket family like AF_INET."""
807 return self._family
808
809 @property
810 def use_dns_cache(self) -> bool:
811 """True if local DNS caching is enabled."""
812 return self._use_dns_cache
813
814 def clear_dns_cache(
815 self, host: Optional[str] = None, port: Optional[int] = None
816 ) -> None:
817 """Remove specified host/port or clear all dns local cache."""
818 if host is not None and port is not None:
819 self._cached_hosts.remove((host, port))
820 elif host is not None or port is not None:
821 raise ValueError("either both host and port " "or none of them are allowed")
822 else:
823 self._cached_hosts.clear()
824
825 async def _resolve_host(
826 self, host: str, port: int, traces: Optional[List["Trace"]] = None
827 ) -> List[Dict[str, Any]]:
828 """Resolve host and return list of addresses."""
829 if is_ip_address(host):
830 return [
831 {
832 "hostname": host,
833 "host": host,
834 "port": port,
835 "family": self._family,
836 "proto": 0,
837 "flags": 0,
838 }
839 ]
840
841 if not self._use_dns_cache:
842
843 if traces:
844 for trace in traces:
845 await trace.send_dns_resolvehost_start(host)
846
847 res = await self._resolver.resolve(host, port, family=self._family)
848
849 if traces:
850 for trace in traces:
851 await trace.send_dns_resolvehost_end(host)
852
853 return res
854
855 key = (host, port)
856 if key in self._cached_hosts and not self._cached_hosts.expired(key):
857 # get result early, before any await (#4014)
858 result = self._cached_hosts.next_addrs(key)
859
860 if traces:
861 for trace in traces:
862 await trace.send_dns_cache_hit(host)
863 return result
864
865 #
866 # If multiple connectors are resolving the same host, we wait
867 # for the first one to resolve and then use the result for all of them.
868 # We use a throttle event to ensure that we only resolve the host once
869 # and then use the result for all the waiters.
870 #
871 # In this case we need to create a task to ensure that we can shield
872 # the task from cancellation as cancelling this lookup should not cancel
873 # the underlying lookup or else the cancel event will get broadcast to
874 # all the waiters across all connections.
875 #
876 resolved_host_task = asyncio.create_task(
877 self._resolve_host_with_throttle(key, host, port, traces)
878 )
879 try:
880 return await asyncio.shield(resolved_host_task)
881 except asyncio.CancelledError:
882
883 def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
884 with suppress(Exception, asyncio.CancelledError):
885 fut.result()
886
887 resolved_host_task.add_done_callback(drop_exception)
888 raise
889
890 async def _resolve_host_with_throttle(
891 self,
892 key: Tuple[str, int],
893 host: str,
894 port: int,
895 traces: Optional[List["Trace"]],
896 ) -> List[Dict[str, Any]]:
897 """Resolve host with a dns events throttle."""
898 if key in self._throttle_dns_events:
899 # get event early, before any await (#4014)
900 event = self._throttle_dns_events[key]
901 if traces:
902 for trace in traces:
903 await trace.send_dns_cache_hit(host)
904 await event.wait()
905 else:
906 # update dict early, before any await (#4014)
907 self._throttle_dns_events[key] = EventResultOrError(self._loop)
908 if traces:
909 for trace in traces:
910 await trace.send_dns_cache_miss(host)
911 try:
912
913 if traces:
914 for trace in traces:
915 await trace.send_dns_resolvehost_start(host)
916
917 addrs = await self._resolver.resolve(host, port, family=self._family)
918 if traces:
919 for trace in traces:
920 await trace.send_dns_resolvehost_end(host)
921
922 self._cached_hosts.add(key, addrs)
923 self._throttle_dns_events[key].set()
924 except BaseException as e:
925 # any DNS exception, independently of the implementation
926 # is set for the waiters to raise the same exception.
927 self._throttle_dns_events[key].set(exc=e)
928 raise
929 finally:
930 self._throttle_dns_events.pop(key)
931
932 return self._cached_hosts.next_addrs(key)
933
934 async def _create_connection(
935 self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
936 ) -> ResponseHandler:
937 """Create connection.
938
939 Has same keyword arguments as BaseEventLoop.create_connection.
940 """
941 if req.proxy:
942 _, proto = await self._create_proxy_connection(req, traces, timeout)
943 else:
944 _, proto = await self._create_direct_connection(req, traces, timeout)
945
946 return proto
947
948 @staticmethod
949 @functools.lru_cache(None)
950 def _make_ssl_context(verified: bool) -> SSLContext:
951 if verified:
952 return ssl.create_default_context()
953 else:
954 sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
955 sslcontext.options |= ssl.OP_NO_SSLv2
956 sslcontext.options |= ssl.OP_NO_SSLv3
957 sslcontext.check_hostname = False
958 sslcontext.verify_mode = ssl.CERT_NONE
959 try:
960 sslcontext.options |= ssl.OP_NO_COMPRESSION
961 except AttributeError as attr_err:
962 warnings.warn(
963 "{!s}: The Python interpreter is compiled "
964 "against OpenSSL < 1.0.0. Ref: "
965 "https://docs.python.org/3/library/ssl.html"
966 "#ssl.OP_NO_COMPRESSION".format(attr_err),
967 )
968 sslcontext.set_default_verify_paths()
969 return sslcontext
970
971 def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
972 """Logic to get the correct SSL context
973
974 0. if req.ssl is false, return None
975
976 1. if ssl_context is specified in req, use it
977 2. if _ssl_context is specified in self, use it
978 3. otherwise:
979 1. if verify_ssl is not specified in req, use self.ssl_context
980 (will generate a default context according to self.verify_ssl)
981 2. if verify_ssl is True in req, generate a default SSL context
982 3. if verify_ssl is False in req, generate a SSL context that
983 won't verify
984 """
985 if req.is_ssl():
986 if ssl is None: # pragma: no cover
987 raise RuntimeError("SSL is not supported.")
988 sslcontext = req.ssl
989 if isinstance(sslcontext, ssl.SSLContext):
990 return sslcontext
991 if sslcontext is not True:
992 # not verified or fingerprinted
993 return self._make_ssl_context(False)
994 sslcontext = self._ssl
995 if isinstance(sslcontext, ssl.SSLContext):
996 return sslcontext
997 if sslcontext is not True:
998 # not verified or fingerprinted
999 return self._make_ssl_context(False)
1000 return self._make_ssl_context(True)
1001 else:
1002 return None
1003
1004 def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
1005 ret = req.ssl
1006 if isinstance(ret, Fingerprint):
1007 return ret
1008 ret = self._ssl
1009 if isinstance(ret, Fingerprint):
1010 return ret
1011 return None
1012
1013 async def _wrap_create_connection(
1014 self,
1015 *args: Any,
1016 req: ClientRequest,
1017 timeout: "ClientTimeout",
1018 client_error: Type[Exception] = ClientConnectorError,
1019 **kwargs: Any,
1020 ) -> Tuple[asyncio.Transport, ResponseHandler]:
1021 try:
1022 async with ceil_timeout(
1023 timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
1024 ):
1025 return await self._loop.create_connection(*args, **kwargs)
1026 except cert_errors as exc:
1027 raise ClientConnectorCertificateError(req.connection_key, exc) from exc
1028 except ssl_errors as exc:
1029 raise ClientConnectorSSLError(req.connection_key, exc) from exc
1030 except OSError as exc:
1031 if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
1032 raise
1033 raise client_error(req.connection_key, exc) from exc
1034
1035 def _fail_on_no_start_tls(self, req: "ClientRequest") -> None:
1036 """Raise a :py:exc:`RuntimeError` on missing ``start_tls()``.
1037
1038 It is necessary for TLS-in-TLS so that it is possible to
1039 send HTTPS queries through HTTPS proxies.
1040
1041 This doesn't affect regular HTTP requests, though.
1042 """
1043 if not req.is_ssl():
1044 return
1045
1046 proxy_url = req.proxy
1047 assert proxy_url is not None
1048 if proxy_url.scheme != "https":
1049 return
1050
1051 self._check_loop_for_start_tls()
1052
1053 def _check_loop_for_start_tls(self) -> None:
1054 try:
1055 self._loop.start_tls
1056 except AttributeError as attr_exc:
1057 raise RuntimeError(
1058 "An HTTPS request is being sent through an HTTPS proxy. "
1059 "This needs support for TLS in TLS but it is not implemented "
1060 "in your runtime for the stdlib asyncio.\n\n"
1061 "Please upgrade to Python 3.11 or higher. For more details, "
1062 "please see:\n"
1063 "* https://bugs.python.org/issue37179\n"
1064 "* https://github.com/python/cpython/pull/28073\n"
1065 "* https://docs.aiohttp.org/en/stable/"
1066 "client_advanced.html#proxy-support\n"
1067 "* https://github.com/aio-libs/aiohttp/discussions/6044\n",
1068 ) from attr_exc
1069
1070 def _loop_supports_start_tls(self) -> bool:
1071 try:
1072 self._check_loop_for_start_tls()
1073 except RuntimeError:
1074 return False
1075 else:
1076 return True
1077
1078 def _warn_about_tls_in_tls(
1079 self,
1080 underlying_transport: asyncio.Transport,
1081 req: ClientRequest,
1082 ) -> None:
1083 """Issue a warning if the requested URL has HTTPS scheme."""
1084 if req.request_info.url.scheme != "https":
1085 return
1086
1087 asyncio_supports_tls_in_tls = getattr(
1088 underlying_transport,
1089 "_start_tls_compatible",
1090 False,
1091 )
1092
1093 if asyncio_supports_tls_in_tls:
1094 return
1095
1096 warnings.warn(
1097 "An HTTPS request is being sent through an HTTPS proxy. "
1098 "This support for TLS in TLS is known to be disabled "
1099 "in the stdlib asyncio (Python <3.11). This is why you'll probably see "
1100 "an error in the log below.\n\n"
1101 "It is possible to enable it via monkeypatching. "
1102 "For more details, see:\n"
1103 "* https://bugs.python.org/issue37179\n"
1104 "* https://github.com/python/cpython/pull/28073\n\n"
1105 "You can temporarily patch this as follows:\n"
1106 "* https://docs.aiohttp.org/en/stable/client_advanced.html#proxy-support\n"
1107 "* https://github.com/aio-libs/aiohttp/discussions/6044\n",
1108 RuntimeWarning,
1109 source=self,
1110 # Why `4`? At least 3 of the calls in the stack originate
1111 # from the methods in this class.
1112 stacklevel=3,
1113 )
1114
1115 async def _start_tls_connection(
1116 self,
1117 underlying_transport: asyncio.Transport,
1118 req: ClientRequest,
1119 timeout: "ClientTimeout",
1120 client_error: Type[Exception] = ClientConnectorError,
1121 ) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
1122 """Wrap the raw TCP transport with TLS."""
1123 tls_proto = self._factory() # Create a brand new proto for TLS
1124
1125 # Safety of the `cast()` call here is based on the fact that
1126 # internally `_get_ssl_context()` only returns `None` when
1127 # `req.is_ssl()` evaluates to `False` which is never gonna happen
1128 # in this code path. Of course, it's rather fragile
1129 # maintainability-wise but this is to be solved separately.
1130 sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req))
1131
1132 try:
1133 async with ceil_timeout(
1134 timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
1135 ):
1136 try:
1137 tls_transport = await self._loop.start_tls(
1138 underlying_transport,
1139 tls_proto,
1140 sslcontext,
1141 server_hostname=req.server_hostname or req.host,
1142 ssl_handshake_timeout=timeout.total,
1143 )
1144 except BaseException:
1145 # We need to close the underlying transport since
1146 # `start_tls()` probably failed before it had a
1147 # chance to do this:
1148 underlying_transport.close()
1149 raise
1150 except cert_errors as exc:
1151 raise ClientConnectorCertificateError(req.connection_key, exc) from exc
1152 except ssl_errors as exc:
1153 raise ClientConnectorSSLError(req.connection_key, exc) from exc
1154 except OSError as exc:
1155 if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
1156 raise
1157 raise client_error(req.connection_key, exc) from exc
1158 except TypeError as type_err:
1159 # Example cause looks like this:
1160 # TypeError: transport <asyncio.sslproto._SSLProtocolTransport
1161 # object at 0x7f760615e460> is not supported by start_tls()
1162
1163 raise ClientConnectionError(
1164 "Cannot initialize a TLS-in-TLS connection to host "
1165 f"{req.host!s}:{req.port:d} through an underlying connection "
1166 f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} "
1167 f"[{type_err!s}]"
1168 ) from type_err
1169 else:
1170 if tls_transport is None:
1171 msg = "Failed to start TLS (possibly caused by closing transport)"
1172 raise client_error(req.connection_key, OSError(msg))
1173 tls_proto.connection_made(
1174 tls_transport
1175 ) # Kick the state machine of the new TLS protocol
1176
1177 return tls_transport, tls_proto
1178
1179 async def _create_direct_connection(
1180 self,
1181 req: ClientRequest,
1182 traces: List["Trace"],
1183 timeout: "ClientTimeout",
1184 *,
1185 client_error: Type[Exception] = ClientConnectorError,
1186 ) -> Tuple[asyncio.Transport, ResponseHandler]:
1187 sslcontext = self._get_ssl_context(req)
1188 fingerprint = self._get_fingerprint(req)
1189
1190 host = req.url.raw_host
1191 assert host is not None
1192 # Replace multiple trailing dots with a single one.
1193 # A trailing dot is only present for fully-qualified domain names.
1194 # See https://github.com/aio-libs/aiohttp/pull/7364.
1195 if host.endswith(".."):
1196 host = host.rstrip(".") + "."
1197 port = req.port
1198 assert port is not None
1199 try:
1200 # Cancelling this lookup should not cancel the underlying lookup
1201 # or else the cancel event will get broadcast to all the waiters
1202 # across all connections.
1203 hosts = await self._resolve_host(host, port, traces=traces)
1204 except OSError as exc:
1205 if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
1206 raise
1207 # in case of proxy it is not ClientProxyConnectionError
1208 # it is problem of resolving proxy ip itself
1209 raise ClientConnectorError(req.connection_key, exc) from exc
1210
1211 last_exc: Optional[Exception] = None
1212
1213 for hinfo in hosts:
1214 host = hinfo["host"]
1215 port = hinfo["port"]
1216
1217 # Strip trailing dots, certificates contain FQDN without dots.
1218 # See https://github.com/aio-libs/aiohttp/issues/3636
1219 server_hostname = (
1220 (req.server_hostname or hinfo["hostname"]).rstrip(".")
1221 if sslcontext
1222 else None
1223 )
1224
1225 try:
1226 transp, proto = await self._wrap_create_connection(
1227 self._factory,
1228 host,
1229 port,
1230 timeout=timeout,
1231 ssl=sslcontext,
1232 family=hinfo["family"],
1233 proto=hinfo["proto"],
1234 flags=hinfo["flags"],
1235 server_hostname=server_hostname,
1236 local_addr=self._local_addr,
1237 req=req,
1238 client_error=client_error,
1239 )
1240 except ClientConnectorError as exc:
1241 last_exc = exc
1242 continue
1243
1244 if req.is_ssl() and fingerprint:
1245 try:
1246 fingerprint.check(transp)
1247 except ServerFingerprintMismatch as exc:
1248 transp.close()
1249 if not self._cleanup_closed_disabled:
1250 self._cleanup_closed_transports.append(transp)
1251 last_exc = exc
1252 continue
1253
1254 return transp, proto
1255 else:
1256 assert last_exc is not None
1257 raise last_exc
1258
1259 async def _create_proxy_connection(
1260 self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
1261 ) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
1262 self._fail_on_no_start_tls(req)
1263 runtime_has_start_tls = self._loop_supports_start_tls()
1264
1265 headers: Dict[str, str] = {}
1266 if req.proxy_headers is not None:
1267 headers = req.proxy_headers # type: ignore[assignment]
1268 headers[hdrs.HOST] = req.headers[hdrs.HOST]
1269
1270 url = req.proxy
1271 assert url is not None
1272 proxy_req = ClientRequest(
1273 hdrs.METH_GET,
1274 url,
1275 headers=headers,
1276 auth=req.proxy_auth,
1277 loop=self._loop,
1278 ssl=req.ssl,
1279 )
1280
1281 # create connection to proxy server
1282 transport, proto = await self._create_direct_connection(
1283 proxy_req, [], timeout, client_error=ClientProxyConnectionError
1284 )
1285
1286 # Many HTTP proxies has buggy keepalive support. Let's not
1287 # reuse connection but close it after processing every
1288 # response.
1289 proto.force_close()
1290
1291 auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
1292 if auth is not None:
1293 if not req.is_ssl():
1294 req.headers[hdrs.PROXY_AUTHORIZATION] = auth
1295 else:
1296 proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
1297
1298 if req.is_ssl():
1299 if runtime_has_start_tls:
1300 self._warn_about_tls_in_tls(transport, req)
1301
1302 # For HTTPS requests over HTTP proxy
1303 # we must notify proxy to tunnel connection
1304 # so we send CONNECT command:
1305 # CONNECT www.python.org:443 HTTP/1.1
1306 # Host: www.python.org
1307 #
1308 # next we must do TLS handshake and so on
1309 # to do this we must wrap raw socket into secure one
1310 # asyncio handles this perfectly
1311 proxy_req.method = hdrs.METH_CONNECT
1312 proxy_req.url = req.url
1313 key = attr.evolve(
1314 req.connection_key, proxy=None, proxy_auth=None, proxy_headers_hash=None
1315 )
1316 conn = Connection(self, key, proto, self._loop)
1317 proxy_resp = await proxy_req.send(conn)
1318 try:
1319 protocol = conn._protocol
1320 assert protocol is not None
1321
1322 # read_until_eof=True will ensure the connection isn't closed
1323 # once the response is received and processed allowing
1324 # START_TLS to work on the connection below.
1325 protocol.set_response_params(
1326 read_until_eof=runtime_has_start_tls,
1327 timeout_ceil_threshold=self._timeout_ceil_threshold,
1328 )
1329 resp = await proxy_resp.start(conn)
1330 except BaseException:
1331 proxy_resp.close()
1332 conn.close()
1333 raise
1334 else:
1335 conn._protocol = None
1336 conn._transport = None
1337 try:
1338 if resp.status != 200:
1339 message = resp.reason
1340 if message is None:
1341 message = HTTPStatus(resp.status).phrase
1342 raise ClientHttpProxyError(
1343 proxy_resp.request_info,
1344 resp.history,
1345 status=resp.status,
1346 message=message,
1347 headers=resp.headers,
1348 )
1349 if not runtime_has_start_tls:
1350 rawsock = transport.get_extra_info("socket", default=None)
1351 if rawsock is None:
1352 raise RuntimeError(
1353 "Transport does not expose socket instance"
1354 )
1355 # Duplicate the socket, so now we can close proxy transport
1356 rawsock = rawsock.dup()
1357 except BaseException:
1358 # It shouldn't be closed in `finally` because it's fed to
1359 # `loop.start_tls()` and the docs say not to touch it after
1360 # passing there.
1361 transport.close()
1362 raise
1363 finally:
1364 if not runtime_has_start_tls:
1365 transport.close()
1366
1367 if not runtime_has_start_tls:
1368 # HTTP proxy with support for upgrade to HTTPS
1369 sslcontext = self._get_ssl_context(req)
1370 return await self._wrap_create_connection(
1371 self._factory,
1372 timeout=timeout,
1373 ssl=sslcontext,
1374 sock=rawsock,
1375 server_hostname=req.host,
1376 req=req,
1377 )
1378
1379 return await self._start_tls_connection(
1380 # Access the old transport for the last time before it's
1381 # closed and forgotten forever:
1382 transport,
1383 req=req,
1384 timeout=timeout,
1385 )
1386 finally:
1387 proxy_resp.close()
1388
1389 return transport, proto
1390
1391
1392class UnixConnector(BaseConnector):
1393 """Unix socket connector.
1394
1395 path - Unix socket path.
1396 keepalive_timeout - (optional) Keep-alive timeout.
1397 force_close - Set to True to force close and do reconnect
1398 after each request (and between redirects).
1399 limit - The total number of simultaneous connections.
1400 limit_per_host - Number of simultaneous connections to one host.
1401 loop - Optional event loop.
1402 """
1403
1404 def __init__(
1405 self,
1406 path: str,
1407 force_close: bool = False,
1408 keepalive_timeout: Union[object, float, None] = sentinel,
1409 limit: int = 100,
1410 limit_per_host: int = 0,
1411 loop: Optional[asyncio.AbstractEventLoop] = None,
1412 ) -> None:
1413 super().__init__(
1414 force_close=force_close,
1415 keepalive_timeout=keepalive_timeout,
1416 limit=limit,
1417 limit_per_host=limit_per_host,
1418 loop=loop,
1419 )
1420 self._path = path
1421
1422 @property
1423 def path(self) -> str:
1424 """Path to unix socket."""
1425 return self._path
1426
1427 async def _create_connection(
1428 self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
1429 ) -> ResponseHandler:
1430 try:
1431 async with ceil_timeout(
1432 timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
1433 ):
1434 _, proto = await self._loop.create_unix_connection(
1435 self._factory, self._path
1436 )
1437 except OSError as exc:
1438 if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
1439 raise
1440 raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc
1441
1442 return proto
1443
1444
1445class NamedPipeConnector(BaseConnector):
1446 """Named pipe connector.
1447
1448 Only supported by the proactor event loop.
1449 See also: https://docs.python.org/3/library/asyncio-eventloop.html
1450
1451 path - Windows named pipe path.
1452 keepalive_timeout - (optional) Keep-alive timeout.
1453 force_close - Set to True to force close and do reconnect
1454 after each request (and between redirects).
1455 limit - The total number of simultaneous connections.
1456 limit_per_host - Number of simultaneous connections to one host.
1457 loop - Optional event loop.
1458 """
1459
1460 def __init__(
1461 self,
1462 path: str,
1463 force_close: bool = False,
1464 keepalive_timeout: Union[object, float, None] = sentinel,
1465 limit: int = 100,
1466 limit_per_host: int = 0,
1467 loop: Optional[asyncio.AbstractEventLoop] = None,
1468 ) -> None:
1469 super().__init__(
1470 force_close=force_close,
1471 keepalive_timeout=keepalive_timeout,
1472 limit=limit,
1473 limit_per_host=limit_per_host,
1474 loop=loop,
1475 )
1476 if not isinstance(
1477 self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
1478 ):
1479 raise RuntimeError(
1480 "Named Pipes only available in proactor " "loop under windows"
1481 )
1482 self._path = path
1483
1484 @property
1485 def path(self) -> str:
1486 """Path to the named pipe."""
1487 return self._path
1488
1489 async def _create_connection(
1490 self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
1491 ) -> ResponseHandler:
1492 try:
1493 async with ceil_timeout(
1494 timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
1495 ):
1496 _, proto = await self._loop.create_pipe_connection( # type: ignore[attr-defined]
1497 self._factory, self._path
1498 )
1499 # the drain is required so that the connection_made is called
1500 # and transport is set otherwise it is not set before the
1501 # `assert conn.transport is not None`
1502 # in client.py's _request method
1503 await asyncio.sleep(0)
1504 # other option is to manually set transport like
1505 # `proto.transport = trans`
1506 except OSError as exc:
1507 if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
1508 raise
1509 raise ClientConnectorError(req.connection_key, exc) from exc
1510
1511 return cast(ResponseHandler, proto)