Coverage Report

Created: 2025-08-12 07:37

/src/libjxl/lib/jxl/dec_ans.h
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#ifndef LIB_JXL_DEC_ANS_H_
7
#define LIB_JXL_DEC_ANS_H_
8
9
// Library to decode the ANS population counts from the bit-stream and build a
10
// decoding table from them.
11
12
#include <jxl/memory_manager.h>
13
#include <jxl/types.h>
14
15
#include <algorithm>
16
#include <cstddef>
17
#include <cstdint>
18
#include <cstring>
19
#include <vector>
20
21
#include "lib/jxl/ans_common.h"
22
#include "lib/jxl/ans_params.h"
23
#include "lib/jxl/base/bits.h"
24
#include "lib/jxl/base/compiler_specific.h"
25
#include "lib/jxl/base/status.h"
26
#include "lib/jxl/dec_bit_reader.h"
27
#include "lib/jxl/dec_huffman.h"
28
#include "lib/jxl/field_encodings.h"
29
#include "lib/jxl/memory_manager_internal.h"
30
31
namespace jxl {
32
33
class ANSSymbolReader;
34
35
// Experiments show that best performance is typically achieved for a
36
// split-exponent of 3 or 4. Trend seems to be that '4' is better
37
// for large-ish pictures, and '3' better for rather small-ish pictures.
38
// This is plausible - the more special symbols we have, the better
39
// statistics we need to get a benefit out of them.
40
41
// Our hybrid-encoding scheme has dedicated tokens for the smallest
42
// (1 << split_exponents) numbers, and for the rest
43
// encodes (number of bits) + (msb_in_token sub-leading binary digits) +
44
// (lsb_in_token lowest binary digits) in the token, with the remaining bits
45
// then being encoded as data.
46
//
47
// Example with split_exponent = 4, msb_in_token = 2, lsb_in_token = 0.
48
//
49
// Numbers N in [0 .. 15]:
50
//   These get represented as (token=N, bits='').
51
// Numbers N >= 16:
52
//   If n is such that 2**n <= N < 2**(n+1),
53
//   and m = N - 2**n is the 'mantissa',
54
//   these get represented as:
55
// (token=split_token +
56
//        ((n - split_exponent) * 4) +
57
//        (m >> (n - msb_in_token)),
58
//  bits=m & (1 << (n - msb_in_token)) - 1)
59
// Specifically, we would get:
60
// N = 0 - 15:          (token=N, nbits=0, bits='')
61
// N = 16 (10000):      (token=16, nbits=2, bits='00')
62
// N = 17 (10001):      (token=16, nbits=2, bits='01')
63
// N = 20 (10100):      (token=17, nbits=2, bits='00')
64
// N = 24 (11000):      (token=18, nbits=2, bits='00')
65
// N = 28 (11100):      (token=19, nbits=2, bits='00')
66
// N = 32 (100000):     (token=20, nbits=3, bits='000')
67
// N = 65535:           (token=63, nbits=13, bits='1111111111111')
68
struct HybridUintConfig {
69
  uint32_t split_exponent;
70
  uint32_t split_token;
71
  uint32_t msb_in_token;
72
  uint32_t lsb_in_token;
73
  JXL_INLINE void Encode(uint32_t value, uint32_t* JXL_RESTRICT token,
74
                         uint32_t* JXL_RESTRICT nbits,
75
66.1M
                         uint32_t* JXL_RESTRICT bits) const {
76
66.1M
    if (value < split_token) {
77
58.3M
      *token = value;
78
58.3M
      *nbits = 0;
79
58.3M
      *bits = 0;
80
58.3M
    } else {
81
7.78M
      uint32_t n = FloorLog2Nonzero(value);
82
7.78M
      uint32_t m = value - (1 << n);
83
7.78M
      *token = split_token +
84
7.78M
               ((n - split_exponent) << (msb_in_token + lsb_in_token)) +
85
7.78M
               ((m >> (n - msb_in_token)) << lsb_in_token) +
86
7.78M
               (m & ((1 << lsb_in_token) - 1));
87
7.78M
      *nbits = n - msb_in_token - lsb_in_token;
88
7.78M
      *bits = (value >> lsb_in_token) & ((1UL << *nbits) - 1);
89
7.78M
    }
90
66.1M
  }
91
92
37.1k
  JXL_INLINE uint32_t LsbMask() const { return (1 << lsb_in_token) - 1; }
93
94
  explicit HybridUintConfig(uint32_t split_exponent = 4,
95
                            uint32_t msb_in_token = 2,
96
                            uint32_t lsb_in_token = 0)
97
1.36M
      : split_exponent(split_exponent),
98
1.36M
        split_token(1 << split_exponent),
99
1.36M
        msb_in_token(msb_in_token),
100
1.36M
        lsb_in_token(lsb_in_token) {
101
1.36M
    JXL_DASSERT(split_exponent >= msb_in_token + lsb_in_token);
102
1.36M
  }
103
};
104
105
struct LZ77Params : public Fields {
106
  LZ77Params();
107
  JXL_FIELDS_NAME(LZ77Params)
108
  Status VisitFields(Visitor* JXL_RESTRICT visitor) override;
109
  bool enabled;
110
111
  // Symbols above min_symbol use a special hybrid uint encoding and
112
  // represent a length, to be added to min_length.
113
  uint32_t min_symbol;
114
  uint32_t min_length;
115
116
  // Not serialized by VisitFields.
117
  HybridUintConfig length_uint_config{0, 0, 0};
118
119
  size_t nonserialized_distance_context;
120
};
121
122
static constexpr size_t kWindowSize = 1 << 20;
123
static constexpr size_t kNumSpecialDistances = 120;
124
// Table of special distance codes from WebP lossless.
125
static constexpr int8_t kSpecialDistances[kNumSpecialDistances][2] = {
126
    {0, 1},  {1, 0},  {1, 1},  {-1, 1}, {0, 2},  {2, 0},  {1, 2},  {-1, 2},
127
    {2, 1},  {-2, 1}, {2, 2},  {-2, 2}, {0, 3},  {3, 0},  {1, 3},  {-1, 3},
128
    {3, 1},  {-3, 1}, {2, 3},  {-2, 3}, {3, 2},  {-3, 2}, {0, 4},  {4, 0},
129
    {1, 4},  {-1, 4}, {4, 1},  {-4, 1}, {3, 3},  {-3, 3}, {2, 4},  {-2, 4},
130
    {4, 2},  {-4, 2}, {0, 5},  {3, 4},  {-3, 4}, {4, 3},  {-4, 3}, {5, 0},
131
    {1, 5},  {-1, 5}, {5, 1},  {-5, 1}, {2, 5},  {-2, 5}, {5, 2},  {-5, 2},
132
    {4, 4},  {-4, 4}, {3, 5},  {-3, 5}, {5, 3},  {-5, 3}, {0, 6},  {6, 0},
133
    {1, 6},  {-1, 6}, {6, 1},  {-6, 1}, {2, 6},  {-2, 6}, {6, 2},  {-6, 2},
134
    {4, 5},  {-4, 5}, {5, 4},  {-5, 4}, {3, 6},  {-3, 6}, {6, 3},  {-6, 3},
135
    {0, 7},  {7, 0},  {1, 7},  {-1, 7}, {5, 5},  {-5, 5}, {7, 1},  {-7, 1},
136
    {4, 6},  {-4, 6}, {6, 4},  {-6, 4}, {2, 7},  {-2, 7}, {7, 2},  {-7, 2},
137
    {3, 7},  {-3, 7}, {7, 3},  {-7, 3}, {5, 6},  {-5, 6}, {6, 5},  {-6, 5},
138
    {8, 0},  {4, 7},  {-4, 7}, {7, 4},  {-7, 4}, {8, 1},  {8, 2},  {6, 6},
139
    {-6, 6}, {8, 3},  {5, 7},  {-5, 7}, {7, 5},  {-7, 5}, {8, 4},  {6, 7},
140
    {-6, 7}, {7, 6},  {-7, 6}, {8, 5},  {7, 7},  {-7, 7}, {8, 6},  {8, 7}};
141
405k
static JXL_INLINE int SpecialDistance(size_t index, int multiplier) {
142
405k
  int dist = kSpecialDistances[index][0] +
143
405k
             static_cast<int>(multiplier) * kSpecialDistances[index][1];
144
405k
  return (dist > 1) ? dist : 1;
145
405k
}
Unexecuted instantiation: encode.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_icc_codec.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_ans.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_ans_simd.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_cluster.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_context_map.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_lz77.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_fast_lossless.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_frame.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_modular.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_patch_dictionary.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_dot_dictionary.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_detect_dots.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_debug_image.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_quant_weights.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_modular_simd.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_coeff_order.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_heuristics.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_adaptive_quantization.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_cache.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_group.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_splines.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_chroma_from_luma.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_ac_strategy.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_entropy_coder.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_jpeg_data.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_encoding.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_ma.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_rct.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_transform.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: enc_palette.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: coeff_order.cc:jxl::SpecialDistance(unsigned long, int)
dec_ans.cc:jxl::SpecialDistance(unsigned long, int)
Line
Count
Source
141
405k
static JXL_INLINE int SpecialDistance(size_t index, int multiplier) {
142
405k
  int dist = kSpecialDistances[index][0] +
143
405k
             static_cast<int>(multiplier) * kSpecialDistances[index][1];
144
405k
  return (dist > 1) ? dist : 1;
145
405k
}
Unexecuted instantiation: dec_cache.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: dec_context_map.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: dec_external_image.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: dec_frame.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: dec_group.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: dec_modular.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: dec_noise.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: dec_patch_dictionary.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: decode.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: epf.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: icc_codec.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: dec_ma.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: encoding.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: modular_image.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: transform.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: rct.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: palette.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: quant_weights.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: stage_blending.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: stage_epf.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: stage_write.cc:jxl::SpecialDistance(unsigned long, int)
Unexecuted instantiation: splines.cc:jxl::SpecialDistance(unsigned long, int)
146
147
struct ANSCode {
148
  AlignedMemory alias_tables;
149
  std::vector<HuffmanDecodingData> huffman_data;
150
  std::vector<HybridUintConfig> uint_config;
151
  std::vector<int> degenerate_symbols;
152
  bool use_prefix_code;
153
  uint8_t log_alpha_size;  // for ANS.
154
  LZ77Params lz77;
155
  // Maximum number of bits necessary to represent the result of a
156
  // ReadHybridUint call done with this ANSCode.
157
  size_t max_num_bits = 0;
158
  JxlMemoryManager* memory_manager;
159
  void UpdateMaxNumBits(size_t ctx, size_t symbol);
160
};
161
162
class ANSSymbolReader {
163
 public:
164
  // Invalid symbol reader, to be overwritten.
165
56.3k
  ANSSymbolReader() = default;
166
  static StatusOr<ANSSymbolReader> Create(const ANSCode* code,
167
                                          BitReader* JXL_RESTRICT br,
168
                                          size_t distance_multiplier = 0);
169
170
  JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx,
171
20.8M
                                               BitReader* JXL_RESTRICT br) {
172
20.8M
    const uint32_t res = state_ & (ANS_TAB_SIZE - 1u);
173
174
20.8M
    const AliasTable::Entry* table =
175
20.8M
        &alias_tables_[histo_idx << log_alpha_size_];
176
20.8M
    const AliasTable::Symbol symbol =
177
20.8M
        AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_);
178
20.8M
    state_ = symbol.freq * (state_ >> ANS_LOG_TAB_SIZE) + symbol.offset;
179
180
20.8M
#if JXL_TRUE
181
    // Branchless version is about equally fast on SKX.
182
20.8M
    const uint32_t new_state =
183
20.8M
        (state_ << 16u) | static_cast<uint32_t>(br->PeekFixedBits<16>());
184
20.8M
    const bool normalize = state_ < (1u << 16u);
185
20.8M
    state_ = normalize ? new_state : state_;
186
20.8M
    br->Consume(normalize ? 16 : 0);
187
#else
188
    if (JXL_UNLIKELY(state_ < (1u << 16u))) {
189
      state_ = (state_ << 16u) | br->PeekFixedBits<16>();
190
      br->Consume(16);
191
    }
192
#endif
193
20.8M
    const uint32_t next_res = state_ & (ANS_TAB_SIZE - 1u);
194
20.8M
    AliasTable::Prefetch(table, next_res, log_entry_size_);
195
196
20.8M
    return symbol.value;
197
20.8M
  }
198
199
  JXL_INLINE size_t ReadSymbolHuffWithoutRefill(const size_t histo_idx,
200
162M
                                                BitReader* JXL_RESTRICT br) {
201
162M
    return huffman_data_[histo_idx].ReadSymbol(br);
202
162M
  }
203
204
  JXL_INLINE size_t ReadSymbolWithoutRefill(const size_t histo_idx,
205
183M
                                            BitReader* JXL_RESTRICT br) {
206
    // TODO(veluca): hoist if in hotter loops.
207
183M
    if (JXL_UNLIKELY(use_prefix_code_)) {
208
162M
      return ReadSymbolHuffWithoutRefill(histo_idx, br);
209
162M
    }
210
20.8M
    return ReadSymbolANSWithoutRefill(histo_idx, br);
211
183M
  }
212
213
  JXL_INLINE size_t ReadSymbol(const size_t histo_idx,
214
0
                               BitReader* JXL_RESTRICT br) {
215
0
    br->Refill();
216
0
    return ReadSymbolWithoutRefill(histo_idx, br);
217
0
  }
218
219
#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
220
88.7k
  bool CheckANSFinalState() const { return true; }
221
#else
222
  bool CheckANSFinalState() const { return state_ == (ANS_SIGNATURE << 16u); }
223
#endif
224
225
  template <typename BitReader>
226
  static JXL_INLINE uint32_t ReadHybridUintConfig(
227
183M
      const HybridUintConfig& config, size_t token, BitReader* br) {
228
183M
    size_t split_token = config.split_token;
229
183M
    size_t msb_in_token = config.msb_in_token;
230
183M
    size_t lsb_in_token = config.lsb_in_token;
231
183M
    size_t split_exponent = config.split_exponent;
232
    // Fast-track version of hybrid integer decoding.
233
183M
    if (token < split_token) return token;
234
3.16M
    uint32_t nbits = split_exponent - (msb_in_token + lsb_in_token) +
235
3.16M
                     ((token - split_token) >> (msb_in_token + lsb_in_token));
236
    // Max amount of bits for ReadBits is 32 and max valid left shift is 29
237
    // bits. However, for speed no error is propagated here, instead limit the
238
    // nbits size. If nbits > 29, the code stream is invalid, but no error is
239
    // returned.
240
    // Note that in most cases we will emit an error if the histogram allows
241
    // representing numbers that would cause invalid shifts, but we need to
242
    // keep this check as when LZ77 is enabled it might make sense to have an
243
    // histogram that could in principle cause invalid shifts.
244
3.16M
    nbits &= 31u;
245
3.16M
    uint32_t low = token & ((1 << lsb_in_token) - 1);
246
3.16M
    token >>= lsb_in_token;
247
3.16M
    const size_t bits = br->PeekBits(nbits);
248
3.16M
    br->Consume(nbits);
249
3.16M
    size_t ret = (((((1 << msb_in_token) | (token & ((1 << msb_in_token) - 1)))
250
3.16M
                    << nbits) |
251
3.16M
                   bits)
252
3.16M
                  << lsb_in_token) |
253
3.16M
                 low;
254
    // TODO(eustas): mark BitReader as unhealthy if nbits > 29 or ret does not
255
    //               fit uint32_t
256
3.16M
    return static_cast<uint32_t>(ret);
257
183M
  }
258
259
  // Takes a *clustered* idx. Can only use if HuffRleOnly() is true.
260
  JXL_INLINE void ReadHybridUintClusteredHuffRleOnly(size_t ctx,
261
                                                     BitReader* JXL_RESTRICT br,
262
                                                     uint32_t* value,
263
48
                                                     uint32_t* run) {
264
48
    JXL_DASSERT(IsHuffRleOnly());
265
48
    br->Refill();  // covers ReadSymbolWithoutRefill + PeekBits
266
48
    size_t token = ReadSymbolHuffWithoutRefill(ctx, br);
267
48
    if (JXL_UNLIKELY(token >= lz77_threshold_)) {
268
0
      *run =
269
0
          ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) +
270
0
          lz77_min_length_ - 1;
271
0
      return;
272
0
    }
273
48
    *value = ReadHybridUintConfig(configs[ctx], token, br);
274
48
  }
275
4
  bool IsHuffRleOnly() const {
276
4
    if (lz77_window_ == nullptr) return false;
277
4
    if (!use_prefix_code_) return false;
278
36
    for (size_t i = 0; i < kHuffmanTableBits; i++) {
279
32
      if (huffman_data_[lz77_ctx_].table_[i].bits) return false;
280
32
      if (huffman_data_[lz77_ctx_].table_[i].value != 1) return false;
281
32
    }
282
4
    if (configs[lz77_ctx_].split_token > 1) return false;
283
4
    return true;
284
4
  }
285
868k
  bool UsesLZ77() { return lz77_window_ != nullptr; }
286
287
  // Takes a *clustered* idx. Inlined, for use in hot paths.
288
  template <bool uses_lz77>
289
  JXL_INLINE size_t ReadHybridUintClusteredInlined(size_t ctx,
290
199M
                                                   BitReader* JXL_RESTRICT br) {
291
199M
    if (uses_lz77) {
292
97.4M
      if (JXL_UNLIKELY(num_to_copy_ > 0)) {
293
16.7M
        size_t ret = lz77_window_[(copy_pos_++) & kWindowMask];
294
16.7M
        num_to_copy_--;
295
16.7M
        lz77_window_[(num_decoded_++) & kWindowMask] = ret;
296
16.7M
        return ret;
297
16.7M
      }
298
97.4M
    }
299
300
182M
    br->Refill();  // covers ReadSymbolWithoutRefill + PeekBits
301
182M
    size_t token = ReadSymbolWithoutRefill(ctx, br);
302
182M
    if (uses_lz77) {
303
80.7M
      if (JXL_UNLIKELY(token >= lz77_threshold_)) {
304
6.71k
        num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_,
305
6.71k
                                            token - lz77_threshold_, br) +
306
6.71k
                       lz77_min_length_;
307
6.71k
        br->Refill();  // covers ReadSymbolWithoutRefill + PeekBits
308
        // Distance code.
309
6.71k
        size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br);
310
6.71k
        size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br);
311
6.71k
        if (JXL_LIKELY(distance < num_special_distances_)) {
312
419
          distance = special_distances_[distance];
313
6.29k
        } else {
314
6.29k
          distance = distance + 1 - num_special_distances_;
315
6.29k
        }
316
6.71k
        if (JXL_UNLIKELY(distance > num_decoded_)) {
317
2.93k
          distance = num_decoded_;
318
2.93k
        }
319
6.71k
        if (JXL_UNLIKELY(distance > kWindowSize)) {
320
50
          distance = kWindowSize;
321
50
        }
322
6.71k
        copy_pos_ = num_decoded_ - distance;
323
6.71k
        if (JXL_UNLIKELY(distance == 0)) {
324
1.00k
          JXL_DASSERT(lz77_window_ != nullptr);
325
          // distance 0 -> num_decoded_ == copy_pos_ == 0
326
1.00k
          size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize);
327
1.00k
          memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0]));
328
1.00k
        }
329
        // TODO(eustas): overflow; mark BitReader as unhealthy
330
6.71k
        if (num_to_copy_ < lz77_min_length_) return 0;
331
        // the code below is the same as doing this:
332
        //        return ReadHybridUintClustered<uses_lz77>(ctx, br);
333
        // but gcc doesn't like recursive inlining
334
335
6.23k
        size_t ret = lz77_window_[(copy_pos_++) & kWindowMask];
336
6.23k
        num_to_copy_--;
337
6.23k
        lz77_window_[(num_decoded_++) & kWindowMask] = ret;
338
6.23k
        return ret;
339
6.71k
      }
340
80.7M
    }
341
182M
    size_t ret = ReadHybridUintConfig(configs[ctx], token, br);
342
182M
    if (uses_lz77 && lz77_window_)
343
62.6M
      lz77_window_[(num_decoded_++) & kWindowMask] = ret;
344
182M
    return ret;
345
182M
  }
unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredInlined<true>(unsigned long, jxl::BitReader*)
Line
Count
Source
290
97.4M
                                                   BitReader* JXL_RESTRICT br) {
291
97.4M
    if (uses_lz77) {
292
97.4M
      if (JXL_UNLIKELY(num_to_copy_ > 0)) {
293
16.7M
        size_t ret = lz77_window_[(copy_pos_++) & kWindowMask];
294
16.7M
        num_to_copy_--;
295
16.7M
        lz77_window_[(num_decoded_++) & kWindowMask] = ret;
296
16.7M
        return ret;
297
16.7M
      }
298
97.4M
    }
299
300
80.7M
    br->Refill();  // covers ReadSymbolWithoutRefill + PeekBits
301
80.7M
    size_t token = ReadSymbolWithoutRefill(ctx, br);
302
80.7M
    if (uses_lz77) {
303
80.7M
      if (JXL_UNLIKELY(token >= lz77_threshold_)) {
304
6.71k
        num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_,
305
6.71k
                                            token - lz77_threshold_, br) +
306
6.71k
                       lz77_min_length_;
307
6.71k
        br->Refill();  // covers ReadSymbolWithoutRefill + PeekBits
308
        // Distance code.
309
6.71k
        size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br);
310
6.71k
        size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br);
311
6.71k
        if (JXL_LIKELY(distance < num_special_distances_)) {
312
419
          distance = special_distances_[distance];
313
6.29k
        } else {
314
6.29k
          distance = distance + 1 - num_special_distances_;
315
6.29k
        }
316
6.71k
        if (JXL_UNLIKELY(distance > num_decoded_)) {
317
2.93k
          distance = num_decoded_;
318
2.93k
        }
319
6.71k
        if (JXL_UNLIKELY(distance > kWindowSize)) {
320
50
          distance = kWindowSize;
321
50
        }
322
6.71k
        copy_pos_ = num_decoded_ - distance;
323
6.71k
        if (JXL_UNLIKELY(distance == 0)) {
324
1.00k
          JXL_DASSERT(lz77_window_ != nullptr);
325
          // distance 0 -> num_decoded_ == copy_pos_ == 0
326
1.00k
          size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize);
327
1.00k
          memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0]));
328
1.00k
        }
329
        // TODO(eustas): overflow; mark BitReader as unhealthy
330
6.71k
        if (num_to_copy_ < lz77_min_length_) return 0;
331
        // the code below is the same as doing this:
332
        //        return ReadHybridUintClustered<uses_lz77>(ctx, br);
333
        // but gcc doesn't like recursive inlining
334
335
6.23k
        size_t ret = lz77_window_[(copy_pos_++) & kWindowMask];
336
6.23k
        num_to_copy_--;
337
6.23k
        lz77_window_[(num_decoded_++) & kWindowMask] = ret;
338
6.23k
        return ret;
339
6.71k
      }
340
80.7M
    }
341
80.7M
    size_t ret = ReadHybridUintConfig(configs[ctx], token, br);
342
80.7M
    if (uses_lz77 && lz77_window_)
343
62.6M
      lz77_window_[(num_decoded_++) & kWindowMask] = ret;
344
80.7M
    return ret;
345
80.7M
  }
unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredInlined<false>(unsigned long, jxl::BitReader*)
Line
Count
Source
290
102M
                                                   BitReader* JXL_RESTRICT br) {
291
102M
    if (uses_lz77) {
292
0
      if (JXL_UNLIKELY(num_to_copy_ > 0)) {
293
0
        size_t ret = lz77_window_[(copy_pos_++) & kWindowMask];
294
0
        num_to_copy_--;
295
0
        lz77_window_[(num_decoded_++) & kWindowMask] = ret;
296
0
        return ret;
297
0
      }
298
0
    }
299
300
102M
    br->Refill();  // covers ReadSymbolWithoutRefill + PeekBits
301
102M
    size_t token = ReadSymbolWithoutRefill(ctx, br);
302
102M
    if (uses_lz77) {
303
0
      if (JXL_UNLIKELY(token >= lz77_threshold_)) {
304
0
        num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_,
305
0
                                            token - lz77_threshold_, br) +
306
0
                       lz77_min_length_;
307
0
        br->Refill();  // covers ReadSymbolWithoutRefill + PeekBits
308
        // Distance code.
309
0
        size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br);
310
0
        size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br);
311
0
        if (JXL_LIKELY(distance < num_special_distances_)) {
312
0
          distance = special_distances_[distance];
313
0
        } else {
314
0
          distance = distance + 1 - num_special_distances_;
315
0
        }
316
0
        if (JXL_UNLIKELY(distance > num_decoded_)) {
317
0
          distance = num_decoded_;
318
0
        }
319
0
        if (JXL_UNLIKELY(distance > kWindowSize)) {
320
0
          distance = kWindowSize;
321
0
        }
322
0
        copy_pos_ = num_decoded_ - distance;
323
0
        if (JXL_UNLIKELY(distance == 0)) {
324
0
          JXL_DASSERT(lz77_window_ != nullptr);
325
          // distance 0 -> num_decoded_ == copy_pos_ == 0
326
0
          size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize);
327
0
          memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0]));
328
0
        }
329
        // TODO(eustas): overflow; mark BitReader as unhealthy
330
0
        if (num_to_copy_ < lz77_min_length_) return 0;
331
        // the code below is the same as doing this:
332
        //        return ReadHybridUintClustered<uses_lz77>(ctx, br);
333
        // but gcc doesn't like recursive inlining
334
335
0
        size_t ret = lz77_window_[(copy_pos_++) & kWindowMask];
336
0
        num_to_copy_--;
337
0
        lz77_window_[(num_decoded_++) & kWindowMask] = ret;
338
0
        return ret;
339
0
      }
340
0
    }
341
102M
    size_t ret = ReadHybridUintConfig(configs[ctx], token, br);
342
102M
    if (uses_lz77 && lz77_window_)
343
0
      lz77_window_[(num_decoded_++) & kWindowMask] = ret;
344
102M
    return ret;
345
102M
  }
346
347
  // same but not inlined
348
  template <bool uses_lz77>
349
36.6M
  size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) {
350
36.6M
    return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br);
351
36.6M
  }
unsigned long jxl::ANSSymbolReader::ReadHybridUintClustered<true>(unsigned long, jxl::BitReader*)
Line
Count
Source
349
35.1M
  size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) {
350
35.1M
    return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br);
351
35.1M
  }
unsigned long jxl::ANSSymbolReader::ReadHybridUintClustered<false>(unsigned long, jxl::BitReader*)
Line
Count
Source
349
1.45M
  size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) {
350
1.45M
    return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br);
351
1.45M
  }
352
353
  // inlined only in the no-lz77 case
354
  template <bool uses_lz77>
355
  JXL_INLINE size_t
356
14.8M
  ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) {
357
14.8M
    if (uses_lz77) {
358
13.6M
      return ReadHybridUintClustered<uses_lz77>(ctx, br);
359
13.6M
    } else {
360
1.19M
      return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br);
361
1.19M
    }
362
14.8M
  }
unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredMaybeInlined<true>(unsigned long, jxl::BitReader*)
Line
Count
Source
356
13.6M
  ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) {
357
13.6M
    if (uses_lz77) {
358
13.6M
      return ReadHybridUintClustered<uses_lz77>(ctx, br);
359
13.6M
    } else {
360
0
      return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br);
361
0
    }
362
13.6M
  }
unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredMaybeInlined<false>(unsigned long, jxl::BitReader*)
Line
Count
Source
356
1.19M
  ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) {
357
1.19M
    if (uses_lz77) {
358
0
      return ReadHybridUintClustered<uses_lz77>(ctx, br);
359
1.19M
    } else {
360
1.19M
      return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br);
361
1.19M
    }
362
1.19M
  }
363
364
  // inlined, for use in hot paths
365
  template <bool uses_lz77>
366
  JXL_INLINE size_t
367
  ReadHybridUintInlined(size_t ctx, BitReader* JXL_RESTRICT br,
368
10.9M
                        const std::vector<uint8_t>& context_map) {
369
10.9M
    return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br);
370
10.9M
  }
unsigned long jxl::ANSSymbolReader::ReadHybridUintInlined<true>(unsigned long, jxl::BitReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&)
Line
Count
Source
368
9.95M
                        const std::vector<uint8_t>& context_map) {
369
9.95M
    return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br);
370
9.95M
  }
unsigned long jxl::ANSSymbolReader::ReadHybridUintInlined<false>(unsigned long, jxl::BitReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&)
Line
Count
Source
368
1.01M
                        const std::vector<uint8_t>& context_map) {
369
1.01M
    return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br);
370
1.01M
  }
371
372
  // not inlined, for use in non-hot paths
373
  size_t ReadHybridUint(size_t ctx, BitReader* JXL_RESTRICT br,
374
11.3M
                        const std::vector<uint8_t>& context_map) {
375
11.3M
    return ReadHybridUintClustered</*uses_lz77=*/true>(context_map[ctx], br);
376
11.3M
  }
377
378
  // ctx is a *clustered* context!
379
  // This function will modify the ANS state as if `count` symbols have been
380
  // decoded.
381
217k
  bool IsSingleValueAndAdvance(size_t ctx, uint32_t* value, size_t count) {
382
    // TODO(veluca): No optimization for Huffman mode yet.
383
217k
    if (use_prefix_code_) return false;
384
    // TODO(eustas): Check if we could deal with copy tail as well.
385
101k
    if (num_to_copy_ != 0) return false;
386
    // TODO(eustas): propagate "degenerate_symbol" to simplify this method.
387
100k
    const uint32_t res = state_ & (ANS_TAB_SIZE - 1u);
388
100k
    const AliasTable::Entry* table = &alias_tables_[ctx << log_alpha_size_];
389
100k
    AliasTable::Symbol symbol =
390
100k
        AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_);
391
100k
    if (symbol.freq != ANS_TAB_SIZE) return false;
392
88.2k
    if (configs[ctx].split_token <= symbol.value) return false;
393
87.9k
    if (symbol.value >= lz77_threshold_) return false;
394
87.9k
    *value = symbol.value;
395
87.9k
    if (lz77_window_) {
396
18.9M
      for (size_t i = 0; i < count; i++) {
397
18.9M
        lz77_window_[(num_decoded_++) & kWindowMask] = symbol.value;
398
18.9M
      }
399
8.38k
    }
400
87.9k
    return true;
401
87.9k
  }
402
403
  static constexpr size_t kMaxCheckpointInterval = 512;
404
  struct Checkpoint {
405
    uint32_t state;
406
    uint32_t num_to_copy;
407
    uint32_t copy_pos;
408
    uint32_t num_decoded;
409
    uint32_t lz77_window[kMaxCheckpointInterval];
410
  };
411
11.8k
  void Save(Checkpoint* checkpoint) {
412
11.8k
    checkpoint->state = state_;
413
11.8k
    checkpoint->num_decoded = num_decoded_;
414
11.8k
    checkpoint->num_to_copy = num_to_copy_;
415
11.8k
    checkpoint->copy_pos = copy_pos_;
416
11.8k
    if (lz77_window_) {
417
3.57k
      size_t win_start = num_decoded_ & kWindowMask;
418
3.57k
      size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask;
419
3.57k
      if (win_end > win_start) {
420
3.57k
        memcpy(checkpoint->lz77_window, lz77_window_ + win_start,
421
3.57k
               (win_end - win_start) * sizeof(*lz77_window_));
422
3.57k
      } else {
423
0
        memcpy(checkpoint->lz77_window, lz77_window_ + win_start,
424
0
               (kWindowSize - win_start) * sizeof(*lz77_window_));
425
0
        memcpy(checkpoint->lz77_window + (kWindowSize - win_start),
426
0
               lz77_window_, win_end * sizeof(*lz77_window_));
427
0
      }
428
3.57k
    }
429
11.8k
  }
430
1.55k
  void Restore(const Checkpoint& checkpoint) {
431
1.55k
    state_ = checkpoint.state;
432
1.55k
    JXL_DASSERT(num_decoded_ <=
433
1.55k
                checkpoint.num_decoded + kMaxCheckpointInterval);
434
1.55k
    num_decoded_ = checkpoint.num_decoded;
435
1.55k
    num_to_copy_ = checkpoint.num_to_copy;
436
1.55k
    copy_pos_ = checkpoint.copy_pos;
437
1.55k
    if (lz77_window_) {
438
855
      size_t win_start = num_decoded_ & kWindowMask;
439
855
      size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask;
440
855
      if (win_end > win_start) {
441
855
        memcpy(lz77_window_ + win_start, checkpoint.lz77_window,
442
855
               (win_end - win_start) * sizeof(*lz77_window_));
443
855
      } else {
444
0
        memcpy(lz77_window_ + win_start, checkpoint.lz77_window,
445
0
               (kWindowSize - win_start) * sizeof(*lz77_window_));
446
0
        memcpy(lz77_window_, checkpoint.lz77_window + (kWindowSize - win_start),
447
0
               win_end * sizeof(*lz77_window_));
448
0
      }
449
855
    }
450
1.55k
  }
451
452
 private:
453
  ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br,
454
                  size_t distance_multiplier,
455
                  AlignedMemory&& lz77_window_storage);
456
457
  const AliasTable::Entry* JXL_RESTRICT alias_tables_;  // not owned
458
  const HuffmanDecodingData* huffman_data_;
459
  bool use_prefix_code_;
460
  uint32_t state_ = ANS_SIGNATURE << 16u;
461
  const HybridUintConfig* JXL_RESTRICT configs;
462
  uint32_t log_alpha_size_{};
463
  uint32_t log_entry_size_{};
464
  uint32_t entry_size_minus_1_{};
465
466
  // LZ77 structures and constants.
467
  static constexpr size_t kWindowMask = kWindowSize - 1;
468
  // a std::vector incurs unacceptable decoding speed loss because of
469
  // initialization.
470
  AlignedMemory lz77_window_storage_;
471
  uint32_t* lz77_window_ = nullptr;
472
  uint32_t num_decoded_ = 0;
473
  uint32_t num_to_copy_ = 0;
474
  uint32_t copy_pos_ = 0;
475
  uint32_t lz77_ctx_ = 0;
476
  uint32_t lz77_min_length_ = 0;
477
  uint32_t lz77_threshold_ = 1 << 20;  // bigger than any symbol.
478
  HybridUintConfig lz77_length_uint_;
479
  uint32_t special_distances_[kNumSpecialDistances]{};
480
  uint32_t num_special_distances_{};
481
};
482
483
Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br,
484
                        size_t num_contexts, ANSCode* code,
485
                        std::vector<uint8_t>* context_map,
486
                        bool disallow_lz77 = false);
487
488
// Exposed for tests.
489
Status DecodeUintConfigs(size_t log_alpha_size,
490
                         std::vector<HybridUintConfig>* uint_config,
491
                         BitReader* br);
492
493
}  // namespace jxl
494
495
#endif  // LIB_JXL_DEC_ANS_H_