Coverage Report

Created: 2025-07-16 07:53

/src/libjxl/lib/jxl/dec_ans.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_ans.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <cstdint>
11
#include <vector>
12
13
#include "lib/jxl/ans_common.h"
14
#include "lib/jxl/ans_params.h"
15
#include "lib/jxl/base/bits.h"
16
#include "lib/jxl/base/printf_macros.h"
17
#include "lib/jxl/base/status.h"
18
#include "lib/jxl/dec_context_map.h"
19
#include "lib/jxl/fields.h"
20
#include "lib/jxl/memory_manager_internal.h"
21
22
namespace jxl {
23
namespace {
24
25
// Decodes a number in the range [0..255], by reading 1 - 11 bits.
26
15.5k
inline int DecodeVarLenUint8(BitReader* input) {
27
15.5k
  if (input->ReadFixedBits<1>()) {
28
9.38k
    int nbits = static_cast<int>(input->ReadFixedBits<3>());
29
9.38k
    if (nbits == 0) {
30
360
      return 1;
31
9.02k
    } else {
32
9.02k
      return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits);
33
9.02k
    }
34
9.38k
  }
35
6.12k
  return 0;
36
15.5k
}
37
38
// Decodes a number in the range [0..65535], by reading 1 - 21 bits.
39
34.4k
inline int DecodeVarLenUint16(BitReader* input) {
40
34.4k
  if (input->ReadFixedBits<1>()) {
41
13.1k
    int nbits = static_cast<int>(input->ReadFixedBits<4>());
42
13.1k
    if (nbits == 0) {
43
0
      return 1;
44
13.1k
    } else {
45
13.1k
      return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits);
46
13.1k
    }
47
13.1k
  }
48
21.2k
  return 0;
49
34.4k
}
50
51
Status ReadHistogram(int precision_bits, std::vector<int32_t>* counts,
52
12.3k
                     BitReader* input) {
53
12.3k
  int range = 1 << precision_bits;
54
12.3k
  int simple_code = input->ReadBits(1);
55
12.3k
  if (simple_code == 1) {
56
2.81k
    int i;
57
2.81k
    int symbols[2] = {0};
58
2.81k
    int max_symbol = 0;
59
2.81k
    const int num_symbols = input->ReadBits(1) + 1;
60
5.79k
    for (i = 0; i < num_symbols; ++i) {
61
2.98k
      symbols[i] = DecodeVarLenUint8(input);
62
2.98k
      if (symbols[i] > max_symbol) max_symbol = symbols[i];
63
2.98k
    }
64
2.81k
    counts->resize(max_symbol + 1);
65
2.81k
    if (num_symbols == 1) {
66
2.65k
      (*counts)[symbols[0]] = range;
67
2.65k
    } else {
68
164
      if (symbols[0] == symbols[1]) {  // corrupt data
69
2
        return false;
70
2
      }
71
162
      (*counts)[symbols[0]] = input->ReadBits(precision_bits);
72
162
      (*counts)[symbols[1]] = range - (*counts)[symbols[0]];
73
162
    }
74
9.54k
  } else {
75
9.54k
    int is_flat = input->ReadBits(1);
76
9.54k
    if (is_flat == 1) {
77
2.46k
      int alphabet_size = DecodeVarLenUint8(input) + 1;
78
2.46k
      JXL_ENSURE(alphabet_size <= range);
79
2.46k
      *counts = CreateFlatHistogram(alphabet_size, range);
80
2.46k
      return true;
81
2.46k
    }
82
83
7.07k
    uint32_t shift;
84
7.07k
    {
85
      // TODO(veluca): speed up reading with table lookups.
86
7.07k
      int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
87
7.07k
      int log = 0;
88
9.81k
      for (; log < upper_bound_log; log++) {
89
9.81k
        if (input->ReadFixedBits<1>() == 0) break;
90
9.81k
      }
91
7.07k
      shift = (input->ReadBits(log) | (1 << log)) - 1;
92
7.07k
      if (shift > ANS_LOG_TAB_SIZE + 1) {
93
0
        return JXL_FAILURE("Invalid shift value");
94
0
      }
95
7.07k
    }
96
97
7.07k
    int length = DecodeVarLenUint8(input) + 3;
98
7.07k
    counts->resize(length);
99
7.07k
    int total_count = 0;
100
101
7.07k
    static const uint8_t huff[128][2] = {
102
7.07k
        {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
103
7.07k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
104
7.07k
        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
105
7.07k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
106
7.07k
        {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
107
7.07k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
108
7.07k
        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
109
7.07k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
110
7.07k
        {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
111
7.07k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
112
7.07k
        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
113
7.07k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
114
7.07k
        {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
115
7.07k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
116
7.07k
        {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
117
7.07k
        {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
118
7.07k
    };
119
120
7.07k
    std::vector<int> logcounts(counts->size());
121
7.07k
    int omit_log = -1;
122
7.07k
    int omit_pos = -1;
123
    // This array remembers which symbols have an RLE length.
124
7.07k
    std::vector<int> same(counts->size(), 0);
125
108k
    for (size_t i = 0; i < logcounts.size(); ++i) {
126
101k
      input->Refill();  // for PeekFixedBits + Advance
127
101k
      int idx = input->PeekFixedBits<7>();
128
101k
      input->Consume(huff[idx][0]);
129
101k
      logcounts[i] = huff[idx][1];
130
      // The RLE symbol.
131
101k
      if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) {
132
2.99k
        int rle_length = DecodeVarLenUint8(input);
133
2.99k
        same[i] = rle_length + 5;
134
2.99k
        i += rle_length + 3;
135
2.99k
        continue;
136
2.99k
      }
137
98.8k
      if (logcounts[i] > omit_log) {
138
10.4k
        omit_log = logcounts[i];
139
10.4k
        omit_pos = i;
140
10.4k
      }
141
98.8k
    }
142
    // Invalid input, e.g. due to invalid usage of RLE.
143
7.07k
    if (omit_pos < 0) return JXL_FAILURE("Invalid histogram.");
144
7.07k
    if (static_cast<size_t>(omit_pos) + 1 < logcounts.size() &&
145
7.07k
        logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) {
146
0
      return JXL_FAILURE("Invalid histogram.");
147
0
    }
148
7.07k
    int prev = 0;
149
7.07k
    int numsame = 0;
150
145k
    for (size_t i = 0; i < logcounts.size(); ++i) {
151
138k
      if (same[i]) {
152
        // RLE sequence, let this loop output the same count for the next
153
        // iterations.
154
2.99k
        numsame = same[i] - 1;
155
2.99k
        prev = i > 0 ? (*counts)[i - 1] : 0;
156
2.99k
      }
157
138k
      if (numsame > 0) {
158
39.9k
        (*counts)[i] = prev;
159
39.9k
        numsame--;
160
98.8k
      } else {
161
98.8k
        unsigned int code = logcounts[i];
162
        // omit_pos may not be negative at this point (checked before).
163
98.8k
        if (i == static_cast<size_t>(omit_pos)) {
164
7.07k
          continue;
165
91.7k
        } else if (code == 0) {
166
20.5k
          continue;
167
71.2k
        } else if (code == 1) {
168
3.62k
          (*counts)[i] = 1;
169
67.5k
        } else {
170
67.5k
          int bitcount = GetPopulationCountPrecision(code - 1, shift);
171
67.5k
          (*counts)[i] = (1u << (code - 1)) +
172
67.5k
                         (input->ReadBits(bitcount) << (code - 1 - bitcount));
173
67.5k
        }
174
98.8k
      }
175
111k
      total_count += (*counts)[i];
176
111k
    }
177
7.07k
    (*counts)[omit_pos] = range - total_count;
178
7.07k
    if ((*counts)[omit_pos] <= 0) {
179
      // The histogram we've read sums to more than total_count (including at
180
      // least 1 for the omitted value).
181
1
      return JXL_FAILURE("Invalid histogram count.");
182
1
    }
183
7.07k
  }
184
9.89k
  return true;
185
12.3k
}
186
187
}  // namespace
188
189
Status DecodeANSCodes(JxlMemoryManager* memory_manager,
190
                      const size_t num_histograms,
191
                      const size_t max_alphabet_size, BitReader* in,
192
37.7k
                      ANSCode* result) {
193
37.7k
  result->memory_manager = memory_manager;
194
37.7k
  result->degenerate_symbols.resize(num_histograms, -1);
195
37.7k
  if (result->use_prefix_code) {
196
30.9k
    JXL_ENSURE(max_alphabet_size <= 1 << PREFIX_MAX_BITS);
197
30.9k
    result->huffman_data.resize(num_histograms);
198
30.9k
    std::vector<uint16_t> alphabet_sizes(num_histograms);
199
65.4k
    for (size_t c = 0; c < num_histograms; c++) {
200
34.4k
      alphabet_sizes[c] = DecodeVarLenUint16(in) + 1;
201
34.4k
      if (alphabet_sizes[c] > max_alphabet_size) {
202
0
        return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]);
203
0
      }
204
34.4k
    }
205
64.9k
    for (size_t c = 0; c < num_histograms; c++) {
206
34.4k
      if (alphabet_sizes[c] > 1) {
207
13.1k
        if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) {
208
455
          if (!in->AllReadsWithinBounds()) {
209
439
            return JXL_STATUS(StatusCode::kNotEnoughBytes,
210
439
                              "Not enough bytes for huffman code");
211
439
          }
212
16
          return JXL_FAILURE("Invalid huffman tree number %" PRIuS
213
455
                             ", alphabet size %u",
214
455
                             c, alphabet_sizes[c]);
215
455
        }
216
21.2k
      } else {
217
        // 0-bit codes does not require extension tables.
218
21.2k
        result->huffman_data[c].table_.clear();
219
21.2k
        result->huffman_data[c].table_.resize(1u << kHuffmanTableBits);
220
21.2k
      }
221
8.71M
      for (const auto& h : result->huffman_data[c].table_) {
222
8.71M
        if (h.bits <= kHuffmanTableBits) {
223
8.71M
          result->UpdateMaxNumBits(c, h.value);
224
8.71M
        }
225
8.71M
      }
226
33.9k
    }
227
30.9k
  } else {
228
6.79k
    JXL_ENSURE(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE);
229
6.79k
    size_t alloc_size = num_histograms * (1 << result->log_alpha_size) *
230
6.79k
                        sizeof(AliasTable::Entry);
231
6.79k
    JXL_ASSIGN_OR_RETURN(result->alias_tables,
232
6.79k
                         AlignedMemory::Create(memory_manager, alloc_size));
233
6.79k
    AliasTable::Entry* alias_tables =
234
6.79k
        result->alias_tables.address<AliasTable::Entry>();
235
19.1k
    for (size_t c = 0; c < num_histograms; ++c) {
236
12.3k
      std::vector<int32_t> counts;
237
12.3k
      if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) {
238
3
        return JXL_FAILURE("Invalid histogram bitstream.");
239
3
      }
240
12.3k
      if (counts.size() > max_alphabet_size) {
241
1
        return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size());
242
1
      }
243
12.4k
      while (!counts.empty() && counts.back() == 0) {
244
128
        counts.pop_back();
245
128
      }
246
172k
      for (size_t s = 0; s < counts.size(); s++) {
247
160k
        if (counts[s] != 0) {
248
95.8k
          result->UpdateMaxNumBits(c, s);
249
95.8k
        }
250
160k
      }
251
      // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol.
252
12.3k
      int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1);
253
14.9k
      for (int s = 0; s < degenerate_symbol; ++s) {
254
10.1k
        if (counts[s] != 0) {
255
7.50k
          degenerate_symbol = -1;
256
7.50k
          break;
257
7.50k
        }
258
10.1k
      }
259
12.3k
      result->degenerate_symbols[c] = degenerate_symbol;
260
12.3k
      JXL_RETURN_IF_ERROR(
261
12.3k
          InitAliasTable(counts, ANS_LOG_TAB_SIZE, result->log_alpha_size,
262
12.3k
                         alias_tables + c * (1 << result->log_alpha_size)));
263
12.3k
    }
264
6.79k
  }
265
37.3k
  return true;
266
37.7k
}
267
Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config,
268
63.6k
                        BitReader* br) {
269
63.6k
  br->Refill();
270
63.6k
  size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1));
271
63.6k
  size_t msb_in_token = 0;
272
63.6k
  size_t lsb_in_token = 0;
273
63.6k
  if (split_exponent != log_alpha_size) {
274
    // otherwise, msb/lsb don't matter.
275
56.4k
    size_t nbits = CeilLog2Nonzero(split_exponent + 1);
276
56.4k
    msb_in_token = br->ReadBits(nbits);
277
56.4k
    if (msb_in_token > split_exponent) {
278
      // This could be invalid here already and we need to check this before
279
      // we use its value to read more bits.
280
4
      return JXL_FAILURE("Invalid HybridUintConfig");
281
4
    }
282
56.4k
    nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1);
283
56.4k
    lsb_in_token = br->ReadBits(nbits);
284
56.4k
  }
285
63.6k
  if (lsb_in_token + msb_in_token > split_exponent) {
286
4
    return JXL_FAILURE("Invalid HybridUintConfig");
287
4
  }
288
63.6k
  *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token);
289
63.6k
  return true;
290
63.6k
}
291
292
Status DecodeUintConfigs(size_t log_alpha_size,
293
                         std::vector<HybridUintConfig>* uint_config,
294
37.7k
                         BitReader* br) {
295
  // TODO(veluca): RLE?
296
46.8k
  for (auto& cfg : *uint_config) {
297
46.8k
    JXL_RETURN_IF_ERROR(DecodeUintConfig(log_alpha_size, &cfg, br));
298
46.8k
  }
299
37.7k
  return true;
300
37.7k
}
301
302
114k
LZ77Params::LZ77Params() { Bundle::Init(this); }
303
153k
Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) {
304
153k
  JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled));
305
153k
  if (!visitor->Conditional(enabled)) return true;
306
131k
  JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096),
307
131k
                                         BitsOffset(15, 8), 224, &min_symbol));
308
131k
  JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5),
309
131k
                                         BitsOffset(8, 9), 3, &min_length));
310
131k
  return true;
311
131k
}
312
313
8.81M
void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) {
314
8.81M
  HybridUintConfig* cfg = &uint_config[ctx];
315
  // LZ77 symbols use a different uint config.
316
8.81M
  if (lz77.enabled && lz77.nonserialized_distance_context != ctx &&
317
8.81M
      symbol >= lz77.min_symbol) {
318
6.62k
    symbol -= lz77.min_symbol;
319
6.62k
    cfg = &lz77.length_uint_config;
320
6.62k
  }
321
8.81M
  size_t split_token = cfg->split_token;
322
8.81M
  size_t msb_in_token = cfg->msb_in_token;
323
8.81M
  size_t lsb_in_token = cfg->lsb_in_token;
324
8.81M
  size_t split_exponent = cfg->split_exponent;
325
8.81M
  if (symbol < split_token) {
326
7.87M
    max_num_bits = std::max(max_num_bits, split_exponent);
327
7.87M
    return;
328
7.87M
  }
329
940k
  uint32_t n_extra_bits =
330
940k
      split_exponent - (msb_in_token + lsb_in_token) +
331
940k
      ((symbol - split_token) >> (msb_in_token + lsb_in_token));
332
940k
  size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1;
333
940k
  max_num_bits = std::max(max_num_bits, total_bits);
334
940k
}
335
336
Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br,
337
                        size_t num_contexts, ANSCode* code,
338
38.4k
                        std::vector<uint8_t>* context_map, bool disallow_lz77) {
339
38.4k
  JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77));
340
38.4k
  if (code->lz77.enabled) {
341
16.8k
    num_contexts++;
342
16.8k
    JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8,
343
16.8k
                                         &code->lz77.length_uint_config, br));
344
16.8k
  }
345
38.4k
  if (code->lz77.enabled && disallow_lz77) {
346
0
    return JXL_FAILURE("Using LZ77 when explicitly disallowed");
347
0
  }
348
38.4k
  size_t num_histograms = 1;
349
38.4k
  context_map->resize(num_contexts);
350
38.4k
  if (num_contexts > 1) {
351
24.7k
    JXL_RETURN_IF_ERROR(
352
24.7k
        DecodeContextMap(memory_manager, context_map, &num_histograms, br));
353
24.7k
  }
354
37.7k
  JXL_DEBUG_V(
355
37.7k
      4, "Decoded context map of size %" PRIuS " and %" PRIuS " histograms",
356
37.7k
      num_contexts, num_histograms);
357
37.7k
  code->lz77.nonserialized_distance_context = context_map->back();
358
37.7k
  code->use_prefix_code = static_cast<bool>(br->ReadFixedBits<1>());
359
37.7k
  if (code->use_prefix_code) {
360
30.9k
    code->log_alpha_size = PREFIX_MAX_BITS;
361
30.9k
  } else {
362
6.79k
    code->log_alpha_size = br->ReadFixedBits<2>() + 5;
363
6.79k
  }
364
37.7k
  code->uint_config.resize(num_histograms);
365
37.7k
  JXL_RETURN_IF_ERROR(
366
37.7k
      DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br));
367
37.7k
  const size_t max_alphabet_size = 1 << code->log_alpha_size;
368
37.7k
  JXL_RETURN_IF_ERROR(DecodeANSCodes(memory_manager, num_histograms,
369
37.7k
                                     max_alphabet_size, br, code));
370
37.3k
  return true;
371
37.7k
}
372
373
StatusOr<ANSSymbolReader> ANSSymbolReader::Create(const ANSCode* code,
374
                                                  BitReader* JXL_RESTRICT br,
375
49.9k
                                                  size_t distance_multiplier) {
376
49.9k
  AlignedMemory lz77_window_storage;
377
49.9k
  if (code->lz77.enabled) {
378
19.5k
    JxlMemoryManager* memory_manager = code->memory_manager;
379
19.5k
    JXL_ASSIGN_OR_RETURN(
380
19.5k
        lz77_window_storage,
381
19.5k
        AlignedMemory::Create(memory_manager, kWindowSize * sizeof(uint32_t)));
382
19.5k
  }
383
49.9k
  return ANSSymbolReader(code, br, distance_multiplier,
384
49.9k
                         std::move(lz77_window_storage));
385
49.9k
}
386
387
ANSSymbolReader::ANSSymbolReader(const ANSCode* code,
388
                                 BitReader* JXL_RESTRICT br,
389
                                 size_t distance_multiplier,
390
                                 AlignedMemory&& lz77_window_storage)
391
49.9k
    : alias_tables_(code->alias_tables.address<AliasTable::Entry>()),
392
49.9k
      huffman_data_(code->huffman_data.data()),
393
49.9k
      use_prefix_code_(code->use_prefix_code),
394
49.9k
      configs(code->uint_config.data()),
395
49.9k
      lz77_window_storage_(std::move(lz77_window_storage)) {
396
49.9k
  if (!use_prefix_code_) {
397
7.54k
    state_ = static_cast<uint32_t>(br->ReadFixedBits<32>());
398
7.54k
    log_alpha_size_ = code->log_alpha_size;
399
7.54k
    log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size;
400
7.54k
    entry_size_minus_1_ = (1 << log_entry_size_) - 1;
401
42.4k
  } else {
402
42.4k
    state_ = (ANS_SIGNATURE << 16u);
403
42.4k
  }
404
49.9k
  if (!code->lz77.enabled) return;
405
19.5k
  lz77_window_ = lz77_window_storage_.address<uint32_t>();
406
19.5k
  lz77_ctx_ = code->lz77.nonserialized_distance_context;
407
19.5k
  lz77_length_uint_ = code->lz77.length_uint_config;
408
19.5k
  lz77_threshold_ = code->lz77.min_symbol;
409
19.5k
  lz77_min_length_ = code->lz77.min_length;
410
19.5k
  num_special_distances_ = distance_multiplier == 0 ? 0 : kNumSpecialDistances;
411
372k
  for (size_t i = 0; i < num_special_distances_; i++) {
412
352k
    special_distances_[i] = SpecialDistance(i, distance_multiplier);
413
352k
  }
414
19.5k
}
415
416
}  // namespace jxl