/src/libjxl/lib/jxl/enc_lz77.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/enc_lz77.h" |
7 | | |
8 | | #include <jxl/memory_manager.h> |
9 | | #include <jxl/types.h> |
10 | | |
11 | | #include <algorithm> |
12 | | #include <cmath> |
13 | | #include <cstddef> |
14 | | #include <cstdint> |
15 | | #include <limits> |
16 | | #include <unordered_map> |
17 | | #include <utility> |
18 | | #include <vector> |
19 | | |
20 | | #include "lib/jxl/ans_common.h" |
21 | | #include "lib/jxl/ans_params.h" |
22 | | #include "lib/jxl/base/bits.h" |
23 | | #include "lib/jxl/base/common.h" |
24 | | #include "lib/jxl/base/compiler_specific.h" |
25 | | #include "lib/jxl/base/fast_math-inl.h" |
26 | | #include "lib/jxl/base/status.h" |
27 | | #include "lib/jxl/common.h" |
28 | | #include "lib/jxl/dec_ans.h" |
29 | | #include "lib/jxl/enc_ans_params.h" |
30 | | #include "lib/jxl/enc_aux_out.h" |
31 | | #include "lib/jxl/enc_cluster.h" |
32 | | #include "lib/jxl/enc_context_map.h" |
33 | | #include "lib/jxl/enc_fields.h" |
34 | | #include "lib/jxl/enc_huffman.h" |
35 | | #include "lib/jxl/enc_params.h" |
36 | | #include "lib/jxl/fields.h" |
37 | | |
38 | | namespace jxl { |
39 | | |
40 | | namespace { |
41 | | |
42 | | class SymbolCostEstimator { |
43 | | public: |
44 | | SymbolCostEstimator(size_t num_contexts, bool force_huffman, |
45 | | const std::vector<std::vector<Token>>& tokens, |
46 | 2.97k | const LZ77Params& lz77) { |
47 | 2.97k | std::vector<Histogram> builder(num_contexts); |
48 | | // Build histograms for estimating lz77 savings. |
49 | 2.97k | HybridUintConfig uint_config; |
50 | 5.01k | for (const auto& stream : tokens) { |
51 | 4.78M | for (const auto& token : stream) { |
52 | 4.78M | uint32_t tok, nbits, bits; |
53 | 4.78M | (token.is_lz77_length ? lz77.length_uint_config : uint_config) |
54 | 4.78M | .Encode(token.value, &tok, &nbits, &bits); |
55 | 4.78M | tok += token.is_lz77_length ? lz77.min_symbol : 0; |
56 | 4.78M | JXL_DASSERT(token.context < num_contexts); |
57 | 4.78M | builder[token.context].Add(tok); |
58 | 4.78M | } |
59 | 5.01k | } |
60 | 2.97k | max_alphabet_size_ = 0; |
61 | 11.1k | for (size_t i = 0; i < num_contexts; i++) { |
62 | 8.18k | max_alphabet_size_ = |
63 | 8.18k | std::max(max_alphabet_size_, builder[i].counts.size()); |
64 | 8.18k | } |
65 | 2.97k | bits_.resize(num_contexts * max_alphabet_size_); |
66 | | // TODO(veluca): SIMD? |
67 | 2.97k | add_symbol_cost_.resize(num_contexts); |
68 | 11.1k | for (size_t i = 0; i < num_contexts; i++) { |
69 | 8.18k | float inv_total = 1.0f / (builder[i].total_count + 1e-8f); |
70 | 8.18k | float total_cost = 0; |
71 | 208k | for (size_t j = 0; j < builder[i].counts.size(); j++) { |
72 | 200k | size_t cnt = builder[i].counts[j]; |
73 | 200k | float cost = 0; |
74 | 200k | if (cnt != 0 && cnt != builder[i].total_count) { |
75 | 90.7k | cost = -FastLog2f(cnt * inv_total); |
76 | 90.7k | if (force_huffman) cost = std::ceil(cost); |
77 | 109k | } else if (cnt == 0) { |
78 | 108k | cost = ANS_LOG_TAB_SIZE; // Highest possible cost. |
79 | 108k | } |
80 | 200k | bits_[i * max_alphabet_size_ + j] = cost; |
81 | 200k | total_cost += cost * builder[i].counts[j]; |
82 | 200k | } |
83 | | // Penalty for adding a lz77 symbol to this contest (only used for static |
84 | | // cost model). Higher penalty for contexts that have a very low |
85 | | // per-symbol entropy. |
86 | 8.18k | add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total); |
87 | 8.18k | } |
88 | 2.97k | } |
89 | 4.78M | float Bits(size_t ctx, size_t sym) const { |
90 | 4.78M | return bits_[ctx * max_alphabet_size_ + sym]; |
91 | 4.78M | } |
92 | 0 | float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const { |
93 | 0 | uint32_t nbits, bits, tok; |
94 | 0 | lz77.length_uint_config.Encode(len, &tok, &nbits, &bits); |
95 | 0 | tok += lz77.min_symbol; |
96 | 0 | return nbits + Bits(ctx, tok); |
97 | 0 | } |
98 | 0 | float DistCost(size_t len, const LZ77Params& lz77) const { |
99 | 0 | uint32_t nbits, bits, tok; |
100 | 0 | HybridUintConfig().Encode(len, &tok, &nbits, &bits); |
101 | 0 | return nbits + Bits(lz77.nonserialized_distance_context, tok); |
102 | 0 | } |
103 | 0 | float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; } |
104 | | |
105 | | private: |
106 | | size_t max_alphabet_size_; |
107 | | std::vector<float> bits_; |
108 | | std::vector<float> add_symbol_cost_; |
109 | | }; |
110 | | |
111 | | std::vector<std::vector<Token>> ApplyLZ77_RLE( |
112 | | const HistogramParams& params, size_t num_contexts, |
113 | 2.97k | const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) { |
114 | 2.97k | std::vector<std::vector<Token>> tokens_lz77(tokens.size()); |
115 | | // TODO(veluca): tune heuristics here. |
116 | 2.97k | SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); |
117 | 2.97k | float bit_decrease = 0; |
118 | 2.97k | size_t total_symbols = 0; |
119 | 2.97k | std::vector<float> sym_cost; |
120 | 2.97k | HybridUintConfig uint_config; |
121 | 7.98k | for (size_t stream = 0; stream < tokens.size(); stream++) { |
122 | 5.01k | size_t distance_multiplier = |
123 | 5.01k | params.image_widths.size() > stream ? params.image_widths[stream] : 0; |
124 | 5.01k | const auto& in = tokens[stream]; |
125 | 5.01k | auto& out = tokens_lz77[stream]; |
126 | 5.01k | total_symbols += in.size(); |
127 | | // Cumulative sum of bit costs. |
128 | 5.01k | sym_cost.resize(in.size() + 1); |
129 | 4.79M | for (size_t i = 0; i < in.size(); i++) { |
130 | 4.78M | uint32_t tok, nbits, unused_bits; |
131 | 4.78M | uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); |
132 | 4.78M | sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; |
133 | 4.78M | } |
134 | 5.01k | out.reserve(in.size()); |
135 | 3.77M | for (size_t i = 0; i < in.size(); i++) { |
136 | 3.77M | size_t num_to_copy = 0; |
137 | 3.77M | size_t distance_symbol = 0; // 1 for RLE. |
138 | 3.77M | if (distance_multiplier != 0) { |
139 | 844k | distance_symbol = 1; // Special distance 1 if enabled. |
140 | 844k | JXL_DASSERT(kSpecialDistances[1][0] == 1); |
141 | 844k | JXL_DASSERT(kSpecialDistances[1][1] == 0); |
142 | 844k | } |
143 | 3.77M | if (i > 0) { |
144 | 5.00M | for (; i + num_to_copy < in.size(); num_to_copy++) { |
145 | 5.00M | if (in[i + num_to_copy].value != in[i - 1].value) { |
146 | 3.76M | break; |
147 | 3.76M | } |
148 | 5.00M | } |
149 | 3.76M | } |
150 | 3.77M | if (num_to_copy == 0) { |
151 | 3.54M | out.push_back(in[i]); |
152 | 3.54M | continue; |
153 | 3.54M | } |
154 | 226k | float cost = sym_cost[i + num_to_copy] - sym_cost[i]; |
155 | | // This subtraction might overflow, but that's OK. |
156 | 226k | size_t lz77_len = num_to_copy - lz77.min_length; |
157 | 226k | float lz77_cost = num_to_copy >= lz77.min_length |
158 | 226k | ? CeilLog2Nonzero(lz77_len + 1) + 1 |
159 | 226k | : 0; |
160 | 226k | if (num_to_copy < lz77.min_length || cost <= lz77_cost) { |
161 | 912k | for (size_t j = 0; j < num_to_copy; j++) { |
162 | 725k | out.push_back(in[i + j]); |
163 | 725k | } |
164 | 187k | i += num_to_copy - 1; |
165 | 187k | continue; |
166 | 187k | } |
167 | | // Output the LZ77 length |
168 | 39.4k | out.emplace_back(in[i].context, lz77_len); |
169 | 39.4k | out.back().is_lz77_length = true; |
170 | 39.4k | i += num_to_copy - 1; |
171 | 39.4k | bit_decrease += cost - lz77_cost; |
172 | | // Output the LZ77 copy distance. |
173 | 39.4k | out.emplace_back(lz77.nonserialized_distance_context, distance_symbol); |
174 | 39.4k | } |
175 | 5.01k | } |
176 | | |
177 | 2.97k | if (bit_decrease > total_symbols * 0.2 + 16) { |
178 | 365 | return tokens_lz77; |
179 | 365 | } |
180 | 2.61k | return {}; |
181 | 2.97k | } |
182 | | |
183 | | // Hash chain for LZ77 matching |
184 | | struct HashChain { |
185 | | size_t size_; |
186 | | std::vector<uint32_t> data_; |
187 | | |
188 | | unsigned hash_num_values_ = 32768; |
189 | | unsigned hash_mask_ = hash_num_values_ - 1; |
190 | | unsigned hash_shift_ = 5; |
191 | | |
192 | | std::vector<int> head; |
193 | | std::vector<uint32_t> chain; |
194 | | std::vector<int> val; |
195 | | |
196 | | // Speed up repetitions of zero |
197 | | std::vector<int> headz; |
198 | | std::vector<uint32_t> chainz; |
199 | | std::vector<uint32_t> zeros; |
200 | | uint32_t numzeros = 0; |
201 | | |
202 | | size_t window_size_; |
203 | | size_t window_mask_; |
204 | | size_t min_length_; |
205 | | size_t max_length_; |
206 | | |
207 | | // Map of special distance codes. |
208 | | std::unordered_map<int, int> special_dist_table_; |
209 | | size_t num_special_distances_ = 0; |
210 | | |
211 | | uint32_t maxchainlength = 256; // window_size_ to allow all |
212 | | |
213 | | HashChain(const Token* data, size_t size, size_t window_size, |
214 | | size_t min_length, size_t max_length, size_t distance_multiplier) |
215 | 0 | : size_(size), |
216 | 0 | window_size_(window_size), |
217 | 0 | window_mask_(window_size - 1), |
218 | 0 | min_length_(min_length), |
219 | 0 | max_length_(max_length) { |
220 | 0 | data_.resize(size); |
221 | 0 | for (size_t i = 0; i < size; i++) { |
222 | 0 | data_[i] = data[i].value; |
223 | 0 | } |
224 | |
|
225 | 0 | head.resize(hash_num_values_, -1); |
226 | 0 | val.resize(window_size_, -1); |
227 | 0 | chain.resize(window_size_); |
228 | 0 | for (uint32_t i = 0; i < window_size_; ++i) { |
229 | 0 | chain[i] = i; // same value as index indicates uninitialized |
230 | 0 | } |
231 | |
|
232 | 0 | zeros.resize(window_size_); |
233 | 0 | headz.resize(window_size_ + 1, -1); |
234 | 0 | chainz.resize(window_size_); |
235 | 0 | for (uint32_t i = 0; i < window_size_; ++i) { |
236 | 0 | chainz[i] = i; |
237 | 0 | } |
238 | | // Translate distance to special distance code. |
239 | 0 | if (distance_multiplier) { |
240 | | // Count down, so if due to small distance multiplier multiple distances |
241 | | // map to the same code, the smallest code will be used in the end. |
242 | 0 | for (int i = kNumSpecialDistances - 1; i >= 0; --i) { |
243 | 0 | special_dist_table_[SpecialDistance(i, distance_multiplier)] = i; |
244 | 0 | } |
245 | 0 | num_special_distances_ = kNumSpecialDistances; |
246 | 0 | } |
247 | 0 | } |
248 | | |
249 | 0 | uint32_t GetHash(size_t pos) const { |
250 | 0 | uint32_t result = 0; |
251 | 0 | if (pos + 2 < size_) { |
252 | | // TODO(lode): take the MSB's of the uint32_t values into account as well, |
253 | | // given that the hash code itself is less than 32 bits. |
254 | 0 | result ^= static_cast<uint32_t>(data_[pos + 0] << 0u); |
255 | 0 | result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_); |
256 | 0 | result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2)); |
257 | 0 | } else { |
258 | | // No need to compute hash of last 2 bytes, the length 2 is too short. |
259 | 0 | return 0; |
260 | 0 | } |
261 | 0 | return result & hash_mask_; |
262 | 0 | } |
263 | | |
264 | 0 | uint32_t CountZeros(size_t pos, uint32_t prevzeros) const { |
265 | 0 | size_t end = pos + window_size_; |
266 | 0 | if (end > size_) end = size_; |
267 | 0 | if (prevzeros > 0) { |
268 | 0 | if (prevzeros >= window_mask_ && data_[end - 1] == 0 && |
269 | 0 | end == pos + window_size_) { |
270 | 0 | return prevzeros; |
271 | 0 | } else { |
272 | 0 | return prevzeros - 1; |
273 | 0 | } |
274 | 0 | } |
275 | 0 | uint32_t num = 0; |
276 | 0 | while (pos + num < end && data_[pos + num] == 0) num++; |
277 | 0 | return num; |
278 | 0 | } |
279 | | |
280 | 0 | void Update(size_t pos) { |
281 | 0 | uint32_t hashval = GetHash(pos); |
282 | 0 | uint32_t wpos = pos & window_mask_; |
283 | |
|
284 | 0 | val[wpos] = static_cast<int>(hashval); |
285 | 0 | if (head[hashval] != -1) chain[wpos] = head[hashval]; |
286 | 0 | head[hashval] = wpos; |
287 | |
|
288 | 0 | if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0; |
289 | 0 | numzeros = CountZeros(pos, numzeros); |
290 | |
|
291 | 0 | zeros[wpos] = numzeros; |
292 | 0 | if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros]; |
293 | 0 | headz[numzeros] = wpos; |
294 | 0 | } |
295 | | |
296 | 0 | void Update(size_t pos, size_t len) { |
297 | 0 | for (size_t i = 0; i < len; i++) { |
298 | 0 | Update(pos + i); |
299 | 0 | } |
300 | 0 | } |
301 | | |
302 | | template <typename CB> |
303 | 0 | void FindMatches(size_t pos, int max_dist, const CB& found_match) const { |
304 | 0 | uint32_t wpos = pos & window_mask_; |
305 | 0 | uint32_t hashval = GetHash(pos); |
306 | 0 | uint32_t hashpos = chain[wpos]; |
307 | |
|
308 | 0 | int prev_dist = 0; |
309 | 0 | int end = std::min<int>(pos + max_length_, size_); |
310 | 0 | uint32_t chainlength = 0; |
311 | 0 | uint32_t best_len = 0; |
312 | 0 | for (;;) { |
313 | 0 | int dist = (hashpos <= wpos) ? (wpos - hashpos) |
314 | 0 | : (wpos - hashpos + window_mask_ + 1); |
315 | 0 | if (dist < prev_dist) break; |
316 | 0 | prev_dist = dist; |
317 | 0 | uint32_t len = 0; |
318 | 0 | if (dist > 0) { |
319 | 0 | int i = pos; |
320 | 0 | int j = pos - dist; |
321 | 0 | if (numzeros > 3) { |
322 | 0 | int r = std::min<int>(numzeros - 1, zeros[hashpos]); |
323 | 0 | if (i + r >= end) r = end - i - 1; |
324 | 0 | i += r; |
325 | 0 | j += r; |
326 | 0 | } |
327 | 0 | while (i < end && data_[i] == data_[j]) { |
328 | 0 | i++; |
329 | 0 | j++; |
330 | 0 | } |
331 | 0 | len = i - pos; |
332 | | // This can trigger even if the new length is slightly smaller than the |
333 | | // best length, because it is possible for a slightly cheaper distance |
334 | | // symbol to occur. |
335 | 0 | if (len >= min_length_ && len + 2 >= best_len) { |
336 | 0 | auto it = special_dist_table_.find(dist); |
337 | 0 | int dist_symbol = (it == special_dist_table_.end()) |
338 | 0 | ? (num_special_distances_ + dist - 1) |
339 | 0 | : it->second; |
340 | 0 | found_match(len, dist_symbol); |
341 | 0 | if (len > best_len) best_len = len; |
342 | 0 | } |
343 | 0 | } |
344 | |
|
345 | 0 | chainlength++; |
346 | 0 | if (chainlength >= maxchainlength) break; |
347 | | |
348 | 0 | if (numzeros >= 3 && len > numzeros) { |
349 | 0 | if (hashpos == chainz[hashpos]) break; |
350 | 0 | hashpos = chainz[hashpos]; |
351 | 0 | if (zeros[hashpos] != numzeros) break; |
352 | 0 | } else { |
353 | 0 | if (hashpos == chain[hashpos]) break; |
354 | 0 | hashpos = chain[hashpos]; |
355 | 0 | if (val[hashpos] != static_cast<int>(hashval)) { |
356 | | // outdated hash value |
357 | 0 | break; |
358 | 0 | } |
359 | 0 | } |
360 | 0 | } |
361 | 0 | } Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1}>(unsigned long, int, jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1} const&) const Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0>(unsigned long, int, jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0 const&) const |
362 | | void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol, |
363 | 0 | size_t* result_len) const { |
364 | 0 | *result_dist_symbol = 0; |
365 | 0 | *result_len = 1; |
366 | 0 | FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) { |
367 | 0 | if (len > *result_len || |
368 | 0 | (len == *result_len && *result_dist_symbol > dist_symbol)) { |
369 | 0 | *result_len = len; |
370 | 0 | *result_dist_symbol = dist_symbol; |
371 | 0 | } |
372 | 0 | }); |
373 | 0 | } |
374 | | }; |
375 | | |
376 | 0 | float LenCost(size_t len) { |
377 | 0 | uint32_t nbits, bits, tok; |
378 | 0 | HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits); |
379 | 0 | constexpr float kCostTable[] = { |
380 | 0 | 2.797667318563126, 3.213177690381199, 2.5706009246743737, |
381 | 0 | 2.408392498667534, 2.829649191872326, 3.3923087753324577, |
382 | 0 | 4.029267451554331, 4.415576699706408, 4.509357574741465, |
383 | 0 | 9.21481543803004, 10.020590190114898, 11.858671627804766, |
384 | 0 | 12.45853300490526, 11.713105831990857, 12.561996324849314, |
385 | 0 | 13.775477692278367, 13.174027068768641, |
386 | 0 | }; |
387 | 0 | size_t table_size = sizeof kCostTable / sizeof *kCostTable; |
388 | 0 | if (tok >= table_size) tok = table_size - 1; |
389 | 0 | return kCostTable[tok] + nbits; |
390 | 0 | } |
391 | | |
392 | | // TODO(veluca): this does not take into account usage or non-usage of distance |
393 | | // multipliers. |
394 | 0 | float DistCost(size_t dist) { |
395 | 0 | uint32_t nbits, bits, tok; |
396 | 0 | HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits); |
397 | 0 | constexpr float kCostTable[] = { |
398 | 0 | 6.368282626312716, 5.680793277090298, 8.347404197105247, |
399 | 0 | 7.641619201599141, 6.914328374119438, 7.959808291537444, |
400 | 0 | 8.70023120759855, 8.71378518934703, 9.379132523982769, |
401 | 0 | 9.110472749092708, 9.159029569270908, 9.430936766731973, |
402 | 0 | 7.278284055315169, 7.8278514904267755, 10.026641158289236, |
403 | 0 | 9.976049229827066, 9.64351607048908, 9.563403863480442, |
404 | 0 | 10.171474111762747, 10.45950155077234, 9.994813912104219, |
405 | 0 | 10.322524683741156, 8.465808729388186, 8.756254166066853, |
406 | 0 | 10.160930174662234, 10.247329273413435, 10.04090403724809, |
407 | 0 | 10.129398517544082, 9.342311691539546, 9.07608009102374, |
408 | 0 | 10.104799540677513, 10.378079384990906, 10.165828974075072, |
409 | 0 | 10.337595322341553, 7.940557464567944, 10.575665823319431, |
410 | 0 | 11.023344321751955, 10.736144698831827, 11.118277044595054, |
411 | 0 | 7.468468230648442, 10.738305230932939, 10.906980780216568, |
412 | 0 | 10.163468216353817, 10.17805759656433, 11.167283670483565, |
413 | 0 | 11.147050200274544, 10.517921919244333, 10.651764778156886, |
414 | 0 | 10.17074446448919, 11.217636876224745, 11.261630721139484, |
415 | 0 | 11.403140815247259, 10.892472096873417, 11.1859607804481, |
416 | 0 | 8.017346947551262, 7.895143720278828, 11.036577113822025, |
417 | 0 | 11.170562110315794, 10.326988722591086, 10.40872184751056, |
418 | 0 | 11.213498225466386, 11.30580635516863, 10.672272515665442, |
419 | 0 | 10.768069466228063, 11.145257364153565, 11.64668307145549, |
420 | 0 | 10.593156194627339, 11.207499484844943, 10.767517766396908, |
421 | 0 | 10.826629811407042, 10.737764794499988, 10.6200448518045, |
422 | 0 | 10.191315385198092, 8.468384171390085, 11.731295299170432, |
423 | 0 | 11.824619886654398, 10.41518844301179, 10.16310536548649, |
424 | 0 | 10.539423685097576, 10.495136599328031, 10.469112847728267, |
425 | 0 | 11.72057686174922, 10.910326337834674, 11.378921834673758, |
426 | 0 | 11.847759036098536, 11.92071647623854, 10.810628276345282, |
427 | 0 | 11.008601085273893, 11.910326337834674, 11.949212023423133, |
428 | 0 | 11.298614839104337, 11.611603659010392, 10.472930394619985, |
429 | 0 | 11.835564720850282, 11.523267392285337, 12.01055816679611, |
430 | 0 | 8.413029688994023, 11.895784139536406, 11.984679534970505, |
431 | 0 | 11.220654278717394, 11.716311684833672, 10.61036646226114, |
432 | 0 | 10.89849965960364, 10.203762898863669, 10.997560826267238, |
433 | 0 | 11.484217379438984, 11.792836176993665, 12.24310468755171, |
434 | 0 | 11.464858097919262, 12.212747017409377, 11.425595666074955, |
435 | 0 | 11.572048533398757, 12.742093965163013, 11.381874288645637, |
436 | 0 | 12.191870445817015, 11.683156920035426, 11.152442115262197, |
437 | 0 | 11.90303691580457, 11.653292787169159, 11.938615382266098, |
438 | 0 | 16.970641701570223, 16.853602280380002, 17.26240782594733, |
439 | 0 | 16.644655390108507, 17.14310889757499, 16.910935455445955, |
440 | 0 | 17.505678976959697, 17.213498225466388, 2.4162310293553024, |
441 | 0 | 3.494587244462329, 3.5258600986408344, 3.4959806589517095, |
442 | 0 | 3.098390886949687, 3.343454654302911, 3.588847442290287, |
443 | 0 | 4.14614790111827, 5.152948641990529, 7.433696808092598, |
444 | 0 | 9.716311684833672, |
445 | 0 | }; |
446 | 0 | size_t table_size = sizeof kCostTable / sizeof *kCostTable; |
447 | 0 | if (tok >= table_size) tok = table_size - 1; |
448 | 0 | return kCostTable[tok] + nbits; |
449 | 0 | } |
450 | | |
451 | | std::vector<std::vector<Token>> ApplyLZ77_LZ77( |
452 | | const HistogramParams& params, size_t num_contexts, |
453 | 0 | const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) { |
454 | 0 | std::vector<std::vector<Token>> tokens_lz77(tokens.size()); |
455 | | // TODO(veluca): tune heuristics here. |
456 | 0 | SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); |
457 | 0 | float bit_decrease = 0; |
458 | 0 | size_t total_symbols = 0; |
459 | 0 | HybridUintConfig uint_config; |
460 | 0 | std::vector<float> sym_cost; |
461 | 0 | for (size_t stream = 0; stream < tokens.size(); stream++) { |
462 | 0 | size_t distance_multiplier = |
463 | 0 | params.image_widths.size() > stream ? params.image_widths[stream] : 0; |
464 | 0 | const auto& in = tokens[stream]; |
465 | 0 | auto& out = tokens_lz77[stream]; |
466 | 0 | total_symbols += in.size(); |
467 | | // Cumulative sum of bit costs. |
468 | 0 | sym_cost.resize(in.size() + 1); |
469 | 0 | for (size_t i = 0; i < in.size(); i++) { |
470 | 0 | uint32_t tok, nbits, unused_bits; |
471 | 0 | uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); |
472 | 0 | sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; |
473 | 0 | } |
474 | |
|
475 | 0 | out.reserve(in.size()); |
476 | 0 | size_t max_distance = in.size(); |
477 | 0 | size_t min_length = lz77.min_length; |
478 | 0 | JXL_DASSERT(min_length >= 3); |
479 | 0 | size_t max_length = in.size(); |
480 | | |
481 | | // Use next power of two as window size. |
482 | 0 | size_t window_size = 1; |
483 | 0 | while (window_size < max_distance && window_size < kWindowSize) { |
484 | 0 | window_size <<= 1; |
485 | 0 | } |
486 | |
|
487 | 0 | HashChain chain(in.data(), in.size(), window_size, min_length, max_length, |
488 | 0 | distance_multiplier); |
489 | 0 | size_t len; |
490 | 0 | size_t dist_symbol; |
491 | |
|
492 | 0 | const size_t max_lazy_match_len = 256; // 0 to disable lazy matching |
493 | | |
494 | | // Whether the next symbol was already updated (to test lazy matching) |
495 | 0 | bool already_updated = false; |
496 | 0 | for (size_t i = 0; i < in.size(); i++) { |
497 | 0 | out.push_back(in[i]); |
498 | 0 | if (!already_updated) chain.Update(i); |
499 | 0 | already_updated = false; |
500 | 0 | chain.FindMatch(i, max_distance, &dist_symbol, &len); |
501 | 0 | if (len >= min_length) { |
502 | 0 | if (len < max_lazy_match_len && i + 1 < in.size()) { |
503 | | // Try length at next symbol lazy matching |
504 | 0 | chain.Update(i + 1); |
505 | 0 | already_updated = true; |
506 | 0 | size_t len2, dist_symbol2; |
507 | 0 | chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2); |
508 | 0 | if (len2 > len) { |
509 | | // Use the lazy match. Add literal, and use the next length starting |
510 | | // from the next byte. |
511 | 0 | ++i; |
512 | 0 | already_updated = false; |
513 | 0 | len = len2; |
514 | 0 | dist_symbol = dist_symbol2; |
515 | 0 | out.push_back(in[i]); |
516 | 0 | } |
517 | 0 | } |
518 | |
|
519 | 0 | float cost = sym_cost[i + len] - sym_cost[i]; |
520 | 0 | size_t lz77_len = len - lz77.min_length; |
521 | 0 | float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) + |
522 | 0 | sce.AddSymbolCost(out.back().context); |
523 | |
|
524 | 0 | if (lz77_cost <= cost) { |
525 | 0 | out.back().value = len - min_length; |
526 | 0 | out.back().is_lz77_length = true; |
527 | 0 | out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); |
528 | 0 | bit_decrease += cost - lz77_cost; |
529 | 0 | } else { |
530 | | // LZ77 match ignored, and symbol already pushed. Push all other |
531 | | // symbols and skip. |
532 | 0 | for (size_t j = 1; j < len; j++) { |
533 | 0 | out.push_back(in[i + j]); |
534 | 0 | } |
535 | 0 | } |
536 | |
|
537 | 0 | if (already_updated) { |
538 | 0 | chain.Update(i + 2, len - 2); |
539 | 0 | already_updated = false; |
540 | 0 | } else { |
541 | 0 | chain.Update(i + 1, len - 1); |
542 | 0 | } |
543 | 0 | i += len - 1; |
544 | 0 | } else { |
545 | | // Literal, already pushed |
546 | 0 | } |
547 | 0 | } |
548 | 0 | } |
549 | |
|
550 | 0 | if (bit_decrease > total_symbols * 0.2 + 16) { |
551 | 0 | return tokens_lz77; |
552 | 0 | } |
553 | 0 | return {}; |
554 | 0 | } |
555 | | |
556 | | std::vector<std::vector<Token>> ApplyLZ77_Optimal( |
557 | | const HistogramParams& params, size_t num_contexts, |
558 | 0 | const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) { |
559 | 0 | std::vector<std::vector<Token>> tokens_for_cost_estimate = |
560 | 0 | ApplyLZ77_LZ77(params, num_contexts, tokens, lz77); |
561 | | // If greedy-LZ77 does not give better compression than no-lz77, no reason to |
562 | | // run the optimal matching. |
563 | 0 | if (tokens_for_cost_estimate.empty()) return {}; |
564 | 0 | SymbolCostEstimator sce(num_contexts + 1, params.force_huffman, |
565 | 0 | tokens_for_cost_estimate, lz77); |
566 | 0 | std::vector<std::vector<Token>> tokens_lz77(tokens.size()); |
567 | 0 | HybridUintConfig uint_config; |
568 | 0 | std::vector<float> sym_cost; |
569 | 0 | std::vector<uint32_t> dist_symbols; |
570 | 0 | for (size_t stream = 0; stream < tokens.size(); stream++) { |
571 | 0 | size_t distance_multiplier = |
572 | 0 | params.image_widths.size() > stream ? params.image_widths[stream] : 0; |
573 | 0 | const auto& in = tokens[stream]; |
574 | 0 | auto& out = tokens_lz77[stream]; |
575 | | // Cumulative sum of bit costs. |
576 | 0 | sym_cost.resize(in.size() + 1); |
577 | 0 | for (size_t i = 0; i < in.size(); i++) { |
578 | 0 | uint32_t tok, nbits, unused_bits; |
579 | 0 | uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); |
580 | 0 | sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; |
581 | 0 | } |
582 | |
|
583 | 0 | out.reserve(in.size()); |
584 | 0 | size_t max_distance = in.size(); |
585 | 0 | size_t min_length = lz77.min_length; |
586 | 0 | JXL_DASSERT(min_length >= 3); |
587 | 0 | size_t max_length = in.size(); |
588 | | |
589 | | // Use next power of two as window size. |
590 | 0 | size_t window_size = 1; |
591 | 0 | while (window_size < max_distance && window_size < kWindowSize) { |
592 | 0 | window_size <<= 1; |
593 | 0 | } |
594 | |
|
595 | 0 | HashChain chain(in.data(), in.size(), window_size, min_length, max_length, |
596 | 0 | distance_multiplier); |
597 | |
|
598 | 0 | struct MatchInfo { |
599 | 0 | uint32_t len; |
600 | 0 | uint32_t dist_symbol; |
601 | 0 | uint32_t ctx; |
602 | 0 | float total_cost = std::numeric_limits<float>::max(); |
603 | 0 | }; |
604 | | // Total cost to encode the first N symbols. |
605 | 0 | std::vector<MatchInfo> prefix_costs(in.size() + 1); |
606 | 0 | prefix_costs[0].total_cost = 0; |
607 | |
|
608 | 0 | size_t rle_length = 0; |
609 | 0 | size_t skip_lz77 = 0; |
610 | 0 | for (size_t i = 0; i < in.size(); i++) { |
611 | 0 | chain.Update(i); |
612 | 0 | float lit_cost = |
613 | 0 | prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i]; |
614 | 0 | if (prefix_costs[i + 1].total_cost > lit_cost) { |
615 | 0 | prefix_costs[i + 1].dist_symbol = 0; |
616 | 0 | prefix_costs[i + 1].len = 1; |
617 | 0 | prefix_costs[i + 1].ctx = in[i].context; |
618 | 0 | prefix_costs[i + 1].total_cost = lit_cost; |
619 | 0 | } |
620 | 0 | if (skip_lz77 > 0) { |
621 | 0 | skip_lz77--; |
622 | 0 | continue; |
623 | 0 | } |
624 | 0 | dist_symbols.clear(); |
625 | 0 | chain.FindMatches(i, max_distance, |
626 | 0 | [&dist_symbols](size_t len, size_t dist_symbol) { |
627 | 0 | if (dist_symbols.size() <= len) { |
628 | 0 | dist_symbols.resize(len + 1, dist_symbol); |
629 | 0 | } |
630 | 0 | if (dist_symbol < dist_symbols[len]) { |
631 | 0 | dist_symbols[len] = dist_symbol; |
632 | 0 | } |
633 | 0 | }); |
634 | 0 | if (dist_symbols.size() <= min_length) continue; |
635 | 0 | { |
636 | 0 | size_t best_cost = dist_symbols.back(); |
637 | 0 | for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) { |
638 | 0 | if (dist_symbols[j] < best_cost) { |
639 | 0 | best_cost = dist_symbols[j]; |
640 | 0 | } |
641 | 0 | dist_symbols[j] = best_cost; |
642 | 0 | } |
643 | 0 | } |
644 | 0 | for (size_t j = min_length; j < dist_symbols.size(); j++) { |
645 | | // Cost model that uses results from lazy LZ77. |
646 | 0 | float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) + |
647 | 0 | sce.DistCost(dist_symbols[j], lz77); |
648 | 0 | float cost = prefix_costs[i].total_cost + lz77_cost; |
649 | 0 | if (prefix_costs[i + j].total_cost > cost) { |
650 | 0 | prefix_costs[i + j].len = j; |
651 | 0 | prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1; |
652 | 0 | prefix_costs[i + j].ctx = in[i].context; |
653 | 0 | prefix_costs[i + j].total_cost = cost; |
654 | 0 | } |
655 | 0 | } |
656 | | // We are in a RLE sequence: skip all the symbols except the first 8 and |
657 | | // the last 8. This avoid quadratic costs for sequences with long runs of |
658 | | // the same symbol. |
659 | 0 | if ((dist_symbols.back() == 0 && distance_multiplier == 0) || |
660 | 0 | (dist_symbols.back() == 1 && distance_multiplier != 0)) { |
661 | 0 | rle_length++; |
662 | 0 | } else { |
663 | 0 | rle_length = 0; |
664 | 0 | } |
665 | 0 | if (rle_length >= 8 && dist_symbols.size() > 9) { |
666 | 0 | skip_lz77 = dist_symbols.size() - 10; |
667 | 0 | rle_length = 0; |
668 | 0 | } |
669 | 0 | } |
670 | 0 | size_t pos = in.size(); |
671 | 0 | while (pos > 0) { |
672 | 0 | bool is_lz77_length = prefix_costs[pos].dist_symbol != 0; |
673 | 0 | if (is_lz77_length) { |
674 | 0 | size_t dist_symbol = prefix_costs[pos].dist_symbol - 1; |
675 | 0 | out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); |
676 | 0 | } |
677 | 0 | size_t val = is_lz77_length ? prefix_costs[pos].len - min_length |
678 | 0 | : in[pos - 1].value; |
679 | 0 | out.emplace_back(prefix_costs[pos].ctx, val); |
680 | 0 | out.back().is_lz77_length = is_lz77_length; |
681 | 0 | pos -= prefix_costs[pos].len; |
682 | 0 | } |
683 | 0 | std::reverse(out.begin(), out.end()); |
684 | 0 | } |
685 | 0 | return tokens_lz77; |
686 | 0 | } |
687 | | |
688 | | } // namespace |
689 | | |
690 | | std::vector<std::vector<Token>> ApplyLZ77( |
691 | | const HistogramParams& params, size_t num_contexts, |
692 | 3.53k | const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) { |
693 | 3.53k | switch (params.lz77_method) { |
694 | 2.97k | case HistogramParams::LZ77Method::kRLE: |
695 | 2.97k | return ApplyLZ77_RLE(params, num_contexts, tokens, lz77); |
696 | 0 | case HistogramParams::LZ77Method::kLZ77: |
697 | 0 | return ApplyLZ77_LZ77(params, num_contexts, tokens, lz77); |
698 | 0 | case HistogramParams::LZ77Method::kOptimal: |
699 | 0 | return ApplyLZ77_Optimal(params, num_contexts, tokens, lz77); |
700 | 558 | default: |
701 | 558 | return {}; |
702 | 3.53k | } |
703 | 3.53k | } |
704 | | |
705 | | } // namespace jxl |