1import logging
2import ssl
3import typing
4
5from socksio import socks5
6
7from .._backends.auto import AutoBackend
8from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
9from .._exceptions import ConnectionNotAvailable, ProxyError
10from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url
11from .._ssl import default_ssl_context
12from .._synchronization import AsyncLock
13from .._trace import Trace
14from .connection_pool import AsyncConnectionPool
15from .http11 import AsyncHTTP11Connection
16from .interfaces import AsyncConnectionInterface
17
18logger = logging.getLogger("httpcore.socks")
19
20
21AUTH_METHODS = {
22 b"\x00": "NO AUTHENTICATION REQUIRED",
23 b"\x01": "GSSAPI",
24 b"\x02": "USERNAME/PASSWORD",
25 b"\xff": "NO ACCEPTABLE METHODS",
26}
27
28REPLY_CODES = {
29 b"\x00": "Succeeded",
30 b"\x01": "General SOCKS server failure",
31 b"\x02": "Connection not allowed by ruleset",
32 b"\x03": "Network unreachable",
33 b"\x04": "Host unreachable",
34 b"\x05": "Connection refused",
35 b"\x06": "TTL expired",
36 b"\x07": "Command not supported",
37 b"\x08": "Address type not supported",
38}
39
40
41async def _init_socks5_connection(
42 stream: AsyncNetworkStream,
43 *,
44 host: bytes,
45 port: int,
46 auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
47) -> None:
48 conn = socks5.SOCKS5Connection()
49
50 # Auth method request
51 auth_method = (
52 socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
53 if auth is None
54 else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
55 )
56 conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
57 outgoing_bytes = conn.data_to_send()
58 await stream.write(outgoing_bytes)
59
60 # Auth method response
61 incoming_bytes = await stream.read(max_bytes=4096)
62 response = conn.receive_data(incoming_bytes)
63 assert isinstance(response, socks5.SOCKS5AuthReply)
64 if response.method != auth_method:
65 requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
66 responded = AUTH_METHODS.get(response.method, "UNKNOWN")
67 raise ProxyError(
68 f"Requested {requested} from proxy server, but got {responded}."
69 )
70
71 if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
72 # Username/password request
73 assert auth is not None
74 username, password = auth
75 conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
76 outgoing_bytes = conn.data_to_send()
77 await stream.write(outgoing_bytes)
78
79 # Username/password response
80 incoming_bytes = await stream.read(max_bytes=4096)
81 response = conn.receive_data(incoming_bytes)
82 assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
83 if not response.success:
84 raise ProxyError("Invalid username/password")
85
86 # Connect request
87 conn.send(
88 socks5.SOCKS5CommandRequest.from_address(
89 socks5.SOCKS5Command.CONNECT, (host, port)
90 )
91 )
92 outgoing_bytes = conn.data_to_send()
93 await stream.write(outgoing_bytes)
94
95 # Connect response
96 incoming_bytes = await stream.read(max_bytes=4096)
97 response = conn.receive_data(incoming_bytes)
98 assert isinstance(response, socks5.SOCKS5Reply)
99 if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
100 reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
101 raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
102
103
104class AsyncSOCKSProxy(AsyncConnectionPool):
105 """
106 A connection pool that sends requests via an HTTP proxy.
107 """
108
109 def __init__(
110 self,
111 proxy_url: typing.Union[URL, bytes, str],
112 proxy_auth: typing.Optional[
113 typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
114 ] = None,
115 ssl_context: typing.Optional[ssl.SSLContext] = None,
116 max_connections: typing.Optional[int] = 10,
117 max_keepalive_connections: typing.Optional[int] = None,
118 keepalive_expiry: typing.Optional[float] = None,
119 http1: bool = True,
120 http2: bool = False,
121 retries: int = 0,
122 network_backend: typing.Optional[AsyncNetworkBackend] = None,
123 ) -> None:
124 """
125 A connection pool for making HTTP requests.
126
127 Parameters:
128 proxy_url: The URL to use when connecting to the proxy server.
129 For example `"http://127.0.0.1:8080/"`.
130 ssl_context: An SSL context to use for verifying connections.
131 If not specified, the default `httpcore.default_ssl_context()`
132 will be used.
133 max_connections: The maximum number of concurrent HTTP connections that
134 the pool should allow. Any attempt to send a request on a pool that
135 would exceed this amount will block until a connection is available.
136 max_keepalive_connections: The maximum number of idle HTTP connections
137 that will be maintained in the pool.
138 keepalive_expiry: The duration in seconds that an idle HTTP connection
139 may be maintained for before being expired from the pool.
140 http1: A boolean indicating if HTTP/1.1 requests should be supported
141 by the connection pool. Defaults to True.
142 http2: A boolean indicating if HTTP/2 requests should be supported by
143 the connection pool. Defaults to False.
144 retries: The maximum number of retries when trying to establish
145 a connection.
146 local_address: Local address to connect from. Can also be used to
147 connect using a particular address family. Using
148 `local_address="0.0.0.0"` will connect using an `AF_INET` address
149 (IPv4), while using `local_address="::"` will connect using an
150 `AF_INET6` address (IPv6).
151 uds: Path to a Unix Domain Socket to use instead of TCP sockets.
152 network_backend: A backend instance to use for handling network I/O.
153 """
154 super().__init__(
155 ssl_context=ssl_context,
156 max_connections=max_connections,
157 max_keepalive_connections=max_keepalive_connections,
158 keepalive_expiry=keepalive_expiry,
159 http1=http1,
160 http2=http2,
161 network_backend=network_backend,
162 retries=retries,
163 )
164 self._ssl_context = ssl_context
165 self._proxy_url = enforce_url(proxy_url, name="proxy_url")
166 if proxy_auth is not None:
167 username, password = proxy_auth
168 username_bytes = enforce_bytes(username, name="proxy_auth")
169 password_bytes = enforce_bytes(password, name="proxy_auth")
170 self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
171 username_bytes,
172 password_bytes,
173 )
174 else:
175 self._proxy_auth = None
176
177 def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
178 return AsyncSocks5Connection(
179 proxy_origin=self._proxy_url.origin,
180 remote_origin=origin,
181 proxy_auth=self._proxy_auth,
182 ssl_context=self._ssl_context,
183 keepalive_expiry=self._keepalive_expiry,
184 http1=self._http1,
185 http2=self._http2,
186 network_backend=self._network_backend,
187 )
188
189
190class AsyncSocks5Connection(AsyncConnectionInterface):
191 def __init__(
192 self,
193 proxy_origin: Origin,
194 remote_origin: Origin,
195 proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
196 ssl_context: typing.Optional[ssl.SSLContext] = None,
197 keepalive_expiry: typing.Optional[float] = None,
198 http1: bool = True,
199 http2: bool = False,
200 network_backend: typing.Optional[AsyncNetworkBackend] = None,
201 ) -> None:
202 self._proxy_origin = proxy_origin
203 self._remote_origin = remote_origin
204 self._proxy_auth = proxy_auth
205 self._ssl_context = ssl_context
206 self._keepalive_expiry = keepalive_expiry
207 self._http1 = http1
208 self._http2 = http2
209
210 self._network_backend: AsyncNetworkBackend = (
211 AutoBackend() if network_backend is None else network_backend
212 )
213 self._connect_lock = AsyncLock()
214 self._connection: typing.Optional[AsyncConnectionInterface] = None
215 self._connect_failed = False
216
217 async def handle_async_request(self, request: Request) -> Response:
218 timeouts = request.extensions.get("timeout", {})
219 sni_hostname = request.extensions.get("sni_hostname", None)
220 timeout = timeouts.get("connect", None)
221
222 async with self._connect_lock:
223 if self._connection is None:
224 try:
225 # Connect to the proxy
226 kwargs = {
227 "host": self._proxy_origin.host.decode("ascii"),
228 "port": self._proxy_origin.port,
229 "timeout": timeout,
230 }
231 async with Trace("connect_tcp", logger, request, kwargs) as trace:
232 stream = await self._network_backend.connect_tcp(**kwargs)
233 trace.return_value = stream
234
235 # Connect to the remote host using socks5
236 kwargs = {
237 "stream": stream,
238 "host": self._remote_origin.host.decode("ascii"),
239 "port": self._remote_origin.port,
240 "auth": self._proxy_auth,
241 }
242 async with Trace(
243 "setup_socks5_connection", logger, request, kwargs
244 ) as trace:
245 await _init_socks5_connection(**kwargs)
246 trace.return_value = stream
247
248 # Upgrade the stream to SSL
249 if self._remote_origin.scheme == b"https":
250 ssl_context = (
251 default_ssl_context()
252 if self._ssl_context is None
253 else self._ssl_context
254 )
255 alpn_protocols = (
256 ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
257 )
258 ssl_context.set_alpn_protocols(alpn_protocols)
259
260 kwargs = {
261 "ssl_context": ssl_context,
262 "server_hostname": sni_hostname
263 or self._remote_origin.host.decode("ascii"),
264 "timeout": timeout,
265 }
266 async with Trace("start_tls", logger, request, kwargs) as trace:
267 stream = await stream.start_tls(**kwargs)
268 trace.return_value = stream
269
270 # Determine if we should be using HTTP/1.1 or HTTP/2
271 ssl_object = stream.get_extra_info("ssl_object")
272 http2_negotiated = (
273 ssl_object is not None
274 and ssl_object.selected_alpn_protocol() == "h2"
275 )
276
277 # Create the HTTP/1.1 or HTTP/2 connection
278 if http2_negotiated or (
279 self._http2 and not self._http1
280 ): # pragma: nocover
281 from .http2 import AsyncHTTP2Connection
282
283 self._connection = AsyncHTTP2Connection(
284 origin=self._remote_origin,
285 stream=stream,
286 keepalive_expiry=self._keepalive_expiry,
287 )
288 else:
289 self._connection = AsyncHTTP11Connection(
290 origin=self._remote_origin,
291 stream=stream,
292 keepalive_expiry=self._keepalive_expiry,
293 )
294 except Exception as exc:
295 self._connect_failed = True
296 raise exc
297 elif not self._connection.is_available(): # pragma: nocover
298 raise ConnectionNotAvailable()
299
300 return await self._connection.handle_async_request(request)
301
302 def can_handle_request(self, origin: Origin) -> bool:
303 return origin == self._remote_origin
304
305 async def aclose(self) -> None:
306 if self._connection is not None:
307 await self._connection.aclose()
308
309 def is_available(self) -> bool:
310 if self._connection is None: # pragma: nocover
311 # If HTTP/2 support is enabled, and the resulting connection could
312 # end up as HTTP/2 then we should indicate the connection as being
313 # available to service multiple requests.
314 return (
315 self._http2
316 and (self._remote_origin.scheme == b"https" or not self._http1)
317 and not self._connect_failed
318 )
319 return self._connection.is_available()
320
321 def has_expired(self) -> bool:
322 if self._connection is None: # pragma: nocover
323 return self._connect_failed
324 return self._connection.has_expired()
325
326 def is_idle(self) -> bool:
327 if self._connection is None: # pragma: nocover
328 return self._connect_failed
329 return self._connection.is_idle()
330
331 def is_closed(self) -> bool:
332 if self._connection is None: # pragma: nocover
333 return self._connect_failed
334 return self._connection.is_closed()
335
336 def info(self) -> str:
337 if self._connection is None: # pragma: nocover
338 return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
339 return self._connection.info()
340
341 def __repr__(self) -> str:
342 return f"<{self.__class__.__name__} [{self.info()}]>"