Coverage Report

Created: 2025-12-28 06:31

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/ztunnel/src/proxy.rs
Line
Count
Source
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::fmt::Debug;
16
use std::fs::File;
17
use std::io::Read;
18
use std::net::{IpAddr, SocketAddr};
19
use std::sync::Arc;
20
use std::time::Duration;
21
use std::{fmt, io};
22
23
use hickory_proto::ProtoError;
24
25
use crate::strng::Strng;
26
use rand::Rng;
27
use socket2::TcpKeepalive;
28
use tokio::net::{TcpListener, TcpSocket, TcpStream};
29
use tokio::time::timeout;
30
use tracing::{Instrument, debug, trace, warn};
31
32
use inbound::Inbound;
33
pub use metrics::*;
34
35
use crate::identity::{Identity, SecretManager};
36
37
use crate::dns::resolver::Resolver;
38
use crate::drain::DrainWatcher;
39
use crate::proxy::connection_manager::{ConnectionManager, PolicyWatcher};
40
use crate::proxy::inbound_passthrough::InboundPassthrough;
41
use crate::proxy::outbound::Outbound;
42
use crate::proxy::socks5::Socks5;
43
use crate::rbac::Connection;
44
use crate::state::service::{Service, ServiceDescription};
45
use crate::state::workload::address::Address;
46
use crate::state::workload::{GatewayAddress, Workload};
47
use crate::state::{DemandProxyState, WorkloadInfo};
48
use crate::{config, identity, socket, tls};
49
50
pub mod connection_manager;
51
pub mod inbound;
52
53
mod h2;
54
mod inbound_passthrough;
55
#[allow(non_camel_case_types)]
56
pub mod metrics;
57
mod outbound;
58
pub mod pool;
59
mod socks5;
60
pub mod util;
61
62
pub trait SocketFactory {
63
    fn new_tcp_v4(&self) -> std::io::Result<TcpSocket>;
64
65
    fn new_tcp_v6(&self) -> std::io::Result<TcpSocket>;
66
67
    fn tcp_bind(&self, addr: SocketAddr) -> std::io::Result<socket::Listener>;
68
69
    fn udp_bind(&self, addr: SocketAddr) -> std::io::Result<tokio::net::UdpSocket>;
70
71
    fn ipv6_enabled_localhost(&self) -> std::io::Result<bool>;
72
}
73
74
#[derive(Clone, Copy, Default)]
75
pub struct DefaultSocketFactory(pub config::SocketConfig);
76
77
impl SocketFactory for DefaultSocketFactory {
78
0
    fn new_tcp_v4(&self) -> std::io::Result<TcpSocket> {
79
0
        TcpSocket::new_v4().and_then(|s| {
80
0
            self.setup_socket(&s)?;
81
0
            Ok(s)
82
0
        })
83
0
    }
84
85
0
    fn new_tcp_v6(&self) -> std::io::Result<TcpSocket> {
86
0
        TcpSocket::new_v6().and_then(|s| {
87
0
            self.setup_socket(&s)?;
88
0
            Ok(s)
89
0
        })
90
0
    }
91
92
0
    fn tcp_bind(&self, addr: SocketAddr) -> std::io::Result<socket::Listener> {
93
0
        let std_sock = std::net::TcpListener::bind(addr)?;
94
0
        std_sock.set_nonblocking(true)?;
95
0
        TcpListener::from_std(std_sock).map(socket::Listener::new)
96
0
    }
97
98
0
    fn udp_bind(&self, addr: SocketAddr) -> std::io::Result<tokio::net::UdpSocket> {
99
0
        let std_sock = std::net::UdpSocket::bind(addr)?;
100
0
        std_sock.set_nonblocking(true)?;
101
0
        tokio::net::UdpSocket::from_std(std_sock)
102
0
    }
103
104
0
    fn ipv6_enabled_localhost(&self) -> io::Result<bool> {
105
0
        ipv6_enabled_on_localhost()
106
0
    }
107
}
108
109
impl DefaultSocketFactory {
110
0
    fn setup_socket(&self, s: &TcpSocket) -> io::Result<()> {
111
0
        s.set_nodelay(true)?;
112
0
        let cfg = self.0;
113
0
        if cfg.keepalive_enabled {
114
0
            let ka = TcpKeepalive::new()
115
0
                .with_time(cfg.keepalive_time)
116
0
                .with_retries(cfg.keepalive_retries)
117
0
                .with_interval(cfg.keepalive_interval);
118
0
            let res = socket2::SockRef::from(&s).set_tcp_keepalive(&ka);
119
0
            tracing::trace!("set keepalive: {:?}", res);
120
0
        }
121
0
        if cfg.user_timeout_enabled {
122
            // https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/
123
            // TCP_USER_TIMEOUT = TCP_KEEPIDLE + TCP_KEEPINTVL * TCP_KEEPCNT.
124
0
            let ut = cfg.keepalive_time + cfg.keepalive_retries * cfg.keepalive_interval;
125
0
            let res = socket2::SockRef::from(&s).set_tcp_user_timeout(Some(ut));
126
0
            tracing::trace!("set user timeout: {:?}", res);
127
0
        }
128
0
        Ok(())
129
0
    }
130
}
131
132
pub struct MarkSocketFactory {
133
    pub inner: DefaultSocketFactory,
134
    pub mark: u32,
135
}
136
137
impl SocketFactory for MarkSocketFactory {
138
0
    fn new_tcp_v4(&self) -> io::Result<TcpSocket> {
139
0
        self.inner.new_tcp_v4().and_then(|s| {
140
0
            socket::set_mark(&s, self.mark)?;
141
0
            Ok(s)
142
0
        })
143
0
    }
144
145
0
    fn new_tcp_v6(&self) -> io::Result<TcpSocket> {
146
0
        self.inner.new_tcp_v6().and_then(|s| {
147
0
            socket::set_mark(&s, self.mark)?;
148
0
            Ok(s)
149
0
        })
150
0
    }
151
152
0
    fn tcp_bind(&self, addr: SocketAddr) -> io::Result<socket::Listener> {
153
0
        self.inner.tcp_bind(addr)
154
0
    }
155
156
0
    fn udp_bind(&self, addr: SocketAddr) -> io::Result<tokio::net::UdpSocket> {
157
0
        self.inner.udp_bind(addr)
158
0
    }
159
160
0
    fn ipv6_enabled_localhost(&self) -> io::Result<bool> {
161
0
        self.inner.ipv6_enabled_localhost()
162
0
    }
163
}
164
165
pub struct Proxy {
166
    inbound: Inbound,
167
    inbound_passthrough: InboundPassthrough,
168
    outbound: Outbound,
169
    socks5: Option<Socks5>,
170
    policy_watcher: PolicyWatcher,
171
}
172
173
pub struct LocalWorkloadInformation {
174
    wi: Arc<WorkloadInfo>,
175
    state: DemandProxyState,
176
    // full_cert_manager gives access to the full SecretManager. This MUST only be given restricted
177
    // access to the appropriate certificates
178
    full_cert_manager: Arc<SecretManager>,
179
}
180
181
impl LocalWorkloadInformation {
182
0
    pub fn new(
183
0
        wi: Arc<WorkloadInfo>,
184
0
        state: DemandProxyState,
185
0
        cert_manager: Arc<SecretManager>,
186
0
    ) -> LocalWorkloadInformation {
187
0
        LocalWorkloadInformation {
188
0
            wi,
189
0
            state,
190
0
            full_cert_manager: cert_manager,
191
0
        }
192
0
    }
193
194
0
    pub async fn get_workload(&self) -> Result<Arc<Workload>, Error> {
195
0
        get_workload(&self.state, self.wi.clone()).await
196
0
    }
197
198
0
    pub async fn fetch_certificate(
199
0
        &self,
200
0
    ) -> Result<Arc<tls::WorkloadCertificate>, identity::Error> {
201
        // We don't know the trust domain until we get the workload from XDS, so fetch that
202
0
        let wl = self
203
0
            .get_workload()
204
0
            .await
205
0
            .map_err(|_| identity::Error::UnknownWorkload(self.workload_info()))?;
206
0
        let id = &Identity::Spiffe {
207
0
            trust_domain: wl.trust_domain.clone(),
208
0
            namespace: (&self.wi.namespace).into(),
209
0
            service_account: (&self.wi.service_account).into(),
210
0
        };
211
0
        self.full_cert_manager.fetch_certificate(id).await
212
0
    }
213
214
0
    pub fn workload_info(&self) -> Arc<WorkloadInfo> {
215
0
        self.wi.clone()
216
0
    }
217
218
0
    pub fn as_fetcher(&self) -> Arc<LocalWorkloadFetcher> {
219
0
        LocalWorkloadFetcher::new(self.wi.clone(), self.state.clone())
220
0
    }
221
}
222
223
/// LocalWorkloadFetcher is essentially LocalWorkloadInformation without CA access.
224
/// This is used to down-scope the LocalWorkloadInformation for components who should not have access
225
/// to certificates.
226
pub struct LocalWorkloadFetcher {
227
    wi: Arc<WorkloadInfo>,
228
    state: DemandProxyState,
229
}
230
231
impl LocalWorkloadFetcher {
232
0
    pub fn new(wi: Arc<WorkloadInfo>, state: DemandProxyState) -> Arc<Self> {
233
0
        Arc::new(LocalWorkloadFetcher { wi, state })
234
0
    }
235
0
    pub async fn get_workload(&self) -> Result<Arc<Workload>, Error> {
236
0
        get_workload(&self.state, self.wi.clone()).await
237
0
    }
238
}
239
240
0
async fn get_workload(
241
0
    state: &DemandProxyState,
242
0
    wi: Arc<WorkloadInfo>,
243
0
) -> Result<Arc<Workload>, Error> {
244
0
    state
245
0
        .wait_for_workload(&wi, Duration::from_secs(5))
246
0
        .await
247
0
        .ok_or_else(|| Error::UnknownSourceWorkload(wi.clone()))
248
0
}
249
250
#[derive(Clone)]
251
pub(super) struct ProxyInputs {
252
    cfg: Arc<config::Config>,
253
    connection_manager: ConnectionManager,
254
    pub state: DemandProxyState,
255
    metrics: Arc<Metrics>,
256
    socket_factory: Arc<dyn SocketFactory + Send + Sync>,
257
    local_workload_information: Arc<LocalWorkloadInformation>,
258
    resolver: Option<Arc<dyn Resolver + Send + Sync>>,
259
    // If true, inbound connections created with these inputs will not attempt to preserve the original source IP.
260
    pub disable_inbound_freebind: bool,
261
}
262
263
#[allow(clippy::too_many_arguments)]
264
impl ProxyInputs {
265
0
    pub fn new(
266
0
        cfg: Arc<config::Config>,
267
0
        connection_manager: ConnectionManager,
268
0
        state: DemandProxyState,
269
0
        metrics: Arc<Metrics>,
270
0
        socket_factory: Arc<dyn SocketFactory + Send + Sync>,
271
0
        resolver: Option<Arc<dyn Resolver + Send + Sync>>,
272
0
        local_workload_information: Arc<LocalWorkloadInformation>,
273
0
        disable_inbound_freebind: bool,
274
0
    ) -> Arc<Self> {
275
0
        Arc::new(Self {
276
0
            cfg,
277
0
            state,
278
0
            metrics,
279
0
            connection_manager,
280
0
            socket_factory,
281
0
            local_workload_information,
282
0
            resolver,
283
0
            disable_inbound_freebind,
284
0
        })
285
0
    }
286
}
287
288
impl Proxy {
289
    #[allow(unused_mut)]
290
0
    pub(super) async fn from_inputs(
291
0
        mut pi: Arc<ProxyInputs>,
292
0
        drain: DrainWatcher,
293
0
    ) -> Result<Self, Error> {
294
        // We setup all the listeners first so we can capture any errors that should block startup
295
0
        let inbound = Inbound::new(pi.clone(), drain.clone()).await?;
296
297
        // This exists for `direct` integ tests, no other reason
298
        #[cfg(any(test, feature = "testing"))]
299
        if pi.cfg.fake_self_inbound {
300
            warn!("TEST FAKE - overriding inbound address for test");
301
            let mut old_cfg = (*pi.cfg).clone();
302
            old_cfg.inbound_addr = inbound.address();
303
            let mut new_pi = (*pi).clone();
304
            new_pi.cfg = Arc::new(old_cfg);
305
            pi = Arc::new(new_pi);
306
            warn!("TEST FAKE: new address is {:?}", pi.cfg.inbound_addr);
307
        }
308
309
0
        let inbound_passthrough = InboundPassthrough::new(pi.clone(), drain.clone()).await?;
310
0
        let outbound = Outbound::new(pi.clone(), drain.clone()).await?;
311
0
        let socks5 = if pi.cfg.socks5_addr.is_some() {
312
0
            let socks5 = Socks5::new(pi.clone(), drain.clone()).await?;
313
0
            Some(socks5)
314
        } else {
315
0
            None
316
        };
317
0
        let policy_watcher =
318
0
            PolicyWatcher::new(pi.state.clone(), drain, pi.connection_manager.clone());
319
320
0
        Ok(Proxy {
321
0
            inbound,
322
0
            inbound_passthrough,
323
0
            outbound,
324
0
            socks5,
325
0
            policy_watcher,
326
0
        })
327
0
    }
328
329
0
    pub async fn run(self) {
330
0
        let mut tasks = vec![
331
0
            tokio::spawn(self.inbound_passthrough.run().in_current_span()),
332
0
            tokio::spawn(self.policy_watcher.run().in_current_span()),
333
0
            tokio::spawn(self.inbound.run().in_current_span()),
334
0
            tokio::spawn(self.outbound.run().in_current_span()),
335
        ];
336
337
0
        if let Some(socks5) = self.socks5 {
338
0
            tasks.push(tokio::spawn(socks5.run().in_current_span()));
339
0
        };
340
341
0
        futures::future::join_all(tasks).await;
342
0
    }
343
344
0
    pub fn addresses(&self) -> Addresses {
345
        Addresses {
346
0
            outbound: self.outbound.address(),
347
0
            inbound: self.inbound.address(),
348
0
            socks5: self.socks5.as_ref().map(|s| s.address()),
349
        }
350
0
    }
351
}
352
353
#[derive(Copy, Clone)]
354
pub struct Addresses {
355
    pub outbound: SocketAddr,
356
    pub inbound: SocketAddr,
357
    pub socks5: Option<SocketAddr>,
358
}
359
360
#[derive(Debug, PartialEq, Eq)]
361
pub enum AuthorizationRejectionError {
362
    NoWorkload,
363
    WorkloadMismatch,
364
    ExplicitlyDenied(Strng, Strng),
365
    NotAllowed,
366
}
367
impl fmt::Display for AuthorizationRejectionError {
368
0
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
369
0
        match self {
370
0
            Self::NoWorkload => write!(fmt, "workload not found"),
371
0
            Self::WorkloadMismatch => write!(fmt, "workload mismatch"),
372
0
            Self::ExplicitlyDenied(a, b) => write!(fmt, "explicitly denied by: {a}/{b}"),
373
0
            Self::NotAllowed => write!(fmt, "allow policies exist, but none allowed"),
374
        }
375
0
    }
376
}
377
378
#[derive(thiserror::Error, Debug)]
379
pub enum Error {
380
    #[error("failed to bind to address {0}: {1}")]
381
    Bind(SocketAddr, io::Error),
382
383
    #[error("io error: {0}")]
384
    Io(#[from] io::Error),
385
386
    #[error("while closing connection: {0}")]
387
    ShutdownError(Box<Error>),
388
389
    #[error("connection timed out, maybe a NetworkPolicy is blocking HBONE port 15008: {0}")]
390
    MaybeHBONENetworkPolicyError(io::Error),
391
392
    #[error("destination disconnected before all data was written")]
393
    BackendDisconnected,
394
    #[error("receive: {0}")]
395
    ReceiveError(Box<Error>),
396
397
    #[error("client disconnected before all data was written")]
398
    ClientDisconnected,
399
    #[error("send: {0}")]
400
    SendError(Box<Error>),
401
402
    #[error("connection failed: {0}")]
403
    ConnectionFailed(io::Error),
404
405
    #[error("connection tracking failed")]
406
    ConnectionTrackingFailed,
407
408
    #[error("connection closed due to policy change")]
409
    AuthorizationPolicyLateRejection,
410
411
    #[error("connection closed due to policy rejection: {0}")]
412
    AuthorizationPolicyRejection(AuthorizationRejectionError),
413
414
    #[error("pool draining")]
415
    WorkloadHBONEPoolDraining,
416
417
    #[error("{0}")]
418
    Generic(Box<dyn std::error::Error + Send + Sync>),
419
420
    #[error("{0}")]
421
    Anyhow(anyhow::Error),
422
423
    #[error("http2 handshake failed: {0}")]
424
    Http2Handshake(#[source] ::h2::Error),
425
426
    #[error("h2 failed: {0}")]
427
    H2(#[from] ::h2::Error),
428
429
    #[error("http status: {0}")]
430
    HttpStatus(http::StatusCode),
431
432
    #[error("expected method CONNECT, got {0}")]
433
    NonConnectMethod(String),
434
435
    #[error("invalid CONNECT address {0}")]
436
    ConnectAddress(String),
437
438
    #[error("tls error: {0}")]
439
    Tls(#[from] tls::Error),
440
441
    #[error("identity error: {0}")]
442
    Identity(#[from] identity::Error),
443
444
    #[error("failed to fetch information about local workload: {0}")]
445
    UnknownSourceWorkload(Arc<WorkloadInfo>),
446
447
    #[error("unknown waypoint: {0}")]
448
    UnknownWaypoint(String),
449
450
    #[error("unknown network gateway: {0}")]
451
    UnknownNetworkGateway(String),
452
453
    #[error("no service or workload for hostname: {0}")]
454
    NoHostname(String),
455
456
    #[error("no endpoints for workload: {0}")]
457
    NoWorkloadEndpoints(String),
458
459
    #[error("no valid authority pseudo header: {0}")]
460
    NoValidAuthority(String),
461
462
    #[error("no valid service port in authority header: {0}")]
463
    NoValidServicePort(String, u16),
464
465
    #[error("no valid target port for workload: {0}")]
466
    NoValidTargetPort(String, u16),
467
468
    #[error("no valid routing destination for workload: {0}")]
469
    NoValidDestination(Box<Workload>),
470
471
    #[error("no healthy upstream: {0}")]
472
    NoHealthyUpstream(SocketAddr),
473
474
    #[error("no ip addresses were resolved for workload: {0}")]
475
    NoResolvedAddresses(String),
476
477
    #[error("requested service {0}:{1} found, but cannot resolve port")]
478
    NoPortForServices(String, u16),
479
480
    #[error("requested service {0} found, but has no IP addresses")]
481
    NoIPForService(String),
482
483
    #[error("no service for target address: {0}")]
484
    NoService(SocketAddr),
485
486
    #[error(
487
        "ip addresses were resolved for workload {0}, but valid dns response had no A/AAAA records"
488
    )]
489
    EmptyResolvedAddresses(String),
490
491
    #[error("attempted recursive call to ourselves")]
492
    SelfCall,
493
494
    #[error("no gateway address: {0}")]
495
    NoGatewayAddress(Box<Workload>),
496
497
    #[error("unsupported feature: {0}")]
498
    UnsupportedFeature(String),
499
500
    #[error("ip mismatch: {0} != {1}")]
501
    IPMismatch(IpAddr, IpAddr),
502
503
    #[error("connection failed to drain within the timeout")]
504
    DrainTimeOut,
505
    #[error("connection closed due to connection drain")]
506
    ClosedFromDrain,
507
508
    #[error("dns: {0}")]
509
    Dns(#[from] ProtoError),
510
    #[error("dns lookup: {0}")]
511
    DnsLookup(#[from] hickory_server::authority::LookupError),
512
    #[error("dns response had no valid IP addresses")]
513
    DnsEmpty,
514
}
515
516
// Custom TLV for proxy protocol for the identity of the source
517
const PROXY_PROTOCOL_AUTHORITY_TLV: u8 = 0xD0;
518
519
0
pub async fn write_proxy_protocol<T>(
520
0
    stream: &mut TcpStream,
521
0
    addresses: T,
522
0
    src_id: Option<Identity>,
523
0
) -> io::Result<()>
524
0
where
525
0
    T: Into<ppp::v2::Addresses> + std::fmt::Debug,
526
0
{
527
    use ppp::v2::{Builder, Command, Protocol, Version};
528
    use tokio::io::AsyncWriteExt;
529
530
    // When the hbone_addr populated from the authority header contains a svc hostname, the address included
531
    // with respect to the hbone_addr is the SocketAddr <dst svc IP>:<original dst port>.
532
    // This is done since addresses doesn't support hostnames.
533
    // See ref https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
534
0
    debug!("writing proxy protocol addresses: {:?}", addresses);
535
0
    let mut builder =
536
0
        Builder::with_addresses(Version::Two | Command::Proxy, Protocol::Stream, addresses);
537
538
0
    if let Some(id) = src_id {
539
0
        builder = builder.write_tlv(PROXY_PROTOCOL_AUTHORITY_TLV, id.to_string().as_bytes())?;
540
0
    }
541
542
0
    let header = builder.build()?;
543
0
    stream.write_all(&header).await
544
0
}
545
546
/// Represents a traceparent, as defined by https://www.w3.org/TR/trace-context/
547
#[derive(Eq, PartialEq)]
548
pub struct TraceParent {
549
    version: u8,
550
    trace_id: u128,
551
    parent_id: u64,
552
    flags: u8,
553
}
554
555
pub const BAGGAGE_HEADER: &str = "baggage";
556
pub const TRACEPARENT_HEADER: &str = "traceparent";
557
558
impl TraceParent {
559
0
    pub fn header(&self) -> hyper::header::HeaderValue {
560
0
        hyper::header::HeaderValue::from_bytes(format!("{self:?}").as_bytes()).unwrap()
561
0
    }
562
}
563
impl TraceParent {
564
0
    fn new() -> Self {
565
0
        let mut rng = rand::rng();
566
0
        Self {
567
0
            version: 0,
568
0
            trace_id: rng.random(),
569
0
            parent_id: rng.random(),
570
0
            flags: 0,
571
0
        }
572
0
    }
573
}
574
575
impl fmt::Debug for TraceParent {
576
0
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
577
0
        write!(
578
0
            f,
579
0
            "{:02x}-{:032x}-{:016x}-{:02x}",
580
            self.version, self.trace_id, self.parent_id, self.flags
581
        )
582
0
    }
583
}
584
585
impl fmt::Display for TraceParent {
586
0
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
587
0
        write!(f, "{:032x}", self.trace_id,)
588
0
    }
589
}
590
591
impl TryFrom<&str> for TraceParent {
592
    type Error = anyhow::Error;
593
594
0
    fn try_from(value: &str) -> Result<Self, Self::Error> {
595
0
        if value.len() != 55 {
596
0
            anyhow::bail!("traceparent malformed length was {}", value.len())
597
0
        }
598
599
0
        let segs: Vec<&str> = value.split('-').collect();
600
601
        Ok(Self {
602
0
            version: u8::from_str_radix(segs[0], 16)?,
603
0
            trace_id: u128::from_str_radix(segs[1], 16)?,
604
0
            parent_id: u64::from_str_radix(segs[2], 16)?,
605
0
            flags: u8::from_str_radix(segs[3], 16)?,
606
        })
607
0
    }
608
}
609
610
0
pub(super) fn maybe_set_transparent(
611
0
    pi: &ProxyInputs,
612
0
    listener: &socket::Listener,
613
0
) -> Result<bool, Error> {
614
0
    Ok(match pi.cfg.require_original_source {
615
        Some(true) => {
616
            // Explicitly enabled. Return error if we cannot set it.
617
0
            listener.set_transparent()?;
618
0
            true
619
        }
620
        Some(false) => {
621
            // Explicitly disabled, don't even attempt to set it.
622
0
            false
623
        }
624
        None => {
625
            // Best effort
626
0
            listener.set_transparent().is_ok()
627
        }
628
    })
629
0
}
630
631
0
pub fn get_original_src_from_stream(stream: &TcpStream) -> Option<IpAddr> {
632
0
    stream
633
0
        .peer_addr()
634
0
        .map_or(None, |sa| Some(socket::to_canonical(sa).ip()))
635
0
}
636
637
const CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
638
639
0
pub async fn freebind_connect(
640
0
    local: Option<IpAddr>,
641
0
    addr: SocketAddr,
642
0
    socket_factory: &(dyn SocketFactory + Send + Sync),
643
0
) -> io::Result<TcpStream> {
644
0
    async fn connect(
645
0
        local: Option<IpAddr>,
646
0
        addr: SocketAddr,
647
0
        socket_factory: &(dyn SocketFactory + Send + Sync),
648
0
    ) -> io::Result<TcpStream> {
649
0
        let create_socket = |is_ipv4: bool| {
650
0
            if is_ipv4 {
651
0
                socket_factory.new_tcp_v4()
652
            } else {
653
0
                socket_factory.new_tcp_v6()
654
            }
655
0
        };
656
657
0
        match local {
658
            None => {
659
0
                let socket = create_socket(addr.is_ipv4())?;
660
0
                trace!(dest=%addr, "no local address, connect directly");
661
0
                Ok(socket.connect(addr).await?)
662
            }
663
            // TODO: Need figure out how to handle case of loadbalancing to itself.
664
            //       We use ztunnel addr instead, otherwise app side will be confused.
665
0
            Some(src) if src == socket::to_canonical(addr).ip() => {
666
0
                let socket = create_socket(addr.is_ipv4())?;
667
0
                trace!(%src, dest=%addr, "dest and source are the same, connect directly");
668
0
                Ok(socket.connect(addr).await?)
669
            }
670
0
            Some(src) => {
671
0
                let socket = create_socket(src.is_ipv4())?;
672
0
                let local_addr = SocketAddr::new(src, 0);
673
0
                match socket::set_freebind_and_transparent(&socket) {
674
0
                    Err(err) => warn!("failed to set freebind: {:?}", err),
675
                    _ => {
676
0
                        if let Err(err) = socket.bind(local_addr) {
677
0
                            warn!("failed to bind local addr: {:?}", err)
678
0
                        }
679
                    }
680
                };
681
0
                trace!(%src, dest=%addr, "connect with source IP");
682
0
                Ok(socket.connect(addr).await?)
683
            }
684
        }
685
0
    }
686
    // Wrap the entire connect function in a timeout
687
0
    timeout(CONNECTION_TIMEOUT, connect(local, addr, socket_factory))
688
0
        .await
689
0
        .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?
690
0
}
691
692
// guess_inbound_service selects an upstream service for inbound metrics.
693
// There may be many services for a single workload. We find the the first one with an applicable port
694
// as a best guess.
695
0
pub fn guess_inbound_service(
696
0
    conn: &Connection,
697
0
    for_host_header: &Option<String>,
698
0
    upstream_service: Vec<Arc<Service>>,
699
0
    dest: &Workload,
700
0
) -> Option<ServiceDescription> {
701
    // First, if the client told us what Service they were reaching, look for that
702
    // Note: the set of Services we look for is bounded, so we won't blindly trust bogus info.
703
0
    if let Some(found) = upstream_service
704
0
        .iter()
705
0
        .find(|s| for_host_header.as_deref() == Some(s.hostname.as_ref()))
706
0
        .map(|s| ServiceDescription::from(s.as_ref()))
707
    {
708
0
        return Some(found);
709
0
    }
710
0
    let dport = conn.dst.port();
711
0
    upstream_service
712
0
        .iter()
713
0
        .find(|s| {
714
0
            for (sport, tport) in s.ports.iter() {
715
0
                if tport == &dport {
716
                    // TargetPort directly matches
717
0
                    return true;
718
0
                }
719
                // The service itself didn't have a explicit TargetPort match, but an endpoint might.
720
                // This happens when there is a named port (in Kubernetes, anyways).
721
0
                if s.endpoints.get(&dest.uid).and_then(|e| e.port.get(sport)) == Some(&dport) {
722
                    // Named port matched
723
0
                    return true;
724
0
                }
725
                // no match
726
            }
727
0
            false
728
0
        })
729
0
        .map(|s| ServiceDescription::from(s.as_ref()))
730
0
}
731
732
// Checks that the source identiy and address match the upstream's waypoint
733
0
async fn check_from_waypoint(
734
0
    state: &DemandProxyState,
735
0
    upstream: &Workload,
736
0
    src_identity: Option<&Identity>,
737
0
    src_ip: &IpAddr,
738
0
) -> bool {
739
0
    let is_waypoint = |wl: &Workload| {
740
0
        Some(wl.identity()).as_ref() == src_identity && wl.workload_ips.contains(src_ip)
741
0
    };
742
0
    check_gateway_address(state, upstream.waypoint.as_ref(), is_waypoint).await
743
0
}
744
745
// Check if the source's identity matches any workloads that make up the given gateway
746
// TODO: This can be made more accurate by also checking addresses.
747
0
async fn check_gateway_address<F>(
748
0
    state: &DemandProxyState,
749
0
    gateway_address: Option<&GatewayAddress>,
750
0
    predicate: F,
751
0
) -> bool
752
0
where
753
0
    F: Fn(&Workload) -> bool,
754
0
{
755
0
    let Some(gateway_address) = gateway_address else {
756
0
        return false;
757
    };
758
759
0
    match state.fetch_destination(&gateway_address.destination).await {
760
0
        Some(Address::Workload(wl)) => return predicate(wl.as_ref()),
761
0
        Some(Address::Service(svc)) => {
762
0
            for ep in svc.endpoints.iter() {
763
                // fetch workloads by workload UID since we may not have an IP for an endpoint (e.g., endpoint is just a hostname)
764
0
                let wl = state.fetch_workload_by_uid(&ep.workload_uid).await;
765
0
                if wl.as_ref().is_some_and(|wl| predicate(wl.as_ref())) {
766
0
                    return true;
767
0
                }
768
            }
769
        }
770
0
        None => {}
771
    };
772
773
0
    false
774
0
}
775
776
const IPV6_DISABLED_LO: &str = "/proc/sys/net/ipv6/conf/lo/disable_ipv6";
777
778
0
fn read_sysctl(key: &str) -> io::Result<String> {
779
0
    let mut file = File::open(key)?;
780
0
    let mut data = String::new();
781
0
    file.read_to_string(&mut data)?;
782
0
    Ok(data.trim().to_string())
783
0
}
784
785
0
pub fn ipv6_enabled_on_localhost() -> io::Result<bool> {
786
0
    read_sysctl(IPV6_DISABLED_LO).map(|s| s != "1")
787
0
}
788
789
1.23k
pub fn parse_forwarded_host(input: &str) -> Option<String> {
790
1.23k
    if !input.is_ascii() {
791
17
        return None;
792
1.22k
    }
793
1.22k
    input
794
1.22k
        .split(';')
795
414k
        .find(|part| part.trim().starts_with("host="))
796
1.22k
        .and_then(|host_part| {
797
234
            host_part
798
234
                .trim()
799
234
                .strip_prefix("host=")
800
234
                .map(|h| h.strip_prefix("\"").unwrap_or(h))
801
234
                .map(|h| h.strip_suffix("\"").unwrap_or(h))
802
234
                .map(|s| s.to_string())
803
234
        })
804
1.22k
        .filter(|host| !host.is_empty())
805
1.23k
}
806
807
#[derive(Debug, Clone, PartialEq)]
808
pub enum HboneAddress {
809
    SocketAddr(SocketAddr),
810
    SvcHostname(Strng, u16),
811
}
812
813
impl HboneAddress {
814
0
    pub fn port(&self) -> u16 {
815
0
        match self {
816
0
            HboneAddress::SocketAddr(s) => s.port(),
817
0
            HboneAddress::SvcHostname(_, p) => *p,
818
        }
819
0
    }
820
821
0
    pub fn ip(&self) -> Option<IpAddr> {
822
0
        match self {
823
0
            HboneAddress::SocketAddr(s) => Some(s.ip()),
824
0
            HboneAddress::SvcHostname(_, _) => None,
825
        }
826
0
    }
827
828
0
    pub fn svc_hostname(&self) -> Option<Strng> {
829
0
        match self {
830
0
            HboneAddress::SocketAddr(_) => None,
831
0
            HboneAddress::SvcHostname(s, _) => Some(s.into()),
832
        }
833
0
    }
834
835
0
    pub fn hostname_addr(&self) -> Option<Strng> {
836
0
        match self {
837
0
            HboneAddress::SocketAddr(_) => None,
838
0
            HboneAddress::SvcHostname(_, _) => Some(Strng::from(self.to_string())),
839
        }
840
0
    }
841
}
842
843
impl std::fmt::Display for HboneAddress {
844
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
845
0
        match self {
846
0
            HboneAddress::SocketAddr(addr) => write!(f, "{addr}"),
847
0
            HboneAddress::SvcHostname(host, port) => write!(f, "{host}:{port}"),
848
        }
849
0
    }
850
}
851
852
impl From<SocketAddr> for HboneAddress {
853
0
    fn from(socket_addr: SocketAddr) -> Self {
854
0
        HboneAddress::SocketAddr(socket_addr)
855
0
    }
856
}
857
858
impl From<(Strng, u16)> for HboneAddress {
859
0
    fn from(svc_hostname: (Strng, u16)) -> Self {
860
0
        HboneAddress::SvcHostname(svc_hostname.0, svc_hostname.1)
861
0
    }
862
}
863
864
impl TryFrom<&http::Uri> for HboneAddress {
865
    type Error = Error;
866
867
0
    fn try_from(value: &http::Uri) -> Result<Self, Self::Error> {
868
0
        match value.to_string().parse::<SocketAddr>() {
869
0
            Ok(addr) => Ok(HboneAddress::SocketAddr(addr)),
870
            Err(_) => {
871
0
                let hbone_host = value
872
0
                    .host()
873
0
                    .ok_or_else(|| Error::NoValidAuthority(value.to_string()))?;
874
0
                let hbone_port = value
875
0
                    .port_u16()
876
0
                    .ok_or_else(|| Error::NoValidAuthority(value.to_string()))?;
877
0
                Ok(HboneAddress::SvcHostname(hbone_host.into(), hbone_port))
878
            }
879
        }
880
0
    }
881
}
882
883
#[cfg(test)]
884
mod tests {
885
    use super::*;
886
887
    #[test]
888
    fn test_parse_forwarded_host() {
889
        let header = "by=identifier;for=identifier;host=example.com;proto=https";
890
        assert_eq!(
891
            parse_forwarded_host(header),
892
            Some("example.com".to_string())
893
        );
894
        let header = "by=identifier;for=identifier;host=\"example.com\";proto=https";
895
        assert_eq!(
896
            parse_forwarded_host(header),
897
            Some("example.com".to_string())
898
        );
899
        let header = "by=identifier;for=identifier;proto=https";
900
        assert_eq!(parse_forwarded_host(header), None);
901
        let header = "by=identifier;for=identifier;host=;proto=https";
902
        assert_eq!(parse_forwarded_host(header), None);
903
        let header = r#"for=for;by=by;host=host;proto="pröto""#;
904
        assert_eq!(parse_forwarded_host(header), None);
905
    }
906
}