/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tokio-rustls-0.26.4/src/server.rs
Line | Count | Source |
1 | | use std::future::Future; |
2 | | use std::io::{self, BufRead as _}; |
3 | | #[cfg(unix)] |
4 | | use std::os::unix::io::{AsRawFd, RawFd}; |
5 | | #[cfg(windows)] |
6 | | use std::os::windows::io::{AsRawSocket, RawSocket}; |
7 | | use std::pin::Pin; |
8 | | use std::sync::Arc; |
9 | | use std::task::{Context, Poll}; |
10 | | |
11 | | use rustls::server::AcceptedAlert; |
12 | | use rustls::{ServerConfig, ServerConnection}; |
13 | | use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; |
14 | | |
15 | | use crate::common::{IoSession, MidHandshake, Stream, SyncReadAdapter, SyncWriteAdapter, TlsState}; |
16 | | |
17 | | /// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. |
18 | | #[derive(Clone)] |
19 | | pub struct TlsAcceptor { |
20 | | inner: Arc<ServerConfig>, |
21 | | } |
22 | | |
23 | | impl From<Arc<ServerConfig>> for TlsAcceptor { |
24 | 0 | fn from(inner: Arc<ServerConfig>) -> Self { |
25 | 0 | Self { inner } |
26 | 0 | } |
27 | | } |
28 | | |
29 | | impl TlsAcceptor { |
30 | | #[inline] |
31 | 0 | pub fn accept<IO>(&self, stream: IO) -> Accept<IO> |
32 | 0 | where |
33 | 0 | IO: AsyncRead + AsyncWrite + Unpin, |
34 | | { |
35 | 0 | self.accept_with(stream, |_| ()) |
36 | 0 | } |
37 | | |
38 | 0 | pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO> |
39 | 0 | where |
40 | 0 | IO: AsyncRead + AsyncWrite + Unpin, |
41 | 0 | F: FnOnce(&mut ServerConnection), |
42 | | { |
43 | 0 | let mut session = match ServerConnection::new(self.inner.clone()) { |
44 | 0 | Ok(session) => session, |
45 | 0 | Err(error) => { |
46 | 0 | return Accept(MidHandshake::Error { |
47 | 0 | io: stream, |
48 | 0 | // TODO(eliza): should this really return an `io::Error`? |
49 | 0 | // Probably not... |
50 | 0 | error: io::Error::new(io::ErrorKind::Other, error), |
51 | 0 | }); |
52 | | } |
53 | | }; |
54 | 0 | f(&mut session); |
55 | | |
56 | 0 | Accept(MidHandshake::Handshaking(TlsStream { |
57 | 0 | session, |
58 | 0 | io: stream, |
59 | 0 | state: TlsState::Stream, |
60 | 0 | need_flush: false, |
61 | 0 | })) |
62 | 0 | } |
63 | | |
64 | | /// Get a read-only reference to underlying config |
65 | 0 | pub fn config(&self) -> &Arc<ServerConfig> { |
66 | 0 | &self.inner |
67 | 0 | } |
68 | | } |
69 | | |
70 | | pub struct LazyConfigAcceptor<IO> { |
71 | | acceptor: rustls::server::Acceptor, |
72 | | io: Option<IO>, |
73 | | alert: Option<(rustls::Error, AcceptedAlert)>, |
74 | | } |
75 | | |
76 | | impl<IO> LazyConfigAcceptor<IO> |
77 | | where |
78 | | IO: AsyncRead + AsyncWrite + Unpin, |
79 | | { |
80 | | #[inline] |
81 | 0 | pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { |
82 | 0 | Self { |
83 | 0 | acceptor, |
84 | 0 | io: Some(io), |
85 | 0 | alert: None, |
86 | 0 | } |
87 | 0 | } |
88 | | |
89 | | /// Takes back the client connection. Will return `None` if called more than once or if the |
90 | | /// connection has been accepted. |
91 | | /// |
92 | | /// # Example |
93 | | /// |
94 | | /// ```no_run |
95 | | /// # fn choose_server_config( |
96 | | /// # _: rustls::server::ClientHello, |
97 | | /// # ) -> std::sync::Arc<rustls::ServerConfig> { |
98 | | /// # unimplemented!(); |
99 | | /// # } |
100 | | /// # #[allow(unused_variables)] |
101 | | /// # async fn listen() { |
102 | | /// use tokio::io::AsyncWriteExt; |
103 | | /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap(); |
104 | | /// let (stream, _) = listener.accept().await.unwrap(); |
105 | | /// |
106 | | /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream); |
107 | | /// tokio::pin!(acceptor); |
108 | | /// |
109 | | /// match acceptor.as_mut().await { |
110 | | /// Ok(start) => { |
111 | | /// let clientHello = start.client_hello(); |
112 | | /// let config = choose_server_config(clientHello); |
113 | | /// let stream = start.into_stream(config).await.unwrap(); |
114 | | /// // Proceed with handling the ServerConnection... |
115 | | /// } |
116 | | /// Err(err) => { |
117 | | /// if let Some(mut stream) = acceptor.take_io() { |
118 | | /// stream |
119 | | /// .write_all( |
120 | | /// format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err) |
121 | | /// .as_bytes() |
122 | | /// ) |
123 | | /// .await |
124 | | /// .unwrap(); |
125 | | /// } |
126 | | /// } |
127 | | /// } |
128 | | /// # } |
129 | | /// ``` |
130 | 0 | pub fn take_io(&mut self) -> Option<IO> { |
131 | 0 | self.io.take() |
132 | 0 | } |
133 | | } |
134 | | |
135 | | impl<IO> Future for LazyConfigAcceptor<IO> |
136 | | where |
137 | | IO: AsyncRead + AsyncWrite + Unpin, |
138 | | { |
139 | | type Output = Result<StartHandshake<IO>, io::Error>; |
140 | | |
141 | 0 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
142 | 0 | let this = self.get_mut(); |
143 | | loop { |
144 | 0 | let io = match this.io.as_mut() { |
145 | 0 | Some(io) => io, |
146 | | None => { |
147 | 0 | return Poll::Ready(Err(io::Error::new( |
148 | 0 | io::ErrorKind::Other, |
149 | 0 | "acceptor cannot be polled after acceptance", |
150 | 0 | ))) |
151 | | } |
152 | | }; |
153 | | |
154 | 0 | if let Some((err, mut alert)) = this.alert.take() { |
155 | 0 | match alert.write(&mut SyncWriteAdapter { io, cx }) { |
156 | 0 | Err(e) if e.kind() == io::ErrorKind::WouldBlock => { |
157 | 0 | this.alert = Some((err, alert)); |
158 | 0 | return Poll::Pending; |
159 | | } |
160 | | Ok(0) | Err(_) => { |
161 | 0 | return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err))) |
162 | | } |
163 | | Ok(_) => { |
164 | 0 | this.alert = Some((err, alert)); |
165 | 0 | continue; |
166 | | } |
167 | | }; |
168 | 0 | } |
169 | | |
170 | 0 | let mut reader = SyncReadAdapter { io, cx }; |
171 | 0 | match this.acceptor.read_tls(&mut reader) { |
172 | 0 | Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(), |
173 | 0 | Ok(_) => {} |
174 | 0 | Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, |
175 | 0 | Err(e) => return Err(e).into(), |
176 | | } |
177 | | |
178 | 0 | match this.acceptor.accept() { |
179 | 0 | Ok(Some(accepted)) => { |
180 | 0 | let io = this.io.take().unwrap(); |
181 | 0 | return Poll::Ready(Ok(StartHandshake { accepted, io })); |
182 | | } |
183 | 0 | Ok(None) => {} |
184 | 0 | Err((err, alert)) => { |
185 | 0 | this.alert = Some((err, alert)); |
186 | 0 | } |
187 | | } |
188 | | } |
189 | 0 | } |
190 | | } |
191 | | |
192 | | /// An incoming connection received through [`LazyConfigAcceptor`]. |
193 | | /// |
194 | | /// This contains the generic `IO` asynchronous transport, |
195 | | /// [`ClientHello`](rustls::server::ClientHello) data, |
196 | | /// and all the state required to continue the TLS handshake (e.g. via |
197 | | /// [`StartHandshake::into_stream`]). |
198 | | #[non_exhaustive] |
199 | | #[derive(Debug)] |
200 | | pub struct StartHandshake<IO> { |
201 | | pub accepted: rustls::server::Accepted, |
202 | | pub io: IO, |
203 | | } |
204 | | |
205 | | impl<IO> StartHandshake<IO> |
206 | | where |
207 | | IO: AsyncRead + AsyncWrite + Unpin, |
208 | | { |
209 | | /// Create a new object from an `IO` transport and prior TLS metadata. |
210 | 0 | pub fn from_parts(accepted: rustls::server::Accepted, transport: IO) -> Self { |
211 | 0 | Self { |
212 | 0 | accepted, |
213 | 0 | io: transport, |
214 | 0 | } |
215 | 0 | } |
216 | | |
217 | 0 | pub fn client_hello(&self) -> rustls::server::ClientHello<'_> { |
218 | 0 | self.accepted.client_hello() |
219 | 0 | } |
220 | | |
221 | 0 | pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> { |
222 | 0 | self.into_stream_with(config, |_| ()) |
223 | 0 | } |
224 | | |
225 | 0 | pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO> |
226 | 0 | where |
227 | 0 | F: FnOnce(&mut ServerConnection), |
228 | | { |
229 | 0 | let mut conn = match self.accepted.into_connection(config) { |
230 | 0 | Ok(conn) => conn, |
231 | 0 | Err((error, alert)) => { |
232 | 0 | return Accept(MidHandshake::SendAlert { |
233 | 0 | io: self.io, |
234 | 0 | alert, |
235 | 0 | // TODO(eliza): should this really return an `io::Error`? |
236 | 0 | // Probably not... |
237 | 0 | error: io::Error::new(io::ErrorKind::InvalidData, error), |
238 | 0 | }); |
239 | | } |
240 | | }; |
241 | 0 | f(&mut conn); |
242 | | |
243 | 0 | Accept(MidHandshake::Handshaking(TlsStream { |
244 | 0 | session: conn, |
245 | 0 | io: self.io, |
246 | 0 | state: TlsState::Stream, |
247 | 0 | need_flush: false, |
248 | 0 | })) |
249 | 0 | } |
250 | | } |
251 | | |
252 | | /// Future returned from `TlsAcceptor::accept` which will resolve |
253 | | /// once the accept handshake has finished. |
254 | | pub struct Accept<IO>(MidHandshake<TlsStream<IO>>); |
255 | | |
256 | | impl<IO> Accept<IO> { |
257 | | #[inline] |
258 | 0 | pub fn into_fallible(self) -> FallibleAccept<IO> { |
259 | 0 | FallibleAccept(self.0) |
260 | 0 | } |
261 | | |
262 | 0 | pub fn get_ref(&self) -> Option<&IO> { |
263 | 0 | match &self.0 { |
264 | 0 | MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), |
265 | 0 | MidHandshake::SendAlert { io, .. } => Some(io), |
266 | 0 | MidHandshake::Error { io, .. } => Some(io), |
267 | 0 | MidHandshake::End => None, |
268 | | } |
269 | 0 | } |
270 | | |
271 | 0 | pub fn get_mut(&mut self) -> Option<&mut IO> { |
272 | 0 | match &mut self.0 { |
273 | 0 | MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), |
274 | 0 | MidHandshake::SendAlert { io, .. } => Some(io), |
275 | 0 | MidHandshake::Error { io, .. } => Some(io), |
276 | 0 | MidHandshake::End => None, |
277 | | } |
278 | 0 | } |
279 | | } |
280 | | |
281 | | impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> { |
282 | | type Output = io::Result<TlsStream<IO>>; |
283 | | |
284 | | #[inline] |
285 | 0 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
286 | 0 | Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) |
287 | 0 | } |
288 | | } |
289 | | |
290 | | /// Like [Accept], but returns `IO` on failure. |
291 | | pub struct FallibleAccept<IO>(MidHandshake<TlsStream<IO>>); |
292 | | |
293 | | impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> { |
294 | | type Output = Result<TlsStream<IO>, (io::Error, IO)>; |
295 | | |
296 | | #[inline] |
297 | 0 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
298 | 0 | Pin::new(&mut self.0).poll(cx) |
299 | 0 | } |
300 | | } |
301 | | |
302 | | /// A wrapper around an underlying raw stream which implements the TLS or SSL |
303 | | /// protocol. |
304 | | #[derive(Debug)] |
305 | | pub struct TlsStream<IO> { |
306 | | pub(crate) io: IO, |
307 | | pub(crate) session: ServerConnection, |
308 | | pub(crate) state: TlsState, |
309 | | pub(crate) need_flush: bool, |
310 | | } |
311 | | |
312 | | impl<IO> TlsStream<IO> { |
313 | | #[inline] |
314 | 0 | pub fn get_ref(&self) -> (&IO, &ServerConnection) { |
315 | 0 | (&self.io, &self.session) |
316 | 0 | } |
317 | | |
318 | | #[inline] |
319 | 0 | pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) { |
320 | 0 | (&mut self.io, &mut self.session) |
321 | 0 | } |
322 | | |
323 | | #[inline] |
324 | 0 | pub fn into_inner(self) -> (IO, ServerConnection) { |
325 | 0 | (self.io, self.session) |
326 | 0 | } |
327 | | } |
328 | | |
329 | | impl<IO> IoSession for TlsStream<IO> { |
330 | | type Io = IO; |
331 | | type Session = ServerConnection; |
332 | | |
333 | | #[inline] |
334 | 0 | fn skip_handshake(&self) -> bool { |
335 | 0 | false |
336 | 0 | } |
337 | | |
338 | | #[inline] |
339 | 0 | fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) { |
340 | 0 | ( |
341 | 0 | &mut self.state, |
342 | 0 | &mut self.io, |
343 | 0 | &mut self.session, |
344 | 0 | &mut self.need_flush, |
345 | 0 | ) |
346 | 0 | } |
347 | | |
348 | | #[inline] |
349 | 0 | fn into_io(self) -> Self::Io { |
350 | 0 | self.io |
351 | 0 | } |
352 | | } |
353 | | |
354 | | impl<IO> AsyncRead for TlsStream<IO> |
355 | | where |
356 | | IO: AsyncRead + AsyncWrite + Unpin, |
357 | | { |
358 | 0 | fn poll_read( |
359 | 0 | mut self: Pin<&mut Self>, |
360 | 0 | cx: &mut Context<'_>, |
361 | 0 | buf: &mut ReadBuf<'_>, |
362 | 0 | ) -> Poll<io::Result<()>> { |
363 | 0 | let data = ready!(self.as_mut().poll_fill_buf(cx))?; |
364 | 0 | let len = data.len().min(buf.remaining()); |
365 | 0 | buf.put_slice(&data[..len]); |
366 | 0 | self.consume(len); |
367 | 0 | Poll::Ready(Ok(())) |
368 | 0 | } |
369 | | } |
370 | | |
371 | | impl<IO> AsyncBufRead for TlsStream<IO> |
372 | | where |
373 | | IO: AsyncRead + AsyncWrite + Unpin, |
374 | | { |
375 | 0 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
376 | 0 | match self.state { |
377 | | TlsState::Stream | TlsState::WriteShutdown => { |
378 | 0 | let this = self.get_mut(); |
379 | 0 | let stream = |
380 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
381 | | |
382 | 0 | match stream.poll_fill_buf(cx) { |
383 | 0 | Poll::Ready(Ok(buf)) => { |
384 | 0 | if buf.is_empty() { |
385 | 0 | this.state.shutdown_read(); |
386 | 0 | } |
387 | | |
388 | 0 | Poll::Ready(Ok(buf)) |
389 | | } |
390 | 0 | Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { |
391 | 0 | this.state.shutdown_read(); |
392 | 0 | Poll::Ready(Err(err)) |
393 | | } |
394 | 0 | output => output, |
395 | | } |
396 | | } |
397 | 0 | TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])), |
398 | | #[cfg(feature = "early-data")] |
399 | | ref s => unreachable!("server TLS can not hit this state: {:?}", s), |
400 | | } |
401 | 0 | } |
402 | | |
403 | 0 | fn consume(mut self: Pin<&mut Self>, amt: usize) { |
404 | 0 | self.session.reader().consume(amt); |
405 | 0 | } |
406 | | } |
407 | | |
408 | | impl<IO> AsyncWrite for TlsStream<IO> |
409 | | where |
410 | | IO: AsyncRead + AsyncWrite + Unpin, |
411 | | { |
412 | | /// Note: that it does not guarantee the final data to be sent. |
413 | | /// To be cautious, you must manually call `flush`. |
414 | 0 | fn poll_write( |
415 | 0 | self: Pin<&mut Self>, |
416 | 0 | cx: &mut Context<'_>, |
417 | 0 | buf: &[u8], |
418 | 0 | ) -> Poll<io::Result<usize>> { |
419 | 0 | let this = self.get_mut(); |
420 | 0 | let mut stream = |
421 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
422 | 0 | stream.as_mut_pin().poll_write(cx, buf) |
423 | 0 | } |
424 | | |
425 | | /// Note: that it does not guarantee the final data to be sent. |
426 | | /// To be cautious, you must manually call `flush`. |
427 | 0 | fn poll_write_vectored( |
428 | 0 | self: Pin<&mut Self>, |
429 | 0 | cx: &mut Context<'_>, |
430 | 0 | bufs: &[io::IoSlice<'_>], |
431 | 0 | ) -> Poll<io::Result<usize>> { |
432 | 0 | let this = self.get_mut(); |
433 | 0 | let mut stream = |
434 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
435 | 0 | stream.as_mut_pin().poll_write_vectored(cx, bufs) |
436 | 0 | } |
437 | | |
438 | | #[inline] |
439 | 0 | fn is_write_vectored(&self) -> bool { |
440 | 0 | true |
441 | 0 | } |
442 | | |
443 | 0 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
444 | 0 | let this = self.get_mut(); |
445 | 0 | let mut stream = |
446 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
447 | 0 | stream.as_mut_pin().poll_flush(cx) |
448 | 0 | } |
449 | | |
450 | 0 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
451 | 0 | if self.state.writeable() { |
452 | 0 | self.session.send_close_notify(); |
453 | 0 | self.state.shutdown_write(); |
454 | 0 | } |
455 | | |
456 | 0 | let this = self.get_mut(); |
457 | 0 | let mut stream = |
458 | 0 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
459 | 0 | stream.as_mut_pin().poll_shutdown(cx) |
460 | 0 | } |
461 | | } |
462 | | |
463 | | #[cfg(unix)] |
464 | | impl<IO> AsRawFd for TlsStream<IO> |
465 | | where |
466 | | IO: AsRawFd, |
467 | | { |
468 | 0 | fn as_raw_fd(&self) -> RawFd { |
469 | 0 | self.get_ref().0.as_raw_fd() |
470 | 0 | } |
471 | | } |
472 | | |
473 | | #[cfg(windows)] |
474 | | impl<IO> AsRawSocket for TlsStream<IO> |
475 | | where |
476 | | IO: AsRawSocket, |
477 | | { |
478 | | fn as_raw_socket(&self) -> RawSocket { |
479 | | self.get_ref().0.as_raw_socket() |
480 | | } |
481 | | } |