/rust/registry/src/index.crates.io-1949cf8c6b5b557f/mea-0.5.2/src/waitgroup/mod.rs
Line | Count | Source |
1 | | // Copyright 2024 tison <wander4096@gmail.com> |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | //! A synchronization primitive for waiting on multiple tasks to complete. |
16 | | //! |
17 | | //! Similar to Go's WaitGroup, this type allows a task to wait for multiple other |
18 | | //! tasks to finish. Each task holds a handle to the WaitGroup, and the main task |
19 | | //! can wait for all handles to be dropped before proceeding. |
20 | | //! |
21 | | //! A WaitGroup waits for a collection of tasks to finish. The main task calls |
22 | | //! [`clone()`] to create a new worker handle for each task, and can then wait |
23 | | //! for all tasks to complete by calling `.await` on the WaitGroup. |
24 | | //! |
25 | | //! # Examples |
26 | | //! |
27 | | //! ``` |
28 | | //! # #[tokio::main] |
29 | | //! # async fn main() { |
30 | | //! use std::time::Duration; |
31 | | //! |
32 | | //! use mea::waitgroup::WaitGroup; |
33 | | //! let wg = WaitGroup::new(); |
34 | | //! |
35 | | //! for i in 0..3 { |
36 | | //! let wg = wg.clone(); |
37 | | //! tokio::spawn(async move { |
38 | | //! println!("Task {} starting", i); |
39 | | //! tokio::time::sleep(Duration::from_millis(100)).await; |
40 | | //! // wg is automatically decremented when dropped |
41 | | //! drop(wg); |
42 | | //! }); |
43 | | //! } |
44 | | //! |
45 | | //! // Wait for all tasks to complete |
46 | | //! wg.await; |
47 | | //! println!("All tasks completed"); |
48 | | //! # } |
49 | | //! ``` |
50 | | //! |
51 | | //! [`clone()`]: WaitGroup::clone |
52 | | |
53 | | use std::fmt; |
54 | | use std::future::Future; |
55 | | use std::future::IntoFuture; |
56 | | use std::pin::Pin; |
57 | | use std::sync::Arc; |
58 | | use std::task::Context; |
59 | | use std::task::Poll; |
60 | | |
61 | | use crate::internal::CountdownState; |
62 | | |
63 | | #[cfg(test)] |
64 | | mod tests; |
65 | | |
66 | | /// A synchronization primitive for waiting on multiple tasks to complete. |
67 | | /// |
68 | | /// See the [module level documentation](self) for more. |
69 | | pub struct WaitGroup { |
70 | | state: Arc<CountdownState>, |
71 | | } |
72 | | |
73 | | impl fmt::Debug for WaitGroup { |
74 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
75 | 0 | f.debug_struct("WaitGroup").finish_non_exhaustive() |
76 | 0 | } |
77 | | } |
78 | | |
79 | | impl Default for WaitGroup { |
80 | 0 | fn default() -> Self { |
81 | 0 | Self::new() |
82 | 0 | } |
83 | | } |
84 | | |
85 | | impl WaitGroup { |
86 | | /// Creates a new `WaitGroup`. |
87 | | /// |
88 | | /// # Examples |
89 | | /// |
90 | | /// ``` |
91 | | /// use mea::waitgroup::WaitGroup; |
92 | | /// |
93 | | /// let wg = WaitGroup::new(); |
94 | | /// ``` |
95 | 0 | pub fn new() -> Self { |
96 | 0 | Self { |
97 | 0 | state: Arc::new(CountdownState::new(1)), |
98 | 0 | } |
99 | 0 | } |
100 | | } |
101 | | |
102 | | impl Clone for WaitGroup { |
103 | | /// Creates a new worker handle for the WaitGroup. |
104 | | /// |
105 | | /// This increments the WaitGroup counter. The counter will be decremented |
106 | | /// when the new handle is dropped. |
107 | 0 | fn clone(&self) -> Self { |
108 | 0 | let sync = self.state.clone(); |
109 | 0 | let mut cnt = sync.state(); |
110 | | loop { |
111 | 0 | let new_cnt = cnt.saturating_add(1); |
112 | 0 | match sync.cas_state(cnt, new_cnt) { |
113 | 0 | Ok(_) => return Self { state: sync }, |
114 | 0 | Err(x) => cnt = x, |
115 | | } |
116 | | } |
117 | 0 | } |
118 | | } |
119 | | |
120 | | impl Drop for WaitGroup { |
121 | 0 | fn drop(&mut self) { |
122 | 0 | if self.state.decrement(1) { |
123 | 0 | self.state.wake_all(); |
124 | 0 | } |
125 | 0 | } |
126 | | } |
127 | | |
128 | | impl IntoFuture for WaitGroup { |
129 | | type Output = (); |
130 | | type IntoFuture = Wait; |
131 | | |
132 | | /// Converts the WaitGroup into a future that completes when all tasks finish. This decreases |
133 | | /// the WaitGroup counter. |
134 | 0 | fn into_future(self) -> Self::IntoFuture { |
135 | 0 | let state = self.state.clone(); |
136 | 0 | drop(self); |
137 | 0 | Wait { idx: None, state } |
138 | 0 | } |
139 | | } |
140 | | |
141 | | /// A future that completes when all tasks in a WaitGroup have finished. |
142 | | /// |
143 | | /// This type is created by either: (1) calling `.await` on a `WaitGroup`, or (2) cloning |
144 | | /// itself, which does not increase the WaitGroup counter, but creates a new future that |
145 | | /// will complete when the WaitGroup counter reaches zero. |
146 | | #[must_use = "futures do nothing unless you `.await` or poll them"] |
147 | | pub struct Wait { |
148 | | idx: Option<usize>, |
149 | | state: Arc<CountdownState>, |
150 | | } |
151 | | |
152 | | impl Clone for Wait { |
153 | | /// Creates a new future that also completes when the WaitGroup counter reaches zero. |
154 | | /// |
155 | | /// This does not increment the WaitGroup counter. |
156 | 0 | fn clone(&self) -> Self { |
157 | 0 | Wait { |
158 | 0 | idx: None, |
159 | 0 | state: self.state.clone(), |
160 | 0 | } |
161 | 0 | } |
162 | | } |
163 | | |
164 | | impl fmt::Debug for Wait { |
165 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
166 | 0 | f.debug_struct("Wait").finish_non_exhaustive() |
167 | 0 | } |
168 | | } |
169 | | |
170 | | impl Future for Wait { |
171 | | type Output = (); |
172 | | |
173 | 0 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
174 | 0 | let Self { idx, state } = self.get_mut(); |
175 | | |
176 | | // register waker if the counter is not zero |
177 | 0 | if state.spin_wait(16).is_err() { |
178 | 0 | state.register_waker(idx, cx); |
179 | | // double check after register waker, to catch the update between two steps |
180 | 0 | if state.spin_wait(0).is_err() { |
181 | 0 | return Poll::Pending; |
182 | 0 | } |
183 | 0 | } |
184 | | |
185 | 0 | Poll::Ready(()) |
186 | 0 | } |
187 | | } |