Coverage Report

Created: 2026-02-14 06:16

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/ztunnel/src/proxy/socks5.rs
Line
Count
Source
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use anyhow::Result;
16
use byteorder::{BigEndian, ByteOrder};
17
18
use crate::dns::resolver::Resolver;
19
use hickory_proto::op::{Message, MessageType, Query};
20
use hickory_proto::rr::{Name, RecordType};
21
use hickory_proto::serialize::binary::BinDecodable;
22
use hickory_proto::xfer::Protocol;
23
use hickory_server::authority::MessageRequest;
24
use hickory_server::server::Request;
25
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
26
use std::sync::Arc;
27
use std::time::Instant;
28
use tokio::io::AsyncReadExt;
29
use tokio::io::AsyncWriteExt;
30
use tokio::net::TcpStream;
31
use tokio::sync::watch;
32
use tracing::{Instrument, debug, error, info, info_span, warn};
33
34
use crate::drain::DrainWatcher;
35
use crate::drain::run_with_drain;
36
use crate::proxy::outbound::OutboundConnection;
37
use crate::proxy::{Error, ProxyInputs, TraceParent, util};
38
use crate::{assertions, socket};
39
40
pub(super) struct Socks5 {
41
    pi: Arc<ProxyInputs>,
42
    listener: socket::Listener,
43
    drain: DrainWatcher,
44
}
45
46
impl Socks5 {
47
0
    pub(super) async fn new(pi: Arc<ProxyInputs>, drain: DrainWatcher) -> Result<Socks5, Error> {
48
0
        let listener = pi
49
0
            .socket_factory
50
0
            .tcp_bind(pi.cfg.socks5_addr.unwrap())
51
0
            .map_err(|e| Error::Bind(pi.cfg.socks5_addr.unwrap(), e))?;
52
53
0
        let transparent = super::maybe_set_transparent(&pi, &listener)?;
54
55
0
        info!(
56
0
            address=%listener.local_addr(),
57
            component="socks5",
58
            transparent,
59
0
            "listener established",
60
        );
61
62
0
        Ok(Socks5 {
63
0
            pi,
64
0
            listener,
65
0
            drain,
66
0
        })
67
0
    }
68
69
0
    pub(super) fn address(&self) -> SocketAddr {
70
0
        self.listener.local_addr()
71
0
    }
72
73
0
    pub async fn run(self) {
74
0
        let pi = self.pi.clone();
75
0
        let pool = crate::proxy::pool::WorkloadHBONEPool::new(
76
0
            self.pi.cfg.clone(),
77
0
            self.pi.socket_factory.clone(),
78
0
            self.pi.local_workload_information.clone(),
79
        );
80
0
        let accept = async move |drain: DrainWatcher, force_shutdown: watch::Receiver<()>| {
81
            loop {
82
                // Asynchronously wait for an inbound socket.
83
0
                let socket = self.listener.accept().await;
84
0
                let start = Instant::now();
85
0
                let drain = drain.clone();
86
0
                let mut force_shutdown = force_shutdown.clone();
87
0
                match socket {
88
0
                    Ok((stream, _remote)) => {
89
0
                        let socket_labels = crate::proxy::metrics::SocketLabels {
90
0
                            reporter: crate::proxy::metrics::Reporter::source,
91
0
                        };
92
0
                        self.pi.metrics.record_socket_open(&socket_labels);
93
94
0
                        let oc = OutboundConnection {
95
0
                            pi: self.pi.clone(),
96
0
                            id: TraceParent::new(),
97
0
                            pool: pool.clone(),
98
0
                            hbone_port: self.pi.cfg.inbound_addr.port(),
99
0
                        };
100
0
                        let span = info_span!("socks5", id=%oc.id);
101
0
                        let metrics_for_socket_close = self.pi.metrics.clone();
102
0
                        let serve = (async move {
103
0
                            let _socket_guard = crate::proxy::metrics::SocketCloseGuard::new(
104
0
                                metrics_for_socket_close,
105
0
                                crate::proxy::metrics::Reporter::source,
106
                            );
107
0
                            debug!(component="socks5", "connection started");
108
                            // Since this task is spawned, make sure we are guaranteed to terminate
109
0
                            tokio::select! {
110
0
                                _ = force_shutdown.changed() => {
111
0
                                    debug!(component="socks5", "connection forcefully terminated");
112
                                }
113
0
                                _ = handle_socks_connection(oc, stream) => {}
114
                            }
115
                            // Mark we are done with the connection, so drain can complete
116
0
                            drop(drain);
117
0
                            debug!(component="socks5", dur=?start.elapsed(), "connection completed");
118
0
                        }).instrument(span);
119
120
0
                        assertions::size_between_ref(1000, 2000, &serve);
121
0
                        tokio::spawn(serve);
122
                    }
123
0
                    Err(e) => {
124
0
                        if util::is_runtime_shutdown(&e) {
125
0
                            return;
126
0
                        }
127
0
                        error!("Failed TCP handshake {}", e);
128
                    }
129
                }
130
            }
131
0
        };
Unexecuted instantiation: <ztunnel::proxy::socks5::Socks5>::run::{closure#0}::{closure#0}::{closure#0}::<_>
Unexecuted instantiation: <ztunnel::proxy::socks5::Socks5>::run::{closure#0}::{closure#0}
132
133
0
        run_with_drain(
134
0
            "socks5".to_string(),
135
0
            self.drain,
136
0
            pi.cfg.self_termination_deadline,
137
0
            accept,
138
0
        )
139
0
        .await
140
0
    }
141
}
142
0
async fn handle_socks_connection(mut oc: OutboundConnection, mut stream: TcpStream) {
143
0
    match negotiate_socks_connection(&oc.pi, &mut stream).await {
144
0
        Ok(target) => {
145
            // TODO: ideally, we send the success after we connect. This allows us to actually give a
146
            // success only when we really succeeded, rather than if we completed the SOCKS handshake.
147
            // Additionally, it would allow us to get a proper address to send back/
148
0
            let dummy_addr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0);
149
0
            if let Err(err) = send_success(&mut stream, dummy_addr).await {
150
0
                warn!("failed to send socks success response: {err}");
151
0
                return;
152
0
            }
153
0
            let remote_addr =
154
0
                socket::to_canonical(stream.peer_addr().expect("must receive peer addr"));
155
0
            oc.proxy_to(stream, remote_addr, target).await
156
        }
157
0
        Err(e) => {
158
0
            warn!("failed to negotiate socks connection: {e}");
159
0
            send_error(&e, &mut stream).await;
160
        }
161
    }
162
0
}
163
164
// negotiate_socks_connection will handle the negotiation of a SOCKS5 connection.
165
// This ultimately outputs the target socket address, if the handshake is successful.
166
// This supports a minimal subset of the protocol, sufficient to integrate with common clients:
167
// - only unauthenticated requests
168
// - only CONNECT, with IPv4/IPv6/Hostname
169
0
async fn negotiate_socks_connection(
170
0
    pi: &ProxyInputs,
171
0
    stream: &mut TcpStream,
172
0
) -> Result<SocketAddr, SocksError> {
173
0
    let remote_addr = socket::to_canonical(stream.peer_addr().expect("must receive peer addr"));
174
175
    // Version(5), Number of auth methods
176
0
    let mut version = [0u8; 2];
177
0
    stream.read_exact(&mut version).await?;
178
179
0
    if version[0] != 0x05 {
180
0
        return Err(SocksError::invalid_protocol(format!(
181
0
            "unsupported version {}",
182
0
            version[0]
183
0
        )));
184
0
    }
185
186
0
    let nmethods = version[1];
187
188
0
    if nmethods == 0 {
189
0
        return Err(SocksError::invalid_protocol(format!(
190
0
            "methods cannot be zero {}",
191
0
            version[0]
192
0
        )));
193
0
    }
194
195
    // List of supported auth methods
196
0
    let mut methods = vec![0u8; nmethods as usize];
197
0
    stream.read_exact(&mut methods).await?;
198
199
    // Client must include 'unauthenticated' (0).
200
0
    if !methods.into_iter().any(|x| x == 0) {
201
0
        return Err(SocksError::invalid_protocol(
202
0
            "only unauthenticated is supported".to_string(),
203
0
        ));
204
0
    }
205
206
    // Select 'unauthenticated' (0).
207
0
    stream.write_all(&[0x05, 0x00]).await?;
208
209
    // Version(5), Command - only support CONNECT (1)
210
0
    let mut version_command = [0u8; 2];
211
0
    stream.read_exact(&mut version_command).await?;
212
0
    let version = version_command[0];
213
214
0
    if version != 0x05 {
215
0
        return Err(SocksError::invalid_protocol(format!(
216
0
            "unsupported version {version}",
217
0
        )));
218
0
    }
219
220
0
    if version_command[1] != 1 {
221
0
        return Err(SocksError::invalid_protocol(format!(
222
0
            "unsupported command {}",
223
0
            version_command[1]
224
0
        )));
225
0
    }
226
227
    // Skip RSV
228
0
    stream.read_exact(&mut [0]).await?;
229
230
    // Address type
231
0
    let mut atyp = [0u8];
232
0
    stream.read_exact(&mut atyp).await?;
233
234
0
    let ip = match atyp[0] {
235
        0x01 => {
236
0
            let mut hostb = [0u8; 4];
237
0
            stream.read_exact(&mut hostb).await?;
238
0
            IpAddr::V4(hostb.into())
239
        }
240
        0x04 => {
241
0
            let mut hostb = [0u8; 16];
242
0
            stream.read_exact(&mut hostb).await?;
243
0
            IpAddr::V6(hostb.into())
244
        }
245
        0x03 => {
246
0
            let mut domain_length = [0u8];
247
0
            stream.read_exact(&mut domain_length).await?;
248
0
            let mut domain = vec![0u8; domain_length[0] as usize];
249
0
            stream.read_exact(&mut domain).await?;
250
251
0
            let Ok(ds) = std::str::from_utf8(&domain) else {
252
0
                return Err(SocksError::invalid_protocol(format!(
253
0
                    "domain is not a valid utf8 string: {domain:?}"
254
0
                )));
255
            };
256
0
            let Some(resolver) = &pi.resolver else {
257
0
                return Err(SocksError::invalid_protocol(
258
0
                    "unsupported hostname lookup, requires DNS enabled".to_string(),
259
0
                ));
260
            };
261
262
0
            match dns_lookup(resolver.clone(), remote_addr, ds).await {
263
0
                Ok(ip) => ip,
264
0
                Err(e) => {
265
0
                    return Err(SocksError::HostUnreachable(e));
266
                }
267
            }
268
        }
269
0
        n => {
270
0
            return Err(SocksError::invalid_protocol(format!(
271
0
                "unsupported address type {n}",
272
0
            )));
273
        }
274
    };
275
276
0
    let mut port = [0u8; 2];
277
0
    stream.read_exact(&mut port).await?;
278
0
    let port = BigEndian::read_u16(&port);
279
280
0
    let host = SocketAddr::new(ip, port);
281
282
0
    Ok(host)
283
0
}
284
285
0
async fn dns_lookup(
286
0
    resolver: Arc<dyn Resolver + Send + Sync>,
287
0
    client_addr: SocketAddr,
288
0
    hostname: &str,
289
0
) -> Result<IpAddr, Error> {
290
0
    fn new_message(name: Name, rr_type: RecordType) -> Message {
291
0
        let mut msg = Message::new();
292
0
        msg.set_id(rand::random());
293
0
        msg.set_message_type(MessageType::Query);
294
0
        msg.set_recursion_desired(true);
295
0
        msg.add_query(Query::query(name, rr_type));
296
0
        msg
297
0
    }
298
    /// Converts the given [Message] into a server-side [Request] with dummy values for
299
    /// the client IP and protocol.
300
0
    fn server_request(msg: &Message, client_addr: SocketAddr, protocol: Protocol) -> Request {
301
0
        let wire_bytes = msg.to_vec().unwrap();
302
0
        let msg_request = MessageRequest::from_bytes(&wire_bytes).unwrap();
303
0
        Request::new(msg_request, client_addr, protocol)
304
0
    }
305
306
    /// Creates a A-record [Request] for the given name.
307
0
    fn a_request(name: Name, client_addr: SocketAddr, protocol: Protocol) -> Request {
308
0
        server_request(&new_message(name, RecordType::A), client_addr, protocol)
309
0
    }
310
311
    /// Creates a AAAA-record [Request] for the given name.
312
0
    fn aaaa_request(name: Name, client_addr: SocketAddr, protocol: Protocol) -> Request {
313
0
        server_request(&new_message(name, RecordType::AAAA), client_addr, protocol)
314
0
    }
315
316
    // TODO: do we need to do the search?
317
0
    let name = Name::from_utf8(hostname)?;
318
319
    // TODO: we probably want to race them or something. Is there something higher level that can handle this for us?
320
0
    let req = if client_addr.is_ipv4() {
321
0
        a_request(name, client_addr, Protocol::Udp)
322
    } else {
323
0
        aaaa_request(name, client_addr, Protocol::Udp)
324
    };
325
0
    let answer = resolver.lookup(&req).await?;
326
0
    let response = answer
327
0
        .record_iter()
328
0
        .filter_map(|rec| rec.data().ip_addr())
329
0
        .next() // TODO: do not always use the first result
330
0
        .ok_or_else(|| Error::DnsEmpty)?;
331
332
0
    Ok(response)
333
0
}
334
335
/// send_error sends an error back to the SOCKS client
336
/// This may fail, but since there is nothing a caller can do about it, failures are simply logged and
337
/// not returned.
338
0
pub async fn send_error(err: &SocksError, source: &mut TcpStream) {
339
    // SOCKS response requires us to send a 'server bound address'.
340
    // It's supposed to be the local address we have bound to.
341
    // In many cases, when we are fail we don't have this.
342
0
    let dummy_addr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0);
343
0
    if let Err(e) = send_response(Some(err), source, dummy_addr).await {
344
0
        warn!("failed to send socks error: {e}")
345
0
    }
346
0
}
347
348
/// send_success sends a success back to the SOCKS client.
349
0
pub async fn send_success(source: &mut TcpStream, local_addr: SocketAddr) -> Result<(), Error> {
350
0
    send_response(None, source, local_addr).await
351
0
}
352
353
0
async fn send_response(
354
0
    err: Option<&SocksError>,
355
0
    source: &mut TcpStream,
356
0
    local_addr: SocketAddr,
357
0
) -> Result<(), Error> {
358
    // https://www.rfc-editor.org/rfc/rfc1928#section-6
359
0
    let mut buf: Vec<u8> = Vec::with_capacity(10);
360
0
    buf.push(0x05); // version
361
    // Status
362
0
    buf.push(match err {
363
0
        None => 0,
364
0
        Some(SocksError::General(_)) => 1,
365
0
        Some(SocksError::NotAllowed(_)) => 2,
366
0
        Some(SocksError::NetworkUnreachable(_)) => 3,
367
0
        Some(SocksError::HostUnreachable(_)) => 4,
368
0
        Some(SocksError::ConnectionRefused(_)) => 5,
369
0
        Some(SocksError::CommandNotSupported(_)) => 7,
370
    });
371
0
    buf.push(0); // RSV
372
0
    match local_addr {
373
0
        SocketAddr::V4(addr_v4) => {
374
0
            buf.push(0x01); // IPv4 address type
375
0
            buf.extend_from_slice(&addr_v4.ip().octets());
376
0
        }
377
0
        SocketAddr::V6(addr_v6) => {
378
0
            buf.push(0x04); // IPv6 address type
379
0
            buf.extend_from_slice(&addr_v6.ip().octets());
380
0
        }
381
    }
382
    // Add port in network byte order (big-endian)
383
0
    buf.extend_from_slice(&local_addr.port().to_be_bytes());
384
0
    source.write_all(&buf).await?;
385
0
    Ok(())
386
0
}
387
388
/// OutboundProxyError maps outbound errors to SOCKS5 protocol errors
389
/// See https://datatracker.ietf.org/doc/html/rfc1928#section-6.
390
/// While the socks protocol only allows the int error, we record the full error
391
/// for our own logging purposes.
392
#[derive(thiserror::Error, Debug)]
393
#[allow(dead_code)]
394
pub enum SocksError {
395
    #[error("General: {0}")]
396
    General(Error),
397
    #[error("NotAllowed: {0}")]
398
    NotAllowed(Error),
399
    #[error("NetworkUnreachable: {0}")]
400
    NetworkUnreachable(Error),
401
    #[error("HostUnreachable: {0}")]
402
    HostUnreachable(Error),
403
    #[error("ConnectionRefused: {0}")]
404
    ConnectionRefused(Error),
405
    #[error("CommandNotSupported: {0}")]
406
    CommandNotSupported(Error),
407
}
408
409
impl SocksError {
410
0
    pub fn invalid_protocol(reason: String) -> SocksError {
411
0
        SocksError::CommandNotSupported(Error::Anyhow(anyhow::anyhow!(reason)))
412
0
    }
413
}
414
415
impl From<std::io::Error> for SocksError {
416
0
    fn from(value: std::io::Error) -> Self {
417
0
        SocksError::General(Error::Io(value))
418
0
    }
419
}