Coverage Report

Created: 2026-04-01 06:56

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/regex/regex-automata/src/util/wire.rs
Line
Count
Source
1
/*!
2
Types and routines that support the wire format of finite automata.
3
4
Currently, this module just exports a few error types and some small helpers
5
for deserializing [dense DFAs](crate::dfa::dense::DFA) using correct alignment.
6
*/
7
8
/*
9
A collection of helper functions, types and traits for serializing automata.
10
11
This crate defines its own bespoke serialization mechanism for some structures
12
provided in the public API, namely, DFAs. A bespoke mechanism was developed
13
primarily because structures like automata demand a specific binary format.
14
Attempting to encode their rich structure in an existing serialization
15
format is just not feasible. Moreover, the format for each structure is
16
generally designed such that deserialization is cheap. More specifically, that
17
deserialization can be done in constant time. (The idea being that you can
18
embed it into your binary or mmap it, and then use it immediately.)
19
20
In order to achieve this, the dense and sparse DFAs in this crate use an
21
in-memory representation that very closely corresponds to its binary serialized
22
form. This pervades and complicates everything, and in some cases, requires
23
dealing with alignment and reasoning about safety.
24
25
This technique does have major advantages. In particular, it permits doing
26
the potentially costly work of compiling a finite state machine in an offline
27
manner, and then loading it at runtime not only without having to re-compile
28
the regex, but even without the code required to do the compilation. This, for
29
example, permits one to use a pre-compiled DFA not only in environments without
30
Rust's standard library, but also in environments without a heap.
31
32
In the code below, whenever we insert some kind of padding, it's to enforce a
33
4-byte alignment, unless otherwise noted. Namely, u32 is the only state ID type
34
supported. (In a previous version of this library, DFAs were generic over the
35
state ID representation.)
36
37
Also, serialization generally requires the caller to specify endianness,
38
where as deserialization always assumes native endianness (otherwise cheap
39
deserialization would be impossible). This implies that serializing a structure
40
generally requires serializing both its big-endian and little-endian variants,
41
and then loading the correct one based on the target's endianness.
42
*/
43
44
use core::{cmp, mem::size_of};
45
46
#[cfg(feature = "alloc")]
47
use alloc::{vec, vec::Vec};
48
49
use crate::util::{
50
    int::Pointer,
51
    primitives::{PatternID, PatternIDError, StateID, StateIDError},
52
};
53
54
/// A hack to align a smaller type `B` with a bigger type `T`.
55
///
56
/// The usual use of this is with `B = [u8]` and `T = u32`. That is,
57
/// it permits aligning a sequence of bytes on a 4-byte boundary. This
58
/// is useful in contexts where one wants to embed a serialized [dense
59
/// DFA](crate::dfa::dense::DFA) into a Rust a program while guaranteeing the
60
/// alignment required for the DFA.
61
///
62
/// See [`dense::DFA::from_bytes`](crate::dfa::dense::DFA::from_bytes) for an
63
/// example of how to use this type.
64
#[repr(C)]
65
#[derive(Debug)]
66
pub struct AlignAs<B: ?Sized, T> {
67
    /// A zero-sized field indicating the alignment we want.
68
    pub _align: [T; 0],
69
    /// A possibly non-sized field containing a sequence of bytes.
70
    pub bytes: B,
71
}
72
73
/// An error that occurs when serializing an object from this crate.
74
///
75
/// Serialization, as used in this crate, universally refers to the process
76
/// of transforming a structure (like a DFA) into a custom binary format
77
/// represented by `&[u8]`. To this end, serialization is generally infallible.
78
/// However, it can fail when caller provided buffer sizes are too small. When
79
/// that occurs, a serialization error is reported.
80
///
81
/// A `SerializeError` provides no introspection capabilities. Its only
82
/// supported operation is conversion to a human readable error message.
83
///
84
/// This error type implements the `std::error::Error` trait only when the
85
/// `std` feature is enabled. Otherwise, this type is defined in all
86
/// configurations.
87
#[derive(Debug)]
88
pub struct SerializeError {
89
    /// The name of the thing that a buffer is too small for.
90
    ///
91
    /// Currently, the only kind of serialization error is one that is
92
    /// committed by a caller: providing a destination buffer that is too
93
    /// small to fit the serialized object. This makes sense conceptually,
94
    /// since every valid inhabitant of a type should be serializable.
95
    ///
96
    /// This is somewhat exposed in the public API of this crate. For example,
97
    /// the `to_bytes_{big,little}_endian` APIs return a `Vec<u8>` and are
98
    /// guaranteed to never panic or error. This is only possible because the
99
    /// implementation guarantees that it will allocate a `Vec<u8>` that is
100
    /// big enough.
101
    ///
102
    /// In summary, if a new serialization error kind needs to be added, then
103
    /// it will need careful consideration.
104
    what: &'static str,
105
}
106
107
impl SerializeError {
108
0
    pub(crate) fn buffer_too_small(what: &'static str) -> SerializeError {
109
0
        SerializeError { what }
110
0
    }
111
}
112
113
impl core::fmt::Display for SerializeError {
114
0
    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
115
0
        write!(f, "destination buffer is too small to write {}", self.what)
116
0
    }
117
}
118
119
#[cfg(feature = "std")]
120
impl std::error::Error for SerializeError {}
121
122
/// An error that occurs when deserializing an object defined in this crate.
123
///
124
/// Serialization, as used in this crate, universally refers to the process
125
/// of transforming a structure (like a DFA) into a custom binary format
126
/// represented by `&[u8]`. Deserialization, then, refers to the process of
127
/// cheaply converting this binary format back to the object's in-memory
128
/// representation as defined in this crate. To the extent possible,
129
/// deserialization will report this error whenever this process fails.
130
///
131
/// A `DeserializeError` provides no introspection capabilities. Its only
132
/// supported operation is conversion to a human readable error message.
133
///
134
/// This error type implements the `std::error::Error` trait only when the
135
/// `std` feature is enabled. Otherwise, this type is defined in all
136
/// configurations.
137
#[derive(Debug)]
138
pub struct DeserializeError(DeserializeErrorKind);
139
140
#[derive(Debug)]
141
enum DeserializeErrorKind {
142
    Generic { msg: &'static str },
143
    BufferTooSmall { what: &'static str },
144
    InvalidUsize { what: &'static str },
145
    VersionMismatch { expected: u32, found: u32 },
146
    EndianMismatch { expected: u32, found: u32 },
147
    AlignmentMismatch { alignment: usize, address: usize },
148
    LabelMismatch { expected: &'static str },
149
    ArithmeticOverflow { what: &'static str },
150
    PatternID { err: PatternIDError, what: &'static str },
151
    StateID { err: StateIDError, what: &'static str },
152
}
153
154
impl DeserializeError {
155
1.86k
    pub(crate) fn generic(msg: &'static str) -> DeserializeError {
156
1.86k
        DeserializeError(DeserializeErrorKind::Generic { msg })
157
1.86k
    }
158
159
1.07k
    pub(crate) fn buffer_too_small(what: &'static str) -> DeserializeError {
160
1.07k
        DeserializeError(DeserializeErrorKind::BufferTooSmall { what })
161
1.07k
    }
162
163
0
    fn invalid_usize(what: &'static str) -> DeserializeError {
164
0
        DeserializeError(DeserializeErrorKind::InvalidUsize { what })
165
0
    }
166
167
0
    fn version_mismatch(expected: u32, found: u32) -> DeserializeError {
168
0
        DeserializeError(DeserializeErrorKind::VersionMismatch {
169
0
            expected,
170
0
            found,
171
0
        })
172
0
    }
173
174
0
    fn endian_mismatch(expected: u32, found: u32) -> DeserializeError {
175
0
        DeserializeError(DeserializeErrorKind::EndianMismatch {
176
0
            expected,
177
0
            found,
178
0
        })
179
0
    }
180
181
0
    fn alignment_mismatch(
182
0
        alignment: usize,
183
0
        address: usize,
184
0
    ) -> DeserializeError {
185
0
        DeserializeError(DeserializeErrorKind::AlignmentMismatch {
186
0
            alignment,
187
0
            address,
188
0
        })
189
0
    }
190
191
0
    fn label_mismatch(expected: &'static str) -> DeserializeError {
192
0
        DeserializeError(DeserializeErrorKind::LabelMismatch { expected })
193
0
    }
194
195
0
    fn arithmetic_overflow(what: &'static str) -> DeserializeError {
196
0
        DeserializeError(DeserializeErrorKind::ArithmeticOverflow { what })
197
0
    }
198
199
6
    fn pattern_id_error(
200
6
        err: PatternIDError,
201
6
        what: &'static str,
202
6
    ) -> DeserializeError {
203
6
        DeserializeError(DeserializeErrorKind::PatternID { err, what })
204
6
    }
205
206
207
    pub(crate) fn state_id_error(
207
207
        err: StateIDError,
208
207
        what: &'static str,
209
207
    ) -> DeserializeError {
210
207
        DeserializeError(DeserializeErrorKind::StateID { err, what })
211
207
    }
212
}
213
214
#[cfg(feature = "std")]
215
impl std::error::Error for DeserializeError {}
216
217
impl core::fmt::Display for DeserializeError {
218
0
    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
219
        use self::DeserializeErrorKind::*;
220
221
0
        match self.0 {
222
0
            Generic { msg } => write!(f, "{msg}"),
223
0
            BufferTooSmall { what } => {
224
0
                write!(f, "buffer is too small to read {what}")
225
            }
226
0
            InvalidUsize { what } => {
227
0
                write!(f, "{what} is too big to fit in a usize")
228
            }
229
0
            VersionMismatch { expected, found } => write!(
230
0
                f,
231
0
                "unsupported version: \
232
0
                 expected version {expected} but found version {found}",
233
            ),
234
0
            EndianMismatch { expected, found } => write!(
235
0
                f,
236
0
                "endianness mismatch: expected 0x{expected:X} but \
237
0
                 got 0x{found:X}. (Are you trying to load an object \
238
0
                 serialized with a different endianness?)",
239
            ),
240
0
            AlignmentMismatch { alignment, address } => write!(
241
0
                f,
242
0
                "alignment mismatch: slice starts at address 0x{address:X}, \
243
0
                 which is not aligned to a {alignment} byte boundary",
244
            ),
245
0
            LabelMismatch { expected } => write!(
246
0
                f,
247
0
                "label mismatch: start of serialized object should \
248
0
                 contain a NUL terminated {expected:?} label, but a different \
249
0
                 label was found",
250
            ),
251
0
            ArithmeticOverflow { what } => {
252
0
                write!(f, "arithmetic overflow for {what}")
253
            }
254
0
            PatternID { ref err, what } => {
255
0
                write!(f, "failed to read pattern ID for {what}: {err}")
256
            }
257
0
            StateID { ref err, what } => {
258
0
                write!(f, "failed to read state ID for {what}: {err}")
259
            }
260
        }
261
0
    }
262
}
263
264
/// Safely converts a `&[u32]` to `&[StateID]` with zero cost.
265
#[cfg_attr(feature = "perf-inline", inline(always))]
266
95.2M
pub(crate) fn u32s_to_state_ids(slice: &[u32]) -> &[StateID] {
267
    // SAFETY: This is safe because StateID is defined to have the same memory
268
    // representation as a u32 (it is repr(transparent)). While not every u32
269
    // is a "valid" StateID, callers are not permitted to rely on the validity
270
    // of StateIDs for memory safety. It can only lead to logical errors. (This
271
    // is why StateID::new_unchecked is safe.)
272
    unsafe {
273
95.2M
        core::slice::from_raw_parts(
274
95.2M
            slice.as_ptr().cast::<StateID>(),
275
95.2M
            slice.len(),
276
95.2M
        )
277
    }
278
95.2M
}
279
280
/// Safely converts a `&mut [u32]` to `&mut [StateID]` with zero cost.
281
1.43M
pub(crate) fn u32s_to_state_ids_mut(slice: &mut [u32]) -> &mut [StateID] {
282
    // SAFETY: This is safe because StateID is defined to have the same memory
283
    // representation as a u32 (it is repr(transparent)). While not every u32
284
    // is a "valid" StateID, callers are not permitted to rely on the validity
285
    // of StateIDs for memory safety. It can only lead to logical errors. (This
286
    // is why StateID::new_unchecked is safe.)
287
1.43M
    unsafe {
288
1.43M
        core::slice::from_raw_parts_mut(
289
1.43M
            slice.as_mut_ptr().cast::<StateID>(),
290
1.43M
            slice.len(),
291
1.43M
        )
292
1.43M
    }
293
1.43M
}
294
295
/// Safely converts a `&[u32]` to `&[PatternID]` with zero cost.
296
#[cfg_attr(feature = "perf-inline", inline(always))]
297
3.01M
pub(crate) fn u32s_to_pattern_ids(slice: &[u32]) -> &[PatternID] {
298
    // SAFETY: This is safe because PatternID is defined to have the same
299
    // memory representation as a u32 (it is repr(transparent)). While not
300
    // every u32 is a "valid" PatternID, callers are not permitted to rely
301
    // on the validity of PatternIDs for memory safety. It can only lead to
302
    // logical errors. (This is why PatternID::new_unchecked is safe.)
303
    unsafe {
304
3.01M
        core::slice::from_raw_parts(
305
3.01M
            slice.as_ptr().cast::<PatternID>(),
306
3.01M
            slice.len(),
307
3.01M
        )
308
    }
309
3.01M
}
310
311
/// Checks that the given slice has an alignment that matches `T`.
312
///
313
/// This is useful for checking that a slice has an appropriate alignment
314
/// before casting it to a &[T]. Note though that alignment is not itself
315
/// sufficient to perform the cast for any `T`.
316
17.5k
pub(crate) fn check_alignment<T>(
317
17.5k
    slice: &[u8],
318
17.5k
) -> Result<(), DeserializeError> {
319
17.5k
    let alignment = core::mem::align_of::<T>();
320
17.5k
    let address = slice.as_ptr().as_usize();
321
17.5k
    if address % alignment == 0 {
322
17.5k
        return Ok(());
323
0
    }
324
0
    Err(DeserializeError::alignment_mismatch(alignment, address))
325
17.5k
}
regex_automata::util::wire::check_alignment::<regex_automata::util::primitives::StateID>
Line
Count
Source
316
9.74k
pub(crate) fn check_alignment<T>(
317
9.74k
    slice: &[u8],
318
9.74k
) -> Result<(), DeserializeError> {
319
9.74k
    let alignment = core::mem::align_of::<T>();
320
9.74k
    let address = slice.as_ptr().as_usize();
321
9.74k
    if address % alignment == 0 {
322
9.74k
        return Ok(());
323
0
    }
324
0
    Err(DeserializeError::alignment_mismatch(alignment, address))
325
9.74k
}
regex_automata::util::wire::check_alignment::<regex_automata::util::primitives::PatternID>
Line
Count
Source
316
5.64k
pub(crate) fn check_alignment<T>(
317
5.64k
    slice: &[u8],
318
5.64k
) -> Result<(), DeserializeError> {
319
5.64k
    let alignment = core::mem::align_of::<T>();
320
5.64k
    let address = slice.as_ptr().as_usize();
321
5.64k
    if address % alignment == 0 {
322
5.64k
        return Ok(());
323
0
    }
324
0
    Err(DeserializeError::alignment_mismatch(alignment, address))
325
5.64k
}
regex_automata::util::wire::check_alignment::<u32>
Line
Count
Source
316
2.17k
pub(crate) fn check_alignment<T>(
317
2.17k
    slice: &[u8],
318
2.17k
) -> Result<(), DeserializeError> {
319
2.17k
    let alignment = core::mem::align_of::<T>();
320
2.17k
    let address = slice.as_ptr().as_usize();
321
2.17k
    if address % alignment == 0 {
322
2.17k
        return Ok(());
323
0
    }
324
0
    Err(DeserializeError::alignment_mismatch(alignment, address))
325
2.17k
}
326
327
/// Reads a possibly empty amount of padding, up to 7 bytes, from the beginning
328
/// of the given slice. All padding bytes must be NUL bytes.
329
///
330
/// This is useful because it can be theoretically necessary to pad the
331
/// beginning of a serialized object with NUL bytes to ensure that it starts
332
/// at a correctly aligned address. These padding bytes should come immediately
333
/// before the label.
334
///
335
/// This returns the number of bytes read from the given slice.
336
3.56k
pub(crate) fn skip_initial_padding(slice: &[u8]) -> usize {
337
3.56k
    let mut nread = 0;
338
3.56k
    while nread < 7 && nread < slice.len() && slice[nread] == 0 {
339
0
        nread += 1;
340
0
    }
341
3.56k
    nread
342
3.56k
}
343
344
/// Allocate a byte buffer of the given size, along with some initial padding
345
/// such that `buf[padding..]` has the same alignment as `T`, where the
346
/// alignment of `T` must be at most `8`. In particular, callers should treat
347
/// the first N bytes (second return value) as padding bytes that must not be
348
/// overwritten. In all cases, the following identity holds:
349
///
350
/// ```ignore
351
/// let (buf, padding) = alloc_aligned_buffer::<StateID>(SIZE);
352
/// assert_eq!(SIZE, buf[padding..].len());
353
/// ```
354
///
355
/// In practice, padding is often zero.
356
///
357
/// The requirement for `8` as a maximum here is somewhat arbitrary. In
358
/// practice, we never need anything bigger in this crate, and so this function
359
/// does some sanity asserts under the assumption of a max alignment of `8`.
360
#[cfg(feature = "alloc")]
361
pub(crate) fn alloc_aligned_buffer<T>(size: usize) -> (Vec<u8>, usize) {
362
    // NOTE: This is a kludge because there's no easy way to allocate a Vec<u8>
363
    // with an alignment guaranteed to be greater than 1. We could create a
364
    // Vec<u32>, but this cannot be safely transmuted to a Vec<u8> without
365
    // concern, since reallocing or dropping the Vec<u8> is UB (different
366
    // alignment than the initial allocation). We could define a wrapper type
367
    // to manage this for us, but it seems like more machinery than it's worth.
368
    let buf = vec![0; size];
369
    let align = core::mem::align_of::<T>();
370
    let address = buf.as_ptr().as_usize();
371
    if address % align == 0 {
372
        return (buf, 0);
373
    }
374
    // Let's try this again. We have to create a totally new alloc with
375
    // the maximum amount of bytes we might need. We can't just extend our
376
    // pre-existing 'buf' because that might create a new alloc with a
377
    // different alignment.
378
    let extra = align - 1;
379
    let mut buf = vec![0; size + extra];
380
    let address = buf.as_ptr().as_usize();
381
    // The code below handles the case where 'address' is aligned to T, so if
382
    // we got lucky and 'address' is now aligned to T (when it previously
383
    // wasn't), then we're done.
384
    if address % align == 0 {
385
        buf.truncate(size);
386
        return (buf, 0);
387
    }
388
    let padding = ((address & !(align - 1)).checked_add(align).unwrap())
389
        .checked_sub(address)
390
        .unwrap();
391
    assert!(padding <= 7, "padding of {padding} is bigger than 7");
392
    assert!(
393
        padding <= extra,
394
        "padding of {padding} is bigger than extra {extra} bytes",
395
    );
396
    buf.truncate(size + padding);
397
    assert_eq!(size + padding, buf.len());
398
    assert_eq!(
399
        0,
400
        buf[padding..].as_ptr().as_usize() % align,
401
        "expected end of initial padding to be aligned to {align}",
402
    );
403
    (buf, padding)
404
}
405
406
/// Reads a NUL terminated label starting at the beginning of the given slice.
407
///
408
/// If a NUL terminated label could not be found, then an error is returned.
409
/// Similarly, if a label is found but doesn't match the expected label, then
410
/// an error is returned.
411
///
412
/// Upon success, the total number of bytes read (including padding bytes) is
413
/// returned.
414
6.77k
pub(crate) fn read_label(
415
6.77k
    slice: &[u8],
416
6.77k
    expected_label: &'static str,
417
6.77k
) -> Result<usize, DeserializeError> {
418
    // Set an upper bound on how many bytes we scan for a NUL. Since no label
419
    // in this crate is longer than 256 bytes, if we can't find one within that
420
    // range, then we have corrupted data.
421
6.77k
    let first_nul =
422
206k
        slice[..cmp::min(slice.len(), 256)].iter().position(|&b| b == 0);
423
6.77k
    let first_nul = match first_nul {
424
6.77k
        Some(first_nul) => first_nul,
425
        None => {
426
0
            return Err(DeserializeError::generic(
427
0
                "could not find NUL terminated label \
428
0
                 at start of serialized object",
429
0
            ));
430
        }
431
    };
432
6.77k
    let len = first_nul + padding_len(first_nul);
433
6.77k
    if slice.len() < len {
434
0
        return Err(DeserializeError::generic(
435
0
            "could not find properly sized label at start of serialized object"
436
0
        ));
437
6.77k
    }
438
6.77k
    if expected_label.as_bytes() != &slice[..first_nul] {
439
0
        return Err(DeserializeError::label_mismatch(expected_label));
440
6.77k
    }
441
6.77k
    Ok(len)
442
6.77k
}
443
444
/// Writes the given label to the buffer as a NUL terminated string. The label
445
/// given must not contain NUL, otherwise this will panic. Similarly, the label
446
/// must not be longer than 255 bytes, otherwise this will panic.
447
///
448
/// Additional NUL bytes are written as necessary to ensure that the number of
449
/// bytes written is always a multiple of 4.
450
///
451
/// Upon success, the total number of bytes written (including padding) is
452
/// returned.
453
0
pub(crate) fn write_label(
454
0
    label: &str,
455
0
    dst: &mut [u8],
456
0
) -> Result<usize, SerializeError> {
457
0
    let nwrite = write_label_len(label);
458
0
    if dst.len() < nwrite {
459
0
        return Err(SerializeError::buffer_too_small("label"));
460
0
    }
461
0
    dst[..label.len()].copy_from_slice(label.as_bytes());
462
0
    for i in 0..(nwrite - label.len()) {
463
0
        dst[label.len() + i] = 0;
464
0
    }
465
0
    assert_eq!(nwrite % 4, 0);
466
0
    Ok(nwrite)
467
0
}
468
469
/// Returns the total number of bytes (including padding) that would be written
470
/// for the given label. This panics if the given label contains a NUL byte or
471
/// is longer than 255 bytes. (The size restriction exists so that searching
472
/// for a label during deserialization can be done in small bounded space.)
473
0
pub(crate) fn write_label_len(label: &str) -> usize {
474
0
    assert!(label.len() <= 255, "label must not be longer than 255 bytes");
475
0
    assert!(label.bytes().all(|b| b != 0), "label must not contain NUL bytes");
476
0
    let label_len = label.len() + 1; // +1 for the NUL terminator
477
0
    label_len + padding_len(label_len)
478
0
}
479
480
/// Reads the endianness check from the beginning of the given slice and
481
/// confirms that the endianness of the serialized object matches the expected
482
/// endianness. If the slice is too small or if the endianness check fails,
483
/// this returns an error.
484
///
485
/// Upon success, the total number of bytes read is returned.
486
6.77k
pub(crate) fn read_endianness_check(
487
6.77k
    slice: &[u8],
488
6.77k
) -> Result<usize, DeserializeError> {
489
6.77k
    let (n, nr) = try_read_u32(slice, "endianness check")?;
490
6.77k
    assert_eq!(nr, write_endianness_check_len());
491
6.77k
    if n != 0xFEFF {
492
0
        return Err(DeserializeError::endian_mismatch(0xFEFF, n));
493
6.77k
    }
494
6.77k
    Ok(nr)
495
6.77k
}
496
497
/// Writes 0xFEFF as an integer using the given endianness.
498
///
499
/// This is useful for writing into the header of a serialized object. It can
500
/// be read during deserialization as a sanity check to ensure the proper
501
/// endianness is used.
502
///
503
/// Upon success, the total number of bytes written is returned.
504
pub(crate) fn write_endianness_check<E: Endian>(
505
    dst: &mut [u8],
506
) -> Result<usize, SerializeError> {
507
    let nwrite = write_endianness_check_len();
508
    if dst.len() < nwrite {
509
        return Err(SerializeError::buffer_too_small("endianness check"));
510
    }
511
    E::write_u32(0xFEFF, dst);
512
    Ok(nwrite)
513
}
514
515
/// Returns the number of bytes written by the endianness check.
516
6.77k
pub(crate) fn write_endianness_check_len() -> usize {
517
6.77k
    size_of::<u32>()
518
6.77k
}
519
520
/// Reads a version number from the beginning of the given slice and confirms
521
/// that is matches the expected version number given. If the slice is too
522
/// small or if the version numbers aren't equivalent, this returns an error.
523
///
524
/// Upon success, the total number of bytes read is returned.
525
///
526
/// N.B. Currently, we require that the version number is exactly equivalent.
527
/// In the future, if we bump the version number without a semver bump, then
528
/// we'll need to relax this a bit and support older versions.
529
6.77k
pub(crate) fn read_version(
530
6.77k
    slice: &[u8],
531
6.77k
    expected_version: u32,
532
6.77k
) -> Result<usize, DeserializeError> {
533
6.77k
    let (n, nr) = try_read_u32(slice, "version")?;
534
6.77k
    assert_eq!(nr, write_version_len());
535
6.77k
    if n != expected_version {
536
0
        return Err(DeserializeError::version_mismatch(expected_version, n));
537
6.77k
    }
538
6.77k
    Ok(nr)
539
6.77k
}
540
541
/// Writes the given version number to the beginning of the given slice.
542
///
543
/// This is useful for writing into the header of a serialized object. It can
544
/// be read during deserialization as a sanity check to ensure that the library
545
/// code supports the format of the serialized object.
546
///
547
/// Upon success, the total number of bytes written is returned.
548
pub(crate) fn write_version<E: Endian>(
549
    version: u32,
550
    dst: &mut [u8],
551
) -> Result<usize, SerializeError> {
552
    let nwrite = write_version_len();
553
    if dst.len() < nwrite {
554
        return Err(SerializeError::buffer_too_small("version number"));
555
    }
556
    E::write_u32(version, dst);
557
    Ok(nwrite)
558
}
559
560
/// Returns the number of bytes written by writing the version number.
561
6.77k
pub(crate) fn write_version_len() -> usize {
562
6.77k
    size_of::<u32>()
563
6.77k
}
564
565
/// Reads a pattern ID from the given slice. If the slice has insufficient
566
/// length, then this panics. If the deserialized integer exceeds the pattern
567
/// ID limit for the current target, then this returns an error.
568
///
569
/// Upon success, this also returns the number of bytes read.
570
84.7k
pub(crate) fn read_pattern_id(
571
84.7k
    slice: &[u8],
572
84.7k
    what: &'static str,
573
84.7k
) -> Result<(PatternID, usize), DeserializeError> {
574
84.7k
    let bytes: [u8; PatternID::SIZE] =
575
84.7k
        slice[..PatternID::SIZE].try_into().unwrap();
576
84.7k
    let pid = PatternID::from_ne_bytes(bytes)
577
84.7k
        .map_err(|err| DeserializeError::pattern_id_error(err, what))?;
578
84.7k
    Ok((pid, PatternID::SIZE))
579
84.7k
}
580
581
/// Reads a pattern ID from the given slice. If the slice has insufficient
582
/// length, then this panics. Otherwise, the deserialized integer is assumed
583
/// to be a valid pattern ID.
584
///
585
/// This also returns the number of bytes read.
586
20.4k
pub(crate) fn read_pattern_id_unchecked(slice: &[u8]) -> (PatternID, usize) {
587
20.4k
    let pid = PatternID::from_ne_bytes_unchecked(
588
20.4k
        slice[..PatternID::SIZE].try_into().unwrap(),
589
    );
590
20.4k
    (pid, PatternID::SIZE)
591
20.4k
}
592
593
/// Write the given pattern ID to the beginning of the given slice of bytes
594
/// using the specified endianness. The given slice must have length at least
595
/// `PatternID::SIZE`, or else this panics. Upon success, the total number of
596
/// bytes written is returned.
597
0
pub(crate) fn write_pattern_id<E: Endian>(
598
0
    pid: PatternID,
599
0
    dst: &mut [u8],
600
0
) -> usize {
601
0
    E::write_u32(pid.as_u32(), dst);
602
0
    PatternID::SIZE
603
0
}
604
605
/// Attempts to read a state ID from the given slice. If the slice has an
606
/// insufficient number of bytes or if the state ID exceeds the limit for
607
/// the current target, then this returns an error.
608
///
609
/// Upon success, this also returns the number of bytes read.
610
42.4k
pub(crate) fn try_read_state_id(
611
42.4k
    slice: &[u8],
612
42.4k
    what: &'static str,
613
42.4k
) -> Result<(StateID, usize), DeserializeError> {
614
42.4k
    if slice.len() < StateID::SIZE {
615
0
        return Err(DeserializeError::buffer_too_small(what));
616
42.4k
    }
617
42.4k
    read_state_id(slice, what)
618
42.4k
}
619
620
/// Reads a state ID from the given slice. If the slice has insufficient
621
/// length, then this panics. If the deserialized integer exceeds the state ID
622
/// limit for the current target, then this returns an error.
623
///
624
/// Upon success, this also returns the number of bytes read.
625
121k
pub(crate) fn read_state_id(
626
121k
    slice: &[u8],
627
121k
    what: &'static str,
628
121k
) -> Result<(StateID, usize), DeserializeError> {
629
121k
    let bytes: [u8; StateID::SIZE] =
630
121k
        slice[..StateID::SIZE].try_into().unwrap();
631
121k
    let sid = StateID::from_ne_bytes(bytes)
632
121k
        .map_err(|err| DeserializeError::state_id_error(err, what))?;
633
121k
    Ok((sid, StateID::SIZE))
634
121k
}
635
636
/// Reads a state ID from the given slice. If the slice has insufficient
637
/// length, then this panics. Otherwise, the deserialized integer is assumed
638
/// to be a valid state ID.
639
///
640
/// This also returns the number of bytes read.
641
25.4k
pub(crate) fn read_state_id_unchecked(slice: &[u8]) -> (StateID, usize) {
642
25.4k
    let sid = StateID::from_ne_bytes_unchecked(
643
25.4k
        slice[..StateID::SIZE].try_into().unwrap(),
644
    );
645
25.4k
    (sid, StateID::SIZE)
646
25.4k
}
647
648
/// Write the given state ID to the beginning of the given slice of bytes
649
/// using the specified endianness. The given slice must have length at least
650
/// `StateID::SIZE`, or else this panics. Upon success, the total number of
651
/// bytes written is returned.
652
0
pub(crate) fn write_state_id<E: Endian>(
653
0
    sid: StateID,
654
0
    dst: &mut [u8],
655
0
) -> usize {
656
0
    E::write_u32(sid.as_u32(), dst);
657
0
    StateID::SIZE
658
0
}
659
660
/// Try to read a u16 as a usize from the beginning of the given slice in
661
/// native endian format. If the slice has fewer than 2 bytes or if the
662
/// deserialized number cannot be represented by usize, then this returns an
663
/// error. The error message will include the `what` description of what is
664
/// being deserialized, for better error messages. `what` should be a noun in
665
/// singular form.
666
///
667
/// Upon success, this also returns the number of bytes read.
668
37.1k
pub(crate) fn try_read_u16_as_usize(
669
37.1k
    slice: &[u8],
670
37.1k
    what: &'static str,
671
37.1k
) -> Result<(usize, usize), DeserializeError> {
672
37.1k
    try_read_u16(slice, what).and_then(|(n, nr)| {
673
37.1k
        usize::try_from(n)
674
37.1k
            .map(|n| (n, nr))
675
37.1k
            .map_err(|_| DeserializeError::invalid_usize(what))
676
37.1k
    })
677
37.1k
}
678
679
/// Try to read a u32 as a usize from the beginning of the given slice in
680
/// native endian format. If the slice has fewer than 4 bytes or if the
681
/// deserialized number cannot be represented by usize, then this returns an
682
/// error. The error message will include the `what` description of what is
683
/// being deserialized, for better error messages. `what` should be a noun in
684
/// singular form.
685
///
686
/// Upon success, this also returns the number of bytes read.
687
40.6k
pub(crate) fn try_read_u32_as_usize(
688
40.6k
    slice: &[u8],
689
40.6k
    what: &'static str,
690
40.6k
) -> Result<(usize, usize), DeserializeError> {
691
40.6k
    try_read_u32(slice, what).and_then(|(n, nr)| {
692
40.5k
        usize::try_from(n)
693
40.5k
            .map(|n| (n, nr))
694
40.5k
            .map_err(|_| DeserializeError::invalid_usize(what))
695
40.5k
    })
696
40.6k
}
697
698
/// Try to read a u16 from the beginning of the given slice in native endian
699
/// format. If the slice has fewer than 2 bytes, then this returns an error.
700
/// The error message will include the `what` description of what is being
701
/// deserialized, for better error messages. `what` should be a noun in
702
/// singular form.
703
///
704
/// Upon success, this also returns the number of bytes read.
705
37.1k
pub(crate) fn try_read_u16(
706
37.1k
    slice: &[u8],
707
37.1k
    what: &'static str,
708
37.1k
) -> Result<(u16, usize), DeserializeError> {
709
37.1k
    check_slice_len(slice, size_of::<u16>(), what)?;
710
37.1k
    Ok((read_u16(slice), size_of::<u16>()))
711
37.1k
}
712
713
/// Try to read a u32 from the beginning of the given slice in native endian
714
/// format. If the slice has fewer than 4 bytes, then this returns an error.
715
/// The error message will include the `what` description of what is being
716
/// deserialized, for better error messages. `what` should be a noun in
717
/// singular form.
718
///
719
/// Upon success, this also returns the number of bytes read.
720
85.5k
pub(crate) fn try_read_u32(
721
85.5k
    slice: &[u8],
722
85.5k
    what: &'static str,
723
85.5k
) -> Result<(u32, usize), DeserializeError> {
724
85.5k
    check_slice_len(slice, size_of::<u32>(), what)?;
725
85.3k
    Ok((read_u32(slice), size_of::<u32>()))
726
85.5k
}
727
728
/// Try to read a u128 from the beginning of the given slice in native endian
729
/// format. If the slice has fewer than 16 bytes, then this returns an error.
730
/// The error message will include the `what` description of what is being
731
/// deserialized, for better error messages. `what` should be a noun in
732
/// singular form.
733
///
734
/// Upon success, this also returns the number of bytes read.
735
8.45k
pub(crate) fn try_read_u128(
736
8.45k
    slice: &[u8],
737
8.45k
    what: &'static str,
738
8.45k
) -> Result<(u128, usize), DeserializeError> {
739
8.45k
    check_slice_len(slice, size_of::<u128>(), what)?;
740
8.45k
    Ok((read_u128(slice), size_of::<u128>()))
741
8.45k
}
742
743
/// Read a u16 from the beginning of the given slice in native endian format.
744
/// If the slice has fewer than 2 bytes, then this panics.
745
///
746
/// Marked as inline to speed up sparse searching which decodes integers from
747
/// its automaton at search time.
748
#[cfg_attr(feature = "perf-inline", inline(always))]
749
229k
pub(crate) fn read_u16(slice: &[u8]) -> u16 {
750
229k
    let bytes: [u8; 2] = slice[..size_of::<u16>()].try_into().unwrap();
751
229k
    u16::from_ne_bytes(bytes)
752
229k
}
753
754
/// Read a u32 from the beginning of the given slice in native endian format.
755
/// If the slice has fewer than 4 bytes, then this panics.
756
///
757
/// Marked as inline to speed up sparse searching which decodes integers from
758
/// its automaton at search time.
759
#[cfg_attr(feature = "perf-inline", inline(always))]
760
131k
pub(crate) fn read_u32(slice: &[u8]) -> u32 {
761
131k
    let bytes: [u8; 4] = slice[..size_of::<u32>()].try_into().unwrap();
762
131k
    u32::from_ne_bytes(bytes)
763
131k
}
764
765
/// Read a u128 from the beginning of the given slice in native endian format.
766
/// If the slice has fewer than 16 bytes, then this panics.
767
8.45k
pub(crate) fn read_u128(slice: &[u8]) -> u128 {
768
8.45k
    let bytes: [u8; 16] = slice[..size_of::<u128>()].try_into().unwrap();
769
8.45k
    u128::from_ne_bytes(bytes)
770
8.45k
}
771
772
/// Checks that the given slice has some minimal length. If it's smaller than
773
/// the bound given, then a "buffer too small" error is returned with `what`
774
/// describing what the buffer represents.
775
371k
pub(crate) fn check_slice_len<T>(
776
371k
    slice: &[T],
777
371k
    at_least_len: usize,
778
371k
    what: &'static str,
779
371k
) -> Result<(), DeserializeError> {
780
371k
    if slice.len() < at_least_len {
781
1.07k
        return Err(DeserializeError::buffer_too_small(what));
782
370k
    }
783
370k
    Ok(())
784
371k
}
785
786
/// Multiply the given numbers, and on overflow, return an error that includes
787
/// 'what' in the error message.
788
///
789
/// This is useful when doing arithmetic with untrusted data.
790
34.6k
pub(crate) fn mul(
791
34.6k
    a: usize,
792
34.6k
    b: usize,
793
34.6k
    what: &'static str,
794
34.6k
) -> Result<usize, DeserializeError> {
795
34.6k
    match a.checked_mul(b) {
796
34.6k
        Some(c) => Ok(c),
797
0
        None => Err(DeserializeError::arithmetic_overflow(what)),
798
    }
799
34.6k
}
800
801
/// Add the given numbers, and on overflow, return an error that includes
802
/// 'what' in the error message.
803
///
804
/// This is useful when doing arithmetic with untrusted data.
805
44.7k
pub(crate) fn add(
806
44.7k
    a: usize,
807
44.7k
    b: usize,
808
44.7k
    what: &'static str,
809
44.7k
) -> Result<usize, DeserializeError> {
810
44.7k
    match a.checked_add(b) {
811
44.7k
        Some(c) => Ok(c),
812
0
        None => Err(DeserializeError::arithmetic_overflow(what)),
813
    }
814
44.7k
}
815
816
/// Shift `a` left by `b`, and on overflow, return an error that includes
817
/// 'what' in the error message.
818
///
819
/// This is useful when doing arithmetic with untrusted data.
820
3.31k
pub(crate) fn shl(
821
3.31k
    a: usize,
822
3.31k
    b: usize,
823
3.31k
    what: &'static str,
824
3.31k
) -> Result<usize, DeserializeError> {
825
3.31k
    let amount = u32::try_from(b)
826
3.31k
        .map_err(|_| DeserializeError::arithmetic_overflow(what))?;
827
3.31k
    match a.checked_shl(amount) {
828
3.31k
        Some(c) => Ok(c),
829
0
        None => Err(DeserializeError::arithmetic_overflow(what)),
830
    }
831
3.31k
}
832
833
/// Returns the number of additional bytes required to add to the given length
834
/// in order to make the total length a multiple of 4. The return value is
835
/// always less than 4.
836
6.77k
pub(crate) fn padding_len(non_padding_len: usize) -> usize {
837
6.77k
    (4 - (non_padding_len & 0b11)) & 0b11
838
6.77k
}
839
840
/// A simple trait for writing code generic over endianness.
841
///
842
/// This is similar to what byteorder provides, but we only need a very small
843
/// subset.
844
pub(crate) trait Endian {
845
    /// Writes a u16 to the given destination buffer in a particular
846
    /// endianness. If the destination buffer has a length smaller than 2, then
847
    /// this panics.
848
    fn write_u16(n: u16, dst: &mut [u8]);
849
850
    /// Writes a u32 to the given destination buffer in a particular
851
    /// endianness. If the destination buffer has a length smaller than 4, then
852
    /// this panics.
853
    fn write_u32(n: u32, dst: &mut [u8]);
854
855
    /// Writes a u128 to the given destination buffer in a particular
856
    /// endianness. If the destination buffer has a length smaller than 16,
857
    /// then this panics.
858
    fn write_u128(n: u128, dst: &mut [u8]);
859
}
860
861
/// Little endian writing.
862
pub(crate) enum LE {}
863
/// Big endian writing.
864
pub(crate) enum BE {}
865
866
#[cfg(target_endian = "little")]
867
pub(crate) type NE = LE;
868
#[cfg(target_endian = "big")]
869
pub(crate) type NE = BE;
870
871
impl Endian for LE {
872
0
    fn write_u16(n: u16, dst: &mut [u8]) {
873
0
        dst[..2].copy_from_slice(&n.to_le_bytes());
874
0
    }
875
876
0
    fn write_u32(n: u32, dst: &mut [u8]) {
877
0
        dst[..4].copy_from_slice(&n.to_le_bytes());
878
0
    }
879
880
0
    fn write_u128(n: u128, dst: &mut [u8]) {
881
0
        dst[..16].copy_from_slice(&n.to_le_bytes());
882
0
    }
883
}
884
885
impl Endian for BE {
886
0
    fn write_u16(n: u16, dst: &mut [u8]) {
887
0
        dst[..2].copy_from_slice(&n.to_be_bytes());
888
0
    }
889
890
0
    fn write_u32(n: u32, dst: &mut [u8]) {
891
0
        dst[..4].copy_from_slice(&n.to_be_bytes());
892
0
    }
893
894
0
    fn write_u128(n: u128, dst: &mut [u8]) {
895
0
        dst[..16].copy_from_slice(&n.to_be_bytes());
896
0
    }
897
}
898
899
#[cfg(all(test, feature = "alloc"))]
900
mod tests {
901
    use super::*;
902
903
    #[test]
904
    fn labels() {
905
        let mut buf = [0; 1024];
906
907
        let nwrite = write_label("fooba", &mut buf).unwrap();
908
        assert_eq!(nwrite, 8);
909
        assert_eq!(&buf[..nwrite], b"fooba\x00\x00\x00");
910
911
        let nread = read_label(&buf, "fooba").unwrap();
912
        assert_eq!(nread, 8);
913
    }
914
915
    #[test]
916
    #[should_panic]
917
    fn bad_label_interior_nul() {
918
        // interior NULs are not allowed
919
        write_label("foo\x00bar", &mut [0; 1024]).unwrap();
920
    }
921
922
    #[test]
923
    fn bad_label_almost_too_long() {
924
        // ok
925
        write_label(&"z".repeat(255), &mut [0; 1024]).unwrap();
926
    }
927
928
    #[test]
929
    #[should_panic]
930
    fn bad_label_too_long() {
931
        // labels longer than 255 bytes are banned
932
        write_label(&"z".repeat(256), &mut [0; 1024]).unwrap();
933
    }
934
935
    #[test]
936
    fn padding() {
937
        assert_eq!(0, padding_len(8));
938
        assert_eq!(3, padding_len(9));
939
        assert_eq!(2, padding_len(10));
940
        assert_eq!(1, padding_len(11));
941
        assert_eq!(0, padding_len(12));
942
        assert_eq!(3, padding_len(13));
943
        assert_eq!(2, padding_len(14));
944
        assert_eq!(1, padding_len(15));
945
        assert_eq!(0, padding_len(16));
946
    }
947
}