Coverage Report

Created: 2026-01-30 06:08

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tokio-rustls-0.26.4/src/server.rs
Line
Count
Source
1
use std::future::Future;
2
use std::io::{self, BufRead as _};
3
#[cfg(unix)]
4
use std::os::unix::io::{AsRawFd, RawFd};
5
#[cfg(windows)]
6
use std::os::windows::io::{AsRawSocket, RawSocket};
7
use std::pin::Pin;
8
use std::sync::Arc;
9
use std::task::{Context, Poll};
10
11
use rustls::server::AcceptedAlert;
12
use rustls::{ServerConfig, ServerConnection};
13
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
14
15
use crate::common::{IoSession, MidHandshake, Stream, SyncReadAdapter, SyncWriteAdapter, TlsState};
16
17
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
18
#[derive(Clone)]
19
pub struct TlsAcceptor {
20
    inner: Arc<ServerConfig>,
21
}
22
23
impl From<Arc<ServerConfig>> for TlsAcceptor {
24
0
    fn from(inner: Arc<ServerConfig>) -> Self {
25
0
        Self { inner }
26
0
    }
27
}
28
29
impl TlsAcceptor {
30
    #[inline]
31
0
    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
32
0
    where
33
0
        IO: AsyncRead + AsyncWrite + Unpin,
34
    {
35
0
        self.accept_with(stream, |_| ())
36
0
    }
37
38
0
    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
39
0
    where
40
0
        IO: AsyncRead + AsyncWrite + Unpin,
41
0
        F: FnOnce(&mut ServerConnection),
42
    {
43
0
        let mut session = match ServerConnection::new(self.inner.clone()) {
44
0
            Ok(session) => session,
45
0
            Err(error) => {
46
0
                return Accept(MidHandshake::Error {
47
0
                    io: stream,
48
0
                    // TODO(eliza): should this really return an `io::Error`?
49
0
                    // Probably not...
50
0
                    error: io::Error::new(io::ErrorKind::Other, error),
51
0
                });
52
            }
53
        };
54
0
        f(&mut session);
55
56
0
        Accept(MidHandshake::Handshaking(TlsStream {
57
0
            session,
58
0
            io: stream,
59
0
            state: TlsState::Stream,
60
0
            need_flush: false,
61
0
        }))
62
0
    }
63
64
    /// Get a read-only reference to underlying config
65
0
    pub fn config(&self) -> &Arc<ServerConfig> {
66
0
        &self.inner
67
0
    }
68
}
69
70
pub struct LazyConfigAcceptor<IO> {
71
    acceptor: rustls::server::Acceptor,
72
    io: Option<IO>,
73
    alert: Option<(rustls::Error, AcceptedAlert)>,
74
}
75
76
impl<IO> LazyConfigAcceptor<IO>
77
where
78
    IO: AsyncRead + AsyncWrite + Unpin,
79
{
80
    #[inline]
81
0
    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
82
0
        Self {
83
0
            acceptor,
84
0
            io: Some(io),
85
0
            alert: None,
86
0
        }
87
0
    }
88
89
    /// Takes back the client connection. Will return `None` if called more than once or if the
90
    /// connection has been accepted.
91
    ///
92
    /// # Example
93
    ///
94
    /// ```no_run
95
    /// # fn choose_server_config(
96
    /// #     _: rustls::server::ClientHello,
97
    /// # ) -> std::sync::Arc<rustls::ServerConfig> {
98
    /// #     unimplemented!();
99
    /// # }
100
    /// # #[allow(unused_variables)]
101
    /// # async fn listen() {
102
    /// use tokio::io::AsyncWriteExt;
103
    /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
104
    /// let (stream, _) = listener.accept().await.unwrap();
105
    ///
106
    /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
107
    /// tokio::pin!(acceptor);
108
    ///
109
    /// match acceptor.as_mut().await {
110
    ///     Ok(start) => {
111
    ///         let clientHello = start.client_hello();
112
    ///         let config = choose_server_config(clientHello);
113
    ///         let stream = start.into_stream(config).await.unwrap();
114
    ///         // Proceed with handling the ServerConnection...
115
    ///     }
116
    ///     Err(err) => {
117
    ///         if let Some(mut stream) = acceptor.take_io() {
118
    ///             stream
119
    ///                 .write_all(
120
    ///                     format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
121
    ///                         .as_bytes()
122
    ///                 )
123
    ///                 .await
124
    ///                 .unwrap();
125
    ///         }
126
    ///     }
127
    /// }
128
    /// # }
129
    /// ```
130
0
    pub fn take_io(&mut self) -> Option<IO> {
131
0
        self.io.take()
132
0
    }
133
}
134
135
impl<IO> Future for LazyConfigAcceptor<IO>
136
where
137
    IO: AsyncRead + AsyncWrite + Unpin,
138
{
139
    type Output = Result<StartHandshake<IO>, io::Error>;
140
141
0
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142
0
        let this = self.get_mut();
143
        loop {
144
0
            let io = match this.io.as_mut() {
145
0
                Some(io) => io,
146
                None => {
147
0
                    return Poll::Ready(Err(io::Error::new(
148
0
                        io::ErrorKind::Other,
149
0
                        "acceptor cannot be polled after acceptance",
150
0
                    )))
151
                }
152
            };
153
154
0
            if let Some((err, mut alert)) = this.alert.take() {
155
0
                match alert.write(&mut SyncWriteAdapter { io, cx }) {
156
0
                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
157
0
                        this.alert = Some((err, alert));
158
0
                        return Poll::Pending;
159
                    }
160
                    Ok(0) | Err(_) => {
161
0
                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
162
                    }
163
                    Ok(_) => {
164
0
                        this.alert = Some((err, alert));
165
0
                        continue;
166
                    }
167
                };
168
0
            }
169
170
0
            let mut reader = SyncReadAdapter { io, cx };
171
0
            match this.acceptor.read_tls(&mut reader) {
172
0
                Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
173
0
                Ok(_) => {}
174
0
                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
175
0
                Err(e) => return Err(e).into(),
176
            }
177
178
0
            match this.acceptor.accept() {
179
0
                Ok(Some(accepted)) => {
180
0
                    let io = this.io.take().unwrap();
181
0
                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
182
                }
183
0
                Ok(None) => {}
184
0
                Err((err, alert)) => {
185
0
                    this.alert = Some((err, alert));
186
0
                }
187
            }
188
        }
189
0
    }
190
}
191
192
/// An incoming connection received through [`LazyConfigAcceptor`].
193
///
194
/// This contains the generic `IO` asynchronous transport,
195
/// [`ClientHello`](rustls::server::ClientHello) data,
196
/// and all the state required to continue the TLS handshake (e.g. via
197
/// [`StartHandshake::into_stream`]).
198
#[non_exhaustive]
199
#[derive(Debug)]
200
pub struct StartHandshake<IO> {
201
    pub accepted: rustls::server::Accepted,
202
    pub io: IO,
203
}
204
205
impl<IO> StartHandshake<IO>
206
where
207
    IO: AsyncRead + AsyncWrite + Unpin,
208
{
209
    /// Create a new object from an `IO` transport and prior TLS metadata.
210
0
    pub fn from_parts(accepted: rustls::server::Accepted, transport: IO) -> Self {
211
0
        Self {
212
0
            accepted,
213
0
            io: transport,
214
0
        }
215
0
    }
216
217
0
    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
218
0
        self.accepted.client_hello()
219
0
    }
220
221
0
    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
222
0
        self.into_stream_with(config, |_| ())
223
0
    }
224
225
0
    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
226
0
    where
227
0
        F: FnOnce(&mut ServerConnection),
228
    {
229
0
        let mut conn = match self.accepted.into_connection(config) {
230
0
            Ok(conn) => conn,
231
0
            Err((error, alert)) => {
232
0
                return Accept(MidHandshake::SendAlert {
233
0
                    io: self.io,
234
0
                    alert,
235
0
                    // TODO(eliza): should this really return an `io::Error`?
236
0
                    // Probably not...
237
0
                    error: io::Error::new(io::ErrorKind::InvalidData, error),
238
0
                });
239
            }
240
        };
241
0
        f(&mut conn);
242
243
0
        Accept(MidHandshake::Handshaking(TlsStream {
244
0
            session: conn,
245
0
            io: self.io,
246
0
            state: TlsState::Stream,
247
0
            need_flush: false,
248
0
        }))
249
0
    }
250
}
251
252
/// Future returned from `TlsAcceptor::accept` which will resolve
253
/// once the accept handshake has finished.
254
pub struct Accept<IO>(MidHandshake<TlsStream<IO>>);
255
256
impl<IO> Accept<IO> {
257
    #[inline]
258
0
    pub fn into_fallible(self) -> FallibleAccept<IO> {
259
0
        FallibleAccept(self.0)
260
0
    }
261
262
0
    pub fn get_ref(&self) -> Option<&IO> {
263
0
        match &self.0 {
264
0
            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
265
0
            MidHandshake::SendAlert { io, .. } => Some(io),
266
0
            MidHandshake::Error { io, .. } => Some(io),
267
0
            MidHandshake::End => None,
268
        }
269
0
    }
270
271
0
    pub fn get_mut(&mut self) -> Option<&mut IO> {
272
0
        match &mut self.0 {
273
0
            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
274
0
            MidHandshake::SendAlert { io, .. } => Some(io),
275
0
            MidHandshake::Error { io, .. } => Some(io),
276
0
            MidHandshake::End => None,
277
        }
278
0
    }
279
}
280
281
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
282
    type Output = io::Result<TlsStream<IO>>;
283
284
    #[inline]
285
0
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286
0
        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
287
0
    }
288
}
289
290
/// Like [Accept], but returns `IO` on failure.
291
pub struct FallibleAccept<IO>(MidHandshake<TlsStream<IO>>);
292
293
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
294
    type Output = Result<TlsStream<IO>, (io::Error, IO)>;
295
296
    #[inline]
297
0
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
298
0
        Pin::new(&mut self.0).poll(cx)
299
0
    }
300
}
301
302
/// A wrapper around an underlying raw stream which implements the TLS or SSL
303
/// protocol.
304
#[derive(Debug)]
305
pub struct TlsStream<IO> {
306
    pub(crate) io: IO,
307
    pub(crate) session: ServerConnection,
308
    pub(crate) state: TlsState,
309
    pub(crate) need_flush: bool,
310
}
311
312
impl<IO> TlsStream<IO> {
313
    #[inline]
314
0
    pub fn get_ref(&self) -> (&IO, &ServerConnection) {
315
0
        (&self.io, &self.session)
316
0
    }
317
318
    #[inline]
319
0
    pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
320
0
        (&mut self.io, &mut self.session)
321
0
    }
322
323
    #[inline]
324
0
    pub fn into_inner(self) -> (IO, ServerConnection) {
325
0
        (self.io, self.session)
326
0
    }
327
}
328
329
impl<IO> IoSession for TlsStream<IO> {
330
    type Io = IO;
331
    type Session = ServerConnection;
332
333
    #[inline]
334
0
    fn skip_handshake(&self) -> bool {
335
0
        false
336
0
    }
337
338
    #[inline]
339
0
    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) {
340
0
        (
341
0
            &mut self.state,
342
0
            &mut self.io,
343
0
            &mut self.session,
344
0
            &mut self.need_flush,
345
0
        )
346
0
    }
347
348
    #[inline]
349
0
    fn into_io(self) -> Self::Io {
350
0
        self.io
351
0
    }
352
}
353
354
impl<IO> AsyncRead for TlsStream<IO>
355
where
356
    IO: AsyncRead + AsyncWrite + Unpin,
357
{
358
0
    fn poll_read(
359
0
        mut self: Pin<&mut Self>,
360
0
        cx: &mut Context<'_>,
361
0
        buf: &mut ReadBuf<'_>,
362
0
    ) -> Poll<io::Result<()>> {
363
0
        let data = ready!(self.as_mut().poll_fill_buf(cx))?;
364
0
        let len = data.len().min(buf.remaining());
365
0
        buf.put_slice(&data[..len]);
366
0
        self.consume(len);
367
0
        Poll::Ready(Ok(()))
368
0
    }
369
}
370
371
impl<IO> AsyncBufRead for TlsStream<IO>
372
where
373
    IO: AsyncRead + AsyncWrite + Unpin,
374
{
375
0
    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
376
0
        match self.state {
377
            TlsState::Stream | TlsState::WriteShutdown => {
378
0
                let this = self.get_mut();
379
0
                let stream =
380
0
                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
381
382
0
                match stream.poll_fill_buf(cx) {
383
0
                    Poll::Ready(Ok(buf)) => {
384
0
                        if buf.is_empty() {
385
0
                            this.state.shutdown_read();
386
0
                        }
387
388
0
                        Poll::Ready(Ok(buf))
389
                    }
390
0
                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
391
0
                        this.state.shutdown_read();
392
0
                        Poll::Ready(Err(err))
393
                    }
394
0
                    output => output,
395
                }
396
            }
397
0
            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
398
            #[cfg(feature = "early-data")]
399
            ref s => unreachable!("server TLS can not hit this state: {:?}", s),
400
        }
401
0
    }
402
403
0
    fn consume(mut self: Pin<&mut Self>, amt: usize) {
404
0
        self.session.reader().consume(amt);
405
0
    }
406
}
407
408
impl<IO> AsyncWrite for TlsStream<IO>
409
where
410
    IO: AsyncRead + AsyncWrite + Unpin,
411
{
412
    /// Note: that it does not guarantee the final data to be sent.
413
    /// To be cautious, you must manually call `flush`.
414
0
    fn poll_write(
415
0
        self: Pin<&mut Self>,
416
0
        cx: &mut Context<'_>,
417
0
        buf: &[u8],
418
0
    ) -> Poll<io::Result<usize>> {
419
0
        let this = self.get_mut();
420
0
        let mut stream =
421
0
            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
422
0
        stream.as_mut_pin().poll_write(cx, buf)
423
0
    }
424
425
    /// Note: that it does not guarantee the final data to be sent.
426
    /// To be cautious, you must manually call `flush`.
427
0
    fn poll_write_vectored(
428
0
        self: Pin<&mut Self>,
429
0
        cx: &mut Context<'_>,
430
0
        bufs: &[io::IoSlice<'_>],
431
0
    ) -> Poll<io::Result<usize>> {
432
0
        let this = self.get_mut();
433
0
        let mut stream =
434
0
            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
435
0
        stream.as_mut_pin().poll_write_vectored(cx, bufs)
436
0
    }
437
438
    #[inline]
439
0
    fn is_write_vectored(&self) -> bool {
440
0
        true
441
0
    }
442
443
0
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
444
0
        let this = self.get_mut();
445
0
        let mut stream =
446
0
            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
447
0
        stream.as_mut_pin().poll_flush(cx)
448
0
    }
449
450
0
    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
451
0
        if self.state.writeable() {
452
0
            self.session.send_close_notify();
453
0
            self.state.shutdown_write();
454
0
        }
455
456
0
        let this = self.get_mut();
457
0
        let mut stream =
458
0
            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
459
0
        stream.as_mut_pin().poll_shutdown(cx)
460
0
    }
461
}
462
463
#[cfg(unix)]
464
impl<IO> AsRawFd for TlsStream<IO>
465
where
466
    IO: AsRawFd,
467
{
468
0
    fn as_raw_fd(&self) -> RawFd {
469
0
        self.get_ref().0.as_raw_fd()
470
0
    }
471
}
472
473
#[cfg(windows)]
474
impl<IO> AsRawSocket for TlsStream<IO>
475
where
476
    IO: AsRawSocket,
477
{
478
    fn as_raw_socket(&self) -> RawSocket {
479
        self.get_ref().0.as_raw_socket()
480
    }
481
}