1import logging
2import ssl
3from base64 import b64encode
4from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
5
6from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
7from .._exceptions import ProxyError
8from .._models import (
9 URL,
10 Origin,
11 Request,
12 Response,
13 enforce_bytes,
14 enforce_headers,
15 enforce_url,
16)
17from .._ssl import default_ssl_context
18from .._synchronization import AsyncLock
19from .._trace import Trace
20from .connection import AsyncHTTPConnection
21from .connection_pool import AsyncConnectionPool
22from .http11 import AsyncHTTP11Connection
23from .interfaces import AsyncConnectionInterface
24
25HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
26HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
27
28
29logger = logging.getLogger("httpcore.proxy")
30
31
32def merge_headers(
33 default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
34 override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
35) -> List[Tuple[bytes, bytes]]:
36 """
37 Append default_headers and override_headers, de-duplicating if a key exists
38 in both cases.
39 """
40 default_headers = [] if default_headers is None else list(default_headers)
41 override_headers = [] if override_headers is None else list(override_headers)
42 has_override = set(key.lower() for key, value in override_headers)
43 default_headers = [
44 (key, value)
45 for key, value in default_headers
46 if key.lower() not in has_override
47 ]
48 return default_headers + override_headers
49
50
51def build_auth_header(username: bytes, password: bytes) -> bytes:
52 userpass = username + b":" + password
53 return b"Basic " + b64encode(userpass)
54
55
56class AsyncHTTPProxy(AsyncConnectionPool):
57 """
58 A connection pool that sends requests via an HTTP proxy.
59 """
60
61 def __init__(
62 self,
63 proxy_url: Union[URL, bytes, str],
64 proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
65 proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
66 ssl_context: Optional[ssl.SSLContext] = None,
67 proxy_ssl_context: Optional[ssl.SSLContext] = None,
68 max_connections: Optional[int] = 10,
69 max_keepalive_connections: Optional[int] = None,
70 keepalive_expiry: Optional[float] = None,
71 http1: bool = True,
72 http2: bool = False,
73 retries: int = 0,
74 local_address: Optional[str] = None,
75 uds: Optional[str] = None,
76 network_backend: Optional[AsyncNetworkBackend] = None,
77 socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
78 ) -> None:
79 """
80 A connection pool for making HTTP requests.
81
82 Parameters:
83 proxy_url: The URL to use when connecting to the proxy server.
84 For example `"http://127.0.0.1:8080/"`.
85 proxy_auth: Any proxy authentication as a two-tuple of
86 (username, password). May be either bytes or ascii-only str.
87 proxy_headers: Any HTTP headers to use for the proxy requests.
88 For example `{"Proxy-Authorization": "Basic <username>:<password>"}`.
89 ssl_context: An SSL context to use for verifying connections.
90 If not specified, the default `httpcore.default_ssl_context()`
91 will be used.
92 proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin.
93 max_connections: The maximum number of concurrent HTTP connections that
94 the pool should allow. Any attempt to send a request on a pool that
95 would exceed this amount will block until a connection is available.
96 max_keepalive_connections: The maximum number of idle HTTP connections
97 that will be maintained in the pool.
98 keepalive_expiry: The duration in seconds that an idle HTTP connection
99 may be maintained for before being expired from the pool.
100 http1: A boolean indicating if HTTP/1.1 requests should be supported
101 by the connection pool. Defaults to True.
102 http2: A boolean indicating if HTTP/2 requests should be supported by
103 the connection pool. Defaults to False.
104 retries: The maximum number of retries when trying to establish
105 a connection.
106 local_address: Local address to connect from. Can also be used to
107 connect using a particular address family. Using
108 `local_address="0.0.0.0"` will connect using an `AF_INET` address
109 (IPv4), while using `local_address="::"` will connect using an
110 `AF_INET6` address (IPv6).
111 uds: Path to a Unix Domain Socket to use instead of TCP sockets.
112 network_backend: A backend instance to use for handling network I/O.
113 """
114 super().__init__(
115 ssl_context=ssl_context,
116 max_connections=max_connections,
117 max_keepalive_connections=max_keepalive_connections,
118 keepalive_expiry=keepalive_expiry,
119 http1=http1,
120 http2=http2,
121 network_backend=network_backend,
122 retries=retries,
123 local_address=local_address,
124 uds=uds,
125 socket_options=socket_options,
126 )
127
128 self._proxy_url = enforce_url(proxy_url, name="proxy_url")
129 if (
130 self._proxy_url.scheme == b"http" and proxy_ssl_context is not None
131 ): # pragma: no cover
132 raise RuntimeError(
133 "The `proxy_ssl_context` argument is not allowed for the http scheme"
134 )
135
136 self._ssl_context = ssl_context
137 self._proxy_ssl_context = proxy_ssl_context
138 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
139 if proxy_auth is not None:
140 username = enforce_bytes(proxy_auth[0], name="proxy_auth")
141 password = enforce_bytes(proxy_auth[1], name="proxy_auth")
142 authorization = build_auth_header(username, password)
143 self._proxy_headers = [
144 (b"Proxy-Authorization", authorization)
145 ] + self._proxy_headers
146
147 def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
148 if origin.scheme == b"http":
149 return AsyncForwardHTTPConnection(
150 proxy_origin=self._proxy_url.origin,
151 proxy_headers=self._proxy_headers,
152 remote_origin=origin,
153 keepalive_expiry=self._keepalive_expiry,
154 network_backend=self._network_backend,
155 proxy_ssl_context=self._proxy_ssl_context,
156 )
157 return AsyncTunnelHTTPConnection(
158 proxy_origin=self._proxy_url.origin,
159 proxy_headers=self._proxy_headers,
160 remote_origin=origin,
161 ssl_context=self._ssl_context,
162 proxy_ssl_context=self._proxy_ssl_context,
163 keepalive_expiry=self._keepalive_expiry,
164 http1=self._http1,
165 http2=self._http2,
166 network_backend=self._network_backend,
167 )
168
169
170class AsyncForwardHTTPConnection(AsyncConnectionInterface):
171 def __init__(
172 self,
173 proxy_origin: Origin,
174 remote_origin: Origin,
175 proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
176 keepalive_expiry: Optional[float] = None,
177 network_backend: Optional[AsyncNetworkBackend] = None,
178 socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
179 proxy_ssl_context: Optional[ssl.SSLContext] = None,
180 ) -> None:
181 self._connection = AsyncHTTPConnection(
182 origin=proxy_origin,
183 keepalive_expiry=keepalive_expiry,
184 network_backend=network_backend,
185 socket_options=socket_options,
186 ssl_context=proxy_ssl_context,
187 )
188 self._proxy_origin = proxy_origin
189 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
190 self._remote_origin = remote_origin
191
192 async def handle_async_request(self, request: Request) -> Response:
193 headers = merge_headers(self._proxy_headers, request.headers)
194 url = URL(
195 scheme=self._proxy_origin.scheme,
196 host=self._proxy_origin.host,
197 port=self._proxy_origin.port,
198 target=bytes(request.url),
199 )
200 proxy_request = Request(
201 method=request.method,
202 url=url,
203 headers=headers,
204 content=request.stream,
205 extensions=request.extensions,
206 )
207 return await self._connection.handle_async_request(proxy_request)
208
209 def can_handle_request(self, origin: Origin) -> bool:
210 return origin == self._remote_origin
211
212 async def aclose(self) -> None:
213 await self._connection.aclose()
214
215 def info(self) -> str:
216 return self._connection.info()
217
218 def is_available(self) -> bool:
219 return self._connection.is_available()
220
221 def has_expired(self) -> bool:
222 return self._connection.has_expired()
223
224 def is_idle(self) -> bool:
225 return self._connection.is_idle()
226
227 def is_closed(self) -> bool:
228 return self._connection.is_closed()
229
230 def __repr__(self) -> str:
231 return f"<{self.__class__.__name__} [{self.info()}]>"
232
233
234class AsyncTunnelHTTPConnection(AsyncConnectionInterface):
235 def __init__(
236 self,
237 proxy_origin: Origin,
238 remote_origin: Origin,
239 ssl_context: Optional[ssl.SSLContext] = None,
240 proxy_ssl_context: Optional[ssl.SSLContext] = None,
241 proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
242 keepalive_expiry: Optional[float] = None,
243 http1: bool = True,
244 http2: bool = False,
245 network_backend: Optional[AsyncNetworkBackend] = None,
246 socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
247 ) -> None:
248 self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
249 origin=proxy_origin,
250 keepalive_expiry=keepalive_expiry,
251 network_backend=network_backend,
252 socket_options=socket_options,
253 ssl_context=proxy_ssl_context,
254 )
255 self._proxy_origin = proxy_origin
256 self._remote_origin = remote_origin
257 self._ssl_context = ssl_context
258 self._proxy_ssl_context = proxy_ssl_context
259 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
260 self._keepalive_expiry = keepalive_expiry
261 self._http1 = http1
262 self._http2 = http2
263 self._connect_lock = AsyncLock()
264 self._connected = False
265
266 async def handle_async_request(self, request: Request) -> Response:
267 timeouts = request.extensions.get("timeout", {})
268 timeout = timeouts.get("connect", None)
269
270 async with self._connect_lock:
271 if not self._connected:
272 target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port)
273
274 connect_url = URL(
275 scheme=self._proxy_origin.scheme,
276 host=self._proxy_origin.host,
277 port=self._proxy_origin.port,
278 target=target,
279 )
280 connect_headers = merge_headers(
281 [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
282 )
283 connect_request = Request(
284 method=b"CONNECT",
285 url=connect_url,
286 headers=connect_headers,
287 extensions=request.extensions,
288 )
289 connect_response = await self._connection.handle_async_request(
290 connect_request
291 )
292
293 if connect_response.status < 200 or connect_response.status > 299:
294 reason_bytes = connect_response.extensions.get("reason_phrase", b"")
295 reason_str = reason_bytes.decode("ascii", errors="ignore")
296 msg = "%d %s" % (connect_response.status, reason_str)
297 await self._connection.aclose()
298 raise ProxyError(msg)
299
300 stream = connect_response.extensions["network_stream"]
301
302 # Upgrade the stream to SSL
303 ssl_context = (
304 default_ssl_context()
305 if self._ssl_context is None
306 else self._ssl_context
307 )
308 alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
309 ssl_context.set_alpn_protocols(alpn_protocols)
310
311 kwargs = {
312 "ssl_context": ssl_context,
313 "server_hostname": self._remote_origin.host.decode("ascii"),
314 "timeout": timeout,
315 }
316 async with Trace("start_tls", logger, request, kwargs) as trace:
317 stream = await stream.start_tls(**kwargs)
318 trace.return_value = stream
319
320 # Determine if we should be using HTTP/1.1 or HTTP/2
321 ssl_object = stream.get_extra_info("ssl_object")
322 http2_negotiated = (
323 ssl_object is not None
324 and ssl_object.selected_alpn_protocol() == "h2"
325 )
326
327 # Create the HTTP/1.1 or HTTP/2 connection
328 if http2_negotiated or (self._http2 and not self._http1):
329 from .http2 import AsyncHTTP2Connection
330
331 self._connection = AsyncHTTP2Connection(
332 origin=self._remote_origin,
333 stream=stream,
334 keepalive_expiry=self._keepalive_expiry,
335 )
336 else:
337 self._connection = AsyncHTTP11Connection(
338 origin=self._remote_origin,
339 stream=stream,
340 keepalive_expiry=self._keepalive_expiry,
341 )
342
343 self._connected = True
344 return await self._connection.handle_async_request(request)
345
346 def can_handle_request(self, origin: Origin) -> bool:
347 return origin == self._remote_origin
348
349 async def aclose(self) -> None:
350 await self._connection.aclose()
351
352 def info(self) -> str:
353 return self._connection.info()
354
355 def is_available(self) -> bool:
356 return self._connection.is_available()
357
358 def has_expired(self) -> bool:
359 return self._connection.has_expired()
360
361 def is_idle(self) -> bool:
362 return self._connection.is_idle()
363
364 def is_closed(self) -> bool:
365 return self._connection.is_closed()
366
367 def __repr__(self) -> str:
368 return f"<{self.__class__.__name__} [{self.info()}]>"