/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tokio-util-0.7.13/src/sync/mpsc.rs
Line | Count | Source |
1 | | use futures_sink::Sink; |
2 | | use std::pin::Pin; |
3 | | use std::task::{Context, Poll}; |
4 | | use std::{fmt, mem}; |
5 | | use tokio::sync::mpsc::OwnedPermit; |
6 | | use tokio::sync::mpsc::Sender; |
7 | | |
8 | | use super::ReusableBoxFuture; |
9 | | |
10 | | /// Error returned by the `PollSender` when the channel is closed. |
11 | | #[derive(Debug)] |
12 | | pub struct PollSendError<T>(Option<T>); |
13 | | |
14 | | impl<T> PollSendError<T> { |
15 | | /// Consumes the stored value, if any. |
16 | | /// |
17 | | /// If this error was encountered when calling `start_send`/`send_item`, this will be the item |
18 | | /// that the caller attempted to send. Otherwise, it will be `None`. |
19 | 0 | pub fn into_inner(self) -> Option<T> { |
20 | 0 | self.0 |
21 | 0 | } |
22 | | } |
23 | | |
24 | | impl<T> fmt::Display for PollSendError<T> { |
25 | 0 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
26 | 0 | write!(fmt, "channel closed") |
27 | 0 | } |
28 | | } |
29 | | |
30 | | impl<T: fmt::Debug> std::error::Error for PollSendError<T> {} |
31 | | |
32 | | #[derive(Debug)] |
33 | | enum State<T> { |
34 | | Idle(Sender<T>), |
35 | | Acquiring, |
36 | | ReadyToSend(OwnedPermit<T>), |
37 | | Closed, |
38 | | } |
39 | | |
40 | | /// A wrapper around [`mpsc::Sender`] that can be polled. |
41 | | /// |
42 | | /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender |
43 | | #[derive(Debug)] |
44 | | pub struct PollSender<T> { |
45 | | sender: Option<Sender<T>>, |
46 | | state: State<T>, |
47 | | acquire: PollSenderFuture<T>, |
48 | | } |
49 | | |
50 | | // Creates a future for acquiring a permit from the underlying channel. This is used to ensure |
51 | | // there's capacity for a send to complete. |
52 | | // |
53 | | // By reusing the same async fn for both `Some` and `None`, we make sure every future passed to |
54 | | // ReusableBoxFuture has the same underlying type, and hence the same size and alignment. |
55 | 0 | async fn make_acquire_future<T>( |
56 | 0 | data: Option<Sender<T>>, |
57 | 0 | ) -> Result<OwnedPermit<T>, PollSendError<T>> { |
58 | 0 | match data { |
59 | 0 | Some(sender) => sender |
60 | 0 | .reserve_owned() |
61 | 0 | .await |
62 | 0 | .map_err(|_| PollSendError(None)), |
63 | 0 | None => unreachable!("this future should not be pollable in this state"), |
64 | | } |
65 | 0 | } |
66 | | |
67 | | type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>; |
68 | | |
69 | | #[derive(Debug)] |
70 | | // TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes |
71 | | struct PollSenderFuture<T>(InnerFuture<'static, T>); |
72 | | |
73 | | impl<T> PollSenderFuture<T> { |
74 | | /// Create with an empty inner future with no `Send` bound. |
75 | 0 | fn empty() -> Self { |
76 | | // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not |
77 | | // compatible with the transitive bounds required by `Sender<T>`. |
78 | 0 | Self(ReusableBoxFuture::new(async { unreachable!() })) |
79 | 0 | } |
80 | | } |
81 | | |
82 | | impl<T: Send> PollSenderFuture<T> { |
83 | | /// Create with an empty inner future. |
84 | 0 | fn new() -> Self { |
85 | 0 | let v = InnerFuture::new(make_acquire_future(None)); |
86 | | // This is safe because `make_acquire_future(None)` is actually `'static` |
87 | 0 | Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) }) |
88 | 0 | } |
89 | | |
90 | | /// Poll the inner future. |
91 | 0 | fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> { |
92 | 0 | self.0.poll(cx) |
93 | 0 | } |
94 | | |
95 | | /// Replace the inner future. |
96 | 0 | fn set(&mut self, sender: Option<Sender<T>>) { |
97 | 0 | let inner: *mut InnerFuture<'static, T> = &mut self.0; |
98 | 0 | let inner: *mut InnerFuture<'_, T> = inner.cast(); |
99 | | // SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T` |
100 | | // becomes invalid, and this casts away the type-level lifetime check for that. However, the |
101 | | // inner future is never moved out of this `PollSenderFuture<T>`, so the future will not |
102 | | // live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed |
103 | | // to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so |
104 | | // this is ok. |
105 | 0 | let inner = unsafe { &mut *inner }; |
106 | 0 | inner.set(make_acquire_future(sender)); |
107 | 0 | } |
108 | | } |
109 | | |
110 | | impl<T: Send> PollSender<T> { |
111 | | /// Creates a new `PollSender`. |
112 | 0 | pub fn new(sender: Sender<T>) -> Self { |
113 | 0 | Self { |
114 | 0 | sender: Some(sender.clone()), |
115 | 0 | state: State::Idle(sender), |
116 | 0 | acquire: PollSenderFuture::new(), |
117 | 0 | } |
118 | 0 | } |
119 | | |
120 | 0 | fn take_state(&mut self) -> State<T> { |
121 | 0 | mem::replace(&mut self.state, State::Closed) |
122 | 0 | } |
123 | | |
124 | | /// Attempts to prepare the sender to receive a value. |
125 | | /// |
126 | | /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to |
127 | | /// `send_item`. |
128 | | /// |
129 | | /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, |
130 | | /// by reserving a slot in the channel for the item to be sent. If this method returns |
131 | | /// `Poll::Pending`, the current task is registered to be notified (via |
132 | | /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. |
133 | | /// |
134 | | /// # Errors |
135 | | /// |
136 | | /// If the channel is closed, an error will be returned. This is a permanent state. |
137 | 0 | pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> { |
138 | | loop { |
139 | 0 | let (result, next_state) = match self.take_state() { |
140 | 0 | State::Idle(sender) => { |
141 | | // Start trying to acquire a permit to reserve a slot for our send, and |
142 | | // immediately loop back around to poll it the first time. |
143 | 0 | self.acquire.set(Some(sender)); |
144 | 0 | (None, State::Acquiring) |
145 | | } |
146 | 0 | State::Acquiring => match self.acquire.poll(cx) { |
147 | | // Channel has capacity. |
148 | 0 | Poll::Ready(Ok(permit)) => { |
149 | 0 | (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) |
150 | | } |
151 | | // Channel is closed. |
152 | 0 | Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), |
153 | | // Channel doesn't have capacity yet, so we need to wait. |
154 | 0 | Poll::Pending => (Some(Poll::Pending), State::Acquiring), |
155 | | }, |
156 | | // We're closed, either by choice or because the underlying sender was closed. |
157 | 0 | s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), |
158 | | // We're already ready to send an item. |
159 | 0 | s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), |
160 | | }; |
161 | | |
162 | 0 | self.state = next_state; |
163 | 0 | if let Some(result) = result { |
164 | 0 | return result; |
165 | 0 | } |
166 | | } |
167 | 0 | } |
168 | | |
169 | | /// Sends an item to the channel. |
170 | | /// |
171 | | /// Before calling `send_item`, `poll_reserve` must be called with a successful return |
172 | | /// value of `Poll::Ready(Ok(()))`. |
173 | | /// |
174 | | /// # Errors |
175 | | /// |
176 | | /// If the channel is closed, an error will be returned. This is a permanent state. |
177 | | /// |
178 | | /// # Panics |
179 | | /// |
180 | | /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method |
181 | | /// will panic. |
182 | | #[track_caller] |
183 | 0 | pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> { |
184 | 0 | let (result, next_state) = match self.take_state() { |
185 | | State::Idle(_) | State::Acquiring => { |
186 | 0 | panic!("`send_item` called without first calling `poll_reserve`") |
187 | | } |
188 | | // We have a permit to send our item, so go ahead, which gets us our sender back. |
189 | 0 | State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), |
190 | | // We're closed, either by choice or because the underlying sender was closed. |
191 | 0 | State::Closed => (Err(PollSendError(Some(value))), State::Closed), |
192 | | }; |
193 | | |
194 | | // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. |
195 | 0 | self.state = if self.sender.is_some() { |
196 | 0 | next_state |
197 | | } else { |
198 | 0 | State::Closed |
199 | | }; |
200 | 0 | result |
201 | 0 | } |
202 | | |
203 | | /// Checks whether this sender is been closed. |
204 | | /// |
205 | | /// The underlying channel that this sender was wrapping may still be open. |
206 | 0 | pub fn is_closed(&self) -> bool { |
207 | 0 | matches!(self.state, State::Closed) || self.sender.is_none() |
208 | 0 | } |
209 | | |
210 | | /// Gets a reference to the `Sender` of the underlying channel. |
211 | | /// |
212 | | /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender |
213 | | /// was wrapping may still be open. |
214 | 0 | pub fn get_ref(&self) -> Option<&Sender<T>> { |
215 | 0 | self.sender.as_ref() |
216 | 0 | } |
217 | | |
218 | | /// Closes this sender. |
219 | | /// |
220 | | /// No more messages will be able to be sent from this sender, but the underlying channel will |
221 | | /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. |
222 | | /// |
223 | | /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made |
224 | | /// to `send_item` in order to consume the reserved slot. After that, no further sends will be |
225 | | /// possible. If you do not intend to send another item, you can release the reserved slot back |
226 | | /// to the underlying sender by calling [`abort_send`]. |
227 | | /// |
228 | | /// [`abort_send`]: crate::sync::PollSender::abort_send |
229 | | /// [`Receiver`]: tokio::sync::mpsc::Receiver |
230 | 0 | pub fn close(&mut self) { |
231 | | // Mark ourselves officially closed by dropping our main sender. |
232 | 0 | self.sender = None; |
233 | | |
234 | | // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly |
235 | | // transition to the closed state. Otherwise, leave the existing permit in place for the |
236 | | // caller if they want to complete the send. |
237 | 0 | match self.state { |
238 | 0 | State::Idle(_) => self.state = State::Closed, |
239 | 0 | State::Acquiring => { |
240 | 0 | self.acquire.set(None); |
241 | 0 | self.state = State::Closed; |
242 | 0 | } |
243 | 0 | _ => {} |
244 | | } |
245 | 0 | } |
246 | | |
247 | | /// Aborts the current in-progress send, if any. |
248 | | /// |
249 | | /// Returns `true` if a send was aborted. If the sender was closed prior to calling |
250 | | /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be |
251 | | /// ready to attempt another send. |
252 | 0 | pub fn abort_send(&mut self) -> bool { |
253 | | // We may have been closed in the meantime, after a call to `poll_reserve` already |
254 | | // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the |
255 | | // closed state when we actually abort a send, rather than resetting ourselves back to idle. |
256 | | |
257 | 0 | let (result, next_state) = match self.take_state() { |
258 | | // We're currently trying to reserve a slot to send into. |
259 | | State::Acquiring => { |
260 | | // Replacing the future drops the in-flight one. |
261 | 0 | self.acquire.set(None); |
262 | | |
263 | | // If we haven't closed yet, we have to clone our stored sender since we have no way |
264 | | // to get it back from the acquire future we just dropped. |
265 | 0 | let state = match self.sender.clone() { |
266 | 0 | Some(sender) => State::Idle(sender), |
267 | 0 | None => State::Closed, |
268 | | }; |
269 | 0 | (true, state) |
270 | | } |
271 | | // We got the permit. If we haven't closed yet, get the sender back. |
272 | 0 | State::ReadyToSend(permit) => { |
273 | 0 | let state = if self.sender.is_some() { |
274 | 0 | State::Idle(permit.release()) |
275 | | } else { |
276 | 0 | State::Closed |
277 | | }; |
278 | 0 | (true, state) |
279 | | } |
280 | 0 | s => (false, s), |
281 | | }; |
282 | | |
283 | 0 | self.state = next_state; |
284 | 0 | result |
285 | 0 | } |
286 | | } |
287 | | |
288 | | impl<T> Clone for PollSender<T> { |
289 | | /// Clones this `PollSender`. |
290 | | /// |
291 | | /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. |
292 | 0 | fn clone(&self) -> PollSender<T> { |
293 | 0 | let (sender, state) = match self.sender.clone() { |
294 | 0 | Some(sender) => (Some(sender.clone()), State::Idle(sender)), |
295 | 0 | None => (None, State::Closed), |
296 | | }; |
297 | | |
298 | 0 | Self { |
299 | 0 | sender, |
300 | 0 | state, |
301 | 0 | acquire: PollSenderFuture::empty(), |
302 | 0 | } |
303 | 0 | } |
304 | | } |
305 | | |
306 | | impl<T: Send> Sink<T> for PollSender<T> { |
307 | | type Error = PollSendError<T>; |
308 | | |
309 | 0 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
310 | 0 | Pin::into_inner(self).poll_reserve(cx) |
311 | 0 | } |
312 | | |
313 | 0 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
314 | 0 | Poll::Ready(Ok(())) |
315 | 0 | } |
316 | | |
317 | 0 | fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { |
318 | 0 | Pin::into_inner(self).send_item(item) |
319 | 0 | } |
320 | | |
321 | 0 | fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
322 | 0 | Pin::into_inner(self).close(); |
323 | 0 | Poll::Ready(Ok(())) |
324 | 0 | } |
325 | | } |