/src/ztunnel/src/proxy/h2.rs
Line | Count | Source |
1 | | // Copyright Istio Authors |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | use crate::copy; |
16 | | use bytes::Bytes; |
17 | | use futures_core::ready; |
18 | | use h2::Reason; |
19 | | use std::io::Error; |
20 | | use std::pin::Pin; |
21 | | use std::sync::Arc; |
22 | | use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; |
23 | | use std::task::{Context, Poll}; |
24 | | use std::time::Duration; |
25 | | use tokio::sync::oneshot; |
26 | | use tracing::trace; |
27 | | |
28 | | pub mod client; |
29 | | pub mod server; |
30 | | |
31 | 0 | async fn do_ping_pong( |
32 | 0 | mut ping_pong: h2::PingPong, |
33 | 0 | tx: oneshot::Sender<()>, |
34 | 0 | dropped: Arc<AtomicBool>, |
35 | 0 | ) { |
36 | | const PING_INTERVAL: Duration = Duration::from_secs(10); |
37 | | const PING_TIMEOUT: Duration = Duration::from_secs(20); |
38 | | // delay before sending the first ping, no need to race with the first request |
39 | 0 | tokio::time::sleep(PING_INTERVAL).await; |
40 | | loop { |
41 | 0 | if dropped.load(Ordering::Relaxed) { |
42 | 0 | return; |
43 | 0 | } |
44 | 0 | let ping_fut = ping_pong.ping(h2::Ping::opaque()); |
45 | 0 | log::trace!("ping sent"); |
46 | 0 | match tokio::time::timeout(PING_TIMEOUT, ping_fut).await { |
47 | | Err(_) => { |
48 | | // We will log this again up in drive_connection, so don't worry about a high log level |
49 | 0 | log::trace!("ping timeout"); |
50 | 0 | let _ = tx.send(()); |
51 | 0 | return; |
52 | | } |
53 | 0 | Ok(r) => match r { |
54 | | Ok(_) => { |
55 | 0 | log::trace!("pong received"); |
56 | 0 | tokio::time::sleep(PING_INTERVAL).await; |
57 | | } |
58 | 0 | Err(e) => { |
59 | 0 | if dropped.load(Ordering::Relaxed) { |
60 | | // drive_connection() exits first, no need to error again |
61 | 0 | return; |
62 | 0 | } |
63 | 0 | log::error!("ping error: {e}"); |
64 | 0 | let _ = tx.send(()); |
65 | 0 | return; |
66 | | } |
67 | | }, |
68 | | } |
69 | | } |
70 | 0 | } |
71 | | |
72 | | // H2Stream represents an active HTTP2 stream. Consumers can only Read/Write |
73 | | pub struct H2Stream { |
74 | | read: H2StreamReadHalf, |
75 | | write: H2StreamWriteHalf, |
76 | | } |
77 | | |
78 | | pub struct H2StreamReadHalf { |
79 | | recv_stream: h2::RecvStream, |
80 | | _dropped: Option<DropCounter>, |
81 | | } |
82 | | |
83 | | pub struct H2StreamWriteHalf { |
84 | | send_stream: h2::SendStream<Bytes>, |
85 | | _dropped: Option<DropCounter>, |
86 | | } |
87 | | |
88 | | pub struct TokioH2Stream { |
89 | | stream: H2Stream, |
90 | | buf: Bytes, |
91 | | } |
92 | | |
93 | | struct DropCounter { |
94 | | // Whether the other end of this shared counter has already dropped. |
95 | | // We only decrement if they have, so we do not double count |
96 | | half_dropped: Arc<()>, |
97 | | active_count: Arc<AtomicU16>, |
98 | | } |
99 | | |
100 | | impl DropCounter { |
101 | 0 | pub fn new(active_count: Arc<AtomicU16>) -> (Option<DropCounter>, Option<DropCounter>) { |
102 | 0 | let half_dropped = Arc::new(()); |
103 | 0 | let d1 = DropCounter { |
104 | 0 | half_dropped: half_dropped.clone(), |
105 | 0 | active_count: active_count.clone(), |
106 | 0 | }; |
107 | 0 | let d2 = DropCounter { |
108 | 0 | half_dropped, |
109 | 0 | active_count, |
110 | 0 | }; |
111 | 0 | (Some(d1), Some(d2)) |
112 | 0 | } |
113 | | } |
114 | | |
115 | | impl crate::copy::BufferedSplitter for H2Stream { |
116 | | type R = H2StreamReadHalf; |
117 | | type W = H2StreamWriteHalf; |
118 | 0 | fn split_into_buffered_reader(self) -> (H2StreamReadHalf, H2StreamWriteHalf) { |
119 | 0 | let H2Stream { read, write } = self; |
120 | 0 | (read, write) |
121 | 0 | } |
122 | | } |
123 | | |
124 | | impl H2StreamWriteHalf { |
125 | 0 | fn write_slice(&mut self, buf: Bytes, end_of_stream: bool) -> Result<(), std::io::Error> { |
126 | 0 | self.send_stream |
127 | 0 | .send_data(buf, end_of_stream) |
128 | 0 | .map_err(h2_to_io_error) |
129 | 0 | } |
130 | | } |
131 | | |
132 | | impl Drop for DropCounter { |
133 | 0 | fn drop(&mut self) { |
134 | 0 | let mut half_dropped = Arc::new(()); |
135 | 0 | std::mem::swap(&mut self.half_dropped, &mut half_dropped); |
136 | 0 | if Arc::into_inner(half_dropped).is_none() { |
137 | | // other half already dropped |
138 | 0 | let left = self.active_count.fetch_sub(1, Ordering::SeqCst); |
139 | 0 | trace!("dropping H2Stream, has {} active streams left", left - 1); |
140 | | } else { |
141 | 0 | trace!("dropping H2Stream, other half remains"); |
142 | | } |
143 | 0 | } |
144 | | } |
145 | | |
146 | | // We can't directly implement tokio::io::{AsyncRead, AsyncWrite} for H2Stream because |
147 | | // then the specific implementation will conflict with the generic one. |
148 | | impl TokioH2Stream { |
149 | 0 | pub fn new(stream: H2Stream) -> Self { |
150 | 0 | Self { |
151 | 0 | stream, |
152 | 0 | buf: Bytes::new(), |
153 | 0 | } |
154 | 0 | } |
155 | | } |
156 | | |
157 | | impl tokio::io::AsyncRead for TokioH2Stream { |
158 | 0 | fn poll_read( |
159 | 0 | mut self: Pin<&mut Self>, |
160 | 0 | cx: &mut Context<'_>, |
161 | 0 | buf: &mut tokio::io::ReadBuf<'_>, |
162 | 0 | ) -> Poll<std::io::Result<()>> { |
163 | | // Just return the bytes we have left over and don't poll the stream because |
164 | | // its unclear what to do if there are bytes left over from the previous read, and when we |
165 | | // poll, we get an error. |
166 | 0 | if self.buf.is_empty() { |
167 | | // If we have no unread bytes, we can poll the stream |
168 | | // and fill self.buf with the bytes we read. |
169 | 0 | let pinned = std::pin::Pin::new(&mut self.stream.read); |
170 | 0 | let res = ready!(copy::ResizeBufRead::poll_bytes(pinned, cx))?; |
171 | 0 | self.buf = res; |
172 | 0 | } |
173 | | // Copy as many bytes as we can from self.buf. |
174 | 0 | let cnt = Ord::min(buf.remaining(), self.buf.len()); |
175 | 0 | buf.put_slice(&self.buf[..cnt]); |
176 | 0 | self.buf = self.buf.split_off(cnt); |
177 | 0 | Poll::Ready(Ok(())) |
178 | 0 | } |
179 | | } |
180 | | |
181 | | impl tokio::io::AsyncWrite for TokioH2Stream { |
182 | 0 | fn poll_write( |
183 | 0 | mut self: Pin<&mut Self>, |
184 | 0 | cx: &mut Context<'_>, |
185 | 0 | buf: &[u8], |
186 | 0 | ) -> Poll<Result<usize, tokio::io::Error>> { |
187 | 0 | let pinned = std::pin::Pin::new(&mut self.stream.write); |
188 | 0 | let buf = Bytes::copy_from_slice(buf); |
189 | 0 | copy::AsyncWriteBuf::poll_write_buf(pinned, cx, buf) |
190 | 0 | } |
191 | | |
192 | 0 | fn poll_flush( |
193 | 0 | mut self: Pin<&mut Self>, |
194 | 0 | cx: &mut Context<'_>, |
195 | 0 | ) -> Poll<Result<(), std::io::Error>> { |
196 | 0 | let pinned = std::pin::Pin::new(&mut self.stream.write); |
197 | 0 | copy::AsyncWriteBuf::poll_flush(pinned, cx) |
198 | 0 | } |
199 | | |
200 | 0 | fn poll_shutdown( |
201 | 0 | mut self: Pin<&mut Self>, |
202 | 0 | cx: &mut Context<'_>, |
203 | 0 | ) -> Poll<Result<(), std::io::Error>> { |
204 | 0 | let pinned = std::pin::Pin::new(&mut self.stream.write); |
205 | 0 | copy::AsyncWriteBuf::poll_shutdown(pinned, cx) |
206 | 0 | } |
207 | | } |
208 | | |
209 | | impl copy::ResizeBufRead for H2StreamReadHalf { |
210 | 0 | fn poll_bytes(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<Bytes>> { |
211 | 0 | let this = self.get_mut(); |
212 | | loop { |
213 | 0 | match ready!(this.recv_stream.poll_data(cx)) { |
214 | 0 | None => return Poll::Ready(Ok(Bytes::new())), |
215 | 0 | Some(Ok(buf)) if buf.is_empty() && !this.recv_stream.is_end_stream() => continue, |
216 | 0 | Some(Ok(buf)) => { |
217 | | // TODO: Hyper and Go make their pinging data aware and don't send pings when data is received |
218 | | // Pingora, and our implementation, currently don't do this. |
219 | | // We may want to; if so, modify here. |
220 | | // this.ping.record_data(buf.len()); |
221 | 0 | let _ = this.recv_stream.flow_control().release_capacity(buf.len()); |
222 | 0 | return Poll::Ready(Ok(buf)); |
223 | | } |
224 | 0 | Some(Err(e)) => { |
225 | 0 | return Poll::Ready(match e.reason() { |
226 | | Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => { |
227 | 0 | return Poll::Ready(Ok(Bytes::new())); |
228 | | } |
229 | | Some(Reason::STREAM_CLOSED) => { |
230 | 0 | Err(Error::new(std::io::ErrorKind::BrokenPipe, e)) |
231 | | } |
232 | 0 | _ => Err(h2_to_io_error(e)), |
233 | | }); |
234 | | } |
235 | | } |
236 | | } |
237 | 0 | } |
238 | | |
239 | 0 | fn resize(self: Pin<&mut Self>, _new_size: usize) { |
240 | | // NOP, we don't need to resize as we are abstracting the h2 buffer |
241 | 0 | } |
242 | | } |
243 | | |
244 | | impl copy::AsyncWriteBuf for H2StreamWriteHalf { |
245 | 0 | fn poll_write_buf( |
246 | 0 | mut self: Pin<&mut Self>, |
247 | 0 | cx: &mut Context<'_>, |
248 | 0 | buf: Bytes, |
249 | 0 | ) -> Poll<std::io::Result<usize>> { |
250 | 0 | if buf.is_empty() { |
251 | 0 | return Poll::Ready(Ok(0)); |
252 | 0 | } |
253 | 0 | self.send_stream.reserve_capacity(buf.len()); |
254 | | |
255 | | // We ignore all errors returned by `poll_capacity` and `write`, as we |
256 | | // will get the correct from `poll_reset` anyway. |
257 | 0 | let cnt = match ready!(self.send_stream.poll_capacity(cx)) { |
258 | 0 | None => Some(0), |
259 | 0 | Some(Ok(cnt)) => self.write_slice(buf.slice(..cnt), false).ok().map(|()| cnt), |
260 | 0 | Some(Err(_)) => None, |
261 | | }; |
262 | | |
263 | 0 | if let Some(cnt) = cnt { |
264 | 0 | return Poll::Ready(Ok(cnt)); |
265 | 0 | } |
266 | | |
267 | 0 | Poll::Ready(Err(h2_to_io_error( |
268 | 0 | match ready!(self.send_stream.poll_reset(cx)) { |
269 | | Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { |
270 | 0 | return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); |
271 | | } |
272 | 0 | Ok(reason) => reason.into(), |
273 | 0 | Err(e) => e, |
274 | | }, |
275 | | ))) |
276 | 0 | } |
277 | | |
278 | 0 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> { |
279 | 0 | Poll::Ready(Ok(())) |
280 | 0 | } |
281 | | |
282 | 0 | fn poll_shutdown( |
283 | 0 | mut self: Pin<&mut Self>, |
284 | 0 | cx: &mut Context<'_>, |
285 | 0 | ) -> Poll<Result<(), std::io::Error>> { |
286 | 0 | let r = self.write_slice(Bytes::new(), true); |
287 | 0 | if r.is_ok() { |
288 | 0 | return Poll::Ready(Ok(())); |
289 | 0 | } |
290 | | |
291 | 0 | Poll::Ready(Err(h2_to_io_error( |
292 | 0 | match ready!(self.send_stream.poll_reset(cx)) { |
293 | 0 | Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())), |
294 | | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { |
295 | 0 | return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); |
296 | | } |
297 | 0 | Ok(reason) => reason.into(), |
298 | 0 | Err(e) => e, |
299 | | }, |
300 | | ))) |
301 | 0 | } |
302 | | } |
303 | | |
304 | 0 | fn h2_to_io_error(e: h2::Error) -> std::io::Error { |
305 | 0 | if e.is_io() { |
306 | 0 | e.into_io().unwrap() |
307 | | } else { |
308 | 0 | std::io::Error::other(e) |
309 | | } |
310 | 0 | } |