/src/libjxl/lib/jxl/enc_lz77.cc
Line | Count | Source |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/enc_lz77.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cmath> |
10 | | #include <cstddef> |
11 | | #include <cstdint> |
12 | | #include <limits> |
13 | | #include <unordered_map> |
14 | | #include <vector> |
15 | | |
16 | | #include "lib/jxl/ans_params.h" |
17 | | #include "lib/jxl/base/bits.h" |
18 | | #include "lib/jxl/base/fast_math-inl.h" |
19 | | #include "lib/jxl/base/status.h" |
20 | | #include "lib/jxl/dec_ans.h" |
21 | | #include "lib/jxl/enc_ans.h" |
22 | | #include "lib/jxl/enc_ans_params.h" |
23 | | |
24 | | namespace jxl { |
25 | | |
26 | | namespace { |
27 | | |
28 | | class SymbolCostEstimator { |
29 | | public: |
30 | | SymbolCostEstimator(size_t num_contexts, bool force_huffman, |
31 | | const std::vector<std::vector<Token>>& tokens, |
32 | 63.8k | const LZ77Params& lz77) { |
33 | 63.8k | std::vector<Histogram> builder(num_contexts); |
34 | | // Build histograms for estimating lz77 savings. |
35 | 63.8k | HybridUintConfig uint_config; |
36 | 109k | for (const auto& stream : tokens) { |
37 | 101M | for (const auto& token : stream) { |
38 | 101M | uint32_t tok, nbits, bits; |
39 | 101M | (token.is_lz77_length ? lz77.length_uint_config : uint_config) |
40 | 101M | .Encode(token.value, &tok, &nbits, &bits); |
41 | 101M | tok += token.is_lz77_length ? lz77.min_symbol : 0; |
42 | 101M | JXL_DASSERT(token.context < num_contexts); |
43 | 101M | builder[token.context].Add(tok); |
44 | 101M | } |
45 | 109k | } |
46 | 63.8k | max_alphabet_size_ = 0; |
47 | 249k | for (size_t i = 0; i < num_contexts; i++) { |
48 | 185k | max_alphabet_size_ = |
49 | 185k | std::max(max_alphabet_size_, builder[i].counts.size()); |
50 | 185k | } |
51 | 63.8k | bits_.resize(num_contexts * max_alphabet_size_); |
52 | | // TODO(veluca): SIMD? |
53 | 63.8k | add_symbol_cost_.resize(num_contexts); |
54 | 249k | for (size_t i = 0; i < num_contexts; i++) { |
55 | 185k | float inv_total = 1.0f / (builder[i].total_count + 1e-8f); |
56 | 185k | float total_cost = 0; |
57 | 3.98M | for (size_t j = 0; j < builder[i].counts.size(); j++) { |
58 | 3.79M | size_t cnt = builder[i].counts[j]; |
59 | 3.79M | float cost = 0; |
60 | 3.79M | if (cnt != 0 && cnt != builder[i].total_count) { |
61 | 1.80M | cost = -FastLog2f(cnt * inv_total); |
62 | 1.80M | if (force_huffman) cost = std::ceil(cost); |
63 | 1.99M | } else if (cnt == 0) { |
64 | 1.97M | cost = ANS_LOG_TAB_SIZE; // Highest possible cost. |
65 | 1.97M | } |
66 | 3.79M | bits_[i * max_alphabet_size_ + j] = cost; |
67 | 3.79M | total_cost += cost * builder[i].counts[j]; |
68 | 3.79M | } |
69 | | // Penalty for adding a lz77 symbol to this contest (only used for static |
70 | | // cost model). Higher penalty for contexts that have a very low |
71 | | // per-symbol entropy. |
72 | 185k | add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total); |
73 | 185k | } |
74 | 63.8k | } |
75 | 101M | float Bits(size_t ctx, size_t sym) const { |
76 | 101M | return bits_[ctx * max_alphabet_size_ + sym]; |
77 | 101M | } |
78 | 0 | float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const { |
79 | 0 | uint32_t nbits, bits, tok; |
80 | 0 | lz77.length_uint_config.Encode(len, &tok, &nbits, &bits); |
81 | 0 | tok += lz77.min_symbol; |
82 | 0 | return nbits + Bits(ctx, tok); |
83 | 0 | } |
84 | 0 | float DistCost(size_t len, const LZ77Params& lz77) const { |
85 | 0 | uint32_t nbits, bits, tok; |
86 | 0 | HybridUintConfig().Encode(len, &tok, &nbits, &bits); |
87 | 0 | return nbits + Bits(lz77.nonserialized_distance_context, tok); |
88 | 0 | } |
89 | 0 | float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; } |
90 | | |
91 | | private: |
92 | | size_t max_alphabet_size_; |
93 | | std::vector<float> bits_; |
94 | | std::vector<float> add_symbol_cost_; |
95 | | }; |
96 | | |
97 | | std::vector<std::vector<Token>> ApplyLZ77_RLE( |
98 | | const HistogramParams& params, size_t num_contexts, |
99 | 63.8k | const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) { |
100 | 63.8k | std::vector<std::vector<Token>> tokens_lz77(tokens.size()); |
101 | | // TODO(veluca): tune heuristics here. |
102 | 63.8k | SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); |
103 | 63.8k | float bit_decrease = 0; |
104 | 63.8k | size_t total_symbols = 0; |
105 | 63.8k | std::vector<float> sym_cost; |
106 | 63.8k | HybridUintConfig uint_config; |
107 | 173k | for (size_t stream = 0; stream < tokens.size(); stream++) { |
108 | 109k | size_t distance_multiplier = |
109 | 109k | params.image_widths.size() > stream ? params.image_widths[stream] : 0; |
110 | 109k | const auto& in = tokens[stream]; |
111 | 109k | auto& out = tokens_lz77[stream]; |
112 | 109k | total_symbols += in.size(); |
113 | | // Cumulative sum of bit costs. |
114 | 109k | sym_cost.resize(in.size() + 1); |
115 | 101M | for (size_t i = 0; i < in.size(); i++) { |
116 | 101M | uint32_t tok, nbits, unused_bits; |
117 | 101M | uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); |
118 | 101M | sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; |
119 | 101M | } |
120 | 109k | out.reserve(in.size()); |
121 | 38.3M | for (size_t i = 0; i < in.size(); i++) { |
122 | 38.2M | size_t num_to_copy = 0; |
123 | 38.2M | size_t distance_symbol = 0; // 1 for RLE. |
124 | 38.2M | if (distance_multiplier != 0) { |
125 | 15.3M | distance_symbol = 1; // Special distance 1 if enabled. |
126 | 15.3M | JXL_DASSERT(kSpecialDistances[1][0] == 1); |
127 | 15.3M | JXL_DASSERT(kSpecialDistances[1][1] == 0); |
128 | 15.3M | } |
129 | 38.2M | if (i > 0) { |
130 | 107M | for (; i + num_to_copy < in.size(); num_to_copy++) { |
131 | 107M | if (in[i + num_to_copy].value != in[i - 1].value) { |
132 | 38.1M | break; |
133 | 38.1M | } |
134 | 107M | } |
135 | 38.1M | } |
136 | 38.2M | if (num_to_copy == 0) { |
137 | 32.4M | out.push_back(in[i]); |
138 | 32.4M | continue; |
139 | 32.4M | } |
140 | 5.79M | float cost = sym_cost[i + num_to_copy] - sym_cost[i]; |
141 | | // This subtraction might overflow, but that's OK. |
142 | 5.79M | size_t lz77_len = num_to_copy - lz77.min_length; |
143 | 5.79M | float lz77_cost = num_to_copy >= lz77.min_length |
144 | 5.79M | ? CeilLog2Nonzero(lz77_len + 1) + 1 |
145 | 5.79M | : 0; |
146 | 5.79M | if (num_to_copy < lz77.min_length || cost <= lz77_cost) { |
147 | 25.6M | for (size_t j = 0; j < num_to_copy; j++) { |
148 | 20.5M | out.push_back(in[i + j]); |
149 | 20.5M | } |
150 | 5.04M | i += num_to_copy - 1; |
151 | 5.04M | continue; |
152 | 5.04M | } |
153 | | // Output the LZ77 length |
154 | 751k | out.emplace_back(in[i].context, static_cast<uint32_t>(lz77_len)); |
155 | 751k | out.back().is_lz77_length = true; |
156 | 751k | i += num_to_copy - 1; |
157 | 751k | bit_decrease += cost - lz77_cost; |
158 | | // Output the LZ77 copy distance. |
159 | 751k | out.emplace_back( |
160 | 751k | static_cast<uint32_t>(lz77.nonserialized_distance_context), |
161 | 751k | static_cast<uint32_t>(distance_symbol)); |
162 | 751k | } |
163 | 109k | } |
164 | | |
165 | 63.8k | if (bit_decrease > total_symbols * 0.2 + 16) { |
166 | 6.84k | return tokens_lz77; |
167 | 6.84k | } |
168 | 56.9k | return {}; |
169 | 63.8k | } |
170 | | |
171 | | // Hash chain for LZ77 matching |
172 | | struct HashChain { |
173 | | size_t size_; |
174 | | std::vector<uint32_t> data_; |
175 | | |
176 | | unsigned hash_num_values_ = 32768; |
177 | | unsigned hash_mask_ = hash_num_values_ - 1; |
178 | | unsigned hash_shift_ = 5; |
179 | | |
180 | | std::vector<int> head; |
181 | | std::vector<uint32_t> chain; |
182 | | std::vector<int> val; |
183 | | |
184 | | // Speed up repetitions of zero |
185 | | std::vector<int> headz; |
186 | | std::vector<uint32_t> chainz; |
187 | | std::vector<uint32_t> zeros; |
188 | | uint32_t numzeros = 0; |
189 | | |
190 | | size_t window_size_; |
191 | | size_t window_mask_; |
192 | | size_t min_length_; |
193 | | size_t max_length_; |
194 | | |
195 | | // Map of special distance codes. |
196 | | std::unordered_map<int, int> special_dist_table_; |
197 | | size_t num_special_distances_ = 0; |
198 | | |
199 | | uint32_t maxchainlength = 256; // window_size_ to allow all |
200 | | |
201 | | HashChain(const Token* data, size_t size, size_t window_size, |
202 | | size_t min_length, size_t max_length, size_t distance_multiplier) |
203 | 0 | : size_(size), |
204 | 0 | window_size_(window_size), |
205 | 0 | window_mask_(window_size - 1), |
206 | 0 | min_length_(min_length), |
207 | 0 | max_length_(max_length) { |
208 | 0 | data_.resize(size); |
209 | 0 | for (size_t i = 0; i < size; i++) { |
210 | 0 | data_[i] = data[i].value; |
211 | 0 | } |
212 | |
|
213 | 0 | head.resize(hash_num_values_, -1); |
214 | 0 | val.resize(window_size_, -1); |
215 | 0 | chain.resize(window_size_); |
216 | 0 | for (uint32_t i = 0; i < window_size_; ++i) { |
217 | 0 | chain[i] = i; // same value as index indicates uninitialized |
218 | 0 | } |
219 | |
|
220 | 0 | zeros.resize(window_size_); |
221 | 0 | headz.resize(window_size_ + 1, -1); |
222 | 0 | chainz.resize(window_size_); |
223 | 0 | for (uint32_t i = 0; i < window_size_; ++i) { |
224 | 0 | chainz[i] = i; |
225 | 0 | } |
226 | | // Translate distance to special distance code. |
227 | 0 | if (distance_multiplier) { |
228 | | // Count down, so if due to small distance multiplier multiple distances |
229 | | // map to the same code, the smallest code will be used in the end. |
230 | 0 | for (int i = kNumSpecialDistances - 1; i >= 0; --i) { |
231 | 0 | special_dist_table_[SpecialDistance(i, distance_multiplier)] = i; |
232 | 0 | } |
233 | 0 | num_special_distances_ = kNumSpecialDistances; |
234 | 0 | } |
235 | 0 | } |
236 | | |
237 | 0 | uint32_t GetHash(size_t pos) const { |
238 | 0 | uint32_t result = 0; |
239 | 0 | if (pos + 2 < size_) { |
240 | | // TODO(lode): take the MSB's of the uint32_t values into account as well, |
241 | | // given that the hash code itself is less than 32 bits. |
242 | 0 | result ^= static_cast<uint32_t>(data_[pos + 0] << 0u); |
243 | 0 | result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_); |
244 | 0 | result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2)); |
245 | 0 | } else { |
246 | | // No need to compute hash of last 2 bytes, the length 2 is too short. |
247 | 0 | return 0; |
248 | 0 | } |
249 | 0 | return result & hash_mask_; |
250 | 0 | } |
251 | | |
252 | 0 | uint32_t CountZeros(size_t pos, uint32_t prevzeros) const { |
253 | 0 | size_t end = pos + window_size_; |
254 | 0 | if (end > size_) end = size_; |
255 | 0 | if (prevzeros > 0) { |
256 | 0 | if (prevzeros >= window_mask_ && data_[end - 1] == 0 && |
257 | 0 | end == pos + window_size_) { |
258 | 0 | return prevzeros; |
259 | 0 | } else { |
260 | 0 | return prevzeros - 1; |
261 | 0 | } |
262 | 0 | } |
263 | 0 | uint32_t num = 0; |
264 | 0 | while (pos + num < end && data_[pos + num] == 0) num++; |
265 | 0 | return num; |
266 | 0 | } |
267 | | |
268 | 0 | void Update(size_t pos) { |
269 | 0 | uint32_t hashval = GetHash(pos); |
270 | 0 | uint32_t wpos = pos & window_mask_; |
271 | |
|
272 | 0 | val[wpos] = static_cast<int>(hashval); |
273 | 0 | if (head[hashval] != -1) chain[wpos] = head[hashval]; |
274 | 0 | head[hashval] = wpos; |
275 | |
|
276 | 0 | if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0; |
277 | 0 | numzeros = CountZeros(pos, numzeros); |
278 | |
|
279 | 0 | zeros[wpos] = numzeros; |
280 | 0 | if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros]; |
281 | 0 | headz[numzeros] = wpos; |
282 | 0 | } |
283 | | |
284 | 0 | void Update(size_t pos, size_t len) { |
285 | 0 | for (size_t i = 0; i < len; i++) { |
286 | 0 | Update(pos + i); |
287 | 0 | } |
288 | 0 | } |
289 | | |
290 | | template <typename CB> |
291 | 0 | void FindMatches(size_t pos, int max_dist, const CB& found_match) const { |
292 | 0 | uint32_t wpos = pos & window_mask_; |
293 | 0 | uint32_t hashval = GetHash(pos); |
294 | 0 | uint32_t hashpos = chain[wpos]; |
295 | |
|
296 | 0 | int prev_dist = 0; |
297 | 0 | int end = std::min<int>(pos + max_length_, size_); |
298 | 0 | uint32_t chainlength = 0; |
299 | 0 | uint32_t best_len = 0; |
300 | 0 | for (;;) { |
301 | 0 | int dist = (hashpos <= wpos) ? (wpos - hashpos) |
302 | 0 | : (wpos - hashpos + window_mask_ + 1); |
303 | 0 | if (dist < prev_dist) break; |
304 | 0 | prev_dist = dist; |
305 | 0 | uint32_t len = 0; |
306 | 0 | if (dist > 0) { |
307 | 0 | int i = pos; |
308 | 0 | int j = pos - dist; |
309 | 0 | if (numzeros > 3) { |
310 | 0 | int r = std::min<int>(numzeros - 1, zeros[hashpos]); |
311 | 0 | if (i + r >= end) r = end - i - 1; |
312 | 0 | i += r; |
313 | 0 | j += r; |
314 | 0 | } |
315 | 0 | while (i < end && data_[i] == data_[j]) { |
316 | 0 | i++; |
317 | 0 | j++; |
318 | 0 | } |
319 | 0 | len = i - pos; |
320 | | // This can trigger even if the new length is slightly smaller than the |
321 | | // best length, because it is possible for a slightly cheaper distance |
322 | | // symbol to occur. |
323 | 0 | if (len >= min_length_ && len + 2 >= best_len) { |
324 | 0 | auto it = special_dist_table_.find(dist); |
325 | 0 | int dist_symbol = (it == special_dist_table_.end()) |
326 | 0 | ? (num_special_distances_ + dist - 1) |
327 | 0 | : it->second; |
328 | 0 | found_match(len, dist_symbol); |
329 | 0 | if (len > best_len) best_len = len; |
330 | 0 | } |
331 | 0 | } |
332 | |
|
333 | 0 | chainlength++; |
334 | 0 | if (chainlength >= maxchainlength) break; |
335 | | |
336 | 0 | if (numzeros >= 3 && len > numzeros) { |
337 | 0 | if (hashpos == chainz[hashpos]) break; |
338 | 0 | hashpos = chainz[hashpos]; |
339 | 0 | if (zeros[hashpos] != numzeros) break; |
340 | 0 | } else { |
341 | 0 | if (hashpos == chain[hashpos]) break; |
342 | 0 | hashpos = chain[hashpos]; |
343 | 0 | if (val[hashpos] != static_cast<int>(hashval)) { |
344 | | // outdated hash value |
345 | 0 | break; |
346 | 0 | } |
347 | 0 | } |
348 | 0 | } |
349 | 0 | } Unexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1}>(unsigned long, int, jxl::(anonymous namespace)::HashChain::FindMatch(unsigned long, int, unsigned long*, unsigned long*) const::{lambda(unsigned long, unsigned long)#1} const&) constUnexecuted instantiation: enc_lz77.cc:void jxl::(anonymous namespace)::HashChain::FindMatches<jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0>(unsigned long, int, jxl::(anonymous namespace)::ApplyLZ77_Optimal(jxl::HistogramParams const&, unsigned long, std::__1::vector<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> >, std::__1::allocator<std::__1::vector<jxl::Token, std::__1::allocator<jxl::Token> > > > const&, jxl::LZ77Params const&)::$_0 const&) const |
350 | | void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol, |
351 | 0 | size_t* result_len) const { |
352 | 0 | *result_dist_symbol = 0; |
353 | 0 | *result_len = 1; |
354 | 0 | FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) { |
355 | 0 | if (len > *result_len || |
356 | 0 | (len == *result_len && *result_dist_symbol > dist_symbol)) { |
357 | 0 | *result_len = len; |
358 | 0 | *result_dist_symbol = dist_symbol; |
359 | 0 | } |
360 | 0 | }); |
361 | 0 | } |
362 | | }; |
363 | | |
364 | 0 | float LenCost(size_t len) { |
365 | 0 | uint32_t nbits, bits, tok; |
366 | 0 | HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits); |
367 | 0 | constexpr float kCostTable[] = { |
368 | 0 | 2.797667318563126, 3.213177690381199, 2.5706009246743737, |
369 | 0 | 2.408392498667534, 2.829649191872326, 3.3923087753324577, |
370 | 0 | 4.029267451554331, 4.415576699706408, 4.509357574741465, |
371 | 0 | 9.21481543803004, 10.020590190114898, 11.858671627804766, |
372 | 0 | 12.45853300490526, 11.713105831990857, 12.561996324849314, |
373 | 0 | 13.775477692278367, 13.174027068768641, |
374 | 0 | }; |
375 | 0 | size_t table_size = sizeof kCostTable / sizeof *kCostTable; |
376 | 0 | if (tok >= table_size) tok = table_size - 1; |
377 | 0 | return kCostTable[tok] + nbits; |
378 | 0 | } |
379 | | |
380 | | // TODO(veluca): this does not take into account usage or non-usage of distance |
381 | | // multipliers. |
382 | 0 | float DistCost(size_t dist) { |
383 | 0 | uint32_t nbits, bits, tok; |
384 | 0 | HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits); |
385 | 0 | constexpr float kCostTable[] = { |
386 | 0 | 6.368282626312716, 5.680793277090298, 8.347404197105247, |
387 | 0 | 7.641619201599141, 6.914328374119438, 7.959808291537444, |
388 | 0 | 8.70023120759855, 8.71378518934703, 9.379132523982769, |
389 | 0 | 9.110472749092708, 9.159029569270908, 9.430936766731973, |
390 | 0 | 7.278284055315169, 7.8278514904267755, 10.026641158289236, |
391 | 0 | 9.976049229827066, 9.64351607048908, 9.563403863480442, |
392 | 0 | 10.171474111762747, 10.45950155077234, 9.994813912104219, |
393 | 0 | 10.322524683741156, 8.465808729388186, 8.756254166066853, |
394 | 0 | 10.160930174662234, 10.247329273413435, 10.04090403724809, |
395 | 0 | 10.129398517544082, 9.342311691539546, 9.07608009102374, |
396 | 0 | 10.104799540677513, 10.378079384990906, 10.165828974075072, |
397 | 0 | 10.337595322341553, 7.940557464567944, 10.575665823319431, |
398 | 0 | 11.023344321751955, 10.736144698831827, 11.118277044595054, |
399 | 0 | 7.468468230648442, 10.738305230932939, 10.906980780216568, |
400 | 0 | 10.163468216353817, 10.17805759656433, 11.167283670483565, |
401 | 0 | 11.147050200274544, 10.517921919244333, 10.651764778156886, |
402 | 0 | 10.17074446448919, 11.217636876224745, 11.261630721139484, |
403 | 0 | 11.403140815247259, 10.892472096873417, 11.1859607804481, |
404 | 0 | 8.017346947551262, 7.895143720278828, 11.036577113822025, |
405 | 0 | 11.170562110315794, 10.326988722591086, 10.40872184751056, |
406 | 0 | 11.213498225466386, 11.30580635516863, 10.672272515665442, |
407 | 0 | 10.768069466228063, 11.145257364153565, 11.64668307145549, |
408 | 0 | 10.593156194627339, 11.207499484844943, 10.767517766396908, |
409 | 0 | 10.826629811407042, 10.737764794499988, 10.6200448518045, |
410 | 0 | 10.191315385198092, 8.468384171390085, 11.731295299170432, |
411 | 0 | 11.824619886654398, 10.41518844301179, 10.16310536548649, |
412 | 0 | 10.539423685097576, 10.495136599328031, 10.469112847728267, |
413 | 0 | 11.72057686174922, 10.910326337834674, 11.378921834673758, |
414 | 0 | 11.847759036098536, 11.92071647623854, 10.810628276345282, |
415 | 0 | 11.008601085273893, 11.910326337834674, 11.949212023423133, |
416 | 0 | 11.298614839104337, 11.611603659010392, 10.472930394619985, |
417 | 0 | 11.835564720850282, 11.523267392285337, 12.01055816679611, |
418 | 0 | 8.413029688994023, 11.895784139536406, 11.984679534970505, |
419 | 0 | 11.220654278717394, 11.716311684833672, 10.61036646226114, |
420 | 0 | 10.89849965960364, 10.203762898863669, 10.997560826267238, |
421 | 0 | 11.484217379438984, 11.792836176993665, 12.24310468755171, |
422 | 0 | 11.464858097919262, 12.212747017409377, 11.425595666074955, |
423 | 0 | 11.572048533398757, 12.742093965163013, 11.381874288645637, |
424 | 0 | 12.191870445817015, 11.683156920035426, 11.152442115262197, |
425 | 0 | 11.90303691580457, 11.653292787169159, 11.938615382266098, |
426 | 0 | 16.970641701570223, 16.853602280380002, 17.26240782594733, |
427 | 0 | 16.644655390108507, 17.14310889757499, 16.910935455445955, |
428 | 0 | 17.505678976959697, 17.213498225466388, 2.4162310293553024, |
429 | 0 | 3.494587244462329, 3.5258600986408344, 3.4959806589517095, |
430 | 0 | 3.098390886949687, 3.343454654302911, 3.588847442290287, |
431 | 0 | 4.14614790111827, 5.152948641990529, 7.433696808092598, |
432 | 0 | 9.716311684833672, |
433 | 0 | }; |
434 | 0 | size_t table_size = sizeof kCostTable / sizeof *kCostTable; |
435 | 0 | if (tok >= table_size) tok = table_size - 1; |
436 | 0 | return kCostTable[tok] + nbits; |
437 | 0 | } |
438 | | |
439 | | std::vector<std::vector<Token>> ApplyLZ77_LZ77( |
440 | | const HistogramParams& params, size_t num_contexts, |
441 | 0 | const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) { |
442 | 0 | std::vector<std::vector<Token>> tokens_lz77(tokens.size()); |
443 | | // TODO(veluca): tune heuristics here. |
444 | 0 | SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); |
445 | 0 | float bit_decrease = 0; |
446 | 0 | size_t total_symbols = 0; |
447 | 0 | HybridUintConfig uint_config; |
448 | 0 | std::vector<float> sym_cost; |
449 | 0 | for (size_t stream = 0; stream < tokens.size(); stream++) { |
450 | 0 | size_t distance_multiplier = |
451 | 0 | params.image_widths.size() > stream ? params.image_widths[stream] : 0; |
452 | 0 | const auto& in = tokens[stream]; |
453 | 0 | auto& out = tokens_lz77[stream]; |
454 | 0 | total_symbols += in.size(); |
455 | | // Cumulative sum of bit costs. |
456 | 0 | sym_cost.resize(in.size() + 1); |
457 | 0 | for (size_t i = 0; i < in.size(); i++) { |
458 | 0 | uint32_t tok, nbits, unused_bits; |
459 | 0 | uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); |
460 | 0 | sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; |
461 | 0 | } |
462 | |
|
463 | 0 | out.reserve(in.size()); |
464 | 0 | size_t max_distance = in.size(); |
465 | 0 | size_t min_length = lz77.min_length; |
466 | 0 | JXL_DASSERT(min_length >= 3); |
467 | 0 | size_t max_length = in.size(); |
468 | | |
469 | | // Use next power of two as window size. |
470 | 0 | size_t window_size = 1; |
471 | 0 | while (window_size < max_distance && window_size < kWindowSize) { |
472 | 0 | window_size <<= 1; |
473 | 0 | } |
474 | |
|
475 | 0 | HashChain chain(in.data(), in.size(), window_size, min_length, max_length, |
476 | 0 | distance_multiplier); |
477 | 0 | size_t len; |
478 | 0 | size_t dist_symbol; |
479 | |
|
480 | 0 | const size_t max_lazy_match_len = 256; // 0 to disable lazy matching |
481 | | |
482 | | // Whether the next symbol was already updated (to test lazy matching) |
483 | 0 | bool already_updated = false; |
484 | 0 | for (size_t i = 0; i < in.size(); i++) { |
485 | 0 | out.push_back(in[i]); |
486 | 0 | if (!already_updated) chain.Update(i); |
487 | 0 | already_updated = false; |
488 | 0 | chain.FindMatch(i, max_distance, &dist_symbol, &len); |
489 | 0 | if (len >= min_length) { |
490 | 0 | if (len < max_lazy_match_len && i + 1 < in.size()) { |
491 | | // Try length at next symbol lazy matching |
492 | 0 | chain.Update(i + 1); |
493 | 0 | already_updated = true; |
494 | 0 | size_t len2, dist_symbol2; |
495 | 0 | chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2); |
496 | 0 | if (len2 > len) { |
497 | | // Use the lazy match. Add literal, and use the next length starting |
498 | | // from the next byte. |
499 | 0 | ++i; |
500 | 0 | already_updated = false; |
501 | 0 | len = len2; |
502 | 0 | dist_symbol = dist_symbol2; |
503 | 0 | out.push_back(in[i]); |
504 | 0 | } |
505 | 0 | } |
506 | |
|
507 | 0 | float cost = sym_cost[i + len] - sym_cost[i]; |
508 | 0 | size_t lz77_len = len - lz77.min_length; |
509 | 0 | float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) + |
510 | 0 | sce.AddSymbolCost(out.back().context); |
511 | |
|
512 | 0 | if (lz77_cost <= cost) { |
513 | 0 | out.back().value = len - min_length; |
514 | 0 | out.back().is_lz77_length = true; |
515 | 0 | out.emplace_back( |
516 | 0 | static_cast<uint32_t>(lz77.nonserialized_distance_context), |
517 | 0 | static_cast<uint32_t>(dist_symbol)); |
518 | 0 | bit_decrease += cost - lz77_cost; |
519 | 0 | } else { |
520 | | // LZ77 match ignored, and symbol already pushed. Push all other |
521 | | // symbols and skip. |
522 | 0 | for (size_t j = 1; j < len; j++) { |
523 | 0 | out.push_back(in[i + j]); |
524 | 0 | } |
525 | 0 | } |
526 | |
|
527 | 0 | if (already_updated) { |
528 | 0 | chain.Update(i + 2, len - 2); |
529 | 0 | already_updated = false; |
530 | 0 | } else { |
531 | 0 | chain.Update(i + 1, len - 1); |
532 | 0 | } |
533 | 0 | i += len - 1; |
534 | 0 | } else { |
535 | | // Literal, already pushed |
536 | 0 | } |
537 | 0 | } |
538 | 0 | } |
539 | | |
540 | 0 | if (bit_decrease > total_symbols * 0.2 + 16) { |
541 | 0 | return tokens_lz77; |
542 | 0 | } |
543 | 0 | return {}; |
544 | 0 | } |
545 | | |
546 | | std::vector<std::vector<Token>> ApplyLZ77_Optimal( |
547 | | const HistogramParams& params, size_t num_contexts, |
548 | 0 | const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) { |
549 | 0 | std::vector<std::vector<Token>> tokens_for_cost_estimate = |
550 | 0 | ApplyLZ77_LZ77(params, num_contexts, tokens, lz77); |
551 | | // If greedy-LZ77 does not give better compression than no-lz77, no reason to |
552 | | // run the optimal matching. |
553 | 0 | if (tokens_for_cost_estimate.empty()) return {}; |
554 | 0 | SymbolCostEstimator sce(num_contexts + 1, params.force_huffman, |
555 | 0 | tokens_for_cost_estimate, lz77); |
556 | 0 | std::vector<std::vector<Token>> tokens_lz77(tokens.size()); |
557 | 0 | HybridUintConfig uint_config; |
558 | 0 | std::vector<float> sym_cost; |
559 | 0 | std::vector<uint32_t> dist_symbols; |
560 | 0 | for (size_t stream = 0; stream < tokens.size(); stream++) { |
561 | 0 | size_t distance_multiplier = |
562 | 0 | params.image_widths.size() > stream ? params.image_widths[stream] : 0; |
563 | 0 | const auto& in = tokens[stream]; |
564 | 0 | auto& out = tokens_lz77[stream]; |
565 | | // Cumulative sum of bit costs. |
566 | 0 | sym_cost.resize(in.size() + 1); |
567 | 0 | for (size_t i = 0; i < in.size(); i++) { |
568 | 0 | uint32_t tok, nbits, unused_bits; |
569 | 0 | uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); |
570 | 0 | sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; |
571 | 0 | } |
572 | |
|
573 | 0 | out.reserve(in.size()); |
574 | 0 | size_t max_distance = in.size(); |
575 | 0 | size_t min_length = lz77.min_length; |
576 | 0 | JXL_DASSERT(min_length >= 3); |
577 | 0 | size_t max_length = in.size(); |
578 | | |
579 | | // Use next power of two as window size. |
580 | 0 | size_t window_size = 1; |
581 | 0 | while (window_size < max_distance && window_size < kWindowSize) { |
582 | 0 | window_size <<= 1; |
583 | 0 | } |
584 | |
|
585 | 0 | HashChain chain(in.data(), in.size(), window_size, min_length, max_length, |
586 | 0 | distance_multiplier); |
587 | |
|
588 | 0 | struct MatchInfo { |
589 | 0 | uint32_t len; |
590 | 0 | uint32_t dist_symbol; |
591 | 0 | uint32_t ctx; |
592 | 0 | float total_cost = std::numeric_limits<float>::max(); |
593 | 0 | }; |
594 | | // Total cost to encode the first N symbols. |
595 | 0 | std::vector<MatchInfo> prefix_costs(in.size() + 1); |
596 | 0 | prefix_costs[0].total_cost = 0; |
597 | |
|
598 | 0 | size_t rle_length = 0; |
599 | 0 | size_t skip_lz77 = 0; |
600 | 0 | for (size_t i = 0; i < in.size(); i++) { |
601 | 0 | chain.Update(i); |
602 | 0 | float lit_cost = |
603 | 0 | prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i]; |
604 | 0 | if (prefix_costs[i + 1].total_cost > lit_cost) { |
605 | 0 | prefix_costs[i + 1].dist_symbol = 0; |
606 | 0 | prefix_costs[i + 1].len = 1; |
607 | 0 | prefix_costs[i + 1].ctx = in[i].context; |
608 | 0 | prefix_costs[i + 1].total_cost = lit_cost; |
609 | 0 | } |
610 | 0 | if (skip_lz77 > 0) { |
611 | 0 | skip_lz77--; |
612 | 0 | continue; |
613 | 0 | } |
614 | 0 | dist_symbols.clear(); |
615 | 0 | chain.FindMatches(i, max_distance, |
616 | 0 | [&dist_symbols](size_t len, size_t dist_symbol) { |
617 | 0 | if (dist_symbols.size() <= len) { |
618 | 0 | dist_symbols.resize(len + 1, dist_symbol); |
619 | 0 | } |
620 | 0 | if (dist_symbol < dist_symbols[len]) { |
621 | 0 | dist_symbols[len] = dist_symbol; |
622 | 0 | } |
623 | 0 | }); |
624 | 0 | if (dist_symbols.size() <= min_length) continue; |
625 | 0 | { |
626 | 0 | size_t best_cost = dist_symbols.back(); |
627 | 0 | for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) { |
628 | 0 | if (dist_symbols[j] < best_cost) { |
629 | 0 | best_cost = dist_symbols[j]; |
630 | 0 | } |
631 | 0 | dist_symbols[j] = best_cost; |
632 | 0 | } |
633 | 0 | } |
634 | 0 | for (size_t j = min_length; j < dist_symbols.size(); j++) { |
635 | | // Cost model that uses results from lazy LZ77. |
636 | 0 | float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) + |
637 | 0 | sce.DistCost(dist_symbols[j], lz77); |
638 | 0 | float cost = prefix_costs[i].total_cost + lz77_cost; |
639 | 0 | if (prefix_costs[i + j].total_cost > cost) { |
640 | 0 | prefix_costs[i + j].len = j; |
641 | 0 | prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1; |
642 | 0 | prefix_costs[i + j].ctx = in[i].context; |
643 | 0 | prefix_costs[i + j].total_cost = cost; |
644 | 0 | } |
645 | 0 | } |
646 | | // We are in a RLE sequence: skip all the symbols except the first 8 and |
647 | | // the last 8. This avoid quadratic costs for sequences with long runs of |
648 | | // the same symbol. |
649 | 0 | if ((dist_symbols.back() == 0 && distance_multiplier == 0) || |
650 | 0 | (dist_symbols.back() == 1 && distance_multiplier != 0)) { |
651 | 0 | rle_length++; |
652 | 0 | } else { |
653 | 0 | rle_length = 0; |
654 | 0 | } |
655 | 0 | if (rle_length >= 8 && dist_symbols.size() > 9) { |
656 | 0 | skip_lz77 = dist_symbols.size() - 10; |
657 | 0 | rle_length = 0; |
658 | 0 | } |
659 | 0 | } |
660 | 0 | size_t pos = in.size(); |
661 | 0 | while (pos > 0) { |
662 | 0 | bool is_lz77_length = prefix_costs[pos].dist_symbol != 0; |
663 | 0 | if (is_lz77_length) { |
664 | 0 | size_t dist_symbol = prefix_costs[pos].dist_symbol - 1; |
665 | 0 | out.emplace_back( |
666 | 0 | static_cast<uint32_t>(lz77.nonserialized_distance_context), |
667 | 0 | static_cast<uint32_t>(dist_symbol)); |
668 | 0 | } |
669 | 0 | uint32_t val = |
670 | 0 | is_lz77_length |
671 | 0 | ? (prefix_costs[pos].len - static_cast<uint32_t>(min_length)) |
672 | 0 | : in[pos - 1].value; |
673 | 0 | out.emplace_back(prefix_costs[pos].ctx, val); |
674 | 0 | out.back().is_lz77_length = is_lz77_length; |
675 | 0 | pos -= prefix_costs[pos].len; |
676 | 0 | } |
677 | 0 | if (!out.empty()) std::reverse(out.begin(), out.end()); |
678 | 0 | } |
679 | 0 | return tokens_lz77; |
680 | 0 | } |
681 | | |
682 | | } // namespace |
683 | | |
684 | | std::vector<std::vector<Token>> ApplyLZ77( |
685 | | const HistogramParams& params, size_t num_contexts, |
686 | 77.4k | const std::vector<std::vector<Token>>& tokens, const LZ77Params& lz77) { |
687 | 77.4k | switch (params.lz77_method) { |
688 | 63.8k | case HistogramParams::LZ77Method::kRLE: |
689 | 63.8k | return ApplyLZ77_RLE(params, num_contexts, tokens, lz77); |
690 | 0 | case HistogramParams::LZ77Method::kLZ77: |
691 | 0 | return ApplyLZ77_LZ77(params, num_contexts, tokens, lz77); |
692 | 0 | case HistogramParams::LZ77Method::kOptimal: |
693 | 0 | return ApplyLZ77_Optimal(params, num_contexts, tokens, lz77); |
694 | 13.5k | default: |
695 | 13.5k | return {}; |
696 | 77.4k | } |
697 | 77.4k | } |
698 | | |
699 | | } // namespace jxl |