/src/libjxl/lib/jxl/dec_ans.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #ifndef LIB_JXL_DEC_ANS_H_ |
7 | | #define LIB_JXL_DEC_ANS_H_ |
8 | | |
9 | | // Library to decode the ANS population counts from the bit-stream and build a |
10 | | // decoding table from them. |
11 | | |
12 | | #include <jxl/memory_manager.h> |
13 | | #include <jxl/types.h> |
14 | | |
15 | | #include <algorithm> |
16 | | #include <cstddef> |
17 | | #include <cstdint> |
18 | | #include <cstring> |
19 | | #include <vector> |
20 | | |
21 | | #include "lib/jxl/ans_common.h" |
22 | | #include "lib/jxl/ans_params.h" |
23 | | #include "lib/jxl/base/bits.h" |
24 | | #include "lib/jxl/base/compiler_specific.h" |
25 | | #include "lib/jxl/base/status.h" |
26 | | #include "lib/jxl/dec_bit_reader.h" |
27 | | #include "lib/jxl/dec_huffman.h" |
28 | | #include "lib/jxl/field_encodings.h" |
29 | | #include "lib/jxl/memory_manager_internal.h" |
30 | | |
31 | | namespace jxl { |
32 | | |
33 | | class ANSSymbolReader; |
34 | | |
35 | | // Experiments show that best performance is typically achieved for a |
36 | | // split-exponent of 3 or 4. Trend seems to be that '4' is better |
37 | | // for large-ish pictures, and '3' better for rather small-ish pictures. |
38 | | // This is plausible - the more special symbols we have, the better |
39 | | // statistics we need to get a benefit out of them. |
40 | | |
41 | | // Our hybrid-encoding scheme has dedicated tokens for the smallest |
42 | | // (1 << split_exponents) numbers, and for the rest |
43 | | // encodes (number of bits) + (msb_in_token sub-leading binary digits) + |
44 | | // (lsb_in_token lowest binary digits) in the token, with the remaining bits |
45 | | // then being encoded as data. |
46 | | // |
47 | | // Example with split_exponent = 4, msb_in_token = 2, lsb_in_token = 0. |
48 | | // |
49 | | // Numbers N in [0 .. 15]: |
50 | | // These get represented as (token=N, bits=''). |
51 | | // Numbers N >= 16: |
52 | | // If n is such that 2**n <= N < 2**(n+1), |
53 | | // and m = N - 2**n is the 'mantissa', |
54 | | // these get represented as: |
55 | | // (token=split_token + |
56 | | // ((n - split_exponent) * 4) + |
57 | | // (m >> (n - msb_in_token)), |
58 | | // bits=m & (1 << (n - msb_in_token)) - 1) |
59 | | // Specifically, we would get: |
60 | | // N = 0 - 15: (token=N, nbits=0, bits='') |
61 | | // N = 16 (10000): (token=16, nbits=2, bits='00') |
62 | | // N = 17 (10001): (token=16, nbits=2, bits='01') |
63 | | // N = 20 (10100): (token=17, nbits=2, bits='00') |
64 | | // N = 24 (11000): (token=18, nbits=2, bits='00') |
65 | | // N = 28 (11100): (token=19, nbits=2, bits='00') |
66 | | // N = 32 (100000): (token=20, nbits=3, bits='000') |
67 | | // N = 65535: (token=63, nbits=13, bits='1111111111111') |
68 | | struct HybridUintConfig { |
69 | | uint32_t split_exponent; |
70 | | uint32_t split_token; |
71 | | uint32_t msb_in_token; |
72 | | uint32_t lsb_in_token; |
73 | | JXL_INLINE void Encode(uint32_t value, uint32_t* JXL_RESTRICT token, |
74 | | uint32_t* JXL_RESTRICT nbits, |
75 | 511M | uint32_t* JXL_RESTRICT bits) const { |
76 | 511M | if (value < split_token) { |
77 | 460M | *token = value; |
78 | 460M | *nbits = 0; |
79 | 460M | *bits = 0; |
80 | 460M | } else { |
81 | 51.0M | uint32_t n = FloorLog2Nonzero(value); |
82 | 51.0M | uint32_t m = value - (1 << n); |
83 | 51.0M | *token = split_token + |
84 | 51.0M | ((n - split_exponent) << (msb_in_token + lsb_in_token)) + |
85 | 51.0M | ((m >> (n - msb_in_token)) << lsb_in_token) + |
86 | 51.0M | (m & ((1 << lsb_in_token) - 1)); |
87 | 51.0M | *nbits = n - msb_in_token - lsb_in_token; |
88 | 51.0M | *bits = (value >> lsb_in_token) & ((1UL << *nbits) - 1); |
89 | 51.0M | } |
90 | 511M | } |
91 | | |
92 | 415k | JXL_INLINE uint32_t LsbMask() const { return (1 << lsb_in_token) - 1; } |
93 | | |
94 | | explicit HybridUintConfig(uint32_t split_exponent = 4, |
95 | | uint32_t msb_in_token = 2, |
96 | | uint32_t lsb_in_token = 0) |
97 | 14.0M | : split_exponent(split_exponent), |
98 | 14.0M | split_token(1 << split_exponent), |
99 | 14.0M | msb_in_token(msb_in_token), |
100 | 14.0M | lsb_in_token(lsb_in_token) { |
101 | 14.0M | JXL_DASSERT(split_exponent >= msb_in_token + lsb_in_token); |
102 | 14.0M | } |
103 | | }; |
104 | | |
105 | | struct LZ77Params : public Fields { |
106 | | LZ77Params(); |
107 | | JXL_FIELDS_NAME(LZ77Params) |
108 | | Status VisitFields(Visitor* JXL_RESTRICT visitor) override; |
109 | | bool enabled; |
110 | | |
111 | | // Symbols above min_symbol use a special hybrid uint encoding and |
112 | | // represent a length, to be added to min_length. |
113 | | uint32_t min_symbol; |
114 | | uint32_t min_length; |
115 | | |
116 | | // Not serialized by VisitFields. |
117 | | HybridUintConfig length_uint_config{0, 0, 0}; |
118 | | |
119 | | size_t nonserialized_distance_context; |
120 | | }; |
121 | | |
122 | | static constexpr size_t kWindowSize = 1 << 20; |
123 | | static constexpr size_t kNumSpecialDistances = 120; |
124 | | // Table of special distance codes from WebP lossless. |
125 | | static constexpr int8_t kSpecialDistances[kNumSpecialDistances][2] = { |
126 | | {0, 1}, {1, 0}, {1, 1}, {-1, 1}, {0, 2}, {2, 0}, {1, 2}, {-1, 2}, |
127 | | {2, 1}, {-2, 1}, {2, 2}, {-2, 2}, {0, 3}, {3, 0}, {1, 3}, {-1, 3}, |
128 | | {3, 1}, {-3, 1}, {2, 3}, {-2, 3}, {3, 2}, {-3, 2}, {0, 4}, {4, 0}, |
129 | | {1, 4}, {-1, 4}, {4, 1}, {-4, 1}, {3, 3}, {-3, 3}, {2, 4}, {-2, 4}, |
130 | | {4, 2}, {-4, 2}, {0, 5}, {3, 4}, {-3, 4}, {4, 3}, {-4, 3}, {5, 0}, |
131 | | {1, 5}, {-1, 5}, {5, 1}, {-5, 1}, {2, 5}, {-2, 5}, {5, 2}, {-5, 2}, |
132 | | {4, 4}, {-4, 4}, {3, 5}, {-3, 5}, {5, 3}, {-5, 3}, {0, 6}, {6, 0}, |
133 | | {1, 6}, {-1, 6}, {6, 1}, {-6, 1}, {2, 6}, {-2, 6}, {6, 2}, {-6, 2}, |
134 | | {4, 5}, {-4, 5}, {5, 4}, {-5, 4}, {3, 6}, {-3, 6}, {6, 3}, {-6, 3}, |
135 | | {0, 7}, {7, 0}, {1, 7}, {-1, 7}, {5, 5}, {-5, 5}, {7, 1}, {-7, 1}, |
136 | | {4, 6}, {-4, 6}, {6, 4}, {-6, 4}, {2, 7}, {-2, 7}, {7, 2}, {-7, 2}, |
137 | | {3, 7}, {-3, 7}, {7, 3}, {-7, 3}, {5, 6}, {-5, 6}, {6, 5}, {-6, 5}, |
138 | | {8, 0}, {4, 7}, {-4, 7}, {7, 4}, {-7, 4}, {8, 1}, {8, 2}, {6, 6}, |
139 | | {-6, 6}, {8, 3}, {5, 7}, {-5, 7}, {7, 5}, {-7, 5}, {8, 4}, {6, 7}, |
140 | | {-6, 7}, {7, 6}, {-7, 6}, {8, 5}, {7, 7}, {-7, 7}, {8, 6}, {8, 7}}; |
141 | 1.13M | static JXL_INLINE int SpecialDistance(size_t index, int multiplier) { |
142 | 1.13M | int dist = kSpecialDistances[index][0] + |
143 | 1.13M | static_cast<int>(multiplier) * kSpecialDistances[index][1]; |
144 | 1.13M | return (dist > 1) ? dist : 1; |
145 | 1.13M | } Unexecuted instantiation: encode.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_jpeg_data.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: decode.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: icc_codec.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: encoding.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: modular_image.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: transform.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: quant_weights.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_fast_lossless.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_frame.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_group.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_heuristics.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_icc_codec.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_modular.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_modular_simd.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_patch_dictionary.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_quant_weights.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_splines.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_encoding.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_ma.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_rct.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_transform.cc:jxl::SpecialDistance(unsigned long, int) dec_ans.cc:jxl::SpecialDistance(unsigned long, int) Line | Count | Source | 141 | 1.13M | static JXL_INLINE int SpecialDistance(size_t index, int multiplier) { | 142 | 1.13M | int dist = kSpecialDistances[index][0] + | 143 | 1.13M | static_cast<int>(multiplier) * kSpecialDistances[index][1]; | 144 | 1.13M | return (dist > 1) ? dist : 1; | 145 | 1.13M | } |
Unexecuted instantiation: dec_cache.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_context_map.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_external_image.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_frame.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_group.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_modular.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_noise.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_patch_dictionary.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: epf.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: dec_ma.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: palette.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: rct.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: stage_blending.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: stage_epf.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: stage_write.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: splines.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_ac_strategy.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_adaptive_quantization.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_ans.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_ans_simd.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_cache.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_chroma_from_luma.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_cluster.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_coeff_order.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_context_map.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_debug_image.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_dot_dictionary.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_entropy_coder.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_lz77.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_palette.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: coeff_order.cc:jxl::SpecialDistance(unsigned long, int) Unexecuted instantiation: enc_detect_dots.cc:jxl::SpecialDistance(unsigned long, int) |
146 | | |
147 | | struct ANSCode { |
148 | | AlignedMemory alias_tables; |
149 | | std::vector<HuffmanDecodingData> huffman_data; |
150 | | std::vector<HybridUintConfig> uint_config; |
151 | | std::vector<int> degenerate_symbols; |
152 | | bool use_prefix_code; |
153 | | uint8_t log_alpha_size; // for ANS. |
154 | | LZ77Params lz77; |
155 | | // Maximum number of bits necessary to represent the result of a |
156 | | // ReadHybridUint call done with this ANSCode. |
157 | | size_t max_num_bits = 0; |
158 | | JxlMemoryManager* memory_manager; |
159 | | void UpdateMaxNumBits(size_t ctx, size_t symbol); |
160 | | }; |
161 | | |
162 | | class ANSSymbolReader { |
163 | | public: |
164 | | // Invalid symbol reader, to be overwritten. |
165 | 67.6k | ANSSymbolReader() = default; |
166 | | static StatusOr<ANSSymbolReader> Create(const ANSCode* code, |
167 | | BitReader* JXL_RESTRICT br, |
168 | | size_t distance_multiplier = 0); |
169 | | |
170 | | JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx, |
171 | 204M | BitReader* JXL_RESTRICT br) { |
172 | 204M | const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); |
173 | | |
174 | 204M | const AliasTable::Entry* table = |
175 | 204M | &alias_tables_[histo_idx << log_alpha_size_]; |
176 | 204M | const AliasTable::Symbol symbol = |
177 | 204M | AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); |
178 | 204M | state_ = symbol.freq * (state_ >> ANS_LOG_TAB_SIZE) + symbol.offset; |
179 | | |
180 | 204M | #if JXL_TRUE |
181 | | // Branchless version is about equally fast on SKX. |
182 | 204M | const uint32_t new_state = |
183 | 204M | (state_ << 16u) | static_cast<uint32_t>(br->PeekFixedBits<16>()); |
184 | 204M | const bool normalize = state_ < (1u << 16u); |
185 | 204M | state_ = normalize ? new_state : state_; |
186 | 204M | br->Consume(normalize ? 16 : 0); |
187 | | #else |
188 | | if (JXL_UNLIKELY(state_ < (1u << 16u))) { |
189 | | state_ = (state_ << 16u) | br->PeekFixedBits<16>(); |
190 | | br->Consume(16); |
191 | | } |
192 | | #endif |
193 | 204M | const uint32_t next_res = state_ & (ANS_TAB_SIZE - 1u); |
194 | 204M | AliasTable::Prefetch(table, next_res, log_entry_size_); |
195 | | |
196 | 204M | return symbol.value; |
197 | 204M | } |
198 | | |
199 | | JXL_INLINE size_t ReadSymbolHuffWithoutRefill(const size_t histo_idx, |
200 | 134M | BitReader* JXL_RESTRICT br) { |
201 | 134M | return huffman_data_[histo_idx].ReadSymbol(br); |
202 | 134M | } |
203 | | |
204 | | JXL_INLINE size_t ReadSymbolWithoutRefill(const size_t histo_idx, |
205 | 338M | BitReader* JXL_RESTRICT br) { |
206 | | // TODO(veluca): hoist if in hotter loops. |
207 | 338M | if (JXL_UNLIKELY(use_prefix_code_)) { |
208 | 134M | return ReadSymbolHuffWithoutRefill(histo_idx, br); |
209 | 134M | } |
210 | 204M | return ReadSymbolANSWithoutRefill(histo_idx, br); |
211 | 338M | } |
212 | | |
213 | | JXL_INLINE size_t ReadSymbol(const size_t histo_idx, |
214 | 0 | BitReader* JXL_RESTRICT br) { |
215 | 0 | br->Refill(); |
216 | 0 | return ReadSymbolWithoutRefill(histo_idx, br); |
217 | 0 | } |
218 | | |
219 | | #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION |
220 | 149k | bool CheckANSFinalState() const { return true; } |
221 | | #else |
222 | | bool CheckANSFinalState() const { return state_ == (ANS_SIGNATURE << 16u); } |
223 | | #endif |
224 | | |
225 | | template <typename BitReader> |
226 | | static JXL_INLINE uint32_t ReadHybridUintConfig( |
227 | 338M | const HybridUintConfig& config, size_t token, BitReader* br) { |
228 | 338M | size_t split_token = config.split_token; |
229 | 338M | size_t msb_in_token = config.msb_in_token; |
230 | 338M | size_t lsb_in_token = config.lsb_in_token; |
231 | 338M | size_t split_exponent = config.split_exponent; |
232 | | // Fast-track version of hybrid integer decoding. |
233 | 338M | if (token < split_token) return token; |
234 | 11.8M | uint32_t nbits = split_exponent - (msb_in_token + lsb_in_token) + |
235 | 11.8M | ((token - split_token) >> (msb_in_token + lsb_in_token)); |
236 | | // Max amount of bits for ReadBits is 32 and max valid left shift is 29 |
237 | | // bits. However, for speed no error is propagated here, instead limit the |
238 | | // nbits size. If nbits > 29, the code stream is invalid, but no error is |
239 | | // returned. |
240 | | // Note that in most cases we will emit an error if the histogram allows |
241 | | // representing numbers that would cause invalid shifts, but we need to |
242 | | // keep this check as when LZ77 is enabled it might make sense to have an |
243 | | // histogram that could in principle cause invalid shifts. |
244 | 11.8M | nbits &= 31u; |
245 | 11.8M | uint32_t low = token & ((1 << lsb_in_token) - 1); |
246 | 11.8M | token >>= lsb_in_token; |
247 | 11.8M | const size_t bits = br->PeekBits(nbits); |
248 | 11.8M | br->Consume(nbits); |
249 | 11.8M | size_t ret = (((((1 << msb_in_token) | (token & ((1 << msb_in_token) - 1))) |
250 | 11.8M | << nbits) | |
251 | 11.8M | bits) |
252 | 11.8M | << lsb_in_token) | |
253 | 11.8M | low; |
254 | | // TODO(eustas): mark BitReader as unhealthy if nbits > 29 or ret does not |
255 | | // fit uint32_t |
256 | 11.8M | return static_cast<uint32_t>(ret); |
257 | 338M | } |
258 | | |
259 | | // Takes a *clustered* idx. Can only use if HuffRleOnly() is true. |
260 | | JXL_INLINE void ReadHybridUintClusteredHuffRleOnly(size_t ctx, |
261 | | BitReader* JXL_RESTRICT br, |
262 | | uint32_t* value, |
263 | 33.3k | uint32_t* run) { |
264 | 33.3k | JXL_DASSERT(IsHuffRleOnly()); |
265 | 33.3k | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits |
266 | 33.3k | size_t token = ReadSymbolHuffWithoutRefill(ctx, br); |
267 | 33.3k | if (JXL_UNLIKELY(token >= lz77_threshold_)) { |
268 | 0 | *run = |
269 | 0 | ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + |
270 | 0 | lz77_min_length_ - 1; |
271 | 0 | return; |
272 | 0 | } |
273 | 33.3k | *value = ReadHybridUintConfig(configs[ctx], token, br); |
274 | 33.3k | } |
275 | 33.6k | bool IsHuffRleOnly() const { |
276 | 33.6k | if (lz77_window_ == nullptr) return false; |
277 | 33.6k | if (!use_prefix_code_) return false; |
278 | 300k | for (size_t i = 0; i < kHuffmanTableBits; i++) { |
279 | 267k | if (huffman_data_[lz77_ctx_].table_[i].bits) return false; |
280 | 267k | if (huffman_data_[lz77_ctx_].table_[i].value != 1) return false; |
281 | 267k | } |
282 | 33.4k | if (configs[lz77_ctx_].split_token > 1) return false; |
283 | 33.4k | return true; |
284 | 33.4k | } |
285 | 556k | bool UsesLZ77() { return lz77_window_ != nullptr; } |
286 | | |
287 | | // Takes a *clustered* idx. Inlined, for use in hot paths. |
288 | | template <bool uses_lz77> |
289 | | JXL_INLINE size_t ReadHybridUintClusteredInlined(size_t ctx, |
290 | 341M | BitReader* JXL_RESTRICT br) { |
291 | 341M | if (uses_lz77) { |
292 | 166M | if (JXL_UNLIKELY(num_to_copy_ > 0)) { |
293 | 2.60M | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; |
294 | 2.60M | num_to_copy_--; |
295 | 2.60M | lz77_window_[(num_decoded_++) & kWindowMask] = ret; |
296 | 2.60M | return ret; |
297 | 2.60M | } |
298 | 166M | } |
299 | | |
300 | 338M | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits |
301 | 338M | size_t token = ReadSymbolWithoutRefill(ctx, br); |
302 | 338M | if (uses_lz77) { |
303 | 164M | if (JXL_UNLIKELY(token >= lz77_threshold_)) { |
304 | 47.4k | num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_, |
305 | 47.4k | token - lz77_threshold_, br) + |
306 | 47.4k | lz77_min_length_; |
307 | 47.4k | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits |
308 | | // Distance code. |
309 | 47.4k | size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br); |
310 | 47.4k | size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br); |
311 | 47.4k | if (JXL_LIKELY(distance < num_special_distances_)) { |
312 | 27.3k | distance = special_distances_[distance]; |
313 | 27.3k | } else { |
314 | 20.0k | distance = distance + 1 - num_special_distances_; |
315 | 20.0k | } |
316 | 47.4k | if (JXL_UNLIKELY(distance > num_decoded_)) { |
317 | 8.10k | distance = num_decoded_; |
318 | 8.10k | } |
319 | 47.4k | if (JXL_UNLIKELY(distance > kWindowSize)) { |
320 | 0 | distance = kWindowSize; |
321 | 0 | } |
322 | 47.4k | copy_pos_ = num_decoded_ - distance; |
323 | 47.4k | if (JXL_UNLIKELY(distance == 0)) { |
324 | 377 | JXL_DASSERT(lz77_window_ != nullptr); |
325 | | // distance 0 -> num_decoded_ == copy_pos_ == 0 |
326 | 377 | size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); |
327 | 377 | memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); |
328 | 377 | } |
329 | | // TODO(eustas): overflow; mark BitReader as unhealthy |
330 | 47.4k | if (num_to_copy_ < lz77_min_length_) return 0; |
331 | | // the code below is the same as doing this: |
332 | | // return ReadHybridUintClustered<uses_lz77>(ctx, br); |
333 | | // but gcc doesn't like recursive inlining |
334 | | |
335 | 47.3k | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; |
336 | 47.3k | num_to_copy_--; |
337 | 47.3k | lz77_window_[(num_decoded_++) & kWindowMask] = ret; |
338 | 47.3k | return ret; |
339 | 47.4k | } |
340 | 164M | } |
341 | 338M | size_t ret = ReadHybridUintConfig(configs[ctx], token, br); |
342 | 338M | if (uses_lz77 && lz77_window_) |
343 | 29.1M | lz77_window_[(num_decoded_++) & kWindowMask] = ret; |
344 | 338M | return ret; |
345 | 338M | } unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredInlined<true>(unsigned long, jxl::BitReader*) Line | Count | Source | 290 | 166M | BitReader* JXL_RESTRICT br) { | 291 | 166M | if (uses_lz77) { | 292 | 166M | if (JXL_UNLIKELY(num_to_copy_ > 0)) { | 293 | 2.60M | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; | 294 | 2.60M | num_to_copy_--; | 295 | 2.60M | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 296 | 2.60M | return ret; | 297 | 2.60M | } | 298 | 166M | } | 299 | | | 300 | 164M | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits | 301 | 164M | size_t token = ReadSymbolWithoutRefill(ctx, br); | 302 | 164M | if (uses_lz77) { | 303 | 164M | if (JXL_UNLIKELY(token >= lz77_threshold_)) { | 304 | 47.4k | num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_, | 305 | 47.4k | token - lz77_threshold_, br) + | 306 | 47.4k | lz77_min_length_; | 307 | 47.4k | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits | 308 | | // Distance code. | 309 | 47.4k | size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br); | 310 | 47.4k | size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br); | 311 | 47.4k | if (JXL_LIKELY(distance < num_special_distances_)) { | 312 | 27.3k | distance = special_distances_[distance]; | 313 | 27.3k | } else { | 314 | 20.0k | distance = distance + 1 - num_special_distances_; | 315 | 20.0k | } | 316 | 47.4k | if (JXL_UNLIKELY(distance > num_decoded_)) { | 317 | 8.10k | distance = num_decoded_; | 318 | 8.10k | } | 319 | 47.4k | if (JXL_UNLIKELY(distance > kWindowSize)) { | 320 | 0 | distance = kWindowSize; | 321 | 0 | } | 322 | 47.4k | copy_pos_ = num_decoded_ - distance; | 323 | 47.4k | if (JXL_UNLIKELY(distance == 0)) { | 324 | 377 | JXL_DASSERT(lz77_window_ != nullptr); | 325 | | // distance 0 -> num_decoded_ == copy_pos_ == 0 | 326 | 377 | size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); | 327 | 377 | memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); | 328 | 377 | } | 329 | | // TODO(eustas): overflow; mark BitReader as unhealthy | 330 | 47.4k | if (num_to_copy_ < lz77_min_length_) return 0; | 331 | | // the code below is the same as doing this: | 332 | | // return ReadHybridUintClustered<uses_lz77>(ctx, br); | 333 | | // but gcc doesn't like recursive inlining | 334 | | | 335 | 47.3k | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; | 336 | 47.3k | num_to_copy_--; | 337 | 47.3k | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 338 | 47.3k | return ret; | 339 | 47.4k | } | 340 | 164M | } | 341 | 164M | size_t ret = ReadHybridUintConfig(configs[ctx], token, br); | 342 | 164M | if (uses_lz77 && lz77_window_) | 343 | 29.1M | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 344 | 164M | return ret; | 345 | 164M | } |
unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredInlined<false>(unsigned long, jxl::BitReader*) Line | Count | Source | 290 | 174M | BitReader* JXL_RESTRICT br) { | 291 | 174M | if (uses_lz77) { | 292 | 0 | if (JXL_UNLIKELY(num_to_copy_ > 0)) { | 293 | 0 | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; | 294 | 0 | num_to_copy_--; | 295 | 0 | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 296 | 0 | return ret; | 297 | 0 | } | 298 | 0 | } | 299 | | | 300 | 174M | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits | 301 | 174M | size_t token = ReadSymbolWithoutRefill(ctx, br); | 302 | 174M | if (uses_lz77) { | 303 | 0 | if (JXL_UNLIKELY(token >= lz77_threshold_)) { | 304 | 0 | num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_, | 305 | 0 | token - lz77_threshold_, br) + | 306 | 0 | lz77_min_length_; | 307 | 0 | br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits | 308 | | // Distance code. | 309 | 0 | size_t d_token = ReadSymbolWithoutRefill(lz77_ctx_, br); | 310 | 0 | size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], d_token, br); | 311 | 0 | if (JXL_LIKELY(distance < num_special_distances_)) { | 312 | 0 | distance = special_distances_[distance]; | 313 | 0 | } else { | 314 | 0 | distance = distance + 1 - num_special_distances_; | 315 | 0 | } | 316 | 0 | if (JXL_UNLIKELY(distance > num_decoded_)) { | 317 | 0 | distance = num_decoded_; | 318 | 0 | } | 319 | 0 | if (JXL_UNLIKELY(distance > kWindowSize)) { | 320 | 0 | distance = kWindowSize; | 321 | 0 | } | 322 | 0 | copy_pos_ = num_decoded_ - distance; | 323 | 0 | if (JXL_UNLIKELY(distance == 0)) { | 324 | 0 | JXL_DASSERT(lz77_window_ != nullptr); | 325 | | // distance 0 -> num_decoded_ == copy_pos_ == 0 | 326 | 0 | size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); | 327 | 0 | memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); | 328 | 0 | } | 329 | | // TODO(eustas): overflow; mark BitReader as unhealthy | 330 | 0 | if (num_to_copy_ < lz77_min_length_) return 0; | 331 | | // the code below is the same as doing this: | 332 | | // return ReadHybridUintClustered<uses_lz77>(ctx, br); | 333 | | // but gcc doesn't like recursive inlining | 334 | | | 335 | 0 | size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; | 336 | 0 | num_to_copy_--; | 337 | 0 | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 338 | 0 | return ret; | 339 | 0 | } | 340 | 0 | } | 341 | 174M | size_t ret = ReadHybridUintConfig(configs[ctx], token, br); | 342 | 174M | if (uses_lz77 && lz77_window_) | 343 | 0 | lz77_window_[(num_decoded_++) & kWindowMask] = ret; | 344 | 174M | return ret; | 345 | 174M | } |
|
346 | | |
347 | | // same but not inlined |
348 | | template <bool uses_lz77> |
349 | 157M | size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { |
350 | 157M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); |
351 | 157M | } unsigned long jxl::ANSSymbolReader::ReadHybridUintClustered<true>(unsigned long, jxl::BitReader*) Line | Count | Source | 349 | 153M | size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { | 350 | 153M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); | 351 | 153M | } |
unsigned long jxl::ANSSymbolReader::ReadHybridUintClustered<false>(unsigned long, jxl::BitReader*) Line | Count | Source | 349 | 3.66M | size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { | 350 | 3.66M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); | 351 | 3.66M | } |
|
352 | | |
353 | | // inlined only in the no-lz77 case |
354 | | template <bool uses_lz77> |
355 | | JXL_INLINE size_t |
356 | 20.3M | ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) { |
357 | 20.3M | if (uses_lz77) { |
358 | 9.20M | return ReadHybridUintClustered<uses_lz77>(ctx, br); |
359 | 11.1M | } else { |
360 | 11.1M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); |
361 | 11.1M | } |
362 | 20.3M | } unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredMaybeInlined<true>(unsigned long, jxl::BitReader*) Line | Count | Source | 356 | 9.20M | ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) { | 357 | 9.20M | if (uses_lz77) { | 358 | 9.20M | return ReadHybridUintClustered<uses_lz77>(ctx, br); | 359 | 9.20M | } else { | 360 | 0 | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); | 361 | 0 | } | 362 | 9.20M | } |
unsigned long jxl::ANSSymbolReader::ReadHybridUintClusteredMaybeInlined<false>(unsigned long, jxl::BitReader*) Line | Count | Source | 356 | 11.1M | ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) { | 357 | 11.1M | if (uses_lz77) { | 358 | 0 | return ReadHybridUintClustered<uses_lz77>(ctx, br); | 359 | 11.1M | } else { | 360 | 11.1M | return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); | 361 | 11.1M | } | 362 | 11.1M | } |
|
363 | | |
364 | | // inlined, for use in hot paths |
365 | | template <bool uses_lz77> |
366 | | JXL_INLINE size_t |
367 | | ReadHybridUintInlined(size_t ctx, BitReader* JXL_RESTRICT br, |
368 | 2.67M | const std::vector<uint8_t>& context_map) { |
369 | 2.67M | return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br); |
370 | 2.67M | } unsigned long jxl::ANSSymbolReader::ReadHybridUintInlined<true>(unsigned long, jxl::BitReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&) Line | Count | Source | 368 | 2.12M | const std::vector<uint8_t>& context_map) { | 369 | 2.12M | return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br); | 370 | 2.12M | } |
unsigned long jxl::ANSSymbolReader::ReadHybridUintInlined<false>(unsigned long, jxl::BitReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&) Line | Count | Source | 368 | 550k | const std::vector<uint8_t>& context_map) { | 369 | 550k | return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br); | 370 | 550k | } |
|
371 | | |
372 | | // not inlined, for use in non-hot paths |
373 | | size_t ReadHybridUint(size_t ctx, BitReader* JXL_RESTRICT br, |
374 | 140M | const std::vector<uint8_t>& context_map) { |
375 | 140M | return ReadHybridUintClustered</*uses_lz77=*/true>(context_map[ctx], br); |
376 | 140M | } |
377 | | |
378 | | // ctx is a *clustered* context! |
379 | | // This function will modify the ANS state as if `count` symbols have been |
380 | | // decoded. |
381 | 128k | bool IsSingleValueAndAdvance(size_t ctx, uint32_t* value, size_t count) { |
382 | | // TODO(veluca): No optimization for Huffman mode yet. |
383 | 128k | if (use_prefix_code_) return false; |
384 | | // TODO(eustas): Check if we could deal with copy tail as well. |
385 | 71.3k | if (num_to_copy_ != 0) return false; |
386 | | // TODO(eustas): propagate "degenerate_symbol" to simplify this method. |
387 | 71.3k | const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); |
388 | 71.3k | const AliasTable::Entry* table = &alias_tables_[ctx << log_alpha_size_]; |
389 | 71.3k | AliasTable::Symbol symbol = |
390 | 71.3k | AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); |
391 | 71.3k | if (symbol.freq != ANS_TAB_SIZE) return false; |
392 | 41.0k | if (configs[ctx].split_token <= symbol.value) return false; |
393 | 40.1k | if (symbol.value >= lz77_threshold_) return false; |
394 | 40.1k | *value = symbol.value; |
395 | 40.1k | if (lz77_window_) { |
396 | 16.5M | for (size_t i = 0; i < count; i++) { |
397 | 16.5M | lz77_window_[(num_decoded_++) & kWindowMask] = symbol.value; |
398 | 16.5M | } |
399 | 22.5k | } |
400 | 40.1k | return true; |
401 | 40.1k | } |
402 | | |
403 | | static constexpr size_t kMaxCheckpointInterval = 512; |
404 | | struct Checkpoint { |
405 | | uint32_t state; |
406 | | uint32_t num_to_copy; |
407 | | uint32_t copy_pos; |
408 | | uint32_t num_decoded; |
409 | | uint32_t lz77_window[kMaxCheckpointInterval]; |
410 | | }; |
411 | 21.6k | void Save(Checkpoint* checkpoint) { |
412 | 21.6k | checkpoint->state = state_; |
413 | 21.6k | checkpoint->num_decoded = num_decoded_; |
414 | 21.6k | checkpoint->num_to_copy = num_to_copy_; |
415 | 21.6k | checkpoint->copy_pos = copy_pos_; |
416 | 21.6k | if (lz77_window_) { |
417 | 13.8k | size_t win_start = num_decoded_ & kWindowMask; |
418 | 13.8k | size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; |
419 | 13.8k | if (win_end > win_start) { |
420 | 13.8k | memcpy(checkpoint->lz77_window, lz77_window_ + win_start, |
421 | 13.8k | (win_end - win_start) * sizeof(*lz77_window_)); |
422 | 13.8k | } else { |
423 | 1 | memcpy(checkpoint->lz77_window, lz77_window_ + win_start, |
424 | 1 | (kWindowSize - win_start) * sizeof(*lz77_window_)); |
425 | 1 | memcpy(checkpoint->lz77_window + (kWindowSize - win_start), |
426 | 1 | lz77_window_, win_end * sizeof(*lz77_window_)); |
427 | 1 | } |
428 | 13.8k | } |
429 | 21.6k | } |
430 | 1.49k | void Restore(const Checkpoint& checkpoint) { |
431 | 1.49k | state_ = checkpoint.state; |
432 | 1.49k | JXL_DASSERT(num_decoded_ <= |
433 | 1.49k | checkpoint.num_decoded + kMaxCheckpointInterval); |
434 | 1.49k | num_decoded_ = checkpoint.num_decoded; |
435 | 1.49k | num_to_copy_ = checkpoint.num_to_copy; |
436 | 1.49k | copy_pos_ = checkpoint.copy_pos; |
437 | 1.49k | if (lz77_window_) { |
438 | 1.23k | size_t win_start = num_decoded_ & kWindowMask; |
439 | 1.23k | size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; |
440 | 1.23k | if (win_end > win_start) { |
441 | 1.23k | memcpy(lz77_window_ + win_start, checkpoint.lz77_window, |
442 | 1.23k | (win_end - win_start) * sizeof(*lz77_window_)); |
443 | 1.23k | } else { |
444 | 0 | memcpy(lz77_window_ + win_start, checkpoint.lz77_window, |
445 | 0 | (kWindowSize - win_start) * sizeof(*lz77_window_)); |
446 | 0 | memcpy(lz77_window_, checkpoint.lz77_window + (kWindowSize - win_start), |
447 | 0 | win_end * sizeof(*lz77_window_)); |
448 | 0 | } |
449 | 1.23k | } |
450 | 1.49k | } |
451 | | |
452 | | private: |
453 | | ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, |
454 | | size_t distance_multiplier, |
455 | | AlignedMemory&& lz77_window_storage); |
456 | | |
457 | | const AliasTable::Entry* JXL_RESTRICT alias_tables_; // not owned |
458 | | const HuffmanDecodingData* huffman_data_; |
459 | | bool use_prefix_code_; |
460 | | uint32_t state_ = ANS_SIGNATURE << 16u; |
461 | | const HybridUintConfig* JXL_RESTRICT configs; |
462 | | uint32_t log_alpha_size_{}; |
463 | | uint32_t log_entry_size_{}; |
464 | | uint32_t entry_size_minus_1_{}; |
465 | | |
466 | | // LZ77 structures and constants. |
467 | | static constexpr size_t kWindowMask = kWindowSize - 1; |
468 | | // a std::vector incurs unacceptable decoding speed loss because of |
469 | | // initialization. |
470 | | AlignedMemory lz77_window_storage_; |
471 | | uint32_t* lz77_window_ = nullptr; |
472 | | uint32_t num_decoded_ = 0; |
473 | | uint32_t num_to_copy_ = 0; |
474 | | uint32_t copy_pos_ = 0; |
475 | | uint32_t lz77_ctx_ = 0; |
476 | | uint32_t lz77_min_length_ = 0; |
477 | | uint32_t lz77_threshold_ = 1 << 20; // bigger than any symbol. |
478 | | HybridUintConfig lz77_length_uint_; |
479 | | uint32_t special_distances_[kNumSpecialDistances]{}; |
480 | | uint32_t num_special_distances_{}; |
481 | | }; |
482 | | |
483 | | Status DecodeHistograms(JxlMemoryManager* memory_manager, BitReader* br, |
484 | | size_t num_contexts, ANSCode* code, |
485 | | std::vector<uint8_t>* context_map, |
486 | | bool disallow_lz77 = false); |
487 | | |
488 | | // Exposed for tests. |
489 | | Status DecodeUintConfigs(size_t log_alpha_size, |
490 | | std::vector<HybridUintConfig>* uint_config, |
491 | | BitReader* br); |
492 | | |
493 | | } // namespace jxl |
494 | | |
495 | | #endif // LIB_JXL_DEC_ANS_H_ |