Coverage Report

Created: 2026-04-14 06:46

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/ztunnel/src/state.rs
Line
Count
Source
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use crate::authpol_log;
16
use crate::identity::{Identity, SecretManager};
17
use crate::proxy::{Error, OnDemandDnsLabels};
18
use crate::rbac::Authorization;
19
use crate::state::policy::PolicyStore;
20
use crate::state::service::{
21
    Endpoint, IpFamily, LoadBalancerMode, LoadBalancerScopes, ServiceStore,
22
};
23
use crate::state::service::{Service, ServiceDescription};
24
use crate::state::workload::{
25
    GatewayAddress, NamespacedHostname, NetworkAddress, Workload, WorkloadStore, address::Address,
26
    gatewayaddress::Destination, network_addr,
27
};
28
use crate::strng::Strng;
29
use crate::tls;
30
use crate::xds::istio::security::Authorization as XdsAuthorization;
31
use crate::xds::istio::workload::Address as XdsAddress;
32
use crate::xds::{AdsClient, Demander, LocalClient, ProxyStateUpdater};
33
use crate::{cert_fetcher, config, rbac, xds};
34
use crate::{proxy, strng};
35
use educe::Educe;
36
use futures_util::FutureExt;
37
use hickory_resolver::TokioResolver;
38
use hickory_resolver::config::*;
39
use hickory_resolver::name_server::TokioConnectionProvider;
40
use itertools::Itertools;
41
use rand::prelude::IteratorRandom;
42
use rand::seq::IndexedRandom;
43
use serde::Serializer;
44
use std::collections::HashMap;
45
use std::convert::Into;
46
use std::default::Default;
47
use std::fmt;
48
use std::net::{IpAddr, SocketAddr};
49
use std::str::FromStr;
50
use std::sync::{Arc, RwLock, RwLockReadGuard};
51
use std::time::Duration;
52
use tracing::{debug, trace, warn};
53
54
use self::workload::ApplicationTunnel;
55
56
pub mod policy;
57
pub mod service;
58
pub mod workload;
59
60
#[derive(Debug, Eq, PartialEq, Clone)]
61
pub struct Upstream {
62
    /// Workload is the workload we are connecting to
63
    pub workload: Arc<Workload>,
64
    /// selected_workload_ip defines the IP address we should actually use to connect to this workload
65
    /// This handles multiple IPs (dual stack) or Hostname destinations (DNS resolution)
66
    /// The workload IP might be empty if we have to go through a network gateway.
67
    pub selected_workload_ip: Option<IpAddr>,
68
    /// Port is the port we should connect to
69
    pub port: u16,
70
    /// Service SANs defines SANs defined at the service level *only*. A complete view of things requires
71
    /// looking at workload.identity() as well.
72
    pub service_sans: Vec<Strng>,
73
    /// If this was from a service, the service info.
74
    pub destination_service: Option<ServiceDescription>,
75
}
76
77
#[derive(Clone, Debug, Eq, PartialEq)]
78
enum UpstreamDestination {
79
    UpstreamParts(Arc<Workload>, u16, Option<Arc<Service>>),
80
    OriginalDestination,
81
}
82
83
impl Upstream {
84
0
    pub fn workload_socket_addr(&self) -> Option<SocketAddr> {
85
0
        self.selected_workload_ip
86
0
            .map(|ip| SocketAddr::new(ip, self.port))
87
0
    }
88
0
    pub fn workload_and_services_san(&self) -> Vec<Identity> {
89
0
        self.service_sans
90
0
            .iter()
91
0
            .flat_map(|san| match Identity::from_str(san) {
92
0
                Ok(id) => Some(id),
93
0
                Err(err) => {
94
0
                    warn!("ignoring invalid SAN {}: {}", san, err);
95
0
                    None
96
                }
97
0
            })
98
0
            .chain(std::iter::once(self.workload.identity()))
99
0
            .collect()
100
0
    }
101
102
0
    pub fn service_sans(&self) -> Vec<Identity> {
103
0
        self.service_sans
104
0
            .iter()
105
0
            .flat_map(|san| match Identity::from_str(san) {
106
0
                Ok(id) => Some(id),
107
0
                Err(err) => {
108
0
                    warn!("ignoring invalid SAN {}: {}", san, err);
109
0
                    None
110
                }
111
0
            })
112
0
            .collect()
113
0
    }
114
}
115
116
// Workload information that a specific proxy instance represents. This is used to cross check
117
// with the workload fetched using destination address when making RBAC decisions.
118
#[derive(
119
    Debug, Clone, Eq, Hash, Ord, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize,
120
)]
121
#[serde(rename_all = "camelCase")]
122
pub struct WorkloadInfo {
123
    pub name: String,
124
    pub namespace: String,
125
    pub service_account: String,
126
}
127
128
impl fmt::Display for WorkloadInfo {
129
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130
0
        write!(
131
0
            f,
132
0
            "{}.{} ({})",
133
            self.service_account, self.namespace, self.name
134
        )
135
0
    }
136
}
137
138
impl WorkloadInfo {
139
0
    pub fn new(name: String, namespace: String, service_account: String) -> Self {
140
0
        Self {
141
0
            name,
142
0
            namespace,
143
0
            service_account,
144
0
        }
145
0
    }
146
147
0
    pub fn matches(&self, w: &Workload) -> bool {
148
0
        self.name == w.name
149
0
            && self.namespace == w.namespace
150
0
            && self.service_account == w.service_account
151
0
    }
152
}
153
154
#[derive(Educe, Debug, Clone, Eq, serde::Serialize)]
155
#[educe(PartialEq, Hash)]
156
pub struct ProxyRbacContext {
157
    pub conn: rbac::Connection,
158
    #[educe(Hash(ignore), PartialEq(ignore))]
159
    pub dest_workload: Arc<Workload>,
160
}
161
162
impl ProxyRbacContext {
163
0
    pub fn into_conn(self) -> rbac::Connection {
164
0
        self.conn
165
0
    }
166
}
167
168
impl fmt::Display for ProxyRbacContext {
169
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170
0
        write!(f, "{} ({})", self.conn, self.dest_workload.uid)?;
171
0
        Ok(())
172
0
    }
173
}
174
/// The current state information for this proxy.
175
#[derive(Debug)]
176
pub struct ProxyState {
177
    pub workloads: WorkloadStore,
178
179
    pub services: ServiceStore,
180
181
    pub policies: PolicyStore,
182
}
183
184
#[derive(serde::Serialize, Debug)]
185
#[serde(rename_all = "camelCase")]
186
struct ProxyStateSerialization<'a> {
187
    workloads: Vec<Arc<Workload>>,
188
    services: Vec<Arc<Service>>,
189
    policies: Vec<Authorization>,
190
    staged_services: &'a HashMap<NamespacedHostname, HashMap<Strng, Endpoint>>,
191
}
192
193
impl serde::Serialize for ProxyState {
194
0
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
195
0
    where
196
0
        S: Serializer,
197
    {
198
        // Services all have hostname, so use that as the key
199
0
        let services: Vec<_> = self
200
0
            .services
201
0
            .by_host
202
0
            .iter()
203
0
            .sorted_by_key(|k| k.0)
204
0
            .flat_map(|k| k.1)
205
0
            .cloned()
206
0
            .collect();
207
        // Workloads all have a UID, so use that as the key
208
0
        let workloads: Vec<_> = self
209
0
            .workloads
210
0
            .by_uid
211
0
            .iter()
212
0
            .sorted_by_key(|k| k.0)
213
0
            .map(|k| k.1)
214
0
            .cloned()
215
0
            .collect();
216
0
        let policies: Vec<_> = self
217
0
            .policies
218
0
            .by_key
219
0
            .iter()
220
0
            .sorted_by_key(|k| k.0)
221
0
            .map(|k| k.1)
222
0
            .cloned()
223
0
            .collect();
224
0
        let serializable = ProxyStateSerialization {
225
0
            workloads,
226
0
            services,
227
0
            policies,
228
0
            staged_services: &self.services.staged_services,
229
0
        };
230
0
        serializable.serialize(serializer)
231
0
    }
232
}
233
234
impl ProxyState {
235
0
    pub fn new(local_node: Option<Strng>) -> ProxyState {
236
0
        ProxyState {
237
0
            workloads: WorkloadStore::new(local_node),
238
0
            services: Default::default(),
239
0
            policies: Default::default(),
240
0
        }
241
0
    }
242
243
    /// Find either a workload or service by the destination.
244
    /// If `ns` is provided, prefer a service in that namespace when multiple share the same VIP.
245
0
    pub fn find_destination(&self, dest: &Destination, ns: Option<&Strng>) -> Option<Address> {
246
0
        match dest {
247
0
            Destination::Address(addr) => self.find_address(addr, ns),
248
0
            Destination::Hostname(hostname) => self.find_hostname(hostname),
249
        }
250
0
    }
251
252
    /// Find either a workload or a service by address.
253
    /// If `ns` is provided, prefer a service in that namespace when multiple share the same VIP.
254
0
    pub fn find_address(
255
0
        &self,
256
0
        network_addr: &NetworkAddress,
257
0
        ns: Option<&Strng>,
258
0
    ) -> Option<Address> {
259
        // 1. handle workload ip, if workload not found fallback to service.
260
0
        match self.workloads.find_address(network_addr) {
261
            None => {
262
                // 2. handle service
263
0
                if let Some(svc) = self.services.get_best_by_vip(network_addr, ns) {
264
0
                    return Some(Address::Service(svc));
265
0
                }
266
0
                None
267
            }
268
0
            Some(wl) => Some(Address::Workload(wl)),
269
        }
270
0
    }
271
272
    /// Find either a workload or a service by hostname.
273
0
    pub fn find_hostname(&self, name: &NamespacedHostname) -> Option<Address> {
274
        // Hostnames for services are more common, so lookup service first and fallback to workload.
275
0
        self.services
276
0
            .get_by_namespaced_host(name)
277
0
            .map(Address::Service)
278
0
            .or_else(|| {
279
                // Slow path: lookup workload by O(n) lookup. This is an uncommon path, so probably not worth
280
                // the memory cost to index currently
281
0
                self.workloads
282
0
                    .by_uid
283
0
                    .values()
284
0
                    .find(|w| w.hostname == name.hostname && w.namespace == name.namespace)
285
0
                    .cloned()
286
0
                    .map(Address::Workload)
287
0
            })
288
0
    }
289
290
    /// Find services by hostname.
291
0
    pub fn find_service_by_hostname(
292
0
        &self,
293
0
        hostname: &Strng,
294
0
        namespace: &Strng,
295
0
    ) -> Result<Arc<Service>, Error> {
296
        // Hostnames for services are more common, so lookup service first and fallback to workload.
297
0
        self.services
298
0
            .get_best_by_host(hostname, Some(namespace))
299
0
            .ok_or_else(|| Error::NoHostname(hostname.to_string()))
300
0
    }
301
302
0
    fn find_upstream(
303
0
        &self,
304
0
        network: Strng,
305
0
        source_workload: &Workload,
306
0
        addr: SocketAddr,
307
0
        resolution_mode: ServiceResolutionMode,
308
0
    ) -> Option<UpstreamDestination> {
309
0
        if let Some(svc) = self.services.get_best_by_vip(
310
0
            &network_addr(network.clone(), addr.ip()),
311
0
            Some(&source_workload.namespace),
312
0
        ) {
313
0
            if let Some(lb) = &svc.load_balancer
314
0
                && lb.mode == LoadBalancerMode::Passthrough
315
            {
316
0
                return Some(UpstreamDestination::OriginalDestination);
317
0
            }
318
0
            return self.find_upstream_from_service(
319
0
                source_workload,
320
0
                addr.port(),
321
0
                resolution_mode,
322
0
                svc,
323
            );
324
0
        }
325
0
        if let Some(wl) = self
326
0
            .workloads
327
0
            .find_address(&network_addr(network, addr.ip()))
328
        {
329
0
            return Some(UpstreamDestination::UpstreamParts(wl, addr.port(), None));
330
0
        }
331
0
        None
332
0
    }
333
334
0
    fn find_upstream_from_service(
335
0
        &self,
336
0
        source_workload: &Workload,
337
0
        svc_port: u16,
338
0
        resolution_mode: ServiceResolutionMode,
339
0
        svc: Arc<Service>,
340
0
    ) -> Option<UpstreamDestination> {
341
        // Randomly pick an upstream
342
        // TODO: do this more efficiently, and not just randomly
343
0
        let Some((ep, wl)) = self.load_balance(source_workload, &svc, svc_port, resolution_mode)
344
        else {
345
0
            debug!("Service {} has no healthy endpoints", svc.hostname);
346
0
            return None;
347
        };
348
349
0
        let svc_target_port = svc.ports.get(&svc_port).copied().unwrap_or_default();
350
0
        let target_port = if let Some(&ep_target_port) = ep.port.get(&svc_port) {
351
            // prefer endpoint port mapping
352
0
            ep_target_port
353
0
        } else if svc_target_port > 0 {
354
            // otherwise, see if the service has this port
355
0
            svc_target_port
356
0
        } else if let Some(ApplicationTunnel { port: Some(_), .. }) = &wl.application_tunnel {
357
            // when using app tunnel, we don't require the port to be found on the service
358
0
            svc_port
359
        } else {
360
            // no app tunnel or port mapping, error
361
0
            debug!(
362
0
                "found service {}, but port {} was unknown",
363
0
                svc.hostname, svc_port
364
            );
365
0
            return None;
366
        };
367
368
0
        Some(UpstreamDestination::UpstreamParts(
369
0
            wl,
370
0
            target_port,
371
0
            Some(svc),
372
0
        ))
373
0
    }
374
375
0
    fn load_balance<'a>(
376
0
        &self,
377
0
        src: &Workload,
378
0
        svc: &'a Service,
379
0
        svc_port: u16,
380
0
        resolution_mode: ServiceResolutionMode,
381
0
    ) -> Option<(&'a Endpoint, Arc<Workload>)> {
382
0
        let target_port = svc.ports.get(&svc_port).copied();
383
384
0
        if resolution_mode == ServiceResolutionMode::Standard && target_port.is_none() {
385
            // Port doesn't exist on the service at all, this is invalid
386
0
            debug!("service {} does not have port {}", svc.hostname, svc_port);
387
0
            return None;
388
0
        };
389
390
0
        let endpoints = svc.endpoints.iter().filter_map(|ep| {
391
0
            let Some(wl) = self.workloads.find_uid(&ep.workload_uid) else {
392
0
                debug!("failed to fetch workload for {}", ep.workload_uid);
393
0
                return None;
394
            };
395
396
0
            let in_network = wl.network == src.network;
397
0
            let has_network_gateway = wl.network_gateway.is_some();
398
0
            let has_address = !wl.workload_ips.is_empty() || !wl.hostname.is_empty();
399
0
            if !has_address {
400
                // Workload has no IP. We can only reach it via a network gateway
401
                // WDS is client-agnostic, so we will get a network gateway for a workload
402
                // even if it's in the same network; we should never use it.
403
0
                if in_network || !has_network_gateway {
404
0
                    return None;
405
0
                }
406
0
            }
407
408
0
            match resolution_mode {
409
                ServiceResolutionMode::Standard => {
410
0
                    if target_port.unwrap_or_default() == 0 && !ep.port.contains_key(&svc_port) {
411
                        // Filter workload out, it doesn't have a matching port
412
0
                        trace!(
413
0
                            "filter endpoint {}, it does not have service port {}",
414
                            ep.workload_uid, svc_port
415
                        );
416
0
                        return None;
417
0
                    }
418
                }
419
                ServiceResolutionMode::Waypoint => {
420
0
                    if target_port.is_none() && wl.application_tunnel.is_none() {
421
                        // We ignore this for app_tunnel; in this case, the port does not need to be on the service.
422
                        // This is only valid for waypoints, which are not explicitly addressed by users.
423
                        // We do happen to do a lookup by `waypoint-svc:15008`, this is not a literal call on that service;
424
                        // the port is not required at all if they have application tunnel, as it will be handled by ztunnel on the other end.
425
0
                        trace!(
426
0
                            "filter waypoint endpoint {}, target port is not defined",
427
                            ep.workload_uid
428
                        );
429
0
                        return None;
430
0
                    }
431
                }
432
            }
433
0
            Some((ep, wl))
434
0
        });
435
436
0
        let options = match svc.load_balancer {
437
0
            Some(ref lb) if lb.mode != LoadBalancerMode::Standard => {
438
0
                let ranks = endpoints
439
0
                    .filter_map(|(ep, wl)| {
440
                        // Load balancer will define N targets we want to match
441
                        // Consider [network, region, zone]
442
                        // Rank = 3 means we match all of them
443
                        // Rank = 2 means network and region match
444
                        // Rank = 0 means none match
445
0
                        let mut rank = 0;
446
0
                        for target in &lb.routing_preferences {
447
0
                            let matches = match target {
448
                                LoadBalancerScopes::Region => {
449
0
                                    src.locality.region == wl.locality.region
450
                                }
451
0
                                LoadBalancerScopes::Zone => src.locality.zone == wl.locality.zone,
452
                                LoadBalancerScopes::Subzone => {
453
0
                                    src.locality.subzone == wl.locality.subzone
454
                                }
455
0
                                LoadBalancerScopes::Node => src.node == wl.node,
456
0
                                LoadBalancerScopes::Cluster => src.cluster_id == wl.cluster_id,
457
0
                                LoadBalancerScopes::Network => src.network == wl.network,
458
                            };
459
0
                            if matches {
460
0
                                rank += 1;
461
0
                            } else {
462
0
                                break;
463
                            }
464
                        }
465
                        // Doesn't match all, and required to. Do not select this endpoint
466
0
                        if lb.mode == LoadBalancerMode::Strict
467
0
                            && rank != lb.routing_preferences.len()
468
                        {
469
0
                            return None;
470
0
                        }
471
0
                        Some((rank, ep, wl))
472
0
                    })
473
0
                    .collect::<Vec<_>>();
474
0
                let max = *ranks.iter().map(|(rank, _ep, _wl)| rank).max()?;
475
0
                let options: Vec<_> = ranks
476
0
                    .into_iter()
477
0
                    .filter(|(rank, _ep, _wl)| *rank == max)
478
0
                    .map(|(_, ep, wl)| (ep, wl))
479
0
                    .collect();
480
0
                options
481
            }
482
0
            _ => endpoints.collect(),
483
        };
484
0
        options
485
0
            .choose_weighted(&mut rand::rng(), |(_, wl)| wl.capacity as u64)
486
            // This can fail if there are no weights, the sum is zero (not possible in our API), or if it overflows
487
            // The API has u32 but we sum into an u64, so it would take ~4 billion entries of max weight to overflow
488
0
            .ok()
489
0
            .cloned()
490
0
    }
491
}
492
493
/// Wrapper around [ProxyState] that provides additional methods for requesting information
494
/// on-demand.
495
#[derive(serde::Serialize, Clone)]
496
pub struct DemandProxyState {
497
    #[serde(flatten)]
498
    state: Arc<RwLock<ProxyState>>,
499
500
    /// If present, used to request on-demand updates for workloads.
501
    #[serde(skip_serializing)]
502
    demand: Option<Demander>,
503
504
    #[serde(skip_serializing)]
505
    metrics: Arc<proxy::Metrics>,
506
507
    #[serde(skip_serializing)]
508
    dns_resolver: TokioResolver,
509
}
510
511
impl DemandProxyState {
512
0
    pub(crate) fn get_services_by_workload(&self, wl: &Workload) -> Vec<Arc<Service>> {
513
0
        self.state
514
0
            .read()
515
0
            .expect("mutex")
516
0
            .services
517
0
            .get_by_workload(wl)
518
0
    }
519
}
520
521
impl DemandProxyState {
522
0
    pub fn new(
523
0
        state: Arc<RwLock<ProxyState>>,
524
0
        demand: Option<Demander>,
525
0
        dns_resolver_cfg: ResolverConfig,
526
0
        dns_resolver_opts: ResolverOpts,
527
0
        metrics: Arc<proxy::Metrics>,
528
0
    ) -> Self {
529
0
        let mut rb = hickory_resolver::Resolver::builder_with_config(
530
0
            dns_resolver_cfg,
531
0
            TokioConnectionProvider::default(),
532
        );
533
0
        *rb.options_mut() = dns_resolver_opts;
534
0
        let dns_resolver = rb.build();
535
0
        Self {
536
0
            state,
537
0
            demand,
538
0
            dns_resolver,
539
0
            metrics,
540
0
        }
541
0
    }
542
543
0
    pub fn read(&self) -> RwLockReadGuard<'_, ProxyState> {
544
0
        self.state.read().unwrap()
545
0
    }
546
547
0
    pub async fn assert_rbac(
548
0
        &self,
549
0
        ctx: &ProxyRbacContext,
550
0
    ) -> Result<(), proxy::AuthorizationRejectionError> {
551
0
        let wl = &ctx.dest_workload;
552
0
        let conn = &ctx.conn;
553
0
        let state = self.read();
554
555
        // We can get policies from namespace, global, and workload...
556
0
        let ns = state.policies.get_by_namespace(&wl.namespace);
557
0
        let global = state.policies.get_by_namespace(&crate::strng::EMPTY);
558
0
        let workload = wl.authorization_policies.iter();
559
560
        // Aggregate all of them based on type
561
0
        let (all_allow, all_deny): (Vec<_>, Vec<_>) = ns
562
0
            .iter()
563
0
            .chain(global.iter())
564
0
            .chain(workload)
565
0
            .filter_map(|k| {
566
0
                let pol = state.policies.get(k);
567
                // Policy not found. This is probably transition state where the policy hasn't been sent
568
                // by the control plane, or it was just removed.
569
0
                if pol.is_none() {
570
0
                    warn!("skipping unknown policy {k}");
571
0
                }
572
0
                pol
573
0
            })
574
0
            .partition(|p| p.action == rbac::RbacAction::Allow);
575
576
0
        let (deny, deny_dry_run): (Vec<&Authorization>, Vec<&Authorization>) =
577
0
            all_deny.iter().partition(|p| !p.dry_run);
578
0
        let (allow, allow_dry_run): (Vec<&Authorization>, Vec<&Authorization>) =
579
0
            all_allow.iter().partition(|p| !p.dry_run);
580
581
0
        trace!(
582
0
            allow = allow.len(),
583
0
            deny = deny.len(),
584
0
            "checking connection"
585
        );
586
587
        // Allow and deny logic follows https://istio.io/latest/docs/reference/config/security/authorization-policy/
588
589
0
        for pol in deny_dry_run.iter() {
590
0
            if pol.matches(conn) {
591
0
                authpol_log!(policy = pol.to_key().as_str(), "dry-run: deny policy match");
592
0
            }
593
        }
594
        // "If there are any DENY policies that match the request, deny the request."
595
0
        for pol in deny.iter() {
596
0
            if pol.matches(conn) {
597
0
                authpol_log!(policy = pol.to_key().as_str(), "deny policy match");
598
0
                return Err(proxy::AuthorizationRejectionError::ExplicitlyDenied(
599
0
                    pol.namespace.to_owned(),
600
0
                    pol.name.to_owned(),
601
0
                ));
602
            } else {
603
0
                trace!(policy = pol.to_key().as_str(), "deny policy does not match");
604
            }
605
        }
606
0
        let mut dry_run_allow_matched = false;
607
0
        for pol in allow_dry_run.iter() {
608
0
            if pol.matches(conn) {
609
0
                dry_run_allow_matched = true;
610
0
                authpol_log!(
611
0
                    policy = pol.to_key().as_str(),
612
0
                    "dry-run: allow policy match"
613
                );
614
0
            }
615
        }
616
0
        if allow.is_empty() && !allow_dry_run.is_empty() && !dry_run_allow_matched {
617
            // this is going to be an allow, but the conn would be denied if dry-run policies
618
            // became enforced because none matched
619
0
            authpol_log!("dry-run: no allow policies match");
620
0
        }
621
        // "If there are no ALLOW policies for the workload, allow the request."
622
0
        if allow.is_empty() {
623
0
            authpol_log!("no allow policies, allow");
624
0
            return Ok(());
625
0
        }
626
        // "If any of the ALLOW policies match the request, allow the request."
627
0
        for pol in allow.iter() {
628
0
            if pol.matches(conn) {
629
0
                authpol_log!(policy = pol.to_key().as_str(), "allow policy match");
630
0
                return Ok(());
631
            } else {
632
0
                trace!(
633
0
                    policy = pol.to_key().as_str(),
634
0
                    "allow policy does not match"
635
                );
636
            }
637
        }
638
        // "Deny the request."
639
0
        authpol_log!("no allow policies matched");
640
0
        Err(proxy::AuthorizationRejectionError::NotAllowed)
641
0
    }
642
643
    // Select a workload IP, with DNS resolution if needed
644
0
    async fn pick_workload_destination_or_resolve(
645
0
        &self,
646
0
        dst_workload: &Workload,
647
0
        src_workload: &Workload,
648
0
        original_target_address: SocketAddr,
649
0
        ip_family_restriction: Option<IpFamily>,
650
0
    ) -> Result<Option<IpAddr>, Error> {
651
        // If the user requested the pod by a specific IP, use that directly.
652
0
        if dst_workload
653
0
            .workload_ips
654
0
            .contains(&original_target_address.ip())
655
        {
656
0
            return Ok(Some(original_target_address.ip()));
657
0
        }
658
        // They may have 1 or 2 IPs (single/dual stack)
659
        // Ensure we are meeting the Service family restriction (if any is defined).
660
        // Otherwise, prefer the same IP family as the original request.
661
0
        if let Some(ip) = dst_workload
662
0
            .workload_ips
663
0
            .iter()
664
0
            .filter(|ip| {
665
0
                ip_family_restriction
666
0
                    .map(|f| f.accepts_ip(**ip))
667
0
                    .unwrap_or(true)
668
0
            })
669
0
            .find_or_first(|ip| ip.is_ipv6() == original_target_address.is_ipv6())
670
        {
671
0
            return Ok(Some(*ip));
672
0
        }
673
0
        if dst_workload.hostname.is_empty() {
674
0
            if dst_workload.network_gateway.is_none() {
675
0
                debug!(
676
0
                    "workload {} has no suitable workload IPs for routing",
677
                    dst_workload.name
678
                );
679
0
                return Err(Error::NoValidDestination(Box::new(dst_workload.clone())));
680
            } else {
681
                // We can route through network gateway
682
0
                return Ok(None);
683
            }
684
0
        }
685
0
        let ip = Box::pin(self.resolve_workload_address(
686
0
            dst_workload,
687
0
            src_workload,
688
0
            original_target_address,
689
0
        ))
690
0
        .await?;
691
0
        Ok(Some(ip))
692
0
    }
693
694
0
    async fn resolve_workload_address(
695
0
        &self,
696
0
        workload: &Workload,
697
0
        src_workload: &Workload,
698
0
        original_target_address: SocketAddr,
699
0
    ) -> Result<IpAddr, Error> {
700
0
        let labels = OnDemandDnsLabels::new()
701
0
            .with_destination(workload)
702
0
            .with_source(src_workload);
703
0
        self.metrics
704
0
            .as_ref()
705
0
            .on_demand_dns
706
0
            .get_or_create(&labels)
707
0
            .inc();
708
0
        self.resolve_on_demand_dns(workload, original_target_address)
709
0
            .await
710
0
    }
711
712
0
    async fn resolve_on_demand_dns(
713
0
        &self,
714
0
        workload: &Workload,
715
0
        original_target_address: SocketAddr,
716
0
    ) -> Result<IpAddr, Error> {
717
0
        let workload_uid = workload.uid.clone();
718
0
        let hostname = workload.hostname.clone();
719
0
        trace!(%hostname, "starting DNS lookup");
720
721
0
        let resp = match self.dns_resolver.lookup_ip(hostname.as_str()).await {
722
0
            Err(err) => {
723
0
                warn!(?err,%hostname,"dns lookup failed");
724
0
                return Err(Error::NoResolvedAddresses(workload_uid.to_string()));
725
            }
726
0
            Ok(resp) => resp,
727
        };
728
0
        trace!(%hostname, "dns lookup complete {resp:?}");
729
730
0
        let (matching, unmatching): (Vec<_>, Vec<_>) = resp
731
0
            .as_lookup()
732
0
            .record_iter()
733
0
            .filter_map(|record| record.data().ip_addr())
734
0
            .partition(|record| record.is_ipv6() == original_target_address.is_ipv6());
735
        // Randomly pick an IP, prefer to match the IP family of the downstream request.
736
        // Without this, we run into trouble in pure v4 or pure v6 environments.
737
0
        matching
738
0
            .into_iter()
739
0
            .choose(&mut rand::rng())
740
0
            .or_else(|| unmatching.into_iter().choose(&mut rand::rng()))
741
0
            .ok_or_else(|| Error::EmptyResolvedAddresses(workload_uid.to_string()))
742
0
    }
743
744
    // same as fetch_workload, but if the caller knows the workload is enroute already,
745
    // will retry on cache miss for a configured amount of time - returning the workload
746
    // when we get it, or nothing if the timeout is exceeded, whichever happens first
747
0
    pub async fn wait_for_workload(
748
0
        &self,
749
0
        wl: &WorkloadInfo,
750
0
        deadline: Duration,
751
0
    ) -> Option<Arc<Workload>> {
752
0
        debug!(%wl, "wait for workload");
753
754
        // Take a watch listener *before* checking state (so we don't miss anything)
755
0
        let mut wl_sub = self.read().workloads.new_subscriber();
756
757
0
        debug!(%wl, "got sub, waiting for workload");
758
759
0
        if let Some(wl) = self.find_by_info(wl) {
760
0
            return Some(wl);
761
0
        }
762
763
        // We didn't find the workload we expected, so
764
        // loop until the subscriber wakes us on new workload,
765
        // or we hit the deadline timeout and give up
766
0
        let timeout = tokio::time::sleep(deadline);
767
0
        tokio::pin!(timeout);
768
        loop {
769
0
            tokio::select! {
770
0
                _ = &mut timeout => {
771
0
                    warn!("timed out waiting for workload '{wl}' from xds");
772
0
                    break None;
773
                },
774
0
                _ = wl_sub.changed() => {
775
0
                    if let Some(wl) = self.find_by_info(wl) {
776
0
                        break Some(wl);
777
0
                    }
778
                }
779
            }
780
        }
781
0
    }
782
783
    /// Finds the workload by workload information, as an arc.
784
    /// Note: this does not currently support on-demand.
785
0
    fn find_by_info(&self, wl: &WorkloadInfo) -> Option<Arc<Workload>> {
786
0
        self.read().workloads.find_by_info(wl)
787
0
    }
788
789
    // fetch_workload_by_address looks up a Workload by address.
790
    // Note this should never be used to lookup the local workload we are running, only the peer.
791
    // Since the peer connection may come through gateways, NAT, etc, this should only ever be treated
792
    // as a best-effort.
793
0
    pub async fn fetch_workload_by_address(&self, addr: &NetworkAddress) -> Option<Arc<Workload>> {
794
        // Wait for it on-demand, *if* needed
795
0
        debug!(%addr, "fetch workload");
796
0
        if let Some(wl) = self.read().workloads.find_address(addr) {
797
0
            return Some(wl);
798
0
        }
799
0
        if !self.supports_on_demand() {
800
0
            return None;
801
0
        }
802
0
        self.fetch_on_demand(addr.to_string().into()).await;
803
0
        self.read().workloads.find_address(addr)
804
0
    }
805
806
    // only support workload
807
0
    pub async fn fetch_workload_by_uid(&self, uid: &Strng) -> Option<Arc<Workload>> {
808
        // Wait for it on-demand, *if* needed
809
0
        debug!(%uid, "fetch workload");
810
0
        if let Some(wl) = self.read().workloads.find_uid(uid) {
811
0
            return Some(wl);
812
0
        }
813
0
        if !self.supports_on_demand() {
814
0
            return None;
815
0
        }
816
0
        self.fetch_on_demand(uid.clone()).await;
817
0
        self.read().workloads.find_uid(uid)
818
0
    }
819
820
0
    pub async fn fetch_upstream(
821
0
        &self,
822
0
        network: Strng,
823
0
        source_workload: &Workload,
824
0
        addr: SocketAddr,
825
0
        resolution_mode: ServiceResolutionMode,
826
0
    ) -> Result<Option<Upstream>, Error> {
827
0
        self.fetch_address(
828
0
            &network_addr(network.clone(), addr.ip()),
829
0
            Some(&source_workload.namespace),
830
0
        )
831
0
        .await;
832
0
        let upstream = {
833
0
            self.read()
834
0
                .find_upstream(network, source_workload, addr, resolution_mode)
835
            // Drop the lock
836
        };
837
0
        tracing::trace!(%addr, ?upstream, "fetch_upstream");
838
0
        self.finalize_upstream(source_workload, addr, upstream)
839
0
            .await
840
0
    }
841
842
0
    async fn finalize_upstream(
843
0
        &self,
844
0
        source_workload: &Workload,
845
0
        original_target_address: SocketAddr,
846
0
        upstream: Option<UpstreamDestination>,
847
0
    ) -> Result<Option<Upstream>, Error> {
848
0
        let (wl, port, svc) = match upstream {
849
0
            Some(UpstreamDestination::UpstreamParts(wl, port, svc)) => (wl, port, svc),
850
0
            None | Some(UpstreamDestination::OriginalDestination) => return Ok(None),
851
        };
852
0
        let svc_desc = svc.clone().map(|s| ServiceDescription::from(s.as_ref()));
853
0
        let ip_family_restriction = svc.as_ref().and_then(|s| s.ip_families);
854
0
        let selected_workload_ip = self
855
0
            .pick_workload_destination_or_resolve(
856
0
                &wl,
857
0
                source_workload,
858
0
                original_target_address,
859
0
                ip_family_restriction,
860
0
            )
861
0
            .await?; // if we can't load balance just return the error
862
0
        let res = Upstream {
863
0
            workload: wl,
864
0
            selected_workload_ip,
865
0
            port,
866
0
            service_sans: svc.map(|s| s.subject_alt_names.clone()).unwrap_or_default(),
867
0
            destination_service: svc_desc,
868
        };
869
0
        tracing::trace!(?res, "finalize_upstream");
870
0
        Ok(Some(res))
871
0
    }
872
873
    /// Returns destination address, upstream sans, and final sans, for
874
    /// connecting to a remote workload through a gateway.
875
    /// Would be nice to return this as an Upstream, but gateways don't necessarily
876
    /// have workloads. That is, they could just be IPs without a corresponding workload.
877
0
    pub async fn fetch_network_gateway(
878
0
        &self,
879
0
        gw_address: &GatewayAddress,
880
0
        source_workload: &Workload,
881
0
        original_destination_address: SocketAddr,
882
0
    ) -> Result<Upstream, Error> {
883
0
        let (res, target_address) = match &gw_address.destination {
884
0
            Destination::Address(ip) => {
885
0
                let addr = SocketAddr::new(ip.address, gw_address.hbone_mtls_port);
886
0
                let us = self.state.read().unwrap().find_upstream(
887
0
                    ip.network.clone(),
888
0
                    source_workload,
889
0
                    addr,
890
0
                    ServiceResolutionMode::Standard,
891
0
                );
892
                // If the workload references a network gateway by IP, use that IP as the destination.
893
                // Note this means that an IPv6 call may be translated to IPv4 if the network
894
                // gateway is specified as an IPv4 address.
895
                // For this reason, the Hostname method is preferred which can adapt to the callers IP family.
896
0
                (us, addr)
897
            }
898
0
            Destination::Hostname(host) => {
899
0
                let state = self.read();
900
0
                match state.find_hostname(host) {
901
0
                    Some(Address::Service(s)) => {
902
0
                        let us = state.find_upstream_from_service(
903
0
                            source_workload,
904
0
                            gw_address.hbone_mtls_port,
905
0
                            ServiceResolutionMode::Standard,
906
0
                            s,
907
0
                        );
908
                        // For hostname, use the original_destination_address as the target so we can
909
                        // adapt to the callers IP family.
910
0
                        (us, original_destination_address)
911
                    }
912
0
                    Some(Address::Workload(w)) => {
913
0
                        let us = Some(UpstreamDestination::UpstreamParts(
914
0
                            w,
915
0
                            gw_address.hbone_mtls_port,
916
0
                            None,
917
0
                        ));
918
0
                        (us, original_destination_address)
919
                    }
920
                    None => {
921
0
                        return Err(Error::UnknownNetworkGateway(format!(
922
0
                            "network gateway {} not found",
923
0
                            host.hostname
924
0
                        )));
925
                    }
926
                }
927
            }
928
        };
929
0
        self.finalize_upstream(source_workload, target_address, res)
930
0
            .await?
931
0
            .ok_or_else(|| {
932
0
                Error::UnknownNetworkGateway(format!("network gateway {gw_address:?} not found"))
933
0
            })
934
0
    }
935
936
0
    async fn fetch_waypoint(
937
0
        &self,
938
0
        gw_address: &GatewayAddress,
939
0
        source_workload: &Workload,
940
0
        original_destination_address: SocketAddr,
941
0
    ) -> Result<Upstream, Error> {
942
        // Waypoint can be referred to by an IP or Hostname.
943
        // Hostname is preferred as it is a more stable identifier.
944
0
        let (res, target_address) = match &gw_address.destination {
945
0
            Destination::Address(ip) => {
946
0
                let addr = SocketAddr::new(ip.address, gw_address.hbone_mtls_port);
947
0
                let us = self.read().find_upstream(
948
0
                    ip.network.clone(),
949
0
                    source_workload,
950
0
                    addr,
951
0
                    ServiceResolutionMode::Waypoint,
952
0
                );
953
                // If they referenced a waypoint by IP, use that IP as the destination.
954
                // Note this means that an IPv6 call may be translated to IPv4 if the waypoint is specified
955
                // as an IPv4 address.
956
                // For this reason, the Hostname method is preferred which can adapt to the callers IP family.
957
0
                (us, addr)
958
            }
959
0
            Destination::Hostname(host) => {
960
0
                let state = self.read();
961
0
                match state.find_hostname(host) {
962
0
                    Some(Address::Service(s)) => {
963
0
                        let us = state.find_upstream_from_service(
964
0
                            source_workload,
965
0
                            gw_address.hbone_mtls_port,
966
0
                            ServiceResolutionMode::Waypoint,
967
0
                            s,
968
0
                        );
969
                        // For hostname, use the original_destination_address as the target so we can
970
                        // adapt to the callers IP family.
971
0
                        (us, original_destination_address)
972
                    }
973
0
                    Some(Address::Workload(w)) => {
974
0
                        let us = Some(UpstreamDestination::UpstreamParts(
975
0
                            w,
976
0
                            gw_address.hbone_mtls_port,
977
0
                            None,
978
0
                        ));
979
0
                        (us, original_destination_address)
980
                    }
981
                    None => {
982
0
                        return Err(Error::UnknownWaypoint(format!(
983
0
                            "waypoint {} not found",
984
0
                            host.hostname
985
0
                        )));
986
                    }
987
                }
988
            }
989
        };
990
0
        self.finalize_upstream(source_workload, target_address, res)
991
0
            .await?
992
0
            .ok_or_else(|| Error::UnknownWaypoint(format!("waypoint {gw_address:?} not found")))
993
0
    }
994
995
0
    pub async fn fetch_service_waypoint(
996
0
        &self,
997
0
        service: &Service,
998
0
        source_workload: &Workload,
999
0
        original_destination_address: SocketAddr,
1000
0
    ) -> Result<Option<Upstream>, Error> {
1001
0
        let Some(gw_address) = &service.waypoint else {
1002
            // no waypoint
1003
0
            return Ok(None);
1004
        };
1005
0
        self.fetch_waypoint(gw_address, source_workload, original_destination_address)
1006
0
            .await
1007
0
            .map(Some)
1008
0
    }
1009
1010
0
    pub async fn fetch_workload_waypoint(
1011
0
        &self,
1012
0
        wl: &Workload,
1013
0
        source_workload: &Workload,
1014
0
        original_destination_address: SocketAddr,
1015
0
    ) -> Result<Option<Upstream>, Error> {
1016
0
        let Some(gw_address) = &wl.waypoint else {
1017
            // no waypoint
1018
0
            return Ok(None);
1019
        };
1020
0
        self.fetch_waypoint(gw_address, source_workload, original_destination_address)
1021
0
            .await
1022
0
            .map(Some)
1023
0
    }
1024
1025
    /// Looks for either a workload or service by the destination. If not found locally,
1026
    /// attempts to fetch on-demand.
1027
    /// If `ns` is provided, prefer a service in that namespace when multiple share the same VIP.
1028
0
    pub async fn fetch_destination(
1029
0
        &self,
1030
0
        dest: &Destination,
1031
0
        ns: Option<&Strng>,
1032
0
    ) -> Option<Address> {
1033
0
        match dest {
1034
0
            Destination::Address(addr) => self.fetch_address(addr, ns).await,
1035
0
            Destination::Hostname(hostname) => self.fetch_hostname(hostname).await,
1036
        }
1037
0
    }
1038
1039
    /// Looks for the given address to find either a workload or service by IP. If not found
1040
    /// locally, attempts to fetch on-demand.
1041
    /// If `ns` is provided, prefer a service in that namespace when multiple share the same VIP.
1042
0
    pub async fn fetch_address(
1043
0
        &self,
1044
0
        network_addr: &NetworkAddress,
1045
0
        ns: Option<&Strng>,
1046
0
    ) -> Option<Address> {
1047
        // Wait for it on-demand, *if* needed
1048
0
        debug!(%network_addr.address, "fetch address");
1049
0
        if let Some(address) = self.read().find_address(network_addr, ns) {
1050
0
            return Some(address);
1051
0
        }
1052
0
        if !self.supports_on_demand() {
1053
0
            return None;
1054
0
        }
1055
        // if both cache not found, start on demand fetch
1056
0
        self.fetch_on_demand(network_addr.to_string().into()).await;
1057
0
        self.read().find_address(network_addr, ns)
1058
0
    }
1059
1060
    /// Looks for the given hostname to find either a workload or service by IP. If not found
1061
    /// locally, attempts to fetch on-demand.
1062
0
    async fn fetch_hostname(&self, hostname: &NamespacedHostname) -> Option<Address> {
1063
        // Wait for it on-demand, *if* needed
1064
0
        debug!(%hostname, "fetch hostname");
1065
0
        if let Some(address) = self.read().find_hostname(hostname) {
1066
0
            return Some(address);
1067
0
        }
1068
0
        if !self.supports_on_demand() {
1069
0
            return None;
1070
0
        }
1071
        // if both cache not found, start on demand fetch
1072
0
        self.fetch_on_demand(hostname.to_string().into()).await;
1073
0
        self.read().find_hostname(hostname)
1074
0
    }
1075
1076
0
    pub fn supports_on_demand(&self) -> bool {
1077
0
        self.demand.is_some()
1078
0
    }
1079
1080
    /// fetch_on_demand looks up the provided key on-demand and waits for it to return
1081
0
    pub async fn fetch_on_demand(&self, key: Strng) {
1082
0
        if let Some(demand) = &self.demand {
1083
0
            debug!(%key, "sending demand request");
1084
0
            Box::pin(
1085
0
                demand
1086
0
                    .demand(xds::ADDRESS_TYPE, key.clone())
1087
0
                    .then(|o| o.recv()),
1088
            )
1089
0
            .await;
1090
0
            debug!(%key, "on demand ready");
1091
0
        }
1092
0
    }
1093
}
1094
1095
#[derive(Eq, PartialEq, Clone, Copy, Debug)]
1096
pub enum ServiceResolutionMode {
1097
    // We are resolving a normal service
1098
    Standard,
1099
    // We are resolving a waypoint proxy
1100
    Waypoint,
1101
}
1102
1103
#[derive(serde::Serialize)]
1104
pub struct ProxyStateManager {
1105
    #[serde(flatten)]
1106
    state: DemandProxyState,
1107
1108
    #[serde(skip_serializing)]
1109
    xds_client: Option<AdsClient>,
1110
}
1111
1112
impl ProxyStateManager {
1113
0
    pub async fn new(
1114
0
        config: Arc<config::Config>,
1115
0
        xds_metrics: xds::Metrics,
1116
0
        proxy_metrics: Arc<proxy::Metrics>,
1117
0
        awaiting_ready: tokio::sync::watch::Sender<()>,
1118
0
        cert_manager: Arc<SecretManager>,
1119
0
    ) -> anyhow::Result<ProxyStateManager> {
1120
0
        let cert_fetcher = cert_fetcher::new(&config, cert_manager);
1121
0
        let state: Arc<RwLock<ProxyState>> = Arc::new(RwLock::new(ProxyState::new(
1122
0
            config.local_node.as_ref().map(strng::new),
1123
        )));
1124
0
        let xds_client = if config.xds_address.is_some() {
1125
0
            let updater = ProxyStateUpdater::new(state.clone(), cert_fetcher.clone());
1126
0
            let tls_client_fetcher = Box::new(tls::ControlPlaneAuthentication::RootCert(
1127
0
                config.xds_root_cert.clone(),
1128
0
            ));
1129
0
            Some(
1130
0
                xds::Config::new(config.clone(), tls_client_fetcher)
1131
0
                    .with_watched_handler::<XdsAddress>(xds::ADDRESS_TYPE, updater.clone())
1132
0
                    .with_watched_handler::<XdsAuthorization>(xds::AUTHORIZATION_TYPE, updater)
1133
0
                    .build(xds_metrics, awaiting_ready),
1134
0
            )
1135
        } else {
1136
0
            None
1137
        };
1138
0
        if let Some(cfg) = &config.local_xds_config {
1139
0
            let local_client = LocalClient {
1140
0
                local_node: config.local_node.as_ref().map(strng::new),
1141
0
                cfg: cfg.clone(),
1142
0
                state: state.clone(),
1143
0
                cert_fetcher,
1144
0
            };
1145
0
            local_client.run().await?;
1146
0
        }
1147
0
        let demand = xds_client.as_ref().and_then(AdsClient::demander);
1148
0
        Ok(ProxyStateManager {
1149
0
            xds_client,
1150
0
            state: DemandProxyState::new(
1151
0
                state,
1152
0
                demand,
1153
0
                config.dns_resolver_cfg.clone(),
1154
0
                config.dns_resolver_opts.clone(),
1155
0
                proxy_metrics,
1156
0
            ),
1157
0
        })
1158
0
    }
1159
1160
0
    pub fn state(&self) -> DemandProxyState {
1161
0
        self.state.clone()
1162
0
    }
1163
1164
0
    pub async fn run(self) -> anyhow::Result<()> {
1165
0
        match self.xds_client {
1166
0
            Some(xds) => xds.run().await.map_err(|e| anyhow::anyhow!(e)),
1167
0
            None => Ok(()),
1168
        }
1169
0
    }
1170
}
1171
1172
#[cfg(test)]
1173
mod tests {
1174
    use crate::state::service::{EndpointSet, LoadBalancer, LoadBalancerHealthPolicy};
1175
    use crate::state::workload::{HealthStatus, Locality};
1176
    use prometheus_client::registry::Registry;
1177
    use rbac::StringMatch;
1178
    use std::{net::Ipv4Addr, net::SocketAddrV4, time::Duration};
1179
1180
    use self::workload::{ApplicationTunnel, application_tunnel::Protocol as AppProtocol};
1181
1182
    use super::*;
1183
    use crate::test_helpers::helpers::initialize_telemetry;
1184
1185
    use crate::{strng, test_helpers};
1186
    use test_case::test_case;
1187
1188
    #[tokio::test]
1189
    async fn test_wait_for_workload() {
1190
        let mut state = ProxyState::new(None);
1191
        let delayed_wl = Arc::new(test_helpers::test_default_workload());
1192
        state.workloads.insert(delayed_wl.clone());
1193
1194
        let mut registry = Registry::default();
1195
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1196
        let mock_proxy_state = DemandProxyState::new(
1197
            Arc::new(RwLock::new(state)),
1198
            None,
1199
            ResolverConfig::default(),
1200
            ResolverOpts::default(),
1201
            metrics,
1202
        );
1203
1204
        let want = WorkloadInfo {
1205
            name: delayed_wl.name.to_string(),
1206
            namespace: delayed_wl.namespace.to_string(),
1207
            service_account: delayed_wl.service_account.to_string(),
1208
        };
1209
1210
        test_helpers::assert_eventually(
1211
            Duration::from_secs(1),
1212
            || mock_proxy_state.wait_for_workload(&want, Duration::from_millis(50)),
1213
            Some(delayed_wl),
1214
        )
1215
        .await;
1216
    }
1217
1218
    #[tokio::test]
1219
    async fn test_wait_for_workload_delay_fails() {
1220
        let state = ProxyState::new(None);
1221
1222
        let mut registry = Registry::default();
1223
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1224
        let mock_proxy_state = DemandProxyState::new(
1225
            Arc::new(RwLock::new(state)),
1226
            None,
1227
            ResolverConfig::default(),
1228
            ResolverOpts::default(),
1229
            metrics,
1230
        );
1231
1232
        let want = WorkloadInfo {
1233
            name: "fake".to_string(),
1234
            namespace: "fake".to_string(),
1235
            service_account: "fake".to_string(),
1236
        };
1237
1238
        test_helpers::assert_eventually(
1239
            Duration::from_millis(10),
1240
            || mock_proxy_state.wait_for_workload(&want, Duration::from_millis(5)),
1241
            None,
1242
        )
1243
        .await;
1244
    }
1245
1246
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1247
    async fn test_wait_for_workload_eventually() {
1248
        initialize_telemetry();
1249
        let state = ProxyState::new(None);
1250
        let wrap_state = Arc::new(RwLock::new(state));
1251
        let not_delayed_wl = Arc::new(Workload {
1252
            workload_ips: vec!["1.2.3.4".parse().unwrap()],
1253
            uid: "uid".into(),
1254
            name: "n".into(),
1255
            namespace: "ns".into(),
1256
            ..test_helpers::test_default_workload()
1257
        });
1258
        let delayed_wl = Arc::new(test_helpers::test_default_workload());
1259
1260
        let mut registry = Registry::default();
1261
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1262
        let mock_proxy_state = DemandProxyState::new(
1263
            wrap_state.clone(),
1264
            None,
1265
            ResolverConfig::default(),
1266
            ResolverOpts::default(),
1267
            metrics,
1268
        );
1269
1270
        // Some from Address
1271
        let want = WorkloadInfo {
1272
            name: delayed_wl.name.to_string(),
1273
            namespace: delayed_wl.namespace.to_string(),
1274
            service_account: delayed_wl.service_account.to_string(),
1275
        };
1276
1277
        let expected_wl = delayed_wl.clone();
1278
        let t = tokio::spawn(async move {
1279
            test_helpers::assert_eventually(
1280
                Duration::from_millis(500),
1281
                || mock_proxy_state.wait_for_workload(&want, Duration::from_millis(250)),
1282
                Some(expected_wl),
1283
            )
1284
            .await;
1285
        });
1286
        // Send the wrong workload through
1287
        wrap_state.write().unwrap().workloads.insert(not_delayed_wl);
1288
        tokio::time::sleep(Duration::from_millis(100)).await;
1289
        // Send the correct workload through
1290
        wrap_state.write().unwrap().workloads.insert(delayed_wl);
1291
        t.await.expect("should not fail");
1292
    }
1293
1294
    #[tokio::test]
1295
    async fn lookup_address() {
1296
        let mut state = ProxyState::new(None);
1297
        state
1298
            .workloads
1299
            .insert(Arc::new(test_helpers::test_default_workload()));
1300
        state.services.insert(test_helpers::mock_default_service());
1301
1302
        let mut registry = Registry::default();
1303
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1304
        let mock_proxy_state = DemandProxyState::new(
1305
            Arc::new(RwLock::new(state)),
1306
            None,
1307
            ResolverConfig::default(),
1308
            ResolverOpts::default(),
1309
            metrics,
1310
        );
1311
1312
        // Some from Address
1313
        let dst = Destination::Address(NetworkAddress {
1314
            network: strng::EMPTY,
1315
            address: IpAddr::V4(Ipv4Addr::LOCALHOST),
1316
        });
1317
        test_helpers::assert_eventually(
1318
            Duration::from_secs(5),
1319
            || mock_proxy_state.fetch_destination(&dst, None),
1320
            Some(Address::Workload(Arc::new(
1321
                test_helpers::test_default_workload(),
1322
            ))),
1323
        )
1324
        .await;
1325
1326
        // Some from Hostname
1327
        let dst = Destination::Hostname(NamespacedHostname {
1328
            namespace: "default".into(),
1329
            hostname: "defaulthost".into(),
1330
        });
1331
        test_helpers::assert_eventually(
1332
            Duration::from_secs(5),
1333
            || mock_proxy_state.fetch_destination(&dst, None),
1334
            Some(Address::Service(Arc::new(
1335
                test_helpers::mock_default_service(),
1336
            ))),
1337
        )
1338
        .await;
1339
1340
        // None from Address
1341
        let dst = Destination::Address(NetworkAddress {
1342
            network: "".into(),
1343
            address: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)),
1344
        });
1345
        test_helpers::assert_eventually(
1346
            Duration::from_secs(5),
1347
            || mock_proxy_state.fetch_destination(&dst, None),
1348
            None,
1349
        )
1350
        .await;
1351
1352
        // None from Hostname
1353
        let dst = Destination::Hostname(NamespacedHostname {
1354
            namespace: "default".into(),
1355
            hostname: "nothost".into(),
1356
        });
1357
        test_helpers::assert_eventually(
1358
            Duration::from_secs(5),
1359
            || mock_proxy_state.fetch_destination(&dst, None),
1360
            None,
1361
        )
1362
        .await;
1363
    }
1364
1365
    enum PortMappingTestCase {
1366
        EndpointMapping,
1367
        ServiceMapping,
1368
        AppTunnel,
1369
    }
1370
1371
    impl PortMappingTestCase {
1372
        fn service_mapping(&self) -> HashMap<u16, u16> {
1373
            if let PortMappingTestCase::ServiceMapping = self {
1374
                return HashMap::from([(80, 8080)]);
1375
            }
1376
            HashMap::from([(80, 0)])
1377
        }
1378
1379
        fn endpoint_mapping(&self) -> HashMap<u16, u16> {
1380
            if let PortMappingTestCase::EndpointMapping = self {
1381
                return HashMap::from([(80, 9090)]);
1382
            }
1383
            HashMap::from([])
1384
        }
1385
1386
        fn app_tunnel(&self) -> Option<ApplicationTunnel> {
1387
            if let PortMappingTestCase::AppTunnel = self {
1388
                return Some(ApplicationTunnel {
1389
                    protocol: AppProtocol::PROXY,
1390
                    port: Some(15088),
1391
                });
1392
            }
1393
            None
1394
        }
1395
1396
        fn expected_port(&self) -> u16 {
1397
            match self {
1398
                PortMappingTestCase::ServiceMapping => 8080,
1399
                PortMappingTestCase::EndpointMapping => 9090,
1400
                _ => 80,
1401
            }
1402
        }
1403
    }
1404
1405
    #[test_case(PortMappingTestCase::EndpointMapping; "ep mapping")]
1406
    #[test_case(PortMappingTestCase::ServiceMapping; "svc mapping")]
1407
    #[test_case(PortMappingTestCase::AppTunnel; "app tunnel")]
1408
    #[tokio::test]
1409
    async fn find_upstream_port_mappings(tc: PortMappingTestCase) {
1410
        initialize_telemetry();
1411
        let wl = Workload {
1412
            uid: "cluster1//v1/Pod/default/ep_no_port_mapping".into(),
1413
            name: "ep_no_port_mapping".into(),
1414
            namespace: "default".into(),
1415
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))],
1416
            application_tunnel: tc.app_tunnel(),
1417
            ..test_helpers::test_default_workload()
1418
        };
1419
        let svc = Service {
1420
            name: "test-svc".into(),
1421
            hostname: "example.com".into(),
1422
            namespace: "default".into(),
1423
            vips: vec![NetworkAddress {
1424
                address: "10.0.0.1".parse().unwrap(),
1425
                network: "".into(),
1426
            }],
1427
            endpoints: EndpointSet::from_list([Endpoint {
1428
                workload_uid: "cluster1//v1/Pod/default/ep_no_port_mapping".into(),
1429
                port: tc.endpoint_mapping(),
1430
                status: HealthStatus::Healthy,
1431
            }]),
1432
            ports: tc.service_mapping(),
1433
            ..test_helpers::mock_default_service()
1434
        };
1435
1436
        let mut state = ProxyState::new(None);
1437
        state.workloads.insert(wl.clone().into());
1438
        state.services.insert(svc);
1439
1440
        let mode = match tc {
1441
            PortMappingTestCase::AppTunnel => ServiceResolutionMode::Waypoint,
1442
            _ => ServiceResolutionMode::Standard,
1443
        };
1444
1445
        let port = match state.find_upstream("".into(), &wl, "10.0.0.1:80".parse().unwrap(), mode) {
1446
            Some(UpstreamDestination::UpstreamParts(_, port, _)) => port,
1447
            _ => panic!("upstream to be found"),
1448
        };
1449
1450
        assert_eq!(port, tc.expected_port());
1451
    }
1452
1453
    fn create_workload(dest_uid: u8) -> Workload {
1454
        Workload {
1455
            name: "test".into(),
1456
            namespace: format!("ns{dest_uid}").into(),
1457
            trust_domain: "cluster.local".into(),
1458
            service_account: "defaultacct".into(),
1459
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, dest_uid))],
1460
            uid: format!("{dest_uid}").into(),
1461
            ..test_helpers::test_default_workload()
1462
        }
1463
    }
1464
1465
    fn get_workload(state: &DemandProxyState, dest_uid: u8) -> Arc<Workload> {
1466
        let key: Strng = format!("{dest_uid}").into();
1467
        state.read().workloads.by_uid[&key].clone()
1468
    }
1469
1470
    fn get_rbac_context(
1471
        state: &DemandProxyState,
1472
        dest_uid: u8,
1473
        src_svc_acct: &str,
1474
    ) -> crate::state::ProxyRbacContext {
1475
        let key: Strng = format!("{dest_uid}").into();
1476
        let workload = &state.read().workloads.by_uid[&key];
1477
        crate::state::ProxyRbacContext {
1478
            conn: rbac::Connection {
1479
                src_identity: Some(Identity::Spiffe {
1480
                    trust_domain: "cluster.local".into(),
1481
                    namespace: "default".into(),
1482
                    service_account: src_svc_acct.to_string().into(),
1483
                }),
1484
                src: std::net::SocketAddr::V4(SocketAddrV4::new(
1485
                    Ipv4Addr::new(192, 168, 1, 1),
1486
                    1234,
1487
                )),
1488
                dst_network: "".into(),
1489
                dst: SocketAddr::new(workload.workload_ips[0], 8080),
1490
            },
1491
            dest_workload: get_workload(state, dest_uid),
1492
        }
1493
    }
1494
    fn create_state(state: ProxyState) -> DemandProxyState {
1495
        let mut registry = Registry::default();
1496
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1497
        DemandProxyState::new(
1498
            Arc::new(RwLock::new(state)),
1499
            None,
1500
            ResolverConfig::default(),
1501
            ResolverOpts::default(),
1502
            metrics,
1503
        )
1504
    }
1505
1506
    fn create_dry_run_wildcard_rbac_policy(action: rbac::RbacAction) -> rbac::Authorization {
1507
        rbac::Authorization {
1508
            action,
1509
            namespace: "ns1".into(),
1510
            name: "wildcard".into(),
1511
            rules: vec![vec![]],
1512
            scope: rbac::RbacScope::Namespace,
1513
            dry_run: true,
1514
        }
1515
    }
1516
1517
    // test that we confirm with https://istio.io/latest/docs/reference/config/security/authorization-policy/.
1518
    // We don't test #1 as ztunnel doesn't support custom policies.
1519
    // 1. If there are any CUSTOM policies that match the request, evaluate and deny the request if the evaluation result is deny.
1520
    // 2. If there are any DENY policies that match the request, deny the request.
1521
    // 3. If there are no ALLOW policies for the workload, allow the request.
1522
    // 4. If any of the ALLOW policies match the request, allow the request.
1523
    // 5. Deny the request.
1524
    #[tokio::test]
1525
    async fn assert_rbac_logic_deny_allow() {
1526
        let mut state = ProxyState::new(None);
1527
        state.workloads.insert(Arc::new(create_workload(1)));
1528
        state.workloads.insert(Arc::new(create_workload(2)));
1529
        // Dry run policies should have no effect.
1530
        state.policies.insert(
1531
            "wildcard-allow".into(),
1532
            create_dry_run_wildcard_rbac_policy(rbac::RbacAction::Allow),
1533
        );
1534
        state.policies.insert(
1535
            "wildcard-deny".into(),
1536
            create_dry_run_wildcard_rbac_policy(rbac::RbacAction::Deny),
1537
        );
1538
        state.policies.insert(
1539
            "allow".into(),
1540
            rbac::Authorization {
1541
                action: rbac::RbacAction::Allow,
1542
                namespace: "ns1".into(),
1543
                name: "foo".into(),
1544
                rules: vec![
1545
                    // rule1:
1546
                    vec![
1547
                        // from:
1548
                        vec![rbac::RbacMatch {
1549
                            principals: vec![StringMatch::Exact(
1550
                                "cluster.local/ns/default/sa/defaultacct".into(),
1551
                            )],
1552
                            ..Default::default()
1553
                        }],
1554
                    ],
1555
                ],
1556
                scope: rbac::RbacScope::Namespace,
1557
                dry_run: false,
1558
            },
1559
        );
1560
        state.policies.insert(
1561
            "deny".into(),
1562
            rbac::Authorization {
1563
                action: rbac::RbacAction::Deny,
1564
                namespace: "ns1".into(),
1565
                name: "deny".into(),
1566
                rules: vec![
1567
                    // rule1:
1568
                    vec![
1569
                        // from:
1570
                        vec![rbac::RbacMatch {
1571
                            principals: vec![StringMatch::Exact(
1572
                                "cluster.local/ns/default/sa/denyacct".into(),
1573
                            )],
1574
                            ..Default::default()
1575
                        }],
1576
                    ],
1577
                ],
1578
                scope: rbac::RbacScope::Namespace,
1579
                dry_run: false,
1580
            },
1581
        );
1582
1583
        let mock_proxy_state = create_state(state);
1584
1585
        // test workload in ns2. this should work as ns2 doesn't have any policies. this tests:
1586
        // 3. If there are no ALLOW policies for the workload, allow the request.
1587
        assert!(
1588
            mock_proxy_state
1589
                .assert_rbac(&get_rbac_context(&mock_proxy_state, 2, "not-defaultacct"))
1590
                .await
1591
                .is_ok()
1592
        );
1593
1594
        let ctx = get_rbac_context(&mock_proxy_state, 1, "defaultacct");
1595
        // 4. if any allow policies match, allow
1596
        assert!(mock_proxy_state.assert_rbac(&ctx).await.is_ok());
1597
1598
        {
1599
            // test a src workload with unknown svc account. this should fail as we have allow policies,
1600
            // but they don't match.
1601
            // 5. deny the request
1602
            let mut ctx = ctx.clone();
1603
            ctx.conn.src_identity = Some(Identity::Spiffe {
1604
                trust_domain: "cluster.local".into(),
1605
                namespace: "default".into(),
1606
                service_account: "not-defaultacct".into(),
1607
            });
1608
1609
            assert_eq!(
1610
                mock_proxy_state.assert_rbac(&ctx).await.err().unwrap(),
1611
                proxy::AuthorizationRejectionError::NotAllowed
1612
            );
1613
        }
1614
        {
1615
            let mut ctx = ctx.clone();
1616
            ctx.conn.src_identity = Some(Identity::Spiffe {
1617
                trust_domain: "cluster.local".into(),
1618
                namespace: "default".into(),
1619
                service_account: "denyacct".into(),
1620
            });
1621
1622
            // 2. If there are any DENY policies that match the request, deny the request.
1623
            assert_eq!(
1624
                mock_proxy_state.assert_rbac(&ctx).await.err().unwrap(),
1625
                proxy::AuthorizationRejectionError::ExplicitlyDenied("ns1".into(), "deny".into())
1626
            );
1627
        }
1628
    }
1629
1630
    #[tokio::test]
1631
    async fn assert_rbac_with_dest_workload_info() {
1632
        let mut state = ProxyState::new(None);
1633
        state.workloads.insert(Arc::new(create_workload(1)));
1634
1635
        let mock_proxy_state = create_state(state);
1636
1637
        let ctx = get_rbac_context(&mock_proxy_state, 1, "defaultacct");
1638
        assert!(mock_proxy_state.assert_rbac(&ctx).await.is_ok());
1639
    }
1640
1641
    #[tokio::test]
1642
    async fn assert_rbac_dry_run_with_real_policies() {
1643
        initialize_telemetry();
1644
        crate::telemetry::set_level(true, "debug").ok();
1645
1646
        let mut state = ProxyState::new(None);
1647
        state.workloads.insert(Arc::new(create_workload(1)));
1648
1649
        // Real deny policy that matches denyacct
1650
        state.policies.insert(
1651
            "real-deny".into(),
1652
            rbac::Authorization {
1653
                action: rbac::RbacAction::Deny,
1654
                namespace: "ns1".into(),
1655
                name: "real-deny".into(),
1656
                rules: vec![vec![vec![rbac::RbacMatch {
1657
                    principals: vec![StringMatch::Exact(
1658
                        "cluster.local/ns/default/sa/denyacct".into(),
1659
                    )],
1660
                    ..Default::default()
1661
                }]]],
1662
                scope: rbac::RbacScope::Namespace,
1663
                dry_run: false,
1664
            },
1665
        );
1666
1667
        // Dry-run deny policy that matches both defaultacct and denyacct
1668
        state.policies.insert(
1669
            "dry-run-deny".into(),
1670
            rbac::Authorization {
1671
                action: rbac::RbacAction::Deny,
1672
                namespace: "ns1".into(),
1673
                name: "dry-run-deny".into(),
1674
                rules: vec![
1675
                    vec![vec![rbac::RbacMatch {
1676
                        principals: vec![StringMatch::Exact(
1677
                            "cluster.local/ns/default/sa/defaultacct".into(),
1678
                        )],
1679
                        ..Default::default()
1680
                    }]],
1681
                    vec![vec![rbac::RbacMatch {
1682
                        principals: vec![StringMatch::Exact(
1683
                            "cluster.local/ns/default/sa/denyacct".into(),
1684
                        )],
1685
                        ..Default::default()
1686
                    }]],
1687
                ],
1688
                scope: rbac::RbacScope::Namespace,
1689
                dry_run: true,
1690
            },
1691
        );
1692
1693
        // Real allow policy that matches defaultacct
1694
        state.policies.insert(
1695
            "real-allow".into(),
1696
            rbac::Authorization {
1697
                action: rbac::RbacAction::Allow,
1698
                namespace: "ns1".into(),
1699
                name: "real-allow".into(),
1700
                rules: vec![vec![vec![rbac::RbacMatch {
1701
                    principals: vec![StringMatch::Exact(
1702
                        "cluster.local/ns/default/sa/defaultacct".into(),
1703
                    )],
1704
                    ..Default::default()
1705
                }]]],
1706
                scope: rbac::RbacScope::Namespace,
1707
                dry_run: false,
1708
            },
1709
        );
1710
1711
        // Dry-run allow policy that matches defaultacct
1712
        state.policies.insert(
1713
            "dry-run-allow".into(),
1714
            rbac::Authorization {
1715
                action: rbac::RbacAction::Allow,
1716
                namespace: "ns1".into(),
1717
                name: "dry-run-allow".into(),
1718
                rules: vec![vec![vec![rbac::RbacMatch {
1719
                    principals: vec![StringMatch::Exact(
1720
                        "cluster.local/ns/default/sa/defaultacct".into(),
1721
                    )],
1722
                    ..Default::default()
1723
                }]]],
1724
                scope: rbac::RbacScope::Namespace,
1725
                dry_run: true,
1726
            },
1727
        );
1728
1729
        let mock_proxy_state = create_state(state);
1730
1731
        let ctx = get_rbac_context(&mock_proxy_state, 1, "defaultacct");
1732
        assert!(mock_proxy_state.assert_rbac(&ctx).await.is_ok());
1733
1734
        crate::telemetry::testing::assert_contains(std::collections::HashMap::from([
1735
            ("policy", "ns1/dry-run-deny"),
1736
            ("message", "dry-run: deny policy match"),
1737
        ]));
1738
        crate::telemetry::testing::assert_contains(std::collections::HashMap::from([
1739
            ("policy", "ns1/dry-run-allow"),
1740
            ("message", "dry-run: allow policy match"),
1741
        ]));
1742
    }
1743
1744
    #[tokio::test]
1745
    async fn test_load_balance() {
1746
        initialize_telemetry();
1747
        let mut state = ProxyState::new(None);
1748
        let wl_no_locality = Workload {
1749
            uid: "cluster1//v1/Pod/default/wl_no_locality".into(),
1750
            name: "wl_no_locality".into(),
1751
            namespace: "default".into(),
1752
            trust_domain: "cluster.local".into(),
1753
            service_account: "default".into(),
1754
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))],
1755
            ..test_helpers::test_default_workload()
1756
        };
1757
        let wl_match = Workload {
1758
            uid: "cluster1//v1/Pod/default/wl_match".into(),
1759
            name: "wl_match".into(),
1760
            namespace: "default".into(),
1761
            trust_domain: "cluster.local".into(),
1762
            service_account: "default".into(),
1763
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 2))],
1764
            network: "network".into(),
1765
            locality: Locality {
1766
                region: "reg".into(),
1767
                zone: "zone".into(),
1768
                subzone: "".into(),
1769
            },
1770
            ..test_helpers::test_default_workload()
1771
        };
1772
        let wl_almost = Workload {
1773
            uid: "cluster1//v1/Pod/default/wl_almost".into(),
1774
            name: "wl_almost".into(),
1775
            namespace: "default".into(),
1776
            trust_domain: "cluster.local".into(),
1777
            service_account: "default".into(),
1778
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 3))],
1779
            network: "network".into(),
1780
            locality: Locality {
1781
                region: "reg".into(),
1782
                zone: "not-zone".into(),
1783
                subzone: "".into(),
1784
            },
1785
            ..test_helpers::test_default_workload()
1786
        };
1787
        let wl_empty_ip = Workload {
1788
            uid: "cluster1//v1/Pod/default/wl_empty_ip".into(),
1789
            name: "wl_empty_ip".into(),
1790
            namespace: "default".into(),
1791
            trust_domain: "cluster.local".into(),
1792
            service_account: "default".into(),
1793
            workload_ips: vec![], // none!
1794
            network: "network".into(),
1795
            locality: Locality {
1796
                region: "reg".into(),
1797
                zone: "zone".into(),
1798
                subzone: "".into(),
1799
            },
1800
            ..test_helpers::test_default_workload()
1801
        };
1802
1803
        let _ep_almost = Workload {
1804
            uid: "cluster1//v1/Pod/default/ep_almost".into(),
1805
            name: "wl_almost".into(),
1806
            namespace: "default".into(),
1807
            trust_domain: "cluster.local".into(),
1808
            service_account: "default".into(),
1809
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 4))],
1810
            network: "network".into(),
1811
            locality: Locality {
1812
                region: "reg".into(),
1813
                zone: "other-not-zone".into(),
1814
                subzone: "".into(),
1815
            },
1816
            ..test_helpers::test_default_workload()
1817
        };
1818
        let _ep_no_match = Workload {
1819
            uid: "cluster1//v1/Pod/default/ep_no_match".into(),
1820
            name: "wl_almost".into(),
1821
            namespace: "default".into(),
1822
            trust_domain: "cluster.local".into(),
1823
            service_account: "default".into(),
1824
            workload_ips: vec![IpAddr::V4(Ipv4Addr::new(192, 168, 0, 5))],
1825
            network: "not-network".into(),
1826
            locality: Locality {
1827
                region: "not-reg".into(),
1828
                zone: "unmatched-zone".into(),
1829
                subzone: "".into(),
1830
            },
1831
            ..test_helpers::test_default_workload()
1832
        };
1833
        let endpoints = EndpointSet::from_list([
1834
            Endpoint {
1835
                workload_uid: "cluster1//v1/Pod/default/ep_almost".into(),
1836
                port: HashMap::from([(80u16, 80u16)]),
1837
                status: HealthStatus::Healthy,
1838
            },
1839
            Endpoint {
1840
                workload_uid: "cluster1//v1/Pod/default/ep_no_match".into(),
1841
                port: HashMap::from([(80u16, 80u16)]),
1842
                status: HealthStatus::Healthy,
1843
            },
1844
            Endpoint {
1845
                workload_uid: "cluster1//v1/Pod/default/wl_match".into(),
1846
                port: HashMap::from([(80u16, 80u16)]),
1847
                status: HealthStatus::Healthy,
1848
            },
1849
            Endpoint {
1850
                workload_uid: "cluster1//v1/Pod/default/wl_empty_ip".into(),
1851
                port: HashMap::from([(80u16, 80u16)]),
1852
                status: HealthStatus::Healthy,
1853
            },
1854
        ]);
1855
        let strict_svc = Service {
1856
            endpoints: endpoints.clone(),
1857
            load_balancer: Some(LoadBalancer {
1858
                mode: LoadBalancerMode::Strict,
1859
                routing_preferences: vec![
1860
                    LoadBalancerScopes::Network,
1861
                    LoadBalancerScopes::Region,
1862
                    LoadBalancerScopes::Zone,
1863
                ],
1864
                health_policy: LoadBalancerHealthPolicy::OnlyHealthy,
1865
            }),
1866
            ports: HashMap::from([(80u16, 80u16)]),
1867
            ..test_helpers::mock_default_service()
1868
        };
1869
        let failover_svc = Service {
1870
            endpoints,
1871
            load_balancer: Some(LoadBalancer {
1872
                mode: LoadBalancerMode::Failover,
1873
                routing_preferences: vec![
1874
                    LoadBalancerScopes::Network,
1875
                    LoadBalancerScopes::Region,
1876
                    LoadBalancerScopes::Zone,
1877
                ],
1878
                health_policy: LoadBalancerHealthPolicy::OnlyHealthy,
1879
            }),
1880
            ports: HashMap::from([(80u16, 80u16)]),
1881
            ..test_helpers::mock_default_service()
1882
        };
1883
        state.workloads.insert(Arc::new(wl_no_locality.clone()));
1884
        state.workloads.insert(Arc::new(wl_match.clone()));
1885
        state.workloads.insert(Arc::new(wl_almost.clone()));
1886
        state.workloads.insert(Arc::new(wl_empty_ip.clone()));
1887
        state.services.insert(strict_svc.clone());
1888
        state.services.insert(failover_svc.clone());
1889
1890
        let assert_endpoint = |src: &Workload, svc: &Service, workloads: Vec<&str>, desc: &str| {
1891
            let got = state
1892
                .load_balance(src, svc, 80, ServiceResolutionMode::Standard)
1893
                .map(|(ep, _)| ep.workload_uid.to_string());
1894
            if workloads.is_empty() {
1895
                assert!(got.is_none(), "{}", desc);
1896
            } else {
1897
                let want: Vec<String> = workloads.iter().map(ToString::to_string).collect();
1898
                assert!(want.contains(&got.unwrap()), "{}", desc);
1899
            }
1900
        };
1901
        let assert_not_endpoint =
1902
            |src: &Workload, svc: &Service, uid: &str, tries: usize, desc: &str| {
1903
                for _ in 0..tries {
1904
                    let got = state
1905
                        .load_balance(src, svc, 80, ServiceResolutionMode::Standard)
1906
                        .map(|(ep, _)| ep.workload_uid.as_str());
1907
                    assert!(got != Some(uid), "{}", desc);
1908
                }
1909
            };
1910
1911
        assert_endpoint(
1912
            &wl_no_locality,
1913
            &strict_svc,
1914
            vec![],
1915
            "strict no match should not select",
1916
        );
1917
        assert_endpoint(
1918
            &wl_almost,
1919
            &strict_svc,
1920
            vec![],
1921
            "strict no match should not select",
1922
        );
1923
        assert_endpoint(
1924
            &wl_match,
1925
            &strict_svc,
1926
            vec!["cluster1//v1/Pod/default/wl_match"],
1927
            "strict match",
1928
        );
1929
1930
        assert_endpoint(
1931
            &wl_no_locality,
1932
            &failover_svc,
1933
            vec![
1934
                "cluster1//v1/Pod/default/ep_almost",
1935
                "cluster1//v1/Pod/default/ep_no_match",
1936
                "cluster1//v1/Pod/default/wl_match",
1937
            ],
1938
            "failover no match can select any endpoint",
1939
        );
1940
        assert_endpoint(
1941
            &wl_almost,
1942
            &failover_svc,
1943
            vec![
1944
                "cluster1//v1/Pod/default/ep_almost",
1945
                "cluster1//v1/Pod/default/wl_match",
1946
            ],
1947
            "failover almost match can select any close matches",
1948
        );
1949
        assert_endpoint(
1950
            &wl_match,
1951
            &failover_svc,
1952
            vec!["cluster1//v1/Pod/default/wl_match"],
1953
            "failover full match selects closest match",
1954
        );
1955
        assert_not_endpoint(
1956
            &wl_no_locality,
1957
            &failover_svc,
1958
            "cluster1//v1/Pod/default/wl_empty_ip",
1959
            10,
1960
            "failover no match can select any endpoint",
1961
        );
1962
    }
1963
}