Coverage Report

Created: 2023-08-28 07:24

/src/libjxl/lib/jxl/enc_ans.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_ans.h"
7
8
#include <stdint.h>
9
10
#include <algorithm>
11
#include <array>
12
#include <cmath>
13
#include <limits>
14
#include <numeric>
15
#include <type_traits>
16
#include <unordered_map>
17
#include <utility>
18
#include <vector>
19
20
#include "lib/jxl/ans_common.h"
21
#include "lib/jxl/base/bits.h"
22
#include "lib/jxl/dec_ans.h"
23
#include "lib/jxl/enc_aux_out.h"
24
#include "lib/jxl/enc_cluster.h"
25
#include "lib/jxl/enc_context_map.h"
26
#include "lib/jxl/enc_fields.h"
27
#include "lib/jxl/enc_huffman.h"
28
#include "lib/jxl/fast_math-inl.h"
29
#include "lib/jxl/fields.h"
30
31
namespace jxl {
32
33
namespace {
34
35
#if !JXL_IS_DEBUG_BUILD
36
constexpr
37
#endif
38
    bool ans_fuzzer_friendly_ = false;
39
40
static const int kMaxNumSymbolsForSmallCode = 4;
41
42
void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
43
                       size_t alphabet_size, size_t log_alpha_size,
44
0
                       ANSEncSymbolInfo* info) {
45
0
  size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
46
0
  size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
47
  // create valid alias table for empty streams.
48
0
  for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
49
0
    const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
50
0
    info[s].freq_ = static_cast<uint16_t>(freq);
51
0
#ifdef USE_MULT_BY_RECIPROCAL
52
0
    if (freq != 0) {
53
0
      info[s].ifreq_ =
54
0
          ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
55
0
    } else {
56
0
      info[s].ifreq_ = 1;  // shouldn't matter (symbol shouldn't occur), but...
57
0
    }
58
0
#endif
59
0
    info[s].reverse_map_.resize(freq);
60
0
  }
61
0
  for (int i = 0; i < ANS_TAB_SIZE; i++) {
62
0
    AliasTable::Symbol s =
63
0
        AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
64
0
    info[s.value].reverse_map_[s.offset] = i;
65
0
  }
66
0
}
67
68
float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
69
0
                       size_t len) {
70
0
  float sum = 0.0f;
71
0
  int total_histogram = 0;
72
0
  int total_counts = 0;
73
0
  for (size_t i = 0; i < len; ++i) {
74
0
    total_histogram += histogram[i];
75
0
    total_counts += counts[i];
76
0
    if (histogram[i] > 0) {
77
0
      JXL_ASSERT(counts[i] > 0);
78
      // += histogram[i] * -log(counts[i]/total_counts)
79
0
      sum += histogram[i] *
80
0
             std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
81
0
    }
82
0
  }
83
0
  if (total_histogram > 0) {
84
    // Used only in assert.
85
0
    (void)total_counts;
86
0
    JXL_ASSERT(total_counts == ANS_TAB_SIZE);
87
0
  }
88
0
  return sum;
89
0
}
90
91
0
float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
92
0
  const float flat_bits = std::max(FastLog2f(len), 0.0f);
93
0
  float total_histogram = 0;
94
0
  for (size_t i = 0; i < len; ++i) {
95
0
    total_histogram += histogram[i];
96
0
  }
97
0
  return total_histogram * flat_bits;
98
0
}
99
100
// Static Huffman code for encoding logcounts. The last symbol is used as RLE
101
// sequence.
102
static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
103
    5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
104
};
105
static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
106
    17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
107
};
108
109
// Returns the difference between largest count that can be represented and is
110
// smaller than "count" and smallest representable count larger than "count".
111
0
static int SmallestIncrement(uint32_t count, uint32_t shift) {
112
0
  int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
113
0
  int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
114
0
  return drop_bits < 0 ? 1 : (1 << drop_bits);
115
0
}
116
117
template <bool minimize_error_of_sum>
118
bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
119
0
                        uint32_t shift, int* omit_pos, ANSHistBin* counts) {
120
0
  int sum = 0;
121
0
  float sum_nonrounded = 0.0;
122
0
  int remainder_pos = 0;  // if all of them are handled in first loop
123
0
  int remainder_log = -1;
124
0
  for (int n = 0; n < max_symbol; ++n) {
125
0
    if (targets[n] > 0 && targets[n] < 1.0f) {
126
0
      counts[n] = 1;
127
0
      sum_nonrounded += targets[n];
128
0
      sum += counts[n];
129
0
    }
130
0
  }
131
0
  const float discount_ratio =
132
0
      (table_size - sum) / (table_size - sum_nonrounded);
133
0
  JXL_ASSERT(discount_ratio > 0);
134
0
  JXL_ASSERT(discount_ratio <= 1.0f);
135
  // Invariant for minimize_error_of_sum == true:
136
  // abs(sum - sum_nonrounded)
137
  //   <= SmallestIncrement(max(targets[])) + max_symbol
138
0
  for (int n = 0; n < max_symbol; ++n) {
139
0
    if (targets[n] >= 1.0f) {
140
0
      sum_nonrounded += targets[n];
141
0
      counts[n] =
142
0
          static_cast<ANSHistBin>(targets[n] * discount_ratio);  // truncate
143
0
      if (counts[n] == 0) counts[n] = 1;
144
0
      if (counts[n] == table_size) counts[n] = table_size - 1;
145
      // Round the count to the closest nonzero multiple of SmallestIncrement
146
      // (when minimize_error_of_sum is false) or one of two closest so as to
147
      // keep the sum as close as possible to sum_nonrounded.
148
0
      int inc = SmallestIncrement(counts[n], shift);
149
0
      counts[n] -= counts[n] & (inc - 1);
150
      // TODO(robryk): Should we rescale targets[n]?
151
0
      const float target =
152
0
          minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
153
0
      if (counts[n] == 0 ||
154
0
          (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
155
0
        counts[n] += inc;
156
0
      }
157
0
      sum += counts[n];
158
0
      const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
159
0
      if (count_log > remainder_log) {
160
0
        remainder_pos = n;
161
0
        remainder_log = count_log;
162
0
      }
163
0
    }
164
0
  }
165
0
  JXL_ASSERT(remainder_pos != -1);
166
  // NOTE: This is the only place where counts could go negative. We could
167
  // detect that, return false and make ANSHistBin uint32_t.
168
0
  counts[remainder_pos] -= sum - table_size;
169
0
  *omit_pos = remainder_pos;
170
0
  return counts[remainder_pos] > 0;
171
0
}
Unexecuted instantiation: enc_ans.cc:bool jxl::(anonymous namespace)::RebalanceHistogram<false>(float const*, int, int, unsigned int, int*, int*)
Unexecuted instantiation: enc_ans.cc:bool jxl::(anonymous namespace)::RebalanceHistogram<true>(float const*, int, int, unsigned int, int*, int*)
172
173
Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
174
                       const int precision_bits, uint32_t shift,
175
0
                       int* num_symbols, int* symbols) {
176
0
  const int32_t table_size = 1 << precision_bits;  // target sum / table size
177
0
  uint64_t total = 0;
178
0
  int max_symbol = 0;
179
0
  int symbol_count = 0;
180
0
  for (int n = 0; n < length; ++n) {
181
0
    total += counts[n];
182
0
    if (counts[n] > 0) {
183
0
      if (symbol_count < kMaxNumSymbolsForSmallCode) {
184
0
        symbols[symbol_count] = n;
185
0
      }
186
0
      ++symbol_count;
187
0
      max_symbol = n + 1;
188
0
    }
189
0
  }
190
0
  *num_symbols = symbol_count;
191
0
  if (symbol_count == 0) {
192
0
    return true;
193
0
  }
194
0
  if (symbol_count == 1) {
195
0
    counts[symbols[0]] = table_size;
196
0
    return true;
197
0
  }
198
0
  if (symbol_count > table_size)
199
0
    return JXL_FAILURE("Too many entries in an ANS histogram");
200
201
0
  const float norm = 1.f * table_size / total;
202
0
  std::vector<float> targets(max_symbol);
203
0
  for (size_t n = 0; n < targets.size(); ++n) {
204
0
    targets[n] = norm * counts[n];
205
0
  }
206
0
  if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
207
0
                                 omit_pos, counts)) {
208
    // Use an alternative rebalancing mechanism if the one above failed
209
    // to create a histogram that is positive wherever the original one was.
210
0
    if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
211
0
                                  omit_pos, counts)) {
212
0
      return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
213
0
    }
214
0
  }
215
0
  return true;
216
0
}
217
218
struct SizeWriter {
219
  size_t size = 0;
220
0
  void Write(size_t num, size_t bits) { size += num; }
221
};
222
223
template <typename Writer>
224
0
void StoreVarLenUint8(size_t n, Writer* writer) {
225
0
  JXL_DASSERT(n <= 255);
226
0
  if (n == 0) {
227
0
    writer->Write(1, 0);
228
0
  } else {
229
0
    writer->Write(1, 1);
230
0
    size_t nbits = FloorLog2Nonzero(n);
231
0
    writer->Write(3, nbits);
232
0
    writer->Write(nbits, n - (1ULL << nbits));
233
0
  }
234
0
}
Unexecuted instantiation: enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint8<jxl::(anonymous namespace)::SizeWriter>(unsigned long, jxl::(anonymous namespace)::SizeWriter*)
Unexecuted instantiation: enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint8<jxl::BitWriter>(unsigned long, jxl::BitWriter*)
235
236
template <typename Writer>
237
0
void StoreVarLenUint16(size_t n, Writer* writer) {
238
0
  JXL_DASSERT(n <= 65535);
239
0
  if (n == 0) {
240
0
    writer->Write(1, 0);
241
0
  } else {
242
0
    writer->Write(1, 1);
243
0
    size_t nbits = FloorLog2Nonzero(n);
244
0
    writer->Write(4, nbits);
245
0
    writer->Write(nbits, n - (1ULL << nbits));
246
0
  }
247
0
}
Unexecuted instantiation: enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint16<jxl::BitWriter>(unsigned long, jxl::BitWriter*)
Unexecuted instantiation: enc_ans.cc:void jxl::(anonymous namespace)::StoreVarLenUint16<jxl::(anonymous namespace)::SizeWriter>(unsigned long, jxl::(anonymous namespace)::SizeWriter*)
248
249
template <typename Writer>
250
bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
251
                  const int omit_pos, const int num_symbols, uint32_t shift,
252
0
                  const int* symbols, Writer* writer) {
253
0
  bool ok = true;
254
0
  if (num_symbols <= 2) {
255
    // Small tree marker to encode 1-2 symbols.
256
0
    writer->Write(1, 1);
257
0
    if (num_symbols == 0) {
258
0
      writer->Write(1, 0);
259
0
      StoreVarLenUint8(0, writer);
260
0
    } else {
261
0
      writer->Write(1, num_symbols - 1);
262
0
      for (int i = 0; i < num_symbols; ++i) {
263
0
        StoreVarLenUint8(symbols[i], writer);
264
0
      }
265
0
    }
266
0
    if (num_symbols == 2) {
267
0
      writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
268
0
    }
269
0
  } else {
270
    // Mark non-small tree.
271
0
    writer->Write(1, 0);
272
    // Mark non-flat histogram.
273
0
    writer->Write(1, 0);
274
275
    // Precompute sequences for RLE encoding. Contains the number of identical
276
    // values starting at a given index. Only contains the value at the first
277
    // element of the series.
278
0
    std::vector<uint32_t> same(alphabet_size, 0);
279
0
    int last = 0;
280
0
    for (int i = 1; i < alphabet_size; i++) {
281
      // Store the sequence length once different symbol reached, or we're at
282
      // the end, or the length is longer than we can encode, or we are at
283
      // the omit_pos. We don't support including the omit_pos in an RLE
284
      // sequence because this value may use a different amount of log2 bits
285
      // than standard, it is too complex to handle in the decoder.
286
0
      if (counts[i] != counts[last] || i + 1 == alphabet_size ||
287
0
          (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
288
0
        same[last] = (i - last);
289
0
        last = i + 1;
290
0
      }
291
0
    }
292
293
0
    int length = 0;
294
0
    std::vector<int> logcounts(alphabet_size);
295
0
    int omit_log = 0;
296
0
    for (int i = 0; i < alphabet_size; ++i) {
297
0
      JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
298
0
      JXL_ASSERT(counts[i] >= 0);
299
0
      if (i == omit_pos) {
300
0
        length = i + 1;
301
0
      } else if (counts[i] > 0) {
302
0
        logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
303
0
        length = i + 1;
304
0
        if (i < omit_pos) {
305
0
          omit_log = std::max(omit_log, logcounts[i] + 1);
306
0
        } else {
307
0
          omit_log = std::max(omit_log, logcounts[i]);
308
0
        }
309
0
      }
310
0
    }
311
0
    logcounts[omit_pos] = omit_log;
312
313
    // Elias gamma-like code for shift. Only difference is that if the number
314
    // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
315
    // the terminating 0 in unary coding.
316
0
    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
317
0
    int log = FloorLog2Nonzero(shift + 1);
318
0
    writer->Write(log, (1 << log) - 1);
319
0
    if (log != upper_bound_log) writer->Write(1, 0);
320
0
    writer->Write(log, ((1 << log) - 1) & (shift + 1));
321
322
    // Since num_symbols >= 3, we know that length >= 3, therefore we encode
323
    // length - 3.
324
0
    if (length - 3 > 255) {
325
      // Pretend that everything is OK, but complain about correctness later.
326
0
      StoreVarLenUint8(255, writer);
327
0
      ok = false;
328
0
    } else {
329
0
      StoreVarLenUint8(length - 3, writer);
330
0
    }
331
332
    // The logcount values are encoded with a static Huffman code.
333
0
    static const size_t kMinReps = 4;
334
0
    size_t rep = ANS_LOG_TAB_SIZE + 1;
335
0
    for (int i = 0; i < length; ++i) {
336
0
      if (i > 0 && same[i - 1] > kMinReps) {
337
        // Encode the RLE symbol and skip the repeated ones.
338
0
        writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
339
0
        StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
340
0
        i += same[i - 1] - 2;
341
0
        continue;
342
0
      }
343
0
      writer->Write(kLogCountBitLengths[logcounts[i]],
344
0
                    kLogCountSymbols[logcounts[i]]);
345
0
    }
346
0
    for (int i = 0; i < length; ++i) {
347
0
      if (i > 0 && same[i - 1] > kMinReps) {
348
        // Skip symbols encoded by RLE.
349
0
        i += same[i - 1] - 2;
350
0
        continue;
351
0
      }
352
0
      if (logcounts[i] > 1 && i != omit_pos) {
353
0
        int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
354
0
        int drop_bits = logcounts[i] - 1 - bitcount;
355
0
        JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
356
0
        writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
357
0
      }
358
0
    }
359
0
  }
360
0
  return ok;
361
0
}
Unexecuted instantiation: enc_ans.cc:bool jxl::(anonymous namespace)::EncodeCounts<jxl::(anonymous namespace)::SizeWriter>(int const*, int, int, int, unsigned int, int const*, jxl::(anonymous namespace)::SizeWriter*)
Unexecuted instantiation: enc_ans.cc:bool jxl::(anonymous namespace)::EncodeCounts<jxl::BitWriter>(int const*, int, int, int, unsigned int, int const*, jxl::BitWriter*)
362
363
0
void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
364
  // Mark non-small tree.
365
0
  writer->Write(1, 0);
366
  // Mark uniform histogram.
367
0
  writer->Write(1, 1);
368
0
  JXL_ASSERT(alphabet_size > 0);
369
  // Encode alphabet size.
370
0
  StoreVarLenUint8(alphabet_size - 1, writer);
371
0
}
372
373
float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
374
0
                              uint32_t method) {
375
0
  if (method == 0) {  // Flat code
376
0
    return ANS_LOG_TAB_SIZE + 2 +
377
0
           EstimateDataBitsFlat(histogram, alphabet_size);
378
0
  }
379
  // Non-flat: shift = method-1.
380
0
  uint32_t shift = method - 1;
381
0
  std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
382
0
  int omit_pos = 0;
383
0
  int num_symbols;
384
0
  int symbols[kMaxNumSymbolsForSmallCode] = {};
385
0
  JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
386
0
                            ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
387
0
  SizeWriter writer;
388
  // Ignore the correctness, no real encoding happens at this stage.
389
0
  (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
390
0
                     symbols, &writer);
391
0
  return writer.size +
392
0
         EstimateDataBits(histogram, counts.data(), alphabet_size);
393
0
}
394
395
uint32_t ComputeBestMethod(
396
    const ANSHistBin* histogram, size_t alphabet_size, float* cost,
397
0
    HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
398
0
  size_t method = 0;
399
0
  float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
400
0
  auto try_shift = [&](size_t shift) {
401
0
    float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
402
0
    if (c < fcost) {
403
0
      method = shift + 1;
404
0
      fcost = c;
405
0
    }
406
0
  };
407
0
  switch (ans_histogram_strategy) {
408
0
    case HistogramParams::ANSHistogramStrategy::kPrecise: {
409
0
      for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) {
410
0
        try_shift(shift);
411
0
      }
412
0
      break;
413
0
    }
414
0
    case HistogramParams::ANSHistogramStrategy::kApproximate: {
415
0
      for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
416
0
        try_shift(shift);
417
0
      }
418
0
      break;
419
0
    }
420
0
    case HistogramParams::ANSHistogramStrategy::kFast: {
421
0
      try_shift(0);
422
0
      try_shift(ANS_LOG_TAB_SIZE / 2);
423
0
      try_shift(ANS_LOG_TAB_SIZE);
424
0
      break;
425
0
    }
426
0
  };
427
0
  *cost = fcost;
428
0
  return method;
429
0
}
430
431
}  // namespace
432
433
// Returns an estimate of the cost of encoding this histogram and the
434
// corresponding data.
435
size_t BuildAndStoreANSEncodingData(
436
    HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
437
    const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
438
0
    bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
439
0
  if (use_prefix_code) {
440
0
    if (alphabet_size <= 1) return 0;
441
0
    std::vector<uint32_t> histo(alphabet_size);
442
0
    for (size_t i = 0; i < alphabet_size; i++) {
443
0
      histo[i] = histogram[i];
444
0
      JXL_CHECK(histogram[i] >= 0);
445
0
    }
446
0
    size_t cost = 0;
447
0
    {
448
0
      std::vector<uint8_t> depths(alphabet_size);
449
0
      std::vector<uint16_t> bits(alphabet_size);
450
0
      if (writer == nullptr) {
451
0
        BitWriter tmp_writer;
452
0
        BitWriter::Allotment allotment(
453
0
            &tmp_writer, 8 * alphabet_size + 8);  // safe upper bound
454
0
        BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
455
0
                                 bits.data(), &tmp_writer);
456
0
        allotment.ReclaimAndCharge(&tmp_writer, 0, /*aux_out=*/nullptr);
457
0
        cost = tmp_writer.BitsWritten();
458
0
      } else {
459
0
        size_t start = writer->BitsWritten();
460
0
        BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
461
0
                                 bits.data(), writer);
462
0
        cost = writer->BitsWritten() - start;
463
0
      }
464
0
      for (size_t i = 0; i < alphabet_size; i++) {
465
0
        info[i].bits = depths[i] == 0 ? 0 : bits[i];
466
0
        info[i].depth = depths[i];
467
0
      }
468
0
    }
469
    // Estimate data cost.
470
0
    for (size_t i = 0; i < alphabet_size; i++) {
471
0
      cost += histogram[i] * info[i].depth;
472
0
    }
473
0
    return cost;
474
0
  }
475
0
  JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
476
  // Ensure we ignore trailing zeros in the histogram.
477
0
  if (alphabet_size != 0) {
478
0
    size_t largest_symbol = 0;
479
0
    for (size_t i = 0; i < alphabet_size; i++) {
480
0
      if (histogram[i] != 0) largest_symbol = i;
481
0
    }
482
0
    alphabet_size = largest_symbol + 1;
483
0
  }
484
0
  float cost;
485
0
  uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
486
0
                                      ans_histogram_strategy);
487
0
  JXL_ASSERT(cost >= 0);
488
0
  int num_symbols;
489
0
  int symbols[kMaxNumSymbolsForSmallCode] = {};
490
0
  std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
491
0
  if (!counts.empty()) {
492
0
    size_t sum = 0;
493
0
    for (size_t i = 0; i < counts.size(); i++) {
494
0
      sum += counts[i];
495
0
    }
496
0
    if (sum == 0) {
497
0
      counts[0] = ANS_TAB_SIZE;
498
0
    }
499
0
  }
500
0
  if (method == 0) {
501
0
    counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
502
0
    AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
503
0
    InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
504
0
    ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
505
0
    if (writer != nullptr) {
506
0
      EncodeFlatHistogram(alphabet_size, writer);
507
0
    }
508
0
    return cost;
509
0
  }
510
0
  int omit_pos = 0;
511
0
  uint32_t shift = method - 1;
512
0
  JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
513
0
                            ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
514
0
  AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
515
0
  InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
516
0
  ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
517
0
  if (writer != nullptr) {
518
0
    bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
519
0
                           shift, symbols, writer);
520
0
    (void)ok;
521
0
    JXL_DASSERT(ok);
522
0
  }
523
0
  return cost;
524
0
}
525
526
0
float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
527
0
  float c;
528
0
  ComputeBestMethod(data, alphabet_size, &c,
529
0
                    HistogramParams::ANSHistogramStrategy::kFast);
530
0
  return c;
531
0
}
532
533
template <typename Writer>
534
void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
535
0
                      size_t log_alpha_size) {
536
0
  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
537
0
                uint_config.split_exponent);
538
0
  if (uint_config.split_exponent == log_alpha_size) {
539
0
    return;  // msb/lsb don't matter.
540
0
  }
541
0
  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
542
0
  writer->Write(nbits, uint_config.msb_in_token);
543
0
  nbits = CeilLog2Nonzero(uint_config.split_exponent -
544
0
                          uint_config.msb_in_token + 1);
545
0
  writer->Write(nbits, uint_config.lsb_in_token);
546
0
}
Unexecuted instantiation: void jxl::EncodeUintConfig<jxl::BitWriter>(jxl::HybridUintConfig, jxl::BitWriter*, unsigned long)
Unexecuted instantiation: enc_ans.cc:void jxl::EncodeUintConfig<jxl::(anonymous namespace)::SizeWriter>(jxl::HybridUintConfig, jxl::(anonymous namespace)::SizeWriter*, unsigned long)
547
template <typename Writer>
548
void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
549
0
                       Writer* writer, size_t log_alpha_size) {
550
  // TODO(veluca): RLE?
551
0
  for (size_t i = 0; i < uint_config.size(); i++) {
552
0
    EncodeUintConfig(uint_config[i], writer, log_alpha_size);
553
0
  }
554
0
}
Unexecuted instantiation: void jxl::EncodeUintConfigs<jxl::BitWriter>(std::__1::vector<jxl::HybridUintConfig, std::__1::allocator<jxl::HybridUintConfig> > const&, jxl::BitWriter*, unsigned long)
Unexecuted instantiation: enc_ans.cc:void jxl::EncodeUintConfigs<jxl::(anonymous namespace)::SizeWriter>(std::__1::vector<jxl::HybridUintConfig, std::__1::allocator<jxl::HybridUintConfig> > const&, jxl::(anonymous namespace)::SizeWriter*, unsigned long)
555
template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
556
                                BitWriter*, size_t);
557
558
namespace {
559
560
void ChooseUintConfigs(const HistogramParams& params,
561
                       const std::vector<std::vector<Token>>& tokens,
562
                       const std::vector<uint8_t>& context_map,
563
                       std::vector<Histogram>* clustered_histograms,
564
0
                       EntropyEncodingData* codes, size_t* log_alpha_size) {
565
0
  codes->uint_config.resize(clustered_histograms->size());
566
567
0
  if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
568
0
  if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
569
0
    codes->uint_config.clear();
570
0
    codes->uint_config.resize(clustered_histograms->size(),
571
0
                              HybridUintConfig(0, 0, 0));
572
0
    return;
573
0
  }
574
0
  if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
575
0
    codes->uint_config.clear();
576
0
    codes->uint_config.resize(clustered_histograms->size(),
577
0
                              HybridUintConfig(2, 0, 1));
578
0
    return;
579
0
  }
580
581
  // Brute-force method that tries a few options.
582
0
  std::vector<HybridUintConfig> configs;
583
0
  if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
584
0
    configs = {
585
0
        HybridUintConfig(4, 2, 0),  // default
586
0
        HybridUintConfig(4, 1, 0),  // less precise
587
0
        HybridUintConfig(4, 2, 1),  // add sign
588
0
        HybridUintConfig(4, 2, 2),  // add sign+parity
589
0
        HybridUintConfig(4, 1, 2),  // add parity but less msb
590
        // Same as above, but more direct coding.
591
0
        HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
592
0
        HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
593
0
        HybridUintConfig(5, 1, 2),
594
        // Same as above, but less direct coding.
595
0
        HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
596
0
        HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
597
        // For near-lossless.
598
0
        HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
599
0
        HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
600
0
        HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
601
        // Other
602
0
        HybridUintConfig(0, 0, 0),   // varlenuint
603
0
        HybridUintConfig(2, 0, 1),   // works well for ctx map
604
0
        HybridUintConfig(7, 0, 0),   // direct coding
605
0
        HybridUintConfig(8, 0, 0),   // direct coding
606
0
        HybridUintConfig(9, 0, 0),   // direct coding
607
0
        HybridUintConfig(10, 0, 0),  // direct coding
608
0
        HybridUintConfig(11, 0, 0),  // direct coding
609
0
        HybridUintConfig(12, 0, 0),  // direct coding
610
0
    };
611
0
  } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
612
0
    configs = {
613
0
        HybridUintConfig(4, 2, 0),  // default
614
0
        HybridUintConfig(4, 1, 2),  // add parity but less msb
615
0
        HybridUintConfig(0, 0, 0),  // smallest histograms
616
0
        HybridUintConfig(2, 0, 1),  // works well for ctx map
617
0
    };
618
0
  }
619
620
0
  std::vector<float> costs(clustered_histograms->size(),
621
0
                           std::numeric_limits<float>::max());
622
0
  std::vector<uint32_t> extra_bits(clustered_histograms->size());
623
0
  std::vector<uint8_t> is_valid(clustered_histograms->size());
624
0
  size_t max_alpha =
625
0
      codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
626
0
  for (HybridUintConfig cfg : configs) {
627
0
    std::fill(is_valid.begin(), is_valid.end(), true);
628
0
    std::fill(extra_bits.begin(), extra_bits.end(), 0);
629
630
0
    for (size_t i = 0; i < clustered_histograms->size(); i++) {
631
0
      (*clustered_histograms)[i].Clear();
632
0
    }
633
0
    for (size_t i = 0; i < tokens.size(); ++i) {
634
0
      for (size_t j = 0; j < tokens[i].size(); ++j) {
635
0
        const Token token = tokens[i][j];
636
        // TODO(veluca): do not ignore lz77 commands.
637
0
        if (token.is_lz77_length) continue;
638
0
        size_t histo = context_map[token.context];
639
0
        uint32_t tok, nbits, bits;
640
0
        cfg.Encode(token.value, &tok, &nbits, &bits);
641
0
        if (tok >= max_alpha ||
642
0
            (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
643
0
          is_valid[histo] = false;
644
0
          continue;
645
0
        }
646
0
        extra_bits[histo] += nbits;
647
0
        (*clustered_histograms)[histo].Add(tok);
648
0
      }
649
0
    }
650
651
0
    for (size_t i = 0; i < clustered_histograms->size(); i++) {
652
0
      if (!is_valid[i]) continue;
653
0
      float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
654
      // add signaling cost of the hybriduintconfig itself
655
0
      cost += CeilLog2Nonzero(cfg.split_exponent + 1);
656
0
      cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
657
0
      if (cost < costs[i]) {
658
0
        codes->uint_config[i] = cfg;
659
0
        costs[i] = cost;
660
0
      }
661
0
    }
662
0
  }
663
664
  // Rebuild histograms.
665
0
  for (size_t i = 0; i < clustered_histograms->size(); i++) {
666
0
    (*clustered_histograms)[i].Clear();
667
0
  }
668
0
  *log_alpha_size = 4;
669
0
  for (size_t i = 0; i < tokens.size(); ++i) {
670
0
    for (size_t j = 0; j < tokens[i].size(); ++j) {
671
0
      const Token token = tokens[i][j];
672
0
      uint32_t tok, nbits, bits;
673
0
      size_t histo = context_map[token.context];
674
0
      (token.is_lz77_length ? codes->lz77.length_uint_config
675
0
                            : codes->uint_config[histo])
676
0
          .Encode(token.value, &tok, &nbits, &bits);
677
0
      tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
678
0
      (*clustered_histograms)[histo].Add(tok);
679
0
      while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
680
0
    }
681
0
  }
682
0
#if JXL_ENABLE_ASSERT
683
0
  size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
684
0
  JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
685
0
#endif
686
0
}
687
688
class HistogramBuilder {
689
 public:
690
  explicit HistogramBuilder(const size_t num_contexts)
691
0
      : histograms_(num_contexts) {}
692
693
0
  void VisitSymbol(int symbol, size_t histo_idx) {
694
0
    JXL_DASSERT(histo_idx < histograms_.size());
695
0
    histograms_[histo_idx].Add(symbol);
696
0
  }
697
698
  // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
699
  size_t BuildAndStoreEntropyCodes(
700
      const HistogramParams& params,
701
      const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
702
      std::vector<uint8_t>* context_map, bool use_prefix_code,
703
0
      BitWriter* writer, size_t layer, AuxOut* aux_out) const {
704
0
    size_t cost = 0;
705
0
    codes->encoding_info.clear();
706
0
    std::vector<Histogram> clustered_histograms(histograms_);
707
0
    context_map->resize(histograms_.size());
708
0
    if (histograms_.size() > 1) {
709
0
      if (!ans_fuzzer_friendly_) {
710
0
        std::vector<uint32_t> histogram_symbols;
711
0
        ClusterHistograms(params, histograms_, kClustersLimit,
712
0
                          &clustered_histograms, &histogram_symbols);
713
0
        for (size_t c = 0; c < histograms_.size(); ++c) {
714
0
          (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
715
0
        }
716
0
      } else {
717
0
        fill(context_map->begin(), context_map->end(), 0);
718
0
        size_t max_symbol = 0;
719
0
        for (const Histogram& h : histograms_) {
720
0
          max_symbol = std::max(h.data_.size(), max_symbol);
721
0
        }
722
0
        size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
723
0
        clustered_histograms.resize(1);
724
0
        clustered_histograms[0].Clear();
725
0
        for (size_t i = 0; i < num_symbols; i++) {
726
0
          clustered_histograms[0].Add(i);
727
0
        }
728
0
      }
729
0
      if (writer != nullptr) {
730
0
        EncodeContextMap(*context_map, clustered_histograms.size(), writer,
731
0
                         layer, aux_out);
732
0
      }
733
0
    }
734
0
    if (aux_out != nullptr) {
735
0
      for (size_t i = 0; i < clustered_histograms.size(); ++i) {
736
0
        aux_out->layers[layer].clustered_entropy +=
737
0
            clustered_histograms[i].ShannonEntropy();
738
0
      }
739
0
    }
740
0
    codes->use_prefix_code = use_prefix_code;
741
0
    size_t log_alpha_size = codes->lz77.enabled ? 8 : 7;  // Sane default.
742
0
    if (ans_fuzzer_friendly_) {
743
0
      codes->uint_config.clear();
744
0
      codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
745
0
    } else {
746
0
      ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
747
0
                        codes, &log_alpha_size);
748
0
    }
749
0
    if (log_alpha_size < 5) log_alpha_size = 5;
750
0
    SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
751
0
    cost += 1;
752
0
    if (writer) writer->Write(1, use_prefix_code);
753
754
0
    if (use_prefix_code) {
755
0
      log_alpha_size = PREFIX_MAX_BITS;
756
0
    } else {
757
0
      cost += 2;
758
0
    }
759
0
    if (writer == nullptr) {
760
0
      EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
761
0
    } else {
762
0
      if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
763
0
      EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
764
0
    }
765
0
    if (use_prefix_code) {
766
0
      for (size_t c = 0; c < clustered_histograms.size(); ++c) {
767
0
        size_t num_symbol = 1;
768
0
        for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
769
0
          if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
770
0
        }
771
0
        if (writer) {
772
0
          StoreVarLenUint16(num_symbol - 1, writer);
773
0
        } else {
774
0
          StoreVarLenUint16(num_symbol - 1, &size_writer);
775
0
        }
776
0
      }
777
0
    }
778
0
    cost += size_writer.size;
779
0
    for (size_t c = 0; c < clustered_histograms.size(); ++c) {
780
0
      size_t num_symbol = 1;
781
0
      for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
782
0
        if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
783
0
      }
784
0
      codes->encoding_info.emplace_back();
785
0
      codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
786
787
0
      BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
788
0
      cost += BuildAndStoreANSEncodingData(
789
0
          params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
790
0
          num_symbol, log_alpha_size, use_prefix_code,
791
0
          codes->encoding_info.back().data(), writer);
792
0
      allotment.FinishedHistogram(writer);
793
0
      allotment.ReclaimAndCharge(writer, layer, aux_out);
794
0
    }
795
0
    return cost;
796
0
  }
797
798
0
  const Histogram& Histo(size_t i) const { return histograms_[i]; }
799
800
 private:
801
  std::vector<Histogram> histograms_;
802
};
803
804
class SymbolCostEstimator {
805
 public:
806
  SymbolCostEstimator(size_t num_contexts, bool force_huffman,
807
                      const std::vector<std::vector<Token>>& tokens,
808
0
                      const LZ77Params& lz77) {
809
0
    HistogramBuilder builder(num_contexts);
810
    // Build histograms for estimating lz77 savings.
811
0
    HybridUintConfig uint_config;
812
0
    for (size_t i = 0; i < tokens.size(); ++i) {
813
0
      for (size_t j = 0; j < tokens[i].size(); ++j) {
814
0
        const Token token = tokens[i][j];
815
0
        uint32_t tok, nbits, bits;
816
0
        (token.is_lz77_length ? lz77.length_uint_config : uint_config)
817
0
            .Encode(token.value, &tok, &nbits, &bits);
818
0
        tok += token.is_lz77_length ? lz77.min_symbol : 0;
819
0
        builder.VisitSymbol(tok, token.context);
820
0
      }
821
0
    }
822
0
    max_alphabet_size_ = 0;
823
0
    for (size_t i = 0; i < num_contexts; i++) {
824
0
      max_alphabet_size_ =
825
0
          std::max(max_alphabet_size_, builder.Histo(i).data_.size());
826
0
    }
827
0
    bits_.resize(num_contexts * max_alphabet_size_);
828
    // TODO(veluca): SIMD?
829
0
    add_symbol_cost_.resize(num_contexts);
830
0
    for (size_t i = 0; i < num_contexts; i++) {
831
0
      float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
832
0
      float total_cost = 0;
833
0
      for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
834
0
        size_t cnt = builder.Histo(i).data_[j];
835
0
        float cost = 0;
836
0
        if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
837
0
          cost = -FastLog2f(cnt * inv_total);
838
0
          if (force_huffman) cost = std::ceil(cost);
839
0
        } else if (cnt == 0) {
840
0
          cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
841
0
        }
842
0
        bits_[i * max_alphabet_size_ + j] = cost;
843
0
        total_cost += cost * builder.Histo(i).data_[j];
844
0
      }
845
      // Penalty for adding a lz77 symbol to this contest (only used for static
846
      // cost model). Higher penalty for contexts that have a very low
847
      // per-symbol entropy.
848
0
      add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
849
0
    }
850
0
  }
851
0
  float Bits(size_t ctx, size_t sym) const {
852
0
    return bits_[ctx * max_alphabet_size_ + sym];
853
0
  }
854
0
  float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
855
0
    uint32_t nbits, bits, tok;
856
0
    lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
857
0
    tok += lz77.min_symbol;
858
0
    return nbits + Bits(ctx, tok);
859
0
  }
860
0
  float DistCost(size_t len, const LZ77Params& lz77) const {
861
0
    uint32_t nbits, bits, tok;
862
0
    HybridUintConfig().Encode(len, &tok, &nbits, &bits);
863
0
    return nbits + Bits(lz77.nonserialized_distance_context, tok);
864
0
  }
865
0
  float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
866
867
 private:
868
  size_t max_alphabet_size_;
869
  std::vector<float> bits_;
870
  std::vector<float> add_symbol_cost_;
871
};
872
873
void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
874
                   const std::vector<std::vector<Token>>& tokens,
875
                   LZ77Params& lz77,
876
0
                   std::vector<std::vector<Token>>& tokens_lz77) {
877
  // TODO(veluca): tune heuristics here.
878
0
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
879
0
  float bit_decrease = 0;
880
0
  size_t total_symbols = 0;
881
0
  tokens_lz77.resize(tokens.size());
882
0
  std::vector<float> sym_cost;
883
0
  HybridUintConfig uint_config;
884
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
885
0
    size_t distance_multiplier =
886
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
887
0
    const auto& in = tokens[stream];
888
0
    auto& out = tokens_lz77[stream];
889
0
    total_symbols += in.size();
890
    // Cumulative sum of bit costs.
891
0
    sym_cost.resize(in.size() + 1);
892
0
    for (size_t i = 0; i < in.size(); i++) {
893
0
      uint32_t tok, nbits, unused_bits;
894
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
895
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
896
0
    }
897
0
    out.reserve(in.size());
898
0
    for (size_t i = 0; i < in.size(); i++) {
899
0
      size_t num_to_copy = 0;
900
0
      size_t distance_symbol = 0;  // 1 for RLE.
901
0
      if (distance_multiplier != 0) {
902
0
        distance_symbol = 1;  // Special distance 1 if enabled.
903
0
        JXL_DASSERT(kSpecialDistances[1][0] == 1);
904
0
        JXL_DASSERT(kSpecialDistances[1][1] == 0);
905
0
      }
906
0
      if (i > 0) {
907
0
        for (; i + num_to_copy < in.size(); num_to_copy++) {
908
0
          if (in[i + num_to_copy].value != in[i - 1].value) {
909
0
            break;
910
0
          }
911
0
        }
912
0
      }
913
0
      if (num_to_copy == 0) {
914
0
        out.push_back(in[i]);
915
0
        continue;
916
0
      }
917
0
      float cost = sym_cost[i + num_to_copy] - sym_cost[i];
918
      // This subtraction might overflow, but that's OK.
919
0
      size_t lz77_len = num_to_copy - lz77.min_length;
920
0
      float lz77_cost = num_to_copy >= lz77.min_length
921
0
                            ? CeilLog2Nonzero(lz77_len + 1) + 1
922
0
                            : 0;
923
0
      if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
924
0
        for (size_t j = 0; j < num_to_copy; j++) {
925
0
          out.push_back(in[i + j]);
926
0
        }
927
0
        i += num_to_copy - 1;
928
0
        continue;
929
0
      }
930
      // Output the LZ77 length
931
0
      out.emplace_back(in[i].context, lz77_len);
932
0
      out.back().is_lz77_length = true;
933
0
      i += num_to_copy - 1;
934
0
      bit_decrease += cost - lz77_cost;
935
      // Output the LZ77 copy distance.
936
0
      out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
937
0
    }
938
0
  }
939
940
0
  if (bit_decrease > total_symbols * 0.2 + 16) {
941
0
    lz77.enabled = true;
942
0
  }
943
0
}
944
945
// Hash chain for LZ77 matching
946
struct HashChain {
947
  size_t size_;
948
  std::vector<uint32_t> data_;
949
950
  unsigned hash_num_values_ = 32768;
951
  unsigned hash_mask_ = hash_num_values_ - 1;
952
  unsigned hash_shift_ = 5;
953
954
  std::vector<int> head;
955
  std::vector<uint32_t> chain;
956
  std::vector<int> val;
957
958
  // Speed up repetitions of zero
959
  std::vector<int> headz;
960
  std::vector<uint32_t> chainz;
961
  std::vector<uint32_t> zeros;
962
  uint32_t numzeros = 0;
963
964
  size_t window_size_;
965
  size_t window_mask_;
966
  size_t min_length_;
967
  size_t max_length_;
968
969
  // Map of special distance codes.
970
  std::unordered_map<int, int> special_dist_table_;
971
  size_t num_special_distances_ = 0;
972
973
  uint32_t maxchainlength = 256;  // window_size_ to allow all
974
975
  HashChain(const Token* data, size_t size, size_t window_size,
976
            size_t min_length, size_t max_length, size_t distance_multiplier)
977
      : size_(size),
978
        window_size_(window_size),
979
        window_mask_(window_size - 1),
980
        min_length_(min_length),
981
0
        max_length_(max_length) {
982
0
    data_.resize(size);
983
0
    for (size_t i = 0; i < size; i++) {
984
0
      data_[i] = data[i].value;
985
0
    }
986
987
0
    head.resize(hash_num_values_, -1);
988
0
    val.resize(window_size_, -1);
989
0
    chain.resize(window_size_);
990
0
    for (uint32_t i = 0; i < window_size_; ++i) {
991
0
      chain[i] = i;  // same value as index indicates uninitialized
992
0
    }
993
994
0
    zeros.resize(window_size_);
995
0
    headz.resize(window_size_ + 1, -1);
996
0
    chainz.resize(window_size_);
997
0
    for (uint32_t i = 0; i < window_size_; ++i) {
998
0
      chainz[i] = i;
999
0
    }
1000
    // Translate distance to special distance code.
1001
0
    if (distance_multiplier) {
1002
      // Count down, so if due to small distance multiplier multiple distances
1003
      // map to the same code, the smallest code will be used in the end.
1004
0
      for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
1005
0
        int xi = kSpecialDistances[i][0];
1006
0
        int yi = kSpecialDistances[i][1];
1007
0
        int distance = yi * distance_multiplier + xi;
1008
        // Ensure that we map distance 1 to the lowest symbols.
1009
0
        if (distance < 1) distance = 1;
1010
0
        special_dist_table_[distance] = i;
1011
0
      }
1012
0
      num_special_distances_ = kNumSpecialDistances;
1013
0
    }
1014
0
  }
1015
1016
0
  uint32_t GetHash(size_t pos) const {
1017
0
    uint32_t result = 0;
1018
0
    if (pos + 2 < size_) {
1019
      // TODO(lode): take the MSB's of the uint32_t values into account as well,
1020
      // given that the hash code itself is less than 32 bits.
1021
0
      result ^= (uint32_t)(data_[pos + 0] << 0u);
1022
0
      result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
1023
0
      result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
1024
0
    } else {
1025
      // No need to compute hash of last 2 bytes, the length 2 is too short.
1026
0
      return 0;
1027
0
    }
1028
0
    return result & hash_mask_;
1029
0
  }
1030
1031
0
  uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1032
0
    size_t end = pos + window_size_;
1033
0
    if (end > size_) end = size_;
1034
0
    if (prevzeros > 0) {
1035
0
      if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1036
0
          end == pos + window_size_) {
1037
0
        return prevzeros;
1038
0
      } else {
1039
0
        return prevzeros - 1;
1040
0
      }
1041
0
    }
1042
0
    uint32_t num = 0;
1043
0
    while (pos + num < end && data_[pos + num] == 0) num++;
1044
0
    return num;
1045
0
  }
1046
1047
0
  void Update(size_t pos) {
1048
0
    uint32_t hashval = GetHash(pos);
1049
0
    uint32_t wpos = pos & window_mask_;
1050
1051
0
    val[wpos] = (int)hashval;
1052
0
    if (head[hashval] != -1) chain[wpos] = head[hashval];
1053
0
    head[hashval] = wpos;
1054
1055
0
    if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1056
0
    numzeros = CountZeros(pos, numzeros);
1057
1058
0
    zeros[wpos] = numzeros;
1059
0
    if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1060
0
    headz[numzeros] = wpos;
1061
0
  }
1062
1063
0
  void Update(size_t pos, size_t len) {
1064
0
    for (size_t i = 0; i < len; i++) {
1065
0
      Update(pos + i);
1066
0
    }
1067
0
  }
1068
1069
  template <typename CB>
1070
0
  void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1071
0
    uint32_t wpos = pos & window_mask_;
1072
0
    uint32_t hashval = GetHash(pos);
1073
0
    uint32_t hashpos = chain[wpos];
1074
1075
0
    int prev_dist = 0;
1076
0
    int end = std::min<int>(pos + max_length_, size_);
1077
0
    uint32_t chainlength = 0;
1078
0
    uint32_t best_len = 0;
1079
0
    for (;;) {
1080
0
      int dist = (hashpos <= wpos) ? (wpos - hashpos)
1081
0
                                   : (wpos - hashpos + window_mask_ + 1);
1082
0
      if (dist < prev_dist) break;
1083
0
      prev_dist = dist;
1084
0
      uint32_t len = 0;
1085
0
      if (dist > 0) {
1086
0
        int i = pos;
1087
0
        int j = pos - dist;
1088
0
        if (numzeros > 3) {
1089
0
          int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1090
0
          if (i + r >= end) r = end - i - 1;
1091
0
          i += r;
1092
0
          j += r;
1093
0
        }
1094
0
        while (i < end && data_[i] == data_[j]) {
1095
0
          i++;
1096
0
          j++;
1097
0
        }
1098
0
        len = i - pos;
1099
        // This can trigger even if the new length is slightly smaller than the
1100
        // best length, because it is possible for a slightly cheaper distance
1101
        // symbol to occur.
1102
0
        if (len >= min_length_ && len + 2 >= best_len) {
1103
0
          auto it = special_dist_table_.find(dist);
1104
0
          int dist_symbol = (it == special_dist_table_.end())
1105
0
                                ? (num_special_distances_ + dist - 1)
1106
0
                                : it->second;
1107
0
          found_match(len, dist_symbol);
1108
0
          if (len > best_len) best_len = len;
1109
0
        }
1110
0
      }
1111
1112
0
      chainlength++;
1113
0
      if (chainlength >= maxchainlength) break;
1114
1115
0
      if (numzeros >= 3 && len > numzeros) {
1116
0
        if (hashpos == chainz[hashpos]) break;
1117
0
        hashpos = chainz[hashpos];
1118
0
        if (zeros[hashpos] != numzeros) break;
1119
0
      } else {
1120
0
        if (hashpos == chain[hashpos]) break;
1121
0
        hashpos = chain[hashpos];
1122
0
        if (val[hashpos] != (int)hashval) break;  // outdated hash value
1123
0
      }
1124
0
    }
1125
0
  }
Unexecuted instantiation: enc_ans.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1}>(unsigned long, int, jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1} const&) const
Unexecuted instantiation: enc_ans.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params&, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > >&)::$_1>(unsigned long, int, jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params&, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > >&)::$_1 const&) const
1126
  void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1127
0
                 size_t* result_len) const {
1128
0
    *result_dist_symbol = 0;
1129
0
    *result_len = 1;
1130
0
    FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1131
0
      if (len > *result_len ||
1132
0
          (len == *result_len && *result_dist_symbol > dist_symbol)) {
1133
0
        *result_len = len;
1134
0
        *result_dist_symbol = dist_symbol;
1135
0
      }
1136
0
    });
1137
0
  }
1138
};
1139
1140
0
float LenCost(size_t len) {
1141
0
  uint32_t nbits, bits, tok;
1142
0
  HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1143
0
  constexpr float kCostTable[] = {
1144
0
      2.797667318563126,  3.213177690381199,  2.5706009246743737,
1145
0
      2.408392498667534,  2.829649191872326,  3.3923087753324577,
1146
0
      4.029267451554331,  4.415576699706408,  4.509357574741465,
1147
0
      9.21481543803004,   10.020590190114898, 11.858671627804766,
1148
0
      12.45853300490526,  11.713105831990857, 12.561996324849314,
1149
0
      13.775477692278367, 13.174027068768641,
1150
0
  };
1151
0
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1152
0
  if (tok >= table_size) tok = table_size - 1;
1153
0
  return kCostTable[tok] + nbits;
1154
0
}
1155
1156
// TODO(veluca): this does not take into account usage or non-usage of distance
1157
// multipliers.
1158
0
float DistCost(size_t dist) {
1159
0
  uint32_t nbits, bits, tok;
1160
0
  HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1161
0
  constexpr float kCostTable[] = {
1162
0
      6.368282626312716,  5.680793277090298,  8.347404197105247,
1163
0
      7.641619201599141,  6.914328374119438,  7.959808291537444,
1164
0
      8.70023120759855,   8.71378518934703,   9.379132523982769,
1165
0
      9.110472749092708,  9.159029569270908,  9.430936766731973,
1166
0
      7.278284055315169,  7.8278514904267755, 10.026641158289236,
1167
0
      9.976049229827066,  9.64351607048908,   9.563403863480442,
1168
0
      10.171474111762747, 10.45950155077234,  9.994813912104219,
1169
0
      10.322524683741156, 8.465808729388186,  8.756254166066853,
1170
0
      10.160930174662234, 10.247329273413435, 10.04090403724809,
1171
0
      10.129398517544082, 9.342311691539546,  9.07608009102374,
1172
0
      10.104799540677513, 10.378079384990906, 10.165828974075072,
1173
0
      10.337595322341553, 7.940557464567944,  10.575665823319431,
1174
0
      11.023344321751955, 10.736144698831827, 11.118277044595054,
1175
0
      7.468468230648442,  10.738305230932939, 10.906980780216568,
1176
0
      10.163468216353817, 10.17805759656433,  11.167283670483565,
1177
0
      11.147050200274544, 10.517921919244333, 10.651764778156886,
1178
0
      10.17074446448919,  11.217636876224745, 11.261630721139484,
1179
0
      11.403140815247259, 10.892472096873417, 11.1859607804481,
1180
0
      8.017346947551262,  7.895143720278828,  11.036577113822025,
1181
0
      11.170562110315794, 10.326988722591086, 10.40872184751056,
1182
0
      11.213498225466386, 11.30580635516863,  10.672272515665442,
1183
0
      10.768069466228063, 11.145257364153565, 11.64668307145549,
1184
0
      10.593156194627339, 11.207499484844943, 10.767517766396908,
1185
0
      10.826629811407042, 10.737764794499988, 10.6200448518045,
1186
0
      10.191315385198092, 8.468384171390085,  11.731295299170432,
1187
0
      11.824619886654398, 10.41518844301179,  10.16310536548649,
1188
0
      10.539423685097576, 10.495136599328031, 10.469112847728267,
1189
0
      11.72057686174922,  10.910326337834674, 11.378921834673758,
1190
0
      11.847759036098536, 11.92071647623854,  10.810628276345282,
1191
0
      11.008601085273893, 11.910326337834674, 11.949212023423133,
1192
0
      11.298614839104337, 11.611603659010392, 10.472930394619985,
1193
0
      11.835564720850282, 11.523267392285337, 12.01055816679611,
1194
0
      8.413029688994023,  11.895784139536406, 11.984679534970505,
1195
0
      11.220654278717394, 11.716311684833672, 10.61036646226114,
1196
0
      10.89849965960364,  10.203762898863669, 10.997560826267238,
1197
0
      11.484217379438984, 11.792836176993665, 12.24310468755171,
1198
0
      11.464858097919262, 12.212747017409377, 11.425595666074955,
1199
0
      11.572048533398757, 12.742093965163013, 11.381874288645637,
1200
0
      12.191870445817015, 11.683156920035426, 11.152442115262197,
1201
0
      11.90303691580457,  11.653292787169159, 11.938615382266098,
1202
0
      16.970641701570223, 16.853602280380002, 17.26240782594733,
1203
0
      16.644655390108507, 17.14310889757499,  16.910935455445955,
1204
0
      17.505678976959697, 17.213498225466388, 2.4162310293553024,
1205
0
      3.494587244462329,  3.5258600986408344, 3.4959806589517095,
1206
0
      3.098390886949687,  3.343454654302911,  3.588847442290287,
1207
0
      4.14614790111827,   5.152948641990529,  7.433696808092598,
1208
0
      9.716311684833672,
1209
0
  };
1210
0
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1211
0
  if (tok >= table_size) tok = table_size - 1;
1212
0
  return kCostTable[tok] + nbits;
1213
0
}
1214
1215
void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1216
                    const std::vector<std::vector<Token>>& tokens,
1217
                    LZ77Params& lz77,
1218
0
                    std::vector<std::vector<Token>>& tokens_lz77) {
1219
  // TODO(veluca): tune heuristics here.
1220
0
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1221
0
  float bit_decrease = 0;
1222
0
  size_t total_symbols = 0;
1223
0
  tokens_lz77.resize(tokens.size());
1224
0
  HybridUintConfig uint_config;
1225
0
  std::vector<float> sym_cost;
1226
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
1227
0
    size_t distance_multiplier =
1228
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1229
0
    const auto& in = tokens[stream];
1230
0
    auto& out = tokens_lz77[stream];
1231
0
    total_symbols += in.size();
1232
    // Cumulative sum of bit costs.
1233
0
    sym_cost.resize(in.size() + 1);
1234
0
    for (size_t i = 0; i < in.size(); i++) {
1235
0
      uint32_t tok, nbits, unused_bits;
1236
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1237
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1238
0
    }
1239
1240
0
    out.reserve(in.size());
1241
0
    size_t max_distance = in.size();
1242
0
    size_t min_length = lz77.min_length;
1243
0
    JXL_ASSERT(min_length >= 3);
1244
0
    size_t max_length = in.size();
1245
1246
    // Use next power of two as window size.
1247
0
    size_t window_size = 1;
1248
0
    while (window_size < max_distance && window_size < kWindowSize) {
1249
0
      window_size <<= 1;
1250
0
    }
1251
1252
0
    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1253
0
                    distance_multiplier);
1254
0
    size_t len, dist_symbol;
1255
1256
0
    const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
1257
1258
    // Whether the next symbol was already updated (to test lazy matching)
1259
0
    bool already_updated = false;
1260
0
    for (size_t i = 0; i < in.size(); i++) {
1261
0
      out.push_back(in[i]);
1262
0
      if (!already_updated) chain.Update(i);
1263
0
      already_updated = false;
1264
0
      chain.FindMatch(i, max_distance, &dist_symbol, &len);
1265
0
      if (len >= min_length) {
1266
0
        if (len < max_lazy_match_len && i + 1 < in.size()) {
1267
          // Try length at next symbol lazy matching
1268
0
          chain.Update(i + 1);
1269
0
          already_updated = true;
1270
0
          size_t len2, dist_symbol2;
1271
0
          chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1272
0
          if (len2 > len) {
1273
            // Use the lazy match. Add literal, and use the next length starting
1274
            // from the next byte.
1275
0
            ++i;
1276
0
            already_updated = false;
1277
0
            len = len2;
1278
0
            dist_symbol = dist_symbol2;
1279
0
            out.push_back(in[i]);
1280
0
          }
1281
0
        }
1282
1283
0
        float cost = sym_cost[i + len] - sym_cost[i];
1284
0
        size_t lz77_len = len - lz77.min_length;
1285
0
        float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1286
0
                          sce.AddSymbolCost(out.back().context);
1287
1288
0
        if (lz77_cost <= cost) {
1289
0
          out.back().value = len - min_length;
1290
0
          out.back().is_lz77_length = true;
1291
0
          out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1292
0
          bit_decrease += cost - lz77_cost;
1293
0
        } else {
1294
          // LZ77 match ignored, and symbol already pushed. Push all other
1295
          // symbols and skip.
1296
0
          for (size_t j = 1; j < len; j++) {
1297
0
            out.push_back(in[i + j]);
1298
0
          }
1299
0
        }
1300
1301
0
        if (already_updated) {
1302
0
          chain.Update(i + 2, len - 2);
1303
0
          already_updated = false;
1304
0
        } else {
1305
0
          chain.Update(i + 1, len - 1);
1306
0
        }
1307
0
        i += len - 1;
1308
0
      } else {
1309
        // Literal, already pushed
1310
0
      }
1311
0
    }
1312
0
  }
1313
1314
0
  if (bit_decrease > total_symbols * 0.2 + 16) {
1315
0
    lz77.enabled = true;
1316
0
  }
1317
0
}
1318
1319
void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1320
                       const std::vector<std::vector<Token>>& tokens,
1321
                       LZ77Params& lz77,
1322
0
                       std::vector<std::vector<Token>>& tokens_lz77) {
1323
0
  std::vector<std::vector<Token>> tokens_for_cost_estimate;
1324
0
  ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1325
  // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1326
  // run the optimal matching.
1327
0
  if (!lz77.enabled) return;
1328
0
  SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1329
0
                          tokens_for_cost_estimate, lz77);
1330
0
  tokens_lz77.resize(tokens.size());
1331
0
  HybridUintConfig uint_config;
1332
0
  std::vector<float> sym_cost;
1333
0
  std::vector<uint32_t> dist_symbols;
1334
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
1335
0
    size_t distance_multiplier =
1336
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1337
0
    const auto& in = tokens[stream];
1338
0
    auto& out = tokens_lz77[stream];
1339
    // Cumulative sum of bit costs.
1340
0
    sym_cost.resize(in.size() + 1);
1341
0
    for (size_t i = 0; i < in.size(); i++) {
1342
0
      uint32_t tok, nbits, unused_bits;
1343
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1344
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1345
0
    }
1346
1347
0
    out.reserve(in.size());
1348
0
    size_t max_distance = in.size();
1349
0
    size_t min_length = lz77.min_length;
1350
0
    JXL_ASSERT(min_length >= 3);
1351
0
    size_t max_length = in.size();
1352
1353
    // Use next power of two as window size.
1354
0
    size_t window_size = 1;
1355
0
    while (window_size < max_distance && window_size < kWindowSize) {
1356
0
      window_size <<= 1;
1357
0
    }
1358
1359
0
    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1360
0
                    distance_multiplier);
1361
1362
0
    struct MatchInfo {
1363
0
      uint32_t len;
1364
0
      uint32_t dist_symbol;
1365
0
      uint32_t ctx;
1366
0
      float total_cost = std::numeric_limits<float>::max();
1367
0
    };
1368
    // Total cost to encode the first N symbols.
1369
0
    std::vector<MatchInfo> prefix_costs(in.size() + 1);
1370
0
    prefix_costs[0].total_cost = 0;
1371
1372
0
    size_t rle_length = 0;
1373
0
    size_t skip_lz77 = 0;
1374
0
    for (size_t i = 0; i < in.size(); i++) {
1375
0
      chain.Update(i);
1376
0
      float lit_cost =
1377
0
          prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1378
0
      if (prefix_costs[i + 1].total_cost > lit_cost) {
1379
0
        prefix_costs[i + 1].dist_symbol = 0;
1380
0
        prefix_costs[i + 1].len = 1;
1381
0
        prefix_costs[i + 1].ctx = in[i].context;
1382
0
        prefix_costs[i + 1].total_cost = lit_cost;
1383
0
      }
1384
0
      if (skip_lz77 > 0) {
1385
0
        skip_lz77--;
1386
0
        continue;
1387
0
      }
1388
0
      dist_symbols.clear();
1389
0
      chain.FindMatches(i, max_distance,
1390
0
                        [&dist_symbols](size_t len, size_t dist_symbol) {
1391
0
                          if (dist_symbols.size() <= len) {
1392
0
                            dist_symbols.resize(len + 1, dist_symbol);
1393
0
                          }
1394
0
                          if (dist_symbol < dist_symbols[len]) {
1395
0
                            dist_symbols[len] = dist_symbol;
1396
0
                          }
1397
0
                        });
1398
0
      if (dist_symbols.size() <= min_length) continue;
1399
0
      {
1400
0
        size_t best_cost = dist_symbols.back();
1401
0
        for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1402
0
          if (dist_symbols[j] < best_cost) {
1403
0
            best_cost = dist_symbols[j];
1404
0
          }
1405
0
          dist_symbols[j] = best_cost;
1406
0
        }
1407
0
      }
1408
0
      for (size_t j = min_length; j < dist_symbols.size(); j++) {
1409
        // Cost model that uses results from lazy LZ77.
1410
0
        float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1411
0
                          sce.DistCost(dist_symbols[j], lz77);
1412
0
        float cost = prefix_costs[i].total_cost + lz77_cost;
1413
0
        if (prefix_costs[i + j].total_cost > cost) {
1414
0
          prefix_costs[i + j].len = j;
1415
0
          prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1416
0
          prefix_costs[i + j].ctx = in[i].context;
1417
0
          prefix_costs[i + j].total_cost = cost;
1418
0
        }
1419
0
      }
1420
      // We are in a RLE sequence: skip all the symbols except the first 8 and
1421
      // the last 8. This avoid quadratic costs for sequences with long runs of
1422
      // the same symbol.
1423
0
      if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1424
0
          (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1425
0
        rle_length++;
1426
0
      } else {
1427
0
        rle_length = 0;
1428
0
      }
1429
0
      if (rle_length >= 8 && dist_symbols.size() > 9) {
1430
0
        skip_lz77 = dist_symbols.size() - 10;
1431
0
        rle_length = 0;
1432
0
      }
1433
0
    }
1434
0
    size_t pos = in.size();
1435
0
    while (pos > 0) {
1436
0
      bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1437
0
      if (is_lz77_length) {
1438
0
        size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1439
0
        out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1440
0
      }
1441
0
      size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1442
0
                                  : in[pos - 1].value;
1443
0
      out.emplace_back(prefix_costs[pos].ctx, val);
1444
0
      out.back().is_lz77_length = is_lz77_length;
1445
0
      pos -= prefix_costs[pos].len;
1446
0
    }
1447
0
    std::reverse(out.begin(), out.end());
1448
0
  }
1449
0
}
1450
1451
void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1452
               const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1453
0
               std::vector<std::vector<Token>>& tokens_lz77) {
1454
0
  lz77.enabled = false;
1455
0
  if (params.force_huffman) {
1456
0
    lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1457
0
  } else {
1458
0
    lz77.min_symbol = 224;
1459
0
  }
1460
0
  if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1461
0
    return;
1462
0
  } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1463
0
    ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1464
0
  } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1465
0
    ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1466
0
  } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1467
0
    ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1468
0
  } else {
1469
0
    JXL_UNREACHABLE("Not implemented");
1470
0
  }
1471
0
}
1472
}  // namespace
1473
1474
size_t BuildAndEncodeHistograms(const HistogramParams& params,
1475
                                size_t num_contexts,
1476
                                std::vector<std::vector<Token>>& tokens,
1477
                                EntropyEncodingData* codes,
1478
                                std::vector<uint8_t>* context_map,
1479
                                BitWriter* writer, size_t layer,
1480
0
                                AuxOut* aux_out) {
1481
0
  size_t total_bits = 0;
1482
0
  codes->lz77.nonserialized_distance_context = num_contexts;
1483
0
  std::vector<std::vector<Token>> tokens_lz77;
1484
0
  ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1485
0
  if (ans_fuzzer_friendly_) {
1486
0
    codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1487
0
    codes->lz77.min_symbol = 2048;
1488
0
  }
1489
1490
0
  const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1491
0
  BitWriter::Allotment allotment(writer,
1492
0
                                 128 + num_contexts * 40 + max_contexts * 96);
1493
0
  if (writer) {
1494
0
    JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1495
0
  } else {
1496
0
    size_t ebits, bits;
1497
0
    JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1498
0
    total_bits += bits;
1499
0
  }
1500
0
  if (codes->lz77.enabled) {
1501
0
    if (writer) {
1502
0
      size_t b = writer->BitsWritten();
1503
0
      EncodeUintConfig(codes->lz77.length_uint_config, writer,
1504
0
                       /*log_alpha_size=*/8);
1505
0
      total_bits += writer->BitsWritten() - b;
1506
0
    } else {
1507
0
      SizeWriter size_writer;
1508
0
      EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1509
0
                       /*log_alpha_size=*/8);
1510
0
      total_bits += size_writer.size;
1511
0
    }
1512
0
    num_contexts += 1;
1513
0
    tokens = std::move(tokens_lz77);
1514
0
  }
1515
0
  size_t total_tokens = 0;
1516
  // Build histograms.
1517
0
  HistogramBuilder builder(num_contexts);
1518
0
  HybridUintConfig uint_config;  //  Default config for clustering.
1519
  // Unless we are using the kContextMap histogram option.
1520
0
  if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1521
0
    uint_config = HybridUintConfig(2, 0, 1);
1522
0
  }
1523
0
  if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
1524
0
    uint_config = HybridUintConfig(0, 0, 0);
1525
0
  }
1526
0
  if (ans_fuzzer_friendly_) {
1527
0
    uint_config = HybridUintConfig(10, 0, 0);
1528
0
  }
1529
0
  for (size_t i = 0; i < tokens.size(); ++i) {
1530
0
    if (codes->lz77.enabled) {
1531
0
      for (size_t j = 0; j < tokens[i].size(); ++j) {
1532
0
        const Token& token = tokens[i][j];
1533
0
        total_tokens++;
1534
0
        uint32_t tok, nbits, bits;
1535
0
        (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1536
0
            .Encode(token.value, &tok, &nbits, &bits);
1537
0
        tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1538
0
        builder.VisitSymbol(tok, token.context);
1539
0
      }
1540
0
    } else if (num_contexts == 1) {
1541
0
      for (size_t j = 0; j < tokens[i].size(); ++j) {
1542
0
        const Token& token = tokens[i][j];
1543
0
        total_tokens++;
1544
0
        uint32_t tok, nbits, bits;
1545
0
        uint_config.Encode(token.value, &tok, &nbits, &bits);
1546
0
        builder.VisitSymbol(tok, /*token.context=*/0);
1547
0
      }
1548
0
    } else {
1549
0
      for (size_t j = 0; j < tokens[i].size(); ++j) {
1550
0
        const Token& token = tokens[i][j];
1551
0
        total_tokens++;
1552
0
        uint32_t tok, nbits, bits;
1553
0
        uint_config.Encode(token.value, &tok, &nbits, &bits);
1554
0
        builder.VisitSymbol(tok, token.context);
1555
0
      }
1556
0
    }
1557
0
  }
1558
1559
0
  bool use_prefix_code =
1560
0
      params.force_huffman || total_tokens < 100 ||
1561
0
      params.clustering == HistogramParams::ClusteringType::kFastest ||
1562
0
      ans_fuzzer_friendly_;
1563
0
  if (!use_prefix_code) {
1564
0
    bool all_singleton = true;
1565
0
    for (size_t i = 0; i < num_contexts; i++) {
1566
0
      if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1567
0
        all_singleton = false;
1568
0
      }
1569
0
    }
1570
0
    if (all_singleton) {
1571
0
      use_prefix_code = true;
1572
0
    }
1573
0
  }
1574
1575
  // Encode histograms.
1576
0
  total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1577
0
                                                  context_map, use_prefix_code,
1578
0
                                                  writer, layer, aux_out);
1579
0
  allotment.FinishedHistogram(writer);
1580
0
  allotment.ReclaimAndCharge(writer, layer, aux_out);
1581
1582
0
  if (aux_out != nullptr) {
1583
0
    aux_out->layers[layer].num_clustered_histograms +=
1584
0
        codes->encoding_info.size();
1585
0
  }
1586
0
  return total_bits;
1587
0
}
1588
1589
size_t WriteTokens(const std::vector<Token>& tokens,
1590
                   const EntropyEncodingData& codes,
1591
0
                   const std::vector<uint8_t>& context_map, BitWriter* writer) {
1592
0
  size_t num_extra_bits = 0;
1593
0
  if (codes.use_prefix_code) {
1594
0
    for (size_t i = 0; i < tokens.size(); i++) {
1595
0
      uint32_t tok, nbits, bits;
1596
0
      const Token& token = tokens[i];
1597
0
      size_t histo = context_map[token.context];
1598
0
      (token.is_lz77_length ? codes.lz77.length_uint_config
1599
0
                            : codes.uint_config[histo])
1600
0
          .Encode(token.value, &tok, &nbits, &bits);
1601
0
      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1602
      // Combine two calls to the BitWriter. Equivalent to:
1603
      // writer->Write(codes.encoding_info[histo][tok].depth,
1604
      //               codes.encoding_info[histo][tok].bits);
1605
      // writer->Write(nbits, bits);
1606
0
      uint64_t data = codes.encoding_info[histo][tok].bits;
1607
0
      data |= bits << codes.encoding_info[histo][tok].depth;
1608
0
      writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1609
0
      num_extra_bits += nbits;
1610
0
    }
1611
0
    return num_extra_bits;
1612
0
  }
1613
0
  std::vector<uint64_t> out;
1614
0
  std::vector<uint8_t> out_nbits;
1615
0
  out.reserve(tokens.size());
1616
0
  out_nbits.reserve(tokens.size());
1617
0
  uint64_t allbits = 0;
1618
0
  size_t numallbits = 0;
1619
  // Writes in *reversed* order.
1620
0
  auto addbits = [&](size_t bits, size_t nbits) {
1621
0
    if (JXL_UNLIKELY(nbits)) {
1622
0
      JXL_DASSERT(bits >> nbits == 0);
1623
0
      if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1624
0
        out.push_back(allbits);
1625
0
        out_nbits.push_back(numallbits);
1626
0
        numallbits = allbits = 0;
1627
0
      }
1628
0
      allbits <<= nbits;
1629
0
      allbits |= bits;
1630
0
      numallbits += nbits;
1631
0
    }
1632
0
  };
1633
0
  const int end = tokens.size();
1634
0
  ANSCoder ans;
1635
0
  if (codes.lz77.enabled || context_map.size() > 1) {
1636
0
    for (int i = end - 1; i >= 0; --i) {
1637
0
      const Token token = tokens[i];
1638
0
      const uint8_t histo = context_map[token.context];
1639
0
      uint32_t tok, nbits, bits;
1640
0
      (token.is_lz77_length ? codes.lz77.length_uint_config
1641
0
                            : codes.uint_config[histo])
1642
0
          .Encode(tokens[i].value, &tok, &nbits, &bits);
1643
0
      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1644
0
      const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1645
      // Extra bits first as this is reversed.
1646
0
      addbits(bits, nbits);
1647
0
      num_extra_bits += nbits;
1648
0
      uint8_t ans_nbits = 0;
1649
0
      uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1650
0
      addbits(ans_bits, ans_nbits);
1651
0
    }
1652
0
  } else {
1653
0
    for (int i = end - 1; i >= 0; --i) {
1654
0
      uint32_t tok, nbits, bits;
1655
0
      codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
1656
0
      const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
1657
      // Extra bits first as this is reversed.
1658
0
      addbits(bits, nbits);
1659
0
      num_extra_bits += nbits;
1660
0
      uint8_t ans_nbits = 0;
1661
0
      uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1662
0
      addbits(ans_bits, ans_nbits);
1663
0
    }
1664
0
  }
1665
0
  const uint32_t state = ans.GetState();
1666
0
  writer->Write(32, state);
1667
0
  writer->Write(numallbits, allbits);
1668
0
  for (int i = out.size(); i > 0; --i) {
1669
0
    writer->Write(out_nbits[i - 1], out[i - 1]);
1670
0
  }
1671
0
  return num_extra_bits;
1672
0
}
1673
1674
void WriteTokens(const std::vector<Token>& tokens,
1675
                 const EntropyEncodingData& codes,
1676
                 const std::vector<uint8_t>& context_map, BitWriter* writer,
1677
0
                 size_t layer, AuxOut* aux_out) {
1678
0
  BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1679
0
  size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1680
0
  allotment.ReclaimAndCharge(writer, layer, aux_out);
1681
0
  if (aux_out != nullptr) {
1682
0
    aux_out->layers[layer].extra_bits += num_extra_bits;
1683
0
  }
1684
0
}
1685
1686
0
void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1687
#if JXL_IS_DEBUG_BUILD  // Guard against accidental / malicious changes.
1688
0
  ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1689
0
#endif
1690
0
}
1691
}  // namespace jxl