/src/ztunnel/src/proxy/socks5.rs
Line | Count | Source |
1 | | // Copyright Istio Authors |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use anyhow::Result; |
16 | | use byteorder::{BigEndian, ByteOrder}; |
17 | | |
18 | | use crate::dns::resolver::Resolver; |
19 | | use hickory_proto::op::{Message, MessageType, Query}; |
20 | | use hickory_proto::rr::{Name, RecordType}; |
21 | | use hickory_proto::serialize::binary::BinDecodable; |
22 | | use hickory_proto::xfer::Protocol; |
23 | | use hickory_server::authority::MessageRequest; |
24 | | use hickory_server::server::Request; |
25 | | use std::net::{IpAddr, Ipv4Addr, SocketAddr}; |
26 | | use std::sync::Arc; |
27 | | use std::time::Instant; |
28 | | use tokio::io::AsyncReadExt; |
29 | | use tokio::io::AsyncWriteExt; |
30 | | use tokio::net::TcpStream; |
31 | | use tokio::sync::watch; |
32 | | use tracing::{Instrument, debug, error, info, info_span, warn}; |
33 | | |
34 | | use crate::drain::DrainWatcher; |
35 | | use crate::drain::run_with_drain; |
36 | | use crate::proxy::outbound::OutboundConnection; |
37 | | use crate::proxy::{Error, ProxyInputs, TraceParent, util}; |
38 | | use crate::{assertions, socket}; |
39 | | |
40 | | pub(super) struct Socks5 { |
41 | | pi: Arc<ProxyInputs>, |
42 | | listener: socket::Listener, |
43 | | drain: DrainWatcher, |
44 | | } |
45 | | |
46 | | impl Socks5 { |
47 | 0 | pub(super) async fn new(pi: Arc<ProxyInputs>, drain: DrainWatcher) -> Result<Socks5, Error> { |
48 | 0 | let listener = pi |
49 | 0 | .socket_factory |
50 | 0 | .tcp_bind(pi.cfg.socks5_addr.unwrap()) |
51 | 0 | .map_err(|e| Error::Bind(pi.cfg.socks5_addr.unwrap(), e))?; |
52 | | |
53 | 0 | let transparent = super::maybe_set_transparent(&pi, &listener)?; |
54 | | |
55 | 0 | info!( |
56 | 0 | address=%listener.local_addr(), |
57 | | component="socks5", |
58 | | transparent, |
59 | 0 | "listener established", |
60 | | ); |
61 | | |
62 | 0 | Ok(Socks5 { |
63 | 0 | pi, |
64 | 0 | listener, |
65 | 0 | drain, |
66 | 0 | }) |
67 | 0 | } |
68 | | |
69 | 0 | pub(super) fn address(&self) -> SocketAddr { |
70 | 0 | self.listener.local_addr() |
71 | 0 | } |
72 | | |
73 | 0 | pub async fn run(self) { |
74 | 0 | let pi = self.pi.clone(); |
75 | 0 | let pool = crate::proxy::pool::WorkloadHBONEPool::new( |
76 | 0 | self.pi.cfg.clone(), |
77 | 0 | self.pi.socket_factory.clone(), |
78 | 0 | self.pi.local_workload_information.clone(), |
79 | | ); |
80 | 0 | let accept = async move |drain: DrainWatcher, force_shutdown: watch::Receiver<()>| { |
81 | | loop { |
82 | | // Asynchronously wait for an inbound socket. |
83 | 0 | let socket = self.listener.accept().await; |
84 | 0 | let start = Instant::now(); |
85 | 0 | let drain = drain.clone(); |
86 | 0 | let mut force_shutdown = force_shutdown.clone(); |
87 | 0 | match socket { |
88 | 0 | Ok((stream, _remote)) => { |
89 | 0 | let socket_labels = crate::proxy::metrics::SocketLabels { |
90 | 0 | reporter: crate::proxy::metrics::Reporter::source, |
91 | 0 | }; |
92 | 0 | self.pi.metrics.record_socket_open(&socket_labels); |
93 | | |
94 | 0 | let oc = OutboundConnection { |
95 | 0 | pi: self.pi.clone(), |
96 | 0 | id: TraceParent::new(), |
97 | 0 | pool: pool.clone(), |
98 | 0 | hbone_port: self.pi.cfg.inbound_addr.port(), |
99 | 0 | }; |
100 | 0 | let span = info_span!("socks5", id=%oc.id); |
101 | 0 | let metrics_for_socket_close = self.pi.metrics.clone(); |
102 | 0 | let serve = (async move { |
103 | 0 | let _socket_guard = crate::proxy::metrics::SocketCloseGuard::new( |
104 | 0 | metrics_for_socket_close, |
105 | 0 | crate::proxy::metrics::Reporter::source, |
106 | | ); |
107 | 0 | debug!(component="socks5", "connection started"); |
108 | | // Since this task is spawned, make sure we are guaranteed to terminate |
109 | 0 | tokio::select! { |
110 | 0 | _ = force_shutdown.changed() => { |
111 | 0 | debug!(component="socks5", "connection forcefully terminated"); |
112 | | } |
113 | 0 | _ = handle_socks_connection(oc, stream) => {} |
114 | | } |
115 | | // Mark we are done with the connection, so drain can complete |
116 | 0 | drop(drain); |
117 | 0 | debug!(component="socks5", dur=?start.elapsed(), "connection completed"); |
118 | 0 | }).instrument(span); |
119 | | |
120 | 0 | assertions::size_between_ref(1000, 2000, &serve); |
121 | 0 | tokio::spawn(serve); |
122 | | } |
123 | 0 | Err(e) => { |
124 | 0 | if util::is_runtime_shutdown(&e) { |
125 | 0 | return; |
126 | 0 | } |
127 | 0 | error!("Failed TCP handshake {}", e); |
128 | | } |
129 | | } |
130 | | } |
131 | 0 | }; Unexecuted instantiation: <ztunnel::proxy::socks5::Socks5>::run::{closure#0}::{closure#0}::{closure#0}::<_>Unexecuted instantiation: <ztunnel::proxy::socks5::Socks5>::run::{closure#0}::{closure#0} |
132 | | |
133 | 0 | run_with_drain( |
134 | 0 | "socks5".to_string(), |
135 | 0 | self.drain, |
136 | 0 | pi.cfg.self_termination_deadline, |
137 | 0 | accept, |
138 | 0 | ) |
139 | 0 | .await |
140 | 0 | } |
141 | | } |
142 | 0 | async fn handle_socks_connection(mut oc: OutboundConnection, mut stream: TcpStream) { |
143 | 0 | match negotiate_socks_connection(&oc.pi, &mut stream).await { |
144 | 0 | Ok(target) => { |
145 | | // TODO: ideally, we send the success after we connect. This allows us to actually give a |
146 | | // success only when we really succeeded, rather than if we completed the SOCKS handshake. |
147 | | // Additionally, it would allow us to get a proper address to send back/ |
148 | 0 | let dummy_addr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0); |
149 | 0 | if let Err(err) = send_success(&mut stream, dummy_addr).await { |
150 | 0 | warn!("failed to send socks success response: {err}"); |
151 | 0 | return; |
152 | 0 | } |
153 | 0 | let remote_addr = |
154 | 0 | socket::to_canonical(stream.peer_addr().expect("must receive peer addr")); |
155 | 0 | oc.proxy_to(stream, remote_addr, target).await |
156 | | } |
157 | 0 | Err(e) => { |
158 | 0 | warn!("failed to negotiate socks connection: {e}"); |
159 | 0 | send_error(&e, &mut stream).await; |
160 | | } |
161 | | } |
162 | 0 | } |
163 | | |
164 | | // negotiate_socks_connection will handle the negotiation of a SOCKS5 connection. |
165 | | // This ultimately outputs the target socket address, if the handshake is successful. |
166 | | // This supports a minimal subset of the protocol, sufficient to integrate with common clients: |
167 | | // - only unauthenticated requests |
168 | | // - only CONNECT, with IPv4/IPv6/Hostname |
169 | 0 | async fn negotiate_socks_connection( |
170 | 0 | pi: &ProxyInputs, |
171 | 0 | stream: &mut TcpStream, |
172 | 0 | ) -> Result<SocketAddr, SocksError> { |
173 | 0 | let remote_addr = socket::to_canonical(stream.peer_addr().expect("must receive peer addr")); |
174 | | |
175 | | // Version(5), Number of auth methods |
176 | 0 | let mut version = [0u8; 2]; |
177 | 0 | stream.read_exact(&mut version).await?; |
178 | | |
179 | 0 | if version[0] != 0x05 { |
180 | 0 | return Err(SocksError::invalid_protocol(format!( |
181 | 0 | "unsupported version {}", |
182 | 0 | version[0] |
183 | 0 | ))); |
184 | 0 | } |
185 | | |
186 | 0 | let nmethods = version[1]; |
187 | | |
188 | 0 | if nmethods == 0 { |
189 | 0 | return Err(SocksError::invalid_protocol(format!( |
190 | 0 | "methods cannot be zero {}", |
191 | 0 | version[0] |
192 | 0 | ))); |
193 | 0 | } |
194 | | |
195 | | // List of supported auth methods |
196 | 0 | let mut methods = vec![0u8; nmethods as usize]; |
197 | 0 | stream.read_exact(&mut methods).await?; |
198 | | |
199 | | // Client must include 'unauthenticated' (0). |
200 | 0 | if !methods.into_iter().any(|x| x == 0) { |
201 | 0 | return Err(SocksError::invalid_protocol( |
202 | 0 | "only unauthenticated is supported".to_string(), |
203 | 0 | )); |
204 | 0 | } |
205 | | |
206 | | // Select 'unauthenticated' (0). |
207 | 0 | stream.write_all(&[0x05, 0x00]).await?; |
208 | | |
209 | | // Version(5), Command - only support CONNECT (1) |
210 | 0 | let mut version_command = [0u8; 2]; |
211 | 0 | stream.read_exact(&mut version_command).await?; |
212 | 0 | let version = version_command[0]; |
213 | | |
214 | 0 | if version != 0x05 { |
215 | 0 | return Err(SocksError::invalid_protocol(format!( |
216 | 0 | "unsupported version {version}", |
217 | 0 | ))); |
218 | 0 | } |
219 | | |
220 | 0 | if version_command[1] != 1 { |
221 | 0 | return Err(SocksError::invalid_protocol(format!( |
222 | 0 | "unsupported command {}", |
223 | 0 | version_command[1] |
224 | 0 | ))); |
225 | 0 | } |
226 | | |
227 | | // Skip RSV |
228 | 0 | stream.read_exact(&mut [0]).await?; |
229 | | |
230 | | // Address type |
231 | 0 | let mut atyp = [0u8]; |
232 | 0 | stream.read_exact(&mut atyp).await?; |
233 | | |
234 | 0 | let ip = match atyp[0] { |
235 | | 0x01 => { |
236 | 0 | let mut hostb = [0u8; 4]; |
237 | 0 | stream.read_exact(&mut hostb).await?; |
238 | 0 | IpAddr::V4(hostb.into()) |
239 | | } |
240 | | 0x04 => { |
241 | 0 | let mut hostb = [0u8; 16]; |
242 | 0 | stream.read_exact(&mut hostb).await?; |
243 | 0 | IpAddr::V6(hostb.into()) |
244 | | } |
245 | | 0x03 => { |
246 | 0 | let mut domain_length = [0u8]; |
247 | 0 | stream.read_exact(&mut domain_length).await?; |
248 | 0 | let mut domain = vec![0u8; domain_length[0] as usize]; |
249 | 0 | stream.read_exact(&mut domain).await?; |
250 | | |
251 | 0 | let Ok(ds) = std::str::from_utf8(&domain) else { |
252 | 0 | return Err(SocksError::invalid_protocol(format!( |
253 | 0 | "domain is not a valid utf8 string: {domain:?}" |
254 | 0 | ))); |
255 | | }; |
256 | 0 | let Some(resolver) = &pi.resolver else { |
257 | 0 | return Err(SocksError::invalid_protocol( |
258 | 0 | "unsupported hostname lookup, requires DNS enabled".to_string(), |
259 | 0 | )); |
260 | | }; |
261 | | |
262 | 0 | match dns_lookup(resolver.clone(), remote_addr, ds).await { |
263 | 0 | Ok(ip) => ip, |
264 | 0 | Err(e) => { |
265 | 0 | return Err(SocksError::HostUnreachable(e)); |
266 | | } |
267 | | } |
268 | | } |
269 | 0 | n => { |
270 | 0 | return Err(SocksError::invalid_protocol(format!( |
271 | 0 | "unsupported address type {n}", |
272 | 0 | ))); |
273 | | } |
274 | | }; |
275 | | |
276 | 0 | let mut port = [0u8; 2]; |
277 | 0 | stream.read_exact(&mut port).await?; |
278 | 0 | let port = BigEndian::read_u16(&port); |
279 | | |
280 | 0 | let host = SocketAddr::new(ip, port); |
281 | | |
282 | 0 | Ok(host) |
283 | 0 | } |
284 | | |
285 | 0 | async fn dns_lookup( |
286 | 0 | resolver: Arc<dyn Resolver + Send + Sync>, |
287 | 0 | client_addr: SocketAddr, |
288 | 0 | hostname: &str, |
289 | 0 | ) -> Result<IpAddr, Error> { |
290 | 0 | fn new_message(name: Name, rr_type: RecordType) -> Message { |
291 | 0 | let mut msg = Message::new(); |
292 | 0 | msg.set_id(rand::random()); |
293 | 0 | msg.set_message_type(MessageType::Query); |
294 | 0 | msg.set_recursion_desired(true); |
295 | 0 | msg.add_query(Query::query(name, rr_type)); |
296 | 0 | msg |
297 | 0 | } |
298 | | /// Converts the given [Message] into a server-side [Request] with dummy values for |
299 | | /// the client IP and protocol. |
300 | 0 | fn server_request(msg: &Message, client_addr: SocketAddr, protocol: Protocol) -> Request { |
301 | 0 | let wire_bytes = msg.to_vec().unwrap(); |
302 | 0 | let msg_request = MessageRequest::from_bytes(&wire_bytes).unwrap(); |
303 | 0 | Request::new(msg_request, client_addr, protocol) |
304 | 0 | } |
305 | | |
306 | | /// Creates a A-record [Request] for the given name. |
307 | 0 | fn a_request(name: Name, client_addr: SocketAddr, protocol: Protocol) -> Request { |
308 | 0 | server_request(&new_message(name, RecordType::A), client_addr, protocol) |
309 | 0 | } |
310 | | |
311 | | /// Creates a AAAA-record [Request] for the given name. |
312 | 0 | fn aaaa_request(name: Name, client_addr: SocketAddr, protocol: Protocol) -> Request { |
313 | 0 | server_request(&new_message(name, RecordType::AAAA), client_addr, protocol) |
314 | 0 | } |
315 | | |
316 | | // TODO: do we need to do the search? |
317 | 0 | let name = Name::from_utf8(hostname)?; |
318 | | |
319 | | // TODO: we probably want to race them or something. Is there something higher level that can handle this for us? |
320 | 0 | let req = if client_addr.is_ipv4() { |
321 | 0 | a_request(name, client_addr, Protocol::Udp) |
322 | | } else { |
323 | 0 | aaaa_request(name, client_addr, Protocol::Udp) |
324 | | }; |
325 | 0 | let answer = resolver.lookup(&req).await?; |
326 | 0 | let response = answer |
327 | 0 | .record_iter() |
328 | 0 | .filter_map(|rec| rec.data().ip_addr()) |
329 | 0 | .next() // TODO: do not always use the first result |
330 | 0 | .ok_or_else(|| Error::DnsEmpty)?; |
331 | | |
332 | 0 | Ok(response) |
333 | 0 | } |
334 | | |
335 | | /// send_error sends an error back to the SOCKS client |
336 | | /// This may fail, but since there is nothing a caller can do about it, failures are simply logged and |
337 | | /// not returned. |
338 | 0 | pub async fn send_error(err: &SocksError, source: &mut TcpStream) { |
339 | | // SOCKS response requires us to send a 'server bound address'. |
340 | | // It's supposed to be the local address we have bound to. |
341 | | // In many cases, when we are fail we don't have this. |
342 | 0 | let dummy_addr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0); |
343 | 0 | if let Err(e) = send_response(Some(err), source, dummy_addr).await { |
344 | 0 | warn!("failed to send socks error: {e}") |
345 | 0 | } |
346 | 0 | } |
347 | | |
348 | | /// send_success sends a success back to the SOCKS client. |
349 | 0 | pub async fn send_success(source: &mut TcpStream, local_addr: SocketAddr) -> Result<(), Error> { |
350 | 0 | send_response(None, source, local_addr).await |
351 | 0 | } |
352 | | |
353 | 0 | async fn send_response( |
354 | 0 | err: Option<&SocksError>, |
355 | 0 | source: &mut TcpStream, |
356 | 0 | local_addr: SocketAddr, |
357 | 0 | ) -> Result<(), Error> { |
358 | | // https://www.rfc-editor.org/rfc/rfc1928#section-6 |
359 | 0 | let mut buf: Vec<u8> = Vec::with_capacity(10); |
360 | 0 | buf.push(0x05); // version |
361 | | // Status |
362 | 0 | buf.push(match err { |
363 | 0 | None => 0, |
364 | 0 | Some(SocksError::General(_)) => 1, |
365 | 0 | Some(SocksError::NotAllowed(_)) => 2, |
366 | 0 | Some(SocksError::NetworkUnreachable(_)) => 3, |
367 | 0 | Some(SocksError::HostUnreachable(_)) => 4, |
368 | 0 | Some(SocksError::ConnectionRefused(_)) => 5, |
369 | 0 | Some(SocksError::CommandNotSupported(_)) => 7, |
370 | | }); |
371 | 0 | buf.push(0); // RSV |
372 | 0 | match local_addr { |
373 | 0 | SocketAddr::V4(addr_v4) => { |
374 | 0 | buf.push(0x01); // IPv4 address type |
375 | 0 | buf.extend_from_slice(&addr_v4.ip().octets()); |
376 | 0 | } |
377 | 0 | SocketAddr::V6(addr_v6) => { |
378 | 0 | buf.push(0x04); // IPv6 address type |
379 | 0 | buf.extend_from_slice(&addr_v6.ip().octets()); |
380 | 0 | } |
381 | | } |
382 | | // Add port in network byte order (big-endian) |
383 | 0 | buf.extend_from_slice(&local_addr.port().to_be_bytes()); |
384 | 0 | source.write_all(&buf).await?; |
385 | 0 | Ok(()) |
386 | 0 | } |
387 | | |
388 | | /// OutboundProxyError maps outbound errors to SOCKS5 protocol errors |
389 | | /// See https://datatracker.ietf.org/doc/html/rfc1928#section-6. |
390 | | /// While the socks protocol only allows the int error, we record the full error |
391 | | /// for our own logging purposes. |
392 | | #[derive(thiserror::Error, Debug)] |
393 | | #[allow(dead_code)] |
394 | | pub enum SocksError { |
395 | | #[error("General: {0}")] |
396 | | General(Error), |
397 | | #[error("NotAllowed: {0}")] |
398 | | NotAllowed(Error), |
399 | | #[error("NetworkUnreachable: {0}")] |
400 | | NetworkUnreachable(Error), |
401 | | #[error("HostUnreachable: {0}")] |
402 | | HostUnreachable(Error), |
403 | | #[error("ConnectionRefused: {0}")] |
404 | | ConnectionRefused(Error), |
405 | | #[error("CommandNotSupported: {0}")] |
406 | | CommandNotSupported(Error), |
407 | | } |
408 | | |
409 | | impl SocksError { |
410 | 0 | pub fn invalid_protocol(reason: String) -> SocksError { |
411 | 0 | SocksError::CommandNotSupported(Error::Anyhow(anyhow::anyhow!(reason))) |
412 | 0 | } |
413 | | } |
414 | | |
415 | | impl From<std::io::Error> for SocksError { |
416 | 0 | fn from(value: std::io::Error) -> Self { |
417 | 0 | SocksError::General(Error::Io(value)) |
418 | 0 | } |
419 | | } |