Coverage Report

Created: 2026-01-30 06:08

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/mea-0.6.0/src/mpsc/bounded.rs
Line
Count
Source
1
// Copyright 2024 tison <wander4096@gmail.com>
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
//! A bounded multi-producer, single-consumer queue for sending values between asynchronous
16
//! tasks with backpressure control.
17
18
use std::fmt;
19
use std::future::Future;
20
use std::future::poll_fn;
21
use std::pin::pin;
22
use std::sync::Arc;
23
use std::sync::atomic::AtomicUsize;
24
use std::sync::atomic::Ordering;
25
use std::task::Context;
26
use std::task::Poll;
27
use std::task::Waker;
28
29
use crate::atomicbox::AtomicOptionBox;
30
use crate::internal::Acquire;
31
use crate::internal::Semaphore;
32
use crate::mpsc::RecvError;
33
use crate::mpsc::SendError;
34
use crate::mpsc::TryRecvError;
35
use crate::mpsc::error::TrySendError;
36
37
/// Creates a bounded mpsc channel for communicating between asynchronous
38
/// tasks with backpressure.
39
///
40
/// A `send` on this channel will wait if the buffer of the channel is full until a
41
/// `recv` is called on the receiver, which will consume the message and
42
/// free up space in the buffer.
43
#[track_caller]
44
0
pub fn bounded<T>(buffer: usize) -> (BoundedSender<T>, BoundedReceiver<T>) {
45
0
    assert!(buffer > 0, "mpsc bounded channel requires buffer > 0");
46
0
    let state = Arc::new(BoundedState {
47
0
        senders: AtomicUsize::new(1),
48
0
        tx_permits: Semaphore::new(0),
49
0
        rx_task: AtomicOptionBox::none(),
50
0
    });
51
0
    let (sender, receiver) = std::sync::mpsc::sync_channel(buffer);
52
0
    let sender = BoundedSender {
53
0
        state: state.clone(),
54
0
        sender: Some(sender),
55
0
    };
56
0
    let receiver = BoundedReceiver {
57
0
        state: state.clone(),
58
0
        receiver: Some(receiver),
59
0
    };
60
0
    (sender, receiver)
61
0
}
62
63
struct BoundedState {
64
    senders: AtomicUsize,
65
    tx_permits: Semaphore,
66
    rx_task: AtomicOptionBox<Waker>,
67
}
68
69
/// Send values to the associated [`BoundedReceiver`].
70
///
71
/// Instances are created by the [`bounded`] function.
72
pub struct BoundedSender<T> {
73
    state: Arc<BoundedState>,
74
    sender: Option<std::sync::mpsc::SyncSender<T>>,
75
}
76
77
impl<T> Clone for BoundedSender<T> {
78
0
    fn clone(&self) -> Self {
79
0
        self.state.senders.fetch_add(1, Ordering::Release);
80
0
        BoundedSender {
81
0
            state: self.state.clone(),
82
0
            sender: self.sender.clone(),
83
0
        }
84
0
    }
85
}
86
87
impl<T> fmt::Debug for BoundedSender<T> {
88
0
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
89
0
        fmt.debug_struct("BoundedSender").finish_non_exhaustive()
90
0
    }
91
}
92
93
impl<T> Drop for BoundedSender<T> {
94
0
    fn drop(&mut self) {
95
        // drop the sender; this closes the channel if it is the last sender
96
0
        drop(self.sender.take());
97
98
0
        match self.state.senders.fetch_sub(1, Ordering::AcqRel) {
99
            1 => {
100
                // If this is the last sender, we need to wake up the receiver so it can
101
                // observe the disconnected state.
102
0
                if let Some(waker) = self.state.rx_task.take() {
103
0
                    waker.wake();
104
0
                }
105
            }
106
0
            _ => {
107
0
                // there are still other senders left, do nothing
108
0
            }
109
        }
110
0
    }
111
}
112
113
impl<T> BoundedSender<T> {
114
    /// Attempts to send a message to the associated receiver.
115
    ///
116
    /// This method will wait if the buffer of the channel is full until a `recv` is called on the
117
    /// receiver, which will consume the message and free up space in the buffer.
118
    ///
119
    /// If the receiver has been dropped, this function returns an error. The error includes
120
    /// the value passed to `send`.
121
0
    pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
122
0
        let value = match self.try_send(value) {
123
0
            Ok(()) => return Ok(()),
124
0
            Err(TrySendError::Disconnected(value)) => return Err(SendError::new(value)),
125
0
            Err(TrySendError::Full(value)) => value,
126
        };
127
128
        struct SendState<'a, T> {
129
            sender: &'a BoundedSender<T>,
130
            value: Option<T>,
131
            acquire: Acquire<'a>,
132
        }
133
134
        impl<T> SendState<'_, T> {
135
0
            fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
136
0
                let mut value = match self.value.take() {
137
0
                    Some(value) => value,
138
0
                    None => return Poll::Ready(Ok(())),
139
                };
140
141
                loop {
142
0
                    let poll = pin!(&mut self.acquire).poll(cx);
143
144
0
                    value = match self.sender.try_send(value) {
145
0
                        Ok(()) => return Poll::Ready(Ok(())),
146
0
                        Err(TrySendError::Disconnected(value)) => {
147
0
                            return Poll::Ready(Err(SendError::new(value)));
148
                        }
149
0
                        Err(TrySendError::Full(value)) => value,
150
                    };
151
152
0
                    if poll.is_ready() {
153
0
                        self.acquire = self.sender.state.tx_permits.poll_acquire(1);
154
0
                    } else {
155
0
                        self.value = Some(value);
156
0
                        return Poll::Pending;
157
                    }
158
                }
159
0
            }
160
        }
161
162
0
        let acquire = self.state.tx_permits.poll_acquire(1);
163
0
        let mut send = SendState {
164
0
            sender: self,
165
0
            value: Some(value),
166
0
            acquire,
167
0
        };
168
0
        poll_fn(|cx| send.poll_send(cx)).await
169
0
    }
170
171
    /// Attempts to send a message to the associated receiver without waiting.
172
    ///
173
    /// This method returns the [`Full`] error if the buffer of the channel is full.
174
    ///
175
    /// This method returns the [`Disconnected`] error if the channel is currently empty, and there
176
    /// are no outstanding [receivers].
177
    ///
178
    /// [`Full`]: TrySendError::Full
179
    /// [`Disconnected`]: TrySendError::Disconnected
180
    /// [receivers]: BoundedReceiver
181
    ///
182
    /// # Examples
183
    ///
184
    /// ```
185
    /// # #[tokio::main]
186
    /// # async fn main() {
187
    /// use mea::mpsc::TrySendError;
188
    /// use mea::mpsc::bounded;
189
    /// let (tx, mut rx) = bounded::<i32>(1);
190
    ///
191
    /// tx.try_send(1).unwrap();
192
    /// assert_eq!(tx.try_send(2), Err(TrySendError::Full(2)));
193
    ///
194
    /// drop(rx);
195
    /// assert_eq!(tx.try_send(3), Err(TrySendError::Disconnected(3)));
196
    /// # }
197
    /// ```
198
0
    pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
199
        // SAFETY: The sender is guaranteed to be non-null before dropped.
200
0
        let sender = self.sender.as_ref().unwrap();
201
0
        match sender.try_send(value) {
202
            Ok(()) => {
203
0
                if let Some(waker) = self.state.rx_task.take() {
204
0
                    waker.wake();
205
0
                }
206
207
0
                Ok(())
208
            }
209
0
            Err(std::sync::mpsc::TrySendError::Full(value)) => Err(TrySendError::Full(value)),
210
0
            Err(std::sync::mpsc::TrySendError::Disconnected(value)) => {
211
0
                Err(TrySendError::Disconnected(value))
212
            }
213
        }
214
0
    }
215
}
216
217
/// Receives values from the associated [`BoundedSender`].
218
///
219
/// Instances are created by the [`bounded`] function.
220
pub struct BoundedReceiver<T> {
221
    state: Arc<BoundedState>,
222
    receiver: Option<std::sync::mpsc::Receiver<T>>,
223
}
224
225
/// The only `!Sync` field `receiver` is protected by `&mut self` in `recv` and `try_recv`.
226
/// That is, `BoundedReceiver` can only be accessed by one thread at a time.
227
unsafe impl<T: Send> Sync for BoundedReceiver<T> {}
228
229
impl<T> fmt::Debug for BoundedReceiver<T> {
230
0
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
231
0
        fmt.debug_struct("BoundedReceiver").finish_non_exhaustive()
232
0
    }
233
}
234
235
impl<T> Drop for BoundedReceiver<T> {
236
0
    fn drop(&mut self) {
237
0
        drop(self.receiver.take());
238
0
        self.state.tx_permits.notify_all();
239
0
    }
240
}
241
242
impl<T> BoundedReceiver<T> {
243
    /// Tries to receive the next value for this receiver and frees up a space in the buffer if
244
    /// successful.
245
    ///
246
    /// This method returns the [`Empty`] error if the channel is currently
247
    /// empty, but there are still outstanding [senders].
248
    ///
249
    /// This method returns the [`Disconnected`] error if the channel is
250
    /// currently empty, and there are no outstanding [senders].
251
    ///
252
    /// [`Empty`]: TryRecvError::Empty
253
    /// [`Disconnected`]: TryRecvError::Disconnected
254
    /// [senders]: BoundedSender
255
    ///
256
    /// # Examples
257
    ///
258
    /// ```
259
    /// # #[tokio::main]
260
    /// # async fn main() {
261
    /// use mea::mpsc;
262
    /// use mea::mpsc::TryRecvError;
263
    /// let (tx, mut rx) = mpsc::bounded(2);
264
    ///
265
    /// tx.send("hello").await.unwrap();
266
    ///
267
    /// assert_eq!(Ok("hello"), rx.try_recv());
268
    /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
269
    ///
270
    /// tx.send("hello").await.unwrap();
271
    /// drop(tx);
272
    ///
273
    /// assert_eq!(Ok("hello"), rx.try_recv());
274
    /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
275
    /// # }
276
    /// ```
277
0
    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
278
        // SAFETY: The receiver is guaranteed to be non-null before dropped.
279
0
        let receiver = self.receiver.as_ref().unwrap();
280
0
        match receiver.try_recv() {
281
0
            Ok(v) => {
282
0
                self.state.tx_permits.release_if_nonempty(1);
283
0
                Ok(v)
284
            }
285
0
            Err(std::sync::mpsc::TryRecvError::Disconnected) => Err(TryRecvError::Disconnected),
286
0
            Err(std::sync::mpsc::TryRecvError::Empty) => Err(TryRecvError::Empty),
287
        }
288
0
    }
289
290
    /// Receives the next value for this receiver and frees up a space in the buffer if successful.
291
    ///
292
    /// This method returns `Err(RecvError::Disconnected)` if the channel has been closed and there
293
    /// are no remaining messages in the channel's buffer. This indicates that no further values
294
    /// can ever be received from this `Receiver`. The channel is closed when all senders have been
295
    /// dropped.
296
    ///
297
    /// If there are no messages in the channel's buffer, but the channel has not yet been closed,
298
    /// this method will sleep until a message is sent or the channel is closed.
299
    ///
300
    /// # Cancel safety
301
    ///
302
    /// This method is cancel safe. If `recv` is used as the event in a `select` statement
303
    /// and some other branch completes first, it is guaranteed that no messages were received
304
    /// on this channel.
305
    ///
306
    /// # Examples
307
    ///
308
    /// ```
309
    /// # #[tokio::main]
310
    /// # async fn main() {
311
    /// use mea::mpsc;
312
    /// let (tx, mut rx) = mpsc::bounded(1);
313
    ///
314
    /// tokio::spawn(async move {
315
    ///     tx.send("hello").await.unwrap();
316
    /// });
317
    ///
318
    /// assert_eq!(Ok("hello"), rx.recv().await);
319
    /// assert_eq!(Err(mpsc::RecvError::Disconnected), rx.recv().await);
320
    /// # }
321
    /// ```
322
    ///
323
    /// Values are buffered if the channel has enough capacity:
324
    ///
325
    /// ```
326
    /// # #[tokio::main]
327
    /// # async fn main() {
328
    /// use mea::mpsc;
329
    /// let (tx, mut rx) = mpsc::bounded(2);
330
    ///
331
    /// tx.send("hello").await.unwrap();
332
    /// tx.send("world").await.unwrap();
333
    ///
334
    /// assert_eq!(Ok("hello"), rx.recv().await);
335
    /// assert_eq!(Ok("world"), rx.recv().await);
336
    /// # }
337
    /// ```
338
0
    pub async fn recv(&mut self) -> Result<T, RecvError> {
339
0
        poll_fn(|cx| self.poll_recv(cx)).await
340
0
    }
341
342
0
    fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
343
0
        match self.try_recv() {
344
0
            Ok(v) => Poll::Ready(Ok(v)),
345
0
            Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)),
346
            Err(TryRecvError::Empty) => {
347
0
                let waker = Some(Box::new(cx.waker().clone()));
348
0
                self.state.rx_task.store(waker);
349
350
0
                match self.try_recv() {
351
0
                    Ok(v) => Poll::Ready(Ok(v)),
352
0
                    Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError::Disconnected)),
353
0
                    Err(TryRecvError::Empty) => Poll::Pending,
354
                }
355
            }
356
        }
357
0
    }
358
}