Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/httpcore/_async/http_proxy.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

130 statements  

1import logging 

2import ssl 

3from base64 import b64encode 

4from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union 

5 

6from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend 

7from .._exceptions import ProxyError 

8from .._models import ( 

9 URL, 

10 Origin, 

11 Request, 

12 Response, 

13 enforce_bytes, 

14 enforce_headers, 

15 enforce_url, 

16) 

17from .._ssl import default_ssl_context 

18from .._synchronization import AsyncLock 

19from .._trace import Trace 

20from .connection import AsyncHTTPConnection 

21from .connection_pool import AsyncConnectionPool 

22from .http11 import AsyncHTTP11Connection 

23from .interfaces import AsyncConnectionInterface 

24 

25HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]] 

26HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]] 

27 

28 

29logger = logging.getLogger("httpcore.proxy") 

30 

31 

32def merge_headers( 

33 default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, 

34 override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, 

35) -> List[Tuple[bytes, bytes]]: 

36 """ 

37 Append default_headers and override_headers, de-duplicating if a key exists 

38 in both cases. 

39 """ 

40 default_headers = [] if default_headers is None else list(default_headers) 

41 override_headers = [] if override_headers is None else list(override_headers) 

42 has_override = set(key.lower() for key, value in override_headers) 

43 default_headers = [ 

44 (key, value) 

45 for key, value in default_headers 

46 if key.lower() not in has_override 

47 ] 

48 return default_headers + override_headers 

49 

50 

51def build_auth_header(username: bytes, password: bytes) -> bytes: 

52 userpass = username + b":" + password 

53 return b"Basic " + b64encode(userpass) 

54 

55 

56class AsyncHTTPProxy(AsyncConnectionPool): 

57 """ 

58 A connection pool that sends requests via an HTTP proxy. 

59 """ 

60 

61 def __init__( 

62 self, 

63 proxy_url: Union[URL, bytes, str], 

64 proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None, 

65 proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, 

66 ssl_context: Optional[ssl.SSLContext] = None, 

67 proxy_ssl_context: Optional[ssl.SSLContext] = None, 

68 max_connections: Optional[int] = 10, 

69 max_keepalive_connections: Optional[int] = None, 

70 keepalive_expiry: Optional[float] = None, 

71 http1: bool = True, 

72 http2: bool = False, 

73 retries: int = 0, 

74 local_address: Optional[str] = None, 

75 uds: Optional[str] = None, 

76 network_backend: Optional[AsyncNetworkBackend] = None, 

77 socket_options: Optional[Iterable[SOCKET_OPTION]] = None, 

78 ) -> None: 

79 """ 

80 A connection pool for making HTTP requests. 

81 

82 Parameters: 

83 proxy_url: The URL to use when connecting to the proxy server. 

84 For example `"http://127.0.0.1:8080/"`. 

85 proxy_auth: Any proxy authentication as a two-tuple of 

86 (username, password). May be either bytes or ascii-only str. 

87 proxy_headers: Any HTTP headers to use for the proxy requests. 

88 For example `{"Proxy-Authorization": "Basic <username>:<password>"}`. 

89 ssl_context: An SSL context to use for verifying connections. 

90 If not specified, the default `httpcore.default_ssl_context()` 

91 will be used. 

92 proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin. 

93 max_connections: The maximum number of concurrent HTTP connections that 

94 the pool should allow. Any attempt to send a request on a pool that 

95 would exceed this amount will block until a connection is available. 

96 max_keepalive_connections: The maximum number of idle HTTP connections 

97 that will be maintained in the pool. 

98 keepalive_expiry: The duration in seconds that an idle HTTP connection 

99 may be maintained for before being expired from the pool. 

100 http1: A boolean indicating if HTTP/1.1 requests should be supported 

101 by the connection pool. Defaults to True. 

102 http2: A boolean indicating if HTTP/2 requests should be supported by 

103 the connection pool. Defaults to False. 

104 retries: The maximum number of retries when trying to establish 

105 a connection. 

106 local_address: Local address to connect from. Can also be used to 

107 connect using a particular address family. Using 

108 `local_address="0.0.0.0"` will connect using an `AF_INET` address 

109 (IPv4), while using `local_address="::"` will connect using an 

110 `AF_INET6` address (IPv6). 

111 uds: Path to a Unix Domain Socket to use instead of TCP sockets. 

112 network_backend: A backend instance to use for handling network I/O. 

113 """ 

114 super().__init__( 

115 ssl_context=ssl_context, 

116 max_connections=max_connections, 

117 max_keepalive_connections=max_keepalive_connections, 

118 keepalive_expiry=keepalive_expiry, 

119 http1=http1, 

120 http2=http2, 

121 network_backend=network_backend, 

122 retries=retries, 

123 local_address=local_address, 

124 uds=uds, 

125 socket_options=socket_options, 

126 ) 

127 

128 self._proxy_url = enforce_url(proxy_url, name="proxy_url") 

129 if ( 

130 self._proxy_url.scheme == b"http" and proxy_ssl_context is not None 

131 ): # pragma: no cover 

132 raise RuntimeError( 

133 "The `proxy_ssl_context` argument is not allowed for the http scheme" 

134 ) 

135 

136 self._ssl_context = ssl_context 

137 self._proxy_ssl_context = proxy_ssl_context 

138 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 

139 if proxy_auth is not None: 

140 username = enforce_bytes(proxy_auth[0], name="proxy_auth") 

141 password = enforce_bytes(proxy_auth[1], name="proxy_auth") 

142 authorization = build_auth_header(username, password) 

143 self._proxy_headers = [ 

144 (b"Proxy-Authorization", authorization) 

145 ] + self._proxy_headers 

146 

147 def create_connection(self, origin: Origin) -> AsyncConnectionInterface: 

148 if origin.scheme == b"http": 

149 return AsyncForwardHTTPConnection( 

150 proxy_origin=self._proxy_url.origin, 

151 proxy_headers=self._proxy_headers, 

152 remote_origin=origin, 

153 keepalive_expiry=self._keepalive_expiry, 

154 network_backend=self._network_backend, 

155 proxy_ssl_context=self._proxy_ssl_context, 

156 ) 

157 return AsyncTunnelHTTPConnection( 

158 proxy_origin=self._proxy_url.origin, 

159 proxy_headers=self._proxy_headers, 

160 remote_origin=origin, 

161 ssl_context=self._ssl_context, 

162 proxy_ssl_context=self._proxy_ssl_context, 

163 keepalive_expiry=self._keepalive_expiry, 

164 http1=self._http1, 

165 http2=self._http2, 

166 network_backend=self._network_backend, 

167 ) 

168 

169 

170class AsyncForwardHTTPConnection(AsyncConnectionInterface): 

171 def __init__( 

172 self, 

173 proxy_origin: Origin, 

174 remote_origin: Origin, 

175 proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, 

176 keepalive_expiry: Optional[float] = None, 

177 network_backend: Optional[AsyncNetworkBackend] = None, 

178 socket_options: Optional[Iterable[SOCKET_OPTION]] = None, 

179 proxy_ssl_context: Optional[ssl.SSLContext] = None, 

180 ) -> None: 

181 self._connection = AsyncHTTPConnection( 

182 origin=proxy_origin, 

183 keepalive_expiry=keepalive_expiry, 

184 network_backend=network_backend, 

185 socket_options=socket_options, 

186 ssl_context=proxy_ssl_context, 

187 ) 

188 self._proxy_origin = proxy_origin 

189 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 

190 self._remote_origin = remote_origin 

191 

192 async def handle_async_request(self, request: Request) -> Response: 

193 headers = merge_headers(self._proxy_headers, request.headers) 

194 url = URL( 

195 scheme=self._proxy_origin.scheme, 

196 host=self._proxy_origin.host, 

197 port=self._proxy_origin.port, 

198 target=bytes(request.url), 

199 ) 

200 proxy_request = Request( 

201 method=request.method, 

202 url=url, 

203 headers=headers, 

204 content=request.stream, 

205 extensions=request.extensions, 

206 ) 

207 return await self._connection.handle_async_request(proxy_request) 

208 

209 def can_handle_request(self, origin: Origin) -> bool: 

210 return origin == self._remote_origin 

211 

212 async def aclose(self) -> None: 

213 await self._connection.aclose() 

214 

215 def info(self) -> str: 

216 return self._connection.info() 

217 

218 def is_available(self) -> bool: 

219 return self._connection.is_available() 

220 

221 def has_expired(self) -> bool: 

222 return self._connection.has_expired() 

223 

224 def is_idle(self) -> bool: 

225 return self._connection.is_idle() 

226 

227 def is_closed(self) -> bool: 

228 return self._connection.is_closed() 

229 

230 def __repr__(self) -> str: 

231 return f"<{self.__class__.__name__} [{self.info()}]>" 

232 

233 

234class AsyncTunnelHTTPConnection(AsyncConnectionInterface): 

235 def __init__( 

236 self, 

237 proxy_origin: Origin, 

238 remote_origin: Origin, 

239 ssl_context: Optional[ssl.SSLContext] = None, 

240 proxy_ssl_context: Optional[ssl.SSLContext] = None, 

241 proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, 

242 keepalive_expiry: Optional[float] = None, 

243 http1: bool = True, 

244 http2: bool = False, 

245 network_backend: Optional[AsyncNetworkBackend] = None, 

246 socket_options: Optional[Iterable[SOCKET_OPTION]] = None, 

247 ) -> None: 

248 self._connection: AsyncConnectionInterface = AsyncHTTPConnection( 

249 origin=proxy_origin, 

250 keepalive_expiry=keepalive_expiry, 

251 network_backend=network_backend, 

252 socket_options=socket_options, 

253 ssl_context=proxy_ssl_context, 

254 ) 

255 self._proxy_origin = proxy_origin 

256 self._remote_origin = remote_origin 

257 self._ssl_context = ssl_context 

258 self._proxy_ssl_context = proxy_ssl_context 

259 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 

260 self._keepalive_expiry = keepalive_expiry 

261 self._http1 = http1 

262 self._http2 = http2 

263 self._connect_lock = AsyncLock() 

264 self._connected = False 

265 

266 async def handle_async_request(self, request: Request) -> Response: 

267 timeouts = request.extensions.get("timeout", {}) 

268 timeout = timeouts.get("connect", None) 

269 

270 async with self._connect_lock: 

271 if not self._connected: 

272 target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) 

273 

274 connect_url = URL( 

275 scheme=self._proxy_origin.scheme, 

276 host=self._proxy_origin.host, 

277 port=self._proxy_origin.port, 

278 target=target, 

279 ) 

280 connect_headers = merge_headers( 

281 [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers 

282 ) 

283 connect_request = Request( 

284 method=b"CONNECT", 

285 url=connect_url, 

286 headers=connect_headers, 

287 extensions=request.extensions, 

288 ) 

289 connect_response = await self._connection.handle_async_request( 

290 connect_request 

291 ) 

292 

293 if connect_response.status < 200 or connect_response.status > 299: 

294 reason_bytes = connect_response.extensions.get("reason_phrase", b"") 

295 reason_str = reason_bytes.decode("ascii", errors="ignore") 

296 msg = "%d %s" % (connect_response.status, reason_str) 

297 await self._connection.aclose() 

298 raise ProxyError(msg) 

299 

300 stream = connect_response.extensions["network_stream"] 

301 

302 # Upgrade the stream to SSL 

303 ssl_context = ( 

304 default_ssl_context() 

305 if self._ssl_context is None 

306 else self._ssl_context 

307 ) 

308 alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] 

309 ssl_context.set_alpn_protocols(alpn_protocols) 

310 

311 kwargs = { 

312 "ssl_context": ssl_context, 

313 "server_hostname": self._remote_origin.host.decode("ascii"), 

314 "timeout": timeout, 

315 } 

316 async with Trace("start_tls", logger, request, kwargs) as trace: 

317 stream = await stream.start_tls(**kwargs) 

318 trace.return_value = stream 

319 

320 # Determine if we should be using HTTP/1.1 or HTTP/2 

321 ssl_object = stream.get_extra_info("ssl_object") 

322 http2_negotiated = ( 

323 ssl_object is not None 

324 and ssl_object.selected_alpn_protocol() == "h2" 

325 ) 

326 

327 # Create the HTTP/1.1 or HTTP/2 connection 

328 if http2_negotiated or (self._http2 and not self._http1): 

329 from .http2 import AsyncHTTP2Connection 

330 

331 self._connection = AsyncHTTP2Connection( 

332 origin=self._remote_origin, 

333 stream=stream, 

334 keepalive_expiry=self._keepalive_expiry, 

335 ) 

336 else: 

337 self._connection = AsyncHTTP11Connection( 

338 origin=self._remote_origin, 

339 stream=stream, 

340 keepalive_expiry=self._keepalive_expiry, 

341 ) 

342 

343 self._connected = True 

344 return await self._connection.handle_async_request(request) 

345 

346 def can_handle_request(self, origin: Origin) -> bool: 

347 return origin == self._remote_origin 

348 

349 async def aclose(self) -> None: 

350 await self._connection.aclose() 

351 

352 def info(self) -> str: 

353 return self._connection.info() 

354 

355 def is_available(self) -> bool: 

356 return self._connection.is_available() 

357 

358 def has_expired(self) -> bool: 

359 return self._connection.has_expired() 

360 

361 def is_idle(self) -> bool: 

362 return self._connection.is_idle() 

363 

364 def is_closed(self) -> bool: 

365 return self._connection.is_closed() 

366 

367 def __repr__(self) -> str: 

368 return f"<{self.__class__.__name__} [{self.info()}]>"