Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/httpcore/_sync/http_proxy.py: 33%

126 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import ssl 

2from base64 import b64encode 

3from typing import List, Mapping, Optional, Sequence, Tuple, Union 

4 

5from .._exceptions import ProxyError 

6from .._models import ( 

7 URL, 

8 Origin, 

9 Request, 

10 Response, 

11 enforce_bytes, 

12 enforce_headers, 

13 enforce_url, 

14) 

15from .._ssl import default_ssl_context 

16from .._synchronization import Lock 

17from .._trace import Trace 

18from ..backends.base import NetworkBackend 

19from .connection import HTTPConnection 

20from .connection_pool import ConnectionPool 

21from .http11 import HTTP11Connection 

22from .interfaces import ConnectionInterface 

23 

24HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]] 

25HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]] 

26 

27 

28def merge_headers( 

29 default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, 

30 override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, 

31) -> List[Tuple[bytes, bytes]]: 

32 """ 

33 Append default_headers and override_headers, de-duplicating if a key exists 

34 in both cases. 

35 """ 

36 default_headers = [] if default_headers is None else list(default_headers) 

37 override_headers = [] if override_headers is None else list(override_headers) 

38 has_override = set(key.lower() for key, value in override_headers) 

39 default_headers = [ 

40 (key, value) 

41 for key, value in default_headers 

42 if key.lower() not in has_override 

43 ] 

44 return default_headers + override_headers 

45 

46 

47def build_auth_header(username: bytes, password: bytes) -> bytes: 

48 userpass = username + b":" + password 

49 return b"Basic " + b64encode(userpass) 

50 

51 

52class HTTPProxy(ConnectionPool): 

53 """ 

54 A connection pool that sends requests via an HTTP proxy. 

55 """ 

56 

57 def __init__( 

58 self, 

59 proxy_url: Union[URL, bytes, str], 

60 proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None, 

61 proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, 

62 ssl_context: Optional[ssl.SSLContext] = None, 

63 max_connections: Optional[int] = 10, 

64 max_keepalive_connections: Optional[int] = None, 

65 keepalive_expiry: Optional[float] = None, 

66 http1: bool = True, 

67 http2: bool = False, 

68 retries: int = 0, 

69 local_address: Optional[str] = None, 

70 uds: Optional[str] = None, 

71 network_backend: Optional[NetworkBackend] = None, 

72 ) -> None: 

73 """ 

74 A connection pool for making HTTP requests. 

75 

76 Parameters: 

77 proxy_url: The URL to use when connecting to the proxy server. 

78 For example `"http://127.0.0.1:8080/"`. 

79 proxy_auth: Any proxy authentication as a two-tuple of 

80 (username, password). May be either bytes or ascii-only str. 

81 proxy_headers: Any HTTP headers to use for the proxy requests. 

82 For example `{"Proxy-Authorization": "Basic <username>:<password>"}`. 

83 ssl_context: An SSL context to use for verifying connections. 

84 If not specified, the default `httpcore.default_ssl_context()` 

85 will be used. 

86 max_connections: The maximum number of concurrent HTTP connections that 

87 the pool should allow. Any attempt to send a request on a pool that 

88 would exceed this amount will block until a connection is available. 

89 max_keepalive_connections: The maximum number of idle HTTP connections 

90 that will be maintained in the pool. 

91 keepalive_expiry: The duration in seconds that an idle HTTP connection 

92 may be maintained for before being expired from the pool. 

93 http1: A boolean indicating if HTTP/1.1 requests should be supported 

94 by the connection pool. Defaults to True. 

95 http2: A boolean indicating if HTTP/2 requests should be supported by 

96 the connection pool. Defaults to False. 

97 retries: The maximum number of retries when trying to establish 

98 a connection. 

99 local_address: Local address to connect from. Can also be used to 

100 connect using a particular address family. Using 

101 `local_address="0.0.0.0"` will connect using an `AF_INET` address 

102 (IPv4), while using `local_address="::"` will connect using an 

103 `AF_INET6` address (IPv6). 

104 uds: Path to a Unix Domain Socket to use instead of TCP sockets. 

105 network_backend: A backend instance to use for handling network I/O. 

106 """ 

107 super().__init__( 

108 ssl_context=ssl_context, 

109 max_connections=max_connections, 

110 max_keepalive_connections=max_keepalive_connections, 

111 keepalive_expiry=keepalive_expiry, 

112 http1=http1, 

113 http2=http2, 

114 network_backend=network_backend, 

115 retries=retries, 

116 local_address=local_address, 

117 uds=uds, 

118 ) 

119 self._ssl_context = ssl_context 

120 self._proxy_url = enforce_url(proxy_url, name="proxy_url") 

121 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 

122 if proxy_auth is not None: 

123 username = enforce_bytes(proxy_auth[0], name="proxy_auth") 

124 password = enforce_bytes(proxy_auth[1], name="proxy_auth") 

125 authorization = build_auth_header(username, password) 

126 self._proxy_headers = [ 

127 (b"Proxy-Authorization", authorization) 

128 ] + self._proxy_headers 

129 

130 def create_connection(self, origin: Origin) -> ConnectionInterface: 

131 if origin.scheme == b"http": 

132 return ForwardHTTPConnection( 

133 proxy_origin=self._proxy_url.origin, 

134 proxy_headers=self._proxy_headers, 

135 remote_origin=origin, 

136 keepalive_expiry=self._keepalive_expiry, 

137 network_backend=self._network_backend, 

138 ) 

139 return TunnelHTTPConnection( 

140 proxy_origin=self._proxy_url.origin, 

141 proxy_headers=self._proxy_headers, 

142 remote_origin=origin, 

143 ssl_context=self._ssl_context, 

144 keepalive_expiry=self._keepalive_expiry, 

145 http1=self._http1, 

146 http2=self._http2, 

147 network_backend=self._network_backend, 

148 ) 

149 

150 

151class ForwardHTTPConnection(ConnectionInterface): 

152 def __init__( 

153 self, 

154 proxy_origin: Origin, 

155 remote_origin: Origin, 

156 proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, 

157 keepalive_expiry: Optional[float] = None, 

158 network_backend: Optional[NetworkBackend] = None, 

159 ) -> None: 

160 self._connection = HTTPConnection( 

161 origin=proxy_origin, 

162 keepalive_expiry=keepalive_expiry, 

163 network_backend=network_backend, 

164 ) 

165 self._proxy_origin = proxy_origin 

166 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 

167 self._remote_origin = remote_origin 

168 

169 def handle_request(self, request: Request) -> Response: 

170 headers = merge_headers(self._proxy_headers, request.headers) 

171 url = URL( 

172 scheme=self._proxy_origin.scheme, 

173 host=self._proxy_origin.host, 

174 port=self._proxy_origin.port, 

175 target=bytes(request.url), 

176 ) 

177 proxy_request = Request( 

178 method=request.method, 

179 url=url, 

180 headers=headers, 

181 content=request.stream, 

182 extensions=request.extensions, 

183 ) 

184 return self._connection.handle_request(proxy_request) 

185 

186 def can_handle_request(self, origin: Origin) -> bool: 

187 return origin == self._remote_origin 

188 

189 def close(self) -> None: 

190 self._connection.close() 

191 

192 def info(self) -> str: 

193 return self._connection.info() 

194 

195 def is_available(self) -> bool: 

196 return self._connection.is_available() 

197 

198 def has_expired(self) -> bool: 

199 return self._connection.has_expired() 

200 

201 def is_idle(self) -> bool: 

202 return self._connection.is_idle() 

203 

204 def is_closed(self) -> bool: 

205 return self._connection.is_closed() 

206 

207 def __repr__(self) -> str: 

208 return f"<{self.__class__.__name__} [{self.info()}]>" 

209 

210 

211class TunnelHTTPConnection(ConnectionInterface): 

212 def __init__( 

213 self, 

214 proxy_origin: Origin, 

215 remote_origin: Origin, 

216 ssl_context: Optional[ssl.SSLContext] = None, 

217 proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, 

218 keepalive_expiry: Optional[float] = None, 

219 http1: bool = True, 

220 http2: bool = False, 

221 network_backend: Optional[NetworkBackend] = None, 

222 ) -> None: 

223 self._connection: ConnectionInterface = HTTPConnection( 

224 origin=proxy_origin, 

225 keepalive_expiry=keepalive_expiry, 

226 network_backend=network_backend, 

227 ) 

228 self._proxy_origin = proxy_origin 

229 self._remote_origin = remote_origin 

230 self._ssl_context = ssl_context 

231 self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") 

232 self._keepalive_expiry = keepalive_expiry 

233 self._http1 = http1 

234 self._http2 = http2 

235 self._connect_lock = Lock() 

236 self._connected = False 

237 

238 def handle_request(self, request: Request) -> Response: 

239 timeouts = request.extensions.get("timeout", {}) 

240 timeout = timeouts.get("connect", None) 

241 

242 with self._connect_lock: 

243 if not self._connected: 

244 target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) 

245 

246 connect_url = URL( 

247 scheme=self._proxy_origin.scheme, 

248 host=self._proxy_origin.host, 

249 port=self._proxy_origin.port, 

250 target=target, 

251 ) 

252 connect_headers = merge_headers( 

253 [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers 

254 ) 

255 connect_request = Request( 

256 method=b"CONNECT", 

257 url=connect_url, 

258 headers=connect_headers, 

259 extensions=request.extensions, 

260 ) 

261 connect_response = self._connection.handle_request( 

262 connect_request 

263 ) 

264 

265 if connect_response.status < 200 or connect_response.status > 299: 

266 reason_bytes = connect_response.extensions.get("reason_phrase", b"") 

267 reason_str = reason_bytes.decode("ascii", errors="ignore") 

268 msg = "%d %s" % (connect_response.status, reason_str) 

269 self._connection.close() 

270 raise ProxyError(msg) 

271 

272 stream = connect_response.extensions["network_stream"] 

273 

274 # Upgrade the stream to SSL 

275 ssl_context = ( 

276 default_ssl_context() 

277 if self._ssl_context is None 

278 else self._ssl_context 

279 ) 

280 alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] 

281 ssl_context.set_alpn_protocols(alpn_protocols) 

282 

283 kwargs = { 

284 "ssl_context": ssl_context, 

285 "server_hostname": self._remote_origin.host.decode("ascii"), 

286 "timeout": timeout, 

287 } 

288 with Trace("connection.start_tls", request, kwargs) as trace: 

289 stream = stream.start_tls(**kwargs) 

290 trace.return_value = stream 

291 

292 # Determine if we should be using HTTP/1.1 or HTTP/2 

293 ssl_object = stream.get_extra_info("ssl_object") 

294 http2_negotiated = ( 

295 ssl_object is not None 

296 and ssl_object.selected_alpn_protocol() == "h2" 

297 ) 

298 

299 # Create the HTTP/1.1 or HTTP/2 connection 

300 if http2_negotiated or (self._http2 and not self._http1): 

301 from .http2 import HTTP2Connection 

302 

303 self._connection = HTTP2Connection( 

304 origin=self._remote_origin, 

305 stream=stream, 

306 keepalive_expiry=self._keepalive_expiry, 

307 ) 

308 else: 

309 self._connection = HTTP11Connection( 

310 origin=self._remote_origin, 

311 stream=stream, 

312 keepalive_expiry=self._keepalive_expiry, 

313 ) 

314 

315 self._connected = True 

316 return self._connection.handle_request(request) 

317 

318 def can_handle_request(self, origin: Origin) -> bool: 

319 return origin == self._remote_origin 

320 

321 def close(self) -> None: 

322 self._connection.close() 

323 

324 def info(self) -> str: 

325 return self._connection.info() 

326 

327 def is_available(self) -> bool: 

328 return self._connection.is_available() 

329 

330 def has_expired(self) -> bool: 

331 return self._connection.has_expired() 

332 

333 def is_idle(self) -> bool: 

334 return self._connection.is_idle() 

335 

336 def is_closed(self) -> bool: 

337 return self._connection.is_closed() 

338 

339 def __repr__(self) -> str: 

340 return f"<{self.__class__.__name__} [{self.info()}]>"