1from __future__ import annotations
2
3import itertools
4import logging
5import ssl
6import types
7import typing
8
9from .._backends.sync import SyncBackend
10from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
11from .._exceptions import ConnectError, ConnectTimeout
12from .._models import Origin, Request, Response
13from .._ssl import default_ssl_context
14from .._synchronization import Lock
15from .._trace import Trace
16from .http11 import HTTP11Connection
17from .interfaces import ConnectionInterface
18
19RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
20
21
22logger = logging.getLogger("httpcore.connection")
23
24
25def exponential_backoff(factor: float) -> typing.Iterator[float]:
26 """
27 Generate a geometric sequence that has a ratio of 2 and starts with 0.
28
29 For example:
30 - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
31 - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
32 """
33 yield 0
34 for n in itertools.count():
35 yield factor * 2**n
36
37
38class HTTPConnection(ConnectionInterface):
39 def __init__(
40 self,
41 origin: Origin,
42 ssl_context: ssl.SSLContext | None = None,
43 keepalive_expiry: float | None = None,
44 http1: bool = True,
45 http2: bool = False,
46 retries: int = 0,
47 local_address: str | None = None,
48 uds: str | None = None,
49 network_backend: NetworkBackend | None = None,
50 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
51 ) -> None:
52 self._origin = origin
53 self._ssl_context = ssl_context
54 self._keepalive_expiry = keepalive_expiry
55 self._http1 = http1
56 self._http2 = http2
57 self._retries = retries
58 self._local_address = local_address
59 self._uds = uds
60
61 self._network_backend: NetworkBackend = (
62 SyncBackend() if network_backend is None else network_backend
63 )
64 self._connection: ConnectionInterface | None = None
65 self._connect_failed: bool = False
66 self._request_lock = Lock()
67 self._socket_options = socket_options
68
69 def handle_request(self, request: Request) -> Response:
70 if not self.can_handle_request(request.url.origin):
71 raise RuntimeError(
72 f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
73 )
74
75 try:
76 with self._request_lock:
77 if self._connection is None:
78 stream = self._connect(request)
79
80 ssl_object = stream.get_extra_info("ssl_object")
81 http2_negotiated = (
82 ssl_object is not None
83 and ssl_object.selected_alpn_protocol() == "h2"
84 )
85 if http2_negotiated or (self._http2 and not self._http1):
86 from .http2 import HTTP2Connection
87
88 self._connection = HTTP2Connection(
89 origin=self._origin,
90 stream=stream,
91 keepalive_expiry=self._keepalive_expiry,
92 )
93 else:
94 self._connection = HTTP11Connection(
95 origin=self._origin,
96 stream=stream,
97 keepalive_expiry=self._keepalive_expiry,
98 )
99 except BaseException as exc:
100 self._connect_failed = True
101 raise exc
102
103 return self._connection.handle_request(request)
104
105 def _connect(self, request: Request) -> NetworkStream:
106 timeouts = request.extensions.get("timeout", {})
107 sni_hostname = request.extensions.get("sni_hostname", None)
108 timeout = timeouts.get("connect", None)
109
110 retries_left = self._retries
111 delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
112
113 while True:
114 try:
115 if self._uds is None:
116 kwargs = {
117 "host": self._origin.host.decode("ascii"),
118 "port": self._origin.port,
119 "local_address": self._local_address,
120 "timeout": timeout,
121 "socket_options": self._socket_options,
122 }
123 with Trace("connect_tcp", logger, request, kwargs) as trace:
124 stream = self._network_backend.connect_tcp(**kwargs)
125 trace.return_value = stream
126 else:
127 kwargs = {
128 "path": self._uds,
129 "timeout": timeout,
130 "socket_options": self._socket_options,
131 }
132 with Trace(
133 "connect_unix_socket", logger, request, kwargs
134 ) as trace:
135 stream = self._network_backend.connect_unix_socket(
136 **kwargs
137 )
138 trace.return_value = stream
139
140 if self._origin.scheme in (b"https", b"wss"):
141 ssl_context = (
142 default_ssl_context()
143 if self._ssl_context is None
144 else self._ssl_context
145 )
146 alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
147 ssl_context.set_alpn_protocols(alpn_protocols)
148
149 kwargs = {
150 "ssl_context": ssl_context,
151 "server_hostname": sni_hostname
152 or self._origin.host.decode("ascii"),
153 "timeout": timeout,
154 }
155 with Trace("start_tls", logger, request, kwargs) as trace:
156 stream = stream.start_tls(**kwargs)
157 trace.return_value = stream
158 return stream
159 except (ConnectError, ConnectTimeout):
160 if retries_left <= 0:
161 raise
162 retries_left -= 1
163 delay = next(delays)
164 with Trace("retry", logger, request, kwargs) as trace:
165 self._network_backend.sleep(delay)
166
167 def can_handle_request(self, origin: Origin) -> bool:
168 return origin == self._origin
169
170 def close(self) -> None:
171 if self._connection is not None:
172 with Trace("close", logger, None, {}):
173 self._connection.close()
174
175 def is_available(self) -> bool:
176 if self._connection is None:
177 # If HTTP/2 support is enabled, and the resulting connection could
178 # end up as HTTP/2 then we should indicate the connection as being
179 # available to service multiple requests.
180 return (
181 self._http2
182 and (self._origin.scheme == b"https" or not self._http1)
183 and not self._connect_failed
184 )
185 return self._connection.is_available()
186
187 def has_expired(self) -> bool:
188 if self._connection is None:
189 return self._connect_failed
190 return self._connection.has_expired()
191
192 def is_idle(self) -> bool:
193 if self._connection is None:
194 return self._connect_failed
195 return self._connection.is_idle()
196
197 def is_closed(self) -> bool:
198 if self._connection is None:
199 return self._connect_failed
200 return self._connection.is_closed()
201
202 def info(self) -> str:
203 if self._connection is None:
204 return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
205 return self._connection.info()
206
207 def __repr__(self) -> str:
208 return f"<{self.__class__.__name__} [{self.info()}]>"
209
210 # These context managers are not used in the standard flow, but are
211 # useful for testing or working with connection instances directly.
212
213 def __enter__(self) -> HTTPConnection:
214 return self
215
216 def __exit__(
217 self,
218 exc_type: type[BaseException] | None = None,
219 exc_value: BaseException | None = None,
220 traceback: types.TracebackType | None = None,
221 ) -> None:
222 self.close()