Coverage Report

Created: 2025-07-11 06:39

/rust/registry/src/index.crates.io-6f17d22bba15001f/regex-automata-0.2.0/src/util/bytes.rs
Line
Count
Source (jump to first uncovered line)
1
/*
2
A collection of helper functions, types and traits for serializing automata.
3
4
This crate defines its own bespoke serialization mechanism for some structures
5
provided in the public API, namely, DFAs. A bespoke mechanism was developed
6
primarily because structures like automata demand a specific binary format.
7
Attempting to encode their rich structure in an existing serialization
8
format is just not feasible. Moreover, the format for each structure is
9
generally designed such that deserialization is cheap. More specifically, that
10
deserialization can be done in constant time. (The idea being that you can
11
embed it into your binary or mmap it, and then use it immediately.)
12
13
In order to achieve this, most of the structures in this crate use an in-memory
14
representation that very closely corresponds to its binary serialized form.
15
This pervades and complicates everything, and in some cases, requires dealing
16
with alignment and reasoning about safety.
17
18
This technique does have major advantages. In particular, it permits doing
19
the potentially costly work of compiling a finite state machine in an offline
20
manner, and then loading it at runtime not only without having to re-compile
21
the regex, but even without the code required to do the compilation. This, for
22
example, permits one to use a pre-compiled DFA not only in environments without
23
Rust's standard library, but also in environments without a heap.
24
25
In the code below, whenever we insert some kind of padding, it's to enforce a
26
4-byte alignment, unless otherwise noted. Namely, u32 is the only state ID type
27
supported. (In a previous version of this library, DFAs were generic over the
28
state ID representation.)
29
30
Also, serialization generally requires the caller to specify endianness,
31
where as deserialization always assumes native endianness (otherwise cheap
32
deserialization would be impossible). This implies that serializing a structure
33
generally requires serializing both its big-endian and little-endian variants,
34
and then loading the correct one based on the target's endianness.
35
*/
36
37
use core::{
38
    cmp,
39
    convert::{TryFrom, TryInto},
40
    mem::size_of,
41
};
42
43
#[cfg(feature = "alloc")]
44
use alloc::{vec, vec::Vec};
45
46
use crate::util::id::{PatternID, PatternIDError, StateID, StateIDError};
47
48
/// An error that occurs when serializing an object from this crate.
49
///
50
/// Serialization, as used in this crate, universally refers to the process
51
/// of transforming a structure (like a DFA) into a custom binary format
52
/// represented by `&[u8]`. To this end, serialization is generally infallible.
53
/// However, it can fail when caller provided buffer sizes are too small. When
54
/// that occurs, a serialization error is reported.
55
///
56
/// A `SerializeError` provides no introspection capabilities. Its only
57
/// supported operation is conversion to a human readable error message.
58
///
59
/// This error type implements the `std::error::Error` trait only when the
60
/// `std` feature is enabled. Otherwise, this type is defined in all
61
/// configurations.
62
#[derive(Debug)]
63
pub struct SerializeError {
64
    /// The name of the thing that a buffer is too small for.
65
    ///
66
    /// Currently, the only kind of serialization error is one that is
67
    /// committed by a caller: providing a destination buffer that is too
68
    /// small to fit the serialized object. This makes sense conceptually,
69
    /// since every valid inhabitant of a type should be serializable.
70
    ///
71
    /// This is somewhat exposed in the public API of this crate. For example,
72
    /// the `to_bytes_{big,little}_endian` APIs return a `Vec<u8>` and are
73
    /// guaranteed to never panic or error. This is only possible because the
74
    /// implementation guarantees that it will allocate a `Vec<u8>` that is
75
    /// big enough.
76
    ///
77
    /// In summary, if a new serialization error kind needs to be added, then
78
    /// it will need careful consideration.
79
    what: &'static str,
80
}
81
82
impl SerializeError {
83
0
    pub(crate) fn buffer_too_small(what: &'static str) -> SerializeError {
84
0
        SerializeError { what }
85
0
    }
86
}
87
88
impl core::fmt::Display for SerializeError {
89
0
    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
90
0
        write!(f, "destination buffer is too small to write {}", self.what)
91
0
    }
92
}
93
94
#[cfg(feature = "std")]
95
impl std::error::Error for SerializeError {}
96
97
/// An error that occurs when deserializing an object defined in this crate.
98
///
99
/// Serialization, as used in this crate, universally refers to the process
100
/// of transforming a structure (like a DFA) into a custom binary format
101
/// represented by `&[u8]`. Deserialization, then, refers to the process of
102
/// cheaply converting this binary format back to the object's in-memory
103
/// representation as defined in this crate. To the extent possible,
104
/// deserialization will report this error whenever this process fails.
105
///
106
/// A `DeserializeError` provides no introspection capabilities. Its only
107
/// supported operation is conversion to a human readable error message.
108
///
109
/// This error type implements the `std::error::Error` trait only when the
110
/// `std` feature is enabled. Otherwise, this type is defined in all
111
/// configurations.
112
#[derive(Debug)]
113
pub struct DeserializeError(DeserializeErrorKind);
114
115
#[derive(Debug)]
116
enum DeserializeErrorKind {
117
    Generic { msg: &'static str },
118
    BufferTooSmall { what: &'static str },
119
    InvalidUsize { what: &'static str },
120
    InvalidVarint { what: &'static str },
121
    VersionMismatch { expected: u32, found: u32 },
122
    EndianMismatch { expected: u32, found: u32 },
123
    AlignmentMismatch { alignment: usize, address: usize },
124
    LabelMismatch { expected: &'static str },
125
    ArithmeticOverflow { what: &'static str },
126
    PatternID { err: PatternIDError, what: &'static str },
127
    StateID { err: StateIDError, what: &'static str },
128
}
129
130
impl DeserializeError {
131
0
    pub(crate) fn generic(msg: &'static str) -> DeserializeError {
132
0
        DeserializeError(DeserializeErrorKind::Generic { msg })
133
0
    }
134
135
0
    pub(crate) fn buffer_too_small(what: &'static str) -> DeserializeError {
136
0
        DeserializeError(DeserializeErrorKind::BufferTooSmall { what })
137
0
    }
138
139
0
    pub(crate) fn invalid_usize(what: &'static str) -> DeserializeError {
140
0
        DeserializeError(DeserializeErrorKind::InvalidUsize { what })
141
0
    }
142
143
0
    fn invalid_varint(what: &'static str) -> DeserializeError {
144
0
        DeserializeError(DeserializeErrorKind::InvalidVarint { what })
145
0
    }
146
147
0
    fn version_mismatch(expected: u32, found: u32) -> DeserializeError {
148
0
        DeserializeError(DeserializeErrorKind::VersionMismatch {
149
0
            expected,
150
0
            found,
151
0
        })
152
0
    }
153
154
0
    fn endian_mismatch(expected: u32, found: u32) -> DeserializeError {
155
0
        DeserializeError(DeserializeErrorKind::EndianMismatch {
156
0
            expected,
157
0
            found,
158
0
        })
159
0
    }
160
161
0
    fn alignment_mismatch(
162
0
        alignment: usize,
163
0
        address: usize,
164
0
    ) -> DeserializeError {
165
0
        DeserializeError(DeserializeErrorKind::AlignmentMismatch {
166
0
            alignment,
167
0
            address,
168
0
        })
169
0
    }
170
171
0
    fn label_mismatch(expected: &'static str) -> DeserializeError {
172
0
        DeserializeError(DeserializeErrorKind::LabelMismatch { expected })
173
0
    }
174
175
0
    fn arithmetic_overflow(what: &'static str) -> DeserializeError {
176
0
        DeserializeError(DeserializeErrorKind::ArithmeticOverflow { what })
177
0
    }
178
179
0
    pub(crate) fn pattern_id_error(
180
0
        err: PatternIDError,
181
0
        what: &'static str,
182
0
    ) -> DeserializeError {
183
0
        DeserializeError(DeserializeErrorKind::PatternID { err, what })
184
0
    }
185
186
0
    pub(crate) fn state_id_error(
187
0
        err: StateIDError,
188
0
        what: &'static str,
189
0
    ) -> DeserializeError {
190
0
        DeserializeError(DeserializeErrorKind::StateID { err, what })
191
0
    }
192
}
193
194
#[cfg(feature = "std")]
195
impl std::error::Error for DeserializeError {}
196
197
impl core::fmt::Display for DeserializeError {
198
0
    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
199
        use self::DeserializeErrorKind::*;
200
201
0
        match self.0 {
202
0
            Generic { msg } => write!(f, "{}", msg),
203
0
            BufferTooSmall { what } => {
204
0
                write!(f, "buffer is too small to read {}", what)
205
            }
206
0
            InvalidUsize { what } => {
207
0
                write!(f, "{} is too big to fit in a usize", what)
208
            }
209
0
            InvalidVarint { what } => {
210
0
                write!(f, "could not decode valid varint for {}", what)
211
            }
212
0
            VersionMismatch { expected, found } => write!(
213
0
                f,
214
0
                "unsupported version: \
215
0
                 expected version {} but found version {}",
216
0
                expected, found,
217
0
            ),
218
0
            EndianMismatch { expected, found } => write!(
219
0
                f,
220
0
                "endianness mismatch: expected 0x{:X} but got 0x{:X}. \
221
0
                 (Are you trying to load an object serialized with a \
222
0
                 different endianness?)",
223
0
                expected, found,
224
0
            ),
225
0
            AlignmentMismatch { alignment, address } => write!(
226
0
                f,
227
0
                "alignment mismatch: slice starts at address \
228
0
                 0x{:X}, which is not aligned to a {} byte boundary",
229
0
                address, alignment,
230
0
            ),
231
0
            LabelMismatch { expected } => write!(
232
0
                f,
233
0
                "label mismatch: start of serialized object should \
234
0
                 contain a NUL terminated {:?} label, but a different \
235
0
                 label was found",
236
0
                expected,
237
0
            ),
238
0
            ArithmeticOverflow { what } => {
239
0
                write!(f, "arithmetic overflow for {}", what)
240
            }
241
0
            PatternID { ref err, what } => {
242
0
                write!(f, "failed to read pattern ID for {}: {}", what, err)
243
            }
244
0
            StateID { ref err, what } => {
245
0
                write!(f, "failed to read state ID for {}: {}", what, err)
246
            }
247
        }
248
0
    }
249
}
250
251
/// Checks that the given slice has an alignment that matches `T`.
252
///
253
/// This is useful for checking that a slice has an appropriate alignment
254
/// before casting it to a &[T]. Note though that alignment is not itself
255
/// sufficient to perform the cast for any `T`.
256
0
pub fn check_alignment<T>(slice: &[u8]) -> Result<(), DeserializeError> {
257
0
    let alignment = core::mem::align_of::<T>();
258
0
    let address = slice.as_ptr() as usize;
259
0
    if address % alignment == 0 {
260
0
        return Ok(());
261
0
    }
262
0
    Err(DeserializeError::alignment_mismatch(alignment, address))
263
0
}
Unexecuted instantiation: regex_automata::util::bytes::check_alignment::<regex_automata::util::id::StateID>
Unexecuted instantiation: regex_automata::util::bytes::check_alignment::<regex_automata::util::id::PatternID>
Unexecuted instantiation: regex_automata::util::bytes::check_alignment::<u32>
264
265
/// Reads a possibly empty amount of padding, up to 7 bytes, from the beginning
266
/// of the given slice. All padding bytes must be NUL bytes.
267
///
268
/// This is useful because it can be theoretically necessary to pad the
269
/// beginning of a serialized object with NUL bytes to ensure that it starts
270
/// at a correctly aligned address. These padding bytes should come immediately
271
/// before the label.
272
///
273
/// This returns the number of bytes read from the given slice.
274
0
pub fn skip_initial_padding(slice: &[u8]) -> usize {
275
0
    let mut nread = 0;
276
0
    while nread < 7 && nread < slice.len() && slice[nread] == 0 {
277
0
        nread += 1;
278
0
    }
279
0
    nread
280
0
}
281
282
/// Allocate a byte buffer of the given size, along with some initial padding
283
/// such that `buf[padding..]` has the same alignment as `T`, where the
284
/// alignment of `T` must be at most `8`. In particular, callers should treat
285
/// the first N bytes (second return value) as padding bytes that must not be
286
/// overwritten. In all cases, the following identity holds:
287
///
288
/// ```ignore
289
/// let (buf, padding) = alloc_aligned_buffer::<StateID>(SIZE);
290
/// assert_eq!(SIZE, buf[padding..].len());
291
/// ```
292
///
293
/// In practice, padding is often zero.
294
///
295
/// The requirement for `8` as a maximum here is somewhat arbitrary. In
296
/// practice, we never need anything bigger in this crate, and so this function
297
/// does some sanity asserts under the assumption of a max alignment of `8`.
298
#[cfg(feature = "alloc")]
299
pub fn alloc_aligned_buffer<T>(size: usize) -> (Vec<u8>, usize) {
300
    // FIXME: This is a kludge because there's no easy way to allocate a
301
    // Vec<u8> with an alignment guaranteed to be greater than 1. We could
302
    // create a Vec<u32>, but this cannot be safely transmuted to a Vec<u8>
303
    // without concern, since reallocing or dropping the Vec<u8> is UB
304
    // (different alignment than the initial allocation). We could define a
305
    // wrapper type to manage this for us, but it seems like more machinery
306
    // than it's worth.
307
    let mut buf = vec![0; size];
308
    let align = core::mem::align_of::<T>();
309
    let address = buf.as_ptr() as usize;
310
    if address % align == 0 {
311
        return (buf, 0);
312
    }
313
    // It's not quite clear how to robustly test this code, since the allocator
314
    // in my environment appears to always return addresses aligned to at
315
    // least 8 bytes, even when the alignment requirement is smaller. A feeble
316
    // attempt at ensuring correctness is provided with asserts.
317
    let padding = ((address & !0b111).checked_add(8).unwrap())
318
        .checked_sub(address)
319
        .unwrap();
320
    assert!(padding <= 7, "padding of {} is bigger than 7", padding);
321
    buf.extend(core::iter::repeat(0).take(padding));
322
    assert_eq!(size + padding, buf.len());
323
    assert_eq!(
324
        0,
325
        buf[padding..].as_ptr() as usize % align,
326
        "expected end of initial padding to be aligned to {}",
327
        align,
328
    );
329
    (buf, padding)
330
}
331
332
/// Reads a NUL terminated label starting at the beginning of the given slice.
333
///
334
/// If a NUL terminated label could not be found, then an error is returned.
335
/// Similary, if a label is found but doesn't match the expected label, then
336
/// an error is returned.
337
///
338
/// Upon success, the total number of bytes read (including padding bytes) is
339
/// returned.
340
0
pub fn read_label(
341
0
    slice: &[u8],
342
0
    expected_label: &'static str,
343
0
) -> Result<usize, DeserializeError> {
344
0
    // Set an upper bound on how many bytes we scan for a NUL. Since no label
345
0
    // in this crate is longer than 256 bytes, if we can't find one within that
346
0
    // range, then we have corrupted data.
347
0
    let first_nul =
348
0
        slice[..cmp::min(slice.len(), 256)].iter().position(|&b| b == 0);
349
0
    let first_nul = match first_nul {
350
0
        Some(first_nul) => first_nul,
351
        None => {
352
0
            return Err(DeserializeError::generic(
353
0
                "could not find NUL terminated label \
354
0
                 at start of serialized object",
355
0
            ));
356
        }
357
    };
358
0
    let len = first_nul + padding_len(first_nul);
359
0
    if slice.len() < len {
360
0
        return Err(DeserializeError::generic(
361
0
            "could not find properly sized label at start of serialized object"
362
0
        ));
363
0
    }
364
0
    if expected_label.as_bytes() != &slice[..first_nul] {
365
0
        return Err(DeserializeError::label_mismatch(expected_label));
366
0
    }
367
0
    Ok(len)
368
0
}
369
370
/// Writes the given label to the buffer as a NUL terminated string. The label
371
/// given must not contain NUL, otherwise this will panic. Similarly, the label
372
/// must not be longer than 255 bytes, otherwise this will panic.
373
///
374
/// Additional NUL bytes are written as necessary to ensure that the number of
375
/// bytes written is always a multiple of 4.
376
///
377
/// Upon success, the total number of bytes written (including padding) is
378
/// returned.
379
0
pub fn write_label(
380
0
    label: &str,
381
0
    dst: &mut [u8],
382
0
) -> Result<usize, SerializeError> {
383
0
    let nwrite = write_label_len(label);
384
0
    if dst.len() < nwrite {
385
0
        return Err(SerializeError::buffer_too_small("label"));
386
0
    }
387
0
    dst[..label.len()].copy_from_slice(label.as_bytes());
388
0
    for i in 0..(nwrite - label.len()) {
389
0
        dst[label.len() + i] = 0;
390
0
    }
391
0
    assert_eq!(nwrite % 4, 0);
392
0
    Ok(nwrite)
393
0
}
394
395
/// Returns the total number of bytes (including padding) that would be written
396
/// for the given label. This panics if the given label contains a NUL byte or
397
/// is longer than 255 bytes. (The size restriction exists so that searching
398
/// for a label during deserialization can be done in small bounded space.)
399
0
pub fn write_label_len(label: &str) -> usize {
400
0
    if label.len() > 255 {
401
0
        panic!("label must not be longer than 255 bytes");
402
0
    }
403
0
    if label.as_bytes().iter().position(|&b| b == 0).is_some() {
404
0
        panic!("label must not contain NUL bytes");
405
0
    }
406
0
    let label_len = label.len() + 1; // +1 for the NUL terminator
407
0
    label_len + padding_len(label_len)
408
0
}
409
410
/// Reads the endianness check from the beginning of the given slice and
411
/// confirms that the endianness of the serialized object matches the expected
412
/// endianness. If the slice is too small or if the endianness check fails,
413
/// this returns an error.
414
///
415
/// Upon success, the total number of bytes read is returned.
416
0
pub fn read_endianness_check(slice: &[u8]) -> Result<usize, DeserializeError> {
417
0
    let (n, nr) = try_read_u32(slice, "endianness check")?;
418
0
    assert_eq!(nr, write_endianness_check_len());
419
0
    if n != 0xFEFF {
420
0
        return Err(DeserializeError::endian_mismatch(0xFEFF, n));
421
0
    }
422
0
    Ok(nr)
423
0
}
424
425
/// Writes 0xFEFF as an integer using the given endianness.
426
///
427
/// This is useful for writing into the header of a serialized object. It can
428
/// be read during deserialization as a sanity check to ensure the proper
429
/// endianness is used.
430
///
431
/// Upon success, the total number of bytes written is returned.
432
0
pub fn write_endianness_check<E: Endian>(
433
0
    dst: &mut [u8],
434
0
) -> Result<usize, SerializeError> {
435
0
    let nwrite = write_endianness_check_len();
436
0
    if dst.len() < nwrite {
437
0
        return Err(SerializeError::buffer_too_small("endianness check"));
438
0
    }
439
0
    E::write_u32(0xFEFF, dst);
440
0
    Ok(nwrite)
441
0
}
442
443
/// Returns the number of bytes written by the endianness check.
444
0
pub fn write_endianness_check_len() -> usize {
445
0
    size_of::<u32>()
446
0
}
447
448
/// Reads a version number from the beginning of the given slice and confirms
449
/// that is matches the expected version number given. If the slice is too
450
/// small or if the version numbers aren't equivalent, this returns an error.
451
///
452
/// Upon success, the total number of bytes read is returned.
453
///
454
/// N.B. Currently, we require that the version number is exactly equivalent.
455
/// In the future, if we bump the version number without a semver bump, then
456
/// we'll need to relax this a bit and support older versions.
457
0
pub fn read_version(
458
0
    slice: &[u8],
459
0
    expected_version: u32,
460
0
) -> Result<usize, DeserializeError> {
461
0
    let (n, nr) = try_read_u32(slice, "version")?;
462
0
    assert_eq!(nr, write_version_len());
463
0
    if n != expected_version {
464
0
        return Err(DeserializeError::version_mismatch(expected_version, n));
465
0
    }
466
0
    Ok(nr)
467
0
}
468
469
/// Writes the given version number to the beginning of the given slice.
470
///
471
/// This is useful for writing into the header of a serialized object. It can
472
/// be read during deserialization as a sanity check to ensure that the library
473
/// code supports the format of the serialized object.
474
///
475
/// Upon success, the total number of bytes written is returned.
476
0
pub fn write_version<E: Endian>(
477
0
    version: u32,
478
0
    dst: &mut [u8],
479
0
) -> Result<usize, SerializeError> {
480
0
    let nwrite = write_version_len();
481
0
    if dst.len() < nwrite {
482
0
        return Err(SerializeError::buffer_too_small("version number"));
483
0
    }
484
0
    E::write_u32(version, dst);
485
0
    Ok(nwrite)
486
0
}
487
488
/// Returns the number of bytes written by writing the version number.
489
0
pub fn write_version_len() -> usize {
490
0
    size_of::<u32>()
491
0
}
492
493
/// Reads a pattern ID from the given slice. If the slice has insufficient
494
/// length, then this panics. If the deserialized integer exceeds the pattern
495
/// ID limit for the current target, then this returns an error.
496
///
497
/// Upon success, this also returns the number of bytes read.
498
0
pub fn read_pattern_id(
499
0
    slice: &[u8],
500
0
    what: &'static str,
501
0
) -> Result<(PatternID, usize), DeserializeError> {
502
0
    let bytes: [u8; PatternID::SIZE] =
503
0
        slice[..PatternID::SIZE].try_into().unwrap();
504
0
    let pid = PatternID::from_ne_bytes(bytes)
505
0
        .map_err(|err| DeserializeError::pattern_id_error(err, what))?;
506
0
    Ok((pid, PatternID::SIZE))
507
0
}
508
509
/// Reads a pattern ID from the given slice. If the slice has insufficient
510
/// length, then this panics. Otherwise, the deserialized integer is assumed
511
/// to be a valid pattern ID.
512
///
513
/// This also returns the number of bytes read.
514
0
pub fn read_pattern_id_unchecked(slice: &[u8]) -> (PatternID, usize) {
515
0
    let pid = PatternID::from_ne_bytes_unchecked(
516
0
        slice[..PatternID::SIZE].try_into().unwrap(),
517
0
    );
518
0
    (pid, PatternID::SIZE)
519
0
}
520
521
/// Write the given pattern ID to the beginning of the given slice of bytes
522
/// using the specified endianness. The given slice must have length at least
523
/// `PatternID::SIZE`, or else this panics. Upon success, the total number of
524
/// bytes written is returned.
525
0
pub fn write_pattern_id<E: Endian>(pid: PatternID, dst: &mut [u8]) -> usize {
526
0
    E::write_u32(pid.as_u32(), dst);
527
0
    PatternID::SIZE
528
0
}
529
530
/// Attempts to read a state ID from the given slice. If the slice has an
531
/// insufficient number of bytes or if the state ID exceeds the limit for
532
/// the current target, then this returns an error.
533
///
534
/// Upon success, this also returns the number of bytes read.
535
0
pub fn try_read_state_id(
536
0
    slice: &[u8],
537
0
    what: &'static str,
538
0
) -> Result<(StateID, usize), DeserializeError> {
539
0
    if slice.len() < StateID::SIZE {
540
0
        return Err(DeserializeError::buffer_too_small(what));
541
0
    }
542
0
    read_state_id(slice, what)
543
0
}
544
545
/// Reads a state ID from the given slice. If the slice has insufficient
546
/// length, then this panics. If the deserialized integer exceeds the state ID
547
/// limit for the current target, then this returns an error.
548
///
549
/// Upon success, this also returns the number of bytes read.
550
0
pub fn read_state_id(
551
0
    slice: &[u8],
552
0
    what: &'static str,
553
0
) -> Result<(StateID, usize), DeserializeError> {
554
0
    let bytes: [u8; StateID::SIZE] =
555
0
        slice[..StateID::SIZE].try_into().unwrap();
556
0
    let sid = StateID::from_ne_bytes(bytes)
557
0
        .map_err(|err| DeserializeError::state_id_error(err, what))?;
558
0
    Ok((sid, StateID::SIZE))
559
0
}
560
561
/// Reads a state ID from the given slice. If the slice has insufficient
562
/// length, then this panics. Otherwise, the deserialized integer is assumed
563
/// to be a valid state ID.
564
///
565
/// This also returns the number of bytes read.
566
0
pub fn read_state_id_unchecked(slice: &[u8]) -> (StateID, usize) {
567
0
    let sid = StateID::from_ne_bytes_unchecked(
568
0
        slice[..StateID::SIZE].try_into().unwrap(),
569
0
    );
570
0
    (sid, StateID::SIZE)
571
0
}
572
573
/// Write the given state ID to the beginning of the given slice of bytes
574
/// using the specified endianness. The given slice must have length at least
575
/// `StateID::SIZE`, or else this panics. Upon success, the total number of
576
/// bytes written is returned.
577
0
pub fn write_state_id<E: Endian>(sid: StateID, dst: &mut [u8]) -> usize {
578
0
    E::write_u32(sid.as_u32(), dst);
579
0
    StateID::SIZE
580
0
}
581
582
/// Try to read a u16 as a usize from the beginning of the given slice in
583
/// native endian format. If the slice has fewer than 2 bytes or if the
584
/// deserialized number cannot be represented by usize, then this returns an
585
/// error. The error message will include the `what` description of what is
586
/// being deserialized, for better error messages. `what` should be a noun in
587
/// singular form.
588
///
589
/// Upon success, this also returns the number of bytes read.
590
0
pub fn try_read_u16_as_usize(
591
0
    slice: &[u8],
592
0
    what: &'static str,
593
0
) -> Result<(usize, usize), DeserializeError> {
594
0
    try_read_u16(slice, what).and_then(|(n, nr)| {
595
0
        usize::try_from(n)
596
0
            .map(|n| (n, nr))
597
0
            .map_err(|_| DeserializeError::invalid_usize(what))
598
0
    })
599
0
}
600
601
/// Try to read a u32 as a usize from the beginning of the given slice in
602
/// native endian format. If the slice has fewer than 4 bytes or if the
603
/// deserialized number cannot be represented by usize, then this returns an
604
/// error. The error message will include the `what` description of what is
605
/// being deserialized, for better error messages. `what` should be a noun in
606
/// singular form.
607
///
608
/// Upon success, this also returns the number of bytes read.
609
0
pub fn try_read_u32_as_usize(
610
0
    slice: &[u8],
611
0
    what: &'static str,
612
0
) -> Result<(usize, usize), DeserializeError> {
613
0
    try_read_u32(slice, what).and_then(|(n, nr)| {
614
0
        usize::try_from(n)
615
0
            .map(|n| (n, nr))
616
0
            .map_err(|_| DeserializeError::invalid_usize(what))
617
0
    })
618
0
}
619
620
/// Try to read a u16 from the beginning of the given slice in native endian
621
/// format. If the slice has fewer than 2 bytes, then this returns an error.
622
/// The error message will include the `what` description of what is being
623
/// deserialized, for better error messages. `what` should be a noun in
624
/// singular form.
625
///
626
/// Upon success, this also returns the number of bytes read.
627
0
pub fn try_read_u16(
628
0
    slice: &[u8],
629
0
    what: &'static str,
630
0
) -> Result<(u16, usize), DeserializeError> {
631
0
    if slice.len() < size_of::<u16>() {
632
0
        return Err(DeserializeError::buffer_too_small(what));
633
0
    }
634
0
    Ok((read_u16(slice), size_of::<u16>()))
635
0
}
636
637
/// Try to read a u32 from the beginning of the given slice in native endian
638
/// format. If the slice has fewer than 4 bytes, then this returns an error.
639
/// The error message will include the `what` description of what is being
640
/// deserialized, for better error messages. `what` should be a noun in
641
/// singular form.
642
///
643
/// Upon success, this also returns the number of bytes read.
644
0
pub fn try_read_u32(
645
0
    slice: &[u8],
646
0
    what: &'static str,
647
0
) -> Result<(u32, usize), DeserializeError> {
648
0
    if slice.len() < size_of::<u32>() {
649
0
        return Err(DeserializeError::buffer_too_small(what));
650
0
    }
651
0
    Ok((read_u32(slice), size_of::<u32>()))
652
0
}
653
654
/// Read a u16 from the beginning of the given slice in native endian format.
655
/// If the slice has fewer than 2 bytes, then this panics.
656
///
657
/// Marked as inline to speed up sparse searching which decodes integers from
658
/// its automaton at search time.
659
#[inline(always)]
660
0
pub fn read_u16(slice: &[u8]) -> u16 {
661
0
    let bytes: [u8; 2] = slice[..size_of::<u16>()].try_into().unwrap();
662
0
    u16::from_ne_bytes(bytes)
663
0
}
664
665
/// Read a u32 from the beginning of the given slice in native endian format.
666
/// If the slice has fewer than 4 bytes, then this panics.
667
///
668
/// Marked as inline to speed up sparse searching which decodes integers from
669
/// its automaton at search time.
670
#[inline(always)]
671
0
pub fn read_u32(slice: &[u8]) -> u32 {
672
0
    let bytes: [u8; 4] = slice[..size_of::<u32>()].try_into().unwrap();
673
0
    u32::from_ne_bytes(bytes)
674
0
}
675
676
/// Read a u64 from the beginning of the given slice in native endian format.
677
/// If the slice has fewer than 8 bytes, then this panics.
678
///
679
/// Marked as inline to speed up sparse searching which decodes integers from
680
/// its automaton at search time.
681
#[inline(always)]
682
0
pub fn read_u64(slice: &[u8]) -> u64 {
683
0
    let bytes: [u8; 8] = slice[..size_of::<u64>()].try_into().unwrap();
684
0
    u64::from_ne_bytes(bytes)
685
0
}
686
687
/// Write a variable sized integer and return the total number of bytes
688
/// written. If the slice was not big enough to contain the bytes, then this
689
/// returns an error including the "what" description in it. This does no
690
/// padding.
691
///
692
/// See: https://developers.google.com/protocol-buffers/docs/encoding#varints
693
#[allow(dead_code)]
694
0
pub fn write_varu64(
695
0
    mut n: u64,
696
0
    what: &'static str,
697
0
    dst: &mut [u8],
698
0
) -> Result<usize, SerializeError> {
699
0
    let mut i = 0;
700
0
    while n >= 0b1000_0000 {
701
0
        if i >= dst.len() {
702
0
            return Err(SerializeError::buffer_too_small(what));
703
0
        }
704
0
        dst[i] = (n as u8) | 0b1000_0000;
705
0
        n >>= 7;
706
0
        i += 1;
707
    }
708
0
    if i >= dst.len() {
709
0
        return Err(SerializeError::buffer_too_small(what));
710
0
    }
711
0
    dst[i] = n as u8;
712
0
    Ok(i + 1)
713
0
}
714
715
/// Returns the total number of bytes that would be writen to encode n as a
716
/// variable sized integer.
717
///
718
/// See: https://developers.google.com/protocol-buffers/docs/encoding#varints
719
#[allow(dead_code)]
720
0
pub fn write_varu64_len(mut n: u64) -> usize {
721
0
    let mut i = 0;
722
0
    while n >= 0b1000_0000 {
723
0
        n >>= 7;
724
0
        i += 1;
725
0
    }
726
0
    i + 1
727
0
}
728
729
/// Like read_varu64, but attempts to cast the result to usize. If the integer
730
/// cannot fit into a usize, then an error is returned.
731
#[allow(dead_code)]
732
0
pub fn read_varu64_as_usize(
733
0
    slice: &[u8],
734
0
    what: &'static str,
735
0
) -> Result<(usize, usize), DeserializeError> {
736
0
    let (n, nread) = read_varu64(slice, what)?;
737
0
    let n = usize::try_from(n)
738
0
        .map_err(|_| DeserializeError::invalid_usize(what))?;
739
0
    Ok((n, nread))
740
0
}
741
742
/// Reads a variable sized integer from the beginning of slice, and returns the
743
/// integer along with the total number of bytes read. If a valid variable
744
/// sized integer could not be found, then an error is returned that includes
745
/// the "what" description in it.
746
///
747
/// https://developers.google.com/protocol-buffers/docs/encoding#varints
748
#[allow(dead_code)]
749
0
pub fn read_varu64(
750
0
    slice: &[u8],
751
0
    what: &'static str,
752
0
) -> Result<(u64, usize), DeserializeError> {
753
0
    let mut n: u64 = 0;
754
0
    let mut shift: u32 = 0;
755
0
    // The biggest possible value is u64::MAX, which needs all 64 bits which
756
0
    // requires 10 bytes (because 7 * 9 < 64). We use a limit to avoid reading
757
0
    // an unnecessary number of bytes.
758
0
    let limit = cmp::min(slice.len(), 10);
759
0
    for (i, &b) in slice[..limit].iter().enumerate() {
760
0
        if b < 0b1000_0000 {
761
0
            return match (b as u64).checked_shl(shift) {
762
0
                None => Err(DeserializeError::invalid_varint(what)),
763
0
                Some(b) => Ok((n | b, i + 1)),
764
            };
765
0
        }
766
0
        match ((b as u64) & 0b0111_1111).checked_shl(shift) {
767
0
            None => return Err(DeserializeError::invalid_varint(what)),
768
0
            Some(b) => n |= b,
769
0
        }
770
0
        shift += 7;
771
    }
772
0
    Err(DeserializeError::invalid_varint(what))
773
0
}
774
775
/// Checks that the given slice has some minimal length. If it's smaller than
776
/// the bound given, then a "buffer too small" error is returned with `what`
777
/// describing what the buffer represents.
778
0
pub fn check_slice_len<T>(
779
0
    slice: &[T],
780
0
    at_least_len: usize,
781
0
    what: &'static str,
782
0
) -> Result<(), DeserializeError> {
783
0
    if slice.len() < at_least_len {
784
0
        return Err(DeserializeError::buffer_too_small(what));
785
0
    }
786
0
    Ok(())
787
0
}
788
789
/// Multiply the given numbers, and on overflow, return an error that includes
790
/// 'what' in the error message.
791
///
792
/// This is useful when doing arithmetic with untrusted data.
793
0
pub fn mul(
794
0
    a: usize,
795
0
    b: usize,
796
0
    what: &'static str,
797
0
) -> Result<usize, DeserializeError> {
798
0
    match a.checked_mul(b) {
799
0
        Some(c) => Ok(c),
800
0
        None => Err(DeserializeError::arithmetic_overflow(what)),
801
    }
802
0
}
803
804
/// Add the given numbers, and on overflow, return an error that includes
805
/// 'what' in the error message.
806
///
807
/// This is useful when doing arithmetic with untrusted data.
808
0
pub fn add(
809
0
    a: usize,
810
0
    b: usize,
811
0
    what: &'static str,
812
0
) -> Result<usize, DeserializeError> {
813
0
    match a.checked_add(b) {
814
0
        Some(c) => Ok(c),
815
0
        None => Err(DeserializeError::arithmetic_overflow(what)),
816
    }
817
0
}
818
819
/// Shift `a` left by `b`, and on overflow, return an error that includes
820
/// 'what' in the error message.
821
///
822
/// This is useful when doing arithmetic with untrusted data.
823
0
pub fn shl(
824
0
    a: usize,
825
0
    b: usize,
826
0
    what: &'static str,
827
0
) -> Result<usize, DeserializeError> {
828
0
    let amount = u32::try_from(b)
829
0
        .map_err(|_| DeserializeError::arithmetic_overflow(what))?;
830
0
    match a.checked_shl(amount) {
831
0
        Some(c) => Ok(c),
832
0
        None => Err(DeserializeError::arithmetic_overflow(what)),
833
    }
834
0
}
835
836
/// A simple trait for writing code generic over endianness.
837
///
838
/// This is similar to what byteorder provides, but we only need a very small
839
/// subset.
840
pub trait Endian {
841
    /// Writes a u16 to the given destination buffer in a particular
842
    /// endianness. If the destination buffer has a length smaller than 2, then
843
    /// this panics.
844
    fn write_u16(n: u16, dst: &mut [u8]);
845
846
    /// Writes a u32 to the given destination buffer in a particular
847
    /// endianness. If the destination buffer has a length smaller than 4, then
848
    /// this panics.
849
    fn write_u32(n: u32, dst: &mut [u8]);
850
851
    /// Writes a u64 to the given destination buffer in a particular
852
    /// endianness. If the destination buffer has a length smaller than 8, then
853
    /// this panics.
854
    fn write_u64(n: u64, dst: &mut [u8]);
855
}
856
857
/// Little endian writing.
858
pub enum LE {}
859
/// Big endian writing.
860
pub enum BE {}
861
862
#[cfg(target_endian = "little")]
863
pub type NE = LE;
864
#[cfg(target_endian = "big")]
865
pub type NE = BE;
866
867
impl Endian for LE {
868
0
    fn write_u16(n: u16, dst: &mut [u8]) {
869
0
        dst[..2].copy_from_slice(&n.to_le_bytes());
870
0
    }
871
872
0
    fn write_u32(n: u32, dst: &mut [u8]) {
873
0
        dst[..4].copy_from_slice(&n.to_le_bytes());
874
0
    }
875
876
0
    fn write_u64(n: u64, dst: &mut [u8]) {
877
0
        dst[..8].copy_from_slice(&n.to_le_bytes());
878
0
    }
879
}
880
881
impl Endian for BE {
882
0
    fn write_u16(n: u16, dst: &mut [u8]) {
883
0
        dst[..2].copy_from_slice(&n.to_be_bytes());
884
0
    }
885
886
0
    fn write_u32(n: u32, dst: &mut [u8]) {
887
0
        dst[..4].copy_from_slice(&n.to_be_bytes());
888
0
    }
889
890
0
    fn write_u64(n: u64, dst: &mut [u8]) {
891
0
        dst[..8].copy_from_slice(&n.to_be_bytes());
892
0
    }
893
}
894
895
/// Returns the number of additional bytes required to add to the given length
896
/// in order to make the total length a multiple of 4. The return value is
897
/// always less than 4.
898
0
pub fn padding_len(non_padding_len: usize) -> usize {
899
0
    (4 - (non_padding_len & 0b11)) & 0b11
900
0
}
901
902
#[cfg(all(test, feature = "alloc"))]
903
mod tests {
904
    use super::*;
905
906
    #[test]
907
    fn labels() {
908
        let mut buf = [0; 1024];
909
910
        let nwrite = write_label("fooba", &mut buf).unwrap();
911
        assert_eq!(nwrite, 8);
912
        assert_eq!(&buf[..nwrite], b"fooba\x00\x00\x00");
913
914
        let nread = read_label(&buf, "fooba").unwrap();
915
        assert_eq!(nread, 8);
916
    }
917
918
    #[test]
919
    #[should_panic]
920
    fn bad_label_interior_nul() {
921
        // interior NULs are not allowed
922
        write_label("foo\x00bar", &mut [0; 1024]).unwrap();
923
    }
924
925
    #[test]
926
    fn bad_label_almost_too_long() {
927
        // ok
928
        write_label(&"z".repeat(255), &mut [0; 1024]).unwrap();
929
    }
930
931
    #[test]
932
    #[should_panic]
933
    fn bad_label_too_long() {
934
        // labels longer than 255 bytes are banned
935
        write_label(&"z".repeat(256), &mut [0; 1024]).unwrap();
936
    }
937
938
    #[test]
939
    fn padding() {
940
        assert_eq!(0, padding_len(8));
941
        assert_eq!(3, padding_len(9));
942
        assert_eq!(2, padding_len(10));
943
        assert_eq!(1, padding_len(11));
944
        assert_eq!(0, padding_len(12));
945
        assert_eq!(3, padding_len(13));
946
        assert_eq!(2, padding_len(14));
947
        assert_eq!(1, padding_len(15));
948
        assert_eq!(0, padding_len(16));
949
    }
950
}