Coverage Report

Created: 2025-11-16 06:37

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/ztunnel/src/proxy/pool.rs
Line
Count
Source
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#![warn(clippy::cast_lossless)]
16
use super::{Error, SocketFactory};
17
use super::{LocalWorkloadInformation, h2};
18
use std::time::Duration;
19
20
use std::collections::hash_map::DefaultHasher;
21
22
use std::hash::{Hash, Hasher};
23
24
use std::sync::Arc;
25
use std::sync::atomic::{AtomicI32, Ordering};
26
27
use tokio::sync::watch;
28
29
use tokio::sync::Mutex;
30
use tracing::{Instrument, debug, trace};
31
32
use crate::config;
33
34
use flurry;
35
36
use crate::proxy::h2::H2Stream;
37
use crate::proxy::h2::client::{H2ConnectClient, WorkloadKey};
38
use pingora_pool;
39
use tokio::io;
40
41
// A relatively nonstandard HTTP/2 connection pool designed to allow multiplexing proxied workload connections
42
// over a (smaller) number of HTTP/2 mTLS tunnels.
43
//
44
// The following invariants apply to this pool:
45
// - Every workload (inpod mode) gets its own connpool.
46
// - Every unique src/dest key gets their own dedicated connections inside the pool.
47
// - Every unique src/dest key gets 1-n dedicated connections, where N is (currently) unbounded but practically limited
48
//   by flow control throttling.
49
#[derive(Clone)]
50
pub struct WorkloadHBONEPool {
51
    state: Arc<PoolState>,
52
    pool_watcher: watch::Receiver<bool>,
53
}
54
55
// PoolState is effectively the gnarly inner state stuff that needs thread/task sync, and should be wrapped in a Mutex.
56
struct PoolState {
57
    pool_notifier: watch::Sender<bool>, // This is already impl clone? rustc complains that it isn't, tho
58
    timeout_tx: watch::Sender<bool>, // This is already impl clone? rustc complains that it isn't, tho
59
    // this is effectively just a convenience data type - a rwlocked hashmap with keying and LRU drops
60
    // and has no actual hyper/http/connection logic.
61
    connected_pool: Arc<pingora_pool::ConnectionPool<H2ConnectClient>>,
62
    // this must be an atomic/concurrent-safe list-of-locks, so we can lock per-key, not globally, and avoid holding up all conn attempts
63
    established_conn_writelock: flurry::HashMap<u64, Option<Arc<Mutex<()>>>>,
64
    pool_unused_release_timeout: Duration,
65
    // This is merely a counter to track the overall number of conns this pool spawns
66
    // to ensure we get unique poolkeys-per-new-conn, it is not a limit
67
    pool_global_conn_count: AtomicI32,
68
    spawner: ConnSpawner,
69
}
70
71
struct ConnSpawner {
72
    cfg: Arc<config::Config>,
73
    socket_factory: Arc<dyn SocketFactory + Send + Sync>,
74
    local_workload: Arc<LocalWorkloadInformation>,
75
    timeout_rx: watch::Receiver<bool>,
76
}
77
78
// Does nothing but spawn new conns when asked
79
impl ConnSpawner {
80
0
    async fn new_pool_conn(&self, key: WorkloadKey) -> Result<H2ConnectClient, Error> {
81
0
        debug!("spawning new pool conn for {}", key);
82
83
0
        let cert = self.local_workload.fetch_certificate().await?;
84
0
        let connector = cert.outbound_connector(key.dst_id.clone())?;
85
0
        let tcp_stream = super::freebind_connect(None, key.dst, self.socket_factory.as_ref())
86
0
            .await
87
0
            .map_err(|e: io::Error| match e.kind() {
88
0
                io::ErrorKind::TimedOut => Error::MaybeHBONENetworkPolicyError(e),
89
0
                _ => e.into(),
90
0
            })?;
91
92
0
        let tls_stream = connector.connect(tcp_stream).await?;
93
0
        trace!("connector connected, handshaking");
94
0
        let sender = h2::client::spawn_connection(
95
0
            self.cfg.clone(),
96
0
            tls_stream,
97
0
            self.timeout_rx.clone(),
98
0
            key,
99
0
        )
100
0
        .await?;
101
0
        Ok(sender)
102
0
    }
103
}
104
105
impl PoolState {
106
    // This simply puts the connection back into the inner pool,
107
    // and sets up a timed popper, which will resolve
108
    // - when this reference is popped back out of the inner pool (doing nothing)
109
    // - when this reference is evicted from the inner pool (doing nothing)
110
    // - when the timeout_idler is drained (will pop)
111
    // - when the timeout is hit (will pop)
112
    //
113
    // Idle poppers are safe to invoke if the conn they are popping is already gone
114
    // from the inner queue, so we will start one for every insert, let them run or terminate on their own,
115
    // and poll them to completion on shutdown - any duplicates from repeated checkouts/checkins of the same conn
116
    // will simply resolve as a no-op in order.
117
    //
118
    // Note that "idle" in the context of this pool means "no one has asked for it or dropped it in X time, so prune it".
119
    //
120
    // Pruning the idle connection from the pool does not close it - it simply ensures the pool stops holding a ref.
121
    // hyper self-closes client conns when all refs are dropped and streamcount is 0, so pool consumers must
122
    // drop their checked out conns and/or terminate their streams as well.
123
    //
124
    // Note that this simply removes the client ref from this pool - if other things hold client/streamrefs refs,
125
    // they must also drop those before the underlying connection is fully closed.
126
0
    fn maybe_checkin_conn(&self, conn: H2ConnectClient, pool_key: pingora_pool::ConnectionMeta) {
127
0
        if conn.will_be_at_max_streamcount() {
128
0
            debug!(
129
0
                "checked out connection for {:?} is now at max streamcount; removing from pool",
130
                pool_key
131
            );
132
0
            return;
133
0
        }
134
0
        let (evict, pickup) = self.connected_pool.put(&pool_key, conn);
135
0
        let rx = self.spawner.timeout_rx.clone();
136
0
        let pool_ref = self.connected_pool.clone();
137
0
        let pool_key_ref = pool_key.clone();
138
0
        let release_timeout = self.pool_unused_release_timeout;
139
0
        tokio::spawn(
140
0
            async move {
141
0
                debug!("starting an idle timeout for connection {:?}", pool_key_ref);
142
0
                pool_ref
143
0
                    .idle_timeout(&pool_key_ref, release_timeout, evict, rx, pickup)
144
0
                    .await;
145
0
                debug!(
146
0
                    "connection {:?} was removed/checked out/timed out of the pool",
147
                    pool_key_ref
148
                )
149
0
            }
150
0
            .in_current_span(),
151
        );
152
0
        let _ = self.pool_notifier.send(true);
153
0
    }
154
155
    // Since we are using a hash key to do lookup on the inner pingora pool, do a get guard
156
    // to make sure what we pull out actually deep-equals the workload_key, to avoid *sigh* crossing the streams.
157
0
    fn guarded_get(
158
0
        &self,
159
0
        hash_key: &u64,
160
0
        workload_key: &WorkloadKey,
161
0
    ) -> Result<Option<H2ConnectClient>, Error> {
162
0
        match self.connected_pool.get(hash_key) {
163
0
            None => Ok(None),
164
0
            Some(conn) => match Self::enforce_key_integrity(conn, workload_key) {
165
0
                Err(e) => Err(e),
166
0
                Ok(conn) => Ok(Some(conn)),
167
            },
168
        }
169
0
    }
170
171
    // Just for safety's sake, since we are using a hash thanks to pingora NOT supporting arbitrary Eq, Hash
172
    // types, do a deep equality test before returning the conn, returning an error if the conn's key does
173
    // not equal the provided key
174
    //
175
    // this is a final safety check for collisions, we will throw up our hands and refuse to return the conn
176
0
    fn enforce_key_integrity(
177
0
        conn: H2ConnectClient,
178
0
        expected_key: &WorkloadKey,
179
0
    ) -> Result<H2ConnectClient, Error> {
180
0
        match conn.is_for_workload(expected_key) {
181
0
            Ok(()) => Ok(conn),
182
0
            Err(e) => Err(e),
183
        }
184
0
    }
185
186
    // 1. Tries to get a writelock.
187
    // 2. If successful, hold it, spawn a new connection, check it in, return a clone of it.
188
    // 3. If not successful, return nothing.
189
    //
190
    // This is useful if we want to race someone else to the writelock to spawn a connection,
191
    // and expect the losers to queue up and wait for the (singular) winner of the writelock
192
    //
193
    // This function should ALWAYS return a connection if it wins the writelock for the provided key.
194
    // This function should NEVER return a connection if it does not win the writelock for the provided key.
195
    // This function should ALWAYS propagate Error results to the caller
196
    //
197
    // It is important that the *initial* check here is authoritative, hence the locks, as
198
    // we must know if this is a connection for a key *nobody* has tried to start yet
199
    // (i.e. no writelock for our key in the outer map)
200
    // or if other things have already established conns for this key (writelock for our key in the outer map).
201
    //
202
    // This is so we can backpressure correctly if 1000 tasks all demand a new connection
203
    // to the same key at once, and not eagerly open 1000 tunnel connections.
204
0
    async fn start_conn_if_win_writelock(
205
0
        &self,
206
0
        workload_key: &WorkloadKey,
207
0
        pool_key: &pingora_pool::ConnectionMeta,
208
0
    ) -> Result<Option<H2ConnectClient>, Error> {
209
0
        let inner_conn_lock = {
210
0
            trace!("getting keyed lock out of lockmap");
211
0
            let guard = self.established_conn_writelock.guard();
212
213
0
            let exist_conn_lock = self
214
0
                .established_conn_writelock
215
0
                .get(&pool_key.key, &guard)
216
0
                .unwrap();
217
0
            trace!("got keyed lock out of lockmap");
218
0
            exist_conn_lock.as_ref().unwrap().clone()
219
        };
220
221
0
        trace!("attempting to win connlock for {}", workload_key);
222
223
0
        let inner_lock = inner_conn_lock.try_lock();
224
0
        match inner_lock {
225
0
            Ok(_guard) => {
226
                // BEGIN take inner writelock
227
0
                debug!("nothing else is creating a conn and we won the lock, make one");
228
0
                let client = self.spawner.new_pool_conn(workload_key.clone()).await?;
229
230
0
                debug!(
231
0
                    "checking in new conn for {} with pk {:?}",
232
                    workload_key, pool_key
233
                );
234
0
                self.maybe_checkin_conn(client.clone(), pool_key.clone());
235
0
                Ok(Some(client))
236
                // END take inner writelock
237
            }
238
            Err(_) => {
239
0
                debug!(
240
0
                    "did not win connlock for {}, something else has it",
241
                    workload_key
242
                );
243
0
                Ok(None)
244
            }
245
        }
246
0
    }
247
248
    // Does an initial, naive check to see if we have a writelock inserted into the map for this key
249
    //
250
    // If we do, take the writelock for that key, clone (or create) a connection, check it back in,
251
    // and return a cloned ref, then drop the writelock.
252
    //
253
    // Otherwise, return None.
254
    //
255
    // This function should ALWAYS return a connection if a writelock exists for the provided key.
256
    // This function should NEVER return a connection if no writelock exists for the provided key.
257
    // This function should ALWAYS propagate Error results to the caller
258
    //
259
    // It is important that the *initial* check here is authoritative, hence the locks, as
260
    // we must know if this is a connection for a key *nobody* has tried to start yet
261
    // (i.e. no writelock for our key in the outer map)
262
    // or if other things have already established conns for this key (writelock for our key in the outer map).
263
    //
264
    // This is so we can backpressure correctly if 1000 tasks all demand a new connection
265
    // to the same key at once, and not eagerly open 1000 tunnel connections.
266
0
    async fn checkout_conn_under_writelock(
267
0
        &self,
268
0
        workload_key: &WorkloadKey,
269
0
        pool_key: &pingora_pool::ConnectionMeta,
270
0
    ) -> Result<Option<H2ConnectClient>, Error> {
271
0
        let found_conn = {
272
0
            trace!("pool connect outer map - take guard");
273
0
            let guard = self.established_conn_writelock.guard();
274
275
0
            trace!("pool connect outer map - check for keyed mutex");
276
0
            let exist_conn_lock = self.established_conn_writelock.get(&pool_key.key, &guard);
277
0
            exist_conn_lock.and_then(|e_conn_lock| e_conn_lock.clone())
278
        };
279
0
        let Some(exist_conn_lock) = found_conn else {
280
0
            return Ok(None);
281
        };
282
0
        debug!(
283
0
            "checkout - found mutex for pool key {:?}, waiting for writelock",
284
            pool_key
285
        );
286
0
        let _conn_lock = exist_conn_lock.as_ref().lock().await;
287
288
0
        trace!(
289
0
            "checkout - got writelock for conn with key {} and hash {:?}",
290
            workload_key, pool_key.key
291
        );
292
0
        let returned_connection = loop {
293
0
            match self.guarded_get(&pool_key.key, workload_key)? {
294
0
                Some(mut existing) => {
295
0
                    if !existing.ready_to_use() {
296
                        // We checked this out, and will not check it back in
297
                        // Loop again to find another/make a new one
298
0
                        debug!(
299
0
                            "checked out broken connection for {}, dropping it",
300
                            workload_key
301
                        );
302
0
                        continue;
303
0
                    }
304
0
                    debug!("re-using connection for {}", workload_key);
305
0
                    break existing;
306
                }
307
                None => {
308
0
                    debug!("new connection needed for {}", workload_key);
309
0
                    break self.spawner.new_pool_conn(workload_key.clone()).await?;
310
                }
311
            };
312
        };
313
314
        // For any connection, we will check in a copy and return the other unless its already maxed out
315
        // TODO: in the future, we can keep track of these and start to use them once they finish some streams.
316
0
        self.maybe_checkin_conn(returned_connection.clone(), pool_key.clone());
317
0
        Ok(Some(returned_connection))
318
0
    }
319
}
320
321
// When the Arc-wrapped PoolState is finally dropped, trigger the drain,
322
// which will terminate all connection driver spawns, as well as cancel all outstanding eviction timeout spawns
323
impl Drop for PoolState {
324
0
    fn drop(&mut self) {
325
0
        debug!(
326
0
            "poolstate dropping, stopping all connection drivers and cancelling all outstanding eviction timeout spawns"
327
        );
328
0
        let _ = self.timeout_tx.send(true);
329
0
    }
330
}
331
332
impl WorkloadHBONEPool {
333
    // Creates a new pool instance, which should be owned by a single proxied workload.
334
    // The pool will watch the provided drain signal and drain itself when notified.
335
    // Callers should then be safe to drop() the pool instance.
336
0
    pub fn new(
337
0
        cfg: Arc<crate::config::Config>,
338
0
        socket_factory: Arc<dyn SocketFactory + Send + Sync>,
339
0
        local_workload: Arc<LocalWorkloadInformation>,
340
0
    ) -> WorkloadHBONEPool {
341
0
        let (timeout_tx, timeout_rx) = watch::channel(false);
342
0
        let (timeout_send, timeout_recv) = watch::channel(false);
343
0
        let pool_duration = cfg.pool_unused_release_timeout;
344
345
0
        let spawner = ConnSpawner {
346
0
            cfg,
347
0
            socket_factory,
348
0
            local_workload,
349
0
            timeout_rx: timeout_recv.clone(),
350
0
        };
351
352
0
        Self {
353
0
            state: Arc::new(PoolState {
354
0
                pool_notifier: timeout_tx,
355
0
                timeout_tx: timeout_send,
356
0
                // timeout_rx: timeout_recv,
357
0
                // the number here is simply the number of unique src/dest keys
358
0
                // the pool is expected to track before the inner hashmap resizes.
359
0
                connected_pool: Arc::new(pingora_pool::ConnectionPool::new(500)),
360
0
                established_conn_writelock: flurry::HashMap::new(),
361
0
                pool_unused_release_timeout: pool_duration,
362
0
                pool_global_conn_count: AtomicI32::new(0),
363
0
                spawner,
364
0
            }),
365
0
            pool_watcher: timeout_rx,
366
0
        }
367
0
    }
368
369
0
    pub async fn send_request_pooled(
370
0
        &mut self,
371
0
        workload_key: &WorkloadKey,
372
0
        request: http::Request<()>,
373
0
    ) -> Result<H2Stream, Error> {
374
0
        let mut connection = self.connect(workload_key).await?;
375
376
0
        connection.send_request(request).await
377
0
    }
378
379
    // Obtain a pooled connection. Will prefer to retrieve an existing conn from the pool, but
380
    // if none exist, or the existing conn is maxed out on streamcount, will spawn a new one,
381
    // even if it is to the same dest+port.
382
    //
383
    // If many `connects` request a connection to the same dest at once, all will wait until exactly
384
    // one connection is created, before deciding if they should create more or just use that one.
385
0
    async fn connect(&mut self, workload_key: &WorkloadKey) -> Result<H2ConnectClient, Error> {
386
0
        trace!("pool connect START");
387
        // TODO BML this may not be collision resistant, or a fast hash. It should be resistant enough for workloads tho.
388
        // We are doing a deep-equals check at the end to mitigate any collisions, will see about bumping Pingora
389
0
        let mut s = DefaultHasher::new();
390
0
        workload_key.hash(&mut s);
391
0
        let hash_key = s.finish();
392
0
        let pool_key = pingora_pool::ConnectionMeta::new(
393
0
            hash_key,
394
0
            self.state
395
0
                .pool_global_conn_count
396
0
                .fetch_add(1, Ordering::SeqCst),
397
        );
398
        // First, see if we can naively take an inner lock for our specific key, and get a connection.
399
        // This should be the common case, except for the first establishment of a new connection/key.
400
        // This will be done under outer readlock (nonexclusive)/inner keyed writelock (exclusive).
401
0
        let existing_conn = self
402
0
            .state
403
0
            .checkout_conn_under_writelock(workload_key, &pool_key)
404
0
            .await?;
405
406
        // Early return, no need to do anything else
407
0
        if let Some(e) = existing_conn {
408
0
            debug!("initial attempt - found existing conn, done");
409
0
            return Ok(e);
410
0
        }
411
412
        // We couldn't get a writelock for this key. This means nobody has tried to establish any conns for this key yet,
413
        // So, we will take a nonexclusive readlock on the outer lockmap, and attempt to insert one.
414
        //
415
        // (if multiple threads try to insert one, only one will succeed.)
416
        {
417
0
            debug!(
418
0
                "didn't find a connection for key {:?}, making sure lockmap has entry",
419
                hash_key
420
            );
421
0
            let guard = self.state.established_conn_writelock.guard();
422
0
            match self.state.established_conn_writelock.try_insert(
423
0
                hash_key,
424
0
                Some(Arc::new(Mutex::new(()))),
425
0
                &guard,
426
0
            ) {
427
                Ok(_) => {
428
0
                    debug!("inserting conn mutex for key {:?} into lockmap", hash_key);
429
                }
430
                Err(_) => {
431
0
                    debug!("already have conn for key {:?} in lockmap", hash_key);
432
                }
433
            }
434
        }
435
436
        // If we get here, it means the following are true:
437
        // 1. We have a guaranteed sharded mutex in the outer map for our current key
438
        // 2. We can now, under readlock(nonexclusive) in the outer map, attempt to
439
        // take the inner writelock for our specific key (exclusive).
440
        //
441
        // This doesn't block other tasks spawning connections against other keys, but DOES block other
442
        // tasks spawning connections against THIS key - which is what we want.
443
444
        // NOTE: The inner, key-specific mutex is a tokio::async::Mutex, and not a stdlib sync mutex.
445
        // these differ from the stdlib sync mutex in that they are (slightly) slower
446
        // (they effectively sleep the current task) and they can be held over an await.
447
        // The tokio docs (rightly) advise you to not use these,
448
        // because holding a lock over an await is a great way to create deadlocks if the await you
449
        // hold it over does not resolve.
450
        //
451
        // HOWEVER. Here we know this connection will either establish or timeout (or fail with error)
452
        // and we WANT other tasks to go back to sleep if a task is already trying to create a new connection for this key.
453
        //
454
        // So the downsides are actually useful (we WANT task contention -
455
        // to block other parallel tasks from trying to spawn a connection for this key if we are already doing so)
456
0
        trace!("fallback attempt - trying win win connlock");
457
0
        let res = match self
458
0
            .state
459
0
            .start_conn_if_win_writelock(workload_key, &pool_key)
460
0
            .await?
461
        {
462
0
            Some(client) => client,
463
            None => {
464
0
                debug!("we didn't win the lock, something else is creating a conn, wait for it");
465
                // If we get here, it means the following are true:
466
                // 1. We have a writelock in the outer map for this key (either we inserted, or someone beat us to it - but it's there)
467
                // 2. We could not get the exclusive inner writelock to add a new conn for this key.
468
                // 3. Someone else got the exclusive inner writelock, and is adding a new conn for this key.
469
                //
470
                // So, loop and wait for the pool_watcher to tell us a new conn was enpooled,
471
                // so we can pull it out and check it.
472
                loop {
473
0
                    match self.pool_watcher.changed().await {
474
                        Ok(_) => {
475
0
                            trace!(
476
0
                                "notified a new conn was enpooled, checking for hash {:?}",
477
                                hash_key
478
                            );
479
                            // Notifier fired, try and get a conn out for our key.
480
0
                            let existing_conn = self
481
0
                                .state
482
0
                                .checkout_conn_under_writelock(workload_key, &pool_key)
483
0
                                .await?;
484
0
                            match existing_conn {
485
                                None => {
486
0
                                    trace!(
487
0
                                        "woke up on pool notification, but didn't find a conn for {:?} yet",
488
                                        hash_key
489
                                    );
490
0
                                    continue;
491
                                }
492
0
                                Some(e_conn) => {
493
0
                                    debug!("found existing conn after waiting");
494
0
                                    break e_conn;
495
                                }
496
                            }
497
                        }
498
                        Err(_) => {
499
0
                            return Err(Error::WorkloadHBONEPoolDraining);
500
                        }
501
                    }
502
                }
503
            }
504
        };
505
0
        Ok(res)
506
0
    }
507
}
508
509
#[cfg(test)]
510
mod test {
511
    use std::convert::Infallible;
512
    use std::net::IpAddr;
513
    use std::net::SocketAddr;
514
    use std::time::Instant;
515
516
    use crate::{drain, identity, proxy};
517
518
    use futures_util::{StreamExt, future};
519
    use hyper::body::Incoming;
520
521
    use hickory_resolver::config::{ResolverConfig, ResolverOpts};
522
    use hyper::service::service_fn;
523
    use hyper::{Request, Response};
524
    use prometheus_client::registry::Registry;
525
    use std::sync::RwLock;
526
    use std::sync::atomic::AtomicU32;
527
    use std::time::Duration;
528
    use tokio::io::AsyncReadExt;
529
    use tokio::io::AsyncWriteExt;
530
    use tokio::net::TcpListener;
531
532
    use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
533
    use tokio::sync::oneshot;
534
535
    use tracing::{Instrument, error};
536
537
    use crate::test_helpers::helpers::initialize_telemetry;
538
539
    use crate::identity::Identity;
540
541
    use self::h2::TokioH2Stream;
542
543
    use super::*;
544
    use crate::drain::DrainWatcher;
545
    use crate::state::workload;
546
    use crate::state::{DemandProxyState, ProxyState, WorkloadInfo};
547
    use crate::test_helpers::test_default_workload;
548
    use ztunnel::test_helpers::*;
549
550
    macro_rules! assert_opens_drops {
551
        ($srv:expr_2021, $open:expr_2021, $drops:expr_2021) => {
552
            assert_eq!(
553
                $srv.conn_counter.load(Ordering::Relaxed),
554
                $open,
555
                "total connections opened, wanted {}",
556
                $open
557
            );
558
            #[allow(clippy::reversed_empty_ranges)]
559
            for want in 0..$drops {
560
                tokio::time::timeout(Duration::from_secs(2), $srv.drop_rx.recv())
561
                    .await
562
                    .expect(&format!(
563
                        "wanted {} drops, but timed out after getting {}",
564
                        $drops, want
565
                    ))
566
                    .expect("wanted drop");
567
            }
568
            assert!(
569
                $srv.drop_rx.is_empty(),
570
                "after {} drops, we shouldn't have more, but got {}",
571
                $drops,
572
                $srv.drop_rx.len()
573
            )
574
        };
575
    }
576
577
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
578
    async fn connections_reused() {
579
        let (pool, mut srv) = setup_test(3).await;
580
581
        let key = key(&srv, 2);
582
583
        // Pool allows 3. When we spawn 2 concurrently, we should open a single connection and keep it alive
584
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 2).await;
585
        assert_opens_drops!(srv, 1, 0);
586
587
        // Since the last two closed, we are free to re-use the same connection
588
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 2).await;
589
        assert_opens_drops!(srv, 1, 0);
590
591
        // Once we drop the pool, we should drop the connections as well
592
        drop(pool);
593
        assert_opens_drops!(srv, 1, 1);
594
    }
595
596
    /// This is really a test for TokioH2Stream, but its nicer here because we have access to
597
    /// streams.
598
    /// Most important, we make sure there are no panics.
599
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
600
    async fn read_buffering() {
601
        let (mut pool, srv) = setup_test(3).await;
602
603
        let key = key(&srv, 2);
604
        let req = || {
605
            http::Request::builder()
606
                .uri(srv.addr.to_string())
607
                .method(http::Method::CONNECT)
608
                .version(http::Version::HTTP_2)
609
                .body(())
610
                .unwrap()
611
        };
612
613
        let c = pool.send_request_pooled(&key.clone(), req()).await.unwrap();
614
        let mut c = TokioH2Stream::new(c);
615
        c.write_all(b"abcde").await.unwrap();
616
        let mut b = [0u8; 100];
617
        // Properly buffer reads and don't error
618
        assert_eq!(c.read(&mut b).await.unwrap(), 8);
619
        assert_eq!(&b[..8], b"poolsrv\n"); // this is added by itself
620
        assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1);
621
        assert_eq!(&b[..1], b"a");
622
        assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1);
623
        assert_eq!(&b[..1], b"b");
624
        assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1);
625
        assert_eq!(&b[..1], b"c");
626
        assert_eq!(c.read(&mut b).await.unwrap(), 2); // there are only two bytes left
627
        assert_eq!(&b[..2], b"de");
628
629
        // Once we drop the pool, we should still retained the buffered data,
630
        // but then we should error.
631
        c.write_all(b"abcde").await.unwrap();
632
        assert_eq!(c.read(&mut b[..3]).await.unwrap(), 3);
633
        assert_eq!(&b[..3], b"abc");
634
        drop(pool);
635
        assert_eq!(c.read(&mut b[..2]).await.unwrap(), 2);
636
        assert_eq!(&b[..2], b"de");
637
        assert!(c.read(&mut b).await.is_err());
638
    }
639
640
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
641
    async fn unique_keys_have_unique_connections() {
642
        let (pool, mut srv) = setup_test(3).await;
643
644
        let key1 = key(&srv, 1);
645
        let key2 = key(&srv, 2);
646
647
        test_client(pool.clone(), key1, srv.addr).await;
648
        test_client(pool.clone(), key2, srv.addr).await;
649
        assert_opens_drops!(srv, 2, 0);
650
        // Once we drop the pool, we should drop the connections as well
651
        drop(pool);
652
        assert_opens_drops!(srv, 2, 2);
653
    }
654
655
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
656
    async fn connection_limits() {
657
        let (pool, mut srv) = setup_test(2).await;
658
659
        let key = key(&srv, 1);
660
661
        // Pool allows 2. When we spawn 4 concurrently, so we need 2 connections
662
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 4).await;
663
        assert_opens_drops!(srv, 2, 2);
664
665
        // This should require 3 connections (2 already opened, 1 new). However, due to an inefficiency
666
        // in our pool, we don't properly reuse streams that hit the max.
667
        // The first batch of 4 will start a connection for the first 2 connections, and each max out so they
668
        // are not returned to the pool.
669
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 5).await;
670
        assert_opens_drops!(srv, 5, 2);
671
672
        // Once we drop the pool, we should drop the rest of the connections as well (3 new ones, and the one already checked above)
673
        drop(pool);
674
        assert_opens_drops!(srv, 5, 1);
675
    }
676
677
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
678
    async fn server_goaway() {
679
        let (pool, mut srv) = setup_test(2).await;
680
681
        let key = key(&srv, 1);
682
683
        // Establish one connection, it will be pooled
684
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 1).await;
685
        assert_opens_drops!(srv, 1, 0);
686
687
        // Trigger server GOAWAY. Wait for the server to finish
688
        srv.goaway_tx.send(()).unwrap();
689
        assert_opens_drops!(srv, 1, 1);
690
691
        // Open a new connection. We should create a new one, since the last one is busted
692
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 1).await;
693
        assert_opens_drops!(srv, 2, 0);
694
    }
695
696
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
697
    async fn single_pool() {
698
        // Test an edge case of a pool size of 1. Probably users shouldn't have pool size 1, and if
699
        // they do, we should just disable the pool. For now, we don't do that, so make sure it works.
700
        let (pool, mut srv) = setup_test(1).await;
701
702
        let key = key(&srv, 1);
703
704
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 2).await;
705
        assert_opens_drops!(srv, 2, 2);
706
    }
707
708
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
709
    async fn stress_test_single_source() {
710
        let (pool, mut srv) = setup_test(101).await;
711
712
        let key = key(&srv, 1);
713
714
        // Spin up 100 requests, they should all work
715
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 100).await;
716
        assert_opens_drops!(srv, 1, 0);
717
718
        // Once we drop the pool, we should drop the connections as well
719
        drop(pool);
720
        assert_opens_drops!(srv, 1, 1);
721
    }
722
723
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
724
    async fn stress_test_multiple_source() {
725
        let (pool, mut srv) = setup_test(100).await;
726
727
        // Spin up 100 requests each from their own source, they should all work
728
        let mut tasks = vec![];
729
        for count in 0..100 {
730
            let key = key(&srv, count);
731
            tasks.push(test_client(pool.clone(), key.clone(), srv.addr));
732
        }
733
        future::join_all(tasks).await;
734
735
        drop(pool);
736
        assert_opens_drops!(srv, 100, 100);
737
    }
738
739
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
740
    async fn stress_test_many_client_many_sources() {
741
        let (pool, mut srv) = setup_test(100).await;
742
743
        // Spin up 300 requests each from 3 different sources, they should all work
744
        let mut tasks = vec![];
745
        for count in 0..300u16 {
746
            let key = key(&srv, (count % 3) as u8);
747
            tasks.push(test_client(pool.clone(), key.clone(), srv.addr));
748
        }
749
        future::join_all(tasks).await;
750
        drop(pool);
751
        assert_opens_drops!(srv, 3, 3);
752
    }
753
754
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
755
    async fn idle_eviction() {
756
        let (pool, mut srv) = setup_test_with_idle(3, Duration::from_millis(100)).await;
757
758
        let key = key(&srv, 1);
759
760
        // Pool allows 3. When we spawn 2 concurrently, we should open a single connection and keep it alive
761
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 2).await;
762
        // After 100ms, we should drop everything
763
        assert_opens_drops!(srv, 1, 1);
764
    }
765
766
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
767
    async fn idle_eviction_with_persistent() {
768
        let (pool, mut srv) = setup_test_with_idle(4, Duration::from_millis(100)).await;
769
770
        let key = key(&srv, 1);
771
        let (client_stop_signal, client_stop) = drain::new();
772
        // Spin up 1 connection
773
        spawn_persistent_client(pool.clone(), key.clone(), srv.addr, client_stop).await;
774
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 2).await;
775
        // We shouldn't drop anything yet
776
        assert_opens_drops!(srv, 1, 0);
777
        // This should spill over into a new connection, which should drop
778
        spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 4).await;
779
        assert_opens_drops!(srv, 2, 1);
780
781
        // Trigger the persistent client to stop, we should evict that connection as well
782
        client_stop_signal
783
            .start_drain_and_wait(drain::DrainMode::Immediate)
784
            .await;
785
        assert_opens_drops!(srv, 2, 1);
786
    }
787
788
    async fn spawn_clients_concurrently(
789
        mut pool: WorkloadHBONEPool,
790
        key: WorkloadKey,
791
        remote_addr: SocketAddr,
792
        req_count: u32,
793
    ) {
794
        let (shutdown_send, _shutdown_recv) = tokio::sync::broadcast::channel::<()>(1);
795
796
        let mut tasks = vec![];
797
        for req_num in 0..req_count {
798
            let req = || {
799
                hyper::Request::builder()
800
                    .uri(format!("{remote_addr}"))
801
                    .method(hyper::Method::CONNECT)
802
                    .version(hyper::Version::HTTP_2)
803
                    .body(())
804
                    .unwrap()
805
            };
806
807
            let start = Instant::now();
808
809
            let c1 = pool
810
                .send_request_pooled(&key.clone(), req())
811
                .instrument(tracing::debug_span!("client", request = req_num))
812
                .await
813
                .expect("connect should succeed");
814
            debug!(
815
                "client spent {}ms waiting for conn",
816
                start.elapsed().as_millis()
817
            );
818
            let mut shutdown_recv = shutdown_send.subscribe();
819
            tasks.push(tokio::spawn(async move {
820
                let _ = shutdown_recv.recv().await;
821
                drop(c1);
822
                debug!("dropped stream");
823
            }));
824
        }
825
        drop(shutdown_send);
826
        future::join_all(tasks).await;
827
    }
828
829
    async fn test_client(mut pool: WorkloadHBONEPool, key: WorkloadKey, remote_addr: SocketAddr) {
830
        let req = || {
831
            hyper::Request::builder()
832
                .uri(format!("{remote_addr}"))
833
                .method(hyper::Method::CONNECT)
834
                .version(hyper::Version::HTTP_2)
835
                .body(())
836
                .unwrap()
837
        };
838
839
        let start = Instant::now();
840
841
        let _c1 = pool
842
            .send_request_pooled(&key.clone(), req())
843
            .await
844
            .expect("connect should succeed");
845
        debug!(
846
            "client spent {}ms waiting for conn",
847
            start.elapsed().as_millis()
848
        );
849
    }
850
851
    async fn spawn_persistent_client(
852
        mut pool: WorkloadHBONEPool,
853
        key: WorkloadKey,
854
        remote_addr: SocketAddr,
855
        stop: DrainWatcher,
856
    ) {
857
        let req = || {
858
            http::Request::builder()
859
                .uri(format!("{remote_addr}"))
860
                .method(http::Method::CONNECT)
861
                .version(http::Version::HTTP_2)
862
                .body(())
863
                .unwrap()
864
        };
865
866
        let start = Instant::now();
867
868
        let c1 = pool.send_request_pooled(&key.clone(), req()).await.unwrap();
869
        debug!(
870
            "client spent {}ms waiting for conn",
871
            start.elapsed().as_millis()
872
        );
873
        tokio::spawn(async move {
874
            let _ = stop.wait_for_drain().await;
875
            debug!("persistent client stop");
876
            // Close our connection
877
            drop(c1);
878
        });
879
    }
880
881
    async fn spawn_server(
882
        conn_count: Arc<AtomicU32>,
883
        drop_tx: UnboundedSender<()>,
884
        goaway: oneshot::Receiver<()>,
885
    ) -> SocketAddr {
886
        use http_body_util::Empty;
887
        // We'll bind to 127.0.0.1:3000
888
        let addr = SocketAddr::from(([127, 0, 0, 1], 0));
889
        let test_cfg = test_config();
890
        async fn hello_world(
891
            req: Request<Incoming>,
892
        ) -> Result<Response<Empty<bytes::Bytes>>, Infallible> {
893
            debug!("hello world: received request");
894
            tokio::task::spawn(async move {
895
                match hyper::upgrade::on(req).await {
896
                    Ok(upgraded) => {
897
                        let mut io = hyper_util::rt::TokioIo::new(upgraded);
898
                        io.write_all(b"poolsrv\n").await.unwrap();
899
                        tcp::handle_stream(tcp::Mode::ReadWrite, &mut io).await;
900
                    }
901
                    Err(e) => panic!("No upgrade {e}"),
902
                }
903
                debug!("hello world: completed request");
904
            });
905
            Ok::<_, Infallible>(Response::new(http_body_util::Empty::<bytes::Bytes>::new()))
906
        }
907
908
        // We create a TcpListener and bind it to 127.0.0.1:3000
909
        let listener = TcpListener::bind(addr).await.unwrap();
910
        let bound_addr = listener.local_addr().unwrap();
911
912
        let certs = crate::tls::mock::generate_test_certs(
913
            &Identity::default().into(),
914
            Duration::from_secs(0),
915
            Duration::from_secs(100),
916
        );
917
        let acceptor = crate::tls::mock::MockServerCertProvider::new(certs);
918
        let mut tls_stream = crate::hyper_util::tls_server(acceptor, listener);
919
920
        let mut goaway = Some(goaway);
921
        tokio::spawn(async move {
922
            // We start a loop to continuously accept incoming connections
923
            // and also count them
924
            let conn_count = conn_count.clone();
925
            let drop_tx = drop_tx.clone();
926
            let accept = async move {
927
                loop {
928
                    let goaway_rx = goaway.take();
929
                    let stream = tls_stream.next().await.unwrap();
930
                    conn_count.fetch_add(1, Ordering::SeqCst);
931
                    debug!("server stream started");
932
                    let drop_tx = drop_tx.clone();
933
934
                    let server = crate::hyper_util::http2_server()
935
                        .initial_stream_window_size(test_cfg.window_size)
936
                        .initial_connection_window_size(test_cfg.connection_window_size)
937
                        .max_frame_size(test_cfg.frame_size)
938
                        .max_header_list_size(65536)
939
                        .serve_connection(
940
                            hyper_util::rt::TokioIo::new(stream),
941
                            service_fn(hello_world),
942
                        );
943
944
                    // Spawn a tokio task to serve multiple connections concurrently
945
                    tokio::task::spawn(async move {
946
                        let recv = async move {
947
                            match goaway_rx {
948
                                Some(rx) => {
949
                                    let _ = rx.await;
950
                                }
951
                                None => futures_util::future::pending::<()>().await,
952
                            };
953
                        };
954
                        let res = match futures_util::future::select(Box::pin(recv), server).await {
955
                            futures_util::future::Either::Left((_shutdown, mut server)) => {
956
                                debug!("server drain starting... {_shutdown:?}");
957
                                let drain = std::pin::Pin::new(&mut server);
958
                                drain.graceful_shutdown();
959
                                let _res = server.await;
960
                                debug!("server drain done");
961
                                Ok(())
962
                            }
963
                            // Serving finished, just return the result.
964
                            futures_util::future::Either::Right((res, _shutdown)) => {
965
                                debug!("inbound serve done {:?}", res);
966
                                res
967
                            }
968
                        };
969
                        if let Err(err) = res {
970
                            error!("server failed: {err:?}");
971
                        }
972
                        let _ = drop_tx.send(());
973
                    });
974
                }
975
            };
976
            accept.await;
977
        });
978
979
        bound_addr
980
    }
981
982
    async fn setup_test(max_conns: u16) -> (WorkloadHBONEPool, TestServer) {
983
        setup_test_with_idle(max_conns, Duration::from_secs(100)).await
984
    }
985
986
    async fn setup_test_with_idle(
987
        max_conns: u16,
988
        idle: Duration,
989
    ) -> (WorkloadHBONEPool, TestServer) {
990
        initialize_telemetry();
991
        let conn_counter: Arc<AtomicU32> = Arc::new(AtomicU32::new(0));
992
        let (drop_tx, drop_rx) = tokio::sync::mpsc::unbounded_channel::<()>();
993
        let (goaway_tx, goaway_rx) = oneshot::channel::<()>();
994
        let addr = spawn_server(conn_counter.clone(), drop_tx, goaway_rx).await;
995
996
        let cfg = crate::config::Config {
997
            pool_max_streams_per_conn: max_conns,
998
            pool_unused_release_timeout: idle,
999
            ..crate::config::parse_config().unwrap()
1000
        };
1001
        let sock_fact = Arc::new(crate::proxy::DefaultSocketFactory::default());
1002
1003
        let mut state = ProxyState::new(None);
1004
        let wl = Arc::new(workload::Workload {
1005
            uid: "uid".into(),
1006
            name: "source-workload".into(),
1007
            namespace: "ns".into(),
1008
            service_account: "default".into(),
1009
            ..test_default_workload()
1010
        });
1011
        state.workloads.insert(wl.clone());
1012
        let mut registry = Registry::default();
1013
        let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry));
1014
        let mock_proxy_state = DemandProxyState::new(
1015
            Arc::new(RwLock::new(state)),
1016
            None,
1017
            ResolverConfig::default(),
1018
            ResolverOpts::default(),
1019
            metrics,
1020
        );
1021
        let local_workload = Arc::new(proxy::LocalWorkloadInformation::new(
1022
            Arc::new(WorkloadInfo {
1023
                name: wl.name.to_string(),
1024
                namespace: wl.namespace.to_string(),
1025
                service_account: wl.service_account.to_string(),
1026
            }),
1027
            mock_proxy_state,
1028
            identity::mock::new_secret_manager(Duration::from_secs(10)),
1029
        ));
1030
        let pool = WorkloadHBONEPool::new(Arc::new(cfg), sock_fact, local_workload);
1031
        let server = TestServer {
1032
            conn_counter,
1033
            drop_rx,
1034
            goaway_tx,
1035
            addr,
1036
        };
1037
        (pool, server)
1038
    }
1039
1040
    struct TestServer {
1041
        conn_counter: Arc<AtomicU32>,
1042
        drop_rx: UnboundedReceiver<()>,
1043
        goaway_tx: oneshot::Sender<()>,
1044
        addr: SocketAddr,
1045
    }
1046
1047
    fn key(srv: &TestServer, ip: u8) -> WorkloadKey {
1048
        WorkloadKey {
1049
            src_id: Identity::default(),
1050
            dst_id: vec![Identity::default()],
1051
            src: IpAddr::from([127, 0, 0, ip]),
1052
            dst: srv.addr,
1053
        }
1054
    }
1055
}