Coverage Report

Created: 2024-12-17 06:15

/rust/registry/src/index.crates.io-6f17d22bba15001f/tokio-rustls-0.26.1/src/client.rs
Line
Count
Source (jump to first uncovered line)
1
use std::io;
2
#[cfg(unix)]
3
use std::os::unix::io::{AsRawFd, RawFd};
4
#[cfg(windows)]
5
use std::os::windows::io::{AsRawSocket, RawSocket};
6
use std::pin::Pin;
7
#[cfg(feature = "early-data")]
8
use std::task::Waker;
9
use std::task::{Context, Poll};
10
11
use rustls::ClientConnection;
12
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14
use crate::common::{IoSession, Stream, TlsState};
15
16
/// A wrapper around an underlying raw stream which implements the TLS or SSL
17
/// protocol.
18
#[derive(Debug)]
19
pub struct TlsStream<IO> {
20
    pub(crate) io: IO,
21
    pub(crate) session: ClientConnection,
22
    pub(crate) state: TlsState,
23
24
    #[cfg(feature = "early-data")]
25
    pub(crate) early_waker: Option<Waker>,
26
}
27
28
impl<IO> TlsStream<IO> {
29
    #[inline]
30
0
    pub fn get_ref(&self) -> (&IO, &ClientConnection) {
31
0
        (&self.io, &self.session)
32
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>>>::get_ref
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_>>::get_ref
33
34
    #[inline]
35
0
    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
36
0
        (&mut self.io, &mut self.session)
37
0
    }
38
39
    #[inline]
40
0
    pub fn into_inner(self) -> (IO, ClientConnection) {
41
0
        (self.io, self.session)
42
0
    }
43
}
44
45
#[cfg(unix)]
46
impl<S> AsRawFd for TlsStream<S>
47
where
48
    S: AsRawFd,
49
{
50
0
    fn as_raw_fd(&self) -> RawFd {
51
0
        self.get_ref().0.as_raw_fd()
52
0
    }
53
}
54
55
#[cfg(windows)]
56
impl<S> AsRawSocket for TlsStream<S>
57
where
58
    S: AsRawSocket,
59
{
60
    fn as_raw_socket(&self) -> RawSocket {
61
        self.get_ref().0.as_raw_socket()
62
    }
63
}
64
65
impl<IO> IoSession for TlsStream<IO> {
66
    type Io = IO;
67
    type Session = ClientConnection;
68
69
    #[inline]
70
0
    fn skip_handshake(&self) -> bool {
71
0
        self.state.is_early_data()
72
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio_rustls::common::handshake::IoSession>::skip_handshake
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio_rustls::common::handshake::IoSession>::skip_handshake
73
74
    #[inline]
75
0
    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
76
0
        (&mut self.state, &mut self.io, &mut self.session)
77
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio_rustls::common::handshake::IoSession>::get_mut
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio_rustls::common::handshake::IoSession>::get_mut
78
79
    #[inline]
80
0
    fn into_io(self) -> Self::Io {
81
0
        self.io
82
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio_rustls::common::handshake::IoSession>::into_io
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio_rustls::common::handshake::IoSession>::into_io
83
}
84
85
impl<IO> AsyncRead for TlsStream<IO>
86
where
87
    IO: AsyncRead + AsyncWrite + Unpin,
88
{
89
0
    fn poll_read(
90
0
        self: Pin<&mut Self>,
91
0
        cx: &mut Context<'_>,
92
0
        buf: &mut ReadBuf<'_>,
93
0
    ) -> Poll<io::Result<()>> {
94
0
        match self.state {
95
            #[cfg(feature = "early-data")]
96
            TlsState::EarlyData(..) => {
97
                let this = self.get_mut();
98
99
                // In the EarlyData state, we have not really established a Tls connection.
100
                // Before writing data through `AsyncWrite` and completing the tls handshake,
101
                // we ignore read readiness and return to pending.
102
                //
103
                // In order to avoid event loss,
104
                // we need to register a waker and wake it up after tls is connected.
105
                if this
106
                    .early_waker
107
                    .as_ref()
108
                    .filter(|waker| cx.waker().will_wake(waker))
109
                    .is_none()
110
                {
111
                    this.early_waker = Some(cx.waker().clone());
112
                }
113
114
                Poll::Pending
115
            }
116
            TlsState::Stream | TlsState::WriteShutdown => {
117
0
                let this = self.get_mut();
118
0
                let mut stream =
119
0
                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120
0
                let prev = buf.remaining();
121
0
122
0
                match stream.as_mut_pin().poll_read(cx, buf) {
123
                    Poll::Ready(Ok(())) => {
124
0
                        if prev == buf.remaining() || stream.eof {
125
0
                            this.state.shutdown_read();
126
0
                        }
127
128
0
                        Poll::Ready(Ok(()))
129
                    }
130
0
                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
131
0
                        this.state.shutdown_read();
132
0
                        Poll::Ready(Err(err))
133
                    }
134
0
                    output => output,
135
                }
136
            }
137
0
            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
138
        }
139
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_read::AsyncRead>::poll_read
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_read::AsyncRead>::poll_read
140
}
141
142
impl<IO> AsyncWrite for TlsStream<IO>
143
where
144
    IO: AsyncRead + AsyncWrite + Unpin,
145
{
146
    /// Note: that it does not guarantee the final data to be sent.
147
    /// To be cautious, you must manually call `flush`.
148
0
    fn poll_write(
149
0
        self: Pin<&mut Self>,
150
0
        cx: &mut Context<'_>,
151
0
        buf: &[u8],
152
0
    ) -> Poll<io::Result<usize>> {
153
0
        let this = self.get_mut();
154
0
        let mut stream =
155
0
            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
156
0
157
0
        #[cfg(feature = "early-data")]
158
0
        {
159
0
            let bufs = [io::IoSlice::new(buf)];
160
0
            let written = ready!(poll_handle_early_data(
161
0
                &mut this.state,
162
0
                &mut stream,
163
0
                &mut this.early_waker,
164
0
                cx,
165
0
                &bufs
166
0
            ))?;
167
0
            if written != 0 {
168
0
                return Poll::Ready(Ok(written));
169
0
            }
170
0
        }
171
0
172
0
        stream.as_mut_pin().poll_write(cx, buf)
173
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::poll_write
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::poll_write
174
175
    /// Note: that it does not guarantee the final data to be sent.
176
    /// To be cautious, you must manually call `flush`.
177
0
    fn poll_write_vectored(
178
0
        self: Pin<&mut Self>,
179
0
        cx: &mut Context<'_>,
180
0
        bufs: &[io::IoSlice<'_>],
181
0
    ) -> Poll<io::Result<usize>> {
182
0
        let this = self.get_mut();
183
0
        let mut stream =
184
0
            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
185
0
186
0
        #[cfg(feature = "early-data")]
187
0
        {
188
0
            let written = ready!(poll_handle_early_data(
189
0
                &mut this.state,
190
0
                &mut stream,
191
0
                &mut this.early_waker,
192
0
                cx,
193
0
                bufs
194
0
            ))?;
195
0
            if written != 0 {
196
0
                return Poll::Ready(Ok(written));
197
0
            }
198
0
        }
199
0
200
0
        stream.as_mut_pin().poll_write_vectored(cx, bufs)
201
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::poll_write_vectored
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::poll_write_vectored
202
203
    #[inline]
204
0
    fn is_write_vectored(&self) -> bool {
205
0
        true
206
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::is_write_vectored
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::is_write_vectored
207
208
0
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209
0
        let this = self.get_mut();
210
0
        let mut stream =
211
0
            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
212
0
213
0
        #[cfg(feature = "early-data")]
214
0
        ready!(poll_handle_early_data(
215
0
            &mut this.state,
216
0
            &mut stream,
217
0
            &mut this.early_waker,
218
0
            cx,
219
0
            &[]
220
0
        ))?;
221
0
222
0
        stream.as_mut_pin().poll_flush(cx)
223
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::poll_flush
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::poll_flush
224
225
0
    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
226
0
        #[cfg(feature = "early-data")]
227
0
        {
228
0
            // complete handshake
229
0
            if matches!(self.state, TlsState::EarlyData(..)) {
230
0
                ready!(self.as_mut().poll_flush(cx))?;
231
0
            }
232
0
        }
233
0
234
0
        if self.state.writeable() {
235
0
            self.session.send_close_notify();
236
0
            self.state.shutdown_write();
237
0
        }
238
239
0
        let this = self.get_mut();
240
0
        let mut stream =
241
0
            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
242
0
        stream.as_mut_pin().poll_shutdown(cx)
243
0
    }
Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::poll_shutdown
Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::poll_shutdown
244
}
245
246
#[cfg(feature = "early-data")]
247
fn poll_handle_early_data<IO>(
248
    state: &mut TlsState,
249
    stream: &mut Stream<IO, ClientConnection>,
250
    early_waker: &mut Option<Waker>,
251
    cx: &mut Context<'_>,
252
    bufs: &[io::IoSlice<'_>],
253
) -> Poll<io::Result<usize>>
254
where
255
    IO: AsyncRead + AsyncWrite + Unpin,
256
{
257
    if let TlsState::EarlyData(pos, data) = state {
258
        use std::io::Write;
259
260
        // write early data
261
        if let Some(mut early_data) = stream.session.early_data() {
262
            let mut written = 0;
263
264
            for buf in bufs {
265
                if buf.is_empty() {
266
                    continue;
267
                }
268
269
                let len = match early_data.write(buf) {
270
                    Ok(0) => break,
271
                    Ok(n) => n,
272
                    Err(err) => return Poll::Ready(Err(err)),
273
                };
274
275
                written += len;
276
                data.extend_from_slice(&buf[..len]);
277
278
                if len < buf.len() {
279
                    break;
280
                }
281
            }
282
283
            if written != 0 {
284
                return Poll::Ready(Ok(written));
285
            }
286
        }
287
288
        // complete handshake
289
        while stream.session.is_handshaking() {
290
            ready!(stream.handshake(cx))?;
291
        }
292
293
        // write early data (fallback)
294
        if !stream.session.is_early_data_accepted() {
295
            while *pos < data.len() {
296
                let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
297
                *pos += len;
298
            }
299
        }
300
301
        // end
302
        *state = TlsState::Stream;
303
304
        if let Some(waker) = early_waker.take() {
305
            waker.wake();
306
        }
307
    }
308
309
    Poll::Ready(Ok(0))
310
}