Coverage Report

Created: 2026-06-16 07:20

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/enc_lz77.cc
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_lz77.h"
7
8
#include <algorithm>
9
#include <cmath>
10
#include <cstddef>
11
#include <cstdint>
12
#include <limits>
13
#include <unordered_map>
14
#include <vector>
15
16
#include "lib/jxl/ans_params.h"
17
#include "lib/jxl/base/bits.h"
18
#include "lib/jxl/base/fast_math-inl.h"
19
#include "lib/jxl/base/status.h"
20
#include "lib/jxl/dec_ans.h"
21
#include "lib/jxl/enc_ans.h"
22
#include "lib/jxl/enc_ans_params.h"
23
24
namespace jxl {
25
26
namespace {
27
28
class SymbolCostEstimator {
29
 public:
30
  SymbolCostEstimator(size_t num_contexts, bool force_huffman,
31
                      const std::vector<std::vector<Token>>& tokens,
32
63.8k
                      const LZ77Params& lz77) {
33
63.8k
    std::vector<Histogram> builder(num_contexts);
34
    // Build histograms for estimating lz77 savings.
35
63.8k
    HybridUintConfig uint_config;
36
109k
    for (const auto& stream : tokens) {
37
101M
      for (const auto& token : stream) {
38
101M
        uint32_t tok, nbits, bits;
39
101M
        (token.is_lz77_length ? lz77.length_uint_config : uint_config)
40
101M
            .Encode(token.value, &tok, &nbits, &bits);
41
101M
        tok += token.is_lz77_length ? lz77.min_symbol : 0;
42
101M
        JXL_DASSERT(token.context < num_contexts);
43
101M
        builder[token.context].Add(tok);
44
101M
      }
45
109k
    }
46
63.8k
    max_alphabet_size_ = 0;
47
249k
    for (size_t i = 0; i < num_contexts; i++) {
48
185k
      max_alphabet_size_ =
49
185k
          std::max(max_alphabet_size_, builder[i].counts.size());
50
185k
    }
51
63.8k
    bits_.resize(num_contexts * max_alphabet_size_);
52
    // TODO(veluca): SIMD?
53
63.8k
    add_symbol_cost_.resize(num_contexts);
54
249k
    for (size_t i = 0; i < num_contexts; i++) {
55
185k
      float inv_total = 1.0f / (builder[i].total_count + 1e-8f);
56
185k
      float total_cost = 0;
57
3.98M
      for (size_t j = 0; j < builder[i].counts.size(); j++) {
58
3.79M
        size_t cnt = builder[i].counts[j];
59
3.79M
        float cost = 0;
60
3.79M
        if (cnt != 0 && cnt != builder[i].total_count) {
61
1.80M
          cost = -FastLog2f(cnt * inv_total);
62
1.80M
          if (force_huffman) cost = std::ceil(cost);
63
1.99M
        } else if (cnt == 0) {
64
1.97M
          cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
65
1.97M
        }
66
3.79M
        bits_[i * max_alphabet_size_ + j] = cost;
67
3.79M
        total_cost += cost * builder[i].counts[j];
68
3.79M
      }
69
      // Penalty for adding a lz77 symbol to this contest (only used for static
70
      // cost model). Higher penalty for contexts that have a very low
71
      // per-symbol entropy.
72
185k
      add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
73
185k
    }
74
63.8k
  }
75
101M
  float Bits(size_t ctx, size_t sym) const {
76
101M
    return bits_[ctx * max_alphabet_size_ + sym];
77
101M
  }
78
0
  float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
79
0
    uint32_t nbits, bits, tok;
80
0
    lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
81
0
    tok += lz77.min_symbol;
82
0
    return nbits + Bits(ctx, tok);
83
0
  }
84
0
  float DistCost(size_t len, const LZ77Params& lz77) const {
85
0
    uint32_t nbits, bits, tok;
86
0
    HybridUintConfig().Encode(len, &tok, &nbits, &bits);
87
0
    return nbits + Bits(lz77.nonserialized_distance_context, tok);
88
0
  }
89
0
  float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
90
91
 private:
92
  size_t max_alphabet_size_;
93
  std::vector<float> bits_;
94
  std::vector<float> add_symbol_cost_;
95
};
96
97
std::vector<std::vector<Token>> ApplyLZ77_RLE(
98
    const HistogramParams& params, size_t num_contexts,
99
63.8k
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
100
63.8k
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
101
  // TODO(veluca): tune heuristics here.
102
63.8k
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
103
63.8k
  float bit_decrease = 0;
104
63.8k
  size_t total_symbols = 0;
105
63.8k
  std::vector<float> sym_cost;
106
63.8k
  HybridUintConfig uint_config;
107
173k
  for (size_t stream = 0; stream < tokens.size(); stream++) {
108
109k
    size_t distance_multiplier =
109
109k
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
110
109k
    const auto& in = tokens[stream];
111
109k
    auto& out = tokens_lz77[stream];
112
109k
    total_symbols += in.size();
113
    // Cumulative sum of bit costs.
114
109k
    sym_cost.resize(in.size() + 1);
115
101M
    for (size_t i = 0; i < in.size(); i++) {
116
101M
      uint32_t tok, nbits, unused_bits;
117
101M
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
118
101M
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
119
101M
    }
120
109k
    out.reserve(in.size());
121
38.3M
    for (size_t i = 0; i < in.size(); i++) {
122
38.2M
      size_t num_to_copy = 0;
123
38.2M
      size_t distance_symbol = 0;  // 1 for RLE.
124
38.2M
      if (distance_multiplier != 0) {
125
15.3M
        distance_symbol = 1;  // Special distance 1 if enabled.
126
15.3M
        JXL_DASSERT(kSpecialDistances[1][0] == 1);
127
15.3M
        JXL_DASSERT(kSpecialDistances[1][1] == 0);
128
15.3M
      }
129
38.2M
      if (i > 0) {
130
107M
        for (; i + num_to_copy < in.size(); num_to_copy++) {
131
107M
          if (in[i + num_to_copy].value != in[i - 1].value) {
132
38.1M
            break;
133
38.1M
          }
134
107M
        }
135
38.1M
      }
136
38.2M
      if (num_to_copy == 0) {
137
32.4M
        out.push_back(in[i]);
138
32.4M
        continue;
139
32.4M
      }
140
5.79M
      float cost = sym_cost[i + num_to_copy] - sym_cost[i];
141
      // This subtraction might overflow, but that's OK.
142
5.79M
      size_t lz77_len = num_to_copy - lz77.min_length;
143
5.79M
      float lz77_cost = num_to_copy >= lz77.min_length
144
5.79M
                            ? CeilLog2Nonzero(lz77_len + 1) + 1
145
5.79M
                            : 0;
146
5.79M
      if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
147
25.6M
        for (size_t j = 0; j < num_to_copy; j++) {
148
20.5M
          out.push_back(in[i + j]);
149
20.5M
        }
150
5.04M
        i += num_to_copy - 1;
151
5.04M
        continue;
152
5.04M
      }
153
      // Output the LZ77 length
154
751k
      out.emplace_back(in[i].context, static_cast<uint32_t>(lz77_len));
155
751k
      out.back().is_lz77_length = true;
156
751k
      i += num_to_copy - 1;
157
751k
      bit_decrease += cost - lz77_cost;
158
      // Output the LZ77 copy distance.
159
751k
      out.emplace_back(
160
751k
          static_cast<uint32_t>(lz77.nonserialized_distance_context),
161
751k
          static_cast<uint32_t>(distance_symbol));
162
751k
    }
163
109k
  }
164
165
63.8k
  if (bit_decrease > total_symbols * 0.2 + 16) {
166
6.84k
    return tokens_lz77;
167
6.84k
  }
168
56.9k
  return {};
169
63.8k
}
170
171
// Hash chain for LZ77 matching
172
struct HashChain {
173
  size_t size_;
174
  std::vector<uint32_t> data_;
175
176
  unsigned hash_num_values_ = 32768;
177
  unsigned hash_mask_ = hash_num_values_ - 1;
178
  unsigned hash_shift_ = 5;
179
180
  std::vector<int> head;
181
  std::vector<uint32_t> chain;
182
  std::vector<int> val;
183
184
  // Speed up repetitions of zero
185
  std::vector<int> headz;
186
  std::vector<uint32_t> chainz;
187
  std::vector<uint32_t> zeros;
188
  uint32_t numzeros = 0;
189
190
  size_t window_size_;
191
  size_t window_mask_;
192
  size_t min_length_;
193
  size_t max_length_;
194
195
  // Map of special distance codes.
196
  std::unordered_map<int, int> special_dist_table_;
197
  size_t num_special_distances_ = 0;
198
199
  uint32_t maxchainlength = 256;  // window_size_ to allow all
200
201
  HashChain(const Token* data, size_t size, size_t window_size,
202
            size_t min_length, size_t max_length, size_t distance_multiplier)
203
0
      : size_(size),
204
0
        window_size_(window_size),
205
0
        window_mask_(window_size - 1),
206
0
        min_length_(min_length),
207
0
        max_length_(max_length) {
208
0
    data_.resize(size);
209
0
    for (size_t i = 0; i < size; i++) {
210
0
      data_[i] = data[i].value;
211
0
    }
212
213
0
    head.resize(hash_num_values_, -1);
214
0
    val.resize(window_size_, -1);
215
0
    chain.resize(window_size_);
216
0
    for (uint32_t i = 0; i < window_size_; ++i) {
217
0
      chain[i] = i;  // same value as index indicates uninitialized
218
0
    }
219
220
0
    zeros.resize(window_size_);
221
0
    headz.resize(window_size_ + 1, -1);
222
0
    chainz.resize(window_size_);
223
0
    for (uint32_t i = 0; i < window_size_; ++i) {
224
0
      chainz[i] = i;
225
0
    }
226
    // Translate distance to special distance code.
227
0
    if (distance_multiplier) {
228
      // Count down, so if due to small distance multiplier multiple distances
229
      // map to the same code, the smallest code will be used in the end.
230
0
      for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
231
0
        special_dist_table_[SpecialDistance(i, distance_multiplier)] = i;
232
0
      }
233
0
      num_special_distances_ = kNumSpecialDistances;
234
0
    }
235
0
  }
236
237
0
  uint32_t GetHash(size_t pos) const {
238
0
    uint32_t result = 0;
239
0
    if (pos + 2 < size_) {
240
      // TODO(lode): take the MSB's of the uint32_t values into account as well,
241
      // given that the hash code itself is less than 32 bits.
242
0
      result ^= static_cast<uint32_t>(data_[pos + 0] << 0u);
243
0
      result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_);
244
0
      result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2));
245
0
    } else {
246
      // No need to compute hash of last 2 bytes, the length 2 is too short.
247
0
      return 0;
248
0
    }
249
0
    return result & hash_mask_;
250
0
  }
251
252
0
  uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
253
0
    size_t end = pos + window_size_;
254
0
    if (end > size_) end = size_;
255
0
    if (prevzeros > 0) {
256
0
      if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
257
0
          end == pos + window_size_) {
258
0
        return prevzeros;
259
0
      } else {
260
0
        return prevzeros - 1;
261
0
      }
262
0
    }
263
0
    uint32_t num = 0;
264
0
    while (pos + num < end && data_[pos + num] == 0) num++;
265
0
    return num;
266
0
  }
267
268
0
  void Update(size_t pos) {
269
0
    uint32_t hashval = GetHash(pos);
270
0
    uint32_t wpos = pos & window_mask_;
271
272
0
    val[wpos] = static_cast<int>(hashval);
273
0
    if (head[hashval] != -1) chain[wpos] = head[hashval];
274
0
    head[hashval] = wpos;
275
276
0
    if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
277
0
    numzeros = CountZeros(pos, numzeros);
278
279
0
    zeros[wpos] = numzeros;
280
0
    if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
281
0
    headz[numzeros] = wpos;
282
0
  }
283
284
0
  void Update(size_t pos, size_t len) {
285
0
    for (size_t i = 0; i < len; i++) {
286
0
      Update(pos + i);
287
0
    }
288
0
  }
289
290
  template <typename CB>
291
0
  void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
292
0
    uint32_t wpos = pos & window_mask_;
293
0
    uint32_t hashval = GetHash(pos);
294
0
    uint32_t hashpos = chain[wpos];
295
296
0
    int prev_dist = 0;
297
0
    int end = std::min<int>(pos + max_length_, size_);
298
0
    uint32_t chainlength = 0;
299
0
    uint32_t best_len = 0;
300
0
    for (;;) {
301
0
      int dist = (hashpos <= wpos) ? (wpos - hashpos)
302
0
                                   : (wpos - hashpos + window_mask_ + 1);
303
0
      if (dist < prev_dist) break;
304
0
      prev_dist = dist;
305
0
      uint32_t len = 0;
306
0
      if (dist > 0) {
307
0
        int i = pos;
308
0
        int j = pos - dist;
309
0
        if (numzeros > 3) {
310
0
          int r = std::min<int>(numzeros - 1, zeros[hashpos]);
311
0
          if (i + r >= end) r = end - i - 1;
312
0
          i += r;
313
0
          j += r;
314
0
        }
315
0
        while (i < end && data_[i] == data_[j]) {
316
0
          i++;
317
0
          j++;
318
0
        }
319
0
        len = i - pos;
320
        // This can trigger even if the new length is slightly smaller than the
321
        // best length, because it is possible for a slightly cheaper distance
322
        // symbol to occur.
323
0
        if (len >= min_length_ && len + 2 >= best_len) {
324
0
          auto it = special_dist_table_.find(dist);
325
0
          int dist_symbol = (it == special_dist_table_.end())
326
0
                                ? (num_special_distances_ + dist - 1)
327
0
                                : it->second;
328
0
          found_match(len, dist_symbol);
329
0
          if (len > best_len) best_len = len;
330
0
        }
331
0
      }
332
333
0
      chainlength++;
334
0
      if (chainlength >= maxchainlength) break;
335
336
0
      if (numzeros >= 3 && len > numzeros) {
337
0
        if (hashpos == chainz[hashpos]) break;
338
0
        hashpos = chainz[hashpos];
339
0
        if (zeros[hashpos] != numzeros) break;
340
0
      } else {
341
0
        if (hashpos == chain[hashpos]) break;
342
0
        hashpos = chain[hashpos];
343
0
        if (val[hashpos] != static_cast<int>(hashval)) {
344
          // outdated hash value
345
0
          break;
346
0
        }
347
0
      }
348
0
    }
349
0
  }
Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1}>(unsigned long, int, jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1} const&) const
Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0>(unsigned long, int, jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0 const&) const
350
  void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
351
0
                 size_t* result_len) const {
352
0
    *result_dist_symbol = 0;
353
0
    *result_len = 1;
354
0
    FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
355
0
      if (len > *result_len ||
356
0
          (len == *result_len && *result_dist_symbol > dist_symbol)) {
357
0
        *result_len = len;
358
0
        *result_dist_symbol = dist_symbol;
359
0
      }
360
0
    });
361
0
  }
362
};
363
364
0
float LenCost(size_t len) {
365
0
  uint32_t nbits, bits, tok;
366
0
  HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
367
0
  constexpr float kCostTable[] = {
368
0
      2.797667318563126,  3.213177690381199,  2.5706009246743737,
369
0
      2.408392498667534,  2.829649191872326,  3.3923087753324577,
370
0
      4.029267451554331,  4.415576699706408,  4.509357574741465,
371
0
      9.21481543803004,   10.020590190114898, 11.858671627804766,
372
0
      12.45853300490526,  11.713105831990857, 12.561996324849314,
373
0
      13.775477692278367, 13.174027068768641,
374
0
  };
375
0
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
376
0
  if (tok >= table_size) tok = table_size - 1;
377
0
  return kCostTable[tok] + nbits;
378
0
}
379
380
// TODO(veluca): this does not take into account usage or non-usage of distance
381
// multipliers.
382
0
float DistCost(size_t dist) {
383
0
  uint32_t nbits, bits, tok;
384
0
  HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
385
0
  constexpr float kCostTable[] = {
386
0
      6.368282626312716,  5.680793277090298,  8.347404197105247,
387
0
      7.641619201599141,  6.914328374119438,  7.959808291537444,
388
0
      8.70023120759855,   8.71378518934703,   9.379132523982769,
389
0
      9.110472749092708,  9.159029569270908,  9.430936766731973,
390
0
      7.278284055315169,  7.8278514904267755, 10.026641158289236,
391
0
      9.976049229827066,  9.64351607048908,   9.563403863480442,
392
0
      10.171474111762747, 10.45950155077234,  9.994813912104219,
393
0
      10.322524683741156, 8.465808729388186,  8.756254166066853,
394
0
      10.160930174662234, 10.247329273413435, 10.04090403724809,
395
0
      10.129398517544082, 9.342311691539546,  9.07608009102374,
396
0
      10.104799540677513, 10.378079384990906, 10.165828974075072,
397
0
      10.337595322341553, 7.940557464567944,  10.575665823319431,
398
0
      11.023344321751955, 10.736144698831827, 11.118277044595054,
399
0
      7.468468230648442,  10.738305230932939, 10.906980780216568,
400
0
      10.163468216353817, 10.17805759656433,  11.167283670483565,
401
0
      11.147050200274544, 10.517921919244333, 10.651764778156886,
402
0
      10.17074446448919,  11.217636876224745, 11.261630721139484,
403
0
      11.403140815247259, 10.892472096873417, 11.1859607804481,
404
0
      8.017346947551262,  7.895143720278828,  11.036577113822025,
405
0
      11.170562110315794, 10.326988722591086, 10.40872184751056,
406
0
      11.213498225466386, 11.30580635516863,  10.672272515665442,
407
0
      10.768069466228063, 11.145257364153565, 11.64668307145549,
408
0
      10.593156194627339, 11.207499484844943, 10.767517766396908,
409
0
      10.826629811407042, 10.737764794499988, 10.6200448518045,
410
0
      10.191315385198092, 8.468384171390085,  11.731295299170432,
411
0
      11.824619886654398, 10.41518844301179,  10.16310536548649,
412
0
      10.539423685097576, 10.495136599328031, 10.469112847728267,
413
0
      11.72057686174922,  10.910326337834674, 11.378921834673758,
414
0
      11.847759036098536, 11.92071647623854,  10.810628276345282,
415
0
      11.008601085273893, 11.910326337834674, 11.949212023423133,
416
0
      11.298614839104337, 11.611603659010392, 10.472930394619985,
417
0
      11.835564720850282, 11.523267392285337, 12.01055816679611,
418
0
      8.413029688994023,  11.895784139536406, 11.984679534970505,
419
0
      11.220654278717394, 11.716311684833672, 10.61036646226114,
420
0
      10.89849965960364,  10.203762898863669, 10.997560826267238,
421
0
      11.484217379438984, 11.792836176993665, 12.24310468755171,
422
0
      11.464858097919262, 12.212747017409377, 11.425595666074955,
423
0
      11.572048533398757, 12.742093965163013, 11.381874288645637,
424
0
      12.191870445817015, 11.683156920035426, 11.152442115262197,
425
0
      11.90303691580457,  11.653292787169159, 11.938615382266098,
426
0
      16.970641701570223, 16.853602280380002, 17.26240782594733,
427
0
      16.644655390108507, 17.14310889757499,  16.910935455445955,
428
0
      17.505678976959697, 17.213498225466388, 2.4162310293553024,
429
0
      3.494587244462329,  3.5258600986408344, 3.4959806589517095,
430
0
      3.098390886949687,  3.343454654302911,  3.588847442290287,
431
0
      4.14614790111827,   5.152948641990529,  7.433696808092598,
432
0
      9.716311684833672,
433
0
  };
434
0
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
435
0
  if (tok >= table_size) tok = table_size - 1;
436
0
  return kCostTable[tok] + nbits;
437
0
}
438
439
std::vector<std::vector<Token>> ApplyLZ77_LZ77(
440
    const HistogramParams& params, size_t num_contexts,
441
0
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
442
0
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
443
  // TODO(veluca): tune heuristics here.
444
0
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
445
0
  float bit_decrease = 0;
446
0
  size_t total_symbols = 0;
447
0
  HybridUintConfig uint_config;
448
0
  std::vector<float> sym_cost;
449
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
450
0
    size_t distance_multiplier =
451
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
452
0
    const auto& in = tokens[stream];
453
0
    auto& out = tokens_lz77[stream];
454
0
    total_symbols += in.size();
455
    // Cumulative sum of bit costs.
456
0
    sym_cost.resize(in.size() + 1);
457
0
    for (size_t i = 0; i < in.size(); i++) {
458
0
      uint32_t tok, nbits, unused_bits;
459
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
460
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
461
0
    }
462
463
0
    out.reserve(in.size());
464
0
    size_t max_distance = in.size();
465
0
    size_t min_length = lz77.min_length;
466
0
    JXL_DASSERT(min_length >= 3);
467
0
    size_t max_length = in.size();
468
469
    // Use next power of two as window size.
470
0
    size_t window_size = 1;
471
0
    while (window_size < max_distance && window_size < kWindowSize) {
472
0
      window_size <<= 1;
473
0
    }
474
475
0
    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
476
0
                    distance_multiplier);
477
0
    size_t len;
478
0
    size_t dist_symbol;
479
480
0
    const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
481
482
    // Whether the next symbol was already updated (to test lazy matching)
483
0
    bool already_updated = false;
484
0
    for (size_t i = 0; i < in.size(); i++) {
485
0
      out.push_back(in[i]);
486
0
      if (!already_updated) chain.Update(i);
487
0
      already_updated = false;
488
0
      chain.FindMatch(i, max_distance, &dist_symbol, &len);
489
0
      if (len >= min_length) {
490
0
        if (len < max_lazy_match_len && i + 1 < in.size()) {
491
          // Try length at next symbol lazy matching
492
0
          chain.Update(i + 1);
493
0
          already_updated = true;
494
0
          size_t len2, dist_symbol2;
495
0
          chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
496
0
          if (len2 > len) {
497
            // Use the lazy match. Add literal, and use the next length starting
498
            // from the next byte.
499
0
            ++i;
500
0
            already_updated = false;
501
0
            len = len2;
502
0
            dist_symbol = dist_symbol2;
503
0
            out.push_back(in[i]);
504
0
          }
505
0
        }
506
507
0
        float cost = sym_cost[i + len] - sym_cost[i];
508
0
        size_t lz77_len = len - lz77.min_length;
509
0
        float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
510
0
                          sce.AddSymbolCost(out.back().context);
511
512
0
        if (lz77_cost <= cost) {
513
0
          out.back().value = len - min_length;
514
0
          out.back().is_lz77_length = true;
515
0
          out.emplace_back(
516
0
              static_cast<uint32_t>(lz77.nonserialized_distance_context),
517
0
              static_cast<uint32_t>(dist_symbol));
518
0
          bit_decrease += cost - lz77_cost;
519
0
        } else {
520
          // LZ77 match ignored, and symbol already pushed. Push all other
521
          // symbols and skip.
522
0
          for (size_t j = 1; j < len; j++) {
523
0
            out.push_back(in[i + j]);
524
0
          }
525
0
        }
526
527
0
        if (already_updated) {
528
0
          chain.Update(i + 2, len - 2);
529
0
          already_updated = false;
530
0
        } else {
531
0
          chain.Update(i + 1, len - 1);
532
0
        }
533
0
        i += len - 1;
534
0
      } else {
535
        // Literal, already pushed
536
0
      }
537
0
    }
538
0
  }
539
540
0
  if (bit_decrease > total_symbols * 0.2 + 16) {
541
0
    return tokens_lz77;
542
0
  }
543
0
  return {};
544
0
}
545
546
std::vector<std::vector<Token>> ApplyLZ77_Optimal(
547
    const HistogramParams& params, size_t num_contexts,
548
0
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
549
0
  std::vector<std::vector<Token>> tokens_for_cost_estimate =
550
0
      ApplyLZ77_LZ77(params, num_contexts, tokens, lz77);
551
  // If greedy-LZ77 does not give better compression than no-lz77, no reason to
552
  // run the optimal matching.
553
0
  if (tokens_for_cost_estimate.empty()) return {};
554
0
  SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
555
0
                          tokens_for_cost_estimate, lz77);
556
0
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
557
0
  HybridUintConfig uint_config;
558
0
  std::vector<float> sym_cost;
559
0
  std::vector<uint32_t> dist_symbols;
560
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
561
0
    size_t distance_multiplier =
562
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
563
0
    const auto& in = tokens[stream];
564
0
    auto& out = tokens_lz77[stream];
565
    // Cumulative sum of bit costs.
566
0
    sym_cost.resize(in.size() + 1);
567
0
    for (size_t i = 0; i < in.size(); i++) {
568
0
      uint32_t tok, nbits, unused_bits;
569
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
570
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
571
0
    }
572
573
0
    out.reserve(in.size());
574
0
    size_t max_distance = in.size();
575
0
    size_t min_length = lz77.min_length;
576
0
    JXL_DASSERT(min_length >= 3);
577
0
    size_t max_length = in.size();
578
579
    // Use next power of two as window size.
580
0
    size_t window_size = 1;
581
0
    while (window_size < max_distance && window_size < kWindowSize) {
582
0
      window_size <<= 1;
583
0
    }
584
585
0
    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
586
0
                    distance_multiplier);
587
588
0
    struct MatchInfo {
589
0
      uint32_t len;
590
0
      uint32_t dist_symbol;
591
0
      uint32_t ctx;
592
0
      float total_cost = std::numeric_limits<float>::max();
593
0
    };
594
    // Total cost to encode the first N symbols.
595
0
    std::vector<MatchInfo> prefix_costs(in.size() + 1);
596
0
    prefix_costs[0].total_cost = 0;
597
598
0
    size_t rle_length = 0;
599
0
    size_t skip_lz77 = 0;
600
0
    for (size_t i = 0; i < in.size(); i++) {
601
0
      chain.Update(i);
602
0
      float lit_cost =
603
0
          prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
604
0
      if (prefix_costs[i + 1].total_cost > lit_cost) {
605
0
        prefix_costs[i + 1].dist_symbol = 0;
606
0
        prefix_costs[i + 1].len = 1;
607
0
        prefix_costs[i + 1].ctx = in[i].context;
608
0
        prefix_costs[i + 1].total_cost = lit_cost;
609
0
      }
610
0
      if (skip_lz77 > 0) {
611
0
        skip_lz77--;
612
0
        continue;
613
0
      }
614
0
      dist_symbols.clear();
615
0
      chain.FindMatches(i, max_distance,
616
0
                        [&dist_symbols](size_t len, size_t dist_symbol) {
617
0
                          if (dist_symbols.size() <= len) {
618
0
                            dist_symbols.resize(len + 1, dist_symbol);
619
0
                          }
620
0
                          if (dist_symbol < dist_symbols[len]) {
621
0
                            dist_symbols[len] = dist_symbol;
622
0
                          }
623
0
                        });
624
0
      if (dist_symbols.size() <= min_length) continue;
625
0
      {
626
0
        size_t best_cost = dist_symbols.back();
627
0
        for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
628
0
          if (dist_symbols[j] < best_cost) {
629
0
            best_cost = dist_symbols[j];
630
0
          }
631
0
          dist_symbols[j] = best_cost;
632
0
        }
633
0
      }
634
0
      for (size_t j = min_length; j < dist_symbols.size(); j++) {
635
        // Cost model that uses results from lazy LZ77.
636
0
        float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
637
0
                          sce.DistCost(dist_symbols[j], lz77);
638
0
        float cost = prefix_costs[i].total_cost + lz77_cost;
639
0
        if (prefix_costs[i + j].total_cost > cost) {
640
0
          prefix_costs[i + j].len = j;
641
0
          prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
642
0
          prefix_costs[i + j].ctx = in[i].context;
643
0
          prefix_costs[i + j].total_cost = cost;
644
0
        }
645
0
      }
646
      // We are in a RLE sequence: skip all the symbols except the first 8 and
647
      // the last 8. This avoid quadratic costs for sequences with long runs of
648
      // the same symbol.
649
0
      if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
650
0
          (dist_symbols.back() == 1 && distance_multiplier != 0)) {
651
0
        rle_length++;
652
0
      } else {
653
0
        rle_length = 0;
654
0
      }
655
0
      if (rle_length >= 8 && dist_symbols.size() > 9) {
656
0
        skip_lz77 = dist_symbols.size() - 10;
657
0
        rle_length = 0;
658
0
      }
659
0
    }
660
0
    size_t pos = in.size();
661
0
    while (pos > 0) {
662
0
      bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
663
0
      if (is_lz77_length) {
664
0
        size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
665
0
        out.emplace_back(
666
0
            static_cast<uint32_t>(lz77.nonserialized_distance_context),
667
0
            static_cast<uint32_t>(dist_symbol));
668
0
      }
669
0
      uint32_t val =
670
0
          is_lz77_length
671
0
              ? (prefix_costs[pos].len - static_cast<uint32_t>(min_length))
672
0
              : in[pos - 1].value;
673
0
      out.emplace_back(prefix_costs[pos].ctx, val);
674
0
      out.back().is_lz77_length = is_lz77_length;
675
0
      pos -= prefix_costs[pos].len;
676
0
    }
677
0
    if (!out.empty()) std::reverse(out.begin(), out.end());
678
0
  }
679
0
  return tokens_lz77;
680
0
}
681
682
}  // namespace
683
684
std::vector<std::vector<Token>> ApplyLZ77(
685
    const HistogramParams& params, size_t num_contexts,
686
77.4k
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
687
77.4k
  switch (params.lz77_method) {
688
63.8k
    case HistogramParams::LZ77Method::kRLE:
689
63.8k
      return ApplyLZ77_RLE(params, num_contexts, tokens, lz77);
690
0
    case HistogramParams::LZ77Method::kLZ77:
691
0
      return ApplyLZ77_LZ77(params, num_contexts, tokens, lz77);
692
0
    case HistogramParams::LZ77Method::kOptimal:
693
0
      return ApplyLZ77_Optimal(params, num_contexts, tokens, lz77);
694
13.5k
    default:
695
13.5k
      return {};
696
77.4k
  }
697
77.4k
}
698
699
}  // namespace jxl