Coverage Report

Created: 2025-11-14 07:32

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/enc_lz77.cc
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_lz77.h"
7
8
#include <jxl/memory_manager.h>
9
#include <jxl/types.h>
10
11
#include <algorithm>
12
#include <cmath>
13
#include <cstddef>
14
#include <cstdint>
15
#include <limits>
16
#include <unordered_map>
17
#include <utility>
18
#include <vector>
19
20
#include "lib/jxl/ans_common.h"
21
#include "lib/jxl/ans_params.h"
22
#include "lib/jxl/base/bits.h"
23
#include "lib/jxl/base/common.h"
24
#include "lib/jxl/base/compiler_specific.h"
25
#include "lib/jxl/base/fast_math-inl.h"
26
#include "lib/jxl/base/status.h"
27
#include "lib/jxl/common.h"
28
#include "lib/jxl/dec_ans.h"
29
#include "lib/jxl/enc_ans_params.h"
30
#include "lib/jxl/enc_aux_out.h"
31
#include "lib/jxl/enc_cluster.h"
32
#include "lib/jxl/enc_context_map.h"
33
#include "lib/jxl/enc_fields.h"
34
#include "lib/jxl/enc_huffman.h"
35
#include "lib/jxl/enc_params.h"
36
#include "lib/jxl/fields.h"
37
38
namespace jxl {
39
40
namespace {
41
42
class SymbolCostEstimator {
43
 public:
44
  SymbolCostEstimator(size_t num_contexts, bool force_huffman,
45
                      const std::vector<std::vector<Token>>& tokens,
46
32.7k
                      const LZ77Params& lz77) {
47
32.7k
    std::vector<Histogram> builder(num_contexts);
48
    // Build histograms for estimating lz77 savings.
49
32.7k
    HybridUintConfig uint_config;
50
51.9k
    for (const auto& stream : tokens) {
51
36.4M
      for (const auto& token : stream) {
52
36.4M
        uint32_t tok, nbits, bits;
53
36.4M
        (token.is_lz77_length ? lz77.length_uint_config : uint_config)
54
36.4M
            .Encode(token.value, &tok, &nbits, &bits);
55
36.4M
        tok += token.is_lz77_length ? lz77.min_symbol : 0;
56
36.4M
        JXL_DASSERT(token.context < num_contexts);
57
36.4M
        builder[token.context].Add(tok);
58
36.4M
      }
59
51.9k
    }
60
32.7k
    max_alphabet_size_ = 0;
61
141k
    for (size_t i = 0; i < num_contexts; i++) {
62
108k
      max_alphabet_size_ =
63
108k
          std::max(max_alphabet_size_, builder[i].counts.size());
64
108k
    }
65
32.7k
    bits_.resize(num_contexts * max_alphabet_size_);
66
    // TODO(veluca): SIMD?
67
32.7k
    add_symbol_cost_.resize(num_contexts);
68
141k
    for (size_t i = 0; i < num_contexts; i++) {
69
108k
      float inv_total = 1.0f / (builder[i].total_count + 1e-8f);
70
108k
      float total_cost = 0;
71
2.72M
      for (size_t j = 0; j < builder[i].counts.size(); j++) {
72
2.61M
        size_t cnt = builder[i].counts[j];
73
2.61M
        float cost = 0;
74
2.61M
        if (cnt != 0 && cnt != builder[i].total_count) {
75
1.22M
          cost = -FastLog2f(cnt * inv_total);
76
1.22M
          if (force_huffman) cost = std::ceil(cost);
77
1.39M
        } else if (cnt == 0) {
78
1.38M
          cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
79
1.38M
        }
80
2.61M
        bits_[i * max_alphabet_size_ + j] = cost;
81
2.61M
        total_cost += cost * builder[i].counts[j];
82
2.61M
      }
83
      // Penalty for adding a lz77 symbol to this contest (only used for static
84
      // cost model). Higher penalty for contexts that have a very low
85
      // per-symbol entropy.
86
108k
      add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
87
108k
    }
88
32.7k
  }
89
36.4M
  float Bits(size_t ctx, size_t sym) const {
90
36.4M
    return bits_[ctx * max_alphabet_size_ + sym];
91
36.4M
  }
92
0
  float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
93
0
    uint32_t nbits, bits, tok;
94
0
    lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
95
0
    tok += lz77.min_symbol;
96
0
    return nbits + Bits(ctx, tok);
97
0
  }
98
0
  float DistCost(size_t len, const LZ77Params& lz77) const {
99
0
    uint32_t nbits, bits, tok;
100
0
    HybridUintConfig().Encode(len, &tok, &nbits, &bits);
101
0
    return nbits + Bits(lz77.nonserialized_distance_context, tok);
102
0
  }
103
0
  float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
104
105
 private:
106
  size_t max_alphabet_size_;
107
  std::vector<float> bits_;
108
  std::vector<float> add_symbol_cost_;
109
};
110
111
std::vector<std::vector<Token>> ApplyLZ77_RLE(
112
    const HistogramParams& params, size_t num_contexts,
113
32.7k
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
114
32.7k
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
115
  // TODO(veluca): tune heuristics here.
116
32.7k
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
117
32.7k
  float bit_decrease = 0;
118
32.7k
  size_t total_symbols = 0;
119
32.7k
  std::vector<float> sym_cost;
120
32.7k
  HybridUintConfig uint_config;
121
84.6k
  for (size_t stream = 0; stream < tokens.size(); stream++) {
122
51.9k
    size_t distance_multiplier =
123
51.9k
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
124
51.9k
    const auto& in = tokens[stream];
125
51.9k
    auto& out = tokens_lz77[stream];
126
51.9k
    total_symbols += in.size();
127
    // Cumulative sum of bit costs.
128
51.9k
    sym_cost.resize(in.size() + 1);
129
36.5M
    for (size_t i = 0; i < in.size(); i++) {
130
36.4M
      uint32_t tok, nbits, unused_bits;
131
36.4M
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
132
36.4M
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
133
36.4M
    }
134
51.9k
    out.reserve(in.size());
135
20.0M
    for (size_t i = 0; i < in.size(); i++) {
136
20.0M
      size_t num_to_copy = 0;
137
20.0M
      size_t distance_symbol = 0;  // 1 for RLE.
138
20.0M
      if (distance_multiplier != 0) {
139
8.46M
        distance_symbol = 1;  // Special distance 1 if enabled.
140
8.46M
        JXL_DASSERT(kSpecialDistances[1][0] == 1);
141
8.46M
        JXL_DASSERT(kSpecialDistances[1][1] == 0);
142
8.46M
      }
143
20.0M
      if (i > 0) {
144
38.8M
        for (; i + num_to_copy < in.size(); num_to_copy++) {
145
38.8M
          if (in[i + num_to_copy].value != in[i - 1].value) {
146
19.9M
            break;
147
19.9M
          }
148
38.8M
        }
149
19.9M
      }
150
20.0M
      if (num_to_copy == 0) {
151
17.5M
        out.push_back(in[i]);
152
17.5M
        continue;
153
17.5M
      }
154
2.41M
      float cost = sym_cost[i + num_to_copy] - sym_cost[i];
155
      // This subtraction might overflow, but that's OK.
156
2.41M
      size_t lz77_len = num_to_copy - lz77.min_length;
157
2.41M
      float lz77_cost = num_to_copy >= lz77.min_length
158
2.41M
                            ? CeilLog2Nonzero(lz77_len + 1) + 1
159
2.41M
                            : 0;
160
2.41M
      if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
161
11.6M
        for (size_t j = 0; j < num_to_copy; j++) {
162
9.49M
          out.push_back(in[i + j]);
163
9.49M
        }
164
2.13M
        i += num_to_copy - 1;
165
2.13M
        continue;
166
2.13M
      }
167
      // Output the LZ77 length
168
277k
      out.emplace_back(in[i].context, static_cast<uint32_t>(lz77_len));
169
277k
      out.back().is_lz77_length = true;
170
277k
      i += num_to_copy - 1;
171
277k
      bit_decrease += cost - lz77_cost;
172
      // Output the LZ77 copy distance.
173
277k
      out.emplace_back(
174
277k
          static_cast<uint32_t>(lz77.nonserialized_distance_context),
175
277k
          static_cast<uint32_t>(distance_symbol));
176
277k
    }
177
51.9k
  }
178
179
32.7k
  if (bit_decrease > total_symbols * 0.2 + 16) {
180
2.83k
    return tokens_lz77;
181
2.83k
  }
182
29.9k
  return {};
183
32.7k
}
184
185
// Hash chain for LZ77 matching
186
struct HashChain {
187
  size_t size_;
188
  std::vector<uint32_t> data_;
189
190
  unsigned hash_num_values_ = 32768;
191
  unsigned hash_mask_ = hash_num_values_ - 1;
192
  unsigned hash_shift_ = 5;
193
194
  std::vector<int> head;
195
  std::vector<uint32_t> chain;
196
  std::vector<int> val;
197
198
  // Speed up repetitions of zero
199
  std::vector<int> headz;
200
  std::vector<uint32_t> chainz;
201
  std::vector<uint32_t> zeros;
202
  uint32_t numzeros = 0;
203
204
  size_t window_size_;
205
  size_t window_mask_;
206
  size_t min_length_;
207
  size_t max_length_;
208
209
  // Map of special distance codes.
210
  std::unordered_map<int, int> special_dist_table_;
211
  size_t num_special_distances_ = 0;
212
213
  uint32_t maxchainlength = 256;  // window_size_ to allow all
214
215
  HashChain(const Token* data, size_t size, size_t window_size,
216
            size_t min_length, size_t max_length, size_t distance_multiplier)
217
0
      : size_(size),
218
0
        window_size_(window_size),
219
0
        window_mask_(window_size - 1),
220
0
        min_length_(min_length),
221
0
        max_length_(max_length) {
222
0
    data_.resize(size);
223
0
    for (size_t i = 0; i < size; i++) {
224
0
      data_[i] = data[i].value;
225
0
    }
226
227
0
    head.resize(hash_num_values_, -1);
228
0
    val.resize(window_size_, -1);
229
0
    chain.resize(window_size_);
230
0
    for (uint32_t i = 0; i < window_size_; ++i) {
231
0
      chain[i] = i;  // same value as index indicates uninitialized
232
0
    }
233
234
0
    zeros.resize(window_size_);
235
0
    headz.resize(window_size_ + 1, -1);
236
0
    chainz.resize(window_size_);
237
0
    for (uint32_t i = 0; i < window_size_; ++i) {
238
0
      chainz[i] = i;
239
0
    }
240
    // Translate distance to special distance code.
241
0
    if (distance_multiplier) {
242
      // Count down, so if due to small distance multiplier multiple distances
243
      // map to the same code, the smallest code will be used in the end.
244
0
      for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
245
0
        special_dist_table_[SpecialDistance(i, distance_multiplier)] = i;
246
0
      }
247
0
      num_special_distances_ = kNumSpecialDistances;
248
0
    }
249
0
  }
250
251
0
  uint32_t GetHash(size_t pos) const {
252
0
    uint32_t result = 0;
253
0
    if (pos + 2 < size_) {
254
      // TODO(lode): take the MSB's of the uint32_t values into account as well,
255
      // given that the hash code itself is less than 32 bits.
256
0
      result ^= static_cast<uint32_t>(data_[pos + 0] << 0u);
257
0
      result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_);
258
0
      result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2));
259
0
    } else {
260
      // No need to compute hash of last 2 bytes, the length 2 is too short.
261
0
      return 0;
262
0
    }
263
0
    return result & hash_mask_;
264
0
  }
265
266
0
  uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
267
0
    size_t end = pos + window_size_;
268
0
    if (end > size_) end = size_;
269
0
    if (prevzeros > 0) {
270
0
      if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
271
0
          end == pos + window_size_) {
272
0
        return prevzeros;
273
0
      } else {
274
0
        return prevzeros - 1;
275
0
      }
276
0
    }
277
0
    uint32_t num = 0;
278
0
    while (pos + num < end && data_[pos + num] == 0) num++;
279
0
    return num;
280
0
  }
281
282
0
  void Update(size_t pos) {
283
0
    uint32_t hashval = GetHash(pos);
284
0
    uint32_t wpos = pos & window_mask_;
285
286
0
    val[wpos] = static_cast<int>(hashval);
287
0
    if (head[hashval] != -1) chain[wpos] = head[hashval];
288
0
    head[hashval] = wpos;
289
290
0
    if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
291
0
    numzeros = CountZeros(pos, numzeros);
292
293
0
    zeros[wpos] = numzeros;
294
0
    if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
295
0
    headz[numzeros] = wpos;
296
0
  }
297
298
0
  void Update(size_t pos, size_t len) {
299
0
    for (size_t i = 0; i < len; i++) {
300
0
      Update(pos + i);
301
0
    }
302
0
  }
303
304
  template <typename CB>
305
0
  void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
306
0
    uint32_t wpos = pos & window_mask_;
307
0
    uint32_t hashval = GetHash(pos);
308
0
    uint32_t hashpos = chain[wpos];
309
310
0
    int prev_dist = 0;
311
0
    int end = std::min<int>(pos + max_length_, size_);
312
0
    uint32_t chainlength = 0;
313
0
    uint32_t best_len = 0;
314
0
    for (;;) {
315
0
      int dist = (hashpos <= wpos) ? (wpos - hashpos)
316
0
                                   : (wpos - hashpos + window_mask_ + 1);
317
0
      if (dist < prev_dist) break;
318
0
      prev_dist = dist;
319
0
      uint32_t len = 0;
320
0
      if (dist > 0) {
321
0
        int i = pos;
322
0
        int j = pos - dist;
323
0
        if (numzeros > 3) {
324
0
          int r = std::min<int>(numzeros - 1, zeros[hashpos]);
325
0
          if (i + r >= end) r = end - i - 1;
326
0
          i += r;
327
0
          j += r;
328
0
        }
329
0
        while (i < end && data_[i] == data_[j]) {
330
0
          i++;
331
0
          j++;
332
0
        }
333
0
        len = i - pos;
334
        // This can trigger even if the new length is slightly smaller than the
335
        // best length, because it is possible for a slightly cheaper distance
336
        // symbol to occur.
337
0
        if (len >= min_length_ && len + 2 >= best_len) {
338
0
          auto it = special_dist_table_.find(dist);
339
0
          int dist_symbol = (it == special_dist_table_.end())
340
0
                                ? (num_special_distances_ + dist - 1)
341
0
                                : it->second;
342
0
          found_match(len, dist_symbol);
343
0
          if (len > best_len) best_len = len;
344
0
        }
345
0
      }
346
347
0
      chainlength++;
348
0
      if (chainlength >= maxchainlength) break;
349
350
0
      if (numzeros >= 3 && len > numzeros) {
351
0
        if (hashpos == chainz[hashpos]) break;
352
0
        hashpos = chainz[hashpos];
353
0
        if (zeros[hashpos] != numzeros) break;
354
0
      } else {
355
0
        if (hashpos == chain[hashpos]) break;
356
0
        hashpos = chain[hashpos];
357
0
        if (val[hashpos] != static_cast<int>(hashval)) {
358
          // outdated hash value
359
0
          break;
360
0
        }
361
0
      }
362
0
    }
363
0
  }
Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1}>(unsigned long, int, jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1} const&) const
Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0>(unsigned long, int, jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0 const&) const
364
  void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
365
0
                 size_t* result_len) const {
366
0
    *result_dist_symbol = 0;
367
0
    *result_len = 1;
368
0
    FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
369
0
      if (len > *result_len ||
370
0
          (len == *result_len && *result_dist_symbol > dist_symbol)) {
371
0
        *result_len = len;
372
0
        *result_dist_symbol = dist_symbol;
373
0
      }
374
0
    });
375
0
  }
376
};
377
378
0
float LenCost(size_t len) {
379
0
  uint32_t nbits, bits, tok;
380
0
  HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
381
0
  constexpr float kCostTable[] = {
382
0
      2.797667318563126,  3.213177690381199,  2.5706009246743737,
383
0
      2.408392498667534,  2.829649191872326,  3.3923087753324577,
384
0
      4.029267451554331,  4.415576699706408,  4.509357574741465,
385
0
      9.21481543803004,   10.020590190114898, 11.858671627804766,
386
0
      12.45853300490526,  11.713105831990857, 12.561996324849314,
387
0
      13.775477692278367, 13.174027068768641,
388
0
  };
389
0
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
390
0
  if (tok >= table_size) tok = table_size - 1;
391
0
  return kCostTable[tok] + nbits;
392
0
}
393
394
// TODO(veluca): this does not take into account usage or non-usage of distance
395
// multipliers.
396
0
float DistCost(size_t dist) {
397
0
  uint32_t nbits, bits, tok;
398
0
  HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
399
0
  constexpr float kCostTable[] = {
400
0
      6.368282626312716,  5.680793277090298,  8.347404197105247,
401
0
      7.641619201599141,  6.914328374119438,  7.959808291537444,
402
0
      8.70023120759855,   8.71378518934703,   9.379132523982769,
403
0
      9.110472749092708,  9.159029569270908,  9.430936766731973,
404
0
      7.278284055315169,  7.8278514904267755, 10.026641158289236,
405
0
      9.976049229827066,  9.64351607048908,   9.563403863480442,
406
0
      10.171474111762747, 10.45950155077234,  9.994813912104219,
407
0
      10.322524683741156, 8.465808729388186,  8.756254166066853,
408
0
      10.160930174662234, 10.247329273413435, 10.04090403724809,
409
0
      10.129398517544082, 9.342311691539546,  9.07608009102374,
410
0
      10.104799540677513, 10.378079384990906, 10.165828974075072,
411
0
      10.337595322341553, 7.940557464567944,  10.575665823319431,
412
0
      11.023344321751955, 10.736144698831827, 11.118277044595054,
413
0
      7.468468230648442,  10.738305230932939, 10.906980780216568,
414
0
      10.163468216353817, 10.17805759656433,  11.167283670483565,
415
0
      11.147050200274544, 10.517921919244333, 10.651764778156886,
416
0
      10.17074446448919,  11.217636876224745, 11.261630721139484,
417
0
      11.403140815247259, 10.892472096873417, 11.1859607804481,
418
0
      8.017346947551262,  7.895143720278828,  11.036577113822025,
419
0
      11.170562110315794, 10.326988722591086, 10.40872184751056,
420
0
      11.213498225466386, 11.30580635516863,  10.672272515665442,
421
0
      10.768069466228063, 11.145257364153565, 11.64668307145549,
422
0
      10.593156194627339, 11.207499484844943, 10.767517766396908,
423
0
      10.826629811407042, 10.737764794499988, 10.6200448518045,
424
0
      10.191315385198092, 8.468384171390085,  11.731295299170432,
425
0
      11.824619886654398, 10.41518844301179,  10.16310536548649,
426
0
      10.539423685097576, 10.495136599328031, 10.469112847728267,
427
0
      11.72057686174922,  10.910326337834674, 11.378921834673758,
428
0
      11.847759036098536, 11.92071647623854,  10.810628276345282,
429
0
      11.008601085273893, 11.910326337834674, 11.949212023423133,
430
0
      11.298614839104337, 11.611603659010392, 10.472930394619985,
431
0
      11.835564720850282, 11.523267392285337, 12.01055816679611,
432
0
      8.413029688994023,  11.895784139536406, 11.984679534970505,
433
0
      11.220654278717394, 11.716311684833672, 10.61036646226114,
434
0
      10.89849965960364,  10.203762898863669, 10.997560826267238,
435
0
      11.484217379438984, 11.792836176993665, 12.24310468755171,
436
0
      11.464858097919262, 12.212747017409377, 11.425595666074955,
437
0
      11.572048533398757, 12.742093965163013, 11.381874288645637,
438
0
      12.191870445817015, 11.683156920035426, 11.152442115262197,
439
0
      11.90303691580457,  11.653292787169159, 11.938615382266098,
440
0
      16.970641701570223, 16.853602280380002, 17.26240782594733,
441
0
      16.644655390108507, 17.14310889757499,  16.910935455445955,
442
0
      17.505678976959697, 17.213498225466388, 2.4162310293553024,
443
0
      3.494587244462329,  3.5258600986408344, 3.4959806589517095,
444
0
      3.098390886949687,  3.343454654302911,  3.588847442290287,
445
0
      4.14614790111827,   5.152948641990529,  7.433696808092598,
446
0
      9.716311684833672,
447
0
  };
448
0
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
449
0
  if (tok >= table_size) tok = table_size - 1;
450
0
  return kCostTable[tok] + nbits;
451
0
}
452
453
std::vector<std::vector<Token>> ApplyLZ77_LZ77(
454
    const HistogramParams& params, size_t num_contexts,
455
0
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
456
0
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
457
  // TODO(veluca): tune heuristics here.
458
0
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
459
0
  float bit_decrease = 0;
460
0
  size_t total_symbols = 0;
461
0
  HybridUintConfig uint_config;
462
0
  std::vector<float> sym_cost;
463
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
464
0
    size_t distance_multiplier =
465
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
466
0
    const auto& in = tokens[stream];
467
0
    auto& out = tokens_lz77[stream];
468
0
    total_symbols += in.size();
469
    // Cumulative sum of bit costs.
470
0
    sym_cost.resize(in.size() + 1);
471
0
    for (size_t i = 0; i < in.size(); i++) {
472
0
      uint32_t tok, nbits, unused_bits;
473
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
474
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
475
0
    }
476
477
0
    out.reserve(in.size());
478
0
    size_t max_distance = in.size();
479
0
    size_t min_length = lz77.min_length;
480
0
    JXL_DASSERT(min_length >= 3);
481
0
    size_t max_length = in.size();
482
483
    // Use next power of two as window size.
484
0
    size_t window_size = 1;
485
0
    while (window_size < max_distance && window_size < kWindowSize) {
486
0
      window_size <<= 1;
487
0
    }
488
489
0
    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
490
0
                    distance_multiplier);
491
0
    size_t len;
492
0
    size_t dist_symbol;
493
494
0
    const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
495
496
    // Whether the next symbol was already updated (to test lazy matching)
497
0
    bool already_updated = false;
498
0
    for (size_t i = 0; i < in.size(); i++) {
499
0
      out.push_back(in[i]);
500
0
      if (!already_updated) chain.Update(i);
501
0
      already_updated = false;
502
0
      chain.FindMatch(i, max_distance, &dist_symbol, &len);
503
0
      if (len >= min_length) {
504
0
        if (len < max_lazy_match_len && i + 1 < in.size()) {
505
          // Try length at next symbol lazy matching
506
0
          chain.Update(i + 1);
507
0
          already_updated = true;
508
0
          size_t len2, dist_symbol2;
509
0
          chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
510
0
          if (len2 > len) {
511
            // Use the lazy match. Add literal, and use the next length starting
512
            // from the next byte.
513
0
            ++i;
514
0
            already_updated = false;
515
0
            len = len2;
516
0
            dist_symbol = dist_symbol2;
517
0
            out.push_back(in[i]);
518
0
          }
519
0
        }
520
521
0
        float cost = sym_cost[i + len] - sym_cost[i];
522
0
        size_t lz77_len = len - lz77.min_length;
523
0
        float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
524
0
                          sce.AddSymbolCost(out.back().context);
525
526
0
        if (lz77_cost <= cost) {
527
0
          out.back().value = len - min_length;
528
0
          out.back().is_lz77_length = true;
529
0
          out.emplace_back(
530
0
              static_cast<uint32_t>(lz77.nonserialized_distance_context),
531
0
              static_cast<uint32_t>(dist_symbol));
532
0
          bit_decrease += cost - lz77_cost;
533
0
        } else {
534
          // LZ77 match ignored, and symbol already pushed. Push all other
535
          // symbols and skip.
536
0
          for (size_t j = 1; j < len; j++) {
537
0
            out.push_back(in[i + j]);
538
0
          }
539
0
        }
540
541
0
        if (already_updated) {
542
0
          chain.Update(i + 2, len - 2);
543
0
          already_updated = false;
544
0
        } else {
545
0
          chain.Update(i + 1, len - 1);
546
0
        }
547
0
        i += len - 1;
548
0
      } else {
549
        // Literal, already pushed
550
0
      }
551
0
    }
552
0
  }
553
554
0
  if (bit_decrease > total_symbols * 0.2 + 16) {
555
0
    return tokens_lz77;
556
0
  }
557
0
  return {};
558
0
}
559
560
std::vector<std::vector<Token>> ApplyLZ77_Optimal(
561
    const HistogramParams& params, size_t num_contexts,
562
0
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
563
0
  std::vector<std::vector<Token>> tokens_for_cost_estimate =
564
0
      ApplyLZ77_LZ77(params, num_contexts, tokens, lz77);
565
  // If greedy-LZ77 does not give better compression than no-lz77, no reason to
566
  // run the optimal matching.
567
0
  if (tokens_for_cost_estimate.empty()) return {};
568
0
  SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
569
0
                          tokens_for_cost_estimate, lz77);
570
0
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
571
0
  HybridUintConfig uint_config;
572
0
  std::vector<float> sym_cost;
573
0
  std::vector<uint32_t> dist_symbols;
574
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
575
0
    size_t distance_multiplier =
576
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
577
0
    const auto& in = tokens[stream];
578
0
    auto& out = tokens_lz77[stream];
579
    // Cumulative sum of bit costs.
580
0
    sym_cost.resize(in.size() + 1);
581
0
    for (size_t i = 0; i < in.size(); i++) {
582
0
      uint32_t tok, nbits, unused_bits;
583
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
584
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
585
0
    }
586
587
0
    out.reserve(in.size());
588
0
    size_t max_distance = in.size();
589
0
    size_t min_length = lz77.min_length;
590
0
    JXL_DASSERT(min_length >= 3);
591
0
    size_t max_length = in.size();
592
593
    // Use next power of two as window size.
594
0
    size_t window_size = 1;
595
0
    while (window_size < max_distance && window_size < kWindowSize) {
596
0
      window_size <<= 1;
597
0
    }
598
599
0
    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
600
0
                    distance_multiplier);
601
602
0
    struct MatchInfo {
603
0
      uint32_t len;
604
0
      uint32_t dist_symbol;
605
0
      uint32_t ctx;
606
0
      float total_cost = std::numeric_limits<float>::max();
607
0
    };
608
    // Total cost to encode the first N symbols.
609
0
    std::vector<MatchInfo> prefix_costs(in.size() + 1);
610
0
    prefix_costs[0].total_cost = 0;
611
612
0
    size_t rle_length = 0;
613
0
    size_t skip_lz77 = 0;
614
0
    for (size_t i = 0; i < in.size(); i++) {
615
0
      chain.Update(i);
616
0
      float lit_cost =
617
0
          prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
618
0
      if (prefix_costs[i + 1].total_cost > lit_cost) {
619
0
        prefix_costs[i + 1].dist_symbol = 0;
620
0
        prefix_costs[i + 1].len = 1;
621
0
        prefix_costs[i + 1].ctx = in[i].context;
622
0
        prefix_costs[i + 1].total_cost = lit_cost;
623
0
      }
624
0
      if (skip_lz77 > 0) {
625
0
        skip_lz77--;
626
0
        continue;
627
0
      }
628
0
      dist_symbols.clear();
629
0
      chain.FindMatches(i, max_distance,
630
0
                        [&dist_symbols](size_t len, size_t dist_symbol) {
631
0
                          if (dist_symbols.size() <= len) {
632
0
                            dist_symbols.resize(len + 1, dist_symbol);
633
0
                          }
634
0
                          if (dist_symbol < dist_symbols[len]) {
635
0
                            dist_symbols[len] = dist_symbol;
636
0
                          }
637
0
                        });
638
0
      if (dist_symbols.size() <= min_length) continue;
639
0
      {
640
0
        size_t best_cost = dist_symbols.back();
641
0
        for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
642
0
          if (dist_symbols[j] < best_cost) {
643
0
            best_cost = dist_symbols[j];
644
0
          }
645
0
          dist_symbols[j] = best_cost;
646
0
        }
647
0
      }
648
0
      for (size_t j = min_length; j < dist_symbols.size(); j++) {
649
        // Cost model that uses results from lazy LZ77.
650
0
        float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
651
0
                          sce.DistCost(dist_symbols[j], lz77);
652
0
        float cost = prefix_costs[i].total_cost + lz77_cost;
653
0
        if (prefix_costs[i + j].total_cost > cost) {
654
0
          prefix_costs[i + j].len = j;
655
0
          prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
656
0
          prefix_costs[i + j].ctx = in[i].context;
657
0
          prefix_costs[i + j].total_cost = cost;
658
0
        }
659
0
      }
660
      // We are in a RLE sequence: skip all the symbols except the first 8 and
661
      // the last 8. This avoid quadratic costs for sequences with long runs of
662
      // the same symbol.
663
0
      if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
664
0
          (dist_symbols.back() == 1 && distance_multiplier != 0)) {
665
0
        rle_length++;
666
0
      } else {
667
0
        rle_length = 0;
668
0
      }
669
0
      if (rle_length >= 8 && dist_symbols.size() > 9) {
670
0
        skip_lz77 = dist_symbols.size() - 10;
671
0
        rle_length = 0;
672
0
      }
673
0
    }
674
0
    size_t pos = in.size();
675
0
    while (pos > 0) {
676
0
      bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
677
0
      if (is_lz77_length) {
678
0
        size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
679
0
        out.emplace_back(
680
0
            static_cast<uint32_t>(lz77.nonserialized_distance_context),
681
0
            static_cast<uint32_t>(dist_symbol));
682
0
      }
683
0
      uint32_t val =
684
0
          is_lz77_length
685
0
              ? (prefix_costs[pos].len - static_cast<uint32_t>(min_length))
686
0
              : in[pos - 1].value;
687
0
      out.emplace_back(prefix_costs[pos].ctx, val);
688
0
      out.back().is_lz77_length = is_lz77_length;
689
0
      pos -= prefix_costs[pos].len;
690
0
    }
691
0
    if (!out.empty()) std::reverse(out.begin(), out.end());
692
0
  }
693
0
  return tokens_lz77;
694
0
}
695
696
}  // namespace
697
698
std::vector<std::vector<Token>> ApplyLZ77(
699
    const HistogramParams& params, size_t num_contexts,
700
41.8k
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
701
41.8k
  switch (params.lz77_method) {
702
32.7k
    case HistogramParams::LZ77Method::kRLE:
703
32.7k
      return ApplyLZ77_RLE(params, num_contexts, tokens, lz77);
704
0
    case HistogramParams::LZ77Method::kLZ77:
705
0
      return ApplyLZ77_LZ77(params, num_contexts, tokens, lz77);
706
0
    case HistogramParams::LZ77Method::kOptimal:
707
0
      return ApplyLZ77_Optimal(params, num_contexts, tokens, lz77);
708
9.06k
    default:
709
9.06k
      return {};
710
41.8k
  }
711
41.8k
}
712
713
}  // namespace jxl