Coverage Report

Created: 2025-12-28 06:31

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/ztunnel/src/state.rs
Line
Count
Source
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use crate::identity::{Identity, SecretManager};
16
use crate::proxy::{Error, OnDemandDnsLabels};
17
use crate::rbac::Authorization;
18
use crate::state::policy::PolicyStore;
19
use crate::state::service::{
20
    Endpoint, IpFamily, LoadBalancerMode, LoadBalancerScopes, ServiceStore,
21
};
22
use crate::state::service::{Service, ServiceDescription};
23
use crate::state::workload::{
24
    GatewayAddress, NamespacedHostname, NetworkAddress, Workload, WorkloadStore, address::Address,
25
    gatewayaddress::Destination, network_addr,
26
};
27
use crate::strng::Strng;
28
use crate::tls;
29
use crate::xds::istio::security::Authorization as XdsAuthorization;
30
use crate::xds::istio::workload::Address as XdsAddress;
31
use crate::xds::{AdsClient, Demander, LocalClient, ProxyStateUpdater};
32
use crate::{cert_fetcher, config, rbac, xds};
33
use crate::{proxy, strng};
34
use educe::Educe;
35
use futures_util::FutureExt;
36
use hickory_resolver::TokioResolver;
37
use hickory_resolver::config::*;
38
use hickory_resolver::name_server::TokioConnectionProvider;
39
use itertools::Itertools;
40
use rand::prelude::IteratorRandom;
41
use rand::seq::IndexedRandom;
42
use serde::Serializer;
43
use std::collections::HashMap;
44
use std::convert::Into;
45
use std::default::Default;
46
use std::fmt;
47
use std::net::{IpAddr, SocketAddr};
48
use std::str::FromStr;
49
use std::sync::{Arc, RwLock, RwLockReadGuard};
50
use std::time::Duration;
51
use tracing::{debug, trace, warn};
52
53
use self::workload::ApplicationTunnel;
54
55
pub mod policy;
56
pub mod service;
57
pub mod workload;
58
59
#[derive(Debug, Eq, PartialEq, Clone)]
60
pub struct Upstream {
61
    /// Workload is the workload we are connecting to
62
    pub workload: Arc<Workload>,
63
    /// selected_workload_ip defines the IP address we should actually use to connect to this workload
64
    /// This handles multiple IPs (dual stack) or Hostname destinations (DNS resolution)
65
    /// The workload IP might be empty if we have to go through a network gateway.
66
    pub selected_workload_ip: Option<IpAddr>,
67
    /// Port is the port we should connect to
68
    pub port: u16,
69
    /// Service SANs defines SANs defined at the service level *only*. A complete view of things requires
70
    /// looking at workload.identity() as well.
71
    pub service_sans: Vec<Strng>,
72
    /// If this was from a service, the service info.
73
    pub destination_service: Option<ServiceDescription>,
74
}
75
76
#[derive(Clone, Debug, Eq, PartialEq)]
77
enum UpstreamDestination {
78
    UpstreamParts(Arc<Workload>, u16, Option<Arc<Service>>),
79
    OriginalDestination,
80
}
81
82
impl Upstream {
83
0
    pub fn workload_socket_addr(&self) -> Option<SocketAddr> {
84
0
        self.selected_workload_ip
85
0
            .map(|ip| SocketAddr::new(ip, self.port))
86
0
    }
87
0
    pub fn workload_and_services_san(&self) -> Vec<Identity> {
88
0
        self.service_sans
89
0
            .iter()
90
0
            .flat_map(|san| match Identity::from_str(san) {
91
0
                Ok(id) => Some(id),
92
0
                Err(err) => {
93
0
                    warn!("ignoring invalid SAN {}: {}", san, err);
94
0
                    None
95
                }
96
0
            })
97
0
            .chain(std::iter::once(self.workload.identity()))
98
0
            .collect()
99
0
    }
100
101
0
    pub fn service_sans(&self) -> Vec<Identity> {
102
0
        self.service_sans
103
0
            .iter()
104
0
            .flat_map(|san| match Identity::from_str(san) {
105
0
                Ok(id) => Some(id),
106
0
                Err(err) => {
107
0
                    warn!("ignoring invalid SAN {}: {}", san, err);
108
0
                    None
109
                }
110
0
            })
111
0
            .collect()
112
0
    }
113
}
114
115
// Workload information that a specific proxy instance represents. This is used to cross check
116
// with the workload fetched using destination address when making RBAC decisions.
117
#[derive(
118
    Debug, Clone, Eq, Hash, Ord, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize,
119
)]
120
#[serde(rename_all = "camelCase")]
121
pub struct WorkloadInfo {
122
    pub name: String,
123
    pub namespace: String,
124
    pub service_account: String,
125
}
126
127
impl fmt::Display for WorkloadInfo {
128
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129
0
        write!(
130
0
            f,
131
0
            "{}.{} ({})",
132
            self.service_account, self.namespace, self.name
133
        )
134
0
    }
135
}
136
137
impl WorkloadInfo {
138
0
    pub fn new(name: String, namespace: String, service_account: String) -> Self {
139
0
        Self {
140
0
            name,
141
0
            namespace,
142
0
            service_account,
143
0
        }
144
0
    }
145
146
0
    pub fn matches(&self, w: &Workload) -> bool {
147
0
        self.name == w.name
148
0
            && self.namespace == w.namespace
149
0
            && self.service_account == w.service_account
150
0
    }
151
}
152
153
#[derive(Educe, Debug, Clone, Eq, serde::Serialize)]
154
#[educe(PartialEq, Hash)]
155
pub struct ProxyRbacContext {
156
    pub conn: rbac::Connection,
157
    #[educe(Hash(ignore), PartialEq(ignore))]
158
    pub dest_workload: Arc<Workload>,
159
}
160
161
impl ProxyRbacContext {
162
0
    pub fn into_conn(self) -> rbac::Connection {
163
0
        self.conn
164
0
    }
165
}
166
167
impl fmt::Display for ProxyRbacContext {
168
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169
0
        write!(f, "{} ({})", self.conn, self.dest_workload.uid)?;
170
0
        Ok(())
171
0
    }
172
}
173
/// The current state information for this proxy.
174
#[derive(Debug)]
175
pub struct ProxyState {
176
    pub workloads: WorkloadStore,
177
178
    pub services: ServiceStore,
179
180
    pub policies: PolicyStore,
181
}
182
183
#[derive(serde::Serialize, Debug)]
184
#[serde(rename_all = "camelCase")]
185
struct ProxyStateSerialization<'a> {
186
    workloads: Vec<Arc<Workload>>,
187
    services: Vec<Arc<Service>>,
188
    policies: Vec<Authorization>,
189
    staged_services: &'a HashMap<NamespacedHostname, HashMap<Strng, Endpoint>>,
190
}
191
192
impl serde::Serialize for ProxyState {
193
0
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194
0
    where
195
0
        S: Serializer,
196
    {
197
        // Services all have hostname, so use that as the key
198
0
        let services: Vec<_> = self
199
0
            .services
200
0
            .by_host
201
0
            .iter()
202
0
            .sorted_by_key(|k| k.0)
203
0
            .flat_map(|k| k.1)
204
0
            .cloned()
205
0
            .collect();
206
        // Workloads all have a UID, so use that as the key
207
0
        let workloads: Vec<_> = self
208
0
            .workloads
209
0
            .by_uid
210
0
            .iter()
211
0
            .sorted_by_key(|k| k.0)
212
0
            .map(|k| k.1)
213
0
            .cloned()
214
0
            .collect();
215
0
        let policies: Vec<_> = self
216
0
            .policies
217
0
            .by_key
218
0
            .iter()
219
0
            .sorted_by_key(|k| k.0)
220
0
            .map(|k| k.1)
221
0
            .cloned()
222
0
            .collect();
223
0
        let serializable = ProxyStateSerialization {
224
0
            workloads,
225
0
            services,
226
0
            policies,
227
0
            staged_services: &self.services.staged_services,
228
0
        };
229
0
        serializable.serialize(serializer)
230
0
    }
231
}
232
233
impl ProxyState {
234
0
    pub fn new(local_node: Option<Strng>) -> ProxyState {
235
0
        ProxyState {
236
0
            workloads: WorkloadStore::new(local_node),
237
0
            services: Default::default(),
238
0
            policies: Default::default(),
239
0
        }
240
0
    }
241
242
    /// Find either a workload or service by the destination.
243
0
    pub fn find_destination(&self, dest: &Destination) -> Option<Address> {
244
0
        match dest {
245
0
            Destination::Address(addr) => self.find_address(addr),
246
0
            Destination::Hostname(hostname) => self.find_hostname(hostname),
247
        }
248
0
    }
249
250
    /// Find either a workload or a service by address.
251
0
    pub fn find_address(&self, network_addr: &NetworkAddress) -> Option<Address> {
252
        // 1. handle workload ip, if workload not found fallback to service.
253
0
        match self.workloads.find_address(network_addr) {
254
            None => {
255
                // 2. handle service
256
0
                if let Some(svc) = self.services.get_by_vip(network_addr) {
257
0
                    return Some(Address::Service(svc));
258
0
                }
259
0
                None
260
            }
261
0
            Some(wl) => Some(Address::Workload(wl)),
262
        }
263
0
    }
264
265
    /// Find either a workload or a service by hostname.
266
0
    pub fn find_hostname(&self, name: &NamespacedHostname) -> Option<Address> {
267
        // Hostnames for services are more common, so lookup service first and fallback to workload.
268
0
        self.services
269
0
            .get_by_namespaced_host(name)
270
0
            .map(Address::Service)
271
0
            .or_else(|| {
272
                // Slow path: lookup workload by O(n) lookup. This is an uncommon path, so probably not worth
273
                // the memory cost to index currently
274
0
                self.workloads
275
0
                    .by_uid
276
0
                    .values()
277
0
                    .find(|w| w.hostname == name.hostname && w.namespace == name.namespace)
278
0
                    .cloned()
279
0
                    .map(Address::Workload)
280
0
            })
281
0
    }
282
283
    /// Find services by hostname.
284
0
    pub fn find_service_by_hostname(&self, hostname: &Strng) -> Result<Vec<Arc<Service>>, Error> {
285
        // Hostnames for services are more common, so lookup service first and fallback to workload.
286
0
        self.services
287
0
            .get_by_host(hostname)
288
0
            .ok_or_else(|| Error::NoHostname(hostname.to_string()))
289
0
    }
290
291
0
    fn find_upstream(
292
0
        &self,
293
0
        network: Strng,
294
0
        source_workload: &Workload,
295
0
        addr: SocketAddr,
296
0
        resolution_mode: ServiceResolutionMode,
297
0
    ) -> Option<UpstreamDestination> {
298
0
        if let Some(svc) = self
299
0
            .services
300
0
            .get_by_vip(&network_addr(network.clone(), addr.ip()))
301
        {
302
0
            if let Some(lb) = &svc.load_balancer
303
0
                && lb.mode == LoadBalancerMode::Passthrough
304
            {
305
0
                return Some(UpstreamDestination::OriginalDestination);
306
0
            }
307
0
            return self.find_upstream_from_service(
308
0
                source_workload,
309
0
                addr.port(),
310
0
                resolution_mode,
311
0
                svc,
312
            );
313
0
        }
314
0
        if let Some(wl) = self
315
0
            .workloads
316
0
            .find_address(&network_addr(network, addr.ip()))
317
        {
318
0
            return Some(UpstreamDestination::UpstreamParts(wl, addr.port(), None));
319
0
        }
320
0
        None
321
0
    }
322
323
0
    fn find_upstream_from_service(
324
0
        &self,
325
0
        source_workload: &Workload,
326
0
        svc_port: u16,
327
0
        resolution_mode: ServiceResolutionMode,
328
0
        svc: Arc<Service>,
329
0
    ) -> Option<UpstreamDestination> {
330
        // Randomly pick an upstream
331
        // TODO: do this more efficiently, and not just randomly
332
0
        let Some((ep, wl)) = self.load_balance(source_workload, &svc, svc_port, resolution_mode)
333
        else {
334
0
            debug!("Service {} has no healthy endpoints", svc.hostname);
335
0
            return None;
336
        };
337
338
0
        let svc_target_port = svc.ports.get(&svc_port).copied().unwrap_or_default();
339
0
        let target_port = if let Some(&ep_target_port) = ep.port.get(&svc_port) {
340
            // prefer endpoint port mapping
341
0
            ep_target_port
342
0
        } else if svc_target_port > 0 {
343
            // otherwise, see if the service has this port
344
0
            svc_target_port
345
0
        } else if let Some(ApplicationTunnel { port: Some(_), .. }) = &wl.application_tunnel {
346
            // when using app tunnel, we don't require the port to be found on the service
347
0
            svc_port
348
        } else {
349
            // no app tunnel or port mapping, error
350
0
            debug!(
351
0
                "found service {}, but port {} was unknown",
352
0
                svc.hostname, svc_port
353
            );
354
0
            return None;
355
        };
356
357
0
        Some(UpstreamDestination::UpstreamParts(
358
0
            wl,
359
0
            target_port,
360
0
            Some(svc),
361
0
        ))
362
0
    }
363
364
0
    fn load_balance<'a>(
365
0
        &self,
366
0
        src: &Workload,
367
0
        svc: &'a Service,
368
0
        svc_port: u16,
369
0
        resolution_mode: ServiceResolutionMode,
370
0
    ) -> Option<(&'a Endpoint, Arc<Workload>)> {
371
0
        let target_port = svc.ports.get(&svc_port).copied();
372
373
0
        if resolution_mode == ServiceResolutionMode::Standard && target_port.is_none() {
374
            // Port doesn't exist on the service at all, this is invalid
375
0
            debug!("service {} does not have port {}", svc.hostname, svc_port);
376
0
            return None;
377
0
        };
378
379
0
        let endpoints = svc.endpoints.iter().filter_map(|ep| {
380
0
            let Some(wl) = self.workloads.find_uid(&ep.workload_uid) else {
381
0
                debug!("failed to fetch workload for {}", ep.workload_uid);
382
0
                return None;
383
            };
384
385
0
            let in_network = wl.network == src.network;
386
0
            let has_network_gateway = wl.network_gateway.is_some();
387
0
            let has_address = !wl.workload_ips.is_empty() || !wl.hostname.is_empty();
388
0
            if !has_address {
389
                // Workload has no IP. We can only reach it via a network gateway
390
                // WDS is client-agnostic, so we will get a network gateway for a workload
391
                // even if it's in the same network; we should never use it.
392
0
                if in_network || !has_network_gateway {
393
0
                    return None;
394
0
                }
395
0
            }
396
397
0
            match resolution_mode {
398
                ServiceResolutionMode::Standard => {
399
0
                    if target_port.unwrap_or_default() == 0 && !ep.port.contains_key(&svc_port) {
400
                        // Filter workload out, it doesn't have a matching port
401
0
                        trace!(
402
0
                            "filter endpoint {}, it does not have service port {}",
403
                            ep.workload_uid, svc_port
404
                        );
405
0
                        return None;
406
0
                    }
407
                }
408
                ServiceResolutionMode::Waypoint => {
409
0
                    if target_port.is_none() && wl.application_tunnel.is_none() {
410
                        // We ignore this for app_tunnel; in this case, the port does not need to be on the service.
411
                        // This is only valid for waypoints, which are not explicitly addressed by users.
412
                        // We do happen to do a lookup by `waypoint-svc:15008`, this is not a literal call on that service;
413
                        // the port is not required at all if they have application tunnel, as it will be handled by ztunnel on the other end.
414
0
                        trace!(
415
0
                            "filter waypoint endpoint {}, target port is not defined",
416
                            ep.workload_uid
417
                        );
418
0
                        return None;
419
0
                    }
420
                }
421
            }
422
0
            Some((ep, wl))
423
0
        });
424
425
0
        let options = match svc.load_balancer {
426
0
            Some(ref lb) if lb.mode != LoadBalancerMode::Standard => {
427
0
                let ranks = endpoints
428
0
                    .filter_map(|(ep, wl)| {
429
                        // Load balancer will define N targets we want to match
430
                        // Consider [network, region, zone]
431
                        // Rank = 3 means we match all of them
432
                        // Rank = 2 means network and region match
433
                        // Rank = 0 means none match
434
0
                        let mut rank = 0;
435
0
                        for target in &lb.routing_preferences {
436
0
                            let matches = match target {
437
                                LoadBalancerScopes::Region => {
438
0
                                    src.locality.region == wl.locality.region
439
                                }
440
0
                                LoadBalancerScopes::Zone => src.locality.zone == wl.locality.zone,
441
                                LoadBalancerScopes::Subzone => {
442
0
                                    src.locality.subzone == wl.locality.subzone
443
                                }
444
0
                                LoadBalancerScopes::Node => src.node == wl.node,
445
0
                                LoadBalancerScopes::Cluster => src.cluster_id == wl.cluster_id,
446
0
                                LoadBalancerScopes::Network => src.network == wl.network,
447
                            };
448
0
                            if matches {
449
0
                                rank += 1;
450
0
                            } else {
451
0
                                break;
452
                            }
453
                        }
454
                        // Doesn't match all, and required to. Do not select this endpoint
455
0
                        if lb.mode == LoadBalancerMode::Strict
456
0
                            && rank != lb.routing_preferences.len()
457
                        {
458
0
                            return None;
459
0
                        }
460
0
                        Some((rank, ep, wl))
461
0
                    })
462
0
                    .collect::<Vec<_>>();
463
0
                let max = *ranks.iter().map(|(rank, _ep, _wl)| rank).max()?;
464
0
                let options: Vec<_> = ranks
465
0
                    .into_iter()
466
0
                    .filter(|(rank, _ep, _wl)| *rank == max)
467
0
                    .map(|(_, ep, wl)| (ep, wl))
468
0
                    .collect();
469
0
                options
470
            }
471
0
            _ => endpoints.collect(),
472
        };
473
0
        options
474
0
            .choose_weighted(&mut rand::rng(), |(_, wl)| wl.capacity as u64)
475
            // This can fail if there are no weights, the sum is zero (not possible in our API), or if it overflows
476
            // The API has u32 but we sum into an u64, so it would take ~4 billion entries of max weight to overflow
477
0
            .ok()
478
0
            .cloned()
479
0
    }
480
}
481
482
/// Wrapper around [ProxyState] that provides additional methods for requesting information
483
/// on-demand.
484
#[derive(serde::Serialize, Clone)]
485
pub struct DemandProxyState {
486
    #[serde(flatten)]
487
    state: Arc<RwLock<ProxyState>>,
488
489
    /// If present, used to request on-demand updates for workloads.
490
    #[serde(skip_serializing)]
491
    demand: Option<Demander>,
492
493
    #[serde(skip_serializing)]
494
    metrics: Arc<proxy::Metrics>,
495
496
    #[serde(skip_serializing)]
497
    dns_resolver: TokioResolver,
498
}
499
500
impl DemandProxyState {
501
0
    pub(crate) fn get_services_by_workload(&self, wl: &Workload) -> Vec<Arc<Service>> {
502
0
        self.state
503
0
            .read()
504
0
            .expect("mutex")
505
0
            .services
506
0
            .get_by_workload(wl)
507
0
    }
508
}
509
510
impl DemandProxyState {
511
0
    pub fn new(
512
0
        state: Arc<RwLock<ProxyState>>,
513
0
        demand: Option<Demander>,
514
0
        dns_resolver_cfg: ResolverConfig,
515
0
        dns_resolver_opts: ResolverOpts,
516
0
        metrics: Arc<proxy::Metrics>,
517
0
    ) -> Self {
518
0
        let mut rb = hickory_resolver::Resolver::builder_with_config(
519
0
            dns_resolver_cfg,
520
0
            TokioConnectionProvider::default(),
521
        );
522
0
        *rb.options_mut() = dns_resolver_opts;
523
0
        let dns_resolver = rb.build();
524
0
        Self {
525
0
            state,
526
0
            demand,
527
0
            dns_resolver,
528
0
            metrics,
529
0
        }
530
0
    }
531
532
0
    pub fn read(&self) -> RwLockReadGuard<'_, ProxyState> {
533
0
        self.state.read().unwrap()
534
0
    }
535
536
0
    pub async fn assert_rbac(
537
0
        &self,
538
0
        ctx: &ProxyRbacContext,
539
0
    ) -> Result<(), proxy::AuthorizationRejectionError> {
540
0
        let wl = &ctx.dest_workload;
541
0
        let conn = &ctx.conn;
542
0
        let state = self.read();
543
544
        // We can get policies from namespace, global, and workload...
545
0
        let ns = state.policies.get_by_namespace(&wl.namespace);
546
0
        let global = state.policies.get_by_namespace(&crate::strng::EMPTY);
547
0
        let workload = wl.authorization_policies.iter();
548
549
        // Aggregate all of them based on type
550
0
        let (allow, deny): (Vec<_>, Vec<_>) = ns
551
0
            .iter()
552
0
            .chain(global.iter())
553
0
            .chain(workload)
554
0
            .filter_map(|k| {
555
0
                let pol = state.policies.get(k);
556
                // Policy not found. This is probably transition state where the policy hasn't been sent
557
                // by the control plane, or it was just removed.
558
0
                if pol.is_none() {
559
0
                    warn!("skipping unknown policy {k}");
560
0
                }
561
0
                pol
562
0
            })
563
0
            .partition(|p| p.action == rbac::RbacAction::Allow);
564
565
0
        trace!(
566
0
            allow = allow.len(),
567
0
            deny = deny.len(),
568
0
            "checking connection"
569
        );
570
571
        // Allow and deny logic follows https://istio.io/latest/docs/reference/config/security/authorization-policy/
572
573
        // "If there are any DENY policies that match the request, deny the request."
574
0
        for pol in deny.iter() {
575
0
            if pol.matches(conn) {
576
0
                debug!(policy = pol.to_key().as_str(), "deny policy match");
577
0
                return Err(proxy::AuthorizationRejectionError::ExplicitlyDenied(
578
0
                    pol.namespace.to_owned(),
579
0
                    pol.name.to_owned(),
580
0
                ));
581
            } else {
582
0
                trace!(policy = pol.to_key().as_str(), "deny policy does not match");
583
            }
584
        }
585
        // "If there are no ALLOW policies for the workload, allow the request."
586
0
        if allow.is_empty() {
587
0
            debug!("no allow policies, allow");
588
0
            return Ok(());
589
0
        }
590
        // "If any of the ALLOW policies match the request, allow the request."
591
0
        for pol in allow.iter() {
592
0
            if pol.matches(conn) {
593
0
                debug!(policy = pol.to_key().as_str(), "allow policy match");
594
0
                return Ok(());
595
            } else {
596
0
                trace!(
597
0
                    policy = pol.to_key().as_str(),
598
0
                    "allow policy does not match"
599
                );
600
            }
601
        }
602
        // "Deny the request."
603
0
        debug!("no allow policies matched");
604
0
        Err(proxy::AuthorizationRejectionError::NotAllowed)
605
0
    }
606
607
    // Select a workload IP, with DNS resolution if needed
608
0
    async fn pick_workload_destination_or_resolve(
609
0
        &self,
610
0
        dst_workload: &Workload,
611
0
        src_workload: &Workload,
612
0
        original_target_address: SocketAddr,
613
0
        ip_family_restriction: Option<IpFamily>,
614
0
    ) -> Result<Option<IpAddr>, Error> {
615
        // If the user requested the pod by a specific IP, use that directly.
616
0
        if dst_workload
617
0
            .workload_ips
618
0
            .contains(&original_target_address.ip())
619
        {
620
0
            return Ok(Some(original_target_address.ip()));
621
0
        }
622
        // They may have 1 or 2 IPs (single/dual stack)
623
        // Ensure we are meeting the Service family restriction (if any is defined).
624
        // Otherwise, prefer the same IP family as the original request.
625
0
        if let Some(ip) = dst_workload
626
0
            .workload_ips
627
0
            .iter()
628
0
            .filter(|ip| {
629
0
                ip_family_restriction
630
0
                    .map(|f| f.accepts_ip(**ip))
631
0
                    .unwrap_or(true)
632
0
            })
633
0
            .find_or_first(|ip| ip.is_ipv6() == original_target_address.is_ipv6())
634
        {
635
0
            return Ok(Some(*ip));
636
0
        }
637
0
        if dst_workload.hostname.is_empty() {
638
0
            if dst_workload.network_gateway.is_none() {
639
0
                debug!(
640
0
                    "workload {} has no suitable workload IPs for routing",
641
                    dst_workload.name
642
                );
643
0
                return Err(Error::NoValidDestination(Box::new(dst_workload.clone())));
644
            } else {
645
                // We can route through network gateway
646
0
                return Ok(None);
647
            }
648
0
        }
649
0
        let ip = Box::pin(self.resolve_workload_address(
650
0
            dst_workload,
651
0
            src_workload,
652
0
            original_target_address,
653
0
        ))
654
0
        .await?;
655
0
        Ok(Some(ip))
656
0
    }
657
658
0
    async fn resolve_workload_address(
659
0
        &self,
660
0
        workload: &Workload,
661
0
        src_workload: &Workload,
662
0
        original_target_address: SocketAddr,
663
0
    ) -> Result<IpAddr, Error> {
664
0
        let labels = OnDemandDnsLabels::new()
665
0
            .with_destination(workload)
666
0
            .with_source(src_workload);
667
0
        self.metrics
668
0
            .as_ref()
669
0
            .on_demand_dns
670
0
            .get_or_create(&labels)
671
0
            .inc();
672
0
        self.resolve_on_demand_dns(workload, original_target_address)
673
0
            .await
674
0
    }
675
676
0
    async fn resolve_on_demand_dns(
677
0
        &self,
678
0
        workload: &Workload,
679
0
        original_target_address: SocketAddr,
680
0
    ) -> Result<IpAddr, Error> {
681
0
        let workload_uid = workload.uid.clone();
682
0
        let hostname = workload.hostname.clone();
683
0
        trace!(%hostname, "starting DNS lookup");
684
685
0
        let resp = match self.dns_resolver.lookup_ip(hostname.as_str()).await {
686
0
            Err(err) => {
687
0
                warn!(?err,%hostname,"dns lookup failed");
688
0
                return Err(Error::NoResolvedAddresses(workload_uid.to_string()));
689
            }
690
0
            Ok(resp) => resp,
691
        };
692
0
        trace!(%hostname, "dns lookup complete {resp:?}");
693
694
0
        let (matching, unmatching): (Vec<_>, Vec<_>) = resp
695
0
            .as_lookup()
696
0
            .record_iter()
697
0
            .filter_map(|record| record.data().ip_addr())
698
0
            .partition(|record| record.is_ipv6() == original_target_address.is_ipv6());
699
        // Randomly pick an IP, prefer to match the IP family of the downstream request.
700
        // Without this, we run into trouble in pure v4 or pure v6 environments.
701
0
        matching
702
0
            .into_iter()
703
0
            .choose(&mut rand::rng())
704
0
            .or_else(|| unmatching.into_iter().choose(&mut rand::rng()))
705
0
            .ok_or_else(|| Error::EmptyResolvedAddresses(workload_uid.to_string()))
706
0
    }
707
708
    // same as fetch_workload, but if the caller knows the workload is enroute already,
709
    // will retry on cache miss for a configured amount of time - returning the workload
710
    // when we get it, or nothing if the timeout is exceeded, whichever happens first
711
0
    pub async fn wait_for_workload(
712
0
        &self,
713
0
        wl: &WorkloadInfo,
714
0
        deadline: Duration,
715
0
    ) -> Option<Arc<Workload>> {
716
0
        debug!(%wl, "wait for workload");
717
718
        // Take a watch listener *before* checking state (so we don't miss anything)
719
0
        let mut wl_sub = self.read().workloads.new_subscriber();
720
721
0
        debug!(%wl, "got sub, waiting for workload");
722
723
0
        if let Some(wl) = self.find_by_info(wl) {
724
0
            return Some(wl);
725
0
        }
726
727
        // We didn't find the workload we expected, so
728
        // loop until the subscriber wakes us on new workload,
729
        // or we hit the deadline timeout and give up
730
0
        let timeout = tokio::time::sleep(deadline);
731
0
        tokio::pin!(timeout);
732
        loop {
733
0
            tokio::select! {
734
0
                _ = &mut timeout => {
735
0
                    warn!("timed out waiting for workload '{wl}' from xds");
736
0
                    break None;
737
                },
738
0
                _ = wl_sub.changed() => {
739
0
                    if let Some(wl) = self.find_by_info(wl) {
740
0
                        break Some(wl);
741
0
                    }
742
                }
743
            }
744
        }
745
0
    }
746
747
    /// Finds the workload by workload information, as an arc.
748
    /// Note: this does not currently support on-demand.
749
0
    fn find_by_info(&self, wl: &WorkloadInfo) -> Option<Arc<Workload>> {
750
0
        self.read().workloads.find_by_info(wl)
751
0
    }
752
753
    // fetch_workload_by_address looks up a Workload by address.
754
    // Note this should never be used to lookup the local workload we are running, only the peer.
755
    // Since the peer connection may come through gateways, NAT, etc, this should only ever be treated
756
    // as a best-effort.
757
0
    pub async fn fetch_workload_by_address(&self, addr: &NetworkAddress) -> Option<Arc<Workload>> {
758
        // Wait for it on-demand, *if* needed
759
0
        debug!(%addr, "fetch workload");
760
0
        if let Some(wl) = self.read().workloads.find_address(addr) {
761
0
            return Some(wl);
762
0
        }
763
0
        if !self.supports_on_demand() {
764
0
            return None;
765
0
        }
766
0
        self.fetch_on_demand(addr.to_string().into()).await;
767
0
        self.read().workloads.find_address(addr)
768
0
    }
769
770
    // only support workload
771
0
    pub async fn fetch_workload_by_uid(&self, uid: &Strng) -> Option<Arc<Workload>> {
772
        // Wait for it on-demand, *if* needed
773
0
        debug!(%uid, "fetch workload");
774
0
        if let Some(wl) = self.read().workloads.find_uid(uid) {
775
0
            return Some(wl);
776
0
        }
777
0
        if !self.supports_on_demand() {
778
0
            return None;
779
0
        }
780
0
        self.fetch_on_demand(uid.clone()).await;
781
0
        self.read().workloads.find_uid(uid)
782
0
    }
783
784
0
    pub async fn fetch_upstream(
785
0
        &self,
786
0
        network: Strng,
787
0
        source_workload: &Workload,
788
0
        addr: SocketAddr,
789
0
        resolution_mode: ServiceResolutionMode,
790
0
    ) -> Result<Option<Upstream>, Error> {
791
0
        self.fetch_address(&network_addr(network.clone(), addr.ip()))
792
0
            .await;
793
0
        let upstream = {
794
0
            self.read()
795
0
                .find_upstream(network, source_workload, addr, resolution_mode)
796
            // Drop the lock
797
        };
798
0
        tracing::trace!(%addr, ?upstream, "fetch_upstream");
799
0
        self.finalize_upstream(source_workload, addr, upstream)
800
0
            .await
801
0
    }
802
803
0
    async fn finalize_upstream(
804
0
        &self,
805
0
        source_workload: &Workload,
806
0
        original_target_address: SocketAddr,
807
0
        upstream: Option<UpstreamDestination>,
808
0
    ) -> Result<Option<Upstream>, Error> {
809
0
        let (wl, port, svc) = match upstream {
810
0
            Some(UpstreamDestination::UpstreamParts(wl, port, svc)) => (wl, port, svc),
811
0
            None | Some(UpstreamDestination::OriginalDestination) => return Ok(None),
812
        };
813
0
        let svc_desc = svc.clone().map(|s| ServiceDescription::from(s.as_ref()));
814
0
        let ip_family_restriction = svc.as_ref().and_then(|s| s.ip_families);
815
0
        let selected_workload_ip = self
816
0
            .pick_workload_destination_or_resolve(
817
0
                &wl,
818
0
                source_workload,
819
0
                original_target_address,
820
0
                ip_family_restriction,
821
0
            )
822
0
            .await?; // if we can't load balance just return the error
823
0
        let res = Upstream {
824
0
            workload: wl,
825
0
            selected_workload_ip,
826
0
            port,
827
0
            service_sans: svc.map(|s| s.subject_alt_names.clone()).unwrap_or_default(),
828
0
            destination_service: svc_desc,
829
        };
830
0
        tracing::trace!(?res, "finalize_upstream");
831
0
        Ok(Some(res))
832
0
    }
833
834
    /// Returns destination address, upstream sans, and final sans, for
835
    /// connecting to a remote workload through a gateway.
836
    /// Would be nice to return this as an Upstream, but gateways don't necessarily
837
    /// have workloads. That is, they could just be IPs without a corresponding workload.
838
0
    pub async fn fetch_network_gateway(
839
0
        &self,
840
0
        gw_address: &GatewayAddress,
841
0
        source_workload: &Workload,
842
0
        original_destination_address: SocketAddr,
843
0
    ) -> Result<Upstream, Error> {
844
0
        let (res, target_address) = match &gw_address.destination {
845
0
            Destination::Address(ip) => {
846
0
                let addr = SocketAddr::new(ip.address, gw_address.hbone_mtls_port);
847
0
                let us = self.state.read().unwrap().find_upstream(
848
0
                    ip.network.clone(),
849
0
                    source_workload,
850
0
                    addr,
851
0
                    ServiceResolutionMode::Standard,
852
0
                );
853
                // If the workload references a network gateway by IP, use that IP as the destination.
854
                // Note this means that an IPv6 call may be translated to IPv4 if the network
855
                // gateway is specified as an IPv4 address.
856
                // For this reason, the Hostname method is preferred which can adapt to the callers IP family.
857
0
                (us, addr)
858
            }
859
0
            Destination::Hostname(host) => {
860
0
                let state = self.read();
861
0
                match state.find_hostname(host) {
862
0
                    Some(Address::Service(s)) => {
863
0
                        let us = state.find_upstream_from_service(
864
0
                            source_workload,
865
0
                            gw_address.hbone_mtls_port,
866
0
                            ServiceResolutionMode::Standard,
867
0
                            s,
868
0
                        );
869
                        // For hostname, use the original_destination_address as the target so we can
870
                        // adapt to the callers IP family.
871
0
                        (us, original_destination_address)
872
                    }
873
0
                    Some(Address::Workload(w)) => {
874
0
                        let us = Some(UpstreamDestination::UpstreamParts(
875
0
                            w,
876
0
                            gw_address.hbone_mtls_port,
877
0
                            None,
878
0
                        ));
879
0
                        (us, original_destination_address)
880
                    }
881
                    None => {
882
0
                        return Err(Error::UnknownNetworkGateway(format!(
883
0
                            "network gateway {} not found",
884
0
                            host.hostname
885
0
                        )));
886
                    }
887
                }
888
            }
889
        };
890
0
        self.finalize_upstream(source_workload, target_address, res)
891
0
            .await?
892
0
            .ok_or_else(|| {
893
0
                Error::UnknownNetworkGateway(format!("network gateway {gw_address:?} not found"))
894
0
            })
895
0
    }
896
897
0
    async fn fetch_waypoint(
898
0
        &self,
899
0
        gw_address: &GatewayAddress,
900
0
        source_workload: &Workload,
901
0
        original_destination_address: SocketAddr,
902
0
    ) -> Result<Upstream, Error> {
903
        // Waypoint can be referred to by an IP or Hostname.
904
        // Hostname is preferred as it is a more stable identifier.
905
0
        let (res, target_address) = match &gw_address.destination {
906
0
            Destination::Address(ip) => {
907
0
                let addr = SocketAddr::new(ip.address, gw_address.hbone_mtls_port);
908
0
                let us = self.read().find_upstream(
909
0
                    ip.network.clone(),
910
0
                    source_workload,
911
0
                    addr,
912
0
                    ServiceResolutionMode::Waypoint,
913
0
                );
914
                // If they referenced a waypoint by IP, use that IP as the destination.
915
                // Note this means that an IPv6 call may be translated to IPv4 if the waypoint is specified
916
                // as an IPv4 address.
917
                // For this reason, the Hostname method is preferred which can adapt to the callers IP family.
918
0
                (us, addr)
919
            }
920
0
            Destination::Hostname(host) => {
921
0
                let state = self.read();
922
0
                match state.find_hostname(host) {
923
0
                    Some(Address::Service(s)) => {
924
0
                        let us = state.find_upstream_from_service(
925
0
                            source_workload,
926
0
                            gw_address.hbone_mtls_port,
927
0
                            ServiceResolutionMode::Waypoint,
928
0
                            s,
929
0
                        );
930
                        // For hostname, use the original_destination_address as the target so we can
931
                        // adapt to the callers IP family.
932
0
                        (us, original_destination_address)
933
                    }
934
0
                    Some(Address::Workload(w)) => {
935
0
                        let us = Some(UpstreamDestination::UpstreamParts(
936
0
                            w,
937
0
                            gw_address.hbone_mtls_port,
938
0
                            None,
939
0
                        ));
940
0
                        (us, original_destination_address)
941
                    }
942
                    None => {
943
0
                        return Err(Error::UnknownWaypoint(format!(
944
0
                            "waypoint {} not found",
945
0
                            host.hostname
946
0
                        )));
947
                    }
948
                }
949
            }
950
        };
951
0
        self.finalize_upstream(source_workload, target_address, res)
952
0
            .await?
953
0
            .ok_or_else(|| Error::UnknownWaypoint(format!("waypoint {gw_address:?} not found")))
954
0
    }
955
956
0
    pub async fn fetch_service_waypoint(
957
0
        &self,
958
0
        service: &Service,
959
0
        source_workload: &Workload,
960
0
        original_destination_address: SocketAddr,
961
0
    ) -> Result<Option<Upstream>, Error> {
962
0
        let Some(gw_address) = &service.waypoint else {
963
            // no waypoint
964
0
            return Ok(None);
965
        };
966
0
        self.fetch_waypoint(gw_address, source_workload, original_destination_address)
967
0
            .await
968
0
            .map(Some)
969
0
    }
970
971
0
    pub async fn fetch_workload_waypoint(
972
0
        &self,
973
0
        wl: &Workload,
974
0
        source_workload: &Workload,
975
0
        original_destination_address: SocketAddr,
976
0
    ) -> Result<Option<Upstream>, Error> {
977
0
        let Some(gw_address) = &wl.waypoint else {
978
            // no waypoint
979
0
            return Ok(None);
980
        };
981
0
        self.fetch_waypoint(gw_address, source_workload, original_destination_address)
982
0
            .await
983
0
            .map(Some)
984
0
    }
985
986
    /// Looks for either a workload or service by the destination. If not found locally,
987
    /// attempts to fetch on-demand.
988
0
    pub async fn fetch_destination(&self, dest: &Destination) -> Option<Address> {
989
0
        match dest {
990
0
            Destination::Address(addr) => self.fetch_address(addr).await,
991
0
            Destination::Hostname(hostname) => self.fetch_hostname(hostname).await,
992
        }
993
0
    }
994
995
    /// Looks for the given address to find either a workload or service by IP. If not found
996
    /// locally, attempts to fetch on-demand.
997
0
    pub async fn fetch_address(&self, network_addr: &NetworkAddress) -> Option<Address> {
998
        // Wait for it on-demand, *if* needed
999
0
        debug!(%network_addr.address, "fetch address");
1000
0
        if let Some(address) = self.read().find_address(network_addr) {
1001
0
            return Some(address);
1002
0
        }
1003
0
        if !self.supports_on_demand() {
1004
0
            return None;
1005
0
        }
1006
        // if both cache not found, start on demand fetch
1007
0
        self.fetch_on_demand(network_addr.to_string().into()).await;
1008
0
        self.read().find_address(network_addr)
1009
0
    }
1010
1011
    /// Looks for the given hostname to find either a workload or service by IP. If not found
1012
    /// locally, attempts to fetch on-demand.
1013
0
    async fn fetch_hostname(&self, hostname: &NamespacedHostname) -> Option<Address> {
1014
        // Wait for it on-demand, *if* needed
1015
0
        debug!(%hostname, "fetch hostname");
1016
0
        if let Some(address) = self.read().find_hostname(hostname) {
1017
0
            return Some(address);
1018
0
        }
1019
0
        if !self.supports_on_demand() {
1020
0
            return None;
1021
0
        }
1022
        // if both cache not found, start on demand fetch
1023
0
        self.fetch_on_demand(hostname.to_string().into()).await;
1024
0
        self.read().find_hostname(hostname)
1025
0
    }
1026
1027
0
    pub fn supports_on_demand(&self) -> bool {
1028
0
        self.demand.is_some()
1029
0
    }
1030
1031
    /// fetch_on_demand looks up the provided key on-demand and waits for it to return
1032
0
    pub async fn fetch_on_demand(&self, key: Strng) {
1033
0
        if let Some(demand) = &self.demand {
1034
0
            debug!(%key, "sending demand request");
1035
0
            Box::pin(
1036
0
                demand
1037
0
                    .demand(xds::ADDRESS_TYPE, key.clone())
1038
0
                    .then(|o| o.recv()),
1039
            )
1040
0
            .await;
1041
0
            debug!(%key, "on demand ready");
1042
0
        }
1043
0
    }
1044
}
1045
1046
#[derive(Eq, PartialEq, Clone, Copy, Debug)]
1047
pub enum ServiceResolutionMode {
1048
    // We are resolving a normal service
1049
    Standard,
1050
    // We are resolving a waypoint proxy
1051
    Waypoint,
1052
}
1053
1054
#[derive(serde::Serialize)]
1055
pub struct ProxyStateManager {
1056
    #[serde(flatten)]
1057
    state: DemandProxyState,
1058
1059
    #[serde(skip_serializing)]
1060
    xds_client: Option<AdsClient>,
1061
}
1062
1063
impl ProxyStateManager {
1064
0
    pub async fn new(
1065
0
        config: Arc<config::Config>,
1066
0
        xds_metrics: xds::Metrics,
1067
0
        proxy_metrics: Arc<proxy::Metrics>,
1068
0
        awaiting_ready: tokio::sync::watch::Sender<()>,
1069
0
        cert_manager: Arc<SecretManager>,
1070
0
    ) -> anyhow::Result<ProxyStateManager> {
1071
0
        let cert_fetcher = cert_fetcher::new(&config, cert_manager);
1072
0
        let state: Arc<RwLock<ProxyState>> = Arc::new(RwLock::new(ProxyState::new(
1073
0
            config.local_node.as_ref().map(strng::new),
1074
        )));
1075
0
        let xds_client = if config.xds_address.is_some() {
1076
0
            let updater = ProxyStateUpdater::new(state.clone(), cert_fetcher.clone());
1077
0
            let tls_client_fetcher = Box::new(tls::ControlPlaneAuthentication::RootCert(
1078
0
                config.xds_root_cert.clone(),
1079
0
            ));
1080
0
            Some(
1081
0
                xds::Config::new(config.clone(), tls_client_fetcher)
1082
0
                    .with_watched_handler::<XdsAddress>(xds::ADDRESS_TYPE, updater.clone())
1083
0
                    .with_watched_handler::<XdsAuthorization>(xds::AUTHORIZATION_TYPE, updater)
1084
0
                    .build(xds_metrics, awaiting_ready),
1085
0
            )
1086
        } else {
1087
0
            None
1088
        };
1089
0
        if let Some(cfg) = &config.local_xds_config {
1090
0
            let local_client = LocalClient {
1091
0
                local_node: config.local_node.as_ref().map(strng::new),
1092
0
                cfg: cfg.clone(),
1093
0
                state: state.clone(),
1094
0
                cert_fetcher,
1095
0
            };
1096
0
            local_client.run().await?;
1097
0
        }
1098
0
        let demand = xds_client.as_ref().and_then(AdsClient::demander);
1099
0
        Ok(ProxyStateManager {
1100
0
            xds_client,
1101
0
            state: DemandProxyState::new(
1102
0
                state,
1103
0
                demand,
1104
0
                config.dns_resolver_cfg.clone(),
1105
0
                config.dns_resolver_opts.clone(),
1106
0
                proxy_metrics,
1107
0
            ),
1108
0
        })
1109
0
    }
1110
1111
0
    pub fn state(&self) -> DemandProxyState {
1112
0
        self.state.clone()
1113
0
    }
1114
1115
0
    pub async fn run(self) -> anyhow::Result<()> {
1116
0
        match self.xds_client {
1117
0
            Some(xds) => xds.run().await.map_err(|e| anyhow::anyhow!(e)),
1118
0
            None => Ok(()),
1119
        }
1120
0
    }
1121
}
1122
1123
#[cfg(test)]
1124
mod tests {
1125
    use crate::state::service::{EndpointSet, LoadBalancer, LoadBalancerHealthPolicy};
1126
    use crate::state::workload::{HealthStatus, Locality};
1127
    use prometheus_client::registry::Registry;
1128
    use rbac::StringMatch;
1129
    use std::{net::Ipv4Addr, net::SocketAddrV4, time::Duration};
1130
1131
    use self::workload::{ApplicationTunnel, application_tunnel::Protocol as AppProtocol};
1132
1133
    use super::*;
1134
    use crate::test_helpers::helpers::initialize_telemetry;
1135
1136
    use crate::{strng, test_helpers};
1137
    use test_case::test_case;
1138
1139
    #[tokio::test]
1140
    async fn test_wait_for_workload() {
1141
        let mut state = ProxyState::new(None);
1142
        let delayed_wl = Arc::new(test_helpers::test_default_workload());
1143
        state.workloads.insert(delayed_wl.clone());
1144
1145
        let mut registry = Registry::default();
1146
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1147
        let mock_proxy_state = DemandProxyState::new(
1148
            Arc::new(RwLock::new(state)),
1149
            None,
1150
            ResolverConfig::default(),
1151
            ResolverOpts::default(),
1152
            metrics,
1153
        );
1154
1155
        let want = WorkloadInfo {
1156
            name: delayed_wl.name.to_string(),
1157
            namespace: delayed_wl.namespace.to_string(),
1158
            service_account: delayed_wl.service_account.to_string(),
1159
        };
1160
1161
        test_helpers::assert_eventually(
1162
            Duration::from_secs(1),
1163
            || mock_proxy_state.wait_for_workload(&want, Duration::from_millis(50)),
1164
            Some(delayed_wl),
1165
        )
1166
        .await;
1167
    }
1168
1169
    #[tokio::test]
1170
    async fn test_wait_for_workload_delay_fails() {
1171
        let state = ProxyState::new(None);
1172
1173
        let mut registry = Registry::default();
1174
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1175
        let mock_proxy_state = DemandProxyState::new(
1176
            Arc::new(RwLock::new(state)),
1177
            None,
1178
            ResolverConfig::default(),
1179
            ResolverOpts::default(),
1180
            metrics,
1181
        );
1182
1183
        let want = WorkloadInfo {
1184
            name: "fake".to_string(),
1185
            namespace: "fake".to_string(),
1186
            service_account: "fake".to_string(),
1187
        };
1188
1189
        test_helpers::assert_eventually(
1190
            Duration::from_millis(10),
1191
            || mock_proxy_state.wait_for_workload(&want, Duration::from_millis(5)),
1192
            None,
1193
        )
1194
        .await;
1195
    }
1196
1197
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1198
    async fn test_wait_for_workload_eventually() {
1199
        initialize_telemetry();
1200
        let state = ProxyState::new(None);
1201
        let wrap_state = Arc::new(RwLock::new(state));
1202
        let not_delayed_wl = Arc::new(Workload {
1203
            workload_ips: vec!["1.2.3.4".parse().unwrap()],
1204
            uid: "uid".into(),
1205
            name: "n".into(),
1206
            namespace: "ns".into(),
1207
            ..test_helpers::test_default_workload()
1208
        });
1209
        let delayed_wl = Arc::new(test_helpers::test_default_workload());
1210
1211
        let mut registry = Registry::default();
1212
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1213
        let mock_proxy_state = DemandProxyState::new(
1214
            wrap_state.clone(),
1215
            None,
1216
            ResolverConfig::default(),
1217
            ResolverOpts::default(),
1218
            metrics,
1219
        );
1220
1221
        // Some from Address
1222
        let want = WorkloadInfo {
1223
            name: delayed_wl.name.to_string(),
1224
            namespace: delayed_wl.namespace.to_string(),
1225
            service_account: delayed_wl.service_account.to_string(),
1226
        };
1227
1228
        let expected_wl = delayed_wl.clone();
1229
        let t = tokio::spawn(async move {
1230
            test_helpers::assert_eventually(
1231
                Duration::from_millis(500),
1232
                || mock_proxy_state.wait_for_workload(&want, Duration::from_millis(250)),
1233
                Some(expected_wl),
1234
            )
1235
            .await;
1236
        });
1237
        // Send the wrong workload through
1238
        wrap_state.write().unwrap().workloads.insert(not_delayed_wl);
1239
        tokio::time::sleep(Duration::from_millis(100)).await;
1240
        // Send the correct workload through
1241
        wrap_state.write().unwrap().workloads.insert(delayed_wl);
1242
        t.await.expect("should not fail");
1243
    }
1244
1245
    #[tokio::test]
1246
    async fn lookup_address() {
1247
        let mut state = ProxyState::new(None);
1248
        state
1249
            .workloads
1250
            .insert(Arc::new(test_helpers::test_default_workload()));
1251
        state.services.insert(test_helpers::mock_default_service());
1252
1253
        let mut registry = Registry::default();
1254
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1255
        let mock_proxy_state = DemandProxyState::new(
1256
            Arc::new(RwLock::new(state)),
1257
            None,
1258
            ResolverConfig::default(),
1259
            ResolverOpts::default(),
1260
            metrics,
1261
        );
1262
1263
        // Some from Address
1264
        let dst = Destination::Address(NetworkAddress {
1265
            network: strng::EMPTY,
1266
            address: IpAddr::V4(Ipv4Addr::LOCALHOST),
1267
        });
1268
        test_helpers::assert_eventually(
1269
            Duration::from_secs(5),
1270
            || mock_proxy_state.fetch_destination(&dst),
1271
            Some(Address::Workload(Arc::new(
1272
                test_helpers::test_default_workload(),
1273
            ))),
1274
        )
1275
        .await;
1276
1277
        // Some from Hostname
1278
        let dst = Destination::Hostname(NamespacedHostname {
1279
            namespace: "default".into(),
1280
            hostname: "defaulthost".into(),
1281
        });
1282
        test_helpers::assert_eventually(
1283
            Duration::from_secs(5),
1284
            || mock_proxy_state.fetch_destination(&dst),
1285
            Some(Address::Service(Arc::new(
1286
                test_helpers::mock_default_service(),
1287
            ))),
1288
        )
1289
        .await;
1290
1291
        // None from Address
1292
        let dst = Destination::Address(NetworkAddress {
1293
            network: "".into(),
1294
            address: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)),
1295
        });
1296
        test_helpers::assert_eventually(
1297
            Duration::from_secs(5),
1298
            || mock_proxy_state.fetch_destination(&dst),
1299
            None,
1300
        )
1301
        .await;
1302
1303
        // None from Hostname
1304
        let dst = Destination::Hostname(NamespacedHostname {
1305
            namespace: "default".into(),
1306
            hostname: "nothost".into(),
1307
        });
1308
        test_helpers::assert_eventually(
1309
            Duration::from_secs(5),
1310
            || mock_proxy_state.fetch_destination(&dst),
1311
            None,
1312
        )
1313
        .await;
1314
    }
1315
1316
    enum PortMappingTestCase {
1317
        EndpointMapping,
1318
        ServiceMapping,
1319
        AppTunnel,
1320
    }
1321
1322
    impl PortMappingTestCase {
1323
        fn service_mapping(&self) -> HashMap<u16, u16> {
1324
            if let PortMappingTestCase::ServiceMapping = self {
1325
                return HashMap::from([(80, 8080)]);
1326
            }
1327
            HashMap::from([(80, 0)])
1328
        }
1329
1330
        fn endpoint_mapping(&self) -> HashMap<u16, u16> {
1331
            if let PortMappingTestCase::EndpointMapping = self {
1332
                return HashMap::from([(80, 9090)]);
1333
            }
1334
            HashMap::from([])
1335
        }
1336
1337
        fn app_tunnel(&self) -> Option<ApplicationTunnel> {
1338
            if let PortMappingTestCase::AppTunnel = self {
1339
                return Some(ApplicationTunnel {
1340
                    protocol: AppProtocol::PROXY,
1341
                    port: Some(15088),
1342
                });
1343
            }
1344
            None
1345
        }
1346
1347
        fn expected_port(&self) -> u16 {
1348
            match self {
1349
                PortMappingTestCase::ServiceMapping => 8080,
1350
                PortMappingTestCase::EndpointMapping => 9090,
1351
                _ => 80,
1352
            }
1353
        }
1354
    }
1355
1356
    #[test_case(PortMappingTestCase::EndpointMapping; "ep mapping")]
1357
    #[test_case(PortMappingTestCase::ServiceMapping; "svc mapping")]
1358
    #[test_case(PortMappingTestCase::AppTunnel; "app tunnel")]
1359
    #[tokio::test]
1360
    async fn find_upstream_port_mappings(tc: PortMappingTestCase) {
1361
        initialize_telemetry();
1362
        let wl = Workload {
1363
            uid: "cluster1//v1/Pod/default/ep_no_port_mapping".into(),
1364
            name: "ep_no_port_mapping".into(),
1365
            namespace: "default".into(),
1366
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))],
1367
            application_tunnel: tc.app_tunnel(),
1368
            ..test_helpers::test_default_workload()
1369
        };
1370
        let svc = Service {
1371
            name: "test-svc".into(),
1372
            hostname: "example.com".into(),
1373
            namespace: "default".into(),
1374
            vips: vec![NetworkAddress {
1375
                address: "10.0.0.1".parse().unwrap(),
1376
                network: "".into(),
1377
            }],
1378
            endpoints: EndpointSet::from_list([Endpoint {
1379
                workload_uid: "cluster1//v1/Pod/default/ep_no_port_mapping".into(),
1380
                port: tc.endpoint_mapping(),
1381
                status: HealthStatus::Healthy,
1382
            }]),
1383
            ports: tc.service_mapping(),
1384
            ..test_helpers::mock_default_service()
1385
        };
1386
1387
        let mut state = ProxyState::new(None);
1388
        state.workloads.insert(wl.clone().into());
1389
        state.services.insert(svc);
1390
1391
        let mode = match tc {
1392
            PortMappingTestCase::AppTunnel => ServiceResolutionMode::Waypoint,
1393
            _ => ServiceResolutionMode::Standard,
1394
        };
1395
1396
        let port = match state.find_upstream("".into(), &wl, "10.0.0.1:80".parse().unwrap(), mode) {
1397
            Some(UpstreamDestination::UpstreamParts(_, port, _)) => port,
1398
            _ => panic!("upstream to be found"),
1399
        };
1400
1401
        assert_eq!(port, tc.expected_port());
1402
    }
1403
1404
    fn create_workload(dest_uid: u8) -> Workload {
1405
        Workload {
1406
            name: "test".into(),
1407
            namespace: format!("ns{dest_uid}").into(),
1408
            trust_domain: "cluster.local".into(),
1409
            service_account: "defaultacct".into(),
1410
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, dest_uid))],
1411
            uid: format!("{dest_uid}").into(),
1412
            ..test_helpers::test_default_workload()
1413
        }
1414
    }
1415
1416
    fn get_workload(state: &DemandProxyState, dest_uid: u8) -> Arc<Workload> {
1417
        let key: Strng = format!("{dest_uid}").into();
1418
        state.read().workloads.by_uid[&key].clone()
1419
    }
1420
1421
    fn get_rbac_context(
1422
        state: &DemandProxyState,
1423
        dest_uid: u8,
1424
        src_svc_acct: &str,
1425
    ) -> crate::state::ProxyRbacContext {
1426
        let key: Strng = format!("{dest_uid}").into();
1427
        let workload = &state.read().workloads.by_uid[&key];
1428
        crate::state::ProxyRbacContext {
1429
            conn: rbac::Connection {
1430
                src_identity: Some(Identity::Spiffe {
1431
                    trust_domain: "cluster.local".into(),
1432
                    namespace: "default".into(),
1433
                    service_account: src_svc_acct.to_string().into(),
1434
                }),
1435
                src: std::net::SocketAddr::V4(SocketAddrV4::new(
1436
                    Ipv4Addr::new(192, 168, 1, 1),
1437
                    1234,
1438
                )),
1439
                dst_network: "".into(),
1440
                dst: SocketAddr::new(workload.workload_ips[0], 8080),
1441
            },
1442
            dest_workload: get_workload(state, dest_uid),
1443
        }
1444
    }
1445
    fn create_state(state: ProxyState) -> DemandProxyState {
1446
        let mut registry = Registry::default();
1447
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1448
        DemandProxyState::new(
1449
            Arc::new(RwLock::new(state)),
1450
            None,
1451
            ResolverConfig::default(),
1452
            ResolverOpts::default(),
1453
            metrics,
1454
        )
1455
    }
1456
1457
    // test that we confirm with https://istio.io/latest/docs/reference/config/security/authorization-policy/.
1458
    // We don't test #1 as ztunnel doesn't support custom policies.
1459
    // 1. If there are any CUSTOM policies that match the request, evaluate and deny the request if the evaluation result is deny.
1460
    // 2. If there are any DENY policies that match the request, deny the request.
1461
    // 3. If there are no ALLOW policies for the workload, allow the request.
1462
    // 4. If any of the ALLOW policies match the request, allow the request.
1463
    // 5. Deny the request.
1464
    #[tokio::test]
1465
    async fn assert_rbac_logic_deny_allow() {
1466
        let mut state = ProxyState::new(None);
1467
        state.workloads.insert(Arc::new(create_workload(1)));
1468
        state.workloads.insert(Arc::new(create_workload(2)));
1469
        state.policies.insert(
1470
            "allow".into(),
1471
            rbac::Authorization {
1472
                action: rbac::RbacAction::Allow,
1473
                namespace: "ns1".into(),
1474
                name: "foo".into(),
1475
                rules: vec![
1476
                    // rule1:
1477
                    vec![
1478
                        // from:
1479
                        vec![rbac::RbacMatch {
1480
                            principals: vec![StringMatch::Exact(
1481
                                "cluster.local/ns/default/sa/defaultacct".into(),
1482
                            )],
1483
                            ..Default::default()
1484
                        }],
1485
                    ],
1486
                ],
1487
                scope: rbac::RbacScope::Namespace,
1488
            },
1489
        );
1490
        state.policies.insert(
1491
            "deny".into(),
1492
            rbac::Authorization {
1493
                action: rbac::RbacAction::Deny,
1494
                namespace: "ns1".into(),
1495
                name: "deny".into(),
1496
                rules: vec![
1497
                    // rule1:
1498
                    vec![
1499
                        // from:
1500
                        vec![rbac::RbacMatch {
1501
                            principals: vec![StringMatch::Exact(
1502
                                "cluster.local/ns/default/sa/denyacct".into(),
1503
                            )],
1504
                            ..Default::default()
1505
                        }],
1506
                    ],
1507
                ],
1508
                scope: rbac::RbacScope::Namespace,
1509
            },
1510
        );
1511
1512
        let mock_proxy_state = create_state(state);
1513
1514
        // test workload in ns2. this should work as ns2 doesn't have any policies. this tests:
1515
        // 3. If there are no ALLOW policies for the workload, allow the request.
1516
        assert!(
1517
            mock_proxy_state
1518
                .assert_rbac(&get_rbac_context(&mock_proxy_state, 2, "not-defaultacct"))
1519
                .await
1520
                .is_ok()
1521
        );
1522
1523
        let ctx = get_rbac_context(&mock_proxy_state, 1, "defaultacct");
1524
        // 4. if any allow policies match, allow
1525
        assert!(mock_proxy_state.assert_rbac(&ctx).await.is_ok());
1526
1527
        {
1528
            // test a src workload with unknown svc account. this should fail as we have allow policies,
1529
            // but they don't match.
1530
            // 5. deny the request
1531
            let mut ctx = ctx.clone();
1532
            ctx.conn.src_identity = Some(Identity::Spiffe {
1533
                trust_domain: "cluster.local".into(),
1534
                namespace: "default".into(),
1535
                service_account: "not-defaultacct".into(),
1536
            });
1537
1538
            assert_eq!(
1539
                mock_proxy_state.assert_rbac(&ctx).await.err().unwrap(),
1540
                proxy::AuthorizationRejectionError::NotAllowed
1541
            );
1542
        }
1543
        {
1544
            let mut ctx = ctx.clone();
1545
            ctx.conn.src_identity = Some(Identity::Spiffe {
1546
                trust_domain: "cluster.local".into(),
1547
                namespace: "default".into(),
1548
                service_account: "denyacct".into(),
1549
            });
1550
1551
            // 2. If there are any DENY policies that match the request, deny the request.
1552
            assert_eq!(
1553
                mock_proxy_state.assert_rbac(&ctx).await.err().unwrap(),
1554
                proxy::AuthorizationRejectionError::ExplicitlyDenied("ns1".into(), "deny".into())
1555
            );
1556
        }
1557
    }
1558
1559
    #[tokio::test]
1560
    async fn assert_rbac_with_dest_workload_info() {
1561
        let mut state = ProxyState::new(None);
1562
        state.workloads.insert(Arc::new(create_workload(1)));
1563
1564
        let mock_proxy_state = create_state(state);
1565
1566
        let ctx = get_rbac_context(&mock_proxy_state, 1, "defaultacct");
1567
        assert!(mock_proxy_state.assert_rbac(&ctx).await.is_ok());
1568
    }
1569
1570
    #[tokio::test]
1571
    async fn test_load_balance() {
1572
        initialize_telemetry();
1573
        let mut state = ProxyState::new(None);
1574
        let wl_no_locality = Workload {
1575
            uid: "cluster1//v1/Pod/default/wl_no_locality".into(),
1576
            name: "wl_no_locality".into(),
1577
            namespace: "default".into(),
1578
            trust_domain: "cluster.local".into(),
1579
            service_account: "default".into(),
1580
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))],
1581
            ..test_helpers::test_default_workload()
1582
        };
1583
        let wl_match = Workload {
1584
            uid: "cluster1//v1/Pod/default/wl_match".into(),
1585
            name: "wl_match".into(),
1586
            namespace: "default".into(),
1587
            trust_domain: "cluster.local".into(),
1588
            service_account: "default".into(),
1589
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))],
1590
            network: "network".into(),
1591
            locality: Locality {
1592
                region: "reg".into(),
1593
                zone: "zone".into(),
1594
                subzone: "".into(),
1595
            },
1596
            ..test_helpers::test_default_workload()
1597
        };
1598
        let wl_almost = Workload {
1599
            uid: "cluster1//v1/Pod/default/wl_almost".into(),
1600
            name: "wl_almost".into(),
1601
            namespace: "default".into(),
1602
            trust_domain: "cluster.local".into(),
1603
            service_account: "default".into(),
1604
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 3))],
1605
            network: "network".into(),
1606
            locality: Locality {
1607
                region: "reg".into(),
1608
                zone: "not-zone".into(),
1609
                subzone: "".into(),
1610
            },
1611
            ..test_helpers::test_default_workload()
1612
        };
1613
        let wl_empty_ip = Workload {
1614
            uid: "cluster1//v1/Pod/default/wl_empty_ip".into(),
1615
            name: "wl_empty_ip".into(),
1616
            namespace: "default".into(),
1617
            trust_domain: "cluster.local".into(),
1618
            service_account: "default".into(),
1619
            workload_ips: vec![], // none!
1620
            network: "network".into(),
1621
            locality: Locality {
1622
                region: "reg".into(),
1623
                zone: "zone".into(),
1624
                subzone: "".into(),
1625
            },
1626
            ..test_helpers::test_default_workload()
1627
        };
1628
1629
        let _ep_almost = Workload {
1630
            uid: "cluster1//v1/Pod/default/ep_almost".into(),
1631
            name: "wl_almost".into(),
1632
            namespace: "default".into(),
1633
            trust_domain: "cluster.local".into(),
1634
            service_account: "default".into(),
1635
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 4))],
1636
            network: "network".into(),
1637
            locality: Locality {
1638
                region: "reg".into(),
1639
                zone: "other-not-zone".into(),
1640
                subzone: "".into(),
1641
            },
1642
            ..test_helpers::test_default_workload()
1643
        };
1644
        let _ep_no_match = Workload {
1645
            uid: "cluster1//v1/Pod/default/ep_no_match".into(),
1646
            name: "wl_almost".into(),
1647
            namespace: "default".into(),
1648
            trust_domain: "cluster.local".into(),
1649
            service_account: "default".into(),
1650
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 5))],
1651
            network: "not-network".into(),
1652
            locality: Locality {
1653
                region: "not-reg".into(),
1654
                zone: "unmatched-zone".into(),
1655
                subzone: "".into(),
1656
            },
1657
            ..test_helpers::test_default_workload()
1658
        };
1659
        let endpoints = EndpointSet::from_list([
1660
            Endpoint {
1661
                workload_uid: "cluster1//v1/Pod/default/ep_almost".into(),
1662
                port: HashMap::from([(80u16, 80u16)]),
1663
                status: HealthStatus::Healthy,
1664
            },
1665
            Endpoint {
1666
                workload_uid: "cluster1//v1/Pod/default/ep_no_match".into(),
1667
                port: HashMap::from([(80u16, 80u16)]),
1668
                status: HealthStatus::Healthy,
1669
            },
1670
            Endpoint {
1671
                workload_uid: "cluster1//v1/Pod/default/wl_match".into(),
1672
                port: HashMap::from([(80u16, 80u16)]),
1673
                status: HealthStatus::Healthy,
1674
            },
1675
            Endpoint {
1676
                workload_uid: "cluster1//v1/Pod/default/wl_empty_ip".into(),
1677
                port: HashMap::from([(80u16, 80u16)]),
1678
                status: HealthStatus::Healthy,
1679
            },
1680
        ]);
1681
        let strict_svc = Service {
1682
            endpoints: endpoints.clone(),
1683
            load_balancer: Some(LoadBalancer {
1684
                mode: LoadBalancerMode::Strict,
1685
                routing_preferences: vec![
1686
                    LoadBalancerScopes::Network,
1687
                    LoadBalancerScopes::Region,
1688
                    LoadBalancerScopes::Zone,
1689
                ],
1690
                health_policy: LoadBalancerHealthPolicy::OnlyHealthy,
1691
            }),
1692
            ports: HashMap::from([(80u16, 80u16)]),
1693
            ..test_helpers::mock_default_service()
1694
        };
1695
        let failover_svc = Service {
1696
            endpoints,
1697
            load_balancer: Some(LoadBalancer {
1698
                mode: LoadBalancerMode::Failover,
1699
                routing_preferences: vec![
1700
                    LoadBalancerScopes::Network,
1701
                    LoadBalancerScopes::Region,
1702
                    LoadBalancerScopes::Zone,
1703
                ],
1704
                health_policy: LoadBalancerHealthPolicy::OnlyHealthy,
1705
            }),
1706
            ports: HashMap::from([(80u16, 80u16)]),
1707
            ..test_helpers::mock_default_service()
1708
        };
1709
        state.workloads.insert(Arc::new(wl_no_locality.clone()));
1710
        state.workloads.insert(Arc::new(wl_match.clone()));
1711
        state.workloads.insert(Arc::new(wl_almost.clone()));
1712
        state.workloads.insert(Arc::new(wl_empty_ip.clone()));
1713
        state.services.insert(strict_svc.clone());
1714
        state.services.insert(failover_svc.clone());
1715
1716
        let assert_endpoint = |src: &Workload, svc: &Service, workloads: Vec<&str>, desc: &str| {
1717
            let got = state
1718
                .load_balance(src, svc, 80, ServiceResolutionMode::Standard)
1719
                .map(|(ep, _)| ep.workload_uid.to_string());
1720
            if workloads.is_empty() {
1721
                assert!(got.is_none(), "{}", desc);
1722
            } else {
1723
                let want: Vec<String> = workloads.iter().map(ToString::to_string).collect();
1724
                assert!(want.contains(&got.unwrap()), "{}", desc);
1725
            }
1726
        };
1727
        let assert_not_endpoint =
1728
            |src: &Workload, svc: &Service, uid: &str, tries: usize, desc: &str| {
1729
                for _ in 0..tries {
1730
                    let got = state
1731
                        .load_balance(src, svc, 80, ServiceResolutionMode::Standard)
1732
                        .map(|(ep, _)| ep.workload_uid.as_str());
1733
                    assert!(got != Some(uid), "{}", desc);
1734
                }
1735
            };
1736
1737
        assert_endpoint(
1738
            &wl_no_locality,
1739
            &strict_svc,
1740
            vec![],
1741
            "strict no match should not select",
1742
        );
1743
        assert_endpoint(
1744
            &wl_almost,
1745
            &strict_svc,
1746
            vec![],
1747
            "strict no match should not select",
1748
        );
1749
        assert_endpoint(
1750
            &wl_match,
1751
            &strict_svc,
1752
            vec!["cluster1//v1/Pod/default/wl_match"],
1753
            "strict match",
1754
        );
1755
1756
        assert_endpoint(
1757
            &wl_no_locality,
1758
            &failover_svc,
1759
            vec![
1760
                "cluster1//v1/Pod/default/ep_almost",
1761
                "cluster1//v1/Pod/default/ep_no_match",
1762
                "cluster1//v1/Pod/default/wl_match",
1763
            ],
1764
            "failover no match can select any endpoint",
1765
        );
1766
        assert_endpoint(
1767
            &wl_almost,
1768
            &failover_svc,
1769
            vec![
1770
                "cluster1//v1/Pod/default/ep_almost",
1771
                "cluster1//v1/Pod/default/wl_match",
1772
            ],
1773
            "failover almost match can select any close matches",
1774
        );
1775
        assert_endpoint(
1776
            &wl_match,
1777
            &failover_svc,
1778
            vec!["cluster1//v1/Pod/default/wl_match"],
1779
            "failover full match selects closest match",
1780
        );
1781
        assert_not_endpoint(
1782
            &wl_no_locality,
1783
            &failover_svc,
1784
            "cluster1//v1/Pod/default/wl_empty_ip",
1785
            10,
1786
            "failover no match can select any endpoint",
1787
        );
1788
    }
1789
}