/rust/registry/src/index.crates.io-1949cf8c6b5b557f/mea-0.6.0/src/mpsc/bounded.rs
Line | Count | Source |
1 | | // Copyright 2024 tison <wander4096@gmail.com> |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | //! A bounded multi-producer, single-consumer queue for sending values between asynchronous |
16 | | //! tasks with backpressure control. |
17 | | |
18 | | use std::fmt; |
19 | | use std::future::Future; |
20 | | use std::future::poll_fn; |
21 | | use std::pin::pin; |
22 | | use std::sync::Arc; |
23 | | use std::sync::atomic::AtomicUsize; |
24 | | use std::sync::atomic::Ordering; |
25 | | use std::task::Context; |
26 | | use std::task::Poll; |
27 | | use std::task::Waker; |
28 | | |
29 | | use crate::atomicbox::AtomicOptionBox; |
30 | | use crate::internal::Acquire; |
31 | | use crate::internal::Semaphore; |
32 | | use crate::mpsc::RecvError; |
33 | | use crate::mpsc::SendError; |
34 | | use crate::mpsc::TryRecvError; |
35 | | use crate::mpsc::error::TrySendError; |
36 | | |
37 | | /// Creates a bounded mpsc channel for communicating between asynchronous |
38 | | /// tasks with backpressure. |
39 | | /// |
40 | | /// A `send` on this channel will wait if the buffer of the channel is full until a |
41 | | /// `recv` is called on the receiver, which will consume the message and |
42 | | /// free up space in the buffer. |
43 | | #[track_caller] |
44 | 0 | pub fn bounded<T>(buffer: usize) -> (BoundedSender<T>, BoundedReceiver<T>) { |
45 | 0 | assert!(buffer > 0, "mpsc bounded channel requires buffer > 0"); |
46 | 0 | let state = Arc::new(BoundedState { |
47 | 0 | senders: AtomicUsize::new(1), |
48 | 0 | tx_permits: Semaphore::new(0), |
49 | 0 | rx_task: AtomicOptionBox::none(), |
50 | 0 | }); |
51 | 0 | let (sender, receiver) = std::sync::mpsc::sync_channel(buffer); |
52 | 0 | let sender = BoundedSender { |
53 | 0 | state: state.clone(), |
54 | 0 | sender: Some(sender), |
55 | 0 | }; |
56 | 0 | let receiver = BoundedReceiver { |
57 | 0 | state: state.clone(), |
58 | 0 | receiver: Some(receiver), |
59 | 0 | }; |
60 | 0 | (sender, receiver) |
61 | 0 | } |
62 | | |
63 | | struct BoundedState { |
64 | | senders: AtomicUsize, |
65 | | tx_permits: Semaphore, |
66 | | rx_task: AtomicOptionBox<Waker>, |
67 | | } |
68 | | |
69 | | /// Send values to the associated [`BoundedReceiver`]. |
70 | | /// |
71 | | /// Instances are created by the [`bounded`] function. |
72 | | pub struct BoundedSender<T> { |
73 | | state: Arc<BoundedState>, |
74 | | sender: Option<std::sync::mpsc::SyncSender<T>>, |
75 | | } |
76 | | |
77 | | impl<T> Clone for BoundedSender<T> { |
78 | 0 | fn clone(&self) -> Self { |
79 | 0 | self.state.senders.fetch_add(1, Ordering::Release); |
80 | 0 | BoundedSender { |
81 | 0 | state: self.state.clone(), |
82 | 0 | sender: self.sender.clone(), |
83 | 0 | } |
84 | 0 | } |
85 | | } |
86 | | |
87 | | impl<T> fmt::Debug for BoundedSender<T> { |
88 | 0 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
89 | 0 | fmt.debug_struct("BoundedSender").finish_non_exhaustive() |
90 | 0 | } |
91 | | } |
92 | | |
93 | | impl<T> Drop for BoundedSender<T> { |
94 | 0 | fn drop(&mut self) { |
95 | | // drop the sender; this closes the channel if it is the last sender |
96 | 0 | drop(self.sender.take()); |
97 | | |
98 | 0 | match self.state.senders.fetch_sub(1, Ordering::AcqRel) { |
99 | | 1 => { |
100 | | // If this is the last sender, we need to wake up the receiver so it can |
101 | | // observe the disconnected state. |
102 | 0 | if let Some(waker) = self.state.rx_task.take() { |
103 | 0 | waker.wake(); |
104 | 0 | } |
105 | | } |
106 | 0 | _ => { |
107 | 0 | // there are still other senders left, do nothing |
108 | 0 | } |
109 | | } |
110 | 0 | } |
111 | | } |
112 | | |
113 | | impl<T> BoundedSender<T> { |
114 | | /// Attempts to send a message to the associated receiver. |
115 | | /// |
116 | | /// This method will wait if the buffer of the channel is full until a `recv` is called on the |
117 | | /// receiver, which will consume the message and free up space in the buffer. |
118 | | /// |
119 | | /// If the receiver has been dropped, this function returns an error. The error includes |
120 | | /// the value passed to `send`. |
121 | 0 | pub async fn send(&self, value: T) -> Result<(), SendError<T>> { |
122 | 0 | let value = match self.try_send(value) { |
123 | 0 | Ok(()) => return Ok(()), |
124 | 0 | Err(TrySendError::Disconnected(value)) => return Err(SendError::new(value)), |
125 | 0 | Err(TrySendError::Full(value)) => value, |
126 | | }; |
127 | | |
128 | | struct SendState<'a, T> { |
129 | | sender: &'a BoundedSender<T>, |
130 | | value: Option<T>, |
131 | | acquire: Acquire<'a>, |
132 | | } |
133 | | |
134 | | impl<T> SendState<'_, T> { |
135 | 0 | fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> { |
136 | 0 | let mut value = match self.value.take() { |
137 | 0 | Some(value) => value, |
138 | 0 | None => return Poll::Ready(Ok(())), |
139 | | }; |
140 | | |
141 | | loop { |
142 | 0 | let poll = pin!(&mut self.acquire).poll(cx); |
143 | | |
144 | 0 | value = match self.sender.try_send(value) { |
145 | 0 | Ok(()) => return Poll::Ready(Ok(())), |
146 | 0 | Err(TrySendError::Disconnected(value)) => { |
147 | 0 | return Poll::Ready(Err(SendError::new(value))); |
148 | | } |
149 | 0 | Err(TrySendError::Full(value)) => value, |
150 | | }; |
151 | | |
152 | 0 | if poll.is_ready() { |
153 | 0 | self.acquire = self.sender.state.tx_permits.poll_acquire(1); |
154 | 0 | } else { |
155 | 0 | self.value = Some(value); |
156 | 0 | return Poll::Pending; |
157 | | } |
158 | | } |
159 | 0 | } |
160 | | } |
161 | | |
162 | 0 | let acquire = self.state.tx_permits.poll_acquire(1); |
163 | 0 | let mut send = SendState { |
164 | 0 | sender: self, |
165 | 0 | value: Some(value), |
166 | 0 | acquire, |
167 | 0 | }; |
168 | 0 | poll_fn(|cx| send.poll_send(cx)).await |
169 | 0 | } |
170 | | |
171 | | /// Attempts to send a message to the associated receiver without waiting. |
172 | | /// |
173 | | /// This method returns the [`Full`] error if the buffer of the channel is full. |
174 | | /// |
175 | | /// This method returns the [`Disconnected`] error if the channel is currently empty, and there |
176 | | /// are no outstanding [receivers]. |
177 | | /// |
178 | | /// [`Full`]: TrySendError::Full |
179 | | /// [`Disconnected`]: TrySendError::Disconnected |
180 | | /// [receivers]: BoundedReceiver |
181 | | /// |
182 | | /// # Examples |
183 | | /// |
184 | | /// ``` |
185 | | /// # #[tokio::main] |
186 | | /// # async fn main() { |
187 | | /// use mea::mpsc::TrySendError; |
188 | | /// use mea::mpsc::bounded; |
189 | | /// let (tx, mut rx) = bounded::<i32>(1); |
190 | | /// |
191 | | /// tx.try_send(1).unwrap(); |
192 | | /// assert_eq!(tx.try_send(2), Err(TrySendError::Full(2))); |
193 | | /// |
194 | | /// drop(rx); |
195 | | /// assert_eq!(tx.try_send(3), Err(TrySendError::Disconnected(3))); |
196 | | /// # } |
197 | | /// ``` |
198 | 0 | pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> { |
199 | | // SAFETY: The sender is guaranteed to be non-null before dropped. |
200 | 0 | let sender = self.sender.as_ref().unwrap(); |
201 | 0 | match sender.try_send(value) { |
202 | | Ok(()) => { |
203 | 0 | if let Some(waker) = self.state.rx_task.take() { |
204 | 0 | waker.wake(); |
205 | 0 | } |
206 | | |
207 | 0 | Ok(()) |
208 | | } |
209 | 0 | Err(std::sync::mpsc::TrySendError::Full(value)) => Err(TrySendError::Full(value)), |
210 | 0 | Err(std::sync::mpsc::TrySendError::Disconnected(value)) => { |
211 | 0 | Err(TrySendError::Disconnected(value)) |
212 | | } |
213 | | } |
214 | 0 | } |
215 | | } |
216 | | |
217 | | /// Receives values from the associated [`BoundedSender`]. |
218 | | /// |
219 | | /// Instances are created by the [`bounded`] function. |
220 | | pub struct BoundedReceiver<T> { |
221 | | state: Arc<BoundedState>, |
222 | | receiver: Option<std::sync::mpsc::Receiver<T>>, |
223 | | } |
224 | | |
225 | | /// The only `!Sync` field `receiver` is protected by `&mut self` in `recv` and `try_recv`. |
226 | | /// That is, `BoundedReceiver` can only be accessed by one thread at a time. |
227 | | unsafe impl<T: Send> Sync for BoundedReceiver<T> {} |
228 | | |
229 | | impl<T> fmt::Debug for BoundedReceiver<T> { |
230 | 0 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
231 | 0 | fmt.debug_struct("BoundedReceiver").finish_non_exhaustive() |
232 | 0 | } |
233 | | } |
234 | | |
235 | | impl<T> Drop for BoundedReceiver<T> { |
236 | 0 | fn drop(&mut self) { |
237 | 0 | drop(self.receiver.take()); |
238 | 0 | self.state.tx_permits.notify_all(); |
239 | 0 | } |
240 | | } |
241 | | |
242 | | impl<T> BoundedReceiver<T> { |
243 | | /// Tries to receive the next value for this receiver and frees up a space in the buffer if |
244 | | /// successful. |
245 | | /// |
246 | | /// This method returns the [`Empty`] error if the channel is currently |
247 | | /// empty, but there are still outstanding [senders]. |
248 | | /// |
249 | | /// This method returns the [`Disconnected`] error if the channel is |
250 | | /// currently empty, and there are no outstanding [senders]. |
251 | | /// |
252 | | /// [`Empty`]: TryRecvError::Empty |
253 | | /// [`Disconnected`]: TryRecvError::Disconnected |
254 | | /// [senders]: BoundedSender |
255 | | /// |
256 | | /// # Examples |
257 | | /// |
258 | | /// ``` |
259 | | /// # #[tokio::main] |
260 | | /// # async fn main() { |
261 | | /// use mea::mpsc; |
262 | | /// use mea::mpsc::TryRecvError; |
263 | | /// let (tx, mut rx) = mpsc::bounded(2); |
264 | | /// |
265 | | /// tx.send("hello").await.unwrap(); |
266 | | /// |
267 | | /// assert_eq!(Ok("hello"), rx.try_recv()); |
268 | | /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); |
269 | | /// |
270 | | /// tx.send("hello").await.unwrap(); |
271 | | /// drop(tx); |
272 | | /// |
273 | | /// assert_eq!(Ok("hello"), rx.try_recv()); |
274 | | /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); |
275 | | /// # } |
276 | | /// ``` |
277 | 0 | pub fn try_recv(&mut self) -> Result<T, TryRecvError> { |
278 | | // SAFETY: The receiver is guaranteed to be non-null before dropped. |
279 | 0 | let receiver = self.receiver.as_ref().unwrap(); |
280 | 0 | match receiver.try_recv() { |
281 | 0 | Ok(v) => { |
282 | 0 | self.state.tx_permits.release_if_nonempty(1); |
283 | 0 | Ok(v) |
284 | | } |
285 | 0 | Err(std::sync::mpsc::TryRecvError::Disconnected) => Err(TryRecvError::Disconnected), |
286 | 0 | Err(std::sync::mpsc::TryRecvError::Empty) => Err(TryRecvError::Empty), |
287 | | } |
288 | 0 | } |
289 | | |
290 | | /// Receives the next value for this receiver and frees up a space in the buffer if successful. |
291 | | /// |
292 | | /// This method returns `Err(RecvError::Disconnected)` if the channel has been closed and there |
293 | | /// are no remaining messages in the channel's buffer. This indicates that no further values |
294 | | /// can ever be received from this `Receiver`. The channel is closed when all senders have been |
295 | | /// dropped. |
296 | | /// |
297 | | /// If there are no messages in the channel's buffer, but the channel has not yet been closed, |
298 | | /// this method will sleep until a message is sent or the channel is closed. |
299 | | /// |
300 | | /// # Cancel safety |
301 | | /// |
302 | | /// This method is cancel safe. If `recv` is used as the event in a `select` statement |
303 | | /// and some other branch completes first, it is guaranteed that no messages were received |
304 | | /// on this channel. |
305 | | /// |
306 | | /// # Examples |
307 | | /// |
308 | | /// ``` |
309 | | /// # #[tokio::main] |
310 | | /// # async fn main() { |
311 | | /// use mea::mpsc; |
312 | | /// let (tx, mut rx) = mpsc::bounded(1); |
313 | | /// |
314 | | /// tokio::spawn(async move { |
315 | | /// tx.send("hello").await.unwrap(); |
316 | | /// }); |
317 | | /// |
318 | | /// assert_eq!(Ok("hello"), rx.recv().await); |
319 | | /// assert_eq!(Err(mpsc::RecvError::Disconnected), rx.recv().await); |
320 | | /// # } |
321 | | /// ``` |
322 | | /// |
323 | | /// Values are buffered if the channel has enough capacity: |
324 | | /// |
325 | | /// ``` |
326 | | /// # #[tokio::main] |
327 | | /// # async fn main() { |
328 | | /// use mea::mpsc; |
329 | | /// let (tx, mut rx) = mpsc::bounded(2); |
330 | | /// |
331 | | /// tx.send("hello").await.unwrap(); |
332 | | /// tx.send("world").await.unwrap(); |
333 | | /// |
334 | | /// assert_eq!(Ok("hello"), rx.recv().await); |
335 | | /// assert_eq!(Ok("world"), rx.recv().await); |
336 | | /// # } |
337 | | /// ``` |
338 | 0 | pub async fn recv(&mut self) -> Result<T, RecvError> { |
339 | 0 | poll_fn(|cx| self.poll_recv(cx)).await |
340 | 0 | } |
341 | | |
342 | 0 | fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { |
343 | 0 | match self.try_recv() { |
344 | 0 | Ok(v) => Poll::Ready(Ok(v)), |
345 | 0 | Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)), |
346 | | Err(TryRecvError::Empty) => { |
347 | 0 | let waker = Some(Box::new(cx.waker().clone())); |
348 | 0 | self.state.rx_task.store(waker); |
349 | | |
350 | 0 | match self.try_recv() { |
351 | 0 | Ok(v) => Poll::Ready(Ok(v)), |
352 | 0 | Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)), |
353 | 0 | Err(TryRecvError::Empty) => Poll::Pending, |
354 | | } |
355 | | } |
356 | | } |
357 | 0 | } |
358 | | } |