1from __future__ import annotations
2
3import base64
4import logging
5import ssl
6import typing
7
8from .._backends.base import SOCKET_OPTION, NetworkBackend
9from .._exceptions import ProxyError
10from .._models import (
11 URL,
12 Origin,
13 Request,
14 Response,
15 enforce_bytes,
16 enforce_headers,
17 enforce_url,
18)
19from .._ssl import default_ssl_context
20from .._synchronization import Lock
21from .._trace import Trace
22from .connection import HTTPConnection
23from .connection_pool import ConnectionPool
24from .http11 import HTTP11Connection
25from .interfaces import ConnectionInterface
26
27ByteOrStr = typing.Union[bytes, str]
28HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]]
29HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr]
30
31
32logger = logging.getLogger("httpcore.proxy")
33
34
35def merge_headers(
36 default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
37 override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
38) -> list[tuple[bytes, bytes]]:
39 """
40 Append default_headers and override_headers, de-duplicating if a key exists
41 in both cases.
42 """
43 default_headers = [] if default_headers is None else list(default_headers)
44 override_headers = [] if override_headers is None else list(override_headers)
45 has_override = set(key.lower() for key, value in override_headers)
46 default_headers = [
47 (key, value)
48 for key, value in default_headers
49 if key.lower() not in has_override
50 ]
51 return default_headers + override_headers
52
53
54class HTTPProxy(ConnectionPool): # pragma: nocover
55 """
56 A connection pool that sends requests via an HTTP proxy.
57 """
58
59 def __init__(
60 self,
61 proxy_url: URL | bytes | str,
62 proxy_auth: tuple[bytes | str, bytes | str] | None = None,
63 proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
64 ssl_context: ssl.SSLContext | None = None,
65 proxy_ssl_context: ssl.SSLContext | None = None,
66 max_connections: int | None = 10,
67 max_keepalive_connections: int | None = None,
68 keepalive_expiry: float | None = None,
69 http1: bool = True,
70 http2: bool = False,
71 retries: int = 0,
72 local_address: str | None = None,
73 uds: str | None = None,
74 network_backend: NetworkBackend | None = None,
75 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
76 ) -> None:
77 """
78 A connection pool for making HTTP requests.
79
80 Parameters:
81 proxy_url: The URL to use when connecting to the proxy server.
82 For example `"http://127.0.0.1:8080/"`.
83 proxy_auth: Any proxy authentication as a two-tuple of
84 (username, password). May be either bytes or ascii-only str.
85 proxy_headers: Any HTTP headers to use for the proxy requests.
86 For example `{"Proxy-Authorization": "Basic <username>:<password>"}`.
87 ssl_context: An SSL context to use for verifying connections.
88 If not specified, the default `httpcore.default_ssl_context()`
89 will be used.
90 proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin.
91 max_connections: The maximum number of concurrent HTTP connections that
92 the pool should allow. Any attempt to send a request on a pool that
93 would exceed this amount will block until a connection is available.
94 max_keepalive_connections: The maximum number of idle HTTP connections
95 that will be maintained in the pool.
96 keepalive_expiry: The duration in seconds that an idle HTTP connection
97 may be maintained for before being expired from the pool.
98 http1: A boolean indicating if HTTP/1.1 requests should be supported
99 by the connection pool. Defaults to True.
100 http2: A boolean indicating if HTTP/2 requests should be supported by
101 the connection pool. Defaults to False.
102 retries: The maximum number of retries when trying to establish
103 a connection.
104 local_address: Local address to connect from. Can also be used to
105 connect using a particular address family. Using
106 `local_address="0.0.0.0"` will connect using an `AF_INET` address
107 (IPv4), while using `local_address="::"` will connect using an
108 `AF_INET6` address (IPv6).
109 uds: Path to a Unix Domain Socket to use instead of TCP sockets.
110 network_backend: A backend instance to use for handling network I/O.
111 """
112 super().__init__(
113 ssl_context=ssl_context,
114 max_connections=max_connections,
115 max_keepalive_connections=max_keepalive_connections,
116 keepalive_expiry=keepalive_expiry,
117 http1=http1,
118 http2=http2,
119 network_backend=network_backend,
120 retries=retries,
121 local_address=local_address,
122 uds=uds,
123 socket_options=socket_options,
124 )
125
126 self._proxy_url = enforce_url(proxy_url, name="proxy_url")
127 if (
128 self._proxy_url.scheme == b"http" and proxy_ssl_context is not None
129 ): # pragma: no cover
130 raise RuntimeError(
131 "The `proxy_ssl_context` argument is not allowed for the http scheme"
132 )
133
134 self._ssl_context = ssl_context
135 self._proxy_ssl_context = proxy_ssl_context
136 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
137 if proxy_auth is not None:
138 username = enforce_bytes(proxy_auth[0], name="proxy_auth")
139 password = enforce_bytes(proxy_auth[1], name="proxy_auth")
140 userpass = username + b":" + password
141 authorization = b"Basic " + base64.b64encode(userpass)
142 self._proxy_headers = [
143 (b"Proxy-Authorization", authorization)
144 ] + self._proxy_headers
145
146 def create_connection(self, origin: Origin) -> ConnectionInterface:
147 if origin.scheme == b"http":
148 return ForwardHTTPConnection(
149 proxy_origin=self._proxy_url.origin,
150 proxy_headers=self._proxy_headers,
151 remote_origin=origin,
152 keepalive_expiry=self._keepalive_expiry,
153 network_backend=self._network_backend,
154 proxy_ssl_context=self._proxy_ssl_context,
155 )
156 return TunnelHTTPConnection(
157 proxy_origin=self._proxy_url.origin,
158 proxy_headers=self._proxy_headers,
159 remote_origin=origin,
160 ssl_context=self._ssl_context,
161 proxy_ssl_context=self._proxy_ssl_context,
162 keepalive_expiry=self._keepalive_expiry,
163 http1=self._http1,
164 http2=self._http2,
165 network_backend=self._network_backend,
166 )
167
168
169class ForwardHTTPConnection(ConnectionInterface):
170 def __init__(
171 self,
172 proxy_origin: Origin,
173 remote_origin: Origin,
174 proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
175 keepalive_expiry: float | None = None,
176 network_backend: NetworkBackend | None = None,
177 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
178 proxy_ssl_context: ssl.SSLContext | None = None,
179 ) -> None:
180 self._connection = HTTPConnection(
181 origin=proxy_origin,
182 keepalive_expiry=keepalive_expiry,
183 network_backend=network_backend,
184 socket_options=socket_options,
185 ssl_context=proxy_ssl_context,
186 )
187 self._proxy_origin = proxy_origin
188 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
189 self._remote_origin = remote_origin
190
191 def handle_request(self, request: Request) -> Response:
192 headers = merge_headers(self._proxy_headers, request.headers)
193 url = URL(
194 scheme=self._proxy_origin.scheme,
195 host=self._proxy_origin.host,
196 port=self._proxy_origin.port,
197 target=bytes(request.url),
198 )
199 proxy_request = Request(
200 method=request.method,
201 url=url,
202 headers=headers,
203 content=request.stream,
204 extensions=request.extensions,
205 )
206 return self._connection.handle_request(proxy_request)
207
208 def can_handle_request(self, origin: Origin) -> bool:
209 return origin == self._remote_origin
210
211 def close(self) -> None:
212 self._connection.close()
213
214 def info(self) -> str:
215 return self._connection.info()
216
217 def is_available(self) -> bool:
218 return self._connection.is_available()
219
220 def has_expired(self) -> bool:
221 return self._connection.has_expired()
222
223 def is_idle(self) -> bool:
224 return self._connection.is_idle()
225
226 def is_closed(self) -> bool:
227 return self._connection.is_closed()
228
229 def __repr__(self) -> str:
230 return f"<{self.__class__.__name__} [{self.info()}]>"
231
232
233class TunnelHTTPConnection(ConnectionInterface):
234 def __init__(
235 self,
236 proxy_origin: Origin,
237 remote_origin: Origin,
238 ssl_context: ssl.SSLContext | None = None,
239 proxy_ssl_context: ssl.SSLContext | None = None,
240 proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
241 keepalive_expiry: float | None = None,
242 http1: bool = True,
243 http2: bool = False,
244 network_backend: NetworkBackend | None = None,
245 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
246 ) -> None:
247 self._connection: ConnectionInterface = HTTPConnection(
248 origin=proxy_origin,
249 keepalive_expiry=keepalive_expiry,
250 network_backend=network_backend,
251 socket_options=socket_options,
252 ssl_context=proxy_ssl_context,
253 )
254 self._proxy_origin = proxy_origin
255 self._remote_origin = remote_origin
256 self._ssl_context = ssl_context
257 self._proxy_ssl_context = proxy_ssl_context
258 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
259 self._keepalive_expiry = keepalive_expiry
260 self._http1 = http1
261 self._http2 = http2
262 self._connect_lock = Lock()
263 self._connected = False
264
265 def handle_request(self, request: Request) -> Response:
266 timeouts = request.extensions.get("timeout", {})
267 timeout = timeouts.get("connect", None)
268
269 with self._connect_lock:
270 if not self._connected:
271 target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port)
272
273 connect_url = URL(
274 scheme=self._proxy_origin.scheme,
275 host=self._proxy_origin.host,
276 port=self._proxy_origin.port,
277 target=target,
278 )
279 connect_headers = merge_headers(
280 [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
281 )
282 connect_request = Request(
283 method=b"CONNECT",
284 url=connect_url,
285 headers=connect_headers,
286 extensions=request.extensions,
287 )
288 connect_response = self._connection.handle_request(
289 connect_request
290 )
291
292 if connect_response.status < 200 or connect_response.status > 299:
293 reason_bytes = connect_response.extensions.get("reason_phrase", b"")
294 reason_str = reason_bytes.decode("ascii", errors="ignore")
295 msg = "%d %s" % (connect_response.status, reason_str)
296 self._connection.close()
297 raise ProxyError(msg)
298
299 stream = connect_response.extensions["network_stream"]
300
301 # Upgrade the stream to SSL
302 ssl_context = (
303 default_ssl_context()
304 if self._ssl_context is None
305 else self._ssl_context
306 )
307 alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
308 ssl_context.set_alpn_protocols(alpn_protocols)
309
310 kwargs = {
311 "ssl_context": ssl_context,
312 "server_hostname": self._remote_origin.host.decode("ascii"),
313 "timeout": timeout,
314 }
315 with Trace("start_tls", logger, request, kwargs) as trace:
316 stream = stream.start_tls(**kwargs)
317 trace.return_value = stream
318
319 # Determine if we should be using HTTP/1.1 or HTTP/2
320 ssl_object = stream.get_extra_info("ssl_object")
321 http2_negotiated = (
322 ssl_object is not None
323 and ssl_object.selected_alpn_protocol() == "h2"
324 )
325
326 # Create the HTTP/1.1 or HTTP/2 connection
327 if http2_negotiated or (self._http2 and not self._http1):
328 from .http2 import HTTP2Connection
329
330 self._connection = HTTP2Connection(
331 origin=self._remote_origin,
332 stream=stream,
333 keepalive_expiry=self._keepalive_expiry,
334 )
335 else:
336 self._connection = HTTP11Connection(
337 origin=self._remote_origin,
338 stream=stream,
339 keepalive_expiry=self._keepalive_expiry,
340 )
341
342 self._connected = True
343 return self._connection.handle_request(request)
344
345 def can_handle_request(self, origin: Origin) -> bool:
346 return origin == self._remote_origin
347
348 def close(self) -> None:
349 self._connection.close()
350
351 def info(self) -> str:
352 return self._connection.info()
353
354 def is_available(self) -> bool:
355 return self._connection.is_available()
356
357 def has_expired(self) -> bool:
358 return self._connection.has_expired()
359
360 def is_idle(self) -> bool:
361 return self._connection.is_idle()
362
363 def is_closed(self) -> bool:
364 return self._connection.is_closed()
365
366 def __repr__(self) -> str:
367 return f"<{self.__class__.__name__} [{self.info()}]>"