Coverage Report

Created: 2025-12-31 06:22

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/mea-0.5.2/src/waitgroup/mod.rs
Line
Count
Source
1
// Copyright 2024 tison <wander4096@gmail.com>
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
//! A synchronization primitive for waiting on multiple tasks to complete.
16
//!
17
//! Similar to Go's WaitGroup, this type allows a task to wait for multiple other
18
//! tasks to finish. Each task holds a handle to the WaitGroup, and the main task
19
//! can wait for all handles to be dropped before proceeding.
20
//!
21
//! A WaitGroup waits for a collection of tasks to finish. The main task calls
22
//! [`clone()`] to create a new worker handle for each task, and can then wait
23
//! for all tasks to complete by calling `.await` on the WaitGroup.
24
//!
25
//! # Examples
26
//!
27
//! ```
28
//! # #[tokio::main]
29
//! # async fn main() {
30
//! use std::time::Duration;
31
//!
32
//! use mea::waitgroup::WaitGroup;
33
//! let wg = WaitGroup::new();
34
//!
35
//! for i in 0..3 {
36
//!     let wg = wg.clone();
37
//!     tokio::spawn(async move {
38
//!         println!("Task {} starting", i);
39
//!         tokio::time::sleep(Duration::from_millis(100)).await;
40
//!         // wg is automatically decremented when dropped
41
//!         drop(wg);
42
//!     });
43
//! }
44
//!
45
//! // Wait for all tasks to complete
46
//! wg.await;
47
//! println!("All tasks completed");
48
//! # }
49
//! ```
50
//!
51
//! [`clone()`]: WaitGroup::clone
52
53
use std::fmt;
54
use std::future::Future;
55
use std::future::IntoFuture;
56
use std::pin::Pin;
57
use std::sync::Arc;
58
use std::task::Context;
59
use std::task::Poll;
60
61
use crate::internal::CountdownState;
62
63
#[cfg(test)]
64
mod tests;
65
66
/// A synchronization primitive for waiting on multiple tasks to complete.
67
///
68
/// See the [module level documentation](self) for more.
69
pub struct WaitGroup {
70
    state: Arc<CountdownState>,
71
}
72
73
impl fmt::Debug for WaitGroup {
74
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75
0
        f.debug_struct("WaitGroup").finish_non_exhaustive()
76
0
    }
77
}
78
79
impl Default for WaitGroup {
80
0
    fn default() -> Self {
81
0
        Self::new()
82
0
    }
83
}
84
85
impl WaitGroup {
86
    /// Creates a new `WaitGroup`.
87
    ///
88
    /// # Examples
89
    ///
90
    /// ```
91
    /// use mea::waitgroup::WaitGroup;
92
    ///
93
    /// let wg = WaitGroup::new();
94
    /// ```
95
0
    pub fn new() -> Self {
96
0
        Self {
97
0
            state: Arc::new(CountdownState::new(1)),
98
0
        }
99
0
    }
100
}
101
102
impl Clone for WaitGroup {
103
    /// Creates a new worker handle for the WaitGroup.
104
    ///
105
    /// This increments the WaitGroup counter. The counter will be decremented
106
    /// when the new handle is dropped.
107
0
    fn clone(&self) -> Self {
108
0
        let sync = self.state.clone();
109
0
        let mut cnt = sync.state();
110
        loop {
111
0
            let new_cnt = cnt.saturating_add(1);
112
0
            match sync.cas_state(cnt, new_cnt) {
113
0
                Ok(_) => return Self { state: sync },
114
0
                Err(x) => cnt = x,
115
            }
116
        }
117
0
    }
118
}
119
120
impl Drop for WaitGroup {
121
0
    fn drop(&mut self) {
122
0
        if self.state.decrement(1) {
123
0
            self.state.wake_all();
124
0
        }
125
0
    }
126
}
127
128
impl IntoFuture for WaitGroup {
129
    type Output = ();
130
    type IntoFuture = Wait;
131
132
    /// Converts the WaitGroup into a future that completes when all tasks finish. This decreases
133
    /// the WaitGroup counter.
134
0
    fn into_future(self) -> Self::IntoFuture {
135
0
        let state = self.state.clone();
136
0
        drop(self);
137
0
        Wait { idx: None, state }
138
0
    }
139
}
140
141
/// A future that completes when all tasks in a WaitGroup have finished.
142
///
143
/// This type is created by either: (1) calling `.await` on a `WaitGroup`, or (2) cloning
144
/// itself, which does not increase the WaitGroup counter, but creates a new future that
145
/// will complete when the WaitGroup counter reaches zero.
146
#[must_use = "futures do nothing unless you `.await` or poll them"]
147
pub struct Wait {
148
    idx: Option<usize>,
149
    state: Arc<CountdownState>,
150
}
151
152
impl Clone for Wait {
153
    /// Creates a new future that also completes when the WaitGroup counter reaches zero.
154
    ///
155
    /// This does not increment the WaitGroup counter.
156
0
    fn clone(&self) -> Self {
157
0
        Wait {
158
0
            idx: None,
159
0
            state: self.state.clone(),
160
0
        }
161
0
    }
162
}
163
164
impl fmt::Debug for Wait {
165
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166
0
        f.debug_struct("Wait").finish_non_exhaustive()
167
0
    }
168
}
169
170
impl Future for Wait {
171
    type Output = ();
172
173
0
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
174
0
        let Self { idx, state } = self.get_mut();
175
176
        // register waker if the counter is not zero
177
0
        if state.spin_wait(16).is_err() {
178
0
            state.register_waker(idx, cx);
179
            // double check after register waker, to catch the update between two steps
180
0
            if state.spin_wait(0).is_err() {
181
0
                return Poll::Pending;
182
0
            }
183
0
        }
184
185
0
        Poll::Ready(())
186
0
    }
187
}