Coverage Report

Created: 2025-05-07 06:59

/rust/registry/src/index.crates.io-6f17d22bba15001f/rustls-0.23.26/src/msgs/fragmenter.rs
Line
Count
Source (jump to first uncovered line)
1
use crate::Error;
2
use crate::enums::{ContentType, ProtocolVersion};
3
use crate::msgs::message::{OutboundChunks, OutboundPlainMessage, PlainMessage};
4
pub(crate) const MAX_FRAGMENT_LEN: usize = 16384;
5
pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2;
6
pub(crate) const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD;
7
8
pub struct MessageFragmenter {
9
    max_frag: usize,
10
}
11
12
impl Default for MessageFragmenter {
13
0
    fn default() -> Self {
14
0
        Self {
15
0
            max_frag: MAX_FRAGMENT_LEN,
16
0
        }
17
0
    }
18
}
19
20
impl MessageFragmenter {
21
    /// Take `msg` and fragment it into new messages with the same type and version.
22
    ///
23
    /// Each returned message size is no more than `max_frag`.
24
    ///
25
    /// Return an iterator across those messages.
26
    ///
27
    /// Payloads are borrowed from `msg`.
28
0
    pub fn fragment_message<'a>(
29
0
        &self,
30
0
        msg: &'a PlainMessage,
31
0
    ) -> impl Iterator<Item = OutboundPlainMessage<'a>> + 'a {
32
0
        self.fragment_payload(msg.typ, msg.version, msg.payload.bytes().into())
33
0
    }
34
35
    /// Take `payload` and fragment it into new messages with given type and version.
36
    ///
37
    /// Each returned message size is no more than `max_frag`.
38
    ///
39
    /// Return an iterator across those messages.
40
    ///
41
    /// Payloads are borrowed from `payload`.
42
0
    pub(crate) fn fragment_payload<'a>(
43
0
        &self,
44
0
        typ: ContentType,
45
0
        version: ProtocolVersion,
46
0
        payload: OutboundChunks<'a>,
47
0
    ) -> impl ExactSizeIterator<Item = OutboundPlainMessage<'a>> {
48
0
        Chunker::new(payload, self.max_frag).map(move |payload| OutboundPlainMessage {
49
0
            typ,
50
0
            version,
51
0
            payload,
52
0
        })
53
0
    }
54
55
    /// Set the maximum fragment size that will be produced.
56
    ///
57
    /// This includes overhead. A `max_fragment_size` of 10 will produce TLS fragments
58
    /// up to 10 bytes long.
59
    ///
60
    /// A `max_fragment_size` of `None` sets the highest allowable fragment size.
61
    ///
62
    /// Returns BadMaxFragmentSize if the size is smaller than 32 or larger than 16389.
63
0
    pub fn set_max_fragment_size(&mut self, max_fragment_size: Option<usize>) -> Result<(), Error> {
64
0
        self.max_frag = match max_fragment_size {
65
0
            Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD,
66
0
            None => MAX_FRAGMENT_LEN,
67
0
            _ => return Err(Error::BadMaxFragmentSize),
68
        };
69
0
        Ok(())
70
0
    }
71
}
72
73
/// An iterator over borrowed fragments of a payload
74
struct Chunker<'a> {
75
    payload: OutboundChunks<'a>,
76
    limit: usize,
77
}
78
79
impl<'a> Chunker<'a> {
80
0
    fn new(payload: OutboundChunks<'a>, limit: usize) -> Self {
81
0
        Self { payload, limit }
82
0
    }
83
}
84
85
impl<'a> Iterator for Chunker<'a> {
86
    type Item = OutboundChunks<'a>;
87
88
0
    fn next(&mut self) -> Option<Self::Item> {
89
0
        if self.payload.is_empty() {
90
0
            return None;
91
0
        }
92
0
93
0
        let (before, after) = self.payload.split_at(self.limit);
94
0
        self.payload = after;
95
0
        Some(before)
96
0
    }
97
}
98
99
impl ExactSizeIterator for Chunker<'_> {
100
0
    fn len(&self) -> usize {
101
0
        (self.payload.len() + self.limit - 1) / self.limit
102
0
    }
103
}
104
105
#[cfg(test)]
106
mod tests {
107
    use std::prelude::v1::*;
108
    use std::vec;
109
110
    use super::{MessageFragmenter, PACKET_OVERHEAD};
111
    use crate::enums::{ContentType, ProtocolVersion};
112
    use crate::msgs::base::Payload;
113
    use crate::msgs::message::{OutboundChunks, OutboundPlainMessage, PlainMessage};
114
115
    fn msg_eq(
116
        m: &OutboundPlainMessage<'_>,
117
        total_len: usize,
118
        typ: &ContentType,
119
        version: &ProtocolVersion,
120
        bytes: &[u8],
121
    ) {
122
        assert_eq!(&m.typ, typ);
123
        assert_eq!(&m.version, version);
124
        assert_eq!(m.payload.to_vec(), bytes);
125
126
        let buf = m.to_unencrypted_opaque().encode();
127
128
        assert_eq!(total_len, buf.len());
129
    }
130
131
    #[test]
132
    fn smoke() {
133
        let typ = ContentType::Handshake;
134
        let version = ProtocolVersion::TLSv1_2;
135
        let data: Vec<u8> = (1..70u8).collect();
136
        let m = PlainMessage {
137
            typ,
138
            version,
139
            payload: Payload::new(data),
140
        };
141
142
        let mut frag = MessageFragmenter::default();
143
        frag.set_max_fragment_size(Some(32))
144
            .unwrap();
145
        let q = frag
146
            .fragment_message(&m)
147
            .collect::<Vec<_>>();
148
        assert_eq!(q.len(), 3);
149
        msg_eq(
150
            &q[0],
151
            32,
152
            &typ,
153
            &version,
154
            &[
155
                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
156
                24, 25, 26, 27,
157
            ],
158
        );
159
        msg_eq(
160
            &q[1],
161
            32,
162
            &typ,
163
            &version,
164
            &[
165
                28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
166
                49, 50, 51, 52, 53, 54,
167
            ],
168
        );
169
        msg_eq(
170
            &q[2],
171
            20,
172
            &typ,
173
            &version,
174
            &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
175
        );
176
    }
177
178
    #[test]
179
    fn non_fragment() {
180
        let m = PlainMessage {
181
            typ: ContentType::Handshake,
182
            version: ProtocolVersion::TLSv1_2,
183
            payload: Payload::new(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
184
        };
185
186
        let mut frag = MessageFragmenter::default();
187
        frag.set_max_fragment_size(Some(32))
188
            .unwrap();
189
        let q = frag
190
            .fragment_message(&m)
191
            .collect::<Vec<_>>();
192
        assert_eq!(q.len(), 1);
193
        msg_eq(
194
            &q[0],
195
            PACKET_OVERHEAD + 8,
196
            &ContentType::Handshake,
197
            &ProtocolVersion::TLSv1_2,
198
            b"\x01\x02\x03\x04\x05\x06\x07\x08",
199
        );
200
    }
201
202
    #[test]
203
    fn fragment_multiple_slices() {
204
        let typ = ContentType::Handshake;
205
        let version = ProtocolVersion::TLSv1_2;
206
        let payload_owner: Vec<&[u8]> = vec![&[b'a'; 8], &[b'b'; 12], &[b'c'; 32], &[b'd'; 20]];
207
        let borrowed_payload = OutboundChunks::new(&payload_owner);
208
        let mut frag = MessageFragmenter::default();
209
        frag.set_max_fragment_size(Some(37)) // 32 + packet overhead
210
            .unwrap();
211
212
        let fragments = frag
213
            .fragment_payload(typ, version, borrowed_payload)
214
            .collect::<Vec<_>>();
215
        assert_eq!(fragments.len(), 3);
216
        msg_eq(
217
            &fragments[0],
218
            37,
219
            &typ,
220
            &version,
221
            b"aaaaaaaabbbbbbbbbbbbcccccccccccc",
222
        );
223
        msg_eq(
224
            &fragments[1],
225
            37,
226
            &typ,
227
            &version,
228
            b"ccccccccccccccccccccdddddddddddd",
229
        );
230
        msg_eq(&fragments[2], 13, &typ, &version, b"dddddddd");
231
    }
232
}