/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tokio-1.48.0/src/sync/barrier.rs
Line | Count | Source |
1 | | use crate::loom::sync::Mutex; |
2 | | use crate::sync::watch; |
3 | | #[cfg(all(tokio_unstable, feature = "tracing"))] |
4 | | use crate::util::trace; |
5 | | |
6 | | /// A barrier enables multiple tasks to synchronize the beginning of some computation. |
7 | | /// |
8 | | /// ``` |
9 | | /// # #[tokio::main(flavor = "current_thread")] |
10 | | /// # async fn main() { |
11 | | /// use tokio::sync::Barrier; |
12 | | /// use std::sync::Arc; |
13 | | /// |
14 | | /// let mut handles = Vec::with_capacity(10); |
15 | | /// let barrier = Arc::new(Barrier::new(10)); |
16 | | /// for _ in 0..10 { |
17 | | /// let c = barrier.clone(); |
18 | | /// // The same messages will be printed together. |
19 | | /// // You will NOT see any interleaving. |
20 | | /// handles.push(tokio::spawn(async move { |
21 | | /// println!("before wait"); |
22 | | /// let wait_result = c.wait().await; |
23 | | /// println!("after wait"); |
24 | | /// wait_result |
25 | | /// })); |
26 | | /// } |
27 | | /// |
28 | | /// // Will not resolve until all "after wait" messages have been printed |
29 | | /// let mut num_leaders = 0; |
30 | | /// for handle in handles { |
31 | | /// let wait_result = handle.await.unwrap(); |
32 | | /// if wait_result.is_leader() { |
33 | | /// num_leaders += 1; |
34 | | /// } |
35 | | /// } |
36 | | /// |
37 | | /// // Exactly one barrier will resolve as the "leader" |
38 | | /// assert_eq!(num_leaders, 1); |
39 | | /// # } |
40 | | /// ``` |
41 | | #[derive(Debug)] |
42 | | pub struct Barrier { |
43 | | state: Mutex<BarrierState>, |
44 | | wait: watch::Receiver<usize>, |
45 | | n: usize, |
46 | | #[cfg(all(tokio_unstable, feature = "tracing"))] |
47 | | resource_span: tracing::Span, |
48 | | } |
49 | | |
50 | | #[derive(Debug)] |
51 | | struct BarrierState { |
52 | | waker: watch::Sender<usize>, |
53 | | arrived: usize, |
54 | | generation: usize, |
55 | | } |
56 | | |
57 | | impl Barrier { |
58 | | /// Creates a new barrier that can block a given number of tasks. |
59 | | /// |
60 | | /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all |
61 | | /// tasks at once when the `n`th task calls `wait`. |
62 | | #[track_caller] |
63 | 0 | pub fn new(mut n: usize) -> Barrier { |
64 | 0 | let (waker, wait) = crate::sync::watch::channel(0); |
65 | | |
66 | 0 | if n == 0 { |
67 | 0 | // if n is 0, it's not clear what behavior the user wants. |
68 | 0 | // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every |
69 | 0 | // .wait() immediately unblocks, so we adopt that here as well. |
70 | 0 | n = 1; |
71 | 0 | } |
72 | | |
73 | | #[cfg(all(tokio_unstable, feature = "tracing"))] |
74 | | let resource_span = { |
75 | | let location = std::panic::Location::caller(); |
76 | | let resource_span = tracing::trace_span!( |
77 | | parent: None, |
78 | | "runtime.resource", |
79 | | concrete_type = "Barrier", |
80 | | kind = "Sync", |
81 | | loc.file = location.file(), |
82 | | loc.line = location.line(), |
83 | | loc.col = location.column(), |
84 | | ); |
85 | | |
86 | | resource_span.in_scope(|| { |
87 | | tracing::trace!( |
88 | | target: "runtime::resource::state_update", |
89 | | size = n, |
90 | | ); |
91 | | |
92 | | tracing::trace!( |
93 | | target: "runtime::resource::state_update", |
94 | | arrived = 0, |
95 | | ) |
96 | | }); |
97 | | resource_span |
98 | | }; |
99 | | |
100 | 0 | Barrier { |
101 | 0 | state: Mutex::new(BarrierState { |
102 | 0 | waker, |
103 | 0 | arrived: 0, |
104 | 0 | generation: 1, |
105 | 0 | }), |
106 | 0 | n, |
107 | 0 | wait, |
108 | 0 | #[cfg(all(tokio_unstable, feature = "tracing"))] |
109 | 0 | resource_span, |
110 | 0 | } |
111 | 0 | } |
112 | | |
113 | | /// Does not resolve until all tasks have rendezvoused here. |
114 | | /// |
115 | | /// Barriers are re-usable after all tasks have rendezvoused once, and can |
116 | | /// be used continuously. |
117 | | /// |
118 | | /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from |
119 | | /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks |
120 | | /// will receive a result that will return `false` from `is_leader`. |
121 | | /// |
122 | | /// # Cancel safety |
123 | | /// |
124 | | /// This method is not cancel safe. |
125 | 0 | pub async fn wait(&self) -> BarrierWaitResult { |
126 | | #[cfg(all(tokio_unstable, feature = "tracing"))] |
127 | | return trace::async_op( |
128 | | || self.wait_internal(), |
129 | | self.resource_span.clone(), |
130 | | "Barrier::wait", |
131 | | "poll", |
132 | | false, |
133 | | ) |
134 | | .await; |
135 | | |
136 | | #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] |
137 | 0 | return self.wait_internal().await; |
138 | 0 | } |
139 | 0 | async fn wait_internal(&self) -> BarrierWaitResult { |
140 | 0 | crate::trace::async_trace_leaf().await; |
141 | | |
142 | | // NOTE: we are taking a _synchronous_ lock here. |
143 | | // It is okay to do so because the critical section is fast and never yields, so it cannot |
144 | | // deadlock even if another future is concurrently holding the lock. |
145 | | // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than |
146 | | // the asynchronous counter-parts, so we should use them where possible [citation needed]. |
147 | | // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across |
148 | | // a yield point, and thus marks the returned future as !Send. |
149 | 0 | let generation = { |
150 | 0 | let mut state = self.state.lock(); |
151 | 0 | let generation = state.generation; |
152 | 0 | state.arrived += 1; |
153 | | #[cfg(all(tokio_unstable, feature = "tracing"))] |
154 | | tracing::trace!( |
155 | | target: "runtime::resource::state_update", |
156 | | arrived = 1, |
157 | | arrived.op = "add", |
158 | | ); |
159 | | #[cfg(all(tokio_unstable, feature = "tracing"))] |
160 | | tracing::trace!( |
161 | | target: "runtime::resource::async_op::state_update", |
162 | | arrived = true, |
163 | | ); |
164 | 0 | if state.arrived == self.n { |
165 | | #[cfg(all(tokio_unstable, feature = "tracing"))] |
166 | | tracing::trace!( |
167 | | target: "runtime::resource::async_op::state_update", |
168 | | is_leader = true, |
169 | | ); |
170 | | // we are the leader for this generation |
171 | | // wake everyone, increment the generation, and return |
172 | 0 | state |
173 | 0 | .waker |
174 | 0 | .send(state.generation) |
175 | 0 | .expect("there is at least one receiver"); |
176 | 0 | state.arrived = 0; |
177 | 0 | state.generation += 1; |
178 | 0 | return BarrierWaitResult(true); |
179 | 0 | } |
180 | | |
181 | 0 | generation |
182 | | }; |
183 | | |
184 | | // we're going to have to wait for the last of the generation to arrive |
185 | 0 | let mut wait = self.wait.clone(); |
186 | | |
187 | | loop { |
188 | 0 | let _ = wait.changed().await; |
189 | | |
190 | | // note that the first time through the loop, this _will_ yield a generation |
191 | | // immediately, since we cloned a receiver that has never seen any values. |
192 | 0 | if *wait.borrow() >= generation { |
193 | 0 | break; |
194 | 0 | } |
195 | | } |
196 | | |
197 | 0 | BarrierWaitResult(false) |
198 | 0 | } |
199 | | } |
200 | | |
201 | | /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. |
202 | | #[derive(Debug, Clone)] |
203 | | pub struct BarrierWaitResult(bool); |
204 | | |
205 | | impl BarrierWaitResult { |
206 | | /// Returns `true` if this task from wait is the "leader task". |
207 | | /// |
208 | | /// Only one task will have `true` returned from their result, all other tasks will have |
209 | | /// `false` returned. |
210 | 0 | pub fn is_leader(&self) -> bool { |
211 | 0 | self.0 |
212 | 0 | } |
213 | | } |