Coverage Report

Created: 2025-10-29 07:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/hyper-util-0.1.10/src/server/graceful.rs
Line
Count
Source
1
//! Utility to gracefully shutdown a server.
2
//!
3
//! This module provides a [`GracefulShutdown`] type,
4
//! which can be used to gracefully shutdown a server.
5
//!
6
//! See <https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs>
7
//! for an example of how to use this.
8
9
use std::{
10
    fmt::{self, Debug},
11
    future::Future,
12
    pin::Pin,
13
    task::{self, Poll},
14
};
15
16
use pin_project_lite::pin_project;
17
use tokio::sync::watch;
18
19
/// A graceful shutdown utility
20
pub struct GracefulShutdown {
21
    tx: watch::Sender<()>,
22
}
23
24
impl GracefulShutdown {
25
    /// Create a new graceful shutdown helper.
26
0
    pub fn new() -> Self {
27
0
        let (tx, _) = watch::channel(());
28
0
        Self { tx }
29
0
    }
30
31
    /// Wrap a future for graceful shutdown watching.
32
0
    pub fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
33
0
        let mut rx = self.tx.subscribe();
34
0
        GracefulConnectionFuture::new(conn, async move {
35
0
            let _ = rx.changed().await;
36
            // hold onto the rx until the watched future is completed
37
0
            rx
38
0
        })
39
0
    }
40
41
    /// Signal shutdown for all watched connections.
42
    ///
43
    /// This returns a `Future` which will complete once all watched
44
    /// connections have shutdown.
45
0
    pub async fn shutdown(self) {
46
0
        let Self { tx } = self;
47
48
        // signal all the watched futures about the change
49
0
        let _ = tx.send(());
50
        // and then wait for all of them to complete
51
0
        tx.closed().await;
52
0
    }
53
}
54
55
impl Debug for GracefulShutdown {
56
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57
0
        f.debug_struct("GracefulShutdown").finish()
58
0
    }
59
}
60
61
impl Default for GracefulShutdown {
62
0
    fn default() -> Self {
63
0
        Self::new()
64
0
    }
65
}
66
67
pin_project! {
68
    struct GracefulConnectionFuture<C, F: Future> {
69
        #[pin]
70
        conn: C,
71
        #[pin]
72
        cancel: F,
73
        #[pin]
74
        // If cancelled, this is held until the inner conn is done.
75
        cancelled_guard: Option<F::Output>,
76
    }
77
}
78
79
impl<C, F: Future> GracefulConnectionFuture<C, F> {
80
0
    fn new(conn: C, cancel: F) -> Self {
81
0
        Self {
82
0
            conn,
83
0
            cancel,
84
0
            cancelled_guard: None,
85
0
        }
86
0
    }
87
}
88
89
impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
90
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91
0
        f.debug_struct("GracefulConnectionFuture").finish()
92
0
    }
93
}
94
95
impl<C, F> Future for GracefulConnectionFuture<C, F>
96
where
97
    C: GracefulConnection,
98
    F: Future,
99
{
100
    type Output = C::Output;
101
102
0
    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
103
0
        let mut this = self.project();
104
0
        if this.cancelled_guard.is_none() {
105
0
            if let Poll::Ready(guard) = this.cancel.poll(cx) {
106
0
                this.cancelled_guard.set(Some(guard));
107
0
                this.conn.as_mut().graceful_shutdown();
108
0
            }
109
0
        }
110
0
        this.conn.poll(cx)
111
0
    }
112
}
113
114
/// An internal utility trait as an umbrella target for all (hyper) connection
115
/// types that the [`GracefulShutdown`] can watch.
116
pub trait GracefulConnection: Future<Output = Result<(), Self::Error>> + private::Sealed {
117
    /// The error type returned by the connection when used as a future.
118
    type Error;
119
120
    /// Start a graceful shutdown process for this connection.
121
    fn graceful_shutdown(self: Pin<&mut Self>);
122
}
123
124
#[cfg(feature = "http1")]
125
impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
126
where
127
    S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
128
    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
129
    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
130
    B: hyper::body::Body + 'static,
131
    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
132
{
133
    type Error = hyper::Error;
134
135
0
    fn graceful_shutdown(self: Pin<&mut Self>) {
136
0
        hyper::server::conn::http1::Connection::graceful_shutdown(self);
137
0
    }
138
}
139
140
#[cfg(feature = "http2")]
141
impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
142
where
143
    S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
144
    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
145
    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
146
    B: hyper::body::Body + 'static,
147
    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
148
    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
149
{
150
    type Error = hyper::Error;
151
152
0
    fn graceful_shutdown(self: Pin<&mut Self>) {
153
0
        hyper::server::conn::http2::Connection::graceful_shutdown(self);
154
0
    }
155
}
156
157
#[cfg(feature = "server-auto")]
158
impl<'a, I, B, S, E> GracefulConnection for crate::server::conn::auto::Connection<'a, I, S, E>
159
where
160
    S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
161
    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
162
    S::Future: 'static,
163
    I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
164
    B: hyper::body::Body + 'static,
165
    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
166
    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
167
{
168
    type Error = Box<dyn std::error::Error + Send + Sync>;
169
170
0
    fn graceful_shutdown(self: Pin<&mut Self>) {
171
0
        crate::server::conn::auto::Connection::graceful_shutdown(self);
172
0
    }
173
}
174
175
#[cfg(feature = "server-auto")]
176
impl<'a, I, B, S, E> GracefulConnection
177
    for crate::server::conn::auto::UpgradeableConnection<'a, I, S, E>
178
where
179
    S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
180
    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
181
    S::Future: 'static,
182
    I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
183
    B: hyper::body::Body + 'static,
184
    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
185
    E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
186
{
187
    type Error = Box<dyn std::error::Error + Send + Sync>;
188
189
0
    fn graceful_shutdown(self: Pin<&mut Self>) {
190
0
        crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
191
0
    }
192
}
193
194
mod private {
195
    pub trait Sealed {}
196
197
    #[cfg(feature = "http1")]
198
    impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
199
    where
200
        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
201
        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
202
        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
203
        B: hyper::body::Body + 'static,
204
        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
205
    {
206
    }
207
208
    #[cfg(feature = "http1")]
209
    impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
210
    where
211
        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
212
        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
213
        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
214
        B: hyper::body::Body + 'static,
215
        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
216
    {
217
    }
218
219
    #[cfg(feature = "http2")]
220
    impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
221
    where
222
        S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
223
        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
224
        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
225
        B: hyper::body::Body + 'static,
226
        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
227
        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
228
    {
229
    }
230
231
    #[cfg(feature = "server-auto")]
232
    impl<'a, I, B, S, E> Sealed for crate::server::conn::auto::Connection<'a, I, S, E>
233
    where
234
        S: hyper::service::Service<
235
            http::Request<hyper::body::Incoming>,
236
            Response = http::Response<B>,
237
        >,
238
        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
239
        S::Future: 'static,
240
        I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
241
        B: hyper::body::Body + 'static,
242
        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
243
        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
244
    {
245
    }
246
247
    #[cfg(feature = "server-auto")]
248
    impl<'a, I, B, S, E> Sealed for crate::server::conn::auto::UpgradeableConnection<'a, I, S, E>
249
    where
250
        S: hyper::service::Service<
251
            http::Request<hyper::body::Incoming>,
252
            Response = http::Response<B>,
253
        >,
254
        S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
255
        S::Future: 'static,
256
        I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
257
        B: hyper::body::Body + 'static,
258
        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
259
        E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
260
    {
261
    }
262
}
263
264
#[cfg(test)]
265
mod test {
266
    use super::*;
267
    use pin_project_lite::pin_project;
268
    use std::sync::atomic::{AtomicUsize, Ordering};
269
    use std::sync::Arc;
270
271
    pin_project! {
272
        #[derive(Debug)]
273
        struct DummyConnection<F> {
274
            #[pin]
275
            future: F,
276
            shutdown_counter: Arc<AtomicUsize>,
277
        }
278
    }
279
280
    impl<F> private::Sealed for DummyConnection<F> {}
281
282
    impl<F: Future> GracefulConnection for DummyConnection<F> {
283
        type Error = ();
284
285
        fn graceful_shutdown(self: Pin<&mut Self>) {
286
            self.shutdown_counter.fetch_add(1, Ordering::SeqCst);
287
        }
288
    }
289
290
    impl<F: Future> Future for DummyConnection<F> {
291
        type Output = Result<(), ()>;
292
293
        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
294
            match self.project().future.poll(cx) {
295
                Poll::Ready(_) => Poll::Ready(Ok(())),
296
                Poll::Pending => Poll::Pending,
297
            }
298
        }
299
    }
300
301
    #[cfg(not(miri))]
302
    #[tokio::test]
303
    async fn test_graceful_shutdown_ok() {
304
        let graceful = GracefulShutdown::new();
305
        let shutdown_counter = Arc::new(AtomicUsize::new(0));
306
        let (dummy_tx, _) = tokio::sync::broadcast::channel(1);
307
308
        for i in 1..=3 {
309
            let mut dummy_rx = dummy_tx.subscribe();
310
            let shutdown_counter = shutdown_counter.clone();
311
312
            let future = async move {
313
                tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await;
314
                let _ = dummy_rx.recv().await;
315
            };
316
            let dummy_conn = DummyConnection {
317
                future,
318
                shutdown_counter,
319
            };
320
            let conn = graceful.watch(dummy_conn);
321
            tokio::spawn(async move {
322
                conn.await.unwrap();
323
            });
324
        }
325
326
        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
327
        let _ = dummy_tx.send(());
328
329
        tokio::select! {
330
            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
331
                panic!("timeout")
332
            },
333
            _ = graceful.shutdown() => {
334
                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
335
            }
336
        }
337
    }
338
339
    #[cfg(not(miri))]
340
    #[tokio::test]
341
    async fn test_graceful_shutdown_delayed_ok() {
342
        let graceful = GracefulShutdown::new();
343
        let shutdown_counter = Arc::new(AtomicUsize::new(0));
344
345
        for i in 1..=3 {
346
            let shutdown_counter = shutdown_counter.clone();
347
348
            //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await;
349
            let future = async move {
350
                tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await;
351
            };
352
            let dummy_conn = DummyConnection {
353
                future,
354
                shutdown_counter,
355
            };
356
            let conn = graceful.watch(dummy_conn);
357
            tokio::spawn(async move {
358
                conn.await.unwrap();
359
            });
360
        }
361
362
        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
363
364
        tokio::select! {
365
            _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
366
                panic!("timeout")
367
            },
368
            _ = graceful.shutdown() => {
369
                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
370
            }
371
        }
372
    }
373
374
    #[cfg(not(miri))]
375
    #[tokio::test]
376
    async fn test_graceful_shutdown_multi_per_watcher_ok() {
377
        let graceful = GracefulShutdown::new();
378
        let shutdown_counter = Arc::new(AtomicUsize::new(0));
379
380
        for i in 1..=3 {
381
            let shutdown_counter = shutdown_counter.clone();
382
383
            let mut futures = Vec::new();
384
            for u in 1..=i {
385
                let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50));
386
                let dummy_conn = DummyConnection {
387
                    future,
388
                    shutdown_counter: shutdown_counter.clone(),
389
                };
390
                let conn = graceful.watch(dummy_conn);
391
                futures.push(conn);
392
            }
393
            tokio::spawn(async move {
394
                futures_util::future::join_all(futures).await;
395
            });
396
        }
397
398
        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
399
400
        tokio::select! {
401
            _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => {
402
                panic!("timeout")
403
            },
404
            _ = graceful.shutdown() => {
405
                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6);
406
            }
407
        }
408
    }
409
410
    #[cfg(not(miri))]
411
    #[tokio::test]
412
    async fn test_graceful_shutdown_timeout() {
413
        let graceful = GracefulShutdown::new();
414
        let shutdown_counter = Arc::new(AtomicUsize::new(0));
415
416
        for i in 1..=3 {
417
            let shutdown_counter = shutdown_counter.clone();
418
419
            let future = async move {
420
                if i == 1 {
421
                    std::future::pending::<()>().await
422
                } else {
423
                    std::future::ready(()).await
424
                }
425
            };
426
            let dummy_conn = DummyConnection {
427
                future,
428
                shutdown_counter,
429
            };
430
            let conn = graceful.watch(dummy_conn);
431
            tokio::spawn(async move {
432
                conn.await.unwrap();
433
            });
434
        }
435
436
        assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0);
437
438
        tokio::select! {
439
            _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
440
                assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3);
441
            },
442
            _ = graceful.shutdown() => {
443
                panic!("shutdown should not be completed: as not all our conns finish")
444
            }
445
        }
446
    }
447
}