Coverage Report

Created: 2025-11-14 07:32

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/dec_ans.cc
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_ans.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <algorithm>
11
#include <cstddef>
12
#include <cstdint>
13
#include <utility>
14
#include <vector>
15
16
#include "lib/jxl/ans_common.h"
17
#include "lib/jxl/ans_params.h"
18
#include "lib/jxl/base/bits.h"
19
#include "lib/jxl/base/compiler_specific.h"
20
#include "lib/jxl/base/printf_macros.h"
21
#include "lib/jxl/base/status.h"
22
#include "lib/jxl/dec_bit_reader.h"
23
#include "lib/jxl/dec_context_map.h"
24
#include "lib/jxl/dec_huffman.h"
25
#include "lib/jxl/field_encodings.h"
26
#include "lib/jxl/fields.h"
27
#include "lib/jxl/memory_manager_internal.h"
28
29
namespace jxl {
30
namespace {
31
32
// Decodes a number in the range [0..255], by reading 1 - 11 bits.
33
125k
inline int DecodeVarLenUint8(BitReader* input) {
34
125k
  if (input->ReadFixedBits<1>()) {
35
67.6k
    int nbits = static_cast<int>(input->ReadFixedBits<3>());
36
67.6k
    if (nbits == 0) {
37
6.71k
      return 1;
38
60.9k
    } else {
39
60.9k
      return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits);
40
60.9k
    }
41
67.6k
  }
42
57.4k
  return 0;
43
125k
}
44
45
// Decodes a number in the range [0..65535], by reading 1 - 21 bits.
46
101k
inline int DecodeVarLenUint16(BitReader* input) {
47
101k
  if (input->ReadFixedBits<1>()) {
48
21.6k
    int nbits = static_cast<int>(input->ReadFixedBits<4>());
49
21.6k
    if (nbits == 0) {
50
2.31k
      return 1;
51
19.3k
    } else {
52
19.3k
      return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits);
53
19.3k
    }
54
21.6k
  }
55
79.9k
  return 0;
56
101k
}
57
58
Status ReadHistogram(int precision_bits, std::vector<int32_t>* counts,
59
98.4k
                     BitReader* input) {
60
98.4k
  int range = 1 << precision_bits;
61
98.4k
  int simple_code = input->ReadBits(1);
62
98.4k
  if (simple_code == 1) {
63
22.7k
    int i;
64
22.7k
    int symbols[2] = {0};
65
22.7k
    int max_symbol = 0;
66
22.7k
    const int num_symbols = input->ReadBits(1) + 1;
67
50.3k
    for (i = 0; i < num_symbols; ++i) {
68
27.6k
      symbols[i] = DecodeVarLenUint8(input);
69
27.6k
      if (symbols[i] > max_symbol) max_symbol = symbols[i];
70
27.6k
    }
71
22.7k
    counts->resize(max_symbol + 1);
72
22.7k
    if (num_symbols == 1) {
73
17.9k
      (*counts)[symbols[0]] = range;
74
17.9k
    } else {
75
4.81k
      if (symbols[0] == symbols[1]) {  // corrupt data
76
184
        return false;
77
184
      }
78
4.63k
      (*counts)[symbols[0]] = input->ReadBits(precision_bits);
79
4.63k
      (*counts)[symbols[1]] = range - (*counts)[symbols[0]];
80
4.63k
    }
81
75.6k
  } else {
82
75.6k
    int is_flat = input->ReadBits(1);
83
75.6k
    if (is_flat == 1) {
84
14.6k
      int alphabet_size = DecodeVarLenUint8(input) + 1;
85
14.6k
      JXL_ENSURE(alphabet_size <= range);
86
14.6k
      *counts = CreateFlatHistogram(alphabet_size, range);
87
14.6k
      return true;
88
14.6k
    }
89
90
61.0k
    uint32_t shift;
91
61.0k
    {
92
      // TODO(veluca): speed up reading with table lookups.
93
61.0k
      int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
94
61.0k
      int log = 0;
95
76.5k
      for (; log < upper_bound_log; log++) {
96
75.8k
        if (input->ReadFixedBits<1>() == 0) break;
97
75.8k
      }
98
61.0k
      shift = (input->ReadBits(log) | (1 << log)) - 1;
99
61.0k
      if (shift > ANS_LOG_TAB_SIZE + 1) {
100
4
        return JXL_FAILURE("Invalid shift value");
101
4
      }
102
61.0k
    }
103
104
61.0k
    const size_t length = DecodeVarLenUint8(input) + 3;
105
61.0k
    counts->resize(length);
106
61.0k
    int total_count = 0;
107
108
61.0k
    static const uint8_t huff[128][2] = {
109
61.0k
        {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
110
61.0k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
111
61.0k
        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
112
61.0k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
113
61.0k
        {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
114
61.0k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
115
61.0k
        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
116
61.0k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
117
61.0k
        {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
118
61.0k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
119
61.0k
        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
120
61.0k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
121
61.0k
        {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
122
61.0k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
123
61.0k
        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
124
61.0k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
125
61.0k
    };
126
127
61.0k
    std::vector<int> logcounts(length);
128
61.0k
    int omit_log = -1;
129
61.0k
    int omit_pos = -1;
130
    // This array remembers which symbols have an RLE length.
131
61.0k
    std::vector<int> same(length);
132
843k
    for (size_t i = 0; i < length; ++i) {
133
782k
      input->Refill();  // for PeekFixedBits + Advance
134
782k
      int idx = input->PeekFixedBits<7>();
135
782k
      input->Consume(huff[idx][0]);
136
782k
      logcounts[i] = int(huff[idx][1]) - 1;
137
      // The RLE symbol.
138
782k
      if (logcounts[i] == ANS_LOG_TAB_SIZE) {
139
21.9k
        int rle_length = DecodeVarLenUint8(input);
140
21.9k
        same[i] = rle_length + 5;
141
21.9k
        i += rle_length + 3;
142
21.9k
        continue;
143
21.9k
      }
144
760k
      if (logcounts[i] > omit_log) {
145
106k
        omit_log = logcounts[i];
146
106k
        omit_pos = i;
147
106k
      }
148
760k
    }
149
    // Invalid input, e.g. due to invalid usage of RLE.
150
61.0k
    if (omit_pos < 0) return JXL_FAILURE("Invalid histogram.");
151
61.0k
    if (static_cast<size_t>(omit_pos) + 1 < length &&
152
56.1k
        logcounts[omit_pos + 1] == ANS_LOG_TAB_SIZE) {
153
179
      return JXL_FAILURE("Invalid histogram.");
154
179
    }
155
60.8k
    int prev = 0;
156
60.8k
    int numsame = 0;
157
1.05M
    for (size_t i = 0; i < length; ++i) {
158
992k
      if (same[i]) {
159
        // RLE sequence, let this loop output the same count for the next
160
        // iterations.
161
20.8k
        numsame = same[i] - 1;
162
20.8k
        prev = i > 0 ? (*counts)[i - 1] : 0;
163
20.8k
      }
164
992k
      if (numsame > 0) {
165
256k
        (*counts)[i] = prev;
166
256k
        numsame--;
167
736k
      } else {
168
736k
        int code = logcounts[i];
169
        // omit_pos may not be negative at this point (checked before).
170
736k
        if (i == static_cast<size_t>(omit_pos) || code < 0) {
171
201k
          continue;
172
534k
        } else if (shift == 0 || code == 0) {
173
          // `shift = 0` means `bitcount = 0`
174
417k
          (*counts)[i] = 1 << code;
175
417k
        } else {
176
117k
          int bitcount = GetPopulationCountPrecision(code, shift);
177
117k
          (*counts)[i] = (1 << code) +
178
117k
                         (input->ReadBits(bitcount) << (code - bitcount));
179
117k
        }
180
736k
      }
181
791k
      total_count += (*counts)[i];
182
791k
    }
183
60.8k
    (*counts)[omit_pos] = range - total_count;
184
60.8k
    if ((*counts)[omit_pos] <= 0) {
185
      // The histogram we've read sums to more than total_count (including at
186
      // least 1 for the omitted value).
187
1.29k
      return JXL_FAILURE("Invalid histogram count.");
188
1.29k
    }
189
60.8k
  }
190
82.1k
  return true;
191
98.4k
}
192
193
}  // namespace
194
195
Status DecodeANSCodes(JxlMemoryManager* memory_manager,
196
                      const size_t num_histograms,
197
                      const size_t max_alphabet_size, BitReader* in,
198
121k
                      ANSCode* result) {
199
121k
  result->memory_manager = memory_manager;
200
121k
  result->degenerate_symbols.resize(num_histograms, -1);
201
121k
  if (result->use_prefix_code) {
202
67.7k
    JXL_ENSURE(max_alphabet_size <= 1 << PREFIX_MAX_BITS);
203
67.7k
    result->huffman_data.resize(num_histograms);
204
67.7k
    std::vector<uint16_t> alphabet_sizes(num_histograms);
205
169k
    for (size_t c = 0; c < num_histograms; c++) {
206
101k
      alphabet_sizes[c] = DecodeVarLenUint16(in) + 1;
207
101k
      if (alphabet_sizes[c] > max_alphabet_size) {
208
178
        return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]);
209
178
      }
210
101k
    }
211
152k
    for (size_t c = 0; c < num_histograms; c++) {
212
88.5k
      if (alphabet_sizes[c] > 1) {
213
15.4k
        if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) {
214
3.68k
          if (!in->AllReadsWithinBounds()) {
215
3.50k
            return JXL_NOT_ENOUGH_BYTES("Not enough bytes for huffman code");
216
3.50k
          }
217
173
          return JXL_FAILURE("Invalid huffman tree number %" PRIuS
218
3.68k
                             ", alphabet size %u",
219
3.68k
                             c, alphabet_sizes[c]);
220
3.68k
        }
221
73.1k
      } else {
222
        // 0-bit codes does not require extension tables.
223
73.1k
        result->huffman_data[c].table_.clear();
224
73.1k
        result->huffman_data[c].table_.resize(1u << kHuffmanTableBits);
225
73.1k
      }
226
25.8M
      for (const auto& h : result->huffman_data[c].table_) {
227
25.8M
        if (h.bits <= kHuffmanTableBits) {
228
25.7M
          result->UpdateMaxNumBits(c, h.value);
229
25.7M
        }
230
25.8M
      }
231
84.8k
    }
232
67.5k
  } else {
233
54.0k
    JXL_ENSURE(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE);
234
54.0k
    size_t alloc_size = num_histograms * (1 << result->log_alpha_size) *
235
54.0k
                        sizeof(AliasTable::Entry);
236
54.0k
    JXL_ASSIGN_OR_RETURN(result->alias_tables,
237
54.0k
                         AlignedMemory::Create(memory_manager, alloc_size));
238
54.0k
    AliasTable::Entry* alias_tables =
239
54.0k
        result->alias_tables.address<AliasTable::Entry>();
240
150k
    for (size_t c = 0; c < num_histograms; ++c) {
241
98.4k
      std::vector<int32_t> counts;
242
98.4k
      if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) {
243
1.66k
        return JXL_FAILURE("Invalid histogram bitstream.");
244
1.66k
      }
245
96.7k
      if (counts.size() > max_alphabet_size) {
246
488
        return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size());
247
488
      }
248
105k
      while (!counts.empty() && counts.back() == 0) {
249
9.04k
        counts.pop_back();
250
9.04k
      }
251
1.55M
      for (size_t s = 0; s < counts.size(); s++) {
252
1.46M
        if (counts[s] != 0) {
253
726k
          result->UpdateMaxNumBits(c, s);
254
726k
        }
255
1.46M
      }
256
      // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol.
257
96.2k
      int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1);
258
310k
      for (int s = 0; s < degenerate_symbol; ++s) {
259
283k
        if (counts[s] != 0) {
260
68.6k
          degenerate_symbol = -1;
261
68.6k
          break;
262
68.6k
        }
263
283k
      }
264
96.2k
      result->degenerate_symbols[c] = degenerate_symbol;
265
96.2k
      JXL_RETURN_IF_ERROR(
266
96.2k
          InitAliasTable(counts, ANS_LOG_TAB_SIZE, result->log_alpha_size,
267
96.2k
                         alias_tables + c * (1 << result->log_alpha_size)));
268
96.2k
    }
269
54.0k
  }
270
115k
  return true;
271
121k
}
272
Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config,
273
226k
                        BitReader* br) {
274
226k
  br->Refill();
275
226k
  size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1));
276
226k
  size_t msb_in_token = 0;
277
226k
  size_t lsb_in_token = 0;
278
226k
  if (split_exponent != log_alpha_size) {
279
    // otherwise, msb/lsb don't matter.
280
202k
    size_t nbits = CeilLog2Nonzero(split_exponent + 1);
281
202k
    msb_in_token = br->ReadBits(nbits);
282
202k
    if (msb_in_token > split_exponent) {
283
      // This could be invalid here already and we need to check this before
284
      // we use its value to read more bits.
285
94
      return JXL_FAILURE("Invalid HybridUintConfig");
286
94
    }
287
202k
    nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1);
288
202k
    lsb_in_token = br->ReadBits(nbits);
289
202k
  }
290
226k
  if (lsb_in_token + msb_in_token > split_exponent) {
291
49
    return JXL_FAILURE("Invalid HybridUintConfig");
292
49
  }
293
226k
  *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token);
294
226k
  return true;
295
226k
}
296
297
Status DecodeUintConfigs(size_t log_alpha_size,
298
                         std::vector<HybridUintConfig>* uint_config,
299
121k
                         BitReader* br) {
300
  // TODO(veluca): RLE?
301
201k
  for (auto& cfg : *uint_config) {
302
201k
    JXL_RETURN_IF_ERROR(DecodeUintConfig(log_alpha_size, &cfg, br));
303
201k
  }
304
121k
  return true;
305
121k
}
306
307
262k
LZ77Params::LZ77Params() { Bundle::Init(this); }
308
452k
Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) {
309
452k
  JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled));
310
451k
  if (!visitor->Conditional(enabled)) return true;
311
292k
  JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096),
312
292k
                                         BitsOffset(15, 8), 224, &min_symbol));
313
292k
  JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5),
314
292k
                                         BitsOffset(8, 9), 3, &min_length));
315
291k
  return true;
316
292k
}
317
318
26.4M
void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) {
319
26.4M
  HybridUintConfig* cfg = &uint_config[ctx];
320
  // LZ77 symbols use a different uint config.
321
26.4M
  if (lz77.enabled && lz77.nonserialized_distance_context != ctx &&
322
2.36M
      symbol >= lz77.min_symbol) {
323
88.8k
    symbol -= lz77.min_symbol;
324
88.8k
    cfg = &lz77.length_uint_config;
325
88.8k
  }
326
26.4M
  size_t split_token = cfg->split_token;
327
26.4M
  size_t msb_in_token = cfg->msb_in_token;
328
26.4M
  size_t lsb_in_token = cfg->lsb_in_token;
329
26.4M
  size_t split_exponent = cfg->split_exponent;
330
26.4M
  if (symbol < split_token) {
331
21.8M
    max_num_bits = std::max(max_num_bits, split_exponent);
332
21.8M
    return;
333
21.8M
  }
334
4.65M
  uint32_t n_extra_bits =
335
4.65M
      split_exponent - (msb_in_token + lsb_in_token) +
336
4.65M
      ((symbol - split_token) >> (msb_in_token + lsb_in_token));
337
4.65M
  size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1;
338
4.65M
  max_num_bits = std::max(max_num_bits, total_bits);
339
4.65M
}
340
341
Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br,
342
                        size_t num_contexts, ANSCode* code,
343
128k
                        std::vector<uint8_t>* context_map, bool disallow_lz77) {
344
128k
  JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77));
345
127k
  if (code->lz77.enabled) {
346
25.3k
    num_contexts++;
347
25.3k
    JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8,
348
25.3k
                                         &code->lz77.length_uint_config, br));
349
25.3k
  }
350
127k
  if (code->lz77.enabled && disallow_lz77) {
351
42
    return JXL_FAILURE("Using LZ77 when explicitly disallowed");
352
42
  }
353
127k
  size_t num_histograms = 1;
354
127k
  context_map->resize(num_contexts);
355
127k
  if (num_contexts > 1) {
356
75.0k
    JXL_RETURN_IF_ERROR(
357
75.0k
        DecodeContextMap(memory_manager, context_map, &num_histograms, br));
358
75.0k
  }
359
121k
  JXL_DEBUG_V(
360
121k
      4, "Decoded context map of size %" PRIuS " and %" PRIuS " histograms",
361
121k
      num_contexts, num_histograms);
362
121k
  code->lz77.nonserialized_distance_context = context_map->back();
363
121k
  code->use_prefix_code = static_cast<bool>(br->ReadFixedBits<1>());
364
121k
  if (code->use_prefix_code) {
365
67.7k
    code->log_alpha_size = PREFIX_MAX_BITS;
366
67.7k
  } else {
367
54.1k
    code->log_alpha_size = br->ReadFixedBits<2>() + 5;
368
54.1k
  }
369
121k
  code->uint_config.resize(num_histograms);
370
121k
  JXL_RETURN_IF_ERROR(
371
121k
      DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br));
372
121k
  const size_t max_alphabet_size = 1 << code->log_alpha_size;
373
121k
  JXL_RETURN_IF_ERROR(DecodeANSCodes(memory_manager, num_histograms,
374
121k
                                     max_alphabet_size, br, code));
375
115k
  return true;
376
121k
}
377
378
StatusOr<ANSSymbolReader> ANSSymbolReader::Create(const ANSCode* code,
379
                                                  BitReader* JXL_RESTRICT br,
380
123k
                                                  size_t distance_multiplier) {
381
123k
  AlignedMemory lz77_window_storage;
382
123k
  if (code->lz77.enabled) {
383
18.9k
    JxlMemoryManager* memory_manager = code->memory_manager;
384
18.9k
    JXL_ASSIGN_OR_RETURN(
385
18.9k
        lz77_window_storage,
386
18.9k
        AlignedMemory::Create(memory_manager, kWindowSize * sizeof(uint32_t)));
387
18.9k
  }
388
123k
  return ANSSymbolReader(code, br, distance_multiplier,
389
123k
                         std::move(lz77_window_storage));
390
123k
}
391
392
ANSSymbolReader::ANSSymbolReader(const ANSCode* code,
393
                                 BitReader* JXL_RESTRICT br,
394
                                 size_t distance_multiplier,
395
                                 AlignedMemory&& lz77_window_storage)
396
123k
    : alias_tables_(code->alias_tables.address<AliasTable::Entry>()),
397
123k
      huffman_data_(code->huffman_data.data()),
398
123k
      use_prefix_code_(code->use_prefix_code),
399
123k
      configs(code->uint_config.data()),
400
123k
      lz77_window_storage_(std::move(lz77_window_storage)) {
401
123k
  if (!use_prefix_code_) {
402
52.3k
    state_ = static_cast<uint32_t>(br->ReadFixedBits<32>());
403
52.3k
    log_alpha_size_ = code->log_alpha_size;
404
52.3k
    log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size;
405
52.3k
    entry_size_minus_1_ = (1 << log_entry_size_) - 1;
406
70.9k
  } else {
407
70.9k
    state_ = (ANS_SIGNATURE << 16u);
408
70.9k
  }
409
123k
  if (!code->lz77.enabled) return;
410
18.9k
  lz77_window_ = lz77_window_storage_.address<uint32_t>();
411
18.9k
  lz77_ctx_ = code->lz77.nonserialized_distance_context;
412
18.9k
  lz77_length_uint_ = code->lz77.length_uint_config;
413
18.9k
  lz77_threshold_ = code->lz77.min_symbol;
414
18.9k
  lz77_min_length_ = code->lz77.min_length;
415
18.9k
  num_special_distances_ = distance_multiplier == 0 ? 0 : kNumSpecialDistances;
416
457k
  for (size_t i = 0; i < num_special_distances_; i++) {
417
438k
    special_distances_[i] = SpecialDistance(i, distance_multiplier);
418
438k
  }
419
18.9k
}
420
421
}  // namespace jxl