1import itertools
2import logging
3import ssl
4from types import TracebackType
5from typing import Iterable, Iterator, Optional, Type
6
7from .._backends.sync import SyncBackend
8from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
9from .._exceptions import ConnectError, ConnectTimeout
10from .._models import Origin, Request, Response
11from .._ssl import default_ssl_context
12from .._synchronization import Lock
13from .._trace import Trace
14from .http11 import HTTP11Connection
15from .interfaces import ConnectionInterface
16
17RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
18
19
20logger = logging.getLogger("httpcore.connection")
21
22
23def exponential_backoff(factor: float) -> Iterator[float]:
24 """
25 Generate a geometric sequence that has a ratio of 2 and starts with 0.
26
27 For example:
28 - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
29 - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
30 """
31 yield 0
32 for n in itertools.count():
33 yield factor * 2**n
34
35
36class HTTPConnection(ConnectionInterface):
37 def __init__(
38 self,
39 origin: Origin,
40 ssl_context: Optional[ssl.SSLContext] = None,
41 keepalive_expiry: Optional[float] = None,
42 http1: bool = True,
43 http2: bool = False,
44 retries: int = 0,
45 local_address: Optional[str] = None,
46 uds: Optional[str] = None,
47 network_backend: Optional[NetworkBackend] = None,
48 socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
49 ) -> None:
50 self._origin = origin
51 self._ssl_context = ssl_context
52 self._keepalive_expiry = keepalive_expiry
53 self._http1 = http1
54 self._http2 = http2
55 self._retries = retries
56 self._local_address = local_address
57 self._uds = uds
58
59 self._network_backend: NetworkBackend = (
60 SyncBackend() if network_backend is None else network_backend
61 )
62 self._connection: Optional[ConnectionInterface] = None
63 self._connect_failed: bool = False
64 self._request_lock = Lock()
65 self._socket_options = socket_options
66
67 def handle_request(self, request: Request) -> Response:
68 if not self.can_handle_request(request.url.origin):
69 raise RuntimeError(
70 f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
71 )
72
73 try:
74 with self._request_lock:
75 if self._connection is None:
76 stream = self._connect(request)
77
78 ssl_object = stream.get_extra_info("ssl_object")
79 http2_negotiated = (
80 ssl_object is not None
81 and ssl_object.selected_alpn_protocol() == "h2"
82 )
83 if http2_negotiated or (self._http2 and not self._http1):
84 from .http2 import HTTP2Connection
85
86 self._connection = HTTP2Connection(
87 origin=self._origin,
88 stream=stream,
89 keepalive_expiry=self._keepalive_expiry,
90 )
91 else:
92 self._connection = HTTP11Connection(
93 origin=self._origin,
94 stream=stream,
95 keepalive_expiry=self._keepalive_expiry,
96 )
97 except BaseException as exc:
98 self._connect_failed = True
99 raise exc
100
101 return self._connection.handle_request(request)
102
103 def _connect(self, request: Request) -> NetworkStream:
104 timeouts = request.extensions.get("timeout", {})
105 sni_hostname = request.extensions.get("sni_hostname", None)
106 timeout = timeouts.get("connect", None)
107
108 retries_left = self._retries
109 delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
110
111 while True:
112 try:
113 if self._uds is None:
114 kwargs = {
115 "host": self._origin.host.decode("ascii"),
116 "port": self._origin.port,
117 "local_address": self._local_address,
118 "timeout": timeout,
119 "socket_options": self._socket_options,
120 }
121 with Trace("connect_tcp", logger, request, kwargs) as trace:
122 stream = self._network_backend.connect_tcp(**kwargs)
123 trace.return_value = stream
124 else:
125 kwargs = {
126 "path": self._uds,
127 "timeout": timeout,
128 "socket_options": self._socket_options,
129 }
130 with Trace(
131 "connect_unix_socket", logger, request, kwargs
132 ) as trace:
133 stream = self._network_backend.connect_unix_socket(
134 **kwargs
135 )
136 trace.return_value = stream
137
138 if self._origin.scheme in (b"https", b"wss"):
139 ssl_context = (
140 default_ssl_context()
141 if self._ssl_context is None
142 else self._ssl_context
143 )
144 alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
145 ssl_context.set_alpn_protocols(alpn_protocols)
146
147 kwargs = {
148 "ssl_context": ssl_context,
149 "server_hostname": sni_hostname
150 or self._origin.host.decode("ascii"),
151 "timeout": timeout,
152 }
153 with Trace("start_tls", logger, request, kwargs) as trace:
154 stream = stream.start_tls(**kwargs)
155 trace.return_value = stream
156 return stream
157 except (ConnectError, ConnectTimeout):
158 if retries_left <= 0:
159 raise
160 retries_left -= 1
161 delay = next(delays)
162 with Trace("retry", logger, request, kwargs) as trace:
163 self._network_backend.sleep(delay)
164
165 def can_handle_request(self, origin: Origin) -> bool:
166 return origin == self._origin
167
168 def close(self) -> None:
169 if self._connection is not None:
170 with Trace("close", logger, None, {}):
171 self._connection.close()
172
173 def is_available(self) -> bool:
174 if self._connection is None:
175 # If HTTP/2 support is enabled, and the resulting connection could
176 # end up as HTTP/2 then we should indicate the connection as being
177 # available to service multiple requests.
178 return (
179 self._http2
180 and (self._origin.scheme == b"https" or not self._http1)
181 and not self._connect_failed
182 )
183 return self._connection.is_available()
184
185 def has_expired(self) -> bool:
186 if self._connection is None:
187 return self._connect_failed
188 return self._connection.has_expired()
189
190 def is_idle(self) -> bool:
191 if self._connection is None:
192 return self._connect_failed
193 return self._connection.is_idle()
194
195 def is_closed(self) -> bool:
196 if self._connection is None:
197 return self._connect_failed
198 return self._connection.is_closed()
199
200 def info(self) -> str:
201 if self._connection is None:
202 return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
203 return self._connection.info()
204
205 def __repr__(self) -> str:
206 return f"<{self.__class__.__name__} [{self.info()}]>"
207
208 # These context managers are not used in the standard flow, but are
209 # useful for testing or working with connection instances directly.
210
211 def __enter__(self) -> "HTTPConnection":
212 return self
213
214 def __exit__(
215 self,
216 exc_type: Optional[Type[BaseException]] = None,
217 exc_value: Optional[BaseException] = None,
218 traceback: Optional[TracebackType] = None,
219 ) -> None:
220 self.close()