Coverage Report

Created: 2025-07-01 07:43

/rust/registry/src/index.crates.io-6f17d22bba15001f/fdeflate-0.3.7/src/decompress.rs
Line
Count
Source (jump to first uncovered line)
1
use simd_adler32::Adler32;
2
3
use crate::{
4
    huffman::{self, build_table},
5
    tables::{
6
        self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FIXED_DIST_TABLE,
7
        FIXED_LITLEN_TABLE, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA, LITLEN_TABLE_ENTRIES,
8
    },
9
};
10
11
/// An error encountered while decompressing a deflate stream.
12
#[derive(Debug, PartialEq)]
13
pub enum DecompressionError {
14
    /// The zlib header is corrupt.
15
    BadZlibHeader,
16
    /// All input was consumed, but the end of the stream hasn't been reached.
17
    InsufficientInput,
18
    /// A block header specifies an invalid block type.
19
    InvalidBlockType,
20
    /// An uncompressed block's NLEN value is invalid.
21
    InvalidUncompressedBlockLength,
22
    /// Too many literals were specified.
23
    InvalidHlit,
24
    /// Too many distance codes were specified.
25
    InvalidHdist,
26
    /// Attempted to repeat a previous code before reading any codes, or past the end of the code
27
    /// lengths.
28
    InvalidCodeLengthRepeat,
29
    /// The stream doesn't specify a valid huffman tree.
30
    BadCodeLengthHuffmanTree,
31
    /// The stream doesn't specify a valid huffman tree.
32
    BadLiteralLengthHuffmanTree,
33
    /// The stream doesn't specify a valid huffman tree.
34
    BadDistanceHuffmanTree,
35
    /// The stream contains a literal/length code that was not allowed by the header.
36
    InvalidLiteralLengthCode,
37
    /// The stream contains a distance code that was not allowed by the header.
38
    InvalidDistanceCode,
39
    /// The stream contains contains back-reference as the first symbol.
40
    InputStartsWithRun,
41
    /// The stream contains a back-reference that is too far back.
42
    DistanceTooFarBack,
43
    /// The deflate stream checksum is incorrect.
44
    WrongChecksum,
45
    /// Extra input data.
46
    ExtraInput,
47
}
48
49
struct BlockHeader {
50
    hlit: usize,
51
    hdist: usize,
52
    hclen: usize,
53
    num_lengths_read: usize,
54
55
    /// Low 3-bits are code length code length, high 5-bits are code length code.
56
    table: [u32; 128],
57
    code_lengths: [u8; 320],
58
}
59
60
pub const LITERAL_ENTRY: u32 = 0x8000;
61
pub const EXCEPTIONAL_ENTRY: u32 = 0x4000;
62
pub const SECONDARY_TABLE_ENTRY: u32 = 0x2000;
63
64
/// The Decompressor state for a compressed block.
65
#[derive(Eq, PartialEq, Debug)]
66
struct CompressedBlock {
67
    litlen_table: Box<[u32; 4096]>,
68
    secondary_table: Vec<u16>,
69
70
    dist_table: Box<[u32; 512]>,
71
    dist_secondary_table: Vec<u16>,
72
73
    eof_code: u16,
74
    eof_mask: u16,
75
    eof_bits: u8,
76
}
77
78
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
79
enum State {
80
    ZlibHeader,
81
    BlockHeader,
82
    CodeLengthCodes,
83
    CodeLengths,
84
    CompressedData,
85
    UncompressedData,
86
    Checksum,
87
    Done,
88
}
89
90
/// Decompressor for arbitrary zlib streams.
91
pub struct Decompressor {
92
    /// State for decoding a compressed block.
93
    compression: CompressedBlock,
94
    // State for decoding a block header.
95
    header: BlockHeader,
96
    // Number of bytes left for uncompressed block.
97
    uncompressed_bytes_left: u16,
98
99
    buffer: u64,
100
    nbits: u8,
101
102
    queued_rle: Option<(u8, usize)>,
103
    queued_backref: Option<(usize, usize)>,
104
    last_block: bool,
105
    fixed_table: bool,
106
107
    state: State,
108
    checksum: Adler32,
109
    ignore_adler32: bool,
110
}
111
112
impl Default for Decompressor {
113
0
    fn default() -> Self {
114
0
        Self::new()
115
0
    }
116
}
117
118
impl Decompressor {
119
    /// Create a new decompressor.
120
0
    pub fn new() -> Self {
121
0
        Self {
122
0
            buffer: 0,
123
0
            nbits: 0,
124
0
            compression: CompressedBlock {
125
0
                litlen_table: Box::new([0; 4096]),
126
0
                dist_table: Box::new([0; 512]),
127
0
                secondary_table: Vec::new(),
128
0
                dist_secondary_table: Vec::new(),
129
0
                eof_code: 0,
130
0
                eof_mask: 0,
131
0
                eof_bits: 0,
132
0
            },
133
0
            header: BlockHeader {
134
0
                hlit: 0,
135
0
                hdist: 0,
136
0
                hclen: 0,
137
0
                table: [0; 128],
138
0
                num_lengths_read: 0,
139
0
                code_lengths: [0; 320],
140
0
            },
141
0
            uncompressed_bytes_left: 0,
142
0
            queued_rle: None,
143
0
            queued_backref: None,
144
0
            checksum: Adler32::new(),
145
0
            state: State::ZlibHeader,
146
0
            last_block: false,
147
0
            ignore_adler32: false,
148
0
            fixed_table: false,
149
0
        }
150
0
    }
151
152
    /// Ignore the checksum at the end of the stream.
153
0
    pub fn ignore_adler32(&mut self) {
154
0
        self.ignore_adler32 = true;
155
0
    }
156
157
0
    fn fill_buffer(&mut self, input: &mut &[u8]) {
158
0
        if input.len() >= 8 {
159
0
            self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits;
160
0
            *input = &input[(63 - self.nbits as usize) / 8..];
161
0
            self.nbits |= 56;
162
0
        } else {
163
0
            let nbytes = input.len().min((63 - self.nbits as usize) / 8);
164
0
            let mut input_data = [0; 8];
165
0
            input_data[..nbytes].copy_from_slice(&input[..nbytes]);
166
0
            self.buffer |= u64::from_le_bytes(input_data)
167
0
                .checked_shl(self.nbits as u32)
168
0
                .unwrap_or(0);
169
0
            self.nbits += nbytes as u8 * 8;
170
0
            *input = &input[nbytes..];
171
0
        }
172
0
    }
173
174
0
    fn peak_bits(&mut self, nbits: u8) -> u64 {
175
0
        debug_assert!(nbits <= 56 && nbits <= self.nbits);
176
0
        self.buffer & ((1u64 << nbits) - 1)
177
0
    }
178
0
    fn consume_bits(&mut self, nbits: u8) {
179
0
        debug_assert!(self.nbits >= nbits);
180
0
        self.buffer >>= nbits;
181
0
        self.nbits -= nbits;
182
0
    }
183
184
0
    fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
185
0
        self.fill_buffer(remaining_input);
186
0
        if self.nbits < 10 {
187
0
            return Ok(());
188
0
        }
189
0
190
0
        let start = self.peak_bits(3);
191
0
        self.last_block = start & 1 != 0;
192
0
        match start >> 1 {
193
            0b00 => {
194
0
                let align_bits = (self.nbits - 3) % 8;
195
0
                let header_bits = 3 + 32 + align_bits;
196
0
                if self.nbits < header_bits {
197
0
                    return Ok(());
198
0
                }
199
0
200
0
                let len = (self.peak_bits(align_bits + 19) >> (align_bits + 3)) as u16;
201
0
                let nlen = (self.peak_bits(header_bits) >> (align_bits + 19)) as u16;
202
0
                if nlen != !len {
203
0
                    return Err(DecompressionError::InvalidUncompressedBlockLength);
204
0
                }
205
0
206
0
                self.state = State::UncompressedData;
207
0
                self.uncompressed_bytes_left = len;
208
0
                self.consume_bits(header_bits);
209
0
                Ok(())
210
            }
211
            0b01 => {
212
0
                self.consume_bits(3);
213
0
214
0
                // Check for an entirely empty blocks which can happen if there are "partial
215
0
                // flushes" in the deflate stream. With fixed huffman codes, the EOF symbol is
216
0
                // 7-bits of zeros so we peak ahead and see if the next 7-bits are all zero.
217
0
                if self.peak_bits(7) == 0 {
218
0
                    self.consume_bits(7);
219
0
                    if self.last_block {
220
0
                        self.state = State::Checksum;
221
0
                        return Ok(());
222
0
                    }
223
224
                    // At this point we've consumed the entire block and need to read the next block
225
                    // header. If tail call optimization were guaranteed, we could just recurse
226
                    // here. But without it, a long sequence of empty fixed-blocks might cause a
227
                    // stack overflow. Instead, we consume all empty blocks in a loop and then
228
                    // recurse. This is the only recursive call this function, and thus is safe.
229
0
                    while self.nbits >= 10 && self.peak_bits(10) == 0b010 {
230
0
                        self.consume_bits(10);
231
0
                        self.fill_buffer(remaining_input);
232
0
                    }
233
0
                    return self.read_block_header(remaining_input);
234
0
                }
235
0
236
0
                // Build decoding tables if the previous block wasn't also a fixed block.
237
0
                if !self.fixed_table {
238
0
                    self.fixed_table = true;
239
0
                    for chunk in self.compression.litlen_table.chunks_exact_mut(512) {
240
0
                        chunk.copy_from_slice(&FIXED_LITLEN_TABLE);
241
0
                    }
242
0
                    for chunk in self.compression.dist_table.chunks_exact_mut(32) {
243
0
                        chunk.copy_from_slice(&FIXED_DIST_TABLE);
244
0
                    }
245
0
                    self.compression.eof_bits = 7;
246
0
                    self.compression.eof_code = 0;
247
0
                    self.compression.eof_mask = 0x7f;
248
0
                }
249
250
0
                self.state = State::CompressedData;
251
0
                Ok(())
252
            }
253
            0b10 => {
254
0
                if self.nbits < 17 {
255
0
                    return Ok(());
256
0
                }
257
0
258
0
                self.header.hlit = (self.peak_bits(8) >> 3) as usize + 257;
259
0
                self.header.hdist = (self.peak_bits(13) >> 8) as usize + 1;
260
0
                self.header.hclen = (self.peak_bits(17) >> 13) as usize + 4;
261
0
                if self.header.hlit > 286 {
262
0
                    return Err(DecompressionError::InvalidHlit);
263
0
                }
264
0
                if self.header.hdist > 30 {
265
0
                    return Err(DecompressionError::InvalidHdist);
266
0
                }
267
0
268
0
                self.consume_bits(17);
269
0
                self.state = State::CodeLengthCodes;
270
0
                self.fixed_table = false;
271
0
                Ok(())
272
            }
273
0
            0b11 => Err(DecompressionError::InvalidBlockType),
274
0
            _ => unreachable!(),
275
        }
276
0
    }
277
278
0
    fn read_code_length_codes(
279
0
        &mut self,
280
0
        remaining_input: &mut &[u8],
281
0
    ) -> Result<(), DecompressionError> {
282
0
        self.fill_buffer(remaining_input);
283
0
        if self.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen {
284
0
            return Ok(());
285
0
        }
286
0
287
0
        let mut code_length_lengths = [0; 19];
288
0
        for i in 0..self.header.hclen {
289
0
            code_length_lengths[CLCL_ORDER[i]] = self.peak_bits(3) as u8;
290
0
            self.consume_bits(3);
291
0
292
0
            // We need to refill the buffer after reading 3 * 18 = 54 bits since the buffer holds
293
0
            // between 56 and 63 bits total.
294
0
            if i == 17 {
295
0
                self.fill_buffer(remaining_input);
296
0
            }
297
        }
298
299
0
        let mut codes = [0; 19];
300
0
        if !build_table(
301
0
            &code_length_lengths,
302
0
            &[],
303
0
            &mut codes,
304
0
            &mut self.header.table,
305
0
            &mut Vec::new(),
306
0
            false,
307
0
            false,
308
0
        ) {
309
0
            return Err(DecompressionError::BadCodeLengthHuffmanTree);
310
0
        }
311
0
312
0
        self.state = State::CodeLengths;
313
0
        self.header.num_lengths_read = 0;
314
0
        Ok(())
315
0
    }
316
317
0
    fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
318
0
        let total_lengths = self.header.hlit + self.header.hdist;
319
0
        while self.header.num_lengths_read < total_lengths {
320
0
            self.fill_buffer(remaining_input);
321
0
            if self.nbits < 7 {
322
0
                return Ok(());
323
0
            }
324
0
325
0
            let code = self.peak_bits(7);
326
0
            let entry = self.header.table[code as usize];
327
0
            let length = (entry & 0x7) as u8;
328
0
            let symbol = (entry >> 16) as u8;
329
0
330
0
            debug_assert!(length != 0);
331
0
            match symbol {
332
0
                0..=15 => {
333
0
                    self.header.code_lengths[self.header.num_lengths_read] = symbol;
334
0
                    self.header.num_lengths_read += 1;
335
0
                    self.consume_bits(length);
336
0
                }
337
0
                16..=18 => {
338
0
                    let (base_repeat, extra_bits) = match symbol {
339
0
                        16 => (3, 2),
340
0
                        17 => (3, 3),
341
0
                        18 => (11, 7),
342
0
                        _ => unreachable!(),
343
                    };
344
345
0
                    if self.nbits < length + extra_bits {
346
0
                        return Ok(());
347
0
                    }
348
349
0
                    let value = match symbol {
350
                        16 => {
351
0
                            self.header.code_lengths[self
352
0
                                .header
353
0
                                .num_lengths_read
354
0
                                .checked_sub(1)
355
0
                                .ok_or(DecompressionError::InvalidCodeLengthRepeat)?]
356
                            // TODO: is this right?
357
                        }
358
0
                        17 => 0,
359
0
                        18 => 0,
360
0
                        _ => unreachable!(),
361
                    };
362
363
0
                    let repeat =
364
0
                        (self.peak_bits(length + extra_bits) >> length) as usize + base_repeat;
365
0
                    if self.header.num_lengths_read + repeat > total_lengths {
366
0
                        return Err(DecompressionError::InvalidCodeLengthRepeat);
367
0
                    }
368
369
0
                    for i in 0..repeat {
370
0
                        self.header.code_lengths[self.header.num_lengths_read + i] = value;
371
0
                    }
372
0
                    self.header.num_lengths_read += repeat;
373
0
                    self.consume_bits(length + extra_bits);
374
                }
375
0
                _ => unreachable!(),
376
            }
377
        }
378
379
0
        self.header
380
0
            .code_lengths
381
0
            .copy_within(self.header.hlit..total_lengths, 288);
382
0
        for i in self.header.hlit..288 {
383
0
            self.header.code_lengths[i] = 0;
384
0
        }
385
0
        for i in 288 + self.header.hdist..320 {
386
0
            self.header.code_lengths[i] = 0;
387
0
        }
388
389
0
        Self::build_tables(
390
0
            self.header.hlit,
391
0
            &self.header.code_lengths,
392
0
            &mut self.compression,
393
0
        )?;
394
0
        self.state = State::CompressedData;
395
0
        Ok(())
396
0
    }
397
398
0
    fn build_tables(
399
0
        hlit: usize,
400
0
        code_lengths: &[u8],
401
0
        compression: &mut CompressedBlock,
402
0
    ) -> Result<(), DecompressionError> {
403
0
        // If there is no code assigned for the EOF symbol then the bitstream is invalid.
404
0
        if code_lengths[256] == 0 {
405
            // TODO: Return a dedicated error in this case.
406
0
            return Err(DecompressionError::BadLiteralLengthHuffmanTree);
407
0
        }
408
0
409
0
        let mut codes = [0; 288];
410
0
        compression.secondary_table.clear();
411
0
        if !huffman::build_table(
412
0
            &code_lengths[..hlit],
413
0
            &LITLEN_TABLE_ENTRIES,
414
0
            &mut codes[..hlit],
415
0
            &mut *compression.litlen_table,
416
0
            &mut compression.secondary_table,
417
0
            false,
418
0
            true,
419
0
        ) {
420
0
            return Err(DecompressionError::BadCodeLengthHuffmanTree);
421
0
        }
422
0
423
0
        compression.eof_code = codes[256];
424
0
        compression.eof_mask = (1 << code_lengths[256]) - 1;
425
0
        compression.eof_bits = code_lengths[256];
426
0
427
0
        // Build the distance code table.
428
0
        let lengths = &code_lengths[288..320];
429
0
        if lengths == [0; 32] {
430
0
            compression.dist_table.fill(0);
431
0
        } else {
432
0
            let mut dist_codes = [0; 32];
433
0
            if !huffman::build_table(
434
0
                lengths,
435
0
                &tables::DISTANCE_TABLE_ENTRIES,
436
0
                &mut dist_codes,
437
0
                &mut *compression.dist_table,
438
0
                &mut compression.dist_secondary_table,
439
0
                true,
440
0
                false,
441
0
            ) {
442
0
                return Err(DecompressionError::BadDistanceHuffmanTree);
443
0
            }
444
        }
445
446
0
        Ok(())
447
0
    }
448
449
0
    fn read_compressed(
450
0
        &mut self,
451
0
        remaining_input: &mut &[u8],
452
0
        output: &mut [u8],
453
0
        mut output_index: usize,
454
0
    ) -> Result<usize, DecompressionError> {
455
0
        // Fast decoding loop.
456
0
        //
457
0
        // This loop is optimized for speed and is the main decoding loop for the decompressor,
458
0
        // which is used when there are at least 8 bytes of input and output data available. It
459
0
        // assumes that the bitbuffer is full (nbits >= 56) and that litlen_entry has been loaded.
460
0
        //
461
0
        // These assumptions enable a few optimizations:
462
0
        // - Nearly all checks for nbits are avoided.
463
0
        // - Checking the input size is optimized out in the refill function call.
464
0
        // - The litlen_entry for the next loop iteration can be loaded in parallel with refilling
465
0
        //   the bit buffer. This is because when the input is non-empty, the bit buffer actually
466
0
        //   has 64-bits of valid data (even though nbits will be in 56..=63).
467
0
        self.fill_buffer(remaining_input);
468
0
        let mut litlen_entry = self.compression.litlen_table[(self.buffer & 0xfff) as usize];
469
0
        while self.state == State::CompressedData
470
0
            && output_index + 8 <= output.len()
471
0
            && remaining_input.len() >= 8
472
        {
473
            // First check whether the next symbol is a literal. This code does up to 2 additional
474
            // table lookups to decode more literals.
475
            let mut bits;
476
0
            let mut litlen_code_bits = litlen_entry as u8;
477
0
            if litlen_entry & LITERAL_ENTRY != 0 {
478
0
                let litlen_entry2 = self.compression.litlen_table
479
0
                    [(self.buffer >> litlen_code_bits & 0xfff) as usize];
480
0
                let litlen_code_bits2 = litlen_entry2 as u8;
481
0
                let litlen_entry3 = self.compression.litlen_table
482
0
                    [(self.buffer >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize];
483
0
                let litlen_code_bits3 = litlen_entry3 as u8;
484
0
                let litlen_entry4 = self.compression.litlen_table[(self.buffer
485
0
                    >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3)
486
0
                    & 0xfff)
487
0
                    as usize];
488
0
489
0
                let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
490
0
                output[output_index] = (litlen_entry >> 16) as u8;
491
0
                output[output_index + 1] = (litlen_entry >> 24) as u8;
492
0
                output_index += advance_output_bytes;
493
0
494
0
                if litlen_entry2 & LITERAL_ENTRY != 0 {
495
0
                    let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize;
496
0
                    output[output_index] = (litlen_entry2 >> 16) as u8;
497
0
                    output[output_index + 1] = (litlen_entry2 >> 24) as u8;
498
0
                    output_index += advance_output_bytes2;
499
0
500
0
                    if litlen_entry3 & LITERAL_ENTRY != 0 {
501
0
                        let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize;
502
0
                        output[output_index] = (litlen_entry3 >> 16) as u8;
503
0
                        output[output_index + 1] = (litlen_entry3 >> 24) as u8;
504
0
                        output_index += advance_output_bytes3;
505
0
506
0
                        litlen_entry = litlen_entry4;
507
0
                        self.consume_bits(litlen_code_bits + litlen_code_bits2 + litlen_code_bits3);
508
0
                        self.fill_buffer(remaining_input);
509
0
                        continue;
510
0
                    } else {
511
0
                        self.consume_bits(litlen_code_bits + litlen_code_bits2);
512
0
                        litlen_entry = litlen_entry3;
513
0
                        litlen_code_bits = litlen_code_bits3;
514
0
                        self.fill_buffer(remaining_input);
515
0
                        bits = self.buffer;
516
0
                    }
517
                } else {
518
0
                    self.consume_bits(litlen_code_bits);
519
0
                    bits = self.buffer;
520
0
                    litlen_entry = litlen_entry2;
521
0
                    litlen_code_bits = litlen_code_bits2;
522
0
                    if self.nbits < 48 {
523
0
                        self.fill_buffer(remaining_input);
524
0
                    }
525
                }
526
0
            } else {
527
0
                bits = self.buffer;
528
0
            }
529
530
            // The next symbol is either a 13+ bit literal, back-reference, or an EOF symbol.
531
0
            let (length_base, length_extra_bits, litlen_code_bits) =
532
0
                if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
533
0
                    (
534
0
                        litlen_entry >> 16,
535
0
                        (litlen_entry >> 8) as u8,
536
0
                        litlen_code_bits,
537
0
                    )
538
0
                } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
539
0
                    let secondary_table_index =
540
0
                        (litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
541
0
                    let secondary_entry =
542
0
                        self.compression.secondary_table[secondary_table_index as usize];
543
0
                    let litlen_symbol = secondary_entry >> 4;
544
0
                    let litlen_code_bits = (secondary_entry & 0xf) as u8;
545
0
546
0
                    match litlen_symbol {
547
0
                        0..=255 => {
548
0
                            self.consume_bits(litlen_code_bits);
549
0
                            litlen_entry =
550
0
                                self.compression.litlen_table[(self.buffer & 0xfff) as usize];
551
0
                            self.fill_buffer(remaining_input);
552
0
                            output[output_index] = litlen_symbol as u8;
553
0
                            output_index += 1;
554
0
                            continue;
555
                        }
556
                        256 => {
557
0
                            self.consume_bits(litlen_code_bits);
558
0
                            self.state = match self.last_block {
559
0
                                true => State::Checksum,
560
0
                                false => State::BlockHeader,
561
                            };
562
0
                            break;
563
                        }
564
0
                        _ => (
565
0
                            LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
566
0
                            LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
567
0
                            litlen_code_bits,
568
0
                        ),
569
                    }
570
0
                } else if litlen_code_bits == 0 {
571
0
                    return Err(DecompressionError::InvalidLiteralLengthCode);
572
                } else {
573
0
                    self.consume_bits(litlen_code_bits);
574
0
                    self.state = match self.last_block {
575
0
                        true => State::Checksum,
576
0
                        false => State::BlockHeader,
577
                    };
578
0
                    break;
579
                };
580
0
            bits >>= litlen_code_bits;
581
0
582
0
            let length_extra_mask = (1 << length_extra_bits) - 1;
583
0
            let length = length_base as usize + (bits & length_extra_mask) as usize;
584
0
            bits >>= length_extra_bits;
585
0
586
0
            let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize];
587
0
            let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
588
0
                (
589
0
                    (dist_entry >> 16) as u16,
590
0
                    (dist_entry >> 8) as u8 & 0xf,
591
0
                    dist_entry as u8,
592
0
                )
593
0
            } else if dist_entry >> 8 == 0 {
594
0
                return Err(DecompressionError::InvalidDistanceCode);
595
            } else {
596
0
                let secondary_table_index =
597
0
                    (dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
598
0
                let secondary_entry =
599
0
                    self.compression.dist_secondary_table[secondary_table_index as usize];
600
0
                let dist_symbol = (secondary_entry >> 4) as usize;
601
0
                if dist_symbol >= 30 {
602
0
                    return Err(DecompressionError::InvalidDistanceCode);
603
0
                }
604
0
605
0
                (
606
0
                    DIST_SYM_TO_DIST_BASE[dist_symbol],
607
0
                    DIST_SYM_TO_DIST_EXTRA[dist_symbol],
608
0
                    (secondary_entry & 0xf) as u8,
609
0
                )
610
            };
611
0
            bits >>= dist_code_bits;
612
0
613
0
            let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
614
0
            if dist > output_index {
615
0
                return Err(DecompressionError::DistanceTooFarBack);
616
0
            }
617
0
618
0
            self.consume_bits(
619
0
                litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits,
620
0
            );
621
0
            self.fill_buffer(remaining_input);
622
0
            litlen_entry = self.compression.litlen_table[(self.buffer & 0xfff) as usize];
623
0
624
0
            let copy_length = length.min(output.len() - output_index);
625
0
            if dist == 1 {
626
0
                let last = output[output_index - 1];
627
0
                output[output_index..][..copy_length].fill(last);
628
0
629
0
                if copy_length < length {
630
0
                    self.queued_rle = Some((last, length - copy_length));
631
0
                    output_index = output.len();
632
0
                    break;
633
0
                }
634
0
            } else if output_index + length + 15 <= output.len() {
635
0
                let start = output_index - dist;
636
0
                output.copy_within(start..start + 16, output_index);
637
0
638
0
                if length > 16 || dist < 16 {
639
0
                    for i in (0..length).step_by(dist.min(16)).skip(1) {
640
0
                        output.copy_within(start + i..start + i + 16, output_index + i);
641
0
                    }
642
0
                }
643
            } else {
644
0
                if dist < copy_length {
645
0
                    for i in 0..copy_length {
646
0
                        output[output_index + i] = output[output_index + i - dist];
647
0
                    }
648
                } else {
649
0
                    output.copy_within(
650
0
                        output_index - dist..output_index + copy_length - dist,
651
0
                        output_index,
652
0
                    )
653
                }
654
655
0
                if copy_length < length {
656
0
                    self.queued_backref = Some((dist, length - copy_length));
657
0
                    output_index = output.len();
658
0
                    break;
659
0
                }
660
            }
661
0
            output_index += copy_length;
662
        }
663
664
        // Careful decoding loop.
665
        //
666
        // This loop processes the remaining input when we're too close to the end of the input or
667
        // output to use the fast loop.
668
0
        while let State::CompressedData = self.state {
669
0
            self.fill_buffer(remaining_input);
670
0
            if output_index == output.len() {
671
0
                break;
672
0
            }
673
0
674
0
            let mut bits = self.buffer;
675
0
            let litlen_entry = self.compression.litlen_table[(bits & 0xfff) as usize];
676
0
            let litlen_code_bits = litlen_entry as u8;
677
0
678
0
            if litlen_entry & LITERAL_ENTRY != 0 {
679
                // Fast path: the next symbol is <= 12 bits and a literal, the table specifies the
680
                // output bytes and we can directly write them to the output buffer.
681
0
                let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
682
0
683
0
                if self.nbits < litlen_code_bits {
684
0
                    break;
685
0
                } else if output_index + 1 < output.len() {
686
0
                    output[output_index] = (litlen_entry >> 16) as u8;
687
0
                    output[output_index + 1] = (litlen_entry >> 24) as u8;
688
0
                    output_index += advance_output_bytes;
689
0
                    self.consume_bits(litlen_code_bits);
690
0
                    continue;
691
0
                } else if output_index + advance_output_bytes == output.len() {
692
0
                    debug_assert_eq!(advance_output_bytes, 1);
693
0
                    output[output_index] = (litlen_entry >> 16) as u8;
694
0
                    output_index += 1;
695
0
                    self.consume_bits(litlen_code_bits);
696
0
                    break;
697
                } else {
698
0
                    debug_assert_eq!(advance_output_bytes, 2);
699
0
                    output[output_index] = (litlen_entry >> 16) as u8;
700
0
                    self.queued_rle = Some(((litlen_entry >> 24) as u8, 1));
701
0
                    output_index += 1;
702
0
                    self.consume_bits(litlen_code_bits);
703
0
                    break;
704
                }
705
0
            }
706
707
0
            let (length_base, length_extra_bits, litlen_code_bits) =
708
0
                if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
709
0
                    (
710
0
                        litlen_entry >> 16,
711
0
                        (litlen_entry >> 8) as u8,
712
0
                        litlen_code_bits,
713
0
                    )
714
0
                } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
715
0
                    let secondary_table_index =
716
0
                        (litlen_entry >> 16) + ((bits >> 12) as u32 & (litlen_entry & 0xff));
717
0
                    let secondary_entry =
718
0
                        self.compression.secondary_table[secondary_table_index as usize];
719
0
                    let litlen_symbol = secondary_entry >> 4;
720
0
                    let litlen_code_bits = (secondary_entry & 0xf) as u8;
721
0
722
0
                    if self.nbits < litlen_code_bits {
723
0
                        break;
724
0
                    } else if litlen_symbol < 256 {
725
0
                        self.consume_bits(litlen_code_bits);
726
0
                        output[output_index] = litlen_symbol as u8;
727
0
                        output_index += 1;
728
0
                        continue;
729
0
                    } else if litlen_symbol == 256 {
730
0
                        self.consume_bits(litlen_code_bits);
731
0
                        self.state = match self.last_block {
732
0
                            true => State::Checksum,
733
0
                            false => State::BlockHeader,
734
                        };
735
0
                        break;
736
0
                    }
737
0
738
0
                    (
739
0
                        LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
740
0
                        LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
741
0
                        litlen_code_bits,
742
0
                    )
743
0
                } else if litlen_code_bits == 0 {
744
0
                    return Err(DecompressionError::InvalidLiteralLengthCode);
745
                } else {
746
0
                    if self.nbits < litlen_code_bits {
747
0
                        break;
748
0
                    }
749
0
                    self.consume_bits(litlen_code_bits);
750
0
                    self.state = match self.last_block {
751
0
                        true => State::Checksum,
752
0
                        false => State::BlockHeader,
753
                    };
754
0
                    break;
755
                };
756
0
            bits >>= litlen_code_bits;
757
0
758
0
            let length_extra_mask = (1 << length_extra_bits) - 1;
759
0
            let length = length_base as usize + (bits & length_extra_mask) as usize;
760
0
            bits >>= length_extra_bits;
761
0
762
0
            let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize];
763
0
            let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
764
0
                (
765
0
                    (dist_entry >> 16) as u16,
766
0
                    (dist_entry >> 8) as u8 & 0xf,
767
0
                    dist_entry as u8,
768
0
                )
769
0
            } else if self.nbits > litlen_code_bits + length_extra_bits + 9 {
770
0
                if dist_entry >> 8 == 0 {
771
0
                    return Err(DecompressionError::InvalidDistanceCode);
772
0
                }
773
0
774
0
                let secondary_table_index =
775
0
                    (dist_entry >> 16) + ((bits >> 9) as u32 & (dist_entry & 0xff));
776
0
                let secondary_entry =
777
0
                    self.compression.dist_secondary_table[secondary_table_index as usize];
778
0
                let dist_symbol = (secondary_entry >> 4) as usize;
779
0
                if dist_symbol >= 30 {
780
0
                    return Err(DecompressionError::InvalidDistanceCode);
781
0
                }
782
0
783
0
                (
784
0
                    DIST_SYM_TO_DIST_BASE[dist_symbol],
785
0
                    DIST_SYM_TO_DIST_EXTRA[dist_symbol],
786
0
                    (secondary_entry & 0xf) as u8,
787
0
                )
788
            } else {
789
0
                break;
790
            };
791
0
            bits >>= dist_code_bits;
792
0
793
0
            let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
794
0
            let total_bits =
795
0
                litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits;
796
0
797
0
            if self.nbits < total_bits {
798
0
                break;
799
0
            } else if dist > output_index {
800
0
                return Err(DecompressionError::DistanceTooFarBack);
801
0
            }
802
0
803
0
            self.consume_bits(total_bits);
804
0
805
0
            let copy_length = length.min(output.len() - output_index);
806
0
            if dist == 1 {
807
0
                let last = output[output_index - 1];
808
0
                output[output_index..][..copy_length].fill(last);
809
0
810
0
                if copy_length < length {
811
0
                    self.queued_rle = Some((last, length - copy_length));
812
0
                    output_index = output.len();
813
0
                    break;
814
0
                }
815
0
            } else if output_index + length + 15 <= output.len() {
816
0
                let start = output_index - dist;
817
0
                output.copy_within(start..start + 16, output_index);
818
0
819
0
                if length > 16 || dist < 16 {
820
0
                    for i in (0..length).step_by(dist.min(16)).skip(1) {
821
0
                        output.copy_within(start + i..start + i + 16, output_index + i);
822
0
                    }
823
0
                }
824
            } else {
825
0
                if dist < copy_length {
826
0
                    for i in 0..copy_length {
827
0
                        output[output_index + i] = output[output_index + i - dist];
828
0
                    }
829
                } else {
830
0
                    output.copy_within(
831
0
                        output_index - dist..output_index + copy_length - dist,
832
0
                        output_index,
833
0
                    )
834
                }
835
836
0
                if copy_length < length {
837
0
                    self.queued_backref = Some((dist, length - copy_length));
838
0
                    output_index = output.len();
839
0
                    break;
840
0
                }
841
            }
842
0
            output_index += copy_length;
843
        }
844
845
0
        if self.state == State::CompressedData
846
0
            && self.queued_backref.is_none()
847
0
            && self.queued_rle.is_none()
848
0
            && self.nbits >= 15
849
0
            && self.peak_bits(15) as u16 & self.compression.eof_mask == self.compression.eof_code
850
        {
851
0
            self.consume_bits(self.compression.eof_bits);
852
0
            self.state = match self.last_block {
853
0
                true => State::Checksum,
854
0
                false => State::BlockHeader,
855
            };
856
0
        }
857
858
0
        Ok(output_index)
859
0
    }
860
861
    /// Decompresses a chunk of data.
862
    ///
863
    /// Returns the number of bytes read from `input` and the number of bytes written to `output`,
864
    /// or an error if the deflate stream is not valid. `input` is the compressed data. `output` is
865
    /// the buffer to write the decompressed data to, starting at index `output_position`.
866
    /// `end_of_input` indicates whether more data may be available in the future.
867
    ///
868
    /// The contents of `output` after `output_position` are ignored. However, this function may
869
    /// write additional data to `output` past what is indicated by the return value.
870
    ///
871
    /// When this function returns `Ok`, at least one of the following is true:
872
    /// - The input is fully consumed.
873
    /// - The output is full but there are more bytes to output.
874
    /// - The deflate stream is complete (and `is_done` will return true).
875
    ///
876
    /// # Panics
877
    ///
878
    /// This function will panic if `output_position` is out of bounds.
879
0
    pub fn read(
880
0
        &mut self,
881
0
        input: &[u8],
882
0
        output: &mut [u8],
883
0
        output_position: usize,
884
0
        end_of_input: bool,
885
0
    ) -> Result<(usize, usize), DecompressionError> {
886
0
        if let State::Done = self.state {
887
0
            return Ok((0, 0));
888
0
        }
889
0
890
0
        assert!(output_position <= output.len());
891
892
0
        let mut remaining_input = input;
893
0
        let mut output_index = output_position;
894
895
0
        if let Some((data, len)) = self.queued_rle.take() {
896
0
            let n = len.min(output.len() - output_index);
897
0
            output[output_index..][..n].fill(data);
898
0
            output_index += n;
899
0
            if n < len {
900
0
                self.queued_rle = Some((data, len - n));
901
0
                return Ok((0, n));
902
0
            }
903
0
        }
904
0
        if let Some((dist, len)) = self.queued_backref.take() {
905
0
            let n = len.min(output.len() - output_index);
906
0
            for i in 0..n {
907
0
                output[output_index + i] = output[output_index + i - dist];
908
0
            }
909
0
            output_index += n;
910
0
            if n < len {
911
0
                self.queued_backref = Some((dist, len - n));
912
0
                return Ok((0, n));
913
0
            }
914
0
        }
915
916
        // Main decoding state machine.
917
0
        let mut last_state = None;
918
0
        while last_state != Some(self.state) {
919
0
            last_state = Some(self.state);
920
0
            match self.state {
921
                State::ZlibHeader => {
922
0
                    self.fill_buffer(&mut remaining_input);
923
0
                    if self.nbits < 16 {
924
0
                        break;
925
0
                    }
926
0
927
0
                    let input0 = self.peak_bits(8);
928
0
                    let input1 = self.peak_bits(16) >> 8 & 0xff;
929
0
                    if input0 & 0x0f != 0x08
930
0
                        || (input0 & 0xf0) > 0x70
931
0
                        || input1 & 0x20 != 0
932
0
                        || (input0 << 8 | input1) % 31 != 0
933
                    {
934
0
                        return Err(DecompressionError::BadZlibHeader);
935
0
                    }
936
0
937
0
                    self.consume_bits(16);
938
0
                    self.state = State::BlockHeader;
939
                }
940
                State::BlockHeader => {
941
0
                    self.read_block_header(&mut remaining_input)?;
942
                }
943
                State::CodeLengthCodes => {
944
0
                    self.read_code_length_codes(&mut remaining_input)?;
945
                }
946
                State::CodeLengths => {
947
0
                    self.read_code_lengths(&mut remaining_input)?;
948
                }
949
                State::CompressedData => {
950
                    output_index =
951
0
                        self.read_compressed(&mut remaining_input, output, output_index)?
952
                }
953
                State::UncompressedData => {
954
                    // Drain any bytes from our buffer.
955
0
                    debug_assert_eq!(self.nbits % 8, 0);
956
0
                    while self.nbits > 0
957
0
                        && self.uncompressed_bytes_left > 0
958
0
                        && output_index < output.len()
959
0
                    {
960
0
                        output[output_index] = self.peak_bits(8) as u8;
961
0
                        self.consume_bits(8);
962
0
                        output_index += 1;
963
0
                        self.uncompressed_bytes_left -= 1;
964
0
                    }
965
                    // Buffer may contain one additional byte. Clear it to avoid confusion.
966
0
                    if self.nbits == 0 {
967
0
                        self.buffer = 0;
968
0
                    }
969
970
                    // Copy subsequent bytes directly from the input.
971
0
                    let copy_bytes = (self.uncompressed_bytes_left as usize)
972
0
                        .min(remaining_input.len())
973
0
                        .min(output.len() - output_index);
974
0
                    output[output_index..][..copy_bytes]
975
0
                        .copy_from_slice(&remaining_input[..copy_bytes]);
976
0
                    remaining_input = &remaining_input[copy_bytes..];
977
0
                    output_index += copy_bytes;
978
0
                    self.uncompressed_bytes_left -= copy_bytes as u16;
979
0
980
0
                    if self.uncompressed_bytes_left == 0 {
981
0
                        self.state = if self.last_block {
982
0
                            State::Checksum
983
                        } else {
984
0
                            State::BlockHeader
985
                        };
986
0
                    }
987
                }
988
                State::Checksum => {
989
0
                    self.fill_buffer(&mut remaining_input);
990
0
991
0
                    let align_bits = self.nbits % 8;
992
0
                    if self.nbits >= 32 + align_bits {
993
0
                        self.checksum.write(&output[output_position..output_index]);
994
0
                        if align_bits != 0 {
995
0
                            self.consume_bits(align_bits);
996
0
                        }
997
                        #[cfg(not(fuzzing))]
998
                        if !self.ignore_adler32
999
                            && (self.peak_bits(32) as u32).swap_bytes() != self.checksum.finish()
1000
                        {
1001
                            return Err(DecompressionError::WrongChecksum);
1002
                        }
1003
0
                        self.state = State::Done;
1004
0
                        self.consume_bits(32);
1005
0
                        break;
1006
0
                    }
1007
                }
1008
0
                State::Done => unreachable!(),
1009
            }
1010
        }
1011
1012
0
        if !self.ignore_adler32 && self.state != State::Done {
1013
0
            self.checksum.write(&output[output_position..output_index]);
1014
0
        }
1015
1016
0
        if self.state == State::Done || !end_of_input || output_index == output.len() {
1017
0
            let input_left = remaining_input.len();
1018
0
            Ok((input.len() - input_left, output_index - output_position))
1019
        } else {
1020
0
            Err(DecompressionError::InsufficientInput)
1021
        }
1022
0
    }
1023
1024
    /// Returns true if the decompressor has finished decompressing the input.
1025
0
    pub fn is_done(&self) -> bool {
1026
0
        self.state == State::Done
1027
0
    }
1028
}
1029
1030
/// Decompress the given data.
1031
0
pub fn decompress_to_vec(input: &[u8]) -> Result<Vec<u8>, DecompressionError> {
1032
0
    match decompress_to_vec_bounded(input, usize::MAX) {
1033
0
        Ok(output) => Ok(output),
1034
0
        Err(BoundedDecompressionError::DecompressionError { inner }) => Err(inner),
1035
        Err(BoundedDecompressionError::OutputTooLarge { .. }) => {
1036
0
            unreachable!("Impossible to allocate more than isize::MAX bytes")
1037
        }
1038
    }
1039
0
}
1040
1041
/// An error encountered while decompressing a deflate stream given a bounded maximum output.
1042
pub enum BoundedDecompressionError {
1043
    /// The input is not a valid deflate stream.
1044
    DecompressionError {
1045
        /// The underlying error.
1046
        inner: DecompressionError,
1047
    },
1048
1049
    /// The output is too large.
1050
    OutputTooLarge {
1051
        /// The output decoded so far.
1052
        partial_output: Vec<u8>,
1053
    },
1054
}
1055
impl From<DecompressionError> for BoundedDecompressionError {
1056
0
    fn from(inner: DecompressionError) -> Self {
1057
0
        BoundedDecompressionError::DecompressionError { inner }
1058
0
    }
1059
}
1060
1061
/// Decompress the given data, returning an error if the output is larger than
1062
/// `maxlen` bytes.
1063
0
pub fn decompress_to_vec_bounded(
1064
0
    input: &[u8],
1065
0
    maxlen: usize,
1066
0
) -> Result<Vec<u8>, BoundedDecompressionError> {
1067
0
    let mut decoder = Decompressor::new();
1068
0
    let mut output = vec![0; 1024.min(maxlen)];
1069
0
    let mut input_index = 0;
1070
0
    let mut output_index = 0;
1071
    loop {
1072
0
        let (consumed, produced) =
1073
0
            decoder.read(&input[input_index..], &mut output, output_index, true)?;
1074
0
        input_index += consumed;
1075
0
        output_index += produced;
1076
0
        if decoder.is_done() || output_index == maxlen {
1077
0
            break;
1078
0
        }
1079
0
        output.resize((output_index + 32 * 1024).min(maxlen), 0);
1080
    }
1081
0
    output.resize(output_index, 0);
1082
0
1083
0
    if decoder.is_done() {
1084
0
        Ok(output)
1085
    } else {
1086
0
        Err(BoundedDecompressionError::OutputTooLarge {
1087
0
            partial_output: output,
1088
0
        })
1089
    }
1090
0
}
1091
1092
#[cfg(test)]
1093
mod tests {
1094
    use crate::tables::{LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL};
1095
1096
    use super::*;
1097
    use rand::Rng;
1098
1099
    fn roundtrip(data: &[u8]) {
1100
        let compressed = crate::compress_to_vec(data);
1101
        let decompressed = decompress_to_vec(&compressed).unwrap();
1102
        assert_eq!(&decompressed, data);
1103
    }
1104
1105
    fn roundtrip_miniz_oxide(data: &[u8]) {
1106
        let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3);
1107
        let decompressed = decompress_to_vec(&compressed).unwrap();
1108
        assert_eq!(decompressed.len(), data.len());
1109
        for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() {
1110
            assert_eq!(a, b, "chunk {}..{}", i, i + 1);
1111
        }
1112
        assert_eq!(&decompressed, data);
1113
    }
1114
1115
    #[allow(unused)]
1116
    fn compare_decompression(data: &[u8]) {
1117
        // let decompressed0 = flate2::read::ZlibDecoder::new(std::io::Cursor::new(&data))
1118
        //     .bytes()
1119
        //     .collect::<Result<Vec<_>, _>>()
1120
        //     .unwrap();
1121
        let decompressed = decompress_to_vec(data).unwrap();
1122
        let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(data).unwrap();
1123
        for i in 0..decompressed.len().min(decompressed2.len()) {
1124
            if decompressed[i] != decompressed2[i] {
1125
                panic!(
1126
                    "mismatch at index {} {:?} {:?}",
1127
                    i,
1128
                    &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())],
1129
                    &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())]
1130
                );
1131
            }
1132
        }
1133
        if decompressed != decompressed2 {
1134
            panic!(
1135
                "length mismatch {} {} {:x?}",
1136
                decompressed.len(),
1137
                decompressed2.len(),
1138
                &decompressed2[decompressed.len()..][..16]
1139
            );
1140
        }
1141
        //assert_eq!(decompressed, decompressed2);
1142
    }
1143
1144
    #[test]
1145
    fn tables() {
1146
        for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() {
1147
            let len_base = LEN_SYM_TO_LEN_BASE[i];
1148
            for j in 0..(1 << bits) {
1149
                if i == 27 && j == 31 {
1150
                    continue;
1151
                }
1152
                assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j);
1153
                assert_eq!(
1154
                    LENGTH_TO_SYMBOL[len_base + j - 3],
1155
                    i as u16 + 257,
1156
                    "{} {}",
1157
                    i,
1158
                    j
1159
                );
1160
            }
1161
        }
1162
    }
1163
1164
    #[test]
1165
    fn fixed_tables() {
1166
        let mut compression = CompressedBlock {
1167
            litlen_table: Box::new([0; 4096]),
1168
            dist_table: Box::new([0; 512]),
1169
            secondary_table: Vec::new(),
1170
            dist_secondary_table: Vec::new(),
1171
            eof_code: 0,
1172
            eof_mask: 0,
1173
            eof_bits: 0,
1174
        };
1175
        Decompressor::build_tables(288, &FIXED_CODE_LENGTHS, &mut compression).unwrap();
1176
1177
        assert_eq!(compression.litlen_table[..512], FIXED_LITLEN_TABLE);
1178
        assert_eq!(compression.dist_table[..32], FIXED_DIST_TABLE);
1179
    }
1180
1181
    #[test]
1182
    fn it_works() {
1183
        roundtrip(b"Hello world!");
1184
    }
1185
1186
    #[test]
1187
    fn constant() {
1188
        roundtrip_miniz_oxide(&[0; 50]);
1189
        roundtrip_miniz_oxide(&vec![5; 2048]);
1190
        roundtrip_miniz_oxide(&vec![128; 2048]);
1191
        roundtrip_miniz_oxide(&vec![254; 2048]);
1192
    }
1193
1194
    #[test]
1195
    fn random() {
1196
        let mut rng = rand::thread_rng();
1197
        let mut data = vec![0; 50000];
1198
        for _ in 0..10 {
1199
            for byte in &mut data {
1200
                *byte = rng.gen::<u8>() % 5;
1201
            }
1202
            println!("Random data: {:?}", data);
1203
            roundtrip_miniz_oxide(&data);
1204
        }
1205
    }
1206
1207
    #[test]
1208
    fn ignore_adler32() {
1209
        let mut compressed = crate::compress_to_vec(b"Hello world!");
1210
        let last_byte = compressed.len() - 1;
1211
        compressed[last_byte] = compressed[last_byte].wrapping_add(1);
1212
1213
        match decompress_to_vec(&compressed) {
1214
            Err(DecompressionError::WrongChecksum) => {}
1215
            r => panic!("expected WrongChecksum, got {:?}", r),
1216
        }
1217
1218
        let mut decompressor = Decompressor::new();
1219
        decompressor.ignore_adler32();
1220
        let mut decompressed = vec![0; 1024];
1221
        let decompressed_len = decompressor
1222
            .read(&compressed, &mut decompressed, 0, true)
1223
            .unwrap()
1224
            .1;
1225
        assert_eq!(&decompressed[..decompressed_len], b"Hello world!");
1226
    }
1227
1228
    #[test]
1229
    fn checksum_after_eof() {
1230
        let input = b"Hello world!";
1231
        let compressed = crate::compress_to_vec(input);
1232
1233
        let mut decompressor = Decompressor::new();
1234
        let mut decompressed = vec![0; 1024];
1235
        let (input_consumed, output_written) = decompressor
1236
            .read(
1237
                &compressed[..compressed.len() - 1],
1238
                &mut decompressed,
1239
                0,
1240
                false,
1241
            )
1242
            .unwrap();
1243
        assert_eq!(output_written, input.len());
1244
        assert_eq!(input_consumed, compressed.len() - 1);
1245
1246
        let (input_consumed, output_written) = decompressor
1247
            .read(
1248
                &compressed[input_consumed..],
1249
                &mut decompressed[..output_written],
1250
                output_written,
1251
                true,
1252
            )
1253
            .unwrap();
1254
        assert!(decompressor.is_done());
1255
        assert_eq!(input_consumed, 1);
1256
        assert_eq!(output_written, 0);
1257
1258
        assert_eq!(&decompressed[..input.len()], input);
1259
    }
1260
1261
    #[test]
1262
    fn zero_length() {
1263
        let mut compressed = crate::compress_to_vec(b"").to_vec();
1264
1265
        // Splice in zero-length non-compressed blocks.
1266
        for _ in 0..10 {
1267
            println!("compressed len: {}", compressed.len());
1268
            compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter());
1269
        }
1270
1271
        // Ensure that the full input is decompressed, regardless of whether
1272
        // `end_of_input` is set.
1273
        for end_of_input in [true, false] {
1274
            let mut decompressor = Decompressor::new();
1275
            let (input_consumed, output_written) = decompressor
1276
                .read(&compressed, &mut [], 0, end_of_input)
1277
                .unwrap();
1278
1279
            assert!(decompressor.is_done());
1280
            assert_eq!(input_consumed, compressed.len());
1281
            assert_eq!(output_written, 0);
1282
        }
1283
    }
1284
1285
    mod test_utils;
1286
    use tables::FIXED_CODE_LENGTHS;
1287
    use test_utils::{decompress_by_chunks, TestDecompressionError};
1288
1289
    fn verify_no_sensitivity_to_input_chunking(
1290
        input: &[u8],
1291
    ) -> Result<Vec<u8>, TestDecompressionError> {
1292
        let r_whole = decompress_by_chunks(input, vec![input.len()], false);
1293
        let r_bytewise = decompress_by_chunks(input, std::iter::repeat(1), false);
1294
        assert_eq!(r_whole, r_bytewise);
1295
        r_whole // Returning an arbitrary result, since this is equal to `r_bytewise`.
1296
    }
1297
1298
    /// This is a regression test found by the `buf_independent` fuzzer from the `png` crate.  When
1299
    /// this test case was found, the results were unexpectedly different when 1) decompressing the
1300
    /// whole input (successful result) vs 2) decompressing byte-by-byte
1301
    /// (`Err(InvalidDistanceCode)`).
1302
    #[test]
1303
    fn test_input_chunking_sensitivity_when_handling_distance_codes() {
1304
        let result = verify_no_sensitivity_to_input_chunking(include_bytes!(
1305
            "../tests/input-chunking-sensitivity-example1.zz"
1306
        ))
1307
        .unwrap();
1308
        assert_eq!(result.len(), 281);
1309
        assert_eq!(simd_adler32::adler32(&result.as_slice()), 751299);
1310
    }
1311
1312
    /// This is a regression test found by the `inflate_bytewise3` fuzzer from the `fdeflate`
1313
    /// crate.  When this test case was found, the results were unexpectedly different when 1)
1314
    /// decompressing the whole input (`Err(DistanceTooFarBack)`) vs 2) decompressing byte-by-byte
1315
    /// (successful result)`).
1316
    #[test]
1317
    fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example1() {
1318
        let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1319
            "../tests/input-chunking-sensitivity-example2.zz"
1320
        ))
1321
        .unwrap_err();
1322
        assert_eq!(
1323
            err,
1324
            TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1325
        );
1326
    }
1327
1328
    /// This is a regression test found by the `inflate_bytewise3` fuzzer from the `fdeflate`
1329
    /// crate.  When this test case was found, the results were unexpectedly different when 1)
1330
    /// decompressing the whole input (`Err(InvalidDistanceCode)`) vs 2) decompressing byte-by-byte
1331
    /// (successful result)`).
1332
    #[test]
1333
    fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example2() {
1334
        let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1335
            "../tests/input-chunking-sensitivity-example3.zz"
1336
        ))
1337
        .unwrap_err();
1338
        assert_eq!(
1339
            err,
1340
            TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1341
        );
1342
    }
1343
}