Coverage Report

Created: 2025-07-18 06:13

/src/tungstenite-rs/src/handshake/mod.rs
Line
Count
Source (jump to first uncovered line)
1
//! WebSocket handshake control.
2
3
pub mod client;
4
pub mod headers;
5
pub mod machine;
6
pub mod server;
7
8
use std::{
9
    error::Error as ErrorTrait,
10
    fmt,
11
    io::{Read, Write},
12
};
13
14
use sha1::{Digest, Sha1};
15
16
use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
17
use crate::error::Error;
18
19
/// A WebSocket handshake.
20
#[derive(Debug)]
21
pub struct MidHandshake<Role: HandshakeRole> {
22
    role: Role,
23
    machine: HandshakeMachine<Role::InternalStream>,
24
}
25
26
impl<Role: HandshakeRole> MidHandshake<Role> {
27
    /// Allow access to machine
28
0
    pub fn get_ref(&self) -> &HandshakeMachine<Role::InternalStream> {
29
0
        &self.machine
30
0
    }
31
32
    /// Allow mutable access to machine
33
0
    pub fn get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream> {
34
0
        &mut self.machine
35
0
    }
36
37
    /// Restarts the handshake process.
38
0
    pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
39
0
        let mut mach = self.machine;
40
        loop {
41
0
            mach = match mach.single_round()? {
42
0
                RoundResult::WouldBlock(m) => {
43
0
                    return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
44
                }
45
0
                RoundResult::Incomplete(m) => m,
46
0
                RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
47
0
                    ProcessingResult::Continue(m) => m,
48
0
                    ProcessingResult::Done(result) => return Ok(result),
49
                },
50
            }
51
        }
52
0
    }
53
}
54
55
/// A handshake result.
56
pub enum HandshakeError<Role: HandshakeRole> {
57
    /// Handshake was interrupted (would block).
58
    Interrupted(MidHandshake<Role>),
59
    /// Handshake failed.
60
    Failure(Error),
61
}
62
63
impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
64
0
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65
0
        match *self {
66
0
            HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
67
0
            HandshakeError::Failure(ref e) => write!(f, "HandshakeError::Failure({e:?})"),
68
        }
69
0
    }
70
}
71
72
impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
73
0
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74
0
        match *self {
75
0
            HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
76
0
            HandshakeError::Failure(ref e) => write!(f, "{e}"),
77
        }
78
0
    }
79
}
80
81
impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {}
82
83
impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
84
0
    fn from(err: Error) -> Self {
85
0
        HandshakeError::Failure(err)
86
0
    }
87
}
88
89
/// Handshake role.
90
pub trait HandshakeRole {
91
    #[doc(hidden)]
92
    type IncomingData: TryParse;
93
    #[doc(hidden)]
94
    type InternalStream: Read + Write;
95
    #[doc(hidden)]
96
    type FinalResult;
97
    #[doc(hidden)]
98
    fn stage_finished(
99
        &mut self,
100
        finish: StageResult<Self::IncomingData, Self::InternalStream>,
101
    ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
102
}
103
104
/// Stage processing result.
105
#[doc(hidden)]
106
#[derive(Debug)]
107
pub enum ProcessingResult<Stream, FinalResult> {
108
    Continue(HandshakeMachine<Stream>),
109
    Done(FinalResult),
110
}
111
112
/// Derive the `Sec-WebSocket-Accept` response header from a `Sec-WebSocket-Key` request header.
113
///
114
/// This function can be used to perform a handshake before passing a raw TCP stream to
115
/// [`WebSocket::from_raw_socket`][crate::protocol::WebSocket::from_raw_socket].
116
0
pub fn derive_accept_key(request_key: &[u8]) -> String {
117
    // ... field is constructed by concatenating /key/ ...
118
    // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
119
    const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
120
0
    let mut sha1 = Sha1::default();
121
0
    sha1.update(request_key);
122
0
    sha1.update(WS_GUID);
123
0
    data_encoding::BASE64.encode(&sha1.finalize())
124
0
}
125
126
#[cfg(test)]
127
mod tests {
128
    use super::derive_accept_key;
129
130
    #[test]
131
    fn key_conversion() {
132
        // example from RFC 6455
133
        assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
134
    }
135
}