Coverage Report

Created: 2025-10-29 07:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/ztunnel/src/proxy/h2.rs
Line
Count
Source
1
// Copyright Istio Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
use crate::copy;
16
use bytes::Bytes;
17
use futures_core::ready;
18
use h2::Reason;
19
use std::io::Error;
20
use std::pin::Pin;
21
use std::sync::Arc;
22
use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
23
use std::task::{Context, Poll};
24
use std::time::Duration;
25
use tokio::sync::oneshot;
26
use tracing::trace;
27
28
pub mod client;
29
pub mod server;
30
31
0
async fn do_ping_pong(
32
0
    mut ping_pong: h2::PingPong,
33
0
    tx: oneshot::Sender<()>,
34
0
    dropped: Arc<AtomicBool>,
35
0
) {
36
    const PING_INTERVAL: Duration = Duration::from_secs(10);
37
    const PING_TIMEOUT: Duration = Duration::from_secs(20);
38
    // delay before sending the first ping, no need to race with the first request
39
0
    tokio::time::sleep(PING_INTERVAL).await;
40
    loop {
41
0
        if dropped.load(Ordering::Relaxed) {
42
0
            return;
43
0
        }
44
0
        let ping_fut = ping_pong.ping(h2::Ping::opaque());
45
0
        log::trace!("ping sent");
46
0
        match tokio::time::timeout(PING_TIMEOUT, ping_fut).await {
47
            Err(_) => {
48
                // We will log this again up in drive_connection, so don't worry about a high log level
49
0
                log::trace!("ping timeout");
50
0
                let _ = tx.send(());
51
0
                return;
52
            }
53
0
            Ok(r) => match r {
54
                Ok(_) => {
55
0
                    log::trace!("pong received");
56
0
                    tokio::time::sleep(PING_INTERVAL).await;
57
                }
58
0
                Err(e) => {
59
0
                    if dropped.load(Ordering::Relaxed) {
60
                        // drive_connection() exits first, no need to error again
61
0
                        return;
62
0
                    }
63
0
                    log::error!("ping error: {e}");
64
0
                    let _ = tx.send(());
65
0
                    return;
66
                }
67
            },
68
        }
69
    }
70
0
}
71
72
// H2Stream represents an active HTTP2 stream. Consumers can only Read/Write
73
pub struct H2Stream {
74
    read: H2StreamReadHalf,
75
    write: H2StreamWriteHalf,
76
}
77
78
pub struct H2StreamReadHalf {
79
    recv_stream: h2::RecvStream,
80
    _dropped: Option<DropCounter>,
81
}
82
83
pub struct H2StreamWriteHalf {
84
    send_stream: h2::SendStream<Bytes>,
85
    _dropped: Option<DropCounter>,
86
}
87
88
pub struct TokioH2Stream {
89
    stream: H2Stream,
90
    buf: Bytes,
91
}
92
93
struct DropCounter {
94
    // Whether the other end of this shared counter has already dropped.
95
    // We only decrement if they have, so we do not double count
96
    half_dropped: Arc<()>,
97
    active_count: Arc<AtomicU16>,
98
}
99
100
impl DropCounter {
101
0
    pub fn new(active_count: Arc<AtomicU16>) -> (Option<DropCounter>, Option<DropCounter>) {
102
0
        let half_dropped = Arc::new(());
103
0
        let d1 = DropCounter {
104
0
            half_dropped: half_dropped.clone(),
105
0
            active_count: active_count.clone(),
106
0
        };
107
0
        let d2 = DropCounter {
108
0
            half_dropped,
109
0
            active_count,
110
0
        };
111
0
        (Some(d1), Some(d2))
112
0
    }
113
}
114
115
impl crate::copy::BufferedSplitter for H2Stream {
116
    type R = H2StreamReadHalf;
117
    type W = H2StreamWriteHalf;
118
0
    fn split_into_buffered_reader(self) -> (H2StreamReadHalf, H2StreamWriteHalf) {
119
0
        let H2Stream { read, write } = self;
120
0
        (read, write)
121
0
    }
122
}
123
124
impl H2StreamWriteHalf {
125
0
    fn write_slice(&mut self, buf: Bytes, end_of_stream: bool) -> Result<(), std::io::Error> {
126
0
        self.send_stream
127
0
            .send_data(buf, end_of_stream)
128
0
            .map_err(h2_to_io_error)
129
0
    }
130
}
131
132
impl Drop for DropCounter {
133
0
    fn drop(&mut self) {
134
0
        let mut half_dropped = Arc::new(());
135
0
        std::mem::swap(&mut self.half_dropped, &mut half_dropped);
136
0
        if Arc::into_inner(half_dropped).is_none() {
137
            // other half already dropped
138
0
            let left = self.active_count.fetch_sub(1, Ordering::SeqCst);
139
0
            trace!("dropping H2Stream, has {} active streams left", left - 1);
140
        } else {
141
0
            trace!("dropping H2Stream, other half remains");
142
        }
143
0
    }
144
}
145
146
// We can't directly implement tokio::io::{AsyncRead, AsyncWrite} for H2Stream because
147
// then the specific implementation will conflict with the generic one.
148
impl TokioH2Stream {
149
0
    pub fn new(stream: H2Stream) -> Self {
150
0
        Self {
151
0
            stream,
152
0
            buf: Bytes::new(),
153
0
        }
154
0
    }
155
}
156
157
impl tokio::io::AsyncRead for TokioH2Stream {
158
0
    fn poll_read(
159
0
        mut self: Pin<&mut Self>,
160
0
        cx: &mut Context<'_>,
161
0
        buf: &mut tokio::io::ReadBuf<'_>,
162
0
    ) -> Poll<std::io::Result<()>> {
163
        // Just return the bytes we have left over and don't poll the stream because
164
        // its unclear what to do if there are bytes left over from the previous read, and when we
165
        // poll, we get an error.
166
0
        if self.buf.is_empty() {
167
            // If we have no unread bytes, we can poll the stream
168
            // and fill self.buf with the bytes we read.
169
0
            let pinned = std::pin::Pin::new(&mut self.stream.read);
170
0
            let res = ready!(copy::ResizeBufRead::poll_bytes(pinned, cx))?;
171
0
            self.buf = res;
172
0
        }
173
        // Copy as many bytes as we can from self.buf.
174
0
        let cnt = Ord::min(buf.remaining(), self.buf.len());
175
0
        buf.put_slice(&self.buf[..cnt]);
176
0
        self.buf = self.buf.split_off(cnt);
177
0
        Poll::Ready(Ok(()))
178
0
    }
179
}
180
181
impl tokio::io::AsyncWrite for TokioH2Stream {
182
0
    fn poll_write(
183
0
        mut self: Pin<&mut Self>,
184
0
        cx: &mut Context<'_>,
185
0
        buf: &[u8],
186
0
    ) -> Poll<Result<usize, tokio::io::Error>> {
187
0
        let pinned = std::pin::Pin::new(&mut self.stream.write);
188
0
        let buf = Bytes::copy_from_slice(buf);
189
0
        copy::AsyncWriteBuf::poll_write_buf(pinned, cx, buf)
190
0
    }
191
192
0
    fn poll_flush(
193
0
        mut self: Pin<&mut Self>,
194
0
        cx: &mut Context<'_>,
195
0
    ) -> Poll<Result<(), std::io::Error>> {
196
0
        let pinned = std::pin::Pin::new(&mut self.stream.write);
197
0
        copy::AsyncWriteBuf::poll_flush(pinned, cx)
198
0
    }
199
200
0
    fn poll_shutdown(
201
0
        mut self: Pin<&mut Self>,
202
0
        cx: &mut Context<'_>,
203
0
    ) -> Poll<Result<(), std::io::Error>> {
204
0
        let pinned = std::pin::Pin::new(&mut self.stream.write);
205
0
        copy::AsyncWriteBuf::poll_shutdown(pinned, cx)
206
0
    }
207
}
208
209
impl copy::ResizeBufRead for H2StreamReadHalf {
210
0
    fn poll_bytes(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<Bytes>> {
211
0
        let this = self.get_mut();
212
        loop {
213
0
            match ready!(this.recv_stream.poll_data(cx)) {
214
0
                None => return Poll::Ready(Ok(Bytes::new())),
215
0
                Some(Ok(buf)) if buf.is_empty() && !this.recv_stream.is_end_stream() => continue,
216
0
                Some(Ok(buf)) => {
217
                    // TODO: Hyper and Go make their pinging data aware and don't send pings when data is received
218
                    // Pingora, and our implementation, currently don't do this.
219
                    // We may want to; if so, modify here.
220
                    // this.ping.record_data(buf.len());
221
0
                    let _ = this.recv_stream.flow_control().release_capacity(buf.len());
222
0
                    return Poll::Ready(Ok(buf));
223
                }
224
0
                Some(Err(e)) => {
225
0
                    return Poll::Ready(match e.reason() {
226
                        Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => {
227
0
                            return Poll::Ready(Ok(Bytes::new()));
228
                        }
229
                        Some(Reason::STREAM_CLOSED) => {
230
0
                            Err(Error::new(std::io::ErrorKind::BrokenPipe, e))
231
                        }
232
0
                        _ => Err(h2_to_io_error(e)),
233
                    });
234
                }
235
            }
236
        }
237
0
    }
238
239
0
    fn resize(self: Pin<&mut Self>, _new_size: usize) {
240
        // NOP, we don't need to resize as we are abstracting the h2 buffer
241
0
    }
242
}
243
244
impl copy::AsyncWriteBuf for H2StreamWriteHalf {
245
0
    fn poll_write_buf(
246
0
        mut self: Pin<&mut Self>,
247
0
        cx: &mut Context<'_>,
248
0
        buf: Bytes,
249
0
    ) -> Poll<std::io::Result<usize>> {
250
0
        if buf.is_empty() {
251
0
            return Poll::Ready(Ok(0));
252
0
        }
253
0
        self.send_stream.reserve_capacity(buf.len());
254
255
        // We ignore all errors returned by `poll_capacity` and `write`, as we
256
        // will get the correct from `poll_reset` anyway.
257
0
        let cnt = match ready!(self.send_stream.poll_capacity(cx)) {
258
0
            None => Some(0),
259
0
            Some(Ok(cnt)) => self.write_slice(buf.slice(..cnt), false).ok().map(|()| cnt),
260
0
            Some(Err(_)) => None,
261
        };
262
263
0
        if let Some(cnt) = cnt {
264
0
            return Poll::Ready(Ok(cnt));
265
0
        }
266
267
0
        Poll::Ready(Err(h2_to_io_error(
268
0
            match ready!(self.send_stream.poll_reset(cx)) {
269
                Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
270
0
                    return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
271
                }
272
0
                Ok(reason) => reason.into(),
273
0
                Err(e) => e,
274
            },
275
        )))
276
0
    }
277
278
0
    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
279
0
        Poll::Ready(Ok(()))
280
0
    }
281
282
0
    fn poll_shutdown(
283
0
        mut self: Pin<&mut Self>,
284
0
        cx: &mut Context<'_>,
285
0
    ) -> Poll<Result<(), std::io::Error>> {
286
0
        let r = self.write_slice(Bytes::new(), true);
287
0
        if r.is_ok() {
288
0
            return Poll::Ready(Ok(()));
289
0
        }
290
291
0
        Poll::Ready(Err(h2_to_io_error(
292
0
            match ready!(self.send_stream.poll_reset(cx)) {
293
0
                Ok(Reason::NO_ERROR) => return Poll::Ready(Ok(())),
294
                Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => {
295
0
                    return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
296
                }
297
0
                Ok(reason) => reason.into(),
298
0
                Err(e) => e,
299
            },
300
        )))
301
0
    }
302
}
303
304
0
fn h2_to_io_error(e: h2::Error) -> std::io::Error {
305
0
    if e.is_io() {
306
0
        e.into_io().unwrap()
307
    } else {
308
0
        std::io::Error::other(e)
309
    }
310
0
}