Coverage Report

Created: 2025-12-28 06:31

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/ztunnel/src/socket.rs
Line
Count
Source
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::io::Error;
16
use std::net::SocketAddr;
17
18
use tokio::io;
19
20
use tokio::net::TcpSocket;
21
use tokio::net::{TcpListener, TcpStream};
22
23
use crate::config::SocketConfig;
24
use socket2::{SockRef, TcpKeepalive};
25
26
#[cfg(target_os = "linux")]
27
use {socket2::Domain, std::io::ErrorKind, tracing::warn};
28
29
#[cfg(target_os = "linux")]
30
0
pub fn set_freebind_and_transparent(socket: &TcpSocket) -> io::Result<()> {
31
0
    let socket = SockRef::from(socket);
32
0
    match socket.domain()? {
33
        Domain::IPV4 => {
34
0
            socket.set_ip_transparent_v4(true)?;
35
0
            socket.set_freebind_v4(true)?;
36
        }
37
        Domain::IPV6 => {
38
0
            linux::set_ipv6_transparent(&socket)?;
39
0
            socket.set_freebind_v6(true)?
40
        }
41
0
        _ => return Err(Error::new(ErrorKind::Unsupported, "unsupported domain")),
42
    };
43
0
    Ok(())
44
0
}
45
46
0
pub fn to_canonical(addr: SocketAddr) -> SocketAddr {
47
    // another match has to be used for IPv4 and IPv6 support
48
0
    let ip = addr.ip().to_canonical();
49
0
    SocketAddr::from((ip, addr.port()))
50
0
}
51
52
0
pub fn orig_dst_addr_or_default(stream: &tokio::net::TcpStream) -> SocketAddr {
53
0
    to_canonical(match orig_dst_addr(stream) {
54
0
        Ok(addr) => addr,
55
0
        _ => stream.local_addr().expect("must get local address"),
56
    })
57
0
}
58
59
#[cfg(target_os = "linux")]
60
0
fn orig_dst_addr(stream: &tokio::net::TcpStream) -> io::Result<SocketAddr> {
61
0
    let sock = SockRef::from(stream);
62
    // Dual-stack IPv4/IPv6 sockets require us to check both options.
63
0
    match linux::original_dst(&sock) {
64
0
        Ok(addr) => Ok(addr.as_socket().expect("failed to convert to SocketAddr")),
65
0
        Err(e4) => match linux::original_dst_ipv6(&sock) {
66
0
            Ok(addr) => Ok(addr.as_socket().expect("failed to convert to SocketAddr")),
67
0
            Err(e6) => {
68
0
                if !sock.ip_transparent_v4().unwrap_or(false) {
69
                    // In TPROXY mode, this is normal, so don't bother logging
70
0
                    warn!(
71
0
                        peer=?stream.peer_addr().unwrap(),
72
0
                        local=?stream.local_addr().unwrap(),
73
0
                        "failed to read SO_ORIGINAL_DST: {e4:?}, {e6:?}"
74
                    );
75
0
                }
76
0
                Err(e6)
77
            }
78
        },
79
    }
80
0
}
81
82
#[cfg(not(target_os = "linux"))]
83
fn orig_dst_addr(_: &tokio::net::TcpStream) -> io::Result<SocketAddr> {
84
    Err(Error::new(
85
        io::ErrorKind::Other,
86
        "SO_ORIGINAL_DST not supported on this operating system",
87
    ))
88
}
89
90
#[cfg(not(target_os = "linux"))]
91
pub fn set_freebind_and_transparent(_: &TcpSocket) -> io::Result<()> {
92
    Err(Error::new(
93
        io::ErrorKind::Other,
94
        "IP_TRANSPARENT and IP_FREEBIND are not supported on this operating system",
95
    ))
96
}
97
98
#[cfg(target_os = "linux")]
99
0
pub fn set_mark<S: std::os::unix::io::AsFd>(socket: &S, mark: u32) -> io::Result<()> {
100
0
    let socket = SockRef::from(socket);
101
0
    socket.set_mark(mark)
102
0
}
Unexecuted instantiation: ztunnel::socket::set_mark::<std::net::tcp::TcpListener>
Unexecuted instantiation: ztunnel::socket::set_mark::<std::net::udp::UdpSocket>
Unexecuted instantiation: ztunnel::socket::set_mark::<tokio::net::tcp::socket::TcpSocket>
Unexecuted instantiation: ztunnel::socket::set_mark::<std::os::fd::owned::OwnedFd>
103
104
#[cfg(not(target_os = "linux"))]
105
pub fn set_mark(_socket: &TcpSocket, _mark: u32) -> io::Result<()> {
106
    Err(io::Error::new(
107
        io::ErrorKind::Other,
108
        "SO_MARK not supported on this operating system",
109
    ))
110
}
111
112
#[cfg(target_os = "linux")]
113
#[allow(unsafe_code)]
114
mod linux {
115
    use std::os::unix::io::AsRawFd;
116
117
    use socket2::{SockAddr, SockRef};
118
    use tokio::io;
119
120
0
    pub fn set_ipv6_transparent(sock: &SockRef) -> io::Result<()> {
121
        unsafe {
122
0
            let optval: libc::c_int = 1;
123
0
            let ret = libc::setsockopt(
124
0
                sock.as_raw_fd(),
125
                libc::IPPROTO_IPV6,
126
                libc::IPV6_TRANSPARENT,
127
0
                &optval as *const _ as *const libc::c_void,
128
0
                std::mem::size_of_val(&optval) as libc::socklen_t,
129
            );
130
0
            if ret != 0 {
131
0
                return Err(io::Error::last_os_error());
132
0
            }
133
        }
134
0
        Ok(())
135
0
    }
136
137
0
    pub fn original_dst(sock: &SockRef) -> io::Result<SockAddr> {
138
0
        sock.original_dst_v4()
139
0
    }
140
141
0
    pub fn original_dst_ipv6(sock: &SockRef) -> io::Result<SockAddr> {
142
0
        sock.original_dst_v6()
143
0
    }
144
}
145
146
/// Listener is a wrapper For TCPListener with sane defaults. Notably, setting NODELAY
147
/// You can also pass it additional socket options to set on accepted connections.
148
pub struct Listener {
149
    listener: TcpListener,
150
    cfg: Option<SocketConfig>,
151
}
152
153
impl Listener {
154
0
    pub fn new(l: TcpListener) -> Self {
155
0
        Listener {
156
0
            listener: l,
157
0
            cfg: None,
158
0
        }
159
0
    }
160
0
    pub fn local_addr(&self) -> SocketAddr {
161
0
        self.listener.local_addr().expect("local_addr is available")
162
0
    }
163
0
    pub fn inner(self) -> TcpListener {
164
0
        self.listener
165
0
    }
166
0
    pub fn set_socket_options(&mut self, cfg: Option<SocketConfig>) {
167
0
        self.cfg = cfg;
168
0
    }
169
0
    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
170
0
        let (stream, remote) = self.listener.accept().await?;
171
0
        stream.set_nodelay(true)?;
172
0
        if let Some(cfg) = self.cfg {
173
0
            if cfg.keepalive_enabled {
174
0
                let ka = TcpKeepalive::new()
175
0
                    .with_time(cfg.keepalive_time)
176
0
                    .with_retries(cfg.keepalive_retries)
177
0
                    .with_interval(cfg.keepalive_interval);
178
0
                let res = SockRef::from(&stream).set_tcp_keepalive(&ka);
179
0
                tracing::trace!("set keepalive: {:?}", res);
180
0
            }
181
0
            if cfg.user_timeout_enabled {
182
0
                let ut = cfg.keepalive_time + cfg.keepalive_retries * cfg.keepalive_interval;
183
0
                let res = SockRef::from(&stream).set_tcp_user_timeout(Some(ut));
184
0
                tracing::trace!("set user timeout: {:?}", res);
185
0
            }
186
0
        }
187
0
        Ok((stream, remote))
188
0
    }
189
}
190
191
#[cfg(target_os = "linux")]
192
impl Listener {
193
0
    pub fn set_transparent(&self) -> io::Result<()> {
194
0
        SockRef::from(&self.listener).set_ip_transparent_v4(true)
195
0
    }
196
}
197
198
#[cfg(not(target_os = "linux"))]
199
impl Listener {
200
    pub fn set_transparent(&self) -> io::Result<()> {
201
        Err(io::Error::new(
202
            io::ErrorKind::Other,
203
            "IP_TRANSPARENT not supported on this operating system",
204
        ))
205
    }
206
}