Coverage Report

Created: 2025-07-23 08:18

/src/libjxl/lib/jxl/enc_ans.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_ans.h"
7
8
#include <jxl/memory_manager.h>
9
#include <jxl/types.h>
10
11
#include <algorithm>
12
#include <array>
13
#include <cmath>
14
#include <cstddef>
15
#include <cstdint>
16
#include <limits>
17
#include <utility>
18
#include <vector>
19
20
#include "lib/jxl/ans_common.h"
21
#include "lib/jxl/ans_params.h"
22
#include "lib/jxl/base/bits.h"
23
#include "lib/jxl/base/common.h"
24
#include "lib/jxl/base/compiler_specific.h"
25
#include "lib/jxl/base/status.h"
26
#include "lib/jxl/common.h"
27
#include "lib/jxl/dec_ans.h"
28
#include "lib/jxl/enc_ans_params.h"
29
#include "lib/jxl/enc_ans_simd.h"
30
#include "lib/jxl/enc_aux_out.h"
31
#include "lib/jxl/enc_cluster.h"
32
#include "lib/jxl/enc_context_map.h"
33
#include "lib/jxl/enc_fields.h"
34
#include "lib/jxl/enc_huffman.h"
35
#include "lib/jxl/enc_lz77.h"
36
#include "lib/jxl/enc_params.h"
37
#include "lib/jxl/fields.h"
38
#include "lib/jxl/memory_manager_internal.h"
39
#include "lib/jxl/modular/options.h"
40
#include "lib/jxl/simd_util.h"
41
42
namespace jxl {
43
44
namespace {
45
46
#if (!JXL_IS_DEBUG_BUILD)
47
constexpr
48
#endif
49
    bool ans_fuzzer_friendly_ = false;
50
51
const int kMaxNumSymbolsForSmallCode = 2;
52
53
template <typename Writer>
54
3.99M
void StoreVarLenUint8(size_t n, Writer* writer) {
55
3.99M
  JXL_DASSERT(n <= 255);
56
3.99M
  if (n == 0) {
57
310k
    writer->Write(1, 0);
58
3.68M
  } else {
59
3.68M
    writer->Write(1, 1);
60
3.68M
    size_t nbits = FloorLog2Nonzero(n);
61
3.68M
    writer->Write(3, nbits);
62
3.68M
    writer->Write(nbits, n - (1ULL << nbits));
63
3.68M
  }
64
3.99M
}
enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint8<jxl::SizeWriter>(unsigned long, jxl::SizeWriter*)
Line
Count
Source
54
3.85M
void StoreVarLenUint8(size_t n, Writer* writer) {
55
3.85M
  JXL_DASSERT(n <= 255);
56
3.85M
  if (n == 0) {
57
299k
    writer->Write(1, 0);
58
3.55M
  } else {
59
3.55M
    writer->Write(1, 1);
60
3.55M
    size_t nbits = FloorLog2Nonzero(n);
61
3.55M
    writer->Write(3, nbits);
62
3.55M
    writer->Write(nbits, n - (1ULL << nbits));
63
3.55M
  }
64
3.85M
}
enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint8<jxl::BitWriter>(unsigned long, jxl::BitWriter*)
Line
Count
Source
54
140k
void StoreVarLenUint8(size_t n, Writer* writer) {
55
140k
  JXL_DASSERT(n <= 255);
56
140k
  if (n == 0) {
57
10.8k
    writer->Write(1, 0);
58
130k
  } else {
59
130k
    writer->Write(1, 1);
60
130k
    size_t nbits = FloorLog2Nonzero(n);
61
130k
    writer->Write(3, nbits);
62
130k
    writer->Write(nbits, n - (1ULL << nbits));
63
130k
  }
64
140k
}
65
66
template <typename Writer>
67
22.5k
void StoreVarLenUint16(size_t n, Writer* writer) {
68
22.5k
  JXL_DASSERT(n <= 65535);
69
22.5k
  if (n == 0) {
70
327
    writer->Write(1, 0);
71
22.2k
  } else {
72
22.2k
    writer->Write(1, 1);
73
22.2k
    size_t nbits = FloorLog2Nonzero(n);
74
22.2k
    writer->Write(4, nbits);
75
22.2k
    writer->Write(nbits, n - (1ULL << nbits));
76
22.2k
  }
77
22.5k
}
enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint16<jxl::BitWriter>(unsigned long, jxl::BitWriter*)
Line
Count
Source
67
6.48k
void StoreVarLenUint16(size_t n, Writer* writer) {
68
6.48k
  JXL_DASSERT(n <= 65535);
69
6.48k
  if (n == 0) {
70
327
    writer->Write(1, 0);
71
6.15k
  } else {
72
6.15k
    writer->Write(1, 1);
73
6.15k
    size_t nbits = FloorLog2Nonzero(n);
74
6.15k
    writer->Write(4, nbits);
75
6.15k
    writer->Write(nbits, n - (1ULL << nbits));
76
6.15k
  }
77
6.48k
}
enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint16<jxl::SizeWriter>(unsigned long, jxl::SizeWriter*)
Line
Count
Source
67
16.0k
void StoreVarLenUint16(size_t n, Writer* writer) {
68
16.0k
  JXL_DASSERT(n <= 65535);
69
16.0k
  if (n == 0) {
70
0
    writer->Write(1, 0);
71
16.0k
  } else {
72
16.0k
    writer->Write(1, 1);
73
16.0k
    size_t nbits = FloorLog2Nonzero(n);
74
16.0k
    writer->Write(4, nbits);
75
16.0k
    writer->Write(nbits, n - (1ULL << nbits));
76
16.0k
  }
77
16.0k
}
78
79
class ANSEncodingHistogram {
80
 public:
81
106k
  const std::vector<ANSHistBin>& Counts() const { return counts_; }
82
544k
  float Cost() const { return cost_; }
83
  // The only way to construct valid histogram for ANS encoding
84
  static StatusOr<ANSEncodingHistogram> ComputeBest(
85
      const Histogram& histo,
86
544k
      HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
87
544k
    ANSEncodingHistogram result;
88
89
544k
    result.alphabet_size_ = histo.alphabet_size();
90
544k
    if (result.alphabet_size_ > ANS_MAX_ALPHABET_SIZE)
91
0
      return JXL_FAILURE("Too many entries in an ANS histogram");
92
93
544k
    if (result.alphabet_size_ > 0) {
94
      // Flat code
95
544k
      result.method_ = 0;
96
544k
      result.num_symbols_ = result.alphabet_size_;
97
544k
      result.counts_ = CreateFlatHistogram(result.alphabet_size_, ANS_TAB_SIZE);
98
      // in this case length can be non-suitable for SIMD - fix it
99
544k
      result.counts_.resize(histo.counts.size());
100
544k
      SizeWriter writer;
101
544k
      JXL_RETURN_IF_ERROR(result.Encode(&writer));
102
544k
      result.cost_ = writer.size + EstimateDataBitsFlat(histo);
103
544k
    } else {
104
      // Empty histogram
105
0
      result.method_ = 1;
106
0
      result.num_symbols_ = 0;
107
0
      result.cost_ = 3;
108
0
      return result;
109
0
    }
110
111
544k
    size_t symbol_count = 0;
112
18.5M
    for (size_t n = 0; n < result.alphabet_size_; ++n) {
113
17.9M
      if (histo.counts[n] > 0) {
114
8.72M
        if (symbol_count < kMaxNumSymbolsForSmallCode) {
115
1.04M
          result.symbols_[symbol_count] = n;
116
1.04M
        }
117
8.72M
        ++symbol_count;
118
8.72M
      }
119
17.9M
    }
120
544k
    result.num_symbols_ = symbol_count;
121
544k
    if (symbol_count == 1) {
122
      // Single-bin histogram
123
48.2k
      result.method_ = 1;
124
48.2k
      result.counts_ = histo.counts;
125
48.2k
      result.counts_[result.symbols_[0]] = ANS_TAB_SIZE;
126
48.2k
      SizeWriter writer;
127
48.2k
      JXL_RETURN_IF_ERROR(result.Encode(&writer));
128
48.2k
      result.cost_ = writer.size;
129
48.2k
      return result;
130
48.2k
    }
131
132
    // Here min 2 symbols
133
496k
    ANSEncodingHistogram normalized = result;
134
1.94M
    auto try_shift = [&](uint32_t shift) -> Status {
135
      // `shift = 12` and `shift = 11` are the same
136
1.94M
      normalized.method_ = std::min(shift, ANS_LOG_TAB_SIZE - 1) + 1;
137
138
1.94M
      if (!normalized.RebalanceHistogram(histo)) {
139
0
        return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
140
0
      }
141
1.94M
      SizeWriter writer;
142
1.94M
      JXL_RETURN_IF_ERROR(normalized.Encode(&writer));
143
1.94M
      normalized.cost_ = writer.size + normalized.EstimateDataBits(histo);
144
1.94M
      if (normalized.cost_ < result.cost_) {
145
497k
        result = normalized;
146
497k
      }
147
1.94M
      return true;
148
1.94M
    };
149
150
496k
    switch (ans_histogram_strategy) {
151
10.1k
      case HistogramParams::ANSHistogramStrategy::kPrecise:
152
132k
        for (uint32_t shift = 0; shift < ANS_LOG_TAB_SIZE; shift++) {
153
122k
          JXL_RETURN_IF_ERROR(try_shift(shift));
154
122k
        }
155
10.1k
        break;
156
92.0k
      case HistogramParams::ANSHistogramStrategy::kApproximate:
157
736k
        for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
158
644k
          JXL_RETURN_IF_ERROR(try_shift(shift));
159
644k
        }
160
92.0k
        break;
161
394k
      case HistogramParams::ANSHistogramStrategy::kFast:
162
394k
        JXL_RETURN_IF_ERROR(try_shift(0));
163
394k
        JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE / 2));
164
394k
        JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE));
165
394k
        break;
166
496k
    }
167
168
      // Sanity check
169
496k
#if JXL_IS_DEBUG_BUILD
170
496k
    JXL_ENSURE(histo.counts.size() == result.counts_.size());
171
496k
    ANSHistBin total = 0;  // Used only in assert.
172
18.3M
    for (size_t i = 0; i < result.alphabet_size_; ++i) {
173
17.8M
      JXL_ENSURE(result.counts_[i] >= 0);
174
      // For non-flat histogram values should be zero or non-zero simultaneously
175
      // for the same symbol in both initial and normalized histograms.
176
17.8M
      JXL_ENSURE(result.method_ == 0 ||
177
17.8M
                 (histo.counts[i] > 0) == (result.counts_[i] > 0));
178
      // Check accuracy of the histogram values
179
17.8M
      if (result.method_ > 0 && result.counts_[i] > 0 &&
180
17.8M
          i != result.omit_pos_) {
181
6.29M
        int logcounts = FloorLog2Nonzero<uint32_t>(result.counts_[i]);
182
6.29M
        int bitcount =
183
6.29M
            GetPopulationCountPrecision(logcounts, result.method_ - 1);
184
6.29M
        int drop_bits = logcounts - bitcount;
185
        // Check that the value is divisible by 2^drop_bits
186
6.29M
        JXL_ENSURE((result.counts_[i] & ((1 << drop_bits) - 1)) == 0);
187
6.29M
      }
188
17.8M
      total += result.counts_[i];
189
17.8M
    }
190
2.16M
    for (size_t i = result.alphabet_size_; i < result.counts_.size(); ++i) {
191
1.66M
      JXL_ENSURE(histo.counts[i] == 0);
192
1.66M
      JXL_ENSURE(result.counts_[i] == 0);
193
1.66M
    }
194
496k
    JXL_ENSURE((histo.total_count == 0) || (total == ANS_TAB_SIZE));
195
496k
#endif
196
496k
    return result;
197
496k
  }
198
199
  template <typename Writer>
200
2.64M
  Status Encode(Writer* writer) {
201
    // The check ensures also that all RLE sequences can be
202
    // encoded by `StoreVarLenUint8`
203
2.64M
    JXL_ENSURE(alphabet_size_ <= ANS_MAX_ALPHABET_SIZE);
204
205
    /// Flat histogram.
206
2.64M
    if (method_ == 0) {
207
      // Mark non-small tree.
208
548k
      writer->Write(1, 0);
209
      // Mark uniform histogram.
210
548k
      writer->Write(1, 1);
211
548k
      JXL_ENSURE(alphabet_size_ > 0);
212
      // Encode alphabet size.
213
548k
      StoreVarLenUint8(alphabet_size_ - 1, writer);
214
215
548k
      return true;
216
548k
    }
217
218
    /// Small tree.
219
2.09M
    if (num_symbols_ <= kMaxNumSymbolsForSmallCode) {
220
      // Small tree marker to encode 1-2 symbols.
221
66.4k
      writer->Write(1, 1);
222
66.4k
      if (num_symbols_ == 0) {
223
0
        writer->Write(1, 0);
224
0
        StoreVarLenUint8(0, writer);
225
66.4k
      } else {
226
66.4k
        writer->Write(1, num_symbols_ - 1);
227
148k
        for (size_t i = 0; i < num_symbols_; ++i) {
228
81.6k
          StoreVarLenUint8(symbols_[i], writer);
229
81.6k
        }
230
66.4k
      }
231
66.4k
      if (num_symbols_ == 2) {
232
15.2k
        writer->Write(ANS_LOG_TAB_SIZE, counts_[symbols_[0]]);
233
15.2k
      }
234
235
66.4k
      return true;
236
66.4k
    }
237
238
    /// General tree.
239
    // Mark non-small tree.
240
2.03M
    writer->Write(1, 0);
241
    // Mark non-flat histogram.
242
2.03M
    writer->Write(1, 0);
243
244
    // Elias gamma-like code for `shift = method - 1`. Only difference is that
245
    // if the number of bits to be encoded is equal to `upper_bound_log`,
246
    // we skip the terminating 0 in unary coding.
247
2.03M
    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
248
2.03M
    int log = FloorLog2Nonzero(method_);
249
2.03M
    writer->Write(log, (1 << log) - 1);
250
2.03M
    if (log != upper_bound_log) writer->Write(1, 0);
251
2.03M
    writer->Write(log, ((1 << log) - 1) & method_);
252
253
    // Since `num_symbols_ >= 3`, we know that `alphabet_size_ >= 3`, therefore
254
    // we encode `alphabet_size_ - 3`.
255
2.03M
    StoreVarLenUint8(alphabet_size_ - 3, writer);
256
257
    // Precompute sequences for RLE encoding. Contains the number of identical
258
    // values starting at a given index. Only contains that value at the first
259
    // element of the series.
260
2.03M
    uint8_t same[ANS_MAX_ALPHABET_SIZE] = {};
261
2.03M
    size_t last = 0;
262
75.0M
    for (size_t i = 1; i <= alphabet_size_; i++) {
263
      // Store the sequence length once different symbol reached, or we are
264
      // near the omit_pos_, or we're at the end. We don't support including the
265
      // omit_pos_ in an RLE sequence because this value may use a different
266
      // amount of log2 bits than standard, it is too complex to handle in the
267
      // decoder.
268
72.9M
      if (i == alphabet_size_ || i == omit_pos_ || i == omit_pos_ + 1 ||
269
72.9M
          counts_[i] != counts_[last]) {
270
38.6M
        same[last] = i - last;
271
38.6M
        last = i;
272
38.6M
      }
273
72.9M
    }
274
275
2.03M
    uint8_t bit_width[ANS_MAX_ALPHABET_SIZE] = {};
276
    // Use shortest possible Huffman code to encode `omit_pos` (see
277
    // `kBitWidthLengths`). `bit_width` value at `omit_pos` should be the
278
    // first of maximal values in the whole `bit_width` array, so it can be
279
    // increased without changing that property
280
2.03M
    int omit_width = 10;
281
75.0M
    for (size_t i = 0; i < alphabet_size_; ++i) {
282
72.9M
      if (i != omit_pos_ && counts_[i] > 0) {
283
32.8M
        bit_width[i] = FloorLog2Nonzero<uint32_t>(counts_[i]) + 1;
284
32.8M
        omit_width = std::max(omit_width, bit_width[i] + int{i < omit_pos_});
285
32.8M
      }
286
72.9M
    }
287
2.03M
    bit_width[omit_pos_] = static_cast<uint8_t>(omit_width);
288
289
    // The bit widths are encoded with a static Huffman code.
290
    // The last symbol is used as RLE sequence.
291
2.03M
    constexpr uint8_t kBitWidthLengths[ANS_LOG_TAB_SIZE + 2] = {
292
2.03M
        5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
293
2.03M
    };
294
2.03M
    constexpr uint8_t kBitWidthSymbols[ANS_LOG_TAB_SIZE + 2] = {
295
2.03M
        17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
296
2.03M
    };
297
2.03M
    constexpr uint8_t kMinReps = 5;
298
2.03M
    constexpr size_t rep = ANS_LOG_TAB_SIZE + 1;
299
    // Encode count bit widths
300
48.5M
    for (size_t i = 0; i < alphabet_size_; ++i) {
301
46.5M
      writer->Write(kBitWidthLengths[bit_width[i]],
302
46.5M
                    kBitWidthSymbols[bit_width[i]]);
303
46.5M
      if (same[i] >= kMinReps) {
304
        // Encode the RLE symbol and skip the repeated ones.
305
1.33M
        writer->Write(kBitWidthLengths[rep], kBitWidthSymbols[rep]);
306
1.33M
        StoreVarLenUint8(same[i] - kMinReps, writer);
307
1.33M
        i += same[i] - 1;
308
1.33M
      }
309
46.5M
    }
310
    // Encode additional bits of accuracy
311
2.03M
    uint32_t shift = method_ - 1;
312
2.03M
    if (shift != 0) {  // otherwise `bitcount = 0`
313
35.2M
      for (size_t i = 0; i < alphabet_size_; ++i) {
314
33.7M
        if (bit_width[i] > 1 && i != omit_pos_) {
315
23.3M
          int bitcount = GetPopulationCountPrecision(bit_width[i] - 1, shift);
316
23.3M
          int drop_bits = bit_width[i] - 1 - bitcount;
317
23.3M
          JXL_DASSERT((counts_[i] & ((1 << drop_bits) - 1)) == 0);
318
23.3M
          writer->Write(bitcount, (counts_[i] >> drop_bits) - (1 << bitcount));
319
23.3M
        }
320
33.7M
        if (same[i] >= kMinReps) {
321
          // Skip symbols encoded by RLE.
322
923k
          i += same[i] - 1;
323
923k
        }
324
33.7M
      }
325
1.47M
    }
326
2.03M
    return true;
327
2.03M
  }
enc_ans.cc:jxl::Status jxl::(anonymous namespace)::ANSEncodingHistogram::Encode<jxl::SizeWriter>(jxl::SizeWriter*)
Line
Count
Source
200
2.54M
  Status Encode(Writer* writer) {
201
    // The check ensures also that all RLE sequences can be
202
    // encoded by `StoreVarLenUint8`
203
2.54M
    JXL_ENSURE(alphabet_size_ <= ANS_MAX_ALPHABET_SIZE);
204
205
    /// Flat histogram.
206
2.54M
    if (method_ == 0) {
207
      // Mark non-small tree.
208
544k
      writer->Write(1, 0);
209
      // Mark uniform histogram.
210
544k
      writer->Write(1, 1);
211
544k
      JXL_ENSURE(alphabet_size_ > 0);
212
      // Encode alphabet size.
213
544k
      StoreVarLenUint8(alphabet_size_ - 1, writer);
214
215
544k
      return true;
216
544k
    }
217
218
    /// Small tree.
219
1.99M
    if (num_symbols_ <= kMaxNumSymbolsForSmallCode) {
220
      // Small tree marker to encode 1-2 symbols.
221
62.8k
      writer->Write(1, 1);
222
62.8k
      if (num_symbols_ == 0) {
223
0
        writer->Write(1, 0);
224
0
        StoreVarLenUint8(0, writer);
225
62.8k
      } else {
226
62.8k
        writer->Write(1, num_symbols_ - 1);
227
140k
        for (size_t i = 0; i < num_symbols_; ++i) {
228
77.5k
          StoreVarLenUint8(symbols_[i], writer);
229
77.5k
        }
230
62.8k
      }
231
62.8k
      if (num_symbols_ == 2) {
232
14.6k
        writer->Write(ANS_LOG_TAB_SIZE, counts_[symbols_[0]]);
233
14.6k
      }
234
235
62.8k
      return true;
236
62.8k
    }
237
238
    /// General tree.
239
    // Mark non-small tree.
240
1.93M
    writer->Write(1, 0);
241
    // Mark non-flat histogram.
242
1.93M
    writer->Write(1, 0);
243
244
    // Elias gamma-like code for `shift = method - 1`. Only difference is that
245
    // if the number of bits to be encoded is equal to `upper_bound_log`,
246
    // we skip the terminating 0 in unary coding.
247
1.93M
    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
248
1.93M
    int log = FloorLog2Nonzero(method_);
249
1.93M
    writer->Write(log, (1 << log) - 1);
250
1.93M
    if (log != upper_bound_log) writer->Write(1, 0);
251
1.93M
    writer->Write(log, ((1 << log) - 1) & method_);
252
253
    // Since `num_symbols_ >= 3`, we know that `alphabet_size_ >= 3`, therefore
254
    // we encode `alphabet_size_ - 3`.
255
1.93M
    StoreVarLenUint8(alphabet_size_ - 3, writer);
256
257
    // Precompute sequences for RLE encoding. Contains the number of identical
258
    // values starting at a given index. Only contains that value at the first
259
    // element of the series.
260
1.93M
    uint8_t same[ANS_MAX_ALPHABET_SIZE] = {};
261
1.93M
    size_t last = 0;
262
72.3M
    for (size_t i = 1; i <= alphabet_size_; i++) {
263
      // Store the sequence length once different symbol reached, or we are
264
      // near the omit_pos_, or we're at the end. We don't support including the
265
      // omit_pos_ in an RLE sequence because this value may use a different
266
      // amount of log2 bits than standard, it is too complex to handle in the
267
      // decoder.
268
70.4M
      if (i == alphabet_size_ || i == omit_pos_ || i == omit_pos_ + 1 ||
269
70.4M
          counts_[i] != counts_[last]) {
270
37.0M
        same[last] = i - last;
271
37.0M
        last = i;
272
37.0M
      }
273
70.4M
    }
274
275
1.93M
    uint8_t bit_width[ANS_MAX_ALPHABET_SIZE] = {};
276
    // Use shortest possible Huffman code to encode `omit_pos` (see
277
    // `kBitWidthLengths`). `bit_width` value at `omit_pos` should be the
278
    // first of maximal values in the whole `bit_width` array, so it can be
279
    // increased without changing that property
280
1.93M
    int omit_width = 10;
281
72.3M
    for (size_t i = 0; i < alphabet_size_; ++i) {
282
70.4M
      if (i != omit_pos_ && counts_[i] > 0) {
283
31.4M
        bit_width[i] = FloorLog2Nonzero<uint32_t>(counts_[i]) + 1;
284
31.4M
        omit_width = std::max(omit_width, bit_width[i] + int{i < omit_pos_});
285
31.4M
      }
286
70.4M
    }
287
1.93M
    bit_width[omit_pos_] = static_cast<uint8_t>(omit_width);
288
289
    // The bit widths are encoded with a static Huffman code.
290
    // The last symbol is used as RLE sequence.
291
1.93M
    constexpr uint8_t kBitWidthLengths[ANS_LOG_TAB_SIZE + 2] = {
292
1.93M
        5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
293
1.93M
    };
294
1.93M
    constexpr uint8_t kBitWidthSymbols[ANS_LOG_TAB_SIZE + 2] = {
295
1.93M
        17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
296
1.93M
    };
297
1.93M
    constexpr uint8_t kMinReps = 5;
298
1.93M
    constexpr size_t rep = ANS_LOG_TAB_SIZE + 1;
299
    // Encode count bit widths
300
46.5M
    for (size_t i = 0; i < alphabet_size_; ++i) {
301
44.6M
      writer->Write(kBitWidthLengths[bit_width[i]],
302
44.6M
                    kBitWidthSymbols[bit_width[i]]);
303
44.6M
      if (same[i] >= kMinReps) {
304
        // Encode the RLE symbol and skip the repeated ones.
305
1.29M
        writer->Write(kBitWidthLengths[rep], kBitWidthSymbols[rep]);
306
1.29M
        StoreVarLenUint8(same[i] - kMinReps, writer);
307
1.29M
        i += same[i] - 1;
308
1.29M
      }
309
44.6M
    }
310
    // Encode additional bits of accuracy
311
1.93M
    uint32_t shift = method_ - 1;
312
1.93M
    if (shift != 0) {  // otherwise `bitcount = 0`
313
34.5M
      for (size_t i = 0; i < alphabet_size_; ++i) {
314
33.0M
        if (bit_width[i] > 1 && i != omit_pos_) {
315
22.8M
          int bitcount = GetPopulationCountPrecision(bit_width[i] - 1, shift);
316
22.8M
          int drop_bits = bit_width[i] - 1 - bitcount;
317
22.8M
          JXL_DASSERT((counts_[i] & ((1 << drop_bits) - 1)) == 0);
318
22.8M
          writer->Write(bitcount, (counts_[i] >> drop_bits) - (1 << bitcount));
319
22.8M
        }
320
33.0M
        if (same[i] >= kMinReps) {
321
          // Skip symbols encoded by RLE.
322
909k
          i += same[i] - 1;
323
909k
        }
324
33.0M
      }
325
1.44M
    }
326
1.93M
    return true;
327
1.93M
  }
enc_ans.cc:jxl::Status jxl::(anonymous namespace)::ANSEncodingHistogram::Encode<jxl::BitWriter>(jxl::BitWriter*)
Line
Count
Source
200
102k
  Status Encode(Writer* writer) {
201
    // The check ensures also that all RLE sequences can be
202
    // encoded by `StoreVarLenUint8`
203
102k
    JXL_ENSURE(alphabet_size_ <= ANS_MAX_ALPHABET_SIZE);
204
205
    /// Flat histogram.
206
102k
    if (method_ == 0) {
207
      // Mark non-small tree.
208
3.42k
      writer->Write(1, 0);
209
      // Mark uniform histogram.
210
3.42k
      writer->Write(1, 1);
211
3.42k
      JXL_ENSURE(alphabet_size_ > 0);
212
      // Encode alphabet size.
213
3.42k
      StoreVarLenUint8(alphabet_size_ - 1, writer);
214
215
3.42k
      return true;
216
3.42k
    }
217
218
    /// Small tree.
219
99.0k
    if (num_symbols_ <= kMaxNumSymbolsForSmallCode) {
220
      // Small tree marker to encode 1-2 symbols.
221
3.58k
      writer->Write(1, 1);
222
3.58k
      if (num_symbols_ == 0) {
223
0
        writer->Write(1, 0);
224
0
        StoreVarLenUint8(0, writer);
225
3.58k
      } else {
226
3.58k
        writer->Write(1, num_symbols_ - 1);
227
7.68k
        for (size_t i = 0; i < num_symbols_; ++i) {
228
4.10k
          StoreVarLenUint8(symbols_[i], writer);
229
4.10k
        }
230
3.58k
      }
231
3.58k
      if (num_symbols_ == 2) {
232
524
        writer->Write(ANS_LOG_TAB_SIZE, counts_[symbols_[0]]);
233
524
      }
234
235
3.58k
      return true;
236
3.58k
    }
237
238
    /// General tree.
239
    // Mark non-small tree.
240
95.4k
    writer->Write(1, 0);
241
    // Mark non-flat histogram.
242
95.4k
    writer->Write(1, 0);
243
244
    // Elias gamma-like code for `shift = method - 1`. Only difference is that
245
    // if the number of bits to be encoded is equal to `upper_bound_log`,
246
    // we skip the terminating 0 in unary coding.
247
95.4k
    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
248
95.4k
    int log = FloorLog2Nonzero(method_);
249
95.4k
    writer->Write(log, (1 << log) - 1);
250
95.4k
    if (log != upper_bound_log) writer->Write(1, 0);
251
95.4k
    writer->Write(log, ((1 << log) - 1) & method_);
252
253
    // Since `num_symbols_ >= 3`, we know that `alphabet_size_ >= 3`, therefore
254
    // we encode `alphabet_size_ - 3`.
255
95.4k
    StoreVarLenUint8(alphabet_size_ - 3, writer);
256
257
    // Precompute sequences for RLE encoding. Contains the number of identical
258
    // values starting at a given index. Only contains that value at the first
259
    // element of the series.
260
95.4k
    uint8_t same[ANS_MAX_ALPHABET_SIZE] = {};
261
95.4k
    size_t last = 0;
262
2.63M
    for (size_t i = 1; i <= alphabet_size_; i++) {
263
      // Store the sequence length once different symbol reached, or we are
264
      // near the omit_pos_, or we're at the end. We don't support including the
265
      // omit_pos_ in an RLE sequence because this value may use a different
266
      // amount of log2 bits than standard, it is too complex to handle in the
267
      // decoder.
268
2.54M
      if (i == alphabet_size_ || i == omit_pos_ || i == omit_pos_ + 1 ||
269
2.54M
          counts_[i] != counts_[last]) {
270
1.56M
        same[last] = i - last;
271
1.56M
        last = i;
272
1.56M
      }
273
2.54M
    }
274
275
95.4k
    uint8_t bit_width[ANS_MAX_ALPHABET_SIZE] = {};
276
    // Use shortest possible Huffman code to encode `omit_pos` (see
277
    // `kBitWidthLengths`). `bit_width` value at `omit_pos` should be the
278
    // first of maximal values in the whole `bit_width` array, so it can be
279
    // increased without changing that property
280
95.4k
    int omit_width = 10;
281
2.63M
    for (size_t i = 0; i < alphabet_size_; ++i) {
282
2.54M
      if (i != omit_pos_ && counts_[i] > 0) {
283
1.36M
        bit_width[i] = FloorLog2Nonzero<uint32_t>(counts_[i]) + 1;
284
1.36M
        omit_width = std::max(omit_width, bit_width[i] + int{i < omit_pos_});
285
1.36M
      }
286
2.54M
    }
287
95.4k
    bit_width[omit_pos_] = static_cast<uint8_t>(omit_width);
288
289
    // The bit widths are encoded with a static Huffman code.
290
    // The last symbol is used as RLE sequence.
291
95.4k
    constexpr uint8_t kBitWidthLengths[ANS_LOG_TAB_SIZE + 2] = {
292
95.4k
        5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
293
95.4k
    };
294
95.4k
    constexpr uint8_t kBitWidthSymbols[ANS_LOG_TAB_SIZE + 2] = {
295
95.4k
        17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
296
95.4k
    };
297
95.4k
    constexpr uint8_t kMinReps = 5;
298
95.4k
    constexpr size_t rep = ANS_LOG_TAB_SIZE + 1;
299
    // Encode count bit widths
300
2.00M
    for (size_t i = 0; i < alphabet_size_; ++i) {
301
1.90M
      writer->Write(kBitWidthLengths[bit_width[i]],
302
1.90M
                    kBitWidthSymbols[bit_width[i]]);
303
1.90M
      if (same[i] >= kMinReps) {
304
        // Encode the RLE symbol and skip the repeated ones.
305
37.9k
        writer->Write(kBitWidthLengths[rep], kBitWidthSymbols[rep]);
306
37.9k
        StoreVarLenUint8(same[i] - kMinReps, writer);
307
37.9k
        i += same[i] - 1;
308
37.9k
      }
309
1.90M
    }
310
    // Encode additional bits of accuracy
311
95.4k
    uint32_t shift = method_ - 1;
312
95.4k
    if (shift != 0) {  // otherwise `bitcount = 0`
313
775k
      for (size_t i = 0; i < alphabet_size_; ++i) {
314
738k
        if (bit_width[i] > 1 && i != omit_pos_) {
315
509k
          int bitcount = GetPopulationCountPrecision(bit_width[i] - 1, shift);
316
509k
          int drop_bits = bit_width[i] - 1 - bitcount;
317
509k
          JXL_DASSERT((counts_[i] & ((1 << drop_bits) - 1)) == 0);
318
509k
          writer->Write(bitcount, (counts_[i] >> drop_bits) - (1 << bitcount));
319
509k
        }
320
738k
        if (same[i] >= kMinReps) {
321
          // Skip symbols encoded by RLE.
322
14.4k
          i += same[i] - 1;
323
14.4k
        }
324
738k
      }
325
36.6k
    }
326
95.4k
    return true;
327
95.4k
  }
328
329
  void ANSBuildInfoTable(const AliasTable::Entry* table, size_t log_alpha_size,
330
106k
                         ANSEncSymbolInfo* info) {
331
    // Create valid alias table for empty streams
332
3.17M
    for (size_t s = 0; s < std::max(size_t{1}, alphabet_size_); ++s) {
333
3.06M
      const ANSHistBin freq = s == alphabet_size_ ? ANS_TAB_SIZE : counts_[s];
334
3.06M
      info[s].freq_ = static_cast<uint16_t>(freq);
335
3.06M
#ifdef USE_MULT_BY_RECIPROCAL
336
3.06M
      if (freq != 0) {
337
1.66M
        info[s].ifreq_ = ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) /
338
1.66M
                         info[s].freq_;
339
1.66M
      } else {
340
1.40M
        info[s].ifreq_ =
341
1.40M
            1;  // Shouldn't matter (symbol shouldn't occur), but...
342
1.40M
      }
343
3.06M
#endif
344
3.06M
      info[s].reverse_map_.resize(freq);
345
3.06M
    }
346
106k
    size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
347
106k
    size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
348
436M
    for (int i = 0; i < ANS_TAB_SIZE; i++) {
349
436M
      AliasTable::Symbol s =
350
436M
          AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
351
436M
      info[s.value].reverse_map_[s.offset] = i;
352
436M
    }
353
106k
  }
354
355
 private:
356
544k
  ANSEncodingHistogram() {}
357
358
  // Fixed-point log2 LUT for values of [0,4096]
359
  using Lg2LUT = std::array<uint32_t, ANS_TAB_SIZE + 1>;
360
  static const Lg2LUT lg2;
361
362
1.94M
  float EstimateDataBits(const Histogram& histo) {
363
1.94M
    int64_t sum = 0;
364
72.5M
    for (size_t i = 0; i < alphabet_size_; ++i) {
365
      // += histogram[i] * -log(counts[i]/total_counts)
366
70.5M
      sum += histo.counts[i] * int64_t{lg2[counts_[i]]};
367
70.5M
    }
368
1.94M
    return (histo.total_count - ldexpf(sum, -31)) * ANS_LOG_TAB_SIZE;
369
1.94M
  }
370
371
544k
  static float EstimateDataBitsFlat(const Histogram& histo) {
372
544k
    size_t len = histo.alphabet_size();
373
544k
    int64_t flat_bits = int64_t{lg2[len]} * ANS_LOG_TAB_SIZE;
374
544k
    return ldexpf(histo.total_count * flat_bits, -31);
375
544k
  }
376
377
  struct CountsEntropy {
378
    ANSHistBin count : 16;     // allowed value of counts in a histogram bin
379
    ANSHistBin step_log : 16;  // log2 of increase step size (can use 5 bits)
380
    int32_t delta_lg2;  // change of log between that value and the next allowed
381
  };
382
383
  // Array is sorted by decreasing allowed counts for each possible shift.
384
  // Exclusion of single-bin histograms before `RebalanceHistogram` allows
385
  // to put count upper limit of 4095, and shifts of 11 and 12 produce the
386
  // same table
387
  using CountsArray =
388
      std::array<std::array<CountsEntropy, ANS_TAB_SIZE>, ANS_LOG_TAB_SIZE>;
389
  using CountsIndex =
390
      std::array<std::array<uint16_t, ANS_TAB_SIZE>, ANS_LOG_TAB_SIZE>;
391
  struct AllowedCounts {
392
    CountsArray array;
393
    CountsIndex index;
394
  };
395
  static const AllowedCounts allowed_counts;
396
397
  // Returns the difference between largest count that can be represented and is
398
  // smaller than "count" and smallest representable count larger than "count".
399
82.8M
  static uint32_t SmallestIncrementLog(uint32_t count, uint32_t shift) {
400
82.8M
    if (count == 0) return 0;
401
45.6M
    uint32_t bits = FloorLog2Nonzero(count);
402
45.6M
    uint32_t drop_bits = bits - GetPopulationCountPrecision(bits, shift);
403
45.6M
    return drop_bits;
404
82.8M
  }
405
  // We are growing/reducing histogram step by step trying to maximize total
406
  // entropy i.e. sum of `freq[n] * log[counts[n]]` with a given sum of
407
  // `counts[n]` chosen from `allowed_counts[shift]`. This sum is balanced by
408
  // the `counts[omit_pos_]` in the highest bin of histogram. We start from
409
  // close to correct solution and each time a step with maximum entropy
410
  // increase per unit of bin change is chosen. This greedy scheme is not
411
  // guaranteed to achieve the global maximum, but cannot produce invalid
412
  // histogram. We use a fixed-point approximation for logarithms and all
413
  // arithmetic is integer besides initial approximation. Sum of `freq` and each
414
  // of `lg2[counts]` are supposed to be limited to `int32_t` range, so that the
415
  // sum of their products should not exceed `int64_t`.
416
1.94M
  bool RebalanceHistogram(const Histogram& histo) {
417
1.94M
    constexpr ANSHistBin table_size = ANS_TAB_SIZE;
418
1.94M
    uint32_t shift = method_ - 1;
419
420
1.94M
    struct EntropyDelta {
421
1.94M
      ANSHistBin freq;   // initial count
422
1.94M
      size_t count_ind;  // index of current bin value in `allowed_counts`
423
1.94M
      size_t bin_ind;    // index of current bin in `counts`
424
1.94M
    };
425
    // Penalties corresponding to different step sizes - entropy decrease in
426
    // balancing bin, step of size (1 << ANS_LOG_TAB_SIZE - 1) is not possible
427
1.94M
    std::array<int64_t, ANS_LOG_TAB_SIZE - 1> balance_inc = {};
428
1.94M
    std::array<int64_t, ANS_LOG_TAB_SIZE - 1> balance_dec = {};
429
1.94M
    const auto& ac = allowed_counts.array[shift];
430
1.94M
    const auto& ai = allowed_counts.index[shift];
431
    // TODO(ivan) separate cases of shift >= 11 - all steps are 1 there, and
432
    // possibly 10 - all relevant steps are 2.
433
    // Total entropy change by a step: increase/decrease in current bin
434
    // together with corresponding decrease/increase in the balancing bin.
435
    // Inc steps increase current bin, dec steps decrease
436
622M
    const auto delta_entropy_inc = [&](const EntropyDelta& a) {
437
622M
      return a.freq * int64_t{ac[a.count_ind].delta_lg2} -
438
622M
             balance_inc[ac[a.count_ind].step_log];
439
622M
    };
440
91.7M
    const auto delta_entropy_dec = [&](const EntropyDelta& a) {
441
91.7M
      return a.freq * int64_t{ac[a.count_ind + 1].delta_lg2} -
442
91.7M
             balance_dec[ac[a.count_ind + 1].step_log];
443
91.7M
    };
444
    // Compare steps by entropy increase per unit of histogram bin change.
445
    // Truncation is OK here, accuracy is anyway better than float
446
305M
    const auto IncLess = [&](const EntropyDelta& a, const EntropyDelta& b) {
447
305M
      return delta_entropy_inc(a) >> ac[a.count_ind].step_log <
448
305M
             delta_entropy_inc(b) >> ac[b.count_ind].step_log;
449
305M
    };
450
44.6M
    const auto DecLess = [&](const EntropyDelta& a, const EntropyDelta& b) {
451
44.6M
      return delta_entropy_dec(a) >> ac[a.count_ind + 1].step_log <
452
44.6M
             delta_entropy_dec(b) >> ac[b.count_ind + 1].step_log;
453
44.6M
    };
454
    // Vector of adjustable bins from `allowed_counts`
455
1.94M
    std::vector<EntropyDelta> bins;
456
1.94M
    bins.reserve(256);
457
458
1.94M
    double norm = double{table_size} / histo.total_count;
459
460
1.94M
    size_t remainder_pos = 0;  // highest balancing bin in the histogram
461
1.94M
    int64_t max_freq = 0;
462
1.94M
    ANSHistBin rest = table_size;  // reserve of histogram counts to distribute
463
72.5M
    for (size_t n = 0; n < alphabet_size_; ++n) {
464
70.5M
      ANSHistBin freq = histo.counts[n];
465
70.5M
      if (freq > max_freq) {
466
4.11M
        remainder_pos = n;
467
4.11M
        max_freq = freq;
468
4.11M
      }
469
470
70.5M
      double target = freq * norm;  // rounding
471
      // Keep zeros and clamp nonzero freq counts to [1, table_size)
472
70.5M
      ANSHistBin count = std::max<ANSHistBin>(round(target), freq > 0);
473
70.5M
      count = std::min<ANSHistBin>(count, table_size - 1);
474
70.5M
      uint32_t step_log = SmallestIncrementLog(count, shift);
475
70.5M
      ANSHistBin inc = 1 << step_log;
476
70.5M
      count &= ~(inc - 1);
477
478
70.5M
      counts_[n] = count;
479
70.5M
      rest -= count;
480
70.5M
      if (target > 1.0) {
481
33.1M
        bins.push_back({freq, ai[count], n});
482
33.1M
      }
483
70.5M
    }
484
485
    // Delete the highest balancing bin from adjustable by `allowed_counts`
486
1.94M
    bins.erase(std::find_if(
487
1.94M
        bins.begin(), bins.end(),
488
10.4M
        [&](const EntropyDelta& a) { return a.bin_ind == remainder_pos; }));
489
    // From now on `rest` is the height of balancing bin,
490
    // here it can be negative, but will be tracted into positive domain later
491
1.94M
    rest += counts_[remainder_pos];
492
493
1.94M
    if (!bins.empty()) {
494
1.94M
      const uint32_t max_log = ac[1].step_log;
495
11.9M
      while (true) {
496
        // Update balancing bin penalties setting guards and tractors
497
105M
        for (uint32_t log = 0; log <= max_log; ++log) {
498
94.0M
          ANSHistBin delta = 1 << log;
499
94.0M
          if (rest >= table_size) {
500
            // Tract large `rest` into allowed domain:
501
0
            balance_inc[log] = 0;  // permit all inc steps
502
0
            balance_dec[log] = 0;  // forbid all dec steps
503
94.0M
          } else if (rest > 1) {
504
            // `rest` is OK, put guards against non-possible steps
505
94.0M
            balance_inc[log] =
506
94.0M
                rest > delta  // possible step
507
94.0M
                    ? max_freq * int64_t{lg2[rest] - lg2[rest - delta]}
508
94.0M
                    : std::numeric_limits<int64_t>::max();  // forbidden
509
94.0M
            balance_dec[log] =
510
94.0M
                rest + delta < table_size  // possible step
511
94.0M
                    ? max_freq * int64_t{lg2[rest + delta] - lg2[rest]}
512
94.0M
                    : 0;  // forbidden
513
94.0M
          } else {
514
            // Tract negative or zero `rest` into positive:
515
            // forbid all inc steps
516
276
            balance_inc[log] = std::numeric_limits<int64_t>::max();
517
            // permit all dec steps
518
276
            balance_dec[log] = std::numeric_limits<int64_t>::max();
519
276
          }
520
94.0M
        }
521
        // Try to increase entropy
522
11.9M
        auto best_bin_inc = std::max_element(bins.begin(), bins.end(), IncLess);
523
11.9M
        if (delta_entropy_inc(*best_bin_inc) > 0) {
524
          // Grow the bin with the best histogram entropy increase
525
9.49M
          rest -= 1 << ac[best_bin_inc->count_ind--].step_log;
526
9.49M
        } else {
527
          // This still implies that entropy is strictly increasing each step
528
          // (or `rest` is tracted into positive domain), so we cannot loop
529
          // infinitely
530
2.49M
          auto best_bin_dec =
531
2.49M
              std::min_element(bins.begin(), bins.end(), DecLess);
532
          // Break if no reverse steps can grow entropy (or valid)
533
2.49M
          if (delta_entropy_dec(*best_bin_dec) >= 0) break;
534
          // Decrease the bin with the best histogram entropy increase
535
546k
          rest += 1 << ac[++best_bin_dec->count_ind].step_log;
536
546k
        }
537
11.9M
      }
538
      // Set counts besides the balancing bin
539
31.1M
      for (auto& a : bins) counts_[a.bin_ind] = ac[a.count_ind].count;
540
541
      // The scheme works fine if we have room to grow `bit_width` of balancing
542
      // bin, otherwise we need to put balancing bin to the first bin of 12 bit
543
      // width. In this case both that bin and balancing one should be close to
544
      // 2048 in targets, so exchange of them will not produce much worse
545
      // histogram
546
17.2M
      for (size_t n = 0; n < remainder_pos; ++n) {
547
15.2M
        if (counts_[n] >= 2048) {
548
5.21k
          counts_[remainder_pos] = counts_[n];
549
5.21k
          remainder_pos = n;
550
5.21k
          break;
551
5.21k
        }
552
15.2M
      }
553
1.94M
    }
554
    // Set balancing bin
555
1.94M
    counts_[remainder_pos] = rest;
556
1.94M
    omit_pos_ = remainder_pos;
557
558
1.94M
    return counts_[remainder_pos] > 0;
559
1.94M
  }
560
561
  float cost_ = 0;
562
  uint32_t method_ = 0;
563
  size_t omit_pos_ = 0;
564
  size_t alphabet_size_ = 0;
565
  size_t num_symbols_ = 0;
566
  size_t symbols_[kMaxNumSymbolsForSmallCode] = {};
567
  std::vector<ANSHistBin> counts_{};
568
};
569
570
using AEH = ANSEncodingHistogram;
571
572
250
const AEH::Lg2LUT AEH::lg2 = [] {
573
250
  Lg2LUT lg2;
574
250
  lg2[0] = 0;  // for entropy calculations it is OK
575
1.02M
  for (size_t i = 1; i < lg2.size(); ++i) {
576
1.02M
    lg2[i] = round(ldexp(log2(i) / ANS_LOG_TAB_SIZE, 31));
577
1.02M
  }
578
250
  return lg2;
579
250
}();
580
581
250
const AEH::AllowedCounts AEH::allowed_counts = [] {
582
250
  AllowedCounts result;
583
584
3.25k
  for (uint32_t shift = 0; shift < result.array.size(); ++shift) {
585
3.00k
    auto& ac = result.array[shift];
586
3.00k
    auto& ai = result.index[shift];
587
3.00k
    ANSHistBin last = ~0;
588
3.00k
    size_t slot = 0;
589
    // TODO(eustas): are those "default" values relevant?
590
3.00k
    ac[0].delta_lg2 = 0;
591
3.00k
    ac[0].step_log = 0;
592
12.2M
    for (int32_t i = ac.size() - 1; i >= 0; --i) {
593
12.2M
      int32_t curr = i & ~((1 << SmallestIncrementLog(i, shift)) - 1);
594
12.2M
      if (curr == last) continue;
595
2.39M
      last = curr;
596
2.39M
      ac[slot].count = curr;
597
2.39M
      ai[curr] = slot;
598
2.39M
      if (curr == 0) {
599
        // Guards against non-possible steps:
600
        // at max value [0] - 0 (by init), at min value - max
601
3.00k
        ac[slot].delta_lg2 = std::numeric_limits<int32_t>::max();
602
3.00k
        ac[slot].step_log = 0;
603
2.39M
      } else if (slot > 0) {
604
2.39M
        ANSHistBin prev = ac[slot - 1].count;
605
2.39M
        ac[slot].delta_lg2 = round(ldexp(
606
2.39M
            log2(static_cast<double>(prev) / curr) / ANS_LOG_TAB_SIZE, 31));
607
2.39M
        ac[slot].step_log = FloorLog2Nonzero<uint32_t>(prev - curr);
608
2.39M
        prev = curr;
609
2.39M
      }
610
2.39M
      slot++;
611
2.39M
    }
612
3.00k
  }
613
614
250
  return result;
615
250
}();
616
617
}  // namespace
618
619
438k
StatusOr<float> Histogram::ANSPopulationCost() const {
620
438k
  if (counts.size() > ANS_MAX_ALPHABET_SIZE) {
621
0
    return std::numeric_limits<float>::max();
622
0
  }
623
438k
  JXL_ASSIGN_OR_RETURN(
624
438k
      ANSEncodingHistogram normalized,
625
438k
      ANSEncodingHistogram::ComputeBest(
626
438k
          *this, HistogramParams::ANSHistogramStrategy::kFast));
627
438k
  return normalized.Cost();
628
438k
}
629
630
// Returns an estimate or exact cost of encoding this histogram and the
631
// corresponding data.
632
StatusOr<size_t> EntropyEncodingData::BuildAndStoreANSEncodingData(
633
    JxlMemoryManager* memory_manager,
634
    HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
635
129k
    const Histogram& histogram, BitWriter* writer) {
636
129k
  ANSEncSymbolInfo* info = encoding_info.back().data();
637
129k
  size_t size = histogram.alphabet_size();
638
129k
  if (use_prefix_code) {
639
22.5k
    size_t cost = 0;
640
22.5k
    if (size <= 1) return 0;
641
22.2k
    std::vector<uint32_t> histo(size);
642
187k
    for (size_t i = 0; i < size; i++) {
643
165k
      JXL_ENSURE(histogram.counts[i] >= 0);
644
165k
      histo[i] = histogram.counts[i];
645
165k
    }
646
22.2k
    std::vector<uint8_t> depths(size);
647
22.2k
    std::vector<uint16_t> bits(size);
648
22.2k
    if (writer == nullptr) {
649
16.0k
      BitWriter tmp_writer{memory_manager};
650
16.0k
      JXL_RETURN_IF_ERROR(tmp_writer.WithMaxBits(
651
16.0k
          8 * size + 8,  // safe upper bound
652
16.0k
          LayerType::Header, /*aux_out=*/nullptr, [&] {
653
16.0k
            return BuildAndStoreHuffmanTree(histo.data(), size, depths.data(),
654
16.0k
                                            bits.data(), &tmp_writer);
655
16.0k
          }));
656
16.0k
      cost = tmp_writer.BitsWritten();
657
16.0k
    } else {
658
6.15k
      size_t start = writer->BitsWritten();
659
6.15k
      JXL_RETURN_IF_ERROR(BuildAndStoreHuffmanTree(
660
6.15k
          histo.data(), size, depths.data(), bits.data(), writer));
661
6.15k
      cost = writer->BitsWritten() - start;
662
6.15k
    }
663
187k
    for (size_t i = 0; i < size; i++) {
664
165k
      info[i].bits = depths[i] == 0 ? 0 : bits[i];
665
165k
      info[i].depth = depths[i];
666
165k
    }
667
    // Estimate data cost.
668
187k
    for (size_t i = 0; i < size; i++) {
669
165k
      cost += histo[i] * info[i].depth;
670
165k
    }
671
22.2k
    return cost;
672
22.2k
  }
673
213k
  JXL_ASSIGN_OR_RETURN(
674
213k
      ANSEncodingHistogram normalized,
675
213k
      ANSEncodingHistogram::ComputeBest(histogram, ans_histogram_strategy));
676
213k
  AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
677
213k
  JXL_RETURN_IF_ERROR(
678
213k
      InitAliasTable(normalized.Counts(), ANS_LOG_TAB_SIZE, log_alpha_size, a));
679
106k
  normalized.ANSBuildInfoTable(a, log_alpha_size, info);
680
106k
  if (writer != nullptr) {
681
    // size_t start = writer->BitsWritten();
682
102k
    JXL_RETURN_IF_ERROR(normalized.Encode(writer));
683
    // return writer->BitsWritten() - start;
684
102k
  }
685
106k
  return static_cast<size_t>(ceilf(normalized.Cost()));
686
106k
}
687
688
namespace {
689
690
Histogram HistogramFromSymbolInfo(
691
0
    const std::vector<ANSEncSymbolInfo>& encoding_info, bool use_prefix_code) {
692
0
  Histogram histo;
693
0
  histo.counts.resize(DivCeil(encoding_info.size(), Histogram::kRounding) *
694
0
                      Histogram::kRounding);
695
0
  histo.total_count = 0;
696
0
  for (size_t i = 0; i < encoding_info.size(); ++i) {
697
0
    const ANSEncSymbolInfo& info = encoding_info[i];
698
0
    int count = use_prefix_code
699
0
                    ? (info.depth ? (1u << (PREFIX_MAX_BITS - info.depth)) : 0)
700
0
                    : info.freq_;
701
0
    histo.counts[i] = count;
702
0
    histo.total_count += count;
703
0
  }
704
0
  return histo;
705
0
}
706
707
}  // namespace
708
709
Status EntropyEncodingData::ChooseUintConfigs(
710
    JxlMemoryManager* memory_manager, const HistogramParams& params,
711
    const std::vector<std::vector<Token>>& tokens,
712
34.5k
    std::vector<Histogram>& clustered_histograms) {
713
  // Set sane default `log_alpha_size`.
714
34.5k
  if (use_prefix_code) {
715
22.3k
    log_alpha_size = PREFIX_MAX_BITS;
716
22.3k
  } else if (params.streaming_mode) {
717
    // TODO(szabadka) Figure out if we can use lower values here.
718
0
    log_alpha_size = 8;
719
12.1k
  } else if (lz77.enabled) {
720
2.47k
    log_alpha_size = 8;
721
9.66k
  } else {
722
9.66k
    log_alpha_size = 7;
723
9.66k
  }
724
725
34.5k
  if (ans_fuzzer_friendly_) {
726
0
    uint_config.assign(1, HybridUintConfig(7, 0, 0));
727
0
    return true;
728
0
  }
729
730
34.5k
  uint_config.assign(clustered_histograms.size(), params.UintConfig());
731
  // If the uint config is fixed, just use it.
732
34.5k
  if (params.uint_method != HistogramParams::HybridUintMethod::kBest &&
733
34.5k
      params.uint_method != HistogramParams::HybridUintMethod::kFast) {
734
25.7k
    return true;
735
25.7k
  }
736
  // Even if the uint config is adaptive, just stick with the default in
737
  // streaming mode.
738
8.75k
  if (params.streaming_mode) {
739
0
    return true;
740
0
  }
741
742
  // Brute-force method that tries a few options.
743
8.75k
  std::vector<HybridUintConfig> configs;
744
8.75k
  if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
745
2.49k
    configs = {
746
2.49k
        HybridUintConfig(4, 2, 0),  // default
747
2.49k
        HybridUintConfig(4, 1, 0),  // less precise
748
2.49k
        HybridUintConfig(4, 2, 1),  // add sign
749
2.49k
        HybridUintConfig(4, 2, 2),  // add sign+parity
750
2.49k
        HybridUintConfig(4, 1, 2),  // add parity but less msb
751
        // Same as above, but more direct coding.
752
2.49k
        HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
753
2.49k
        HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
754
2.49k
        HybridUintConfig(5, 1, 2),
755
        // Same as above, but less direct coding.
756
2.49k
        HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
757
2.49k
        HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
758
        // For near-lossless.
759
2.49k
        HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
760
2.49k
        HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
761
2.49k
        HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
762
        // Other
763
2.49k
        HybridUintConfig(0, 0, 0),   // varlenuint
764
2.49k
        HybridUintConfig(2, 0, 1),   // works well for ctx map
765
2.49k
        HybridUintConfig(7, 0, 0),   // direct coding
766
2.49k
        HybridUintConfig(8, 0, 0),   // direct coding
767
2.49k
        HybridUintConfig(9, 0, 0),   // direct coding
768
2.49k
        HybridUintConfig(10, 0, 0),  // direct coding
769
2.49k
        HybridUintConfig(11, 0, 0),  // direct coding
770
2.49k
        HybridUintConfig(12, 0, 0),  // direct coding
771
2.49k
    };
772
6.26k
  } else {
773
6.26k
    JXL_DASSERT(params.uint_method == HistogramParams::HybridUintMethod::kFast);
774
6.26k
    configs = {
775
6.26k
        HybridUintConfig(4, 2, 0),  // default
776
6.26k
        HybridUintConfig(4, 1, 2),  // add parity but less msb
777
6.26k
        HybridUintConfig(0, 0, 0),  // smallest histograms
778
6.26k
        HybridUintConfig(2, 0, 1),  // works well for ctx map
779
6.26k
    };
780
6.26k
  }
781
782
8.75k
  size_t num_histo = clustered_histograms.size();
783
8.75k
  std::vector<uint8_t> is_valid(num_histo);
784
8.75k
  std::vector<size_t> histo_volume(2 * num_histo);
785
8.75k
  std::vector<size_t> histo_offset(2 * num_histo + 1);
786
8.75k
  std::vector<uint32_t> max_value_per_histo(2 * num_histo);
787
788
  // TODO(veluca): do not ignore lz77 commands.
789
790
77.6k
  for (const auto& stream : tokens) {
791
43.6M
    for (const auto& token : stream) {
792
43.6M
      size_t histo = context_map[token.context];
793
43.6M
      histo_volume[histo + (token.is_lz77_length ? num_histo : 0)]++;
794
43.6M
    }
795
77.6k
  }
796
8.75k
  size_t max_histo_volume = 0;
797
121k
  for (size_t h = 0; h < 2 * num_histo; ++h) {
798
112k
    max_histo_volume = std::max(max_histo_volume, histo_volume[h]);
799
112k
    histo_offset[h + 1] = histo_offset[h] + histo_volume[h];
800
112k
  }
801
802
8.75k
  const size_t max_vec_size = MaxVectorSize();
803
8.75k
  std::vector<uint32_t> transposed(histo_offset[num_histo * 2] + max_vec_size);
804
8.75k
  {
805
8.75k
    std::vector<size_t> next_offset = histo_offset;  // copy
806
77.6k
    for (const auto& stream : tokens) {
807
43.6M
      for (const auto& token : stream) {
808
43.6M
        size_t histo =
809
43.6M
            context_map[token.context] + (token.is_lz77_length ? num_histo : 0);
810
43.6M
        transposed[next_offset[histo]++] = token.value;
811
43.6M
      }
812
77.6k
    }
813
8.75k
  }
814
121k
  for (size_t h = 0; h < 2 * num_histo; ++h) {
815
112k
    max_value_per_histo[h] =
816
112k
        MaxValue(transposed.data() + histo_offset[h], histo_volume[h]);
817
112k
  }
818
8.75k
  uint32_t max_lz77 = 0;
819
65.2k
  for (size_t h = num_histo; h < 2 * num_histo; ++h) {
820
56.4k
    max_lz77 = std::max(max_lz77, MaxValue(transposed.data() + histo_offset[h],
821
56.4k
                                           histo_volume[h]));
822
56.4k
  }
823
824
  // Wider histograms are assigned max cost in PopulationCost anyway
825
  // and therefore will not be used
826
8.75k
  size_t max_alpha = ANS_MAX_ALPHABET_SIZE;
827
828
8.75k
  JXL_ASSIGN_OR_RETURN(
829
8.75k
      AlignedMemory tmp,
830
8.75k
      AlignedMemory::Create(memory_manager, (max_histo_volume + max_vec_size) *
831
8.75k
                                                sizeof(uint32_t)));
832
65.2k
  for (size_t h = 0; h < num_histo; h++) {
833
56.4k
    float best_cost = std::numeric_limits<float>::max();
834
414k
    for (HybridUintConfig cfg : configs) {
835
414k
      uint32_t max_v = max_value_per_histo[h];
836
414k
      size_t capacity;
837
414k
      {
838
414k
        uint32_t tok, nbits, bits;
839
414k
        cfg.Encode(max_v, &tok, &nbits, &bits);
840
414k
        tok |= cfg.LsbMask();
841
414k
        if (tok >= max_alpha || (lz77.enabled && tok >= lz77.min_symbol)) {
842
10.6k
          continue;  // Not valid config for this context
843
10.6k
        }
844
403k
        capacity = tok + 1;
845
403k
      }
846
847
0
      Histogram histo;
848
403k
      histo.EnsureCapacity(capacity);
849
403k
      size_t len = histo_volume[h];
850
403k
      uint32_t* data = transposed.data() + histo_offset[h];
851
403k
      size_t extra_bits = EstimateTokenCost(data, len, cfg, tmp);
852
403k
      uint32_t* tmp_tokens = tmp.address<uint32_t>();
853
267M
      for (size_t i = 0; i < len; ++i) {
854
267M
        histo.FastAdd(tmp_tokens[i]);
855
267M
      }
856
403k
      histo.Condition();
857
403k
      JXL_ASSIGN_OR_RETURN(float cost, histo.ANSPopulationCost());
858
403k
      cost += extra_bits;
859
      // Add signaling cost of the hybriduintconfig itself.
860
403k
      cost += CeilLog2Nonzero(cfg.split_exponent + 1);
861
403k
      cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
862
403k
      if (cost < best_cost) {
863
122k
        uint_config[h] = cfg;
864
122k
        best_cost = cost;
865
122k
        clustered_histograms[h].swap(histo);
866
122k
      }
867
403k
    }
868
56.4k
  }
869
870
8.75k
  size_t max_tok = 0;
871
65.2k
  for (size_t h = 0; h < num_histo; ++h) {
872
56.4k
    Histogram& histo = clustered_histograms[h];
873
56.4k
    max_tok = std::max(max_tok, histo.MaxSymbol());
874
56.4k
    size_t len = histo_volume[num_histo + h];
875
56.4k
    if (len == 0) continue;  // E.g. when lz77 not enabled
876
638
    size_t max_histo_tok = max_value_per_histo[num_histo + h];
877
638
    uint32_t tok, nbits, bits;
878
638
    lz77.length_uint_config.Encode(max_histo_tok, &tok, &nbits, &bits);
879
638
    tok |= lz77.length_uint_config.LsbMask();
880
638
    tok += lz77.min_symbol;
881
638
    histo.EnsureCapacity(tok + 1);
882
638
    uint32_t* data = transposed.data() + histo_offset[num_histo + h];
883
638
    uint32_t unused =
884
638
        EstimateTokenCost(data, len, lz77.length_uint_config, tmp);
885
638
    (void)unused;
886
638
    uint32_t* tmp_tokens = tmp.address<uint32_t>();
887
7.56k
    for (size_t i = 0; i < len; ++i) {
888
6.92k
      histo.FastAdd(tmp_tokens[i] + lz77.min_symbol);
889
6.92k
    }
890
638
    histo.Condition();
891
638
    max_tok = std::max(max_tok, histo.MaxSymbol());
892
638
  }
893
894
  // `log_alpha_size - 5` is encoded in the header, so min is 5.
895
8.75k
  size_t log_size = 5;
896
14.2k
  while (max_tok >= (1u << log_size)) ++log_size;
897
898
8.75k
  size_t max_log_alpha_size = use_prefix_code ? PREFIX_MAX_BITS : 8;
899
8.75k
  JXL_ENSURE(log_size <= max_log_alpha_size);
900
901
8.75k
  if (use_prefix_code) {
902
2.36k
    log_alpha_size = PREFIX_MAX_BITS;
903
6.39k
  } else {
904
6.39k
    log_alpha_size = log_size;
905
6.39k
  }
906
907
8.75k
  return true;
908
8.75k
}
909
910
// NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
911
// Returns cost (in bits).
912
StatusOr<size_t> EntropyEncodingData::BuildAndStoreEntropyCodes(
913
    JxlMemoryManager* memory_manager, const HistogramParams& params,
914
    const std::vector<std::vector<Token>>& tokens,
915
    const std::vector<Histogram>& builder, BitWriter* writer, LayerType layer,
916
34.5k
    AuxOut* aux_out) {
917
34.5k
  const size_t prev_histograms = encoding_info.size();
918
34.5k
  std::vector<Histogram> clustered_histograms;
919
34.5k
  for (size_t i = 0; i < prev_histograms; ++i) {
920
0
    clustered_histograms.push_back(
921
0
        HistogramFromSymbolInfo(encoding_info[i], use_prefix_code));
922
0
  }
923
34.5k
  size_t context_offset = context_map.size();
924
34.5k
  context_map.resize(context_offset + builder.size());
925
34.5k
  if (builder.size() > 1) {
926
13.1k
    if (!ans_fuzzer_friendly_) {
927
13.1k
      std::vector<uint32_t> histogram_symbols;
928
13.1k
      JXL_RETURN_IF_ERROR(ClusterHistograms(params, builder, kClustersLimit,
929
13.1k
                                            &clustered_histograms,
930
13.1k
                                            &histogram_symbols));
931
12.4M
      for (size_t c = 0; c < builder.size(); ++c) {
932
12.3M
        context_map[context_offset + c] =
933
12.3M
            static_cast<uint8_t>(histogram_symbols[c]);
934
12.3M
      }
935
13.1k
    } else {
936
0
      JXL_ENSURE(encoding_info.empty());
937
0
      std::fill(context_map.begin(), context_map.end(), 0);
938
0
      size_t max_symbol = 0;
939
0
      for (const Histogram& h : builder) {
940
0
        max_symbol = std::max(h.counts.size(), max_symbol);
941
0
      }
942
0
      size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
943
0
      clustered_histograms.resize(1);
944
0
      clustered_histograms[0].Clear();
945
0
      for (size_t i = 0; i < num_symbols; i++) {
946
0
        clustered_histograms[0].Add(i);
947
0
      }
948
0
    }
949
13.1k
    if (writer != nullptr) {
950
11.6k
      JXL_RETURN_IF_ERROR(EncodeContextMap(
951
11.6k
          context_map, clustered_histograms.size(), writer, layer, aux_out));
952
11.6k
    }
953
21.3k
  } else {
954
21.3k
    JXL_ENSURE(encoding_info.empty());
955
21.3k
    clustered_histograms.push_back(builder[0]);
956
21.3k
  }
957
34.5k
  if (aux_out != nullptr) {
958
0
    for (size_t i = prev_histograms; i < clustered_histograms.size(); ++i) {
959
0
      aux_out->layer(layer).clustered_entropy +=
960
0
          clustered_histograms[i].ShannonEntropy();
961
0
    }
962
0
  }
963
964
34.5k
  JXL_RETURN_IF_ERROR(
965
34.5k
      ChooseUintConfigs(memory_manager, params, tokens, clustered_histograms));
966
967
34.5k
  SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
968
34.5k
  size_t cost = use_prefix_code ? 1 : 3;
969
970
34.5k
  if (writer) writer->Write(1, TO_JXL_BOOL(use_prefix_code));
971
34.5k
  if (writer == nullptr) {
972
18.7k
    EncodeUintConfigs(uint_config, &size_writer, log_alpha_size);
973
18.7k
  } else {
974
15.7k
    if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
975
15.7k
    EncodeUintConfigs(uint_config, writer, log_alpha_size);
976
15.7k
  }
977
34.5k
  if (use_prefix_code) {
978
22.5k
    for (const auto& histo : clustered_histograms) {
979
22.5k
      size_t alphabet_size = std::max<size_t>(1, histo.alphabet_size());
980
22.5k
      if (writer) {
981
6.48k
        StoreVarLenUint16(alphabet_size - 1, writer);
982
16.0k
      } else {
983
16.0k
        StoreVarLenUint16(alphabet_size - 1, &size_writer);
984
16.0k
      }
985
22.5k
    }
986
22.3k
  }
987
34.5k
  cost += size_writer.size;
988
163k
  for (size_t c = prev_histograms; c < clustered_histograms.size(); ++c) {
989
129k
    size_t alphabet_size = clustered_histograms[c].alphabet_size();
990
129k
    encoding_info.emplace_back();
991
129k
    encoding_info.back().resize(alphabet_size);
992
129k
    BitWriter* histo_writer = writer;
993
129k
    if (params.streaming_mode) {
994
0
      encoded_histograms.emplace_back(memory_manager);
995
0
      histo_writer = &encoded_histograms.back();
996
0
    }
997
129k
    const auto& body = [&]() -> Status {
998
129k
      JXL_ASSIGN_OR_RETURN(size_t ans_cost,
999
129k
                           BuildAndStoreANSEncodingData(
1000
129k
                               memory_manager, params.ans_histogram_strategy,
1001
129k
                               clustered_histograms[c], histo_writer));
1002
129k
      cost += ans_cost;
1003
129k
      return true;
1004
129k
    };
1005
129k
    if (histo_writer) {
1006
108k
      JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits(
1007
108k
          256 + alphabet_size * 24, layer, aux_out, body,
1008
108k
          /*finished_histogram=*/true));
1009
108k
    } else {
1010
20.1k
      JXL_RETURN_IF_ERROR(body());
1011
20.1k
    }
1012
129k
    if (params.streaming_mode) {
1013
0
      JXL_RETURN_IF_ERROR(writer->AppendUnaligned(*histo_writer));
1014
0
    }
1015
129k
  }
1016
34.5k
  return cost;
1017
34.5k
}
1018
1019
template <typename Writer>
1020
void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
1021
131k
                      size_t log_alpha_size) {
1022
131k
  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
1023
131k
                uint_config.split_exponent);
1024
131k
  if (uint_config.split_exponent == log_alpha_size) {
1025
89
    return;  // msb/lsb don't matter.
1026
89
  }
1027
131k
  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
1028
131k
  writer->Write(nbits, uint_config.msb_in_token);
1029
131k
  nbits = CeilLog2Nonzero(uint_config.split_exponent -
1030
131k
                          uint_config.msb_in_token + 1);
1031
131k
  writer->Write(nbits, uint_config.lsb_in_token);
1032
131k
}
void jxl::EncodeUintConfig<jxl::SizeWriter>(jxl::HybridUintConfig, jxl::SizeWriter*, unsigned long)
Line
Count
Source
1021
21.7k
                      size_t log_alpha_size) {
1022
21.7k
  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
1023
21.7k
                uint_config.split_exponent);
1024
21.7k
  if (uint_config.split_exponent == log_alpha_size) {
1025
0
    return;  // msb/lsb don't matter.
1026
0
  }
1027
21.7k
  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
1028
21.7k
  writer->Write(nbits, uint_config.msb_in_token);
1029
21.7k
  nbits = CeilLog2Nonzero(uint_config.split_exponent -
1030
21.7k
                          uint_config.msb_in_token + 1);
1031
21.7k
  writer->Write(nbits, uint_config.lsb_in_token);
1032
21.7k
}
void jxl::EncodeUintConfig<jxl::BitWriter>(jxl::HybridUintConfig, jxl::BitWriter*, unsigned long)
Line
Count
Source
1021
110k
                      size_t log_alpha_size) {
1022
110k
  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
1023
110k
                uint_config.split_exponent);
1024
110k
  if (uint_config.split_exponent == log_alpha_size) {
1025
89
    return;  // msb/lsb don't matter.
1026
89
  }
1027
109k
  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
1028
109k
  writer->Write(nbits, uint_config.msb_in_token);
1029
109k
  nbits = CeilLog2Nonzero(uint_config.split_exponent -
1030
109k
                          uint_config.msb_in_token + 1);
1031
109k
  writer->Write(nbits, uint_config.lsb_in_token);
1032
109k
}
1033
template <typename Writer>
1034
void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
1035
34.5k
                       Writer* writer, size_t log_alpha_size) {
1036
  // TODO(veluca): RLE?
1037
129k
  for (const auto& cfg : uint_config) {
1038
129k
    EncodeUintConfig(cfg, writer, log_alpha_size);
1039
129k
  }
1040
34.5k
}
void jxl::EncodeUintConfigs<jxl::BitWriter>(std::__1::vector<jxl::HybridUintConfig, std::__1::allocator<jxl::HybridUintConfig> > const&, jxl::BitWriter*, unsigned long)
Line
Count
Source
1035
15.7k
                       Writer* writer, size_t log_alpha_size) {
1036
  // TODO(veluca): RLE?
1037
108k
  for (const auto& cfg : uint_config) {
1038
108k
    EncodeUintConfig(cfg, writer, log_alpha_size);
1039
108k
  }
1040
15.7k
}
void jxl::EncodeUintConfigs<jxl::SizeWriter>(std::__1::vector<jxl::HybridUintConfig, std::__1::allocator<jxl::HybridUintConfig> > const&, jxl::SizeWriter*, unsigned long)
Line
Count
Source
1035
18.7k
                       Writer* writer, size_t log_alpha_size) {
1036
  // TODO(veluca): RLE?
1037
20.1k
  for (const auto& cfg : uint_config) {
1038
20.1k
    EncodeUintConfig(cfg, writer, log_alpha_size);
1039
20.1k
  }
1040
18.7k
}
1041
template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
1042
                                BitWriter*, size_t);
1043
1044
Status EncodeHistograms(const EntropyEncodingData& codes, BitWriter* writer,
1045
0
                        LayerType layer, AuxOut* aux_out) {
1046
0
  return writer->WithMaxBits(
1047
0
      128 + kClustersLimit * 136, layer, aux_out,
1048
0
      [&]() -> Status {
1049
0
        JXL_RETURN_IF_ERROR(Bundle::Write(codes.lz77, writer, layer, aux_out));
1050
0
        if (codes.lz77.enabled) {
1051
0
          EncodeUintConfig(codes.lz77.length_uint_config, writer,
1052
0
                           /*log_alpha_size=*/8);
1053
0
        }
1054
0
        JXL_RETURN_IF_ERROR(EncodeContextMap(codes.context_map,
1055
0
                                             codes.encoding_info.size(), writer,
1056
0
                                             layer, aux_out));
1057
0
        writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code));
1058
0
        size_t log_alpha_size = 8;
1059
0
        if (codes.use_prefix_code) {
1060
0
          log_alpha_size = PREFIX_MAX_BITS;
1061
0
        } else {
1062
0
          log_alpha_size = 8;  // streaming_mode
1063
0
          writer->Write(2, log_alpha_size - 5);
1064
0
        }
1065
0
        EncodeUintConfigs(codes.uint_config, writer, log_alpha_size);
1066
0
        if (codes.use_prefix_code) {
1067
0
          for (const auto& info : codes.encoding_info) {
1068
0
            StoreVarLenUint16(info.size() - 1, writer);
1069
0
          }
1070
0
        }
1071
0
        for (const auto& histo_writer : codes.encoded_histograms) {
1072
0
          JXL_RETURN_IF_ERROR(writer->AppendUnaligned(histo_writer));
1073
0
        }
1074
0
        return true;
1075
0
      },
1076
0
      /*finished_histogram=*/true);
1077
0
}
1078
1079
StatusOr<size_t> BuildAndEncodeHistograms(
1080
    JxlMemoryManager* memory_manager, const HistogramParams& params,
1081
    size_t num_contexts, std::vector<std::vector<Token>>& tokens,
1082
    EntropyEncodingData* codes, BitWriter* writer, LayerType layer,
1083
34.5k
    AuxOut* aux_out) {
1084
  // TODO(Ivan): presumably not needed - default
1085
  // if (params.initialize_global_state) codes->lz77.enabled = false;
1086
34.5k
  codes->lz77.nonserialized_distance_context = num_contexts;
1087
34.5k
  codes->lz77.min_symbol = params.force_huffman ? 512 : 224;
1088
34.5k
  std::vector<std::vector<Token>> tokens_lz77 =
1089
34.5k
      ApplyLZ77(params, num_contexts, tokens, codes->lz77);
1090
34.5k
  if (!tokens_lz77.empty()) codes->lz77.enabled = true;
1091
34.5k
  if (ans_fuzzer_friendly_) {
1092
0
    codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1093
0
    codes->lz77.min_symbol = 2048;
1094
0
  }
1095
1096
34.5k
  size_t cost = 0;
1097
34.5k
  const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1098
34.5k
  const auto& body = [&]() -> Status {
1099
34.5k
    if (writer) {
1100
15.7k
      JXL_RETURN_IF_ERROR(Bundle::Write(codes->lz77, writer, layer, aux_out));
1101
18.7k
    } else {
1102
18.7k
      size_t ebits, bits;
1103
18.7k
      JXL_RETURN_IF_ERROR(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1104
18.7k
      cost += bits;
1105
18.7k
    }
1106
34.5k
    if (codes->lz77.enabled) {
1107
2.61k
      if (writer) {
1108
1.04k
        size_t b = writer->BitsWritten();
1109
1.04k
        EncodeUintConfig(codes->lz77.length_uint_config, writer,
1110
1.04k
                         /*log_alpha_size=*/8);
1111
1.04k
        cost += writer->BitsWritten() - b;
1112
1.57k
      } else {
1113
1.57k
        SizeWriter size_writer;
1114
1.57k
        EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1115
1.57k
                         /*log_alpha_size=*/8);
1116
1.57k
        cost += size_writer.size;
1117
1.57k
      }
1118
2.61k
      num_contexts += 1;
1119
2.61k
      tokens = std::move(tokens_lz77);
1120
2.61k
    }
1121
34.5k
    size_t total_tokens = 0;
1122
    // Build histograms.
1123
34.5k
    std::vector<Histogram> builder(num_contexts);
1124
34.5k
    HybridUintConfig uint_config = params.UintConfig();
1125
34.5k
    if (ans_fuzzer_friendly_) {
1126
0
      uint_config = HybridUintConfig(10, 0, 0);
1127
0
    }
1128
106k
    for (const auto& stream : tokens) {
1129
106k
      if (codes->lz77.enabled) {
1130
3.92M
        for (const auto& token : stream) {
1131
3.92M
          total_tokens++;
1132
3.92M
          uint32_t tok, nbits, bits;
1133
3.92M
          (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1134
3.92M
              .Encode(token.value, &tok, &nbits, &bits);
1135
3.92M
          tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1136
3.92M
          JXL_DASSERT(token.context < num_contexts);
1137
3.92M
          builder[token.context].Add(tok);
1138
3.92M
        }
1139
103k
      } else if (num_contexts == 1) {
1140
9.62M
        for (const auto& token : stream) {
1141
9.62M
          total_tokens++;
1142
9.62M
          uint32_t tok, nbits, bits;
1143
9.62M
          uint_config.Encode(token.value, &tok, &nbits, &bits);
1144
9.62M
          builder[0].Add(tok);
1145
9.62M
        }
1146
81.3k
      } else {
1147
193M
        for (const auto& token : stream) {
1148
193M
          total_tokens++;
1149
193M
          uint32_t tok, nbits, bits;
1150
193M
          uint_config.Encode(token.value, &tok, &nbits, &bits);
1151
193M
          JXL_DASSERT(token.context < num_contexts);
1152
193M
          builder[token.context].Add(tok);
1153
193M
        }
1154
81.3k
      }
1155
106k
    }
1156
1157
34.5k
    if (params.add_missing_symbols) {
1158
0
      for (size_t c = 0; c < num_contexts; ++c) {
1159
0
        for (int symbol = 0; symbol < ANS_MAX_ALPHABET_SIZE; ++symbol) {
1160
0
          builder[c].Add(symbol);
1161
0
        }
1162
0
      }
1163
0
    }
1164
1165
34.5k
    if (params.initialize_global_state) {
1166
34.5k
      bool use_prefix_code =
1167
34.5k
          params.force_huffman || total_tokens < 100 ||
1168
34.5k
          params.clustering == HistogramParams::ClusteringType::kFastest ||
1169
34.5k
          ans_fuzzer_friendly_;
1170
34.5k
      if (!use_prefix_code) {
1171
12.2k
        bool all_singleton = true;
1172
9.21M
        for (size_t i = 0; i < num_contexts; i++) {
1173
9.20M
          if (builder[i].ShannonEntropy() >= 1e-5) {
1174
1.27M
            all_singleton = false;
1175
1.27M
          }
1176
9.20M
        }
1177
12.2k
        if (all_singleton) {
1178
65
          use_prefix_code = true;
1179
65
        }
1180
12.2k
      }
1181
34.5k
      codes->use_prefix_code = use_prefix_code;
1182
34.5k
    }
1183
1184
34.5k
    if (params.add_fixed_histograms) {
1185
      // TODO(szabadka) Add more fixed histograms.
1186
      // TODO(szabadka) Reduce alphabet size by choosing a non-default
1187
      // uint_config.
1188
0
      const size_t alphabet_size = ANS_MAX_ALPHABET_SIZE;
1189
0
      codes->log_alpha_size = 8;
1190
0
      JXL_ENSURE(alphabet_size == 1u << codes->log_alpha_size);
1191
0
      static_assert(ANS_MAX_ALPHABET_SIZE <= ANS_TAB_SIZE,
1192
0
                    "Alphabet does not fit table");
1193
0
      codes->encoding_info.emplace_back();
1194
0
      codes->encoding_info.back().resize(alphabet_size);
1195
0
      codes->encoded_histograms.emplace_back(memory_manager);
1196
0
      BitWriter* histo_writer = &codes->encoded_histograms.back();
1197
0
      JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits(
1198
0
          256 + alphabet_size * 24, LayerType::Header, nullptr,
1199
0
          [&]() -> Status {
1200
0
            JXL_ASSIGN_OR_RETURN(
1201
0
                size_t ans_cost,
1202
0
                codes->BuildAndStoreANSEncodingData(
1203
0
                    memory_manager, params.ans_histogram_strategy,
1204
0
                    Histogram::Flat(alphabet_size, ANS_TAB_SIZE),
1205
0
                    histo_writer));
1206
0
            (void)ans_cost;
1207
0
            return true;
1208
0
          }));
1209
0
    }
1210
1211
    // Encode histograms.
1212
34.5k
    JXL_ASSIGN_OR_RETURN(
1213
34.5k
        size_t entropy_bits,
1214
34.5k
        codes->BuildAndStoreEntropyCodes(memory_manager, params, tokens,
1215
34.5k
                                         builder, writer, layer, aux_out));
1216
34.5k
    cost += entropy_bits;
1217
34.5k
    return true;
1218
34.5k
  };
1219
34.5k
  if (writer) {
1220
15.7k
    JXL_RETURN_IF_ERROR(writer->WithMaxBits(
1221
15.7k
        128 + num_contexts * 40 + max_contexts * 96, layer, aux_out, body,
1222
15.7k
        /*finished_histogram=*/true));
1223
18.7k
  } else {
1224
18.7k
    JXL_RETURN_IF_ERROR(body());
1225
18.7k
  }
1226
1227
34.5k
  if (aux_out != nullptr) {
1228
0
    aux_out->layer(layer).num_clustered_histograms +=
1229
0
        codes->encoding_info.size();
1230
0
  }
1231
34.5k
  return cost;
1232
34.5k
}
1233
1234
size_t WriteTokens(const std::vector<Token>& tokens,
1235
                   const EntropyEncodingData& codes, size_t context_offset,
1236
21.0k
                   BitWriter* writer) {
1237
21.0k
  size_t num_extra_bits = 0;
1238
21.0k
  if (codes.use_prefix_code) {
1239
273k
    for (const auto& token : tokens) {
1240
273k
      uint32_t tok, nbits, bits;
1241
273k
      size_t histo = codes.context_map[context_offset + token.context];
1242
273k
      (token.is_lz77_length ? codes.lz77.length_uint_config
1243
273k
                            : codes.uint_config[histo])
1244
273k
          .Encode(token.value, &tok, &nbits, &bits);
1245
273k
      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1246
      // Combine two calls to the BitWriter. Equivalent to:
1247
      // writer->Write(codes.encoding_info[histo][tok].depth,
1248
      //               codes.encoding_info[histo][tok].bits);
1249
      // writer->Write(nbits, bits);
1250
273k
      uint64_t data = codes.encoding_info[histo][tok].bits;
1251
273k
      data |= static_cast<uint64_t>(bits)
1252
273k
              << codes.encoding_info[histo][tok].depth;
1253
273k
      writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1254
273k
      num_extra_bits += nbits;
1255
273k
    }
1256
7.08k
    return num_extra_bits;
1257
7.08k
  }
1258
13.9k
  std::vector<uint64_t> out;
1259
13.9k
  std::vector<uint8_t> out_nbits;
1260
13.9k
  out.reserve(tokens.size());
1261
13.9k
  out_nbits.reserve(tokens.size());
1262
13.9k
  uint64_t allbits = 0;
1263
13.9k
  size_t numallbits = 0;
1264
  // Writes in *reversed* order.
1265
396M
  auto addbits = [&](size_t bits, size_t nbits) {
1266
396M
    if (JXL_UNLIKELY(nbits)) {
1267
36.8M
      JXL_DASSERT(bits >> nbits == 0);
1268
36.8M
      if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1269
7.62M
        out.push_back(allbits);
1270
7.62M
        out_nbits.push_back(numallbits);
1271
7.62M
        numallbits = allbits = 0;
1272
7.62M
      }
1273
36.8M
      allbits <<= nbits;
1274
36.8M
      allbits |= bits;
1275
36.8M
      numallbits += nbits;
1276
36.8M
    }
1277
396M
  };
1278
13.9k
  const int end = tokens.size();
1279
13.9k
  ANSCoder ans;
1280
13.9k
  if (codes.lz77.enabled || codes.context_map.size() > 1) {
1281
195M
    for (int i = end - 1; i >= 0; --i) {
1282
195M
      const Token token = tokens[i];
1283
195M
      const uint8_t histo = codes.context_map[context_offset + token.context];
1284
195M
      uint32_t tok, nbits, bits;
1285
195M
      (token.is_lz77_length ? codes.lz77.length_uint_config
1286
195M
                            : codes.uint_config[histo])
1287
195M
          .Encode(tokens[i].value, &tok, &nbits, &bits);
1288
195M
      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1289
195M
      const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1290
195M
      JXL_DASSERT(info.freq_ > 0);
1291
      // Extra bits first as this is reversed.
1292
195M
      addbits(bits, nbits);
1293
195M
      num_extra_bits += nbits;
1294
195M
      uint8_t ans_nbits = 0;
1295
195M
      uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1296
195M
      addbits(ans_bits, ans_nbits);
1297
195M
    }
1298
13.2k
  } else {
1299
2.88M
    for (int i = end - 1; i >= 0; --i) {
1300
2.88M
      uint32_t tok, nbits, bits;
1301
2.88M
      codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
1302
2.88M
      const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
1303
      // Extra bits first as this is reversed.
1304
2.88M
      addbits(bits, nbits);
1305
2.88M
      num_extra_bits += nbits;
1306
2.88M
      uint8_t ans_nbits = 0;
1307
2.88M
      uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1308
2.88M
      addbits(ans_bits, ans_nbits);
1309
2.88M
    }
1310
645
  }
1311
13.9k
  const uint32_t state = ans.GetState();
1312
13.9k
  writer->Write(32, state);
1313
13.9k
  writer->Write(numallbits, allbits);
1314
7.63M
  for (int i = out.size(); i > 0; --i) {
1315
7.62M
    writer->Write(out_nbits[i - 1], out[i - 1]);
1316
7.62M
  }
1317
13.9k
  return num_extra_bits;
1318
13.9k
}
1319
1320
Status WriteTokens(const std::vector<Token>& tokens,
1321
                   const EntropyEncodingData& codes, size_t context_offset,
1322
16.1k
                   BitWriter* writer, LayerType layer, AuxOut* aux_out) {
1323
  // Theoretically, we could have 15 prefix code bits + 31 extra bits.
1324
16.1k
  return writer->WithMaxBits(
1325
16.1k
      46 * tokens.size() + 32 * 1024 * 4, layer, aux_out, [&] {
1326
16.1k
        size_t num_extra_bits =
1327
16.1k
            WriteTokens(tokens, codes, context_offset, writer);
1328
16.1k
        if (aux_out != nullptr) {
1329
0
          aux_out->layer(layer).extra_bits += num_extra_bits;
1330
0
        }
1331
16.1k
        return true;
1332
16.1k
      });
1333
16.1k
}
1334
1335
0
void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1336
#if JXL_IS_DEBUG_BUILD  // Guard against accidental / malicious changes.
1337
0
  ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1338
0
#endif
1339
0
}
1340
1341
HistogramParams HistogramParams::ForModular(
1342
    const CompressParams& cparams,
1343
6.26k
    const std::vector<uint8_t>& extra_dc_precision, bool streaming_mode) {
1344
6.26k
  HistogramParams params;
1345
6.26k
  params.streaming_mode = streaming_mode;
1346
6.26k
  if (cparams.speed_tier > SpeedTier::kKitten) {
1347
6.26k
    params.clustering = HistogramParams::ClusteringType::kFast;
1348
6.26k
    params.ans_histogram_strategy =
1349
6.26k
        cparams.speed_tier > SpeedTier::kThunder
1350
6.26k
            ? HistogramParams::ANSHistogramStrategy::kFast
1351
6.26k
            : HistogramParams::ANSHistogramStrategy::kApproximate;
1352
6.26k
    params.lz77_method =
1353
6.26k
        cparams.modular_mode && cparams.speed_tier <= SpeedTier::kHare
1354
6.26k
            ? HistogramParams::LZ77Method::kRLE
1355
6.26k
            : HistogramParams::LZ77Method::kNone;
1356
    // Near-lossless DC, as well as modular mode, require choosing hybrid uint
1357
    // more carefully.
1358
6.26k
    if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) ||
1359
6.26k
        (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) {
1360
4.13k
      params.uint_method = HistogramParams::HybridUintMethod::kFast;
1361
4.13k
    } else {
1362
2.13k
      params.uint_method = HistogramParams::HybridUintMethod::kNone;
1363
2.13k
    }
1364
6.26k
  } else if (cparams.speed_tier <= SpeedTier::kTortoise) {
1365
0
    params.lz77_method = HistogramParams::LZ77Method::kOptimal;
1366
0
  } else {
1367
0
    params.lz77_method = HistogramParams::LZ77Method::kLZ77;
1368
0
  }
1369
6.26k
  if (cparams.decoding_speed_tier >= 2) {
1370
0
    params.max_histograms = 12;
1371
0
  }
1372
    // No predictor requires LZ77 to compress residuals.
1373
    // Effort 3 and lower have forced predictors, so kNone is set.
1374
6.26k
    if (cparams.options.predictor == Predictor::Zero && cparams.modular_mode) {
1375
0
        params.lz77_method = cparams.speed_tier >= SpeedTier::kFalcon
1376
0
            ? HistogramParams::LZ77Method::kNone
1377
0
            : cparams.speed_tier >= SpeedTier::kHare
1378
0
            ? HistogramParams::LZ77Method::kRLE
1379
0
            : cparams.speed_tier >= SpeedTier::kKitten
1380
0
            ? HistogramParams::LZ77Method::kLZ77
1381
0
            : HistogramParams::LZ77Method::kOptimal;
1382
0
    }
1383
6.26k
  return params;
1384
6.26k
}
1385
}  // namespace jxl