Coverage Report

Created: 2025-02-25 06:39

/src/ztunnel/src/drain.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use std::future::Future;
16
use std::time::Duration;
17
use tokio::sync::watch;
18
use tracing::{debug, info, warn};
19
20
pub use internal::DrainMode;
21
pub use internal::ReleaseShutdown as DrainBlocker;
22
pub use internal::Signal as DrainTrigger;
23
pub use internal::Watch as DrainWatcher;
24
25
/// New constructs a new pair for draining
26
/// * DrainTrigger can be used to start a draining sequence and wait for it to complete.
27
/// * DrainWatcher should be held by anything that wants to participate in the draining. This can be cloned,
28
///   and a drain will not complete until all outstanding DrainWatchers are dropped.
29
0
pub fn new() -> (DrainTrigger, DrainWatcher) {
30
0
    let (tx, rx) = internal::channel();
31
0
    (tx, rx)
32
0
}
33
34
/// run_with_drain provides a wrapper to run a future with graceful shutdown/draining support.
35
/// A caller should construct a future with takes two arguments:
36
/// * drain: while holding onto this, the future is marked as active, which will block the server from shutting down.
37
///   Additionally, it can be watched (with drain.signaled()) to see when to start a graceful shutdown.
38
/// * force_shutdown: when this is triggered, the future must forcefully shutdown any ongoing work ASAP.
39
///   This means the graceful drain exceeded the hard deadline, and all work must terminate now.
40
///   This is only required for spawned() tasks; otherwise, the future is dropped entirely, canceling all work.
41
0
pub async fn run_with_drain<F, Fut, O>(
42
0
    component: String,
43
0
    drain: DrainWatcher,
44
0
    deadline: Duration,
45
0
    make_future: F,
46
0
) where
47
0
    F: FnOnce(DrainWatcher, watch::Receiver<()>) -> Fut,
48
0
    Fut: Future<Output = O>,
49
0
    O: Send + 'static,
50
0
{
Unexecuted instantiation: ztunnel::drain::run_with_drain::<<ztunnel::proxy::inbound_passthrough::InboundPassthrough>::run::{closure#0}::{closure#0}, tracing::instrument::Instrumented<<ztunnel::proxy::inbound_passthrough::InboundPassthrough>::run::{closure#0}::{closure#0}::{closure#0}>, ()>
Unexecuted instantiation: ztunnel::drain::run_with_drain::<<ztunnel::proxy::socks5::Socks5>::run::{closure#0}::{closure#0}, <ztunnel::proxy::socks5::Socks5>::run::{closure#0}::{closure#0}::{closure#0}, ()>
Unexecuted instantiation: ztunnel::drain::run_with_drain::<<ztunnel::proxy::inbound::Inbound>::run::{closure#0}::{closure#0}, tracing::instrument::Instrumented<<ztunnel::proxy::inbound::Inbound>::run::{closure#0}::{closure#0}::{closure#0}>, ()>
Unexecuted instantiation: ztunnel::drain::run_with_drain::<<ztunnel::proxy::outbound::Outbound>::run::{closure#0}::{closure#0}, tracing::instrument::Instrumented<<ztunnel::proxy::outbound::Outbound>::run::{closure#0}::{closure#0}::{closure#0}>, ()>
51
0
    let (sub_drain_signal, sub_drain) = new();
52
0
    let (trigger_force_shutdown, force_shutdown) = watch::channel(());
53
0
    // Stop accepting once we drain.
54
0
    // We will then allow connections up to `deadline` to terminate on their own.
55
0
    // After that, they will be forcefully terminated.
56
0
    let fut = make_future(sub_drain, force_shutdown);
57
0
    tokio::select! {
58
0
        _res = fut => {}
59
0
        res = drain.wait_for_drain() => {
60
0
            if res.mode() == DrainMode::Graceful {
61
0
                debug!(component, "drain started, waiting {:?} for any connections to complete", deadline);
62
0
                if tokio::time::timeout(deadline, sub_drain_signal.start_drain_and_wait(DrainMode::Graceful)).await.is_err() {
63
                    // Not all connections completed within time, we will force shut them down
64
0
                    warn!(component, "drain duration expired with pending connections, forcefully shutting down");
65
0
                }
66
            } else {
67
0
                debug!(component, "terminating");
68
            }
69
            // Trigger force shutdown. In theory, this is only needed in the timeout case. However,
70
            // it doesn't hurt to always trigger it.
71
0
            let _ = trigger_force_shutdown.send(());
72
0
73
0
            info!(component, "shutdown complete");
74
0
            drop(res);
75
        }
76
    };
77
0
}
Unexecuted instantiation: ztunnel::drain::run_with_drain::<<ztunnel::proxy::inbound_passthrough::InboundPassthrough>::run::{closure#0}::{closure#0}, tracing::instrument::Instrumented<<ztunnel::proxy::inbound_passthrough::InboundPassthrough>::run::{closure#0}::{closure#0}::{closure#0}>, ()>::{closure#0}
Unexecuted instantiation: ztunnel::drain::run_with_drain::<<ztunnel::proxy::socks5::Socks5>::run::{closure#0}::{closure#0}, <ztunnel::proxy::socks5::Socks5>::run::{closure#0}::{closure#0}::{closure#0}, ()>::{closure#0}
Unexecuted instantiation: ztunnel::drain::run_with_drain::<<ztunnel::proxy::inbound::Inbound>::run::{closure#0}::{closure#0}, tracing::instrument::Instrumented<<ztunnel::proxy::inbound::Inbound>::run::{closure#0}::{closure#0}::{closure#0}>, ()>::{closure#0}
Unexecuted instantiation: ztunnel::drain::run_with_drain::<<ztunnel::proxy::outbound::Outbound>::run::{closure#0}::{closure#0}, tracing::instrument::Instrumented<<ztunnel::proxy::outbound::Outbound>::run::{closure#0}::{closure#0}::{closure#0}>, ()>::{closure#0}
78
79
mod internal {
80
    use tokio::sync::{mpsc, watch};
81
82
    /// Creates a drain channel.
83
    ///
84
    /// The `Signal` is used to start a drain, and the `Watch` will be notified
85
    /// when a drain is signaled.
86
0
    pub fn channel() -> (Signal, Watch) {
87
0
        let (signal_tx, signal_rx) = watch::channel(None);
88
0
        let (drained_tx, drained_rx) = mpsc::channel(1);
89
0
90
0
        let signal = Signal {
91
0
            drained_rx,
92
0
            signal_tx,
93
0
        };
94
0
        let watch = Watch {
95
0
            drained_tx,
96
0
            signal_rx,
97
0
        };
98
0
        (signal, watch)
99
0
    }
100
101
    enum Never {}
102
103
    #[derive(Debug, Clone, Copy, PartialEq)]
104
    pub enum DrainMode {
105
        Immediate,
106
        Graceful,
107
    }
108
109
    /// Send a drain command to all watchers.
110
    pub struct Signal {
111
        drained_rx: mpsc::Receiver<Never>,
112
        signal_tx: watch::Sender<Option<DrainMode>>,
113
    }
114
115
    /// Watch for a drain command.
116
    ///
117
    /// All `Watch` instances must be dropped for a `Signal::signal` call to
118
    /// complete.
119
    #[derive(Clone)]
120
    pub struct Watch {
121
        drained_tx: mpsc::Sender<Never>,
122
        signal_rx: watch::Receiver<Option<DrainMode>>,
123
    }
124
125
    #[must_use = "ReleaseShutdown should be dropped explicitly to release the runtime"]
126
    #[derive(Clone)]
127
    #[allow(dead_code)]
128
    pub struct ReleaseShutdown(mpsc::Sender<Never>, DrainMode);
129
130
    impl ReleaseShutdown {
131
0
        pub fn mode(&self) -> DrainMode {
132
0
            self.1
133
0
        }
134
    }
135
136
    impl Signal {
137
        /// Waits for all [`Watch`] instances to be dropped.
138
0
        pub async fn closed(&mut self) {
139
0
            self.signal_tx.closed().await;
140
0
        }
141
142
        /// Asynchronously signals all watchers to begin draining gracefully and waits for all
143
        /// handles to be dropped.
144
0
        pub async fn start_drain_and_wait(mut self, mode: DrainMode) {
145
0
            // Update the state of the signal watch so that all watchers are observe
146
0
            // the change.
147
0
            let _ = self.signal_tx.send(Some(mode));
148
0
149
0
            // Wait for all watchers to release their drain handle.
150
0
            match self.drained_rx.recv().await {
151
0
                None => {}
152
0
                Some(n) => match n {},
153
0
            }
154
0
        }
155
    }
156
157
    impl Watch {
158
        /// Returns a `ReleaseShutdown` handle after the drain has been signaled. The
159
        /// handle must be dropped when a shutdown action has been completed to
160
        /// unblock graceful shutdown.
161
0
        pub async fn wait_for_drain(mut self) -> ReleaseShutdown {
162
            // This future completes once `Signal::signal` has been invoked so that
163
            // the channel's state is updated.
164
0
            let mode = self
165
0
                .signal_rx
166
0
                .wait_for(Option::is_some)
167
0
                .await
168
0
                .map(|mode| mode.expect("already asserted it is_some"))
169
0
                // If we got an error, then the signal was dropped entirely. Presumably this means a graceful shutdown is not required.
170
0
                .unwrap_or(DrainMode::Immediate);
171
0
172
0
            // Return a handle that holds the drain channel, so that the signal task
173
0
            // is only notified when all handles have been dropped.
174
0
            ReleaseShutdown(self.drained_tx, mode)
175
0
        }
176
    }
177
178
    impl std::fmt::Debug for Signal {
179
0
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180
0
            f.debug_struct("Signal").finish_non_exhaustive()
181
0
        }
182
    }
183
184
    impl std::fmt::Debug for Watch {
185
0
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186
0
            f.debug_struct("Watch").finish_non_exhaustive()
187
0
        }
188
    }
189
190
    impl std::fmt::Debug for ReleaseShutdown {
191
0
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192
0
            f.debug_struct("ReleaseShutdown").finish_non_exhaustive()
193
0
        }
194
    }
195
}