/src/libjxl/lib/jxl/dec_ans.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #ifndef LIB_JXL_DEC_ANS_H_ |
7 | | #define LIB_JXL_DEC_ANS_H_ |
8 | | |
9 | | // Library to decode the ANS population counts from the bit-stream and build a |
10 | | // decoding table from them. |
11 | | |
12 | | #include <jxl/memory_manager.h> |
13 | | #include <jxl/types.h> |
14 | | |
15 | | #include <algorithm> |
16 | | #include <cstddef> |
17 | | #include <cstdint> |
18 | | #include <cstring> |
19 | | #include <vector> |
20 | | |
21 | | #include "lib/jxl/ans_common.h" |
22 | | #include "lib/jxl/ans_params.h" |
23 | | #include "lib/jxl/base/bits.h" |
24 | | #include "lib/jxl/base/compiler_specific.h" |
25 | | #include "lib/jxl/base/status.h" |
26 | | #include "lib/jxl/dec_bit_reader.h" |
27 | | #include "lib/jxl/dec_huffman.h" |
28 | | #include "lib/jxl/field_encodings.h" |
29 | | #include "lib/jxl/memory_manager_internal.h" |
30 | | |
31 | | namespace jxl { |
32 | | |
33 | | class ANSSymbolReader; |
34 | | |
35 | | // Experiments show that best performance is typically achieved for a |
36 | | // split-exponent of 3 or 4. Trend seems to be that '4' is better |
37 | | // for large-ish pictures, and '3' better for rather small-ish pictures. |
38 | | // This is plausible - the more special symbols we have, the better |
39 | | // statistics we need to get a benefit out of them. |
40 | | |
41 | | // Our hybrid-encoding scheme has dedicated tokens for the smallest |
42 | | // (1 << split_exponents) numbers, and for the rest |
43 | | // encodes (number of bits) + (msb_in_token sub-leading binary digits) + |
44 | | // (lsb_in_token lowest binary digits) in the token, with the remaining bits |
45 | | // then being encoded as data. |
46 | | // |
47 | | // Example with split_exponent = 4, msb_in_token = 2, lsb_in_token = 0. |
48 | | // |
49 | | // Numbers N in [0 .. 15]: |
50 | | // These get represented as (token=N, bits=''). |
51 | | // Numbers N >= 16: |
52 | | // If n is such that 2**n <= N < 2**(n+1), |
53 | | // and m = N - 2**n is the 'mantissa', |
54 | | // these get represented as: |
55 | | // (token=split_token + |
56 | | // ((n - split_exponent) * 4) + |
57 | | // (m >> (n - msb_in_token)), |
58 | | // bits=m & (1 << (n - msb_in_token)) - 1) |
59 | | // Specifically, we would get: |
60 | | // N = 0 - 15: (token=N, nbits=0, bits='') |
61 | | // N = 16 (10000): (token=16, nbits=2, bits='00') |
62 | | // N = 17 (10001): (token=16, nbits=2, bits='01') |
63 | | // N = 20 (10100): (token=17, nbits=2, bits='00') |
64 | | // N = 24 (11000): (token=18, nbits=2, bits='00') |
65 | | // N = 28 (11100): (token=19, nbits=2, bits='00') |
66 | | // N = 32 (100000): (token=20, nbits=3, bits='000') |
67 | | // N = 65535: (token=63, nbits=13, bits='1111111111111') |
68 | | struct HybridUintConfig { |
69 | | uint32_t split_exponent; |
70 | | uint32_t split_token; |
71 | | uint32_t msb_in_token; |
72 | | uint32_t lsb_in_token; |
73 | | JXL_INLINE void Encode(uint32_t value, uint32_t* JXL_RESTRICT token, |
74 | | uint32_t* JXL_RESTRICT nbits, |
75 | 155M | uint32_t* JXL_RESTRICT bits) const { |
76 | 155M | if (value < split_token) { |
77 | 122M | *token = value; |
78 | 122M | *nbits = 0; |
79 | 122M | *bits = 0; |
80 | 122M | } else { |
81 | 32.5M | uint32_t n = FloorLog2Nonzero(value); |
82 | 32.5M | uint32_t m = value - (1 << n); |
83 | 32.5M | *token = split_token + |
84 | 32.5M | ((n - split_exponent) << (msb_in_token + lsb_in_token)) + |
85 | 32.5M | ((m >> (n - msb_in_token)) << lsb_in_token) + |
86 | 32.5M | (m & ((1 << lsb_in_token) - 1)); |
87 | 32.5M | *nbits = n - msb_in_token - lsb_in_token; |
88 | 32.5M | *bits = (value >> lsb_in_token) & ((1UL << *nbits) - 1); |
89 | 32.5M | } |
90 | 155M | } |
91 | | |
92 | | explicit HybridUintConfig(uint32_t split_exponent = 4, |
93 | | uint32_t msb_in_token = 2, |
94 | | uint32_t lsb_in_token = 0) |
95 | 1.44M | : split_exponent(split_exponent), |
96 | 1.44M | split_token(1 << split_exponent), |
97 | 1.44M | msb_in_token(msb_in_token), |
98 | 1.44M | lsb_in_token(lsb_in_token) { |
99 | 1.44M | JXL_DASSERT(split_exponent >= msb_in_token + lsb_in_token); |
100 | 1.44M | } |
101 | | }; |
102 | | |
103 | | struct LZ77Params : public Fields { |
104 | | LZ77Params(); |
105 | | JXL_FIELDS_NAME(LZ77Params) |
106 | | Status VisitFields(Visitor* JXL_RESTRICT visitor) override; |
107 | | bool enabled; |
108 | | |
109 | | // Symbols above min_symbol use a special hybrid uint encoding and |
110 | | // represent a length, to be added to min_length. |
111 | | uint32_t min_symbol; |
112 | | uint32_t min_length; |
113 | | |
114 | | // Not serialized by VisitFields. |
115 | | HybridUintConfig length_uint_config{0, 0, 0}; |
116 | | |
117 | | size_t nonserialized_distance_context; |
118 | | }; |
119 | | |
120 | | static constexpr size_t kWindowSize = 1 << 20; |
121 | | static constexpr size_t kNumSpecialDistances = 120; |
122 | | // Table of special distance codes from WebP lossless. |
123 | | static constexpr int8_t kSpecialDistances[kNumSpecialDistances][2] = { |
124 | | {0, 1}, {1, 0}, {1, 1}, {-1, 1}, {0, 2}, {2, 0}, {1, 2}, {-1, 2}, |
125 | | {2, 1}, {-2, 1}, {2, 2}, {-2, 2}, {0, 3}, {3, 0}, {1, 3}, {-1, 3}, |
126 | | {3, 1}, {-3, 1}, {2, 3}, {-2, 3}, {3, 2}, {-3, 2}, {0, 4}, {4, 0}, |
127 | | {1, 4}, {-1, 4}, {4, 1}, {-4, 1}, {3, 3}, {-3, 3}, {2, 4}, {-2, 4}, |
128 | | {4, 2}, {-4, 2}, {0, 5}, {3, 4}, {-3, 4}, {4, 3}, {-4, 3}, {5, 0}, |
129 | | {1, 5}, {-1, 5}, {5, 1}, {-5, 1}, {2, 5}, {-2, 5}, {5, 2}, {-5, 2}, |
130 | | {4, 4}, {-4, 4}, {3, 5}, {-3, 5}, {5, 3}, {-5, 3}, {0, 6}, {6, 0}, |
131 | | {1, 6}, {-1, 6}, {6, 1}, {-6, 1}, {2, 6}, {-2, 6}, {6, 2}, {-6, 2}, |
132 | | {4, 5}, {-4, 5}, {5, 4}, {-5, 4}, {3, 6}, {-3, 6}, {6, 3}, {-6, 3}, |
133 | | {0, 7}, {7, 0}, {1, 7}, {-1, 7}, {5, 5}, {-5, 5}, {7, 1}, {-7, 1}, |
134 | | {4, 6}, {-4, 6}, {6, 4}, {-6, 4}, {2, 7}, {-2, 7}, {7, 2}, {-7, 2}, |
135 | | {3, 7}, {-3, 7}, {7, 3}, {-7, 3}, {5, 6}, {-5, 6}, {6, 5}, {-6, 5}, |
136 | | {8, 0}, {4, 7}, {-4, 7}, {7, 4}, {-7, 4}, {8, 1}, {8, 2}, {6, 6}, |
137 | | {-6, 6}, {8, 3}, {5, 7}, {-5, 7}, {7, 5}, {-7, 5}, {8, 4}, {6, 7}, |
138 | | {-6, 7}, {7, 6}, {-7, 6}, {8, 5}, {7, 7}, {-7, 7}, {8, 6}, {8, 7}}; |
139 | 683k | static JXL_INLINE int SpecialDistance(size_t index, int multiplier) { |
140 | 683k | int dist = kSpecialDistances[index][0] + |
141 | 683k | static_cast<int>(multiplier) * kSpecialDistances[index][1]; |
142 | 683k | return (dist > 1) ? dist : 1; |
143 | 683k | } Unexecuted instantiation: encode.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_icc_codec.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_ans.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_cluster.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_context_map.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_lz77.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_fast_lossless.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_frame.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_modular.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_patch_dictionary.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_dot_dictionary.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_detect_dots.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_debug_image.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_quant_weights.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_coeff_order.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_heuristics.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_adaptive_quantization.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_cache.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_group.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_splines.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_chroma_from_luma.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_ac_strategy.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_entropy_coder.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_jpeg_data.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_encoding.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_ma.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_transform.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_rct.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_palette.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: coeff_order.cc:jxl::SpecialDistance(unsigned long, int) dec_ans.cc:jxl::SpecialDistance(unsigned long, int) Line | Count | Source | 139 | 683k | static JXL_INLINE int SpecialDistance(size_t index, int multiplier) { | 140 | 683k | int dist = kSpecialDistances[index][0] + | 141 | 683k | static_cast<int>(multiplier) * kSpecialDistances[index][1]; | 142 | 683k | return (dist > 1) ? dist : 1; | 143 | 683k | } |
Unexecuted instantiation: dec_cache.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_context_map.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_external_image.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_frame.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_group.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_modular.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_noise.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_patch_dictionary.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: decode.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: epf.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: icc_codec.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_ma.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: encoding.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: modular_image.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: transform.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: rct.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: palette.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: quant_weights.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: stage_blending.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: stage_epf.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: stage_write.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: splines.cc:jxl::SpecialDistance(unsigned long, int) |
144 | | |
145 | | struct ANSCode { |
146 | | AlignedMemory alias_tables; |
147 | | std::vector<HuffmanDecodingData> huffman_data; |
148 | | std::vector<HybridUintConfig> uint_config; |
149 | | std::vector<int> degenerate_symbols; |
150 | | bool use_prefix_code; |
151 | | uint8_t log_alpha_size; // for ANS. |
152 | | LZ77Params lz77; |
153 | | // Maximum number of bits necessary to represent the result of a |
154 | | // ReadHybridUint call done with this ANSCode. |
155 | | size_t max_num_bits = 0; |
156 | | JxlMemoryManager* memory_manager; |
157 | | void UpdateMaxNumBits(size_t ctx, size_t symbol); |
158 | | }; |
159 | | |
160 | | class ANSSymbolReader { |
161 | | public: |
162 | | // Invalid symbol reader, to be overwritten. |
163 | 31.8k | ANSSymbolReader() = default; |
164 | | static StatusOr<ANSSymbolReader> Create(const ANSCode* code, |
165 | | BitReader* JXL_RESTRICT br, |
166 | | size_t distance_multiplier = 0); |
167 | | |
168 | | JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx, |
169 | 33.5M | BitReader* JXL_RESTRICT br) { |
170 | 33.5M | const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); |
171 | | |
172 | 33.5M | const AliasTable::Entry* table = |
173 | 33.5M | &alias_tables_[histo_idx << log_alpha_size_]; |
174 | 33.5M | const AliasTable::Symbol symbol = |
175 | 33.5M | AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); |
176 | 33.5M | state_ = symbol.freq * (state_ >> ANS_LOG_TAB_SIZE) + symbol.offset; |
177 | | |
178 | 33.5M | #if JXL_TRUE |
179 | | // Branchless version is about equally fast on SKX. |
180 | 33.5M | const uint32_t new_state = |
181 | 33.5M | (state_ << 16u) | static_cast<uint32_t>(br->PeekFixedBits<16>()); |
182 | 33.5M | const bool normalize = state_ < (1u << 16u); |
183 | 33.5M | state_ = normalize ? new_state : state_; |
184 | 33.5M | br->Consume(normalize ? 16 : 0); |
185 | | #else |
186 | | if (JXL_UNLIKELY(state_ < (1u << 16u))) { |
187 | | state_ = (state_ << 16u) | br->PeekFixedBits<16>(); |
188 | | br->Consume(16); |
189 | | } |
190 | | #endif |
191 | 33.5M | const uint32_t next_res = state_ & (ANS_TAB_SIZE - 1u); |
192 | 33.5M | AliasTable::Prefetch(table, next_res, log_entry_size_); |
193 | | |
194 | 33.5M | return symbol.value; |
195 | 33.5M | } |
196 | | |
197 | | JXL_INLINE size_t ReadSymbolHuffWithoutRefill(const size_t histo_idx, |
198 | 82.3M | BitReader* JXL_RESTRICT br) { |
199 | 82.3M | return huffman_data_[histo_idx].ReadSymbol(br); |
200 | 82.3M | } |
201 | | |
202 | | JXL_INLINE size_t ReadSymbolWithoutRefill(const size_t histo_idx, |
203 | 115M | BitReader* JXL_RESTRICT br) { |
204 | | // TODO(veluca): hoist if in hotter loops. |
205 | 115M | if (JXL_UNLIKELY(use_prefix_code_)) { |
206 | 82.3M | return ReadSymbolHuffWithoutRefill(histo_idx, br); |
207 | 82.3M | } |
208 | 33.5M | return ReadSymbolANSWithoutRefill(histo_idx, br); |
209 | 115M | } |
210 | | |
211 | | JXL_INLINE size_t ReadSymbol(const size_t histo_idx, |
212 | 0 | BitReader* JXL_RESTRICT br) { |
213 | 0 | br->Refill(); |
214 | 0 | return ReadSymbolWithoutRefill(histo_idx, br); |
215 | 0 | } |
216 | | |
217 | | #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION |
218 | 81.8k | bool CheckANSFinalState() const { return true; } |
219 | | #else |
220 | | bool CheckANSFinalState() const { return state_ == (ANS_SIGNATURE << 16u); } |
221 | | #endif |
222 | | |
223 | | template <typename BitReader> |
224 | | static JXL_INLINE uint32_t ReadHybridUintConfig( |
225 | 115M | const HybridUintConfig& config, size_t token, BitReader* br) { |
226 | 115M | size_t split_token = config.split_token; |
227 | 115M | size_t msb_in_token = config.msb_in_token; |
228 | 115M | size_t lsb_in_token = config.lsb_in_token; |
229 | 115M | size_t split_exponent = config.split_exponent; |
230 | | // Fast-track version of hybrid integer decoding. |
231 | 115M | if (token < split_token) return token; |
232 | 5.57M | uint32_t nbits = split_exponent - (msb_in_token + lsb_in_token) + |
233 | 5.57M | ((token - split_token) >> (msb_in_token + lsb_in_token)); |
234 | | // Max amount of bits for ReadBits is 32 and max valid left shift is 29 |
235 | | // bits. However, for speed no error is propagated here, instead limit the |
236 | | // nbits size. If nbits > 29, the code stream is invalid, but no error is |
237 | | // returned. |
238 | | // Note that in most cases we will emit an error if the histogram allows |
239 | | // representing numbers that would cause invalid shifts, but we need to |
240 | | // keep this check as when LZ77 is enabled it might make sense to have an |
241 | | // histogram that could in principle cause invalid shifts. |
242 | 5.57M | nbits &= 31u; |
243 | 5.57M | uint32_t low = token & ((1 << lsb_in_token) - 1); |
244 | 5.57M | token >>= lsb_in_token; |
245 | 5.57M | const size_t bits = br->PeekBits(nbits); |
246 | 5.57M | br->Consume(nbits); |
247 | 5.57M | size_t ret = (((((1 << msb_in_token) | (token & ((1 << msb_in_token) - 1))) |
248 | 5.57M | << nbits) | |
249 | 5.57M | bits) |
250 | 5.57M | << lsb_in_token) | |
251 | 5.57M | low; |
252 | | // TODO(eustas): mark BitReader as unhealthy if nbits > 29 or ret does not |
253 | | // fit uint32_t |
254 | 5.57M | return static_cast<uint32_t>(ret); |
255 | 115M | } |
256 | | |
257 | | // Takes a *clustered* idx. Can only use if HuffRleOnly() is true. |
258 | | JXL_INLINE void ReadHybridUintClusteredHuffRleOnly(size_t ctx, |
259 | | BitReader* JXL_RESTRICT br, |
260 | | uint32_t* value, |
261 | 174 | uint32_t* run) { |
262 | 174 | JXL_DASSERT(IsHuffRleOnly()); |
263 | 174 | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits |
264 | 174 | size_t token = ReadSymbolHuffWithoutRefill(ctx, br); |
265 | 174 | if (JXL_UNLIKELY(token >= lz77_threshold_)) { |
266 | 0 | *run = |
267 | 0 | ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + |
268 | 0 | lz77_min_length_ - 1; |
269 | 0 | return; |
270 | 0 | } |
271 | 174 | *value = ReadHybridUintConfig(configs[ctx], token, br); |
272 | 174 | } |
273 | 206 | bool IsHuffRleOnly() const { |
274 | 206 | if (lz77_window_ == nullptr) return false; |
275 | 206 | if (!use_prefix_code_) return false; |
276 | 130 | for (size_t i = 0; i < kHuffmanTableBits; i++) { |
277 | 127 | if (huffman_data_[lz77_ctx_].table_[i].bits) return false; |
278 | 84 | if (huffman_data_[lz77_ctx_].table_[i].value != 1) return false; |
279 | 84 | } |
280 | 3 | if (configs[lz77_ctx_].split_token > 1) return false; |
281 | 3 | return true; |
282 | 3 | } |
283 | 431k | bool UsesLZ77() { return lz77_window_ != nullptr; } |
284 | | |
285 | | // Takes a *clustered* idx. Inlined, for use in hot paths. |
286 | | template <bool uses_lz77> |
287 | | JXL_INLINE size_t ReadHybridUintClusteredInlined(size_t ctx, |
288 | 116M | BitReader* JXL_RESTRICT br) { |
289 | 116M | if (uses_lz77) { |
290 | 40.0M | if (JXL_UNLIKELY(num_to_copy_ > 0)) { |
291 | 1.11M | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; |
292 | 1.11M | num_to_copy_--; |
293 | 1.11M | lz77_window_[(num_decoded_++) & kWindowMask] = ret; |
294 | 1.11M | return ret; |
295 | 1.11M | } |
296 | 40.0M | } |
297 | | |
298 | 115M | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits |
299 | 115M | size_t token = ReadSymbolWithoutRefill(ctx, br); |
300 | 115M | if (uses_lz77) { |
301 | 38.9M | if (JXL_UNLIKELY(token >= lz77_threshold_)) { |
302 | 18.2k | num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_, |
303 | 18.2k | token - lz77_threshold_, br) + |
304 | 18.2k | lz77_min_length_; |
305 | 18.2k | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits |
306 | | // Distance code. |
307 | 18.2k | size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br); |
308 | 18.2k | size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br); |
309 | 18.2k | if (JXL_LIKELY(distance < num_special_distances_)) { |
310 | 6.53k | distance = special_distances_[distance]; |
311 | 11.7k | } else { |
312 | 11.7k | distance = distance + 1 - num_special_distances_; |
313 | 11.7k | } |
314 | 18.2k | if (JXL_UNLIKELY(distance > num_decoded_)) { |
315 | 3.33k | distance = num_decoded_; |
316 | 3.33k | } |
317 | 18.2k | if (JXL_UNLIKELY(distance > kWindowSize)) { |
318 | 0 | distance = kWindowSize; |
319 | 0 | } |
320 | 18.2k | copy_pos_ = num_decoded_ - distance; |
321 | 18.2k | if (JXL_UNLIKELY(distance == 0)) { |
322 | 388 | JXL_DASSERT(lz77_window_ != nullptr); |
323 | | // distance 0 -> num_decoded_ == copy_pos_ == 0 |
324 | 388 | size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); |
325 | 388 | memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); |
326 | 388 | } |
327 | | // TODO(eustas): overflow; mark BitReader as unhealthy |
328 | 18.2k | if (num_to_copy_ < lz77_min_length_) return 0; |
329 | | // the code below is the same as doing this: |
330 | | // return ReadHybridUintClustered<uses_lz77>(ctx, br); |
331 | | // but gcc doesn't like recursive inlining |
332 | | |
333 | 18.1k | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; |
334 | 18.1k | num_to_copy_--; |
335 | 18.1k | lz77_window_[(num_decoded_++) & kWindowMask] = ret; |
336 | 18.1k | return ret; |
337 | 18.2k | } |
338 | 38.9M | } |
339 | 115M | size_t ret = ReadHybridUintConfig(configs[ctx], token, br); |
340 | 115M | if (uses_lz77 && lz77_window_) |
341 | 23.6M | lz77_window_[(num_decoded_++) & kWindowMask] = ret; |
342 | 115M | return ret; |
343 | 115M | } unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredInlined<true>(unsigned long, jxl::BitReader*) Line | Count | Source | 288 | 40.0M | BitReader* JXL_RESTRICT br) { | 289 | 40.0M | if (uses_lz77) { | 290 | 40.0M | if (JXL_UNLIKELY(num_to_copy_ > 0)) { | 291 | 1.11M | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; | 292 | 1.11M | num_to_copy_--; | 293 | 1.11M | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 294 | 1.11M | return ret; | 295 | 1.11M | } | 296 | 40.0M | } | 297 | | | 298 | 38.9M | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits | 299 | 38.9M | size_t token = ReadSymbolWithoutRefill(ctx, br); | 300 | 38.9M | if (uses_lz77) { | 301 | 38.9M | if (JXL_UNLIKELY(token >= lz77_threshold_)) { | 302 | 18.2k | num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_, | 303 | 18.2k | token - lz77_threshold_, br) + | 304 | 18.2k | lz77_min_length_; | 305 | 18.2k | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits | 306 | | // Distance code. | 307 | 18.2k | size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br); | 308 | 18.2k | size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br); | 309 | 18.2k | if (JXL_LIKELY(distance < num_special_distances_)) { | 310 | 6.53k | distance = special_distances_[distance]; | 311 | 11.7k | } else { | 312 | 11.7k | distance = distance + 1 - num_special_distances_; | 313 | 11.7k | } | 314 | 18.2k | if (JXL_UNLIKELY(distance > num_decoded_)) { | 315 | 3.33k | distance = num_decoded_; | 316 | 3.33k | } | 317 | 18.2k | if (JXL_UNLIKELY(distance > kWindowSize)) { | 318 | 0 | distance = kWindowSize; | 319 | 0 | } | 320 | 18.2k | copy_pos_ = num_decoded_ - distance; | 321 | 18.2k | if (JXL_UNLIKELY(distance == 0)) { | 322 | 388 | JXL_DASSERT(lz77_window_ != nullptr); | 323 | | // distance 0 -> num_decoded_ == copy_pos_ == 0 | 324 | 388 | size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); | 325 | 388 | memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); | 326 | 388 | } | 327 | | // TODO(eustas): overflow; mark BitReader as unhealthy | 328 | 18.2k | if (num_to_copy_ < lz77_min_length_) return 0; | 329 | | // the code below is the same as doing this: | 330 | | // return ReadHybridUintClustered<uses_lz77>(ctx, br); | 331 | | // but gcc doesn't like recursive inlining | 332 | | | 333 | 18.1k | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; | 334 | 18.1k | num_to_copy_--; | 335 | 18.1k | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 336 | 18.1k | return ret; | 337 | 18.2k | } | 338 | 38.9M | } | 339 | 38.9M | size_t ret = ReadHybridUintConfig(configs[ctx], token, br); | 340 | 38.9M | if (uses_lz77 && lz77_window_) | 341 | 23.6M | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 342 | 38.9M | return ret; | 343 | 38.9M | } |
unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredInlined<false>(unsigned long, jxl::BitReader*) Line | Count | Source | 288 | 76.9M | BitReader* JXL_RESTRICT br) { | 289 | 76.9M | if (uses_lz77) { | 290 | 0 | if (JXL_UNLIKELY(num_to_copy_ > 0)) { | 291 | 0 | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; | 292 | 0 | num_to_copy_--; | 293 | 0 | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 294 | 0 | return ret; | 295 | 0 | } | 296 | 0 | } | 297 | | | 298 | 76.9M | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits | 299 | 76.9M | size_t token = ReadSymbolWithoutRefill(ctx, br); | 300 | 76.9M | if (uses_lz77) { | 301 | 0 | if (JXL_UNLIKELY(token >= lz77_threshold_)) { | 302 | 0 | num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_, | 303 | 0 | token - lz77_threshold_, br) + | 304 | 0 | lz77_min_length_; | 305 | 0 | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits | 306 | | // Distance code. | 307 | 0 | size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br); | 308 | 0 | size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br); | 309 | 0 | if (JXL_LIKELY(distance < num_special_distances_)) { | 310 | 0 | distance = special_distances_[distance]; | 311 | 0 | } else { | 312 | 0 | distance = distance + 1 - num_special_distances_; | 313 | 0 | } | 314 | 0 | if (JXL_UNLIKELY(distance > num_decoded_)) { | 315 | 0 | distance = num_decoded_; | 316 | 0 | } | 317 | 0 | if (JXL_UNLIKELY(distance > kWindowSize)) { | 318 | 0 | distance = kWindowSize; | 319 | 0 | } | 320 | 0 | copy_pos_ = num_decoded_ - distance; | 321 | 0 | if (JXL_UNLIKELY(distance == 0)) { | 322 | 0 | JXL_DASSERT(lz77_window_ != nullptr); | 323 | | // distance 0 -> num_decoded_ == copy_pos_ == 0 | 324 | 0 | size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); | 325 | 0 | memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); | 326 | 0 | } | 327 | | // TODO(eustas): overflow; mark BitReader as unhealthy | 328 | 0 | if (num_to_copy_ < lz77_min_length_) return 0; | 329 | | // the code below is the same as doing this: | 330 | | // return ReadHybridUintClustered<uses_lz77>(ctx, br); | 331 | | // but gcc doesn't like recursive inlining | 332 | | | 333 | 0 | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; | 334 | 0 | num_to_copy_--; | 335 | 0 | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 336 | 0 | return ret; | 337 | 0 | } | 338 | 0 | } | 339 | 76.9M | size_t ret = ReadHybridUintConfig(configs[ctx], token, br); | 340 | 76.9M | if (uses_lz77 && lz77_window_) | 341 | 0 | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 342 | 76.9M | return ret; | 343 | 76.9M | } |
|
344 | | |
345 | | // same but not inlined |
346 | | template <bool uses_lz77> |
347 | 31.1M | size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { |
348 | 31.1M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); |
349 | 31.1M | } unsigned long jxl::ANSSymbolReader::ReadHybridUintClustered<true>(unsigned long, jxl::BitReader*) Line | Count | Source | 347 | 29.0M | size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { | 348 | 29.0M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); | 349 | 29.0M | } |
unsigned long jxl::ANSSymbolReader::ReadHybridUintClustered<false>(unsigned long, jxl::BitReader*) Line | Count | Source | 347 | 2.03M | size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { | 348 | 2.03M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); | 349 | 2.03M | } |
|
350 | | |
351 | | // inlined only in the no-lz77 case |
352 | | template <bool uses_lz77> |
353 | | JXL_INLINE size_t |
354 | 8.26M | ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) { |
355 | 8.26M | if (uses_lz77) { |
356 | 3.31M | return ReadHybridUintClustered<uses_lz77>(ctx, br); |
357 | 4.95M | } else { |
358 | 4.95M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); |
359 | 4.95M | } |
360 | 8.26M | } unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredMaybeInlined<true>(unsigned long, jxl::BitReader*) Line | Count | Source | 354 | 3.31M | ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) { | 355 | 3.31M | if (uses_lz77) { | 356 | 3.31M | return ReadHybridUintClustered<uses_lz77>(ctx, br); | 357 | 3.31M | } else { | 358 | 0 | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); | 359 | 0 | } | 360 | 3.31M | } |
unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredMaybeInlined<false>(unsigned long, jxl::BitReader*) Line | Count | Source | 354 | 4.95M | ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) { | 355 | 4.95M | if (uses_lz77) { | 356 | 0 | return ReadHybridUintClustered<uses_lz77>(ctx, br); | 357 | 4.95M | } else { | 358 | 4.95M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); | 359 | 4.95M | } | 360 | 4.95M | } |
|
361 | | |
362 | | // inlined, for use in hot paths |
363 | | template <bool uses_lz77> |
364 | | JXL_INLINE size_t |
365 | | ReadHybridUintInlined(size_t ctx, BitReader* JXL_RESTRICT br, |
366 | 3.01M | const std::vector<uint8_t>& context_map) { |
367 | 3.01M | return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br); |
368 | 3.01M | } unsigned long jxl::ANSSymbolReader::ReadHybridUintInlined<true>(unsigned long, jxl::BitReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&) Line | Count | Source | 366 | 2.69M | const std::vector<uint8_t>& context_map) { | 367 | 2.69M | return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br); | 368 | 2.69M | } |
unsigned long jxl::ANSSymbolReader::ReadHybridUintInlined<false>(unsigned long, jxl::BitReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&) Line | Count | Source | 366 | 328k | const std::vector<uint8_t>& context_map) { | 367 | 328k | return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br); | 368 | 328k | } |
|
369 | | |
370 | | // not inlined, for use in non-hot paths |
371 | | size_t ReadHybridUint(size_t ctx, BitReader* JXL_RESTRICT br, |
372 | 22.6M | const std::vector<uint8_t>& context_map) { |
373 | 22.6M | return ReadHybridUintClustered</*uses_lz77=*/true>(context_map[ctx], br); |
374 | 22.6M | } |
375 | | |
376 | | // ctx is a *clustered* context! |
377 | | // This function will modify the ANS state as if `count` symbols have been |
378 | | // decoded. |
379 | 141k | bool IsSingleValueAndAdvance(size_t ctx, uint32_t* value, size_t count) { |
380 | | // TODO(veluca): No optimization for Huffman mode yet. |
381 | 141k | if (use_prefix_code_) return false; |
382 | | // TODO(eustas): propagate "degenerate_symbol" to simplify this method. |
383 | 75.7k | const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); |
384 | 75.7k | const AliasTable::Entry* table = &alias_tables_[ctx << log_alpha_size_]; |
385 | 75.7k | AliasTable::Symbol symbol = |
386 | 75.7k | AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); |
387 | 75.7k | if (symbol.freq != ANS_TAB_SIZE) return false; |
388 | 59.9k | if (configs[ctx].split_token <= symbol.value) return false; |
389 | 59.3k | if (symbol.value >= lz77_threshold_) return false; |
390 | 59.2k | *value = symbol.value; |
391 | 59.2k | if (lz77_window_) { |
392 | 42.5M | for (size_t i = 0; i < count; i++) { |
393 | 42.5M | lz77_window_[(num_decoded_++) & kWindowMask] = symbol.value; |
394 | 42.5M | } |
395 | 18.8k | } |
396 | 59.2k | return true; |
397 | 59.3k | } |
398 | | |
399 | | static constexpr size_t kMaxCheckpointInterval = 512; |
400 | | struct Checkpoint { |
401 | | uint32_t state; |
402 | | uint32_t num_to_copy; |
403 | | uint32_t copy_pos; |
404 | | uint32_t num_decoded; |
405 | | uint32_t lz77_window[kMaxCheckpointInterval]; |
406 | | }; |
407 | 8.07k | void Save(Checkpoint* checkpoint) { |
408 | 8.07k | checkpoint->state = state_; |
409 | 8.07k | checkpoint->num_decoded = num_decoded_; |
410 | 8.07k | checkpoint->num_to_copy = num_to_copy_; |
411 | 8.07k | checkpoint->copy_pos = copy_pos_; |
412 | 8.07k | if (lz77_window_) { |
413 | 2.26k | size_t win_start = num_decoded_ & kWindowMask; |
414 | 2.26k | size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; |
415 | 2.26k | if (win_end > win_start) { |
416 | 2.26k | memcpy(checkpoint->lz77_window, lz77_window_ + win_start, |
417 | 2.26k | (win_end - win_start) * sizeof(*lz77_window_)); |
418 | 2.26k | } else { |
419 | 0 | memcpy(checkpoint->lz77_window, lz77_window_ + win_start, |
420 | 0 | (kWindowSize - win_start) * sizeof(*lz77_window_)); |
421 | 0 | memcpy(checkpoint->lz77_window + (kWindowSize - win_start), |
422 | 0 | lz77_window_, win_end * sizeof(*lz77_window_)); |
423 | 0 | } |
424 | 2.26k | } |
425 | 8.07k | } |
426 | 404 | void Restore(const Checkpoint& checkpoint) { |
427 | 404 | state_ = checkpoint.state; |
428 | 404 | JXL_DASSERT(num_decoded_ <= |
429 | 404 | checkpoint.num_decoded + kMaxCheckpointInterval); |
430 | 404 | num_decoded_ = checkpoint.num_decoded; |
431 | 404 | num_to_copy_ = checkpoint.num_to_copy; |
432 | 404 | copy_pos_ = checkpoint.copy_pos; |
433 | 404 | if (lz77_window_) { |
434 | 226 | size_t win_start = num_decoded_ & kWindowMask; |
435 | 226 | size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; |
436 | 226 | if (win_end > win_start) { |
437 | 226 | memcpy(lz77_window_ + win_start, checkpoint.lz77_window, |
438 | 226 | (win_end - win_start) * sizeof(*lz77_window_)); |
439 | 226 | } else { |
440 | 0 | memcpy(lz77_window_ + win_start, checkpoint.lz77_window, |
441 | 0 | (kWindowSize - win_start) * sizeof(*lz77_window_)); |
442 | 0 | memcpy(lz77_window_, checkpoint.lz77_window + (kWindowSize - win_start), |
443 | 0 | win_end * sizeof(*lz77_window_)); |
444 | 0 | } |
445 | 226 | } |
446 | 404 | } |
447 | | |
448 | | private: |
449 | | ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, |
450 | | size_t distance_multiplier, |
451 | | AlignedMemory&& lz77_window_storage); |
452 | | |
453 | | const AliasTable::Entry* JXL_RESTRICT alias_tables_; // not owned |
454 | | const HuffmanDecodingData* huffman_data_; |
455 | | bool use_prefix_code_; |
456 | | uint32_t state_ = ANS_SIGNATURE << 16u; |
457 | | const HybridUintConfig* JXL_RESTRICT configs; |
458 | | uint32_t log_alpha_size_{}; |
459 | | uint32_t log_entry_size_{}; |
460 | | uint32_t entry_size_minus_1_{}; |
461 | | |
462 | | // LZ77 structures and constants. |
463 | | static constexpr size_t kWindowMask = kWindowSize - 1; |
464 | | // a std::vector incurs unacceptable decoding speed loss because of |
465 | | // initialization. |
466 | | AlignedMemory lz77_window_storage_; |
467 | | uint32_t* lz77_window_ = nullptr; |
468 | | uint32_t num_decoded_ = 0; |
469 | | uint32_t num_to_copy_ = 0; |
470 | | uint32_t copy_pos_ = 0; |
471 | | uint32_t lz77_ctx_ = 0; |
472 | | uint32_t lz77_min_length_ = 0; |
473 | | uint32_t lz77_threshold_ = 1 << 20; // bigger than any symbol. |
474 | | HybridUintConfig lz77_length_uint_; |
475 | | uint32_t special_distances_[kNumSpecialDistances]{}; |
476 | | uint32_t num_special_distances_{}; |
477 | | }; |
478 | | |
479 | | Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br, |
480 | | size_t num_contexts, ANSCode* code, |
481 | | std::vector<uint8_t>* context_map, |
482 | | bool disallow_lz77 = false); |
483 | | |
484 | | // Exposed for tests. |
485 | | Status DecodeUintConfigs(size_t log_alpha_size, |
486 | | std::vector<HybridUintConfig>* uint_config, |
487 | | BitReader* br); |
488 | | |
489 | | } // namespace jxl |
490 | | |
491 | | #endif // LIB_JXL_DEC_ANS_H_ |