/rust/registry/src/index.crates.io-6f17d22bba15001f/tokio-rustls-0.26.1/src/client.rs
Line | Count | Source (jump to first uncovered line) |
1 | | use std::io; |
2 | | #[cfg(unix)] |
3 | | use std::os::unix::io::{AsRawFd, RawFd}; |
4 | | #[cfg(windows)] |
5 | | use std::os::windows::io::{AsRawSocket, RawSocket}; |
6 | | use std::pin::Pin; |
7 | | #[cfg(feature = "early-data")] |
8 | | use std::task::Waker; |
9 | | use std::task::{Context, Poll}; |
10 | | |
11 | | use rustls::ClientConnection; |
12 | | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
13 | | |
14 | | use crate::common::{IoSession, Stream, TlsState}; |
15 | | |
16 | | /// A wrapper around an underlying raw stream which implements the TLS or SSL |
17 | | /// protocol. |
18 | | #[derive(Debug)] |
19 | | pub struct TlsStream<IO> { |
20 | | pub(crate) io: IO, |
21 | | pub(crate) session: ClientConnection, |
22 | | pub(crate) state: TlsState, |
23 | | |
24 | | #[cfg(feature = "early-data")] |
25 | | pub(crate) early_waker: Option<Waker>, |
26 | | } |
27 | | |
28 | | impl<IO> TlsStream<IO> { |
29 | | #[inline] |
30 | 0 | pub fn get_ref(&self) -> (&IO, &ClientConnection) { |
31 | 0 | (&self.io, &self.session) |
32 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>>>::get_ref Unexecuted instantiation: <tokio_rustls::client::TlsStream<_>>::get_ref |
33 | | |
34 | | #[inline] |
35 | 0 | pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) { |
36 | 0 | (&mut self.io, &mut self.session) |
37 | 0 | } |
38 | | |
39 | | #[inline] |
40 | 0 | pub fn into_inner(self) -> (IO, ClientConnection) { |
41 | 0 | (self.io, self.session) |
42 | 0 | } |
43 | | } |
44 | | |
45 | | #[cfg(unix)] |
46 | | impl<S> AsRawFd for TlsStream<S> |
47 | | where |
48 | | S: AsRawFd, |
49 | | { |
50 | 0 | fn as_raw_fd(&self) -> RawFd { |
51 | 0 | self.get_ref().0.as_raw_fd() |
52 | 0 | } |
53 | | } |
54 | | |
55 | | #[cfg(windows)] |
56 | | impl<S> AsRawSocket for TlsStream<S> |
57 | | where |
58 | | S: AsRawSocket, |
59 | | { |
60 | | fn as_raw_socket(&self) -> RawSocket { |
61 | | self.get_ref().0.as_raw_socket() |
62 | | } |
63 | | } |
64 | | |
65 | | impl<IO> IoSession for TlsStream<IO> { |
66 | | type Io = IO; |
67 | | type Session = ClientConnection; |
68 | | |
69 | | #[inline] |
70 | 0 | fn skip_handshake(&self) -> bool { |
71 | 0 | self.state.is_early_data() |
72 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio_rustls::common::handshake::IoSession>::skip_handshake Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio_rustls::common::handshake::IoSession>::skip_handshake |
73 | | |
74 | | #[inline] |
75 | 0 | fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { |
76 | 0 | (&mut self.state, &mut self.io, &mut self.session) |
77 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio_rustls::common::handshake::IoSession>::get_mut Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio_rustls::common::handshake::IoSession>::get_mut |
78 | | |
79 | | #[inline] |
80 | 0 | fn into_io(self) -> Self::Io { |
81 | 0 | self.io |
82 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio_rustls::common::handshake::IoSession>::into_io Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio_rustls::common::handshake::IoSession>::into_io |
83 | | } |
84 | | |
85 | | impl<IO> AsyncRead for TlsStream<IO> |
86 | | where |
87 | | IO: AsyncRead + AsyncWrite + Unpin, |
88 | | { |
89 | 0 | fn poll_read( |
90 | 0 | self: Pin<&mut Self>, |
91 | 0 | cx: &mut Context<'_>, |
92 | 0 | buf: &mut ReadBuf<'_>, |
93 | 0 | ) -> Poll<io::Result<()>> { |
94 | 0 | match self.state { |
95 | | #[cfg(feature = "early-data")] |
96 | | TlsState::EarlyData(..) => { |
97 | | let this = self.get_mut(); |
98 | | |
99 | | // In the EarlyData state, we have not really established a Tls connection. |
100 | | // Before writing data through `AsyncWrite` and completing the tls handshake, |
101 | | // we ignore read readiness and return to pending. |
102 | | // |
103 | | // In order to avoid event loss, |
104 | | // we need to register a waker and wake it up after tls is connected. |
105 | | if this |
106 | | .early_waker |
107 | | .as_ref() |
108 | | .filter(|waker| cx.waker().will_wake(waker)) |
109 | | .is_none() |
110 | | { |
111 | | this.early_waker = Some(cx.waker().clone()); |
112 | | } |
113 | | |
114 | | Poll::Pending |
115 | | } |
116 | | TlsState::Stream | TlsState::WriteShutdown => { |
117 | 0 | let this = self.get_mut(); |
118 | 0 | let mut stream = |
119 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
120 | 0 | let prev = buf.remaining(); |
121 | 0 |
|
122 | 0 | match stream.as_mut_pin().poll_read(cx, buf) { |
123 | | Poll::Ready(Ok(())) => { |
124 | 0 | if prev == buf.remaining() || stream.eof { |
125 | 0 | this.state.shutdown_read(); |
126 | 0 | } |
127 | | |
128 | 0 | Poll::Ready(Ok(())) |
129 | | } |
130 | 0 | Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { |
131 | 0 | this.state.shutdown_read(); |
132 | 0 | Poll::Ready(Err(err)) |
133 | | } |
134 | 0 | output => output, |
135 | | } |
136 | | } |
137 | 0 | TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())), |
138 | | } |
139 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_read::AsyncRead>::poll_read Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_read::AsyncRead>::poll_read |
140 | | } |
141 | | |
142 | | impl<IO> AsyncWrite for TlsStream<IO> |
143 | | where |
144 | | IO: AsyncRead + AsyncWrite + Unpin, |
145 | | { |
146 | | /// Note: that it does not guarantee the final data to be sent. |
147 | | /// To be cautious, you must manually call `flush`. |
148 | 0 | fn poll_write( |
149 | 0 | self: Pin<&mut Self>, |
150 | 0 | cx: &mut Context<'_>, |
151 | 0 | buf: &[u8], |
152 | 0 | ) -> Poll<io::Result<usize>> { |
153 | 0 | let this = self.get_mut(); |
154 | 0 | let mut stream = |
155 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
156 | 0 |
|
157 | 0 | #[cfg(feature = "early-data")] |
158 | 0 | { |
159 | 0 | let bufs = [io::IoSlice::new(buf)]; |
160 | 0 | let written = ready!(poll_handle_early_data( |
161 | 0 | &mut this.state, |
162 | 0 | &mut stream, |
163 | 0 | &mut this.early_waker, |
164 | 0 | cx, |
165 | 0 | &bufs |
166 | 0 | ))?; |
167 | 0 | if written != 0 { |
168 | 0 | return Poll::Ready(Ok(written)); |
169 | 0 | } |
170 | 0 | } |
171 | 0 |
|
172 | 0 | stream.as_mut_pin().poll_write(cx, buf) |
173 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::poll_write Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::poll_write |
174 | | |
175 | | /// Note: that it does not guarantee the final data to be sent. |
176 | | /// To be cautious, you must manually call `flush`. |
177 | 0 | fn poll_write_vectored( |
178 | 0 | self: Pin<&mut Self>, |
179 | 0 | cx: &mut Context<'_>, |
180 | 0 | bufs: &[io::IoSlice<'_>], |
181 | 0 | ) -> Poll<io::Result<usize>> { |
182 | 0 | let this = self.get_mut(); |
183 | 0 | let mut stream = |
184 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
185 | 0 |
|
186 | 0 | #[cfg(feature = "early-data")] |
187 | 0 | { |
188 | 0 | let written = ready!(poll_handle_early_data( |
189 | 0 | &mut this.state, |
190 | 0 | &mut stream, |
191 | 0 | &mut this.early_waker, |
192 | 0 | cx, |
193 | 0 | bufs |
194 | 0 | ))?; |
195 | 0 | if written != 0 { |
196 | 0 | return Poll::Ready(Ok(written)); |
197 | 0 | } |
198 | 0 | } |
199 | 0 |
|
200 | 0 | stream.as_mut_pin().poll_write_vectored(cx, bufs) |
201 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::poll_write_vectored Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::poll_write_vectored |
202 | | |
203 | | #[inline] |
204 | 0 | fn is_write_vectored(&self) -> bool { |
205 | 0 | true |
206 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::is_write_vectored Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::is_write_vectored |
207 | | |
208 | 0 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
209 | 0 | let this = self.get_mut(); |
210 | 0 | let mut stream = |
211 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
212 | 0 |
|
213 | 0 | #[cfg(feature = "early-data")] |
214 | 0 | ready!(poll_handle_early_data( |
215 | 0 | &mut this.state, |
216 | 0 | &mut stream, |
217 | 0 | &mut this.early_waker, |
218 | 0 | cx, |
219 | 0 | &[] |
220 | 0 | ))?; |
221 | 0 |
|
222 | 0 | stream.as_mut_pin().poll_flush(cx) |
223 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::poll_flush Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::poll_flush |
224 | | |
225 | 0 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
226 | 0 | #[cfg(feature = "early-data")] |
227 | 0 | { |
228 | 0 | // complete handshake |
229 | 0 | if matches!(self.state, TlsState::EarlyData(..)) { |
230 | 0 | ready!(self.as_mut().poll_flush(cx))?; |
231 | 0 | } |
232 | 0 | } |
233 | 0 |
|
234 | 0 | if self.state.writeable() { |
235 | 0 | self.session.send_close_notify(); |
236 | 0 | self.state.shutdown_write(); |
237 | 0 | } |
238 | | |
239 | 0 | let this = self.get_mut(); |
240 | 0 | let mut stream = |
241 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
242 | 0 | stream.as_mut_pin().poll_shutdown(cx) |
243 | 0 | } Unexecuted instantiation: <tokio_rustls::client::TlsStream<linkerd_io::scoped::ScopedIo<tokio::net::tcp::stream::TcpStream>> as tokio::io::async_write::AsyncWrite>::poll_shutdown Unexecuted instantiation: <tokio_rustls::client::TlsStream<_> as tokio::io::async_write::AsyncWrite>::poll_shutdown |
244 | | } |
245 | | |
246 | | #[cfg(feature = "early-data")] |
247 | | fn poll_handle_early_data<IO>( |
248 | | state: &mut TlsState, |
249 | | stream: &mut Stream<IO, ClientConnection>, |
250 | | early_waker: &mut Option<Waker>, |
251 | | cx: &mut Context<'_>, |
252 | | bufs: &[io::IoSlice<'_>], |
253 | | ) -> Poll<io::Result<usize>> |
254 | | where |
255 | | IO: AsyncRead + AsyncWrite + Unpin, |
256 | | { |
257 | | if let TlsState::EarlyData(pos, data) = state { |
258 | | use std::io::Write; |
259 | | |
260 | | // write early data |
261 | | if let Some(mut early_data) = stream.session.early_data() { |
262 | | let mut written = 0; |
263 | | |
264 | | for buf in bufs { |
265 | | if buf.is_empty() { |
266 | | continue; |
267 | | } |
268 | | |
269 | | let len = match early_data.write(buf) { |
270 | | Ok(0) => break, |
271 | | Ok(n) => n, |
272 | | Err(err) => return Poll::Ready(Err(err)), |
273 | | }; |
274 | | |
275 | | written += len; |
276 | | data.extend_from_slice(&buf[..len]); |
277 | | |
278 | | if len < buf.len() { |
279 | | break; |
280 | | } |
281 | | } |
282 | | |
283 | | if written != 0 { |
284 | | return Poll::Ready(Ok(written)); |
285 | | } |
286 | | } |
287 | | |
288 | | // complete handshake |
289 | | while stream.session.is_handshaking() { |
290 | | ready!(stream.handshake(cx))?; |
291 | | } |
292 | | |
293 | | // write early data (fallback) |
294 | | if !stream.session.is_early_data_accepted() { |
295 | | while *pos < data.len() { |
296 | | let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; |
297 | | *pos += len; |
298 | | } |
299 | | } |
300 | | |
301 | | // end |
302 | | *state = TlsState::Stream; |
303 | | |
304 | | if let Some(waker) = early_waker.take() { |
305 | | waker.wake(); |
306 | | } |
307 | | } |
308 | | |
309 | | Poll::Ready(Ok(0)) |
310 | | } |