/src/h2/tests/h2-support/src/future_ext.rs
Line | Count | Source |
1 | | use futures::{FutureExt, TryFuture}; |
2 | | use std::future::Future; |
3 | | use std::pin::Pin; |
4 | | use std::sync::atomic::AtomicBool; |
5 | | use std::sync::Arc; |
6 | | use std::task::{Context, Poll, Wake, Waker}; |
7 | | |
8 | | /// Future extension helpers that are useful for tests |
9 | | pub trait TestFuture: Future { |
10 | | /// Drive `other` by polling `self`. |
11 | | /// |
12 | | /// `self` must not resolve before `other` does. |
13 | 0 | fn drive<T>(&mut self, other: T) -> Drive<'_, Self, T> |
14 | 0 | where |
15 | 0 | T: Future, |
16 | 0 | Self: Future + Sized, |
17 | | { |
18 | 0 | Drive { |
19 | 0 | driver: self, |
20 | 0 | future: other.wakened(), |
21 | 0 | } |
22 | 0 | } |
23 | | |
24 | 0 | fn wakened(self) -> Wakened<Self> |
25 | 0 | where |
26 | 0 | Self: Sized, |
27 | | { |
28 | 0 | Wakened { |
29 | 0 | future: Box::pin(self), |
30 | 0 | woken: Arc::new(AtomicBool::new(true)), |
31 | 0 | } |
32 | 0 | } |
33 | | } |
34 | | |
35 | | /// Wraps futures::future::join to ensure that the futures are only polled if they are woken. |
36 | 0 | pub fn join<Fut1, Fut2>( |
37 | 0 | future1: Fut1, |
38 | 0 | future2: Fut2, |
39 | 0 | ) -> futures::future::Join<Wakened<Fut1>, Wakened<Fut2>> |
40 | 0 | where |
41 | 0 | Fut1: Future, |
42 | 0 | Fut2: Future, |
43 | | { |
44 | 0 | futures::future::join(future1.wakened(), future2.wakened()) |
45 | 0 | } |
46 | | |
47 | | /// Wraps futures::future::join3 to ensure that the futures are only polled if they are woken. |
48 | 0 | pub fn join3<Fut1, Fut2, Fut3>( |
49 | 0 | future1: Fut1, |
50 | 0 | future2: Fut2, |
51 | 0 | future3: Fut3, |
52 | 0 | ) -> futures::future::Join3<Wakened<Fut1>, Wakened<Fut2>, Wakened<Fut3>> |
53 | 0 | where |
54 | 0 | Fut1: Future, |
55 | 0 | Fut2: Future, |
56 | 0 | Fut3: Future, |
57 | | { |
58 | 0 | futures::future::join3(future1.wakened(), future2.wakened(), future3.wakened()) |
59 | 0 | } |
60 | | |
61 | | /// Wraps futures::future::join4 to ensure that the futures are only polled if they are woken. |
62 | 0 | pub fn join4<Fut1, Fut2, Fut3, Fut4>( |
63 | 0 | future1: Fut1, |
64 | 0 | future2: Fut2, |
65 | 0 | future3: Fut3, |
66 | 0 | future4: Fut4, |
67 | 0 | ) -> futures::future::Join4<Wakened<Fut1>, Wakened<Fut2>, Wakened<Fut3>, Wakened<Fut4>> |
68 | 0 | where |
69 | 0 | Fut1: Future, |
70 | 0 | Fut2: Future, |
71 | 0 | Fut3: Future, |
72 | 0 | Fut4: Future, |
73 | | { |
74 | 0 | futures::future::join4( |
75 | 0 | future1.wakened(), |
76 | 0 | future2.wakened(), |
77 | 0 | future3.wakened(), |
78 | 0 | future4.wakened(), |
79 | | ) |
80 | 0 | } |
81 | | |
82 | | /// Wraps futures::future::try_join to ensure that the futures are only polled if they are woken. |
83 | 0 | pub fn try_join<Fut1, Fut2>( |
84 | 0 | future1: Fut1, |
85 | 0 | future2: Fut2, |
86 | 0 | ) -> futures::future::TryJoin<Wakened<Fut1>, Wakened<Fut2>> |
87 | 0 | where |
88 | 0 | Fut1: futures::future::TryFuture + Future, |
89 | 0 | Fut2: Future, |
90 | 0 | Wakened<Fut1>: futures::future::TryFuture, |
91 | 0 | Wakened<Fut2>: futures::future::TryFuture<Error = <Wakened<Fut1> as TryFuture>::Error>, |
92 | | { |
93 | 0 | futures::future::try_join(future1.wakened(), future2.wakened()) |
94 | 0 | } |
95 | | |
96 | | /// Wraps futures::future::select to ensure that the futures are only polled if they are woken. |
97 | 0 | pub fn select<A, B>(future1: A, future2: B) -> futures::future::Select<Wakened<A>, Wakened<B>> |
98 | 0 | where |
99 | 0 | A: Future + Unpin, |
100 | 0 | B: Future + Unpin, |
101 | | { |
102 | 0 | futures::future::select(future1.wakened(), future2.wakened()) |
103 | 0 | } |
104 | | |
105 | | /// Wraps futures::future::join_all to ensure that the futures are only polled if they are woken. |
106 | 0 | pub fn join_all<I>(iter: I) -> futures::future::JoinAll<Wakened<I::Item>> |
107 | 0 | where |
108 | 0 | I: IntoIterator, |
109 | 0 | I::Item: Future, |
110 | | { |
111 | 0 | futures::future::join_all(iter.into_iter().map(|f| f.wakened())) |
112 | 0 | } |
113 | | |
114 | | /// A future that only polls the inner future if it has been woken (after the initial poll). |
115 | | pub struct Wakened<T> { |
116 | | future: Pin<Box<T>>, |
117 | | woken: Arc<AtomicBool>, |
118 | | } |
119 | | |
120 | | /// A future that only polls the inner future if it has been woken (after the initial poll). |
121 | | impl<T> Future for Wakened<T> |
122 | | where |
123 | | T: Future, |
124 | | { |
125 | | type Output = T::Output; |
126 | | |
127 | 0 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
128 | 0 | let this = self.get_mut(); |
129 | 0 | if !this.woken.load(std::sync::atomic::Ordering::SeqCst) { |
130 | 0 | return Poll::Pending; |
131 | 0 | } |
132 | 0 | this.woken.store(false, std::sync::atomic::Ordering::SeqCst); |
133 | 0 | let my_waker = IfWokenWaker { |
134 | 0 | inner: cx.waker().clone(), |
135 | 0 | wakened: this.woken.clone(), |
136 | 0 | }; |
137 | 0 | let my_waker = Arc::new(my_waker).into(); |
138 | 0 | let mut cx = Context::from_waker(&my_waker); |
139 | 0 | this.future.as_mut().poll(&mut cx) |
140 | 0 | } |
141 | | } |
142 | | |
143 | | impl Wake for IfWokenWaker { |
144 | | fn wake(self: Arc<Self>) { |
145 | | self.wakened |
146 | | .store(true, std::sync::atomic::Ordering::SeqCst); |
147 | | self.inner.wake_by_ref(); |
148 | | } |
149 | | } |
150 | | |
151 | | struct IfWokenWaker { |
152 | | inner: Waker, |
153 | | wakened: Arc<AtomicBool>, |
154 | | } |
155 | | |
156 | | impl<T: Future> TestFuture for T {} |
157 | | |
158 | | // ===== Drive ====== |
159 | | |
160 | | /// Drive a future to completion while also polling the driver |
161 | | /// |
162 | | /// This is useful for H2 futures that also require the connection to be polled. |
163 | | pub struct Drive<'a, T, U> { |
164 | | driver: &'a mut T, |
165 | | future: Wakened<U>, |
166 | | } |
167 | | |
168 | | impl<'a, T, U> Future for Drive<'a, T, U> |
169 | | where |
170 | | T: Future + Unpin, |
171 | | U: Future, |
172 | | { |
173 | | type Output = U::Output; |
174 | | |
175 | 0 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
176 | 0 | let mut looped = false; |
177 | | loop { |
178 | 0 | match self.future.poll_unpin(cx) { |
179 | 0 | Poll::Ready(val) => return Poll::Ready(val), |
180 | 0 | Poll::Pending => {} |
181 | | } |
182 | | |
183 | 0 | match self.driver.poll_unpin(cx) { |
184 | | Poll::Ready(_) => { |
185 | 0 | if looped { |
186 | | // Try polling the future one last time |
187 | 0 | panic!("driver resolved before future") |
188 | | } else { |
189 | 0 | looped = true; |
190 | 0 | continue; |
191 | | } |
192 | | } |
193 | 0 | Poll::Pending => {} |
194 | | } |
195 | | |
196 | 0 | return Poll::Pending; |
197 | | } |
198 | 0 | } |
199 | | } |