Coverage Report

Created: 2025-06-16 07:00

/src/libjxl/lib/jxl/enc_ans.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_ans.h"
7
8
#include <jxl/memory_manager.h>
9
#include <jxl/types.h>
10
11
#include <algorithm>
12
#include <array>
13
#include <cmath>
14
#include <cstddef>
15
#include <cstdint>
16
#include <limits>
17
#include <utility>
18
#include <vector>
19
20
#include "lib/jxl/ans_common.h"
21
#include "lib/jxl/ans_params.h"
22
#include "lib/jxl/base/bits.h"
23
#include "lib/jxl/base/common.h"
24
#include "lib/jxl/base/compiler_specific.h"
25
#include "lib/jxl/base/status.h"
26
#include "lib/jxl/common.h"
27
#include "lib/jxl/dec_ans.h"
28
#include "lib/jxl/enc_ans_params.h"
29
#include "lib/jxl/enc_aux_out.h"
30
#include "lib/jxl/enc_cluster.h"
31
#include "lib/jxl/enc_context_map.h"
32
#include "lib/jxl/enc_fields.h"
33
#include "lib/jxl/enc_huffman.h"
34
#include "lib/jxl/enc_lz77.h"
35
#include "lib/jxl/enc_params.h"
36
#include "lib/jxl/fields.h"
37
#include "lib/jxl/modular/options.h"
38
39
namespace jxl {
40
41
namespace {
42
43
#if (!JXL_IS_DEBUG_BUILD)
44
constexpr
45
#endif
46
    bool ans_fuzzer_friendly_ = false;
47
48
const int kMaxNumSymbolsForSmallCode = 2;
49
50
template <typename Writer>
51
535k
void StoreVarLenUint8(size_t n, Writer* writer) {
52
535k
  JXL_DASSERT(n <= 255);
53
535k
  if (n == 0) {
54
45.1k
    writer->Write(1, 0);
55
490k
  } else {
56
490k
    writer->Write(1, 1);
57
490k
    size_t nbits = FloorLog2Nonzero(n);
58
490k
    writer->Write(3, nbits);
59
490k
    writer->Write(nbits, n - (1ULL << nbits));
60
490k
  }
61
535k
}
enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint8<jxl::SizeWriter>(unsigned long, jxl::SizeWriter*)
Line
Count
Source
51
509k
void StoreVarLenUint8(size_t n, Writer* writer) {
52
509k
  JXL_DASSERT(n <= 255);
53
509k
  if (n == 0) {
54
42.8k
    writer->Write(1, 0);
55
466k
  } else {
56
466k
    writer->Write(1, 1);
57
466k
    size_t nbits = FloorLog2Nonzero(n);
58
466k
    writer->Write(3, nbits);
59
466k
    writer->Write(nbits, n - (1ULL << nbits));
60
466k
  }
61
509k
}
enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint8<jxl::BitWriter>(unsigned long, jxl::BitWriter*)
Line
Count
Source
51
25.9k
void StoreVarLenUint8(size_t n, Writer* writer) {
52
25.9k
  JXL_DASSERT(n <= 255);
53
25.9k
  if (n == 0) {
54
2.24k
    writer->Write(1, 0);
55
23.6k
  } else {
56
23.6k
    writer->Write(1, 1);
57
23.6k
    size_t nbits = FloorLog2Nonzero(n);
58
23.6k
    writer->Write(3, nbits);
59
23.6k
    writer->Write(nbits, n - (1ULL << nbits));
60
23.6k
  }
61
25.9k
}
62
63
template <typename Writer>
64
2.29k
void StoreVarLenUint16(size_t n, Writer* writer) {
65
2.29k
  JXL_DASSERT(n <= 65535);
66
2.29k
  if (n == 0) {
67
42
    writer->Write(1, 0);
68
2.24k
  } else {
69
2.24k
    writer->Write(1, 1);
70
2.24k
    size_t nbits = FloorLog2Nonzero(n);
71
2.24k
    writer->Write(4, nbits);
72
2.24k
    writer->Write(nbits, n - (1ULL << nbits));
73
2.24k
  }
74
2.29k
}
enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint16<jxl::BitWriter>(unsigned long, jxl::BitWriter*)
Line
Count
Source
64
540
void StoreVarLenUint16(size_t n, Writer* writer) {
65
540
  JXL_DASSERT(n <= 65535);
66
540
  if (n == 0) {
67
42
    writer->Write(1, 0);
68
498
  } else {
69
498
    writer->Write(1, 1);
70
498
    size_t nbits = FloorLog2Nonzero(n);
71
498
    writer->Write(4, nbits);
72
498
    writer->Write(nbits, n - (1ULL << nbits));
73
498
  }
74
540
}
enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint16<jxl::SizeWriter>(unsigned long, jxl::SizeWriter*)
Line
Count
Source
64
1.75k
void StoreVarLenUint16(size_t n, Writer* writer) {
65
1.75k
  JXL_DASSERT(n <= 65535);
66
1.75k
  if (n == 0) {
67
0
    writer->Write(1, 0);
68
1.75k
  } else {
69
1.75k
    writer->Write(1, 1);
70
1.75k
    size_t nbits = FloorLog2Nonzero(n);
71
1.75k
    writer->Write(4, nbits);
72
1.75k
    writer->Write(nbits, n - (1ULL << nbits));
73
1.75k
  }
74
1.75k
}
75
76
class ANSEncodingHistogram {
77
 public:
78
19.9k
  const std::vector<ANSHistBin>& Counts() const { return counts_; }
79
69.9k
  float Cost() const { return cost_; }
80
  // The only way to construct valid histogram for ANS encoding
81
  static StatusOr<ANSEncodingHistogram> ComputeBest(
82
      const Histogram& histo,
83
69.9k
      HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
84
69.9k
    ANSEncodingHistogram result;
85
86
69.9k
    result.alphabet_size_ = histo.alphabet_size();
87
69.9k
    if (result.alphabet_size_ > ANS_MAX_ALPHABET_SIZE)
88
0
      return JXL_FAILURE("Too many entries in an ANS histogram");
89
90
69.9k
    if (result.alphabet_size_ > 0) {
91
      // Flat code
92
69.9k
      result.method_ = 0;
93
69.9k
      result.num_symbols_ = result.alphabet_size_;
94
69.9k
      result.counts_ = CreateFlatHistogram(result.alphabet_size_, ANS_TAB_SIZE);
95
      // in this case length can be non-suitable for SIMD - fix it
96
69.9k
      result.counts_.resize(histo.counts.size());
97
69.9k
      SizeWriter writer;
98
69.9k
      JXL_RETURN_IF_ERROR(result.Encode(&writer));
99
69.9k
      result.cost_ = writer.size + EstimateDataBitsFlat(histo);
100
69.9k
    } else {
101
      // Empty histogram
102
0
      result.method_ = 1;
103
0
      result.num_symbols_ = 0;
104
0
      result.cost_ = 3;
105
0
      return result;
106
0
    }
107
108
69.9k
    size_t symbol_count = 0;
109
2.78M
    for (size_t n = 0; n < result.alphabet_size_; ++n) {
110
2.71M
      if (histo.counts[n] > 0) {
111
1.38M
        if (symbol_count < kMaxNumSymbolsForSmallCode) {
112
133k
          result.symbols_[symbol_count] = n;
113
133k
        }
114
1.38M
        ++symbol_count;
115
1.38M
      }
116
2.71M
    }
117
69.9k
    result.num_symbols_ = symbol_count;
118
69.9k
    if (symbol_count == 1) {
119
      // Single-bin histogram
120
6.84k
      result.method_ = 1;
121
6.84k
      result.counts_ = histo.counts;
122
6.84k
      result.counts_[result.symbols_[0]] = ANS_TAB_SIZE;
123
6.84k
      SizeWriter writer;
124
6.84k
      JXL_RETURN_IF_ERROR(result.Encode(&writer));
125
6.84k
      result.cost_ = writer.size;
126
6.84k
      return result;
127
6.84k
    }
128
129
    // Here min 2 symbols
130
63.1k
    ANSEncodingHistogram normalized = result;
131
272k
    auto try_shift = [&](uint32_t shift) -> Status {
132
      // `shift = 12` and `shift = 11` are the same
133
272k
      normalized.method_ = std::min(shift, ANS_LOG_TAB_SIZE - 1) + 1;
134
135
272k
      if (!normalized.RebalanceHistogram(histo)) {
136
0
        return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
137
0
      }
138
272k
      SizeWriter writer;
139
272k
      JXL_RETURN_IF_ERROR(normalized.Encode(&writer));
140
272k
      normalized.cost_ = writer.size + normalized.EstimateDataBits(histo);
141
272k
      if (normalized.cost_ < result.cost_) {
142
70.3k
        result = normalized;
143
70.3k
      }
144
272k
      return true;
145
272k
    };
146
147
63.1k
    switch (ans_histogram_strategy) {
148
1.22k
      case HistogramParams::ANSHistogramStrategy::kPrecise:
149
15.9k
        for (uint32_t shift = 0; shift < ANS_LOG_TAB_SIZE; shift++) {
150
14.6k
          JXL_RETURN_IF_ERROR(try_shift(shift));
151
14.6k
        }
152
1.22k
        break;
153
18.0k
      case HistogramParams::ANSHistogramStrategy::kApproximate:
154
144k
        for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
155
126k
          JXL_RETURN_IF_ERROR(try_shift(shift));
156
126k
        }
157
18.0k
        break;
158
43.8k
      case HistogramParams::ANSHistogramStrategy::kFast:
159
43.8k
        JXL_RETURN_IF_ERROR(try_shift(0));
160
43.8k
        JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE / 2));
161
43.8k
        JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE));
162
43.8k
        break;
163
63.1k
    }
164
165
      // Sanity check
166
#if JXL_IS_DEBUG_BUILD
167
    JXL_ENSURE(histo.counts.size() == result.counts_.size());
168
    ANSHistBin total = 0;  // Used only in assert.
169
    for (size_t i = 0; i < result.alphabet_size_; ++i) {
170
      JXL_ENSURE(result.counts_[i] >= 0);
171
      // For non-flat histogram values should be zero or non-zero simultaneously
172
      // for the same symbol in both initial and normalized histograms.
173
      JXL_ENSURE(result.method_ == 0 ||
174
                 (histo.counts[i] > 0) == (result.counts_[i] > 0));
175
      // Check accuracy of the histogram values
176
      if (result.method_ > 0 && result.counts_[i] > 0 &&
177
          i != result.omit_pos_) {
178
        int logcounts = FloorLog2Nonzero<uint32_t>(result.counts_[i]);
179
        int bitcount =
180
            GetPopulationCountPrecision(logcounts, result.method_ - 1);
181
        int drop_bits = logcounts - bitcount;
182
        // Check that the value is divisible by 2^drop_bits
183
        JXL_ENSURE((result.counts_[i] & ((1 << drop_bits) - 1)) == 0);
184
      }
185
      total += result.counts_[i];
186
    }
187
    for (size_t i = result.alphabet_size_; i < result.counts_.size(); ++i) {
188
      JXL_ENSURE(histo.counts[i] == 0);
189
      JXL_ENSURE(result.counts_[i] == 0);
190
    }
191
    JXL_ENSURE((histo.total_count == 0) || (total == ANS_TAB_SIZE));
192
#endif
193
63.1k
    return result;
194
63.1k
  }
195
196
  template <typename Writer>
197
369k
  Status Encode(Writer* writer) {
198
    // The check ensures also that all RLE sequences can be
199
    // encoded by `StoreVarLenUint8`
200
369k
    JXL_ENSURE(alphabet_size_ <= ANS_MAX_ALPHABET_SIZE);
201
202
    /// Flat histogram.
203
369k
    if (method_ == 0) {
204
      // Mark non-small tree.
205
70.3k
      writer->Write(1, 0);
206
      // Mark uniform histogram.
207
70.3k
      writer->Write(1, 1);
208
70.3k
      JXL_ENSURE(alphabet_size_ > 0);
209
      // Encode alphabet size.
210
70.3k
      StoreVarLenUint8(alphabet_size_ - 1, writer);
211
212
70.3k
      return true;
213
70.3k
    }
214
215
    /// Small tree.
216
298k
    if (num_symbols_ <= kMaxNumSymbolsForSmallCode) {
217
      // Small tree marker to encode 1-2 symbols.
218
11.4k
      writer->Write(1, 1);
219
11.4k
      if (num_symbols_ == 0) {
220
0
        writer->Write(1, 0);
221
0
        StoreVarLenUint8(0, writer);
222
11.4k
      } else {
223
11.4k
        writer->Write(1, num_symbols_ - 1);
224
27.2k
        for (size_t i = 0; i < num_symbols_; ++i) {
225
15.7k
          StoreVarLenUint8(symbols_[i], writer);
226
15.7k
        }
227
11.4k
      }
228
11.4k
      if (num_symbols_ == 2) {
229
4.25k
        writer->Write(ANS_LOG_TAB_SIZE, counts_[symbols_[0]]);
230
4.25k
      }
231
232
11.4k
      return true;
233
11.4k
    }
234
235
    /// General tree.
236
    // Mark non-small tree.
237
287k
    writer->Write(1, 0);
238
    // Mark non-flat histogram.
239
287k
    writer->Write(1, 0);
240
241
    // Elias gamma-like code for `shift = method - 1`. Only difference is that
242
    // if the number of bits to be encoded is equal to `upper_bound_log`,
243
    // we skip the terminating 0 in unary coding.
244
287k
    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
245
287k
    int log = FloorLog2Nonzero(method_);
246
287k
    writer->Write(log, (1 << log) - 1);
247
287k
    if (log != upper_bound_log) writer->Write(1, 0);
248
287k
    writer->Write(log, ((1 << log) - 1) & method_);
249
250
    // Since `num_symbols_ >= 3`, we know that `alphabet_size_ >= 3`, therefore
251
    // we encode `alphabet_size_ - 3`.
252
287k
    StoreVarLenUint8(alphabet_size_ - 3, writer);
253
254
    // Precompute sequences for RLE encoding. Contains the number of identical
255
    // values starting at a given index. Only contains that value at the first
256
    // element of the series.
257
287k
    uint8_t same[ANS_MAX_ALPHABET_SIZE] = {};
258
287k
    size_t last = 0;
259
12.0M
    for (size_t i = 1; i <= alphabet_size_; i++) {
260
      // Store the sequence length once different symbol reached, or we are
261
      // near the omit_pos_, or we're at the end. We don't support including the
262
      // omit_pos_ in an RLE sequence because this value may use a different
263
      // amount of log2 bits than standard, it is too complex to handle in the
264
      // decoder.
265
11.7M
      if (i == alphabet_size_ || i == omit_pos_ || i == omit_pos_ + 1 ||
266
11.7M
          counts_[i] != counts_[last]) {
267
6.73M
        same[last] = i - last;
268
6.73M
        last = i;
269
6.73M
      }
270
11.7M
    }
271
272
287k
    uint8_t bit_width[ANS_MAX_ALPHABET_SIZE] = {};
273
    // Use shortest possible Huffman code to encode `omit_pos` (see
274
    // `kBitWidthLengths`). `bit_width` value at `omit_pos` should be the
275
    // first of maximal values in the whole `bit_width` array, so it can be
276
    // increased without changing that property
277
287k
    int omit_width = 10;
278
12.0M
    for (size_t i = 0; i < alphabet_size_; ++i) {
279
11.7M
      if (i != omit_pos_ && counts_[i] > 0) {
280
5.85M
        bit_width[i] = FloorLog2Nonzero<uint32_t>(counts_[i]) + 1;
281
5.85M
        omit_width = std::max(omit_width, bit_width[i] + int{i < omit_pos_});
282
5.85M
      }
283
11.7M
    }
284
287k
    bit_width[omit_pos_] = static_cast<uint8_t>(omit_width);
285
286
    // The bit widths are encoded with a static Huffman code.
287
    // The last symbol is used as RLE sequence.
288
287k
    constexpr uint8_t kBitWidthLengths[ANS_LOG_TAB_SIZE + 2] = {
289
287k
        5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
290
287k
    };
291
287k
    constexpr uint8_t kBitWidthSymbols[ANS_LOG_TAB_SIZE + 2] = {
292
287k
        17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
293
287k
    };
294
287k
    constexpr uint8_t kMinReps = 5;
295
287k
    constexpr size_t rep = ANS_LOG_TAB_SIZE + 1;
296
    // Encode count bit widths
297
8.29M
    for (size_t i = 0; i < alphabet_size_; ++i) {
298
8.00M
      writer->Write(kBitWidthLengths[bit_width[i]],
299
8.00M
                    kBitWidthSymbols[bit_width[i]]);
300
8.00M
      if (same[i] >= kMinReps) {
301
        // Encode the RLE symbol and skip the repeated ones.
302
161k
        writer->Write(kBitWidthLengths[rep], kBitWidthSymbols[rep]);
303
161k
        StoreVarLenUint8(same[i] - kMinReps, writer);
304
161k
        i += same[i] - 1;
305
161k
      }
306
8.00M
    }
307
    // Encode additional bits of accuracy
308
287k
    uint32_t shift = method_ - 1;
309
287k
    if (shift != 0) {  // otherwise `bitcount = 0`
310
6.11M
      for (size_t i = 0; i < alphabet_size_; ++i) {
311
5.90M
        if (bit_width[i] > 1 && i != omit_pos_) {
312
4.16M
          int bitcount = GetPopulationCountPrecision(bit_width[i] - 1, shift);
313
4.16M
          int drop_bits = bit_width[i] - 1 - bitcount;
314
4.16M
          JXL_DASSERT((counts_[i] & ((1 << drop_bits) - 1)) == 0);
315
4.16M
          writer->Write(bitcount, (counts_[i] >> drop_bits) - (1 << bitcount));
316
4.16M
        }
317
5.90M
        if (same[i] >= kMinReps) {
318
          // Skip symbols encoded by RLE.
319
111k
          i += same[i] - 1;
320
111k
        }
321
5.90M
      }
322
213k
    }
323
287k
    return true;
324
298k
  }
enc_ans.cc:jxl::Status jxl::(anonymous namespace)::ANSEncodingHistogram::Encode<jxl::SizeWriter>(jxl::SizeWriter*)
Line
Count
Source
197
349k
  Status Encode(Writer* writer) {
198
    // The check ensures also that all RLE sequences can be
199
    // encoded by `StoreVarLenUint8`
200
349k
    JXL_ENSURE(alphabet_size_ <= ANS_MAX_ALPHABET_SIZE);
201
202
    /// Flat histogram.
203
349k
    if (method_ == 0) {
204
      // Mark non-small tree.
205
69.9k
      writer->Write(1, 0);
206
      // Mark uniform histogram.
207
69.9k
      writer->Write(1, 1);
208
69.9k
      JXL_ENSURE(alphabet_size_ > 0);
209
      // Encode alphabet size.
210
69.9k
      StoreVarLenUint8(alphabet_size_ - 1, writer);
211
212
69.9k
      return true;
213
69.9k
    }
214
215
    /// Small tree.
216
279k
    if (num_symbols_ <= kMaxNumSymbolsForSmallCode) {
217
      // Small tree marker to encode 1-2 symbols.
218
10.8k
      writer->Write(1, 1);
219
10.8k
      if (num_symbols_ == 0) {
220
0
        writer->Write(1, 0);
221
0
        StoreVarLenUint8(0, writer);
222
10.8k
      } else {
223
10.8k
        writer->Write(1, num_symbols_ - 1);
224
25.8k
        for (size_t i = 0; i < num_symbols_; ++i) {
225
14.9k
          StoreVarLenUint8(symbols_[i], writer);
226
14.9k
        }
227
10.8k
      }
228
10.8k
      if (num_symbols_ == 2) {
229
4.05k
        writer->Write(ANS_LOG_TAB_SIZE, counts_[symbols_[0]]);
230
4.05k
      }
231
232
10.8k
      return true;
233
10.8k
    }
234
235
    /// General tree.
236
    // Mark non-small tree.
237
268k
    writer->Write(1, 0);
238
    // Mark non-flat histogram.
239
268k
    writer->Write(1, 0);
240
241
    // Elias gamma-like code for `shift = method - 1`. Only difference is that
242
    // if the number of bits to be encoded is equal to `upper_bound_log`,
243
    // we skip the terminating 0 in unary coding.
244
268k
    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
245
268k
    int log = FloorLog2Nonzero(method_);
246
268k
    writer->Write(log, (1 << log) - 1);
247
268k
    if (log != upper_bound_log) writer->Write(1, 0);
248
268k
    writer->Write(log, ((1 << log) - 1) & method_);
249
250
    // Since `num_symbols_ >= 3`, we know that `alphabet_size_ >= 3`, therefore
251
    // we encode `alphabet_size_ - 3`.
252
268k
    StoreVarLenUint8(alphabet_size_ - 3, writer);
253
254
    // Precompute sequences for RLE encoding. Contains the number of identical
255
    // values starting at a given index. Only contains that value at the first
256
    // element of the series.
257
268k
    uint8_t same[ANS_MAX_ALPHABET_SIZE] = {};
258
268k
    size_t last = 0;
259
11.4M
    for (size_t i = 1; i <= alphabet_size_; i++) {
260
      // Store the sequence length once different symbol reached, or we are
261
      // near the omit_pos_, or we're at the end. We don't support including the
262
      // omit_pos_ in an RLE sequence because this value may use a different
263
      // amount of log2 bits than standard, it is too complex to handle in the
264
      // decoder.
265
11.1M
      if (i == alphabet_size_ || i == omit_pos_ || i == omit_pos_ + 1 ||
266
11.1M
          counts_[i] != counts_[last]) {
267
6.36M
        same[last] = i - last;
268
6.36M
        last = i;
269
6.36M
      }
270
11.1M
    }
271
272
268k
    uint8_t bit_width[ANS_MAX_ALPHABET_SIZE] = {};
273
    // Use shortest possible Huffman code to encode `omit_pos` (see
274
    // `kBitWidthLengths`). `bit_width` value at `omit_pos` should be the
275
    // first of maximal values in the whole `bit_width` array, so it can be
276
    // increased without changing that property
277
268k
    int omit_width = 10;
278
11.4M
    for (size_t i = 0; i < alphabet_size_; ++i) {
279
11.1M
      if (i != omit_pos_ && counts_[i] > 0) {
280
5.51M
        bit_width[i] = FloorLog2Nonzero<uint32_t>(counts_[i]) + 1;
281
5.51M
        omit_width = std::max(omit_width, bit_width[i] + int{i < omit_pos_});
282
5.51M
      }
283
11.1M
    }
284
268k
    bit_width[omit_pos_] = static_cast<uint8_t>(omit_width);
285
286
    // The bit widths are encoded with a static Huffman code.
287
    // The last symbol is used as RLE sequence.
288
268k
    constexpr uint8_t kBitWidthLengths[ANS_LOG_TAB_SIZE + 2] = {
289
268k
        5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
290
268k
    };
291
268k
    constexpr uint8_t kBitWidthSymbols[ANS_LOG_TAB_SIZE + 2] = {
292
268k
        17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
293
268k
    };
294
268k
    constexpr uint8_t kMinReps = 5;
295
268k
    constexpr size_t rep = ANS_LOG_TAB_SIZE + 1;
296
    // Encode count bit widths
297
7.82M
    for (size_t i = 0; i < alphabet_size_; ++i) {
298
7.55M
      writer->Write(kBitWidthLengths[bit_width[i]],
299
7.55M
                    kBitWidthSymbols[bit_width[i]]);
300
7.55M
      if (same[i] >= kMinReps) {
301
        // Encode the RLE symbol and skip the repeated ones.
302
155k
        writer->Write(kBitWidthLengths[rep], kBitWidthSymbols[rep]);
303
155k
        StoreVarLenUint8(same[i] - kMinReps, writer);
304
155k
        i += same[i] - 1;
305
155k
      }
306
7.55M
    }
307
    // Encode additional bits of accuracy
308
268k
    uint32_t shift = method_ - 1;
309
268k
    if (shift != 0) {  // otherwise `bitcount = 0`
310
5.95M
      for (size_t i = 0; i < alphabet_size_; ++i) {
311
5.74M
        if (bit_width[i] > 1 && i != omit_pos_) {
312
4.05M
          int bitcount = GetPopulationCountPrecision(bit_width[i] - 1, shift);
313
4.05M
          int drop_bits = bit_width[i] - 1 - bitcount;
314
4.05M
          JXL_DASSERT((counts_[i] & ((1 << drop_bits) - 1)) == 0);
315
4.05M
          writer->Write(bitcount, (counts_[i] >> drop_bits) - (1 << bitcount));
316
4.05M
        }
317
5.74M
        if (same[i] >= kMinReps) {
318
          // Skip symbols encoded by RLE.
319
109k
          i += same[i] - 1;
320
109k
        }
321
5.74M
      }
322
206k
    }
323
268k
    return true;
324
279k
  }
enc_ans.cc:jxl::Status jxl::(anonymous namespace)::ANSEncodingHistogram::Encode<jxl::BitWriter>(jxl::BitWriter*)
Line
Count
Source
197
19.4k
  Status Encode(Writer* writer) {
198
    // The check ensures also that all RLE sequences can be
199
    // encoded by `StoreVarLenUint8`
200
19.4k
    JXL_ENSURE(alphabet_size_ <= ANS_MAX_ALPHABET_SIZE);
201
202
    /// Flat histogram.
203
19.4k
    if (method_ == 0) {
204
      // Mark non-small tree.
205
372
      writer->Write(1, 0);
206
      // Mark uniform histogram.
207
372
      writer->Write(1, 1);
208
372
      JXL_ENSURE(alphabet_size_ > 0);
209
      // Encode alphabet size.
210
372
      StoreVarLenUint8(alphabet_size_ - 1, writer);
211
212
372
      return true;
213
372
    }
214
215
    /// Small tree.
216
19.0k
    if (num_symbols_ <= kMaxNumSymbolsForSmallCode) {
217
      // Small tree marker to encode 1-2 symbols.
218
580
      writer->Write(1, 1);
219
580
      if (num_symbols_ == 0) {
220
0
        writer->Write(1, 0);
221
0
        StoreVarLenUint8(0, writer);
222
580
      } else {
223
580
        writer->Write(1, num_symbols_ - 1);
224
1.35k
        for (size_t i = 0; i < num_symbols_; ++i) {
225
777
          StoreVarLenUint8(symbols_[i], writer);
226
777
        }
227
580
      }
228
580
      if (num_symbols_ == 2) {
229
197
        writer->Write(ANS_LOG_TAB_SIZE, counts_[symbols_[0]]);
230
197
      }
231
232
580
      return true;
233
580
    }
234
235
    /// General tree.
236
    // Mark non-small tree.
237
18.4k
    writer->Write(1, 0);
238
    // Mark non-flat histogram.
239
18.4k
    writer->Write(1, 0);
240
241
    // Elias gamma-like code for `shift = method - 1`. Only difference is that
242
    // if the number of bits to be encoded is equal to `upper_bound_log`,
243
    // we skip the terminating 0 in unary coding.
244
18.4k
    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
245
18.4k
    int log = FloorLog2Nonzero(method_);
246
18.4k
    writer->Write(log, (1 << log) - 1);
247
18.4k
    if (log != upper_bound_log) writer->Write(1, 0);
248
18.4k
    writer->Write(log, ((1 << log) - 1) & method_);
249
250
    // Since `num_symbols_ >= 3`, we know that `alphabet_size_ >= 3`, therefore
251
    // we encode `alphabet_size_ - 3`.
252
18.4k
    StoreVarLenUint8(alphabet_size_ - 3, writer);
253
254
    // Precompute sequences for RLE encoding. Contains the number of identical
255
    // values starting at a given index. Only contains that value at the first
256
    // element of the series.
257
18.4k
    uint8_t same[ANS_MAX_ALPHABET_SIZE] = {};
258
18.4k
    size_t last = 0;
259
565k
    for (size_t i = 1; i <= alphabet_size_; i++) {
260
      // Store the sequence length once different symbol reached, or we are
261
      // near the omit_pos_, or we're at the end. We don't support including the
262
      // omit_pos_ in an RLE sequence because this value may use a different
263
      // amount of log2 bits than standard, it is too complex to handle in the
264
      // decoder.
265
546k
      if (i == alphabet_size_ || i == omit_pos_ || i == omit_pos_ + 1 ||
266
546k
          counts_[i] != counts_[last]) {
267
366k
        same[last] = i - last;
268
366k
        last = i;
269
366k
      }
270
546k
    }
271
272
18.4k
    uint8_t bit_width[ANS_MAX_ALPHABET_SIZE] = {};
273
    // Use shortest possible Huffman code to encode `omit_pos` (see
274
    // `kBitWidthLengths`). `bit_width` value at `omit_pos` should be the
275
    // first of maximal values in the whole `bit_width` array, so it can be
276
    // increased without changing that property
277
18.4k
    int omit_width = 10;
278
565k
    for (size_t i = 0; i < alphabet_size_; ++i) {
279
546k
      if (i != omit_pos_ && counts_[i] > 0) {
280
340k
        bit_width[i] = FloorLog2Nonzero<uint32_t>(counts_[i]) + 1;
281
340k
        omit_width = std::max(omit_width, bit_width[i] + int{i < omit_pos_});
282
340k
      }
283
546k
    }
284
18.4k
    bit_width[omit_pos_] = static_cast<uint8_t>(omit_width);
285
286
    // The bit widths are encoded with a static Huffman code.
287
    // The last symbol is used as RLE sequence.
288
18.4k
    constexpr uint8_t kBitWidthLengths[ANS_LOG_TAB_SIZE + 2] = {
289
18.4k
        5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
290
18.4k
    };
291
18.4k
    constexpr uint8_t kBitWidthSymbols[ANS_LOG_TAB_SIZE + 2] = {
292
18.4k
        17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
293
18.4k
    };
294
18.4k
    constexpr uint8_t kMinReps = 5;
295
18.4k
    constexpr size_t rep = ANS_LOG_TAB_SIZE + 1;
296
    // Encode count bit widths
297
467k
    for (size_t i = 0; i < alphabet_size_; ++i) {
298
449k
      writer->Write(kBitWidthLengths[bit_width[i]],
299
449k
                    kBitWidthSymbols[bit_width[i]]);
300
449k
      if (same[i] >= kMinReps) {
301
        // Encode the RLE symbol and skip the repeated ones.
302
6.28k
        writer->Write(kBitWidthLengths[rep], kBitWidthSymbols[rep]);
303
6.28k
        StoreVarLenUint8(same[i] - kMinReps, writer);
304
6.28k
        i += same[i] - 1;
305
6.28k
      }
306
449k
    }
307
    // Encode additional bits of accuracy
308
18.4k
    uint32_t shift = method_ - 1;
309
18.4k
    if (shift != 0) {  // otherwise `bitcount = 0`
310
163k
      for (size_t i = 0; i < alphabet_size_; ++i) {
311
156k
        if (bit_width[i] > 1 && i != omit_pos_) {
312
109k
          int bitcount = GetPopulationCountPrecision(bit_width[i] - 1, shift);
313
109k
          int drop_bits = bit_width[i] - 1 - bitcount;
314
109k
          JXL_DASSERT((counts_[i] & ((1 << drop_bits) - 1)) == 0);
315
109k
          writer->Write(bitcount, (counts_[i] >> drop_bits) - (1 << bitcount));
316
109k
        }
317
156k
        if (same[i] >= kMinReps) {
318
          // Skip symbols encoded by RLE.
319
2.22k
          i += same[i] - 1;
320
2.22k
        }
321
156k
      }
322
6.39k
    }
323
18.4k
    return true;
324
19.0k
  }
325
326
  void ANSBuildInfoTable(const AliasTable::Entry* table, size_t log_alpha_size,
327
19.9k
                         ANSEncSymbolInfo* info) {
328
    // Create valid alias table for empty streams
329
632k
    for (size_t s = 0; s < std::max(size_t{1}, alphabet_size_); ++s) {
330
612k
      const ANSHistBin freq = s == alphabet_size_ ? ANS_TAB_SIZE : counts_[s];
331
612k
      info[s].freq_ = static_cast<uint16_t>(freq);
332
612k
#ifdef USE_MULT_BY_RECIPROCAL
333
612k
      if (freq != 0) {
334
378k
        info[s].ifreq_ = ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) /
335
378k
                         info[s].freq_;
336
378k
      } else {
337
234k
        info[s].ifreq_ =
338
234k
            1;  // Shouldn't matter (symbol shouldn't occur), but...
339
234k
      }
340
612k
#endif
341
612k
      info[s].reverse_map_.resize(freq);
342
612k
    }
343
19.9k
    size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
344
19.9k
    size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
345
81.5M
    for (int i = 0; i < ANS_TAB_SIZE; i++) {
346
81.5M
      AliasTable::Symbol s =
347
81.5M
          AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
348
81.5M
      info[s.value].reverse_map_[s.offset] = i;
349
81.5M
    }
350
19.9k
  }
351
352
 private:
353
69.9k
  ANSEncodingHistogram() {}
354
355
  // Fixed-point log2 LUT for values of [0,4096]
356
  using Lg2LUT = std::array<uint32_t, ANS_TAB_SIZE + 1>;
357
  static const Lg2LUT lg2;
358
359
272k
  float EstimateDataBits(const Histogram& histo) {
360
272k
    int64_t sum = 0;
361
11.4M
    for (size_t i = 0; i < alphabet_size_; ++i) {
362
      // += histogram[i] * -log(counts[i]/total_counts)
363
11.2M
      sum += histo.counts[i] * int64_t{lg2[counts_[i]]};
364
11.2M
    }
365
272k
    return (histo.total_count - ldexpf(sum, -31)) * ANS_LOG_TAB_SIZE;
366
272k
  }
367
368
69.9k
  static float EstimateDataBitsFlat(const Histogram& histo) {
369
69.9k
    size_t len = histo.alphabet_size();
370
69.9k
    int64_t flat_bits = int64_t{lg2[len]} * ANS_LOG_TAB_SIZE;
371
69.9k
    return ldexpf(histo.total_count * flat_bits, -31);
372
69.9k
  }
373
374
  struct CountsEntropy {
375
    ANSHistBin count : 16;     // allowed value of counts in a histogram bin
376
    ANSHistBin step_log : 16;  // log2 of increase step size (can use 5 bits)
377
    int32_t delta_lg2;  // change of log between that value and the next allowed
378
  };
379
380
  // Array is sorted by decreasing allowed counts for each possible shift.
381
  // Exclusion of single-bin histograms before `RebalanceHistogram` allows
382
  // to put count upper limit of 4095, and shifts of 11 and 12 produce the
383
  // same table
384
  using CountsArray =
385
      std::array<std::array<CountsEntropy, ANS_TAB_SIZE>, ANS_LOG_TAB_SIZE>;
386
  static const CountsArray allowed_counts;
387
388
  // Returns the difference between largest count that can be represented and is
389
  // smaller than "count" and smallest representable count larger than "count".
390
25.5M
  static uint32_t SmallestIncrementLog(uint32_t count, uint32_t shift) {
391
25.5M
    if (count == 0) return 0;
392
20.1M
    uint32_t bits = FloorLog2Nonzero(count);
393
20.1M
    uint32_t drop_bits = bits - GetPopulationCountPrecision(bits, shift);
394
20.1M
    return drop_bits;
395
25.5M
  }
396
  // We are growing/reducing histogram step by step trying to maximize total
397
  // entropy i.e. sum of `freq[n] * log[counts[n]]` with a given sum of
398
  // `counts[n]` chosen from `allowed_counts[shift]`. This sum is balanced by
399
  // the `counts[omit_pos_]` in the highest bin of histogram. We start from
400
  // close to correct solution and each time a step with maximum entropy
401
  // increase per unit of bin change is chosen. This greedy scheme is not
402
  // guaranteed to achieve the global maximum, but cannot produce invalid
403
  // histogram. We use a fixed-point approximation for logarithms and all
404
  // arithmetic is integer besides initial approximation. Sum of `freq` and each
405
  // of `lg2[counts]` are supposed to be limited to `int32_t` range, so that the
406
  // sum of their products should not exceed `int64_t`.
407
272k
  bool RebalanceHistogram(const Histogram& histo) {
408
272k
    constexpr ANSHistBin table_size = ANS_TAB_SIZE;
409
272k
    uint32_t shift = method_ - 1;
410
411
272k
    struct EntropyDelta {
412
272k
      ANSHistBin freq;   // initial count
413
272k
      size_t count_ind;  // index of current bin value in `allowed_counts`
414
272k
      size_t bin_ind;    // index of current bin in `counts`
415
272k
    };
416
    // Penalties corresponding to different step sizes - entropy decrease in
417
    // balancing bin, step of size (1 << ANS_LOG_TAB_SIZE - 1) is not possible
418
272k
    int64_t balance_inc[ANS_LOG_TAB_SIZE - 1] = {};
419
272k
    int64_t balance_dec[ANS_LOG_TAB_SIZE - 1] = {};
420
272k
    const auto& ac = allowed_counts[shift];
421
    // TODO(ivan) separate cases of shift >= 11 - all steps are 1 there, and
422
    // possibly 10 - all relevant steps are 2.
423
    // Total entropy change by a step: increase/decrease in current bin
424
    // together with corresponding decrease/increase in the balancing bin.
425
    // Inc steps increase current bin, dec steps decrease
426
108M
    const auto delta_entropy_inc = [&](const EntropyDelta& a) {
427
108M
      return a.freq * int64_t{ac[a.count_ind].delta_lg2} -
428
108M
             balance_inc[ac[a.count_ind].step_log];
429
108M
    };
430
16.8M
    const auto delta_entropy_dec = [&](const EntropyDelta& a) {
431
16.8M
      return a.freq * int64_t{ac[a.count_ind + 1].delta_lg2} -
432
16.8M
             balance_dec[ac[a.count_ind + 1].step_log];
433
16.8M
    };
434
    // Compare steps by entropy increase per unit of histogram bin change.
435
    // Truncation is OK here, accuracy is anyway better than float
436
53.4M
    const auto IncLess = [&](const EntropyDelta& a, const EntropyDelta& b) {
437
53.4M
      return delta_entropy_inc(a) >> ac[a.count_ind].step_log <
438
53.4M
             delta_entropy_inc(b) >> ac[b.count_ind].step_log;
439
53.4M
    };
440
8.22M
    const auto DecLess = [&](const EntropyDelta& a, const EntropyDelta& b) {
441
8.22M
      return delta_entropy_dec(a) >> ac[a.count_ind + 1].step_log <
442
8.22M
             delta_entropy_dec(b) >> ac[b.count_ind + 1].step_log;
443
8.22M
    };
444
    // Vector of adjustable bins from `allowed_counts`
445
272k
    std::vector<EntropyDelta> bins;
446
272k
    bins.reserve(256);
447
448
272k
    double norm = double{table_size} / histo.total_count;
449
450
272k
    size_t remainder_pos = 0;  // highest balancing bin in the histogram
451
272k
    int64_t max_freq = 0;
452
272k
    ANSHistBin rest = table_size;  // reserve of histogram counts to distribute
453
11.4M
    for (size_t n = 0; n < alphabet_size_; ++n) {
454
11.2M
      ANSHistBin freq = histo.counts[n];
455
11.2M
      if (freq > max_freq) {
456
582k
        remainder_pos = n;
457
582k
        max_freq = freq;
458
582k
      }
459
460
11.2M
      double target = freq * norm;  // rounding
461
      // Keep zeros and clamp nonzero freq counts to [1, table_size)
462
11.2M
      ANSHistBin count = std::max<ANSHistBin>(round(target), freq > 0);
463
11.2M
      count = std::min<ANSHistBin>(count, table_size - 1);
464
11.2M
      uint32_t step_log = SmallestIncrementLog(count, shift);
465
11.2M
      ANSHistBin inc = 1 << step_log;
466
11.2M
      count &= ~(inc - 1);
467
468
11.2M
      counts_[n] = count;
469
11.2M
      rest -= count;
470
11.2M
      if (target > 1.0) {
471
5.67M
        size_t count_ind = 0;
472
        // TODO(ivan) binary search instead of linear?
473
6.67G
        while (ac[count_ind].count != count) ++count_ind;
474
5.67M
        bins.push_back({freq, count_ind, n});
475
5.67M
      }
476
11.2M
    }
477
478
    // Delete the highest balancing bin from adjustable by `allowed_counts`
479
272k
    bins.erase(std::find_if(
480
272k
        bins.begin(), bins.end(),
481
1.70M
        [&](const EntropyDelta& a) { return a.bin_ind == remainder_pos; }));
482
    // From now on `rest` is the height of balancing bin,
483
    // here it can be negative, but will be tracted into positive domain later
484
272k
    rest += counts_[remainder_pos];
485
486
272k
    if (!bins.empty()) {
487
272k
      const uint32_t max_log = ac[1].step_log;
488
1.96M
      while (true) {
489
        // Update balancing bin penalties setting guards and tractors
490
17.3M
        for (uint32_t log = 0; log <= max_log; ++log) {
491
15.4M
          ANSHistBin delta = 1 << log;
492
15.4M
          if (rest >= table_size) {
493
            // Tract large `rest` into allowed domain:
494
0
            balance_inc[log] = 0;  // permit all inc steps
495
0
            balance_dec[log] = 0;  // forbid all dec steps
496
15.4M
          } else if (rest > 1) {
497
            // `rest` is OK, put guards against non-possible steps
498
15.4M
            balance_inc[log] =
499
15.4M
                rest > delta  // possible step
500
15.4M
                    ? max_freq * int64_t{lg2[rest] - lg2[rest - delta]}
501
15.4M
                    : std::numeric_limits<int64_t>::max();  // forbidden
502
15.4M
            balance_dec[log] =
503
15.4M
                rest + delta < table_size  // possible step
504
15.4M
                    ? max_freq * int64_t{lg2[rest + delta] - lg2[rest]}
505
15.4M
                    : 0;  // forbidden
506
15.4M
          } else {
507
            // Tract negative or zero `rest` into positive:
508
            // forbid all inc steps
509
153
            balance_inc[log] = std::numeric_limits<int64_t>::max();
510
            // permit all dec steps
511
153
            balance_dec[log] = std::numeric_limits<int64_t>::max();
512
153
          }
513
15.4M
        }
514
        // Try to increase entropy
515
1.96M
        auto best_bin_inc = std::max_element(bins.begin(), bins.end(), IncLess);
516
1.96M
        if (delta_entropy_inc(*best_bin_inc) > 0) {
517
          // Grow the bin with the best histogram entropy increase
518
1.59M
          rest -= 1 << ac[best_bin_inc->count_ind--].step_log;
519
1.59M
        } else {
520
          // This still implies that entropy is strictly increasing each step
521
          // (or `rest` is tracted into positive domain), so we cannot loop
522
          // infinitely
523
369k
          auto best_bin_dec =
524
369k
              std::min_element(bins.begin(), bins.end(), DecLess);
525
          // Break if no reverse steps can grow entropy (or valid)
526
369k
          if (delta_entropy_dec(*best_bin_dec) >= 0) break;
527
          // Decrease the bin with the best histogram entropy increase
528
96.7k
          rest += 1 << ac[++best_bin_dec->count_ind].step_log;
529
96.7k
        }
530
1.96M
      }
531
      // Set counts besides the balancing bin
532
5.40M
      for (auto& a : bins) counts_[a.bin_ind] = ac[a.count_ind].count;
533
534
      // The scheme works fine if we have room to grow `bit_width` of balancing
535
      // bin, otherwise we need to put balancing bin to the first bin of 12 bit
536
      // width. In this case both that bin and balancing one should be close to
537
      // 2048 in targets, so exchange of them will not produce much worse
538
      // histogram
539
2.48M
      for (size_t n = 0; n < remainder_pos; ++n) {
540
2.21M
        if (counts_[n] >= 2048) {
541
729
          counts_[remainder_pos] = counts_[n];
542
729
          remainder_pos = n;
543
729
          break;
544
729
        }
545
2.21M
      }
546
272k
    }
547
    // Set balancing bin
548
272k
    counts_[remainder_pos] = rest;
549
272k
    omit_pos_ = remainder_pos;
550
551
272k
    return counts_[remainder_pos] > 0;
552
272k
  }
553
554
  float cost_ = 0;
555
  uint32_t method_ = 0;
556
  size_t omit_pos_ = 0;
557
  size_t alphabet_size_ = 0;
558
  size_t num_symbols_ = 0;
559
  size_t symbols_[kMaxNumSymbolsForSmallCode] = {};
560
  std::vector<ANSHistBin> counts_{};
561
};
562
563
using AEH = ANSEncodingHistogram;
564
565
292
const AEH::Lg2LUT AEH::lg2 = [] {
566
292
  Lg2LUT lg2;
567
292
  lg2[0] = 0;  // for entropy calculations it is OK
568
1.19M
  for (size_t i = 1; i < lg2.size(); ++i) {
569
1.19M
    lg2[i] = round(ldexp(log2(i) / ANS_LOG_TAB_SIZE, 31));
570
1.19M
  }
571
292
  return lg2;
572
292
}();
573
574
292
const AEH::CountsArray AEH::allowed_counts = [] {
575
292
  CountsArray allowed_counts = {};
576
577
3.79k
  for (uint32_t shift = 0; shift < allowed_counts.size(); ++shift) {
578
3.50k
    auto& ac = allowed_counts[shift];
579
14.3M
    for (uint32_t i = 1; i < ac.size(); ++i) {
580
14.3M
      int32_t cnt = i & ~((1 << SmallestIncrementLog(i, shift)) - 1);
581
14.3M
      ac[cnt].count = cnt;
582
14.3M
    }
583
3.50k
    std::sort(ac.begin(), ac.end(),
584
47.3M
              [](const CountsEntropy& a, const CountsEntropy& b) {
585
47.3M
                return a.count > b.count;
586
47.3M
              });
587
3.50k
    int ind = 1;
588
2.79M
    while (ac[ind].count > 0) {
589
2.79M
      ac[ind].delta_lg2 = round(
590
2.79M
          ldexp(log2(static_cast<double>(ac[ind - 1].count) / ac[ind].count) /
591
2.79M
                    ANS_LOG_TAB_SIZE,
592
2.79M
                31));
593
2.79M
      ac[ind].step_log =
594
2.79M
          FloorLog2Nonzero<uint32_t>(ac[ind - 1].count - ac[ind].count);
595
2.79M
      ++ind;
596
2.79M
    }
597
    // Guards against non-possible steps:
598
    // at max value [0] - 0 (by init), at min value - max
599
3.50k
    ac[ind].delta_lg2 = std::numeric_limits<int32_t>::max();
600
3.50k
  }
601
292
  return allowed_counts;
602
292
}();
603
604
}  // namespace
605
606
50.0k
StatusOr<float> Histogram::ANSPopulationCost() const {
607
50.0k
  if (counts.size() > ANS_MAX_ALPHABET_SIZE) {
608
0
    return std::numeric_limits<float>::max();
609
0
  }
610
50.0k
  JXL_ASSIGN_OR_RETURN(
611
50.0k
      ANSEncodingHistogram normalized,
612
50.0k
      ANSEncodingHistogram::ComputeBest(
613
50.0k
          *this, HistogramParams::ANSHistogramStrategy::kFast));
614
50.0k
  return normalized.Cost();
615
50.0k
}
616
617
// Returns an estimate or exact cost of encoding this histogram and the
618
// corresponding data.
619
StatusOr<size_t> EntropyEncodingData::BuildAndStoreANSEncodingData(
620
    JxlMemoryManager* memory_manager,
621
    HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
622
22.1k
    const Histogram& histogram, BitWriter* writer) {
623
22.1k
  ANSEncSymbolInfo* info = encoding_info.back().data();
624
22.1k
  size_t size = histogram.alphabet_size();
625
22.1k
  if (use_prefix_code) {
626
2.29k
    size_t cost = 0;
627
2.29k
    if (size <= 1) return 0;
628
2.24k
    std::vector<uint32_t> histo(size);
629
16.6k
    for (size_t i = 0; i < size; i++) {
630
14.4k
      JXL_ENSURE(histogram.counts[i] >= 0);
631
14.4k
      histo[i] = histogram.counts[i];
632
14.4k
    }
633
2.24k
    std::vector<uint8_t> depths(size);
634
2.24k
    std::vector<uint16_t> bits(size);
635
2.24k
    if (writer == nullptr) {
636
1.75k
      BitWriter tmp_writer{memory_manager};
637
1.75k
      JXL_RETURN_IF_ERROR(tmp_writer.WithMaxBits(
638
1.75k
          8 * size + 8,  // safe upper bound
639
1.75k
          LayerType::Header, /*aux_out=*/nullptr, [&] {
640
1.75k
            return BuildAndStoreHuffmanTree(histo.data(), size, depths.data(),
641
1.75k
                                            bits.data(), &tmp_writer);
642
1.75k
          }));
643
1.75k
      cost = tmp_writer.BitsWritten();
644
1.75k
    } else {
645
498
      size_t start = writer->BitsWritten();
646
498
      JXL_RETURN_IF_ERROR(BuildAndStoreHuffmanTree(
647
498
          histo.data(), size, depths.data(), bits.data(), writer));
648
498
      cost = writer->BitsWritten() - start;
649
498
    }
650
16.6k
    for (size_t i = 0; i < size; i++) {
651
14.4k
      info[i].bits = depths[i] == 0 ? 0 : bits[i];
652
14.4k
      info[i].depth = depths[i];
653
14.4k
    }
654
    // Estimate data cost.
655
16.6k
    for (size_t i = 0; i < size; i++) {
656
14.4k
      cost += histo[i] * info[i].depth;
657
14.4k
    }
658
2.24k
    return cost;
659
2.24k
  }
660
39.8k
  JXL_ASSIGN_OR_RETURN(
661
39.8k
      ANSEncodingHistogram normalized,
662
39.8k
      ANSEncodingHistogram::ComputeBest(histogram, ans_histogram_strategy));
663
39.8k
  AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
664
39.8k
  JXL_RETURN_IF_ERROR(
665
39.8k
      InitAliasTable(normalized.Counts(), ANS_LOG_TAB_SIZE, log_alpha_size, a));
666
19.9k
  normalized.ANSBuildInfoTable(a, log_alpha_size, info);
667
19.9k
  if (writer != nullptr) {
668
    // size_t start = writer->BitsWritten();
669
19.4k
    JXL_RETURN_IF_ERROR(normalized.Encode(writer));
670
    // return writer->BitsWritten() - start;
671
19.4k
  }
672
19.9k
  return static_cast<size_t>(ceilf(normalized.Cost()));
673
19.9k
}
674
675
namespace {
676
677
Histogram HistogramFromSymbolInfo(
678
0
    const std::vector<ANSEncSymbolInfo>& encoding_info, bool use_prefix_code) {
679
0
  Histogram histo;
680
0
  histo.counts.resize(DivCeil(encoding_info.size(), Histogram::kRounding) *
681
0
                      Histogram::kRounding);
682
0
  histo.total_count = 0;
683
0
  for (size_t i = 0; i < encoding_info.size(); ++i) {
684
0
    const ANSEncSymbolInfo& info = encoding_info[i];
685
0
    int count = use_prefix_code
686
0
                    ? (info.depth ? (1u << (PREFIX_MAX_BITS - info.depth)) : 0)
687
0
                    : info.freq_;
688
0
    histo.counts[i] = count;
689
0
    histo.total_count += count;
690
0
  }
691
0
  return histo;
692
0
}
693
694
}  // namespace
695
696
Status EntropyEncodingData::ChooseUintConfigs(
697
    const HistogramParams& params,
698
    const std::vector<std::vector<Token>>& tokens,
699
3.53k
    std::vector<Histogram>& clustered_histograms) {
700
  // Set sane default `log_alpha_size`.
701
3.53k
  if (use_prefix_code) {
702
2.28k
    log_alpha_size = PREFIX_MAX_BITS;
703
2.28k
  } else if (params.streaming_mode) {
704
    // TODO(szabadka) Figure out if we can use lower values here.
705
0
    log_alpha_size = 8;
706
1.24k
  } else if (lz77.enabled) {
707
354
    log_alpha_size = 8;
708
890
  } else {
709
890
    log_alpha_size = 7;
710
890
  }
711
712
3.53k
  if (ans_fuzzer_friendly_) {
713
0
    uint_config.assign(1, HybridUintConfig(7, 0, 0));
714
0
    return true;
715
0
  }
716
717
3.53k
  uint_config.assign(clustered_histograms.size(), params.UintConfig());
718
  // If the uint config is fixed, just use it.
719
3.53k
  if (params.uint_method != HistogramParams::HybridUintMethod::kBest &&
720
3.53k
      params.uint_method != HistogramParams::HybridUintMethod::kFast) {
721
2.73k
    return true;
722
2.73k
  }
723
  // Even if the uint config is adaptive, just stick with the default in
724
  // streaming mode.
725
800
  if (params.streaming_mode) {
726
0
    return true;
727
0
  }
728
729
  // Brute-force method that tries a few options.
730
800
  std::vector<HybridUintConfig> configs;
731
800
  if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
732
234
    configs = {
733
234
        HybridUintConfig(4, 2, 0),  // default
734
234
        HybridUintConfig(4, 1, 0),  // less precise
735
234
        HybridUintConfig(4, 2, 1),  // add sign
736
234
        HybridUintConfig(4, 2, 2),  // add sign+parity
737
234
        HybridUintConfig(4, 1, 2),  // add parity but less msb
738
        // Same as above, but more direct coding.
739
234
        HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
740
234
        HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
741
234
        HybridUintConfig(5, 1, 2),
742
        // Same as above, but less direct coding.
743
234
        HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
744
234
        HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
745
        // For near-lossless.
746
234
        HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
747
234
        HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
748
234
        HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
749
        // Other
750
234
        HybridUintConfig(0, 0, 0),   // varlenuint
751
234
        HybridUintConfig(2, 0, 1),   // works well for ctx map
752
234
        HybridUintConfig(7, 0, 0),   // direct coding
753
234
        HybridUintConfig(8, 0, 0),   // direct coding
754
234
        HybridUintConfig(9, 0, 0),   // direct coding
755
234
        HybridUintConfig(10, 0, 0),  // direct coding
756
234
        HybridUintConfig(11, 0, 0),  // direct coding
757
234
        HybridUintConfig(12, 0, 0),  // direct coding
758
234
    };
759
566
  } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
760
566
    configs = {
761
566
        HybridUintConfig(4, 2, 0),  // default
762
566
        HybridUintConfig(4, 1, 2),  // add parity but less msb
763
566
        HybridUintConfig(0, 0, 0),  // smallest histograms
764
566
        HybridUintConfig(2, 0, 1),  // works well for ctx map
765
566
    };
766
566
  }
767
768
800
  std::vector<float> costs(clustered_histograms.size(),
769
800
                           std::numeric_limits<float>::max());
770
800
  std::vector<uint32_t> extra_bits(clustered_histograms.size());
771
800
  std::vector<uint8_t> is_valid(clustered_histograms.size());
772
  // Wider histograms are assigned max cost in PopulationCost anyway
773
  // and therefore will not be used
774
800
  size_t max_alpha = ANS_MAX_ALPHABET_SIZE;
775
8.81k
  for (HybridUintConfig cfg : configs) {
776
8.81k
    std::fill(is_valid.begin(), is_valid.end(), true);
777
8.81k
    std::fill(extra_bits.begin(), extra_bits.end(), 0);
778
779
46.5k
    for (auto& histo : clustered_histograms) {
780
46.5k
      histo.Clear();
781
46.5k
    }
782
34.2k
    for (const auto& stream : tokens) {
783
82.2M
      for (const auto& token : stream) {
784
        // TODO(veluca): do not ignore lz77 commands.
785
82.2M
        if (token.is_lz77_length) continue;
786
82.2M
        size_t histo = context_map[token.context];
787
82.2M
        if (!is_valid[histo]) continue;
788
71.7M
        uint32_t tok, nbits, bits;
789
71.7M
        cfg.Encode(token.value, &tok, &nbits, &bits);
790
71.7M
        if (tok >= max_alpha || (lz77.enabled && tok >= lz77.min_symbol)) {
791
2.10k
          is_valid[histo] = JXL_FALSE;
792
2.10k
          continue;
793
2.10k
        }
794
71.7M
        extra_bits[histo] += nbits;
795
71.7M
        clustered_histograms[histo].Add(tok);
796
71.7M
      }
797
34.2k
    }
798
799
55.3k
    for (size_t i = 0; i < clustered_histograms.size(); i++) {
800
46.5k
      if (!is_valid[i]) continue;
801
88.8k
      JXL_ASSIGN_OR_RETURN(float cost,
802
88.8k
                           clustered_histograms[i].ANSPopulationCost());
803
88.8k
      cost += extra_bits[i];
804
      // add signaling cost of the hybriduintconfig itself
805
88.8k
      cost += CeilLog2Nonzero(cfg.split_exponent + 1);
806
88.8k
      cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
807
88.8k
      if (cost < costs[i]) {
808
12.0k
        uint_config[i] = cfg;
809
12.0k
        costs[i] = cost;
810
12.0k
      }
811
88.8k
    }
812
8.81k
  }
813
814
  // Rebuild histograms.
815
5.55k
  for (auto& histo : clustered_histograms) {
816
5.55k
    histo.Clear();
817
5.55k
  }
818
  // `log_alpha_size - 5` is encoded in the header, so min is 5.
819
800
  size_t log_size = 5;
820
7.15k
  for (const auto& stream : tokens) {
821
6.09M
    for (const auto& token : stream) {
822
6.09M
      uint32_t tok, nbits, bits;
823
6.09M
      size_t histo = context_map[token.context];
824
6.09M
      (token.is_lz77_length ? lz77.length_uint_config : uint_config[histo])
825
6.09M
          .Encode(token.value, &tok, &nbits, &bits);
826
6.09M
      tok += token.is_lz77_length ? lz77.min_symbol : 0;
827
6.09M
      clustered_histograms[histo].Add(tok);
828
6.09M
      while (tok >= (1u << log_size)) ++log_size;
829
6.09M
    }
830
7.15k
  }
831
800
  size_t max_log_alpha_size = use_prefix_code ? PREFIX_MAX_BITS : 8;
832
800
  JXL_ENSURE(log_size <= max_log_alpha_size);
833
834
800
  if (use_prefix_code) {
835
106
    log_alpha_size = PREFIX_MAX_BITS;
836
694
  } else {
837
694
    log_alpha_size = log_size;
838
694
  }
839
840
800
  return true;
841
800
}
842
843
// NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
844
// Returns cost (in bits).
845
StatusOr<size_t> EntropyEncodingData::BuildAndStoreEntropyCodes(
846
    JxlMemoryManager* memory_manager, const HistogramParams& params,
847
    const std::vector<std::vector<Token>>& tokens,
848
    const std::vector<Histogram>& builder, BitWriter* writer, LayerType layer,
849
3.53k
    AuxOut* aux_out) {
850
3.53k
  const size_t prev_histograms = encoding_info.size();
851
3.53k
  std::vector<Histogram> clustered_histograms;
852
3.53k
  for (size_t i = 0; i < prev_histograms; ++i) {
853
0
    clustered_histograms.push_back(
854
0
        HistogramFromSymbolInfo(encoding_info[i], use_prefix_code));
855
0
  }
856
3.53k
  size_t context_offset = context_map.size();
857
3.53k
  context_map.resize(context_offset + builder.size());
858
3.53k
  if (builder.size() > 1) {
859
1.30k
    if (!ans_fuzzer_friendly_) {
860
1.30k
      std::vector<uint32_t> histogram_symbols;
861
1.30k
      JXL_RETURN_IF_ERROR(ClusterHistograms(params, builder, kClustersLimit,
862
1.30k
                                            &clustered_histograms,
863
1.30k
                                            &histogram_symbols));
864
584k
      for (size_t c = 0; c < builder.size(); ++c) {
865
583k
        context_map[context_offset + c] =
866
583k
            static_cast<uint8_t>(histogram_symbols[c]);
867
583k
      }
868
1.30k
    } else {
869
0
      JXL_ENSURE(encoding_info.empty());
870
0
      std::fill(context_map.begin(), context_map.end(), 0);
871
0
      size_t max_symbol = 0;
872
0
      for (const Histogram& h : builder) {
873
0
        max_symbol = std::max(h.counts.size(), max_symbol);
874
0
      }
875
0
      size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
876
0
      clustered_histograms.resize(1);
877
0
      clustered_histograms[0].Clear();
878
0
      for (size_t i = 0; i < num_symbols; i++) {
879
0
        clustered_histograms[0].Add(i);
880
0
      }
881
0
    }
882
1.30k
    if (writer != nullptr) {
883
1.09k
      JXL_RETURN_IF_ERROR(EncodeContextMap(
884
1.09k
          context_map, clustered_histograms.size(), writer, layer, aux_out));
885
1.09k
    }
886
2.22k
  } else {
887
2.22k
    JXL_ENSURE(encoding_info.empty());
888
2.22k
    clustered_histograms.push_back(builder[0]);
889
2.22k
  }
890
3.53k
  if (aux_out != nullptr) {
891
0
    for (size_t i = prev_histograms; i < clustered_histograms.size(); ++i) {
892
0
      aux_out->layer(layer).clustered_entropy +=
893
0
          clustered_histograms[i].ShannonEntropy();
894
0
    }
895
0
  }
896
897
3.53k
  JXL_RETURN_IF_ERROR(ChooseUintConfigs(params, tokens, clustered_histograms));
898
899
3.53k
  SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
900
3.53k
  size_t cost = use_prefix_code ? 1 : 3;
901
902
3.53k
  if (writer) writer->Write(1, TO_JXL_BOOL(use_prefix_code));
903
3.53k
  if (writer == nullptr) {
904
2.02k
    EncodeUintConfigs(uint_config, &size_writer, log_alpha_size);
905
2.02k
  } else {
906
1.50k
    if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
907
1.50k
    EncodeUintConfigs(uint_config, writer, log_alpha_size);
908
1.50k
  }
909
3.53k
  if (use_prefix_code) {
910
2.29k
    for (const auto& histo : clustered_histograms) {
911
2.29k
      size_t alphabet_size = std::max<size_t>(1, histo.alphabet_size());
912
2.29k
      if (writer) {
913
540
        StoreVarLenUint16(alphabet_size - 1, writer);
914
1.75k
      } else {
915
1.75k
        StoreVarLenUint16(alphabet_size - 1, &size_writer);
916
1.75k
      }
917
2.29k
    }
918
2.28k
  }
919
3.53k
  cost += size_writer.size;
920
25.7k
  for (size_t c = prev_histograms; c < clustered_histograms.size(); ++c) {
921
22.1k
    size_t alphabet_size = clustered_histograms[c].alphabet_size();
922
22.1k
    encoding_info.emplace_back();
923
22.1k
    encoding_info.back().resize(alphabet_size);
924
22.1k
    BitWriter* histo_writer = writer;
925
22.1k
    if (params.streaming_mode) {
926
0
      encoded_histograms.emplace_back(memory_manager);
927
0
      histo_writer = &encoded_histograms.back();
928
0
    }
929
22.1k
    const auto& body = [&]() -> Status {
930
22.1k
      JXL_ASSIGN_OR_RETURN(size_t ans_cost,
931
22.1k
                           BuildAndStoreANSEncodingData(
932
22.1k
                               memory_manager, params.ans_histogram_strategy,
933
22.1k
                               clustered_histograms[c], histo_writer));
934
22.1k
      cost += ans_cost;
935
22.1k
      return true;
936
22.1k
    };
937
22.1k
    if (histo_writer) {
938
19.9k
      JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits(
939
19.9k
          256 + alphabet_size * 24, layer, aux_out, body,
940
19.9k
          /*finished_histogram=*/true));
941
19.9k
    } else {
942
2.22k
      JXL_RETURN_IF_ERROR(body());
943
2.22k
    }
944
22.1k
    if (params.streaming_mode) {
945
0
      JXL_RETURN_IF_ERROR(writer->AppendUnaligned(*histo_writer));
946
0
    }
947
22.1k
  }
948
3.53k
  return cost;
949
3.53k
}
950
951
template <typename Writer>
952
void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
953
22.5k
                      size_t log_alpha_size) {
954
22.5k
  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
955
22.5k
                uint_config.split_exponent);
956
22.5k
  if (uint_config.split_exponent == log_alpha_size) {
957
0
    return;  // msb/lsb don't matter.
958
0
  }
959
22.5k
  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
960
22.5k
  writer->Write(nbits, uint_config.msb_in_token);
961
22.5k
  nbits = CeilLog2Nonzero(uint_config.split_exponent -
962
22.5k
                          uint_config.msb_in_token + 1);
963
22.5k
  writer->Write(nbits, uint_config.lsb_in_token);
964
22.5k
}
void jxl::EncodeUintConfig<jxl::SizeWriter>(jxl::HybridUintConfig, jxl::SizeWriter*, unsigned long)
Line
Count
Source
953
2.43k
                      size_t log_alpha_size) {
954
2.43k
  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
955
2.43k
                uint_config.split_exponent);
956
2.43k
  if (uint_config.split_exponent == log_alpha_size) {
957
0
    return;  // msb/lsb don't matter.
958
0
  }
959
2.43k
  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
960
2.43k
  writer->Write(nbits, uint_config.msb_in_token);
961
2.43k
  nbits = CeilLog2Nonzero(uint_config.split_exponent -
962
2.43k
                          uint_config.msb_in_token + 1);
963
2.43k
  writer->Write(nbits, uint_config.lsb_in_token);
964
2.43k
}
void jxl::EncodeUintConfig<jxl::BitWriter>(jxl::HybridUintConfig, jxl::BitWriter*, unsigned long)
Line
Count
Source
953
20.1k
                      size_t log_alpha_size) {
954
20.1k
  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
955
20.1k
                uint_config.split_exponent);
956
20.1k
  if (uint_config.split_exponent == log_alpha_size) {
957
0
    return;  // msb/lsb don't matter.
958
0
  }
959
20.1k
  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
960
20.1k
  writer->Write(nbits, uint_config.msb_in_token);
961
20.1k
  nbits = CeilLog2Nonzero(uint_config.split_exponent -
962
20.1k
                          uint_config.msb_in_token + 1);
963
20.1k
  writer->Write(nbits, uint_config.lsb_in_token);
964
20.1k
}
965
template <typename Writer>
966
void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
967
3.53k
                       Writer* writer, size_t log_alpha_size) {
968
  // TODO(veluca): RLE?
969
22.1k
  for (const auto& cfg : uint_config) {
970
22.1k
    EncodeUintConfig(cfg, writer, log_alpha_size);
971
22.1k
  }
972
3.53k
}
void jxl::EncodeUintConfigs<jxl::BitWriter>(std::__1::vector<jxl::HybridUintConfig, std::__1::allocator<jxl::HybridUintConfig> > const&, jxl::BitWriter*, unsigned long)
Line
Count
Source
967
1.50k
                       Writer* writer, size_t log_alpha_size) {
968
  // TODO(veluca): RLE?
969
19.9k
  for (const auto& cfg : uint_config) {
970
19.9k
    EncodeUintConfig(cfg, writer, log_alpha_size);
971
19.9k
  }
972
1.50k
}
void jxl::EncodeUintConfigs<jxl::SizeWriter>(std::__1::vector<jxl::HybridUintConfig, std::__1::allocator<jxl::HybridUintConfig> > const&, jxl::SizeWriter*, unsigned long)
Line
Count
Source
967
2.02k
                       Writer* writer, size_t log_alpha_size) {
968
  // TODO(veluca): RLE?
969
2.22k
  for (const auto& cfg : uint_config) {
970
2.22k
    EncodeUintConfig(cfg, writer, log_alpha_size);
971
2.22k
  }
972
2.02k
}
973
template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
974
                                BitWriter*, size_t);
975
976
Status EncodeHistograms(const EntropyEncodingData& codes, BitWriter* writer,
977
0
                        LayerType layer, AuxOut* aux_out) {
978
0
  return writer->WithMaxBits(
979
0
      128 + kClustersLimit * 136, layer, aux_out,
980
0
      [&]() -> Status {
981
0
        JXL_RETURN_IF_ERROR(Bundle::Write(codes.lz77, writer, layer, aux_out));
982
0
        if (codes.lz77.enabled) {
983
0
          EncodeUintConfig(codes.lz77.length_uint_config, writer,
984
0
                           /*log_alpha_size=*/8);
985
0
        }
986
0
        JXL_RETURN_IF_ERROR(EncodeContextMap(codes.context_map,
987
0
                                             codes.encoding_info.size(), writer,
988
0
                                             layer, aux_out));
989
0
        writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code));
990
0
        size_t log_alpha_size = 8;
991
0
        if (codes.use_prefix_code) {
992
0
          log_alpha_size = PREFIX_MAX_BITS;
993
0
        } else {
994
0
          log_alpha_size = 8;  // streaming_mode
995
0
          writer->Write(2, log_alpha_size - 5);
996
0
        }
997
0
        EncodeUintConfigs(codes.uint_config, writer, log_alpha_size);
998
0
        if (codes.use_prefix_code) {
999
0
          for (const auto& info : codes.encoding_info) {
1000
0
            StoreVarLenUint16(info.size() - 1, writer);
1001
0
          }
1002
0
        }
1003
0
        for (const auto& histo_writer : codes.encoded_histograms) {
1004
0
          JXL_RETURN_IF_ERROR(writer->AppendUnaligned(histo_writer));
1005
0
        }
1006
0
        return true;
1007
0
      },
1008
0
      /*finished_histogram=*/true);
1009
0
}
1010
1011
StatusOr<size_t> BuildAndEncodeHistograms(
1012
    JxlMemoryManager* memory_manager, const HistogramParams& params,
1013
    size_t num_contexts, std::vector<std::vector<Token>>& tokens,
1014
    EntropyEncodingData* codes, BitWriter* writer, LayerType layer,
1015
3.53k
    AuxOut* aux_out) {
1016
  // TODO(Ivan): presumably not needed - default
1017
  // if (params.initialize_global_state) codes->lz77.enabled = false;
1018
3.53k
  codes->lz77.nonserialized_distance_context = num_contexts;
1019
3.53k
  codes->lz77.min_symbol = params.force_huffman ? 512 : 224;
1020
3.53k
  std::vector<std::vector<Token>> tokens_lz77 =
1021
3.53k
      ApplyLZ77(params, num_contexts, tokens, codes->lz77);
1022
3.53k
  if (!tokens_lz77.empty()) codes->lz77.enabled = true;
1023
3.53k
  if (ans_fuzzer_friendly_) {
1024
0
    codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1025
0
    codes->lz77.min_symbol = 2048;
1026
0
  }
1027
1028
3.53k
  size_t cost = 0;
1029
3.53k
  const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1030
3.53k
  const auto& body = [&]() -> Status {
1031
3.53k
    if (writer) {
1032
1.50k
      JXL_RETURN_IF_ERROR(Bundle::Write(codes->lz77, writer, layer, aux_out));
1033
2.02k
    } else {
1034
2.02k
      size_t ebits, bits;
1035
2.02k
      JXL_RETURN_IF_ERROR(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1036
2.02k
      cost += bits;
1037
2.02k
    }
1038
3.53k
    if (codes->lz77.enabled) {
1039
365
      if (writer) {
1040
149
        size_t b = writer->BitsWritten();
1041
149
        EncodeUintConfig(codes->lz77.length_uint_config, writer,
1042
149
                         /*log_alpha_size=*/8);
1043
149
        cost += writer->BitsWritten() - b;
1044
216
      } else {
1045
216
        SizeWriter size_writer;
1046
216
        EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1047
216
                         /*log_alpha_size=*/8);
1048
216
        cost += size_writer.size;
1049
216
      }
1050
365
      num_contexts += 1;
1051
365
      tokens = std::move(tokens_lz77);
1052
365
    }
1053
3.53k
    size_t total_tokens = 0;
1054
    // Build histograms.
1055
3.53k
    std::vector<Histogram> builder(num_contexts);
1056
3.53k
    HybridUintConfig uint_config = params.UintConfig();
1057
3.53k
    if (ans_fuzzer_friendly_) {
1058
0
      uint_config = HybridUintConfig(10, 0, 0);
1059
0
    }
1060
10.2k
    for (const auto& stream : tokens) {
1061
10.2k
      if (codes->lz77.enabled) {
1062
422k
        for (const auto& token : stream) {
1063
422k
          total_tokens++;
1064
422k
          uint32_t tok, nbits, bits;
1065
422k
          (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1066
422k
              .Encode(token.value, &tok, &nbits, &bits);
1067
422k
          tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1068
422k
          JXL_DASSERT(token.context < num_contexts);
1069
422k
          builder[token.context].Add(tok);
1070
422k
        }
1071
9.92k
      } else if (num_contexts == 1) {
1072
319k
        for (const auto& token : stream) {
1073
319k
          total_tokens++;
1074
319k
          uint32_t tok, nbits, bits;
1075
319k
          uint_config.Encode(token.value, &tok, &nbits, &bits);
1076
319k
          builder[0].Add(tok);
1077
319k
        }
1078
7.68k
      } else {
1079
32.8M
        for (const auto& token : stream) {
1080
32.8M
          total_tokens++;
1081
32.8M
          uint32_t tok, nbits, bits;
1082
32.8M
          uint_config.Encode(token.value, &tok, &nbits, &bits);
1083
32.8M
          JXL_DASSERT(token.context < num_contexts);
1084
32.8M
          builder[token.context].Add(tok);
1085
32.8M
        }
1086
7.68k
      }
1087
10.2k
    }
1088
1089
3.53k
    if (params.add_missing_symbols) {
1090
0
      for (size_t c = 0; c < num_contexts; ++c) {
1091
0
        for (int symbol = 0; symbol < ANS_MAX_ALPHABET_SIZE; ++symbol) {
1092
0
          builder[c].Add(symbol);
1093
0
        }
1094
0
      }
1095
0
    }
1096
1097
3.53k
    if (params.initialize_global_state) {
1098
3.53k
      bool use_prefix_code =
1099
3.53k
          params.force_huffman || total_tokens < 100 ||
1100
3.53k
          params.clustering == HistogramParams::ClusteringType::kFastest ||
1101
3.53k
          ans_fuzzer_friendly_;
1102
3.53k
      if (!use_prefix_code) {
1103
1.25k
        bool all_singleton = true;
1104
324k
        for (size_t i = 0; i < num_contexts; i++) {
1105
323k
          if (builder[i].ShannonEntropy() >= 1e-5) {
1106
137k
            all_singleton = false;
1107
137k
          }
1108
323k
        }
1109
1.25k
        if (all_singleton) {
1110
12
          use_prefix_code = true;
1111
12
        }
1112
1.25k
      }
1113
3.53k
      codes->use_prefix_code = use_prefix_code;
1114
3.53k
    }
1115
1116
3.53k
    if (params.add_fixed_histograms) {
1117
      // TODO(szabadka) Add more fixed histograms.
1118
      // TODO(szabadka) Reduce alphabet size by choosing a non-default
1119
      // uint_config.
1120
0
      const size_t alphabet_size = ANS_MAX_ALPHABET_SIZE;
1121
0
      codes->log_alpha_size = 8;
1122
0
      JXL_ENSURE(alphabet_size == 1u << codes->log_alpha_size);
1123
0
      static_assert(ANS_MAX_ALPHABET_SIZE <= ANS_TAB_SIZE,
1124
0
                    "Alphabet does not fit table");
1125
0
      codes->encoding_info.emplace_back();
1126
0
      codes->encoding_info.back().resize(alphabet_size);
1127
0
      codes->encoded_histograms.emplace_back(memory_manager);
1128
0
      BitWriter* histo_writer = &codes->encoded_histograms.back();
1129
0
      JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits(
1130
0
          256 + alphabet_size * 24, LayerType::Header, nullptr,
1131
0
          [&]() -> Status {
1132
0
            JXL_ASSIGN_OR_RETURN(
1133
0
                size_t ans_cost,
1134
0
                codes->BuildAndStoreANSEncodingData(
1135
0
                    memory_manager, params.ans_histogram_strategy,
1136
0
                    Histogram::Flat(alphabet_size, ANS_TAB_SIZE),
1137
0
                    histo_writer));
1138
0
            (void)ans_cost;
1139
0
            return true;
1140
0
          }));
1141
0
    }
1142
1143
    // Encode histograms.
1144
3.53k
    JXL_ASSIGN_OR_RETURN(
1145
3.53k
        size_t entropy_bits,
1146
3.53k
        codes->BuildAndStoreEntropyCodes(memory_manager, params, tokens,
1147
3.53k
                                         builder, writer, layer, aux_out));
1148
3.53k
    cost += entropy_bits;
1149
3.53k
    return true;
1150
3.53k
  };
1151
3.53k
  if (writer) {
1152
1.50k
    JXL_RETURN_IF_ERROR(writer->WithMaxBits(
1153
1.50k
        128 + num_contexts * 40 + max_contexts * 96, layer, aux_out, body,
1154
1.50k
        /*finished_histogram=*/true));
1155
2.02k
  } else {
1156
2.02k
    JXL_RETURN_IF_ERROR(body());
1157
2.02k
  }
1158
1159
3.53k
  if (aux_out != nullptr) {
1160
0
    aux_out->layer(layer).num_clustered_histograms +=
1161
0
        codes->encoding_info.size();
1162
0
  }
1163
3.53k
  return cost;
1164
3.53k
}
1165
1166
size_t WriteTokens(const std::vector<Token>& tokens,
1167
                   const EntropyEncodingData& codes, size_t context_offset,
1168
2.11k
                   BitWriter* writer) {
1169
2.11k
  size_t num_extra_bits = 0;
1170
2.11k
  if (codes.use_prefix_code) {
1171
19.6k
    for (const auto& token : tokens) {
1172
19.6k
      uint32_t tok, nbits, bits;
1173
19.6k
      size_t histo = codes.context_map[context_offset + token.context];
1174
19.6k
      (token.is_lz77_length ? codes.lz77.length_uint_config
1175
19.6k
                            : codes.uint_config[histo])
1176
19.6k
          .Encode(token.value, &tok, &nbits, &bits);
1177
19.6k
      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1178
      // Combine two calls to the BitWriter. Equivalent to:
1179
      // writer->Write(codes.encoding_info[histo][tok].depth,
1180
      //               codes.encoding_info[histo][tok].bits);
1181
      // writer->Write(nbits, bits);
1182
19.6k
      uint64_t data = codes.encoding_info[histo][tok].bits;
1183
19.6k
      data |= static_cast<uint64_t>(bits)
1184
19.6k
              << codes.encoding_info[histo][tok].depth;
1185
19.6k
      writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1186
19.6k
      num_extra_bits += nbits;
1187
19.6k
    }
1188
583
    return num_extra_bits;
1189
583
  }
1190
1.53k
  std::vector<uint64_t> out;
1191
1.53k
  std::vector<uint8_t> out_nbits;
1192
1.53k
  out.reserve(tokens.size());
1193
1.53k
  out_nbits.reserve(tokens.size());
1194
1.53k
  uint64_t allbits = 0;
1195
1.53k
  size_t numallbits = 0;
1196
  // Writes in *reversed* order.
1197
66.2M
  auto addbits = [&](size_t bits, size_t nbits) {
1198
66.2M
    if (JXL_UNLIKELY(nbits)) {
1199
8.43M
      JXL_DASSERT(bits >> nbits == 0);
1200
8.43M
      if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1201
1.66M
        out.push_back(allbits);
1202
1.66M
        out_nbits.push_back(numallbits);
1203
1.66M
        numallbits = allbits = 0;
1204
1.66M
      }
1205
8.43M
      allbits <<= nbits;
1206
8.43M
      allbits |= bits;
1207
8.43M
      numallbits += nbits;
1208
8.43M
    }
1209
66.2M
  };
1210
1.53k
  const int end = tokens.size();
1211
1.53k
  ANSCoder ans;
1212
1.53k
  if (codes.lz77.enabled || codes.context_map.size() > 1) {
1213
33.0M
    for (int i = end - 1; i >= 0; --i) {
1214
33.0M
      const Token token = tokens[i];
1215
33.0M
      const uint8_t histo = codes.context_map[context_offset + token.context];
1216
33.0M
      uint32_t tok, nbits, bits;
1217
33.0M
      (token.is_lz77_length ? codes.lz77.length_uint_config
1218
33.0M
                            : codes.uint_config[histo])
1219
33.0M
          .Encode(tokens[i].value, &tok, &nbits, &bits);
1220
33.0M
      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1221
33.0M
      const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1222
33.0M
      JXL_DASSERT(info.freq_ > 0);
1223
      // Extra bits first as this is reversed.
1224
33.0M
      addbits(bits, nbits);
1225
33.0M
      num_extra_bits += nbits;
1226
33.0M
      uint8_t ans_nbits = 0;
1227
33.0M
      uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1228
33.0M
      addbits(ans_bits, ans_nbits);
1229
33.0M
    }
1230
1.50k
  } else {
1231
82.3k
    for (int i = end - 1; i >= 0; --i) {
1232
82.2k
      uint32_t tok, nbits, bits;
1233
82.2k
      codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
1234
82.2k
      const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
1235
      // Extra bits first as this is reversed.
1236
82.2k
      addbits(bits, nbits);
1237
82.2k
      num_extra_bits += nbits;
1238
82.2k
      uint8_t ans_nbits = 0;
1239
82.2k
      uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1240
82.2k
      addbits(ans_bits, ans_nbits);
1241
82.2k
    }
1242
31
  }
1243
1.53k
  const uint32_t state = ans.GetState();
1244
1.53k
  writer->Write(32, state);
1245
1.53k
  writer->Write(numallbits, allbits);
1246
1.66M
  for (int i = out.size(); i > 0; --i) {
1247
1.66M
    writer->Write(out_nbits[i - 1], out[i - 1]);
1248
1.66M
  }
1249
1.53k
  return num_extra_bits;
1250
2.11k
}
1251
1252
Status WriteTokens(const std::vector<Token>& tokens,
1253
                   const EntropyEncodingData& codes, size_t context_offset,
1254
1.59k
                   BitWriter* writer, LayerType layer, AuxOut* aux_out) {
1255
  // Theoretically, we could have 15 prefix code bits + 31 extra bits.
1256
1.59k
  return writer->WithMaxBits(
1257
1.59k
      46 * tokens.size() + 32 * 1024 * 4, layer, aux_out, [&] {
1258
1.59k
        size_t num_extra_bits =
1259
1.59k
            WriteTokens(tokens, codes, context_offset, writer);
1260
1.59k
        if (aux_out != nullptr) {
1261
0
          aux_out->layer(layer).extra_bits += num_extra_bits;
1262
0
        }
1263
1.59k
        return true;
1264
1.59k
      });
1265
1.59k
}
1266
1267
0
void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1268
#if JXL_IS_DEBUG_BUILD  // Guard against accidental / malicious changes.
1269
  ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1270
#endif
1271
0
}
1272
1273
HistogramParams HistogramParams::ForModular(
1274
    const CompressParams& cparams,
1275
566
    const std::vector<uint8_t>& extra_dc_precision, bool streaming_mode) {
1276
566
  HistogramParams params;
1277
566
  params.streaming_mode = streaming_mode;
1278
566
  if (cparams.speed_tier > SpeedTier::kKitten) {
1279
566
    params.clustering = HistogramParams::ClusteringType::kFast;
1280
566
    params.ans_histogram_strategy =
1281
566
        cparams.speed_tier > SpeedTier::kThunder
1282
566
            ? HistogramParams::ANSHistogramStrategy::kFast
1283
566
            : HistogramParams::ANSHistogramStrategy::kApproximate;
1284
566
    params.lz77_method =
1285
566
        cparams.modular_mode && cparams.speed_tier <= SpeedTier::kHare
1286
566
            ? HistogramParams::LZ77Method::kRLE
1287
566
            : HistogramParams::LZ77Method::kNone;
1288
    // Near-lossless DC, as well as modular mode, require choosing hybrid uint
1289
    // more carefully.
1290
566
    if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) ||
1291
566
        (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) {
1292
380
      params.uint_method = HistogramParams::HybridUintMethod::kFast;
1293
380
    } else {
1294
186
      params.uint_method = HistogramParams::HybridUintMethod::kNone;
1295
186
    }
1296
566
  } else if (cparams.speed_tier <= SpeedTier::kTortoise) {
1297
0
    params.lz77_method = HistogramParams::LZ77Method::kOptimal;
1298
0
  } else {
1299
0
    params.lz77_method = HistogramParams::LZ77Method::kLZ77;
1300
0
  }
1301
566
  if (cparams.decoding_speed_tier >= 2) {
1302
0
    params.max_histograms = 12;
1303
0
  }
1304
566
  if ((cparams.decoding_speed_tier >= 3 ||
1305
566
       cparams.options.predictor == Predictor::Zero) &&
1306
566
      cparams.modular_mode) {
1307
0
    params.lz77_method = cparams.speed_tier >= SpeedTier::kCheetah
1308
0
                             ? HistogramParams::LZ77Method::kRLE
1309
0
                         : cparams.speed_tier >= SpeedTier::kKitten
1310
0
                             ? HistogramParams::LZ77Method::kLZ77
1311
0
                             : HistogramParams::LZ77Method::kOptimal;
1312
0
  }
1313
566
  return params;
1314
566
}
1315
}  // namespace jxl