Coverage Report

Created: 2025-06-16 07:00

/src/libjxl/lib/jxl/enc_lz77.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_lz77.h"
7
8
#include <jxl/memory_manager.h>
9
#include <jxl/types.h>
10
11
#include <algorithm>
12
#include <cmath>
13
#include <cstddef>
14
#include <cstdint>
15
#include <limits>
16
#include <unordered_map>
17
#include <utility>
18
#include <vector>
19
20
#include "lib/jxl/ans_common.h"
21
#include "lib/jxl/ans_params.h"
22
#include "lib/jxl/base/bits.h"
23
#include "lib/jxl/base/common.h"
24
#include "lib/jxl/base/compiler_specific.h"
25
#include "lib/jxl/base/fast_math-inl.h"
26
#include "lib/jxl/base/status.h"
27
#include "lib/jxl/common.h"
28
#include "lib/jxl/dec_ans.h"
29
#include "lib/jxl/enc_ans_params.h"
30
#include "lib/jxl/enc_aux_out.h"
31
#include "lib/jxl/enc_cluster.h"
32
#include "lib/jxl/enc_context_map.h"
33
#include "lib/jxl/enc_fields.h"
34
#include "lib/jxl/enc_huffman.h"
35
#include "lib/jxl/enc_params.h"
36
#include "lib/jxl/fields.h"
37
38
namespace jxl {
39
40
namespace {
41
42
class SymbolCostEstimator {
43
 public:
44
  SymbolCostEstimator(size_t num_contexts, bool force_huffman,
45
                      const std::vector<std::vector<Token>>& tokens,
46
2.97k
                      const LZ77Params& lz77) {
47
2.97k
    std::vector<Histogram> builder(num_contexts);
48
    // Build histograms for estimating lz77 savings.
49
2.97k
    HybridUintConfig uint_config;
50
5.01k
    for (const auto& stream : tokens) {
51
4.78M
      for (const auto& token : stream) {
52
4.78M
        uint32_t tok, nbits, bits;
53
4.78M
        (token.is_lz77_length ? lz77.length_uint_config : uint_config)
54
4.78M
            .Encode(token.value, &tok, &nbits, &bits);
55
4.78M
        tok += token.is_lz77_length ? lz77.min_symbol : 0;
56
4.78M
        JXL_DASSERT(token.context < num_contexts);
57
4.78M
        builder[token.context].Add(tok);
58
4.78M
      }
59
5.01k
    }
60
2.97k
    max_alphabet_size_ = 0;
61
11.1k
    for (size_t i = 0; i < num_contexts; i++) {
62
8.18k
      max_alphabet_size_ =
63
8.18k
          std::max(max_alphabet_size_, builder[i].counts.size());
64
8.18k
    }
65
2.97k
    bits_.resize(num_contexts * max_alphabet_size_);
66
    // TODO(veluca): SIMD?
67
2.97k
    add_symbol_cost_.resize(num_contexts);
68
11.1k
    for (size_t i = 0; i < num_contexts; i++) {
69
8.18k
      float inv_total = 1.0f / (builder[i].total_count + 1e-8f);
70
8.18k
      float total_cost = 0;
71
208k
      for (size_t j = 0; j < builder[i].counts.size(); j++) {
72
200k
        size_t cnt = builder[i].counts[j];
73
200k
        float cost = 0;
74
200k
        if (cnt != 0 && cnt != builder[i].total_count) {
75
90.7k
          cost = -FastLog2f(cnt * inv_total);
76
90.7k
          if (force_huffman) cost = std::ceil(cost);
77
109k
        } else if (cnt == 0) {
78
108k
          cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
79
108k
        }
80
200k
        bits_[i * max_alphabet_size_ + j] = cost;
81
200k
        total_cost += cost * builder[i].counts[j];
82
200k
      }
83
      // Penalty for adding a lz77 symbol to this contest (only used for static
84
      // cost model). Higher penalty for contexts that have a very low
85
      // per-symbol entropy.
86
8.18k
      add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
87
8.18k
    }
88
2.97k
  }
89
4.78M
  float Bits(size_t ctx, size_t sym) const {
90
4.78M
    return bits_[ctx * max_alphabet_size_ + sym];
91
4.78M
  }
92
0
  float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
93
0
    uint32_t nbits, bits, tok;
94
0
    lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
95
0
    tok += lz77.min_symbol;
96
0
    return nbits + Bits(ctx, tok);
97
0
  }
98
0
  float DistCost(size_t len, const LZ77Params& lz77) const {
99
0
    uint32_t nbits, bits, tok;
100
0
    HybridUintConfig().Encode(len, &tok, &nbits, &bits);
101
0
    return nbits + Bits(lz77.nonserialized_distance_context, tok);
102
0
  }
103
0
  float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
104
105
 private:
106
  size_t max_alphabet_size_;
107
  std::vector<float> bits_;
108
  std::vector<float> add_symbol_cost_;
109
};
110
111
std::vector<std::vector<Token>> ApplyLZ77_RLE(
112
    const HistogramParams& params, size_t num_contexts,
113
2.97k
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
114
2.97k
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
115
  // TODO(veluca): tune heuristics here.
116
2.97k
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
117
2.97k
  float bit_decrease = 0;
118
2.97k
  size_t total_symbols = 0;
119
2.97k
  std::vector<float> sym_cost;
120
2.97k
  HybridUintConfig uint_config;
121
7.98k
  for (size_t stream = 0; stream < tokens.size(); stream++) {
122
5.01k
    size_t distance_multiplier =
123
5.01k
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
124
5.01k
    const auto& in = tokens[stream];
125
5.01k
    auto& out = tokens_lz77[stream];
126
5.01k
    total_symbols += in.size();
127
    // Cumulative sum of bit costs.
128
5.01k
    sym_cost.resize(in.size() + 1);
129
4.79M
    for (size_t i = 0; i < in.size(); i++) {
130
4.78M
      uint32_t tok, nbits, unused_bits;
131
4.78M
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
132
4.78M
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
133
4.78M
    }
134
5.01k
    out.reserve(in.size());
135
3.77M
    for (size_t i = 0; i < in.size(); i++) {
136
3.77M
      size_t num_to_copy = 0;
137
3.77M
      size_t distance_symbol = 0;  // 1 for RLE.
138
3.77M
      if (distance_multiplier != 0) {
139
844k
        distance_symbol = 1;  // Special distance 1 if enabled.
140
844k
        JXL_DASSERT(kSpecialDistances[1][0] == 1);
141
844k
        JXL_DASSERT(kSpecialDistances[1][1] == 0);
142
844k
      }
143
3.77M
      if (i > 0) {
144
5.00M
        for (; i + num_to_copy < in.size(); num_to_copy++) {
145
5.00M
          if (in[i + num_to_copy].value != in[i - 1].value) {
146
3.76M
            break;
147
3.76M
          }
148
5.00M
        }
149
3.76M
      }
150
3.77M
      if (num_to_copy == 0) {
151
3.54M
        out.push_back(in[i]);
152
3.54M
        continue;
153
3.54M
      }
154
226k
      float cost = sym_cost[i + num_to_copy] - sym_cost[i];
155
      // This subtraction might overflow, but that's OK.
156
226k
      size_t lz77_len = num_to_copy - lz77.min_length;
157
226k
      float lz77_cost = num_to_copy >= lz77.min_length
158
226k
                            ? CeilLog2Nonzero(lz77_len + 1) + 1
159
226k
                            : 0;
160
226k
      if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
161
912k
        for (size_t j = 0; j < num_to_copy; j++) {
162
725k
          out.push_back(in[i + j]);
163
725k
        }
164
187k
        i += num_to_copy - 1;
165
187k
        continue;
166
187k
      }
167
      // Output the LZ77 length
168
39.4k
      out.emplace_back(in[i].context, lz77_len);
169
39.4k
      out.back().is_lz77_length = true;
170
39.4k
      i += num_to_copy - 1;
171
39.4k
      bit_decrease += cost - lz77_cost;
172
      // Output the LZ77 copy distance.
173
39.4k
      out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
174
39.4k
    }
175
5.01k
  }
176
177
2.97k
  if (bit_decrease > total_symbols * 0.2 + 16) {
178
365
    return tokens_lz77;
179
365
  }
180
2.61k
  return {};
181
2.97k
}
182
183
// Hash chain for LZ77 matching
184
struct HashChain {
185
  size_t size_;
186
  std::vector<uint32_t> data_;
187
188
  unsigned hash_num_values_ = 32768;
189
  unsigned hash_mask_ = hash_num_values_ - 1;
190
  unsigned hash_shift_ = 5;
191
192
  std::vector<int> head;
193
  std::vector<uint32_t> chain;
194
  std::vector<int> val;
195
196
  // Speed up repetitions of zero
197
  std::vector<int> headz;
198
  std::vector<uint32_t> chainz;
199
  std::vector<uint32_t> zeros;
200
  uint32_t numzeros = 0;
201
202
  size_t window_size_;
203
  size_t window_mask_;
204
  size_t min_length_;
205
  size_t max_length_;
206
207
  // Map of special distance codes.
208
  std::unordered_map<int, int> special_dist_table_;
209
  size_t num_special_distances_ = 0;
210
211
  uint32_t maxchainlength = 256;  // window_size_ to allow all
212
213
  HashChain(const Token* data, size_t size, size_t window_size,
214
            size_t min_length, size_t max_length, size_t distance_multiplier)
215
0
      : size_(size),
216
0
        window_size_(window_size),
217
0
        window_mask_(window_size - 1),
218
0
        min_length_(min_length),
219
0
        max_length_(max_length) {
220
0
    data_.resize(size);
221
0
    for (size_t i = 0; i < size; i++) {
222
0
      data_[i] = data[i].value;
223
0
    }
224
225
0
    head.resize(hash_num_values_, -1);
226
0
    val.resize(window_size_, -1);
227
0
    chain.resize(window_size_);
228
0
    for (uint32_t i = 0; i < window_size_; ++i) {
229
0
      chain[i] = i;  // same value as index indicates uninitialized
230
0
    }
231
232
0
    zeros.resize(window_size_);
233
0
    headz.resize(window_size_ + 1, -1);
234
0
    chainz.resize(window_size_);
235
0
    for (uint32_t i = 0; i < window_size_; ++i) {
236
0
      chainz[i] = i;
237
0
    }
238
    // Translate distance to special distance code.
239
0
    if (distance_multiplier) {
240
      // Count down, so if due to small distance multiplier multiple distances
241
      // map to the same code, the smallest code will be used in the end.
242
0
      for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
243
0
        special_dist_table_[SpecialDistance(i, distance_multiplier)] = i;
244
0
      }
245
0
      num_special_distances_ = kNumSpecialDistances;
246
0
    }
247
0
  }
248
249
0
  uint32_t GetHash(size_t pos) const {
250
0
    uint32_t result = 0;
251
0
    if (pos + 2 < size_) {
252
      // TODO(lode): take the MSB's of the uint32_t values into account as well,
253
      // given that the hash code itself is less than 32 bits.
254
0
      result ^= static_cast<uint32_t>(data_[pos + 0] << 0u);
255
0
      result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_);
256
0
      result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2));
257
0
    } else {
258
      // No need to compute hash of last 2 bytes, the length 2 is too short.
259
0
      return 0;
260
0
    }
261
0
    return result & hash_mask_;
262
0
  }
263
264
0
  uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
265
0
    size_t end = pos + window_size_;
266
0
    if (end > size_) end = size_;
267
0
    if (prevzeros > 0) {
268
0
      if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
269
0
          end == pos + window_size_) {
270
0
        return prevzeros;
271
0
      } else {
272
0
        return prevzeros - 1;
273
0
      }
274
0
    }
275
0
    uint32_t num = 0;
276
0
    while (pos + num < end && data_[pos + num] == 0) num++;
277
0
    return num;
278
0
  }
279
280
0
  void Update(size_t pos) {
281
0
    uint32_t hashval = GetHash(pos);
282
0
    uint32_t wpos = pos & window_mask_;
283
284
0
    val[wpos] = static_cast<int>(hashval);
285
0
    if (head[hashval] != -1) chain[wpos] = head[hashval];
286
0
    head[hashval] = wpos;
287
288
0
    if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
289
0
    numzeros = CountZeros(pos, numzeros);
290
291
0
    zeros[wpos] = numzeros;
292
0
    if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
293
0
    headz[numzeros] = wpos;
294
0
  }
295
296
0
  void Update(size_t pos, size_t len) {
297
0
    for (size_t i = 0; i < len; i++) {
298
0
      Update(pos + i);
299
0
    }
300
0
  }
301
302
  template <typename CB>
303
0
  void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
304
0
    uint32_t wpos = pos & window_mask_;
305
0
    uint32_t hashval = GetHash(pos);
306
0
    uint32_t hashpos = chain[wpos];
307
308
0
    int prev_dist = 0;
309
0
    int end = std::min<int>(pos + max_length_, size_);
310
0
    uint32_t chainlength = 0;
311
0
    uint32_t best_len = 0;
312
0
    for (;;) {
313
0
      int dist = (hashpos <= wpos) ? (wpos - hashpos)
314
0
                                   : (wpos - hashpos + window_mask_ + 1);
315
0
      if (dist < prev_dist) break;
316
0
      prev_dist = dist;
317
0
      uint32_t len = 0;
318
0
      if (dist > 0) {
319
0
        int i = pos;
320
0
        int j = pos - dist;
321
0
        if (numzeros > 3) {
322
0
          int r = std::min<int>(numzeros - 1, zeros[hashpos]);
323
0
          if (i + r >= end) r = end - i - 1;
324
0
          i += r;
325
0
          j += r;
326
0
        }
327
0
        while (i < end && data_[i] == data_[j]) {
328
0
          i++;
329
0
          j++;
330
0
        }
331
0
        len = i - pos;
332
        // This can trigger even if the new length is slightly smaller than the
333
        // best length, because it is possible for a slightly cheaper distance
334
        // symbol to occur.
335
0
        if (len >= min_length_ && len + 2 >= best_len) {
336
0
          auto it = special_dist_table_.find(dist);
337
0
          int dist_symbol = (it == special_dist_table_.end())
338
0
                                ? (num_special_distances_ + dist - 1)
339
0
                                : it->second;
340
0
          found_match(len, dist_symbol);
341
0
          if (len > best_len) best_len = len;
342
0
        }
343
0
      }
344
345
0
      chainlength++;
346
0
      if (chainlength >= maxchainlength) break;
347
348
0
      if (numzeros >= 3 && len > numzeros) {
349
0
        if (hashpos == chainz[hashpos]) break;
350
0
        hashpos = chainz[hashpos];
351
0
        if (zeros[hashpos] != numzeros) break;
352
0
      } else {
353
0
        if (hashpos == chain[hashpos]) break;
354
0
        hashpos = chain[hashpos];
355
0
        if (val[hashpos] != static_cast<int>(hashval)) {
356
          // outdated hash value
357
0
          break;
358
0
        }
359
0
      }
360
0
    }
361
0
  }
Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1}>(unsigned long, int, jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1} const&) const
Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0>(unsigned long, int, jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0 const&) const
362
  void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
363
0
                 size_t* result_len) const {
364
0
    *result_dist_symbol = 0;
365
0
    *result_len = 1;
366
0
    FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
367
0
      if (len > *result_len ||
368
0
          (len == *result_len && *result_dist_symbol > dist_symbol)) {
369
0
        *result_len = len;
370
0
        *result_dist_symbol = dist_symbol;
371
0
      }
372
0
    });
373
0
  }
374
};
375
376
0
float LenCost(size_t len) {
377
0
  uint32_t nbits, bits, tok;
378
0
  HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
379
0
  constexpr float kCostTable[] = {
380
0
      2.797667318563126,  3.213177690381199,  2.5706009246743737,
381
0
      2.408392498667534,  2.829649191872326,  3.3923087753324577,
382
0
      4.029267451554331,  4.415576699706408,  4.509357574741465,
383
0
      9.21481543803004,   10.020590190114898, 11.858671627804766,
384
0
      12.45853300490526,  11.713105831990857, 12.561996324849314,
385
0
      13.775477692278367, 13.174027068768641,
386
0
  };
387
0
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
388
0
  if (tok >= table_size) tok = table_size - 1;
389
0
  return kCostTable[tok] + nbits;
390
0
}
391
392
// TODO(veluca): this does not take into account usage or non-usage of distance
393
// multipliers.
394
0
float DistCost(size_t dist) {
395
0
  uint32_t nbits, bits, tok;
396
0
  HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
397
0
  constexpr float kCostTable[] = {
398
0
      6.368282626312716,  5.680793277090298,  8.347404197105247,
399
0
      7.641619201599141,  6.914328374119438,  7.959808291537444,
400
0
      8.70023120759855,   8.71378518934703,   9.379132523982769,
401
0
      9.110472749092708,  9.159029569270908,  9.430936766731973,
402
0
      7.278284055315169,  7.8278514904267755, 10.026641158289236,
403
0
      9.976049229827066,  9.64351607048908,   9.563403863480442,
404
0
      10.171474111762747, 10.45950155077234,  9.994813912104219,
405
0
      10.322524683741156, 8.465808729388186,  8.756254166066853,
406
0
      10.160930174662234, 10.247329273413435, 10.04090403724809,
407
0
      10.129398517544082, 9.342311691539546,  9.07608009102374,
408
0
      10.104799540677513, 10.378079384990906, 10.165828974075072,
409
0
      10.337595322341553, 7.940557464567944,  10.575665823319431,
410
0
      11.023344321751955, 10.736144698831827, 11.118277044595054,
411
0
      7.468468230648442,  10.738305230932939, 10.906980780216568,
412
0
      10.163468216353817, 10.17805759656433,  11.167283670483565,
413
0
      11.147050200274544, 10.517921919244333, 10.651764778156886,
414
0
      10.17074446448919,  11.217636876224745, 11.261630721139484,
415
0
      11.403140815247259, 10.892472096873417, 11.1859607804481,
416
0
      8.017346947551262,  7.895143720278828,  11.036577113822025,
417
0
      11.170562110315794, 10.326988722591086, 10.40872184751056,
418
0
      11.213498225466386, 11.30580635516863,  10.672272515665442,
419
0
      10.768069466228063, 11.145257364153565, 11.64668307145549,
420
0
      10.593156194627339, 11.207499484844943, 10.767517766396908,
421
0
      10.826629811407042, 10.737764794499988, 10.6200448518045,
422
0
      10.191315385198092, 8.468384171390085,  11.731295299170432,
423
0
      11.824619886654398, 10.41518844301179,  10.16310536548649,
424
0
      10.539423685097576, 10.495136599328031, 10.469112847728267,
425
0
      11.72057686174922,  10.910326337834674, 11.378921834673758,
426
0
      11.847759036098536, 11.92071647623854,  10.810628276345282,
427
0
      11.008601085273893, 11.910326337834674, 11.949212023423133,
428
0
      11.298614839104337, 11.611603659010392, 10.472930394619985,
429
0
      11.835564720850282, 11.523267392285337, 12.01055816679611,
430
0
      8.413029688994023,  11.895784139536406, 11.984679534970505,
431
0
      11.220654278717394, 11.716311684833672, 10.61036646226114,
432
0
      10.89849965960364,  10.203762898863669, 10.997560826267238,
433
0
      11.484217379438984, 11.792836176993665, 12.24310468755171,
434
0
      11.464858097919262, 12.212747017409377, 11.425595666074955,
435
0
      11.572048533398757, 12.742093965163013, 11.381874288645637,
436
0
      12.191870445817015, 11.683156920035426, 11.152442115262197,
437
0
      11.90303691580457,  11.653292787169159, 11.938615382266098,
438
0
      16.970641701570223, 16.853602280380002, 17.26240782594733,
439
0
      16.644655390108507, 17.14310889757499,  16.910935455445955,
440
0
      17.505678976959697, 17.213498225466388, 2.4162310293553024,
441
0
      3.494587244462329,  3.5258600986408344, 3.4959806589517095,
442
0
      3.098390886949687,  3.343454654302911,  3.588847442290287,
443
0
      4.14614790111827,   5.152948641990529,  7.433696808092598,
444
0
      9.716311684833672,
445
0
  };
446
0
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
447
0
  if (tok >= table_size) tok = table_size - 1;
448
0
  return kCostTable[tok] + nbits;
449
0
}
450
451
std::vector<std::vector<Token>> ApplyLZ77_LZ77(
452
    const HistogramParams& params, size_t num_contexts,
453
0
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
454
0
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
455
  // TODO(veluca): tune heuristics here.
456
0
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
457
0
  float bit_decrease = 0;
458
0
  size_t total_symbols = 0;
459
0
  HybridUintConfig uint_config;
460
0
  std::vector<float> sym_cost;
461
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
462
0
    size_t distance_multiplier =
463
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
464
0
    const auto& in = tokens[stream];
465
0
    auto& out = tokens_lz77[stream];
466
0
    total_symbols += in.size();
467
    // Cumulative sum of bit costs.
468
0
    sym_cost.resize(in.size() + 1);
469
0
    for (size_t i = 0; i < in.size(); i++) {
470
0
      uint32_t tok, nbits, unused_bits;
471
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
472
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
473
0
    }
474
475
0
    out.reserve(in.size());
476
0
    size_t max_distance = in.size();
477
0
    size_t min_length = lz77.min_length;
478
0
    JXL_DASSERT(min_length >= 3);
479
0
    size_t max_length = in.size();
480
481
    // Use next power of two as window size.
482
0
    size_t window_size = 1;
483
0
    while (window_size < max_distance && window_size < kWindowSize) {
484
0
      window_size <<= 1;
485
0
    }
486
487
0
    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
488
0
                    distance_multiplier);
489
0
    size_t len;
490
0
    size_t dist_symbol;
491
492
0
    const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
493
494
    // Whether the next symbol was already updated (to test lazy matching)
495
0
    bool already_updated = false;
496
0
    for (size_t i = 0; i < in.size(); i++) {
497
0
      out.push_back(in[i]);
498
0
      if (!already_updated) chain.Update(i);
499
0
      already_updated = false;
500
0
      chain.FindMatch(i, max_distance, &dist_symbol, &len);
501
0
      if (len >= min_length) {
502
0
        if (len < max_lazy_match_len && i + 1 < in.size()) {
503
          // Try length at next symbol lazy matching
504
0
          chain.Update(i + 1);
505
0
          already_updated = true;
506
0
          size_t len2, dist_symbol2;
507
0
          chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
508
0
          if (len2 > len) {
509
            // Use the lazy match. Add literal, and use the next length starting
510
            // from the next byte.
511
0
            ++i;
512
0
            already_updated = false;
513
0
            len = len2;
514
0
            dist_symbol = dist_symbol2;
515
0
            out.push_back(in[i]);
516
0
          }
517
0
        }
518
519
0
        float cost = sym_cost[i + len] - sym_cost[i];
520
0
        size_t lz77_len = len - lz77.min_length;
521
0
        float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
522
0
                          sce.AddSymbolCost(out.back().context);
523
524
0
        if (lz77_cost <= cost) {
525
0
          out.back().value = len - min_length;
526
0
          out.back().is_lz77_length = true;
527
0
          out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
528
0
          bit_decrease += cost - lz77_cost;
529
0
        } else {
530
          // LZ77 match ignored, and symbol already pushed. Push all other
531
          // symbols and skip.
532
0
          for (size_t j = 1; j < len; j++) {
533
0
            out.push_back(in[i + j]);
534
0
          }
535
0
        }
536
537
0
        if (already_updated) {
538
0
          chain.Update(i + 2, len - 2);
539
0
          already_updated = false;
540
0
        } else {
541
0
          chain.Update(i + 1, len - 1);
542
0
        }
543
0
        i += len - 1;
544
0
      } else {
545
        // Literal, already pushed
546
0
      }
547
0
    }
548
0
  }
549
550
0
  if (bit_decrease > total_symbols * 0.2 + 16) {
551
0
    return tokens_lz77;
552
0
  }
553
0
  return {};
554
0
}
555
556
std::vector<std::vector<Token>> ApplyLZ77_Optimal(
557
    const HistogramParams& params, size_t num_contexts,
558
0
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
559
0
  std::vector<std::vector<Token>> tokens_for_cost_estimate =
560
0
      ApplyLZ77_LZ77(params, num_contexts, tokens, lz77);
561
  // If greedy-LZ77 does not give better compression than no-lz77, no reason to
562
  // run the optimal matching.
563
0
  if (tokens_for_cost_estimate.empty()) return {};
564
0
  SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
565
0
                          tokens_for_cost_estimate, lz77);
566
0
  std::vector<std::vector<Token>> tokens_lz77(tokens.size());
567
0
  HybridUintConfig uint_config;
568
0
  std::vector<float> sym_cost;
569
0
  std::vector<uint32_t> dist_symbols;
570
0
  for (size_t stream = 0; stream < tokens.size(); stream++) {
571
0
    size_t distance_multiplier =
572
0
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
573
0
    const auto& in = tokens[stream];
574
0
    auto& out = tokens_lz77[stream];
575
    // Cumulative sum of bit costs.
576
0
    sym_cost.resize(in.size() + 1);
577
0
    for (size_t i = 0; i < in.size(); i++) {
578
0
      uint32_t tok, nbits, unused_bits;
579
0
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
580
0
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
581
0
    }
582
583
0
    out.reserve(in.size());
584
0
    size_t max_distance = in.size();
585
0
    size_t min_length = lz77.min_length;
586
0
    JXL_DASSERT(min_length >= 3);
587
0
    size_t max_length = in.size();
588
589
    // Use next power of two as window size.
590
0
    size_t window_size = 1;
591
0
    while (window_size < max_distance && window_size < kWindowSize) {
592
0
      window_size <<= 1;
593
0
    }
594
595
0
    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
596
0
                    distance_multiplier);
597
598
0
    struct MatchInfo {
599
0
      uint32_t len;
600
0
      uint32_t dist_symbol;
601
0
      uint32_t ctx;
602
0
      float total_cost = std::numeric_limits<float>::max();
603
0
    };
604
    // Total cost to encode the first N symbols.
605
0
    std::vector<MatchInfo> prefix_costs(in.size() + 1);
606
0
    prefix_costs[0].total_cost = 0;
607
608
0
    size_t rle_length = 0;
609
0
    size_t skip_lz77 = 0;
610
0
    for (size_t i = 0; i < in.size(); i++) {
611
0
      chain.Update(i);
612
0
      float lit_cost =
613
0
          prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
614
0
      if (prefix_costs[i + 1].total_cost > lit_cost) {
615
0
        prefix_costs[i + 1].dist_symbol = 0;
616
0
        prefix_costs[i + 1].len = 1;
617
0
        prefix_costs[i + 1].ctx = in[i].context;
618
0
        prefix_costs[i + 1].total_cost = lit_cost;
619
0
      }
620
0
      if (skip_lz77 > 0) {
621
0
        skip_lz77--;
622
0
        continue;
623
0
      }
624
0
      dist_symbols.clear();
625
0
      chain.FindMatches(i, max_distance,
626
0
                        [&dist_symbols](size_t len, size_t dist_symbol) {
627
0
                          if (dist_symbols.size() <= len) {
628
0
                            dist_symbols.resize(len + 1, dist_symbol);
629
0
                          }
630
0
                          if (dist_symbol < dist_symbols[len]) {
631
0
                            dist_symbols[len] = dist_symbol;
632
0
                          }
633
0
                        });
634
0
      if (dist_symbols.size() <= min_length) continue;
635
0
      {
636
0
        size_t best_cost = dist_symbols.back();
637
0
        for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
638
0
          if (dist_symbols[j] < best_cost) {
639
0
            best_cost = dist_symbols[j];
640
0
          }
641
0
          dist_symbols[j] = best_cost;
642
0
        }
643
0
      }
644
0
      for (size_t j = min_length; j < dist_symbols.size(); j++) {
645
        // Cost model that uses results from lazy LZ77.
646
0
        float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
647
0
                          sce.DistCost(dist_symbols[j], lz77);
648
0
        float cost = prefix_costs[i].total_cost + lz77_cost;
649
0
        if (prefix_costs[i + j].total_cost > cost) {
650
0
          prefix_costs[i + j].len = j;
651
0
          prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
652
0
          prefix_costs[i + j].ctx = in[i].context;
653
0
          prefix_costs[i + j].total_cost = cost;
654
0
        }
655
0
      }
656
      // We are in a RLE sequence: skip all the symbols except the first 8 and
657
      // the last 8. This avoid quadratic costs for sequences with long runs of
658
      // the same symbol.
659
0
      if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
660
0
          (dist_symbols.back() == 1 && distance_multiplier != 0)) {
661
0
        rle_length++;
662
0
      } else {
663
0
        rle_length = 0;
664
0
      }
665
0
      if (rle_length >= 8 && dist_symbols.size() > 9) {
666
0
        skip_lz77 = dist_symbols.size() - 10;
667
0
        rle_length = 0;
668
0
      }
669
0
    }
670
0
    size_t pos = in.size();
671
0
    while (pos > 0) {
672
0
      bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
673
0
      if (is_lz77_length) {
674
0
        size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
675
0
        out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
676
0
      }
677
0
      size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
678
0
                                  : in[pos - 1].value;
679
0
      out.emplace_back(prefix_costs[pos].ctx, val);
680
0
      out.back().is_lz77_length = is_lz77_length;
681
0
      pos -= prefix_costs[pos].len;
682
0
    }
683
0
    std::reverse(out.begin(), out.end());
684
0
  }
685
0
  return tokens_lz77;
686
0
}
687
688
}  // namespace
689
690
std::vector<std::vector<Token>> ApplyLZ77(
691
    const HistogramParams& params, size_t num_contexts,
692
3.53k
    const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) {
693
3.53k
  switch (params.lz77_method) {
694
2.97k
    case HistogramParams::LZ77Method::kRLE:
695
2.97k
      return ApplyLZ77_RLE(params, num_contexts, tokens, lz77);
696
0
    case HistogramParams::LZ77Method::kLZ77:
697
0
      return ApplyLZ77_LZ77(params, num_contexts, tokens, lz77);
698
0
    case HistogramParams::LZ77Method::kOptimal:
699
0
      return ApplyLZ77_Optimal(params, num_contexts, tokens, lz77);
700
558
    default:
701
558
      return {};
702
3.53k
  }
703
3.53k
}
704
705
}  // namespace jxl