/src/libjxl/lib/jxl/quant_weights.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #ifndef LIB_JXL_QUANT_WEIGHTS_H_ |
7 | | #define LIB_JXL_QUANT_WEIGHTS_H_ |
8 | | |
9 | | #include <jxl/memory_manager.h> |
10 | | |
11 | | #include <array> |
12 | | #include <cstdint> |
13 | | #include <cstring> |
14 | | #include <vector> |
15 | | |
16 | | #include "lib/jxl/ac_strategy.h" |
17 | | #include "lib/jxl/base/common.h" |
18 | | #include "lib/jxl/base/compiler_specific.h" |
19 | | #include "lib/jxl/base/status.h" |
20 | | #include "lib/jxl/dec_bit_reader.h" |
21 | | #include "lib/jxl/frame_dimensions.h" |
22 | | #include "lib/jxl/memory_manager_internal.h" |
23 | | |
24 | | namespace jxl { |
25 | | |
26 | | static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea; |
27 | | static constexpr size_t kNumPredefinedTables = 1; |
28 | | static constexpr size_t kCeilLog2NumPredefinedTables = 0; |
29 | | static constexpr size_t kLog2NumQuantModes = 3; |
30 | | |
31 | | struct DctQuantWeightParams { |
32 | | static constexpr size_t kLog2MaxDistanceBands = 4; |
33 | | static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands); |
34 | | using DistanceBandsArray = |
35 | | std::array<std::array<float, kMaxDistanceBands>, 3>; |
36 | | |
37 | | size_t num_distance_bands = 0; |
38 | | DistanceBandsArray distance_bands = {}; |
39 | | |
40 | 113k | constexpr DctQuantWeightParams() : num_distance_bands(0) {} |
41 | | |
42 | | constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands, |
43 | | size_t num_dist_bands) |
44 | 32 | : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {} |
45 | | |
46 | | template <size_t num_dist_bands> |
47 | 0 | explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) { |
48 | 0 | num_distance_bands = num_dist_bands; |
49 | 0 | for (size_t c = 0; c < 3; c++) { |
50 | 0 | memcpy(distance_bands[c].data(), dist_bands[c], |
51 | 0 | sizeof(float) * num_dist_bands); |
52 | 0 | } |
53 | 0 | } |
54 | | }; |
55 | | |
56 | | // NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) |
57 | | struct QuantEncodingInternal { |
58 | | enum Mode { |
59 | | kQuantModeLibrary, |
60 | | kQuantModeID, |
61 | | kQuantModeDCT2, |
62 | | kQuantModeDCT4, |
63 | | kQuantModeDCT4X8, |
64 | | kQuantModeAFV, |
65 | | kQuantModeDCT, |
66 | | kQuantModeRAW, |
67 | | }; |
68 | | |
69 | | template <Mode mode> |
70 | | struct Tag {}; |
71 | | |
72 | | using IdWeights = std::array<std::array<float, 3>, 3>; |
73 | | using DCT2Weights = std::array<std::array<float, 6>, 3>; |
74 | | using DCT4Multipliers = std::array<std::array<float, 2>, 3>; |
75 | | using AFVWeights = std::array<std::array<float, 9>, 3>; |
76 | | using DCT4x8Multipliers = std::array<float, 3>; |
77 | | |
78 | | template <size_t A> |
79 | 56.8k | static constexpr QuantEncodingInternal Library() { |
80 | 56.8k | static_assert(A < kNumPredefinedTables, "Library index out of bounds"); |
81 | 56.8k | return QuantEncodingInternal(Tag<kQuantModeLibrary>(), A); |
82 | 56.8k | } |
83 | | constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */, |
84 | | uint8_t predefined) |
85 | 56.8k | : mode(kQuantModeLibrary), predefined(predefined) {} |
86 | | |
87 | | // Identity |
88 | | // xybweights is an array of {xweights, yweights, bweights}. |
89 | 2 | static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) { |
90 | 2 | return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights); |
91 | 2 | } |
92 | | constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */, |
93 | | const IdWeights& xybweights) |
94 | 2 | : mode(kQuantModeID), idweights(xybweights) {} |
95 | | |
96 | | // DCT2 |
97 | 2 | static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) { |
98 | 2 | return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights); |
99 | 2 | } |
100 | | constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */, |
101 | | const DCT2Weights& xybweights) |
102 | 2 | : mode(kQuantModeDCT2), dct2weights(xybweights) {} |
103 | | |
104 | | // DCT4 |
105 | | static constexpr QuantEncodingInternal DCT4( |
106 | 4 | const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) { |
107 | 4 | return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul); |
108 | 4 | } |
109 | | constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */, |
110 | | const DctQuantWeightParams& params, |
111 | | const DCT4Multipliers& xybmul) |
112 | 4 | : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {} |
113 | | |
114 | | // DCT4x8 |
115 | | static constexpr QuantEncodingInternal DCT4X8( |
116 | 4 | const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) { |
117 | 4 | return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul); |
118 | 4 | } |
119 | | constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */, |
120 | | const DctQuantWeightParams& params, |
121 | | const DCT4x8Multipliers& xybmul) |
122 | 4 | : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {} |
123 | | |
124 | | // DCT |
125 | | static constexpr QuantEncodingInternal DCT( |
126 | 24 | const DctQuantWeightParams& params) { |
127 | 24 | return QuantEncodingInternal(Tag<kQuantModeDCT>(), params); |
128 | 24 | } |
129 | | constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */, |
130 | | const DctQuantWeightParams& params) |
131 | 24 | : mode(kQuantModeDCT), dct_params(params) {} |
132 | | |
133 | | // AFV |
134 | | static constexpr QuantEncodingInternal AFV( |
135 | | const DctQuantWeightParams& params4x8, |
136 | 2 | const DctQuantWeightParams& params4x4, const AFVWeights& weights) { |
137 | 2 | return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4, |
138 | 2 | weights); |
139 | 2 | } |
140 | | constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */, |
141 | | const DctQuantWeightParams& params4x8, |
142 | | const DctQuantWeightParams& params4x4, |
143 | | const AFVWeights& weights) |
144 | 2 | : mode(kQuantModeAFV), |
145 | 2 | dct_params(params4x8), |
146 | 2 | afv_weights(weights), |
147 | 2 | dct_params_afv_4x4(params4x4) {} |
148 | | |
149 | | // This constructor is not constexpr so it can't be used in any of the |
150 | | // constexpr cases above. |
151 | 0 | explicit QuantEncodingInternal(Mode mode) : mode(mode) {} |
152 | | |
153 | | Mode mode; |
154 | | |
155 | | // Weights for DCT4+ tables. |
156 | | DctQuantWeightParams dct_params; |
157 | | |
158 | | union { |
159 | | // Weights for identity. |
160 | | IdWeights idweights; |
161 | | |
162 | | // Weights for DCT2. |
163 | | DCT2Weights dct2weights; |
164 | | |
165 | | // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV. |
166 | | DCT4Multipliers dct4multipliers; |
167 | | |
168 | | // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1, |
169 | | // 0); {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) + |
170 | | // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated |
171 | | // as in GetQuantWeights for DC and are used for other coefficients. |
172 | | AFVWeights afv_weights = {}; |
173 | | |
174 | | // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4. |
175 | | DCT4x8Multipliers dct4x8multipliers; |
176 | | |
177 | | // Only used in kQuantModeRAW mode. |
178 | | struct { |
179 | | // explicit quantization table (like in JPEG) |
180 | | std::vector<int>* qtable = nullptr; |
181 | | float qtable_den = 1.f / (8 * 255); |
182 | | } qraw; |
183 | | }; |
184 | | |
185 | | // Weights for 4x4 sub-block in AFV. |
186 | | DctQuantWeightParams dct_params_afv_4x4; |
187 | | |
188 | | // Which predefined table to use. Only used if mode is kQuantModeLibrary. |
189 | | uint8_t predefined = 0; |
190 | | }; |
191 | | |
192 | | class QuantEncoding final : public QuantEncodingInternal { |
193 | | public: |
194 | | QuantEncoding(const QuantEncoding& other) |
195 | 966k | : QuantEncodingInternal( |
196 | 966k | static_cast<const QuantEncodingInternal&>(other)) { |
197 | 966k | if (mode == kQuantModeRAW && qraw.qtable) { |
198 | | // Need to make a copy of the passed *qtable. |
199 | 0 | qraw.qtable = new std::vector<int>(*other.qraw.qtable); |
200 | 0 | } |
201 | 966k | } |
202 | | QuantEncoding(QuantEncoding&& other) noexcept |
203 | 0 | : QuantEncodingInternal( |
204 | 0 | static_cast<const QuantEncodingInternal&>(other)) { |
205 | | // Steal the qtable from the other object if any. |
206 | 0 | if (mode == kQuantModeRAW) { |
207 | 0 | other.qraw.qtable = nullptr; |
208 | 0 | } |
209 | 0 | } |
210 | 0 | QuantEncoding& operator=(const QuantEncoding& other) { |
211 | 0 | if (mode == kQuantModeRAW && qraw.qtable) { |
212 | 0 | delete qraw.qtable; |
213 | 0 | } |
214 | 0 | *static_cast<QuantEncodingInternal*>(this) = |
215 | 0 | QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other)); |
216 | 0 | if (mode == kQuantModeRAW && qraw.qtable) { |
217 | | // Need to make a copy of the passed *qtable. |
218 | 0 | qraw.qtable = new std::vector<int>(*other.qraw.qtable); |
219 | 0 | } |
220 | 0 | return *this; |
221 | 0 | } |
222 | | |
223 | 1.02M | ~QuantEncoding() { |
224 | 1.02M | if (mode == kQuantModeRAW && qraw.qtable) { |
225 | 46 | delete qraw.qtable; |
226 | 46 | } |
227 | 1.02M | } |
228 | | |
229 | | // Wrappers of the QuantEncodingInternal:: static functions that return a |
230 | | // QuantEncoding instead. This is using the explicit and private cast from |
231 | | // QuantEncodingInternal to QuantEncoding, which would be inlined anyway. |
232 | | // In general, you should use this wrappers. The only reason to directly |
233 | | // create a QuantEncodingInternal instance is if you need a constexpr version |
234 | | // of this class. Note that RAW() is not supported in that case since it uses |
235 | | // a std::vector. |
236 | | template <size_t A> |
237 | 56.8k | static QuantEncoding Library() { |
238 | 56.8k | return QuantEncoding(QuantEncodingInternal::Library<A>()); |
239 | 56.8k | } |
240 | 0 | static QuantEncoding Identity(const IdWeights& xybweights) { |
241 | 0 | return QuantEncoding(QuantEncodingInternal::Identity(xybweights)); |
242 | 0 | } |
243 | 0 | static QuantEncoding DCT2(const DCT2Weights& xybweights) { |
244 | 0 | return QuantEncoding(QuantEncodingInternal::DCT2(xybweights)); |
245 | 0 | } |
246 | | static QuantEncoding DCT4(const DctQuantWeightParams& params, |
247 | 0 | const DCT4Multipliers& xybmul) { |
248 | 0 | return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul)); |
249 | 0 | } |
250 | | static QuantEncoding DCT4X8(const DctQuantWeightParams& params, |
251 | 0 | const DCT4x8Multipliers& xybmul) { |
252 | 0 | return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul)); |
253 | 0 | } |
254 | 0 | static QuantEncoding DCT(const DctQuantWeightParams& params) { |
255 | 0 | return QuantEncoding(QuantEncodingInternal::DCT(params)); |
256 | 0 | } |
257 | | static QuantEncoding AFV(const DctQuantWeightParams& params4x8, |
258 | | const DctQuantWeightParams& params4x4, |
259 | 0 | const AFVWeights& weights) { |
260 | 0 | return QuantEncoding( |
261 | 0 | QuantEncodingInternal::AFV(params4x8, params4x4, weights)); |
262 | 0 | } |
263 | | |
264 | | // RAW, note that this one is not a constexpr one. |
265 | 0 | static QuantEncoding RAW(std::vector<int>&& qtable, int shift = 0) { |
266 | 0 | QuantEncoding encoding(kQuantModeRAW); |
267 | 0 | encoding.qraw.qtable = new std::vector<int>(); |
268 | 0 | *encoding.qraw.qtable = qtable; |
269 | 0 | encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255)); |
270 | 0 | return encoding; |
271 | 0 | } |
272 | | |
273 | | private: |
274 | | explicit QuantEncoding(const QuantEncodingInternal& other) |
275 | 56.8k | : QuantEncodingInternal(other) {} |
276 | | |
277 | | explicit QuantEncoding(QuantEncodingInternal::Mode mode_arg) |
278 | 0 | : QuantEncodingInternal(mode_arg) {} |
279 | | }; |
280 | | |
281 | | // A constexpr QuantEncodingInternal instance is often downcasted to the |
282 | | // QuantEncoding subclass even if the instance wasn't an instance of the |
283 | | // subclass. This is safe because user will upcast to QuantEncodingInternal to |
284 | | // access any of its members. |
285 | | static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal), |
286 | | "Don't add any members to QuantEncoding"); |
287 | | |
288 | | // Let's try to keep these 2**N for possible future simplicity. |
289 | | const float kInvDCQuant[3] = { |
290 | | 4096.0f, |
291 | | 512.0f, |
292 | | 256.0f, |
293 | | }; |
294 | | |
295 | | const float kDCQuant[3] = { |
296 | | 1.0f / kInvDCQuant[0], |
297 | | 1.0f / kInvDCQuant[1], |
298 | | 1.0f / kInvDCQuant[2], |
299 | | }; |
300 | | |
301 | | class ModularFrameEncoder; |
302 | | class ModularFrameDecoder; |
303 | | |
304 | | enum class QuantTable : size_t { |
305 | | DCT = 0, |
306 | | IDENTITY, |
307 | | DCT2X2, |
308 | | DCT4X4, |
309 | | DCT16X16, |
310 | | DCT32X32, |
311 | | // DCT16X8 |
312 | | DCT8X16, |
313 | | // DCT32X8 |
314 | | DCT8X32, |
315 | | // DCT32X16 |
316 | | DCT16X32, |
317 | | DCT4X8, |
318 | | // DCT8X4 |
319 | | AFV0, |
320 | | // AFV1 |
321 | | // AFV2 |
322 | | // AFV3 |
323 | | DCT64X64, |
324 | | // DCT64X32, |
325 | | DCT32X64, |
326 | | DCT128X128, |
327 | | // DCT128X64, |
328 | | DCT64X128, |
329 | | DCT256X256, |
330 | | // DCT256X128, |
331 | | DCT128X256 |
332 | | }; |
333 | | |
334 | | static constexpr uint8_t kNumQuantTables = |
335 | | static_cast<uint8_t>(QuantTable::DCT128X256) + 1; |
336 | | |
337 | | static const std::array<QuantTable, AcStrategy::kNumValidStrategies> |
338 | | kAcStrategyToQuantTableMap = { |
339 | | QuantTable::DCT, QuantTable::IDENTITY, QuantTable::DCT2X2, |
340 | | QuantTable::DCT4X4, QuantTable::DCT16X16, QuantTable::DCT32X32, |
341 | | QuantTable::DCT8X16, QuantTable::DCT8X16, QuantTable::DCT8X32, |
342 | | QuantTable::DCT8X32, QuantTable::DCT16X32, QuantTable::DCT16X32, |
343 | | QuantTable::DCT4X8, QuantTable::DCT4X8, QuantTable::AFV0, |
344 | | QuantTable::AFV0, QuantTable::AFV0, QuantTable::AFV0, |
345 | | QuantTable::DCT64X64, QuantTable::DCT32X64, QuantTable::DCT32X64, |
346 | | QuantTable::DCT128X128, QuantTable::DCT64X128, QuantTable::DCT64X128, |
347 | | QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256, |
348 | | }; |
349 | | |
350 | | class DequantMatrices { |
351 | | public: |
352 | | DequantMatrices(); |
353 | | |
354 | | static const QuantEncoding* Library(); |
355 | | |
356 | | using DequantLibraryInternal = |
357 | | std::array<QuantEncodingInternal, kNumPredefinedTables * kNumQuantTables>; |
358 | | // Return the array of library kNumPredefinedTables QuantEncoding entries as |
359 | | // a constexpr array. Use Library() to obtain a pointer to the copy in the |
360 | | // .cc file. |
361 | | static DequantLibraryInternal LibraryInit(); |
362 | | |
363 | | // Returns aligned memory. |
364 | 16.2M | JXL_INLINE const float* Matrix(AcStrategyType quant_kind, size_t c) const { |
365 | 16.2M | JXL_DASSERT((1 << static_cast<uint32_t>(quant_kind)) & computed_mask_); |
366 | 16.2M | return &table_[table_offsets_[static_cast<size_t>(quant_kind) * 3 + c]]; |
367 | 16.2M | } |
368 | | |
369 | 17.4M | JXL_INLINE const float* InvMatrix(AcStrategyType quant_kind, size_t c) const { |
370 | 17.4M | size_t quant_table_idx = static_cast<uint32_t>(quant_kind); |
371 | 17.4M | JXL_DASSERT((1 << quant_table_idx) & computed_mask_); |
372 | 17.4M | return &inv_table_[table_offsets_[quant_table_idx * 3 + c]]; |
373 | 17.4M | } |
374 | | |
375 | | // DC quants are used in modular mode for XYB multipliers. |
376 | 48.1k | JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; } |
377 | 47.7k | JXL_INLINE const float* DCQuants() const { return dc_quant_; } |
378 | | |
379 | 532k | JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; } |
380 | | |
381 | | // For encoder. |
382 | 0 | void SetEncodings(const std::vector<QuantEncoding>& encodings) { |
383 | 0 | encodings_ = encodings; |
384 | 0 | computed_mask_ = 0; |
385 | 0 | } |
386 | | |
387 | | // For encoder. |
388 | 283 | void SetDCQuant(const float dc[3]) { |
389 | 1.13k | for (size_t c = 0; c < 3; c++) { |
390 | 849 | dc_quant_[c] = 1.0f / dc[c]; |
391 | 849 | inv_dc_quant_[c] = dc[c]; |
392 | 849 | } |
393 | 283 | } |
394 | | |
395 | | Status Decode(JxlMemoryManager* memory_manager, BitReader* br, |
396 | | ModularFrameDecoder* modular_frame_decoder = nullptr); |
397 | | Status DecodeDC(BitReader* br); |
398 | | |
399 | 186 | const std::vector<QuantEncoding>& encodings() const { return encodings_; } |
400 | | |
401 | | static constexpr auto required_size_x = |
402 | | to_array<int>({1, 1, 1, 1, 2, 4, 1, 1, 2, 1, 1, 8, 4, 16, 8, 32, 16}); |
403 | | static_assert(kNumQuantTables == required_size_x.size(), |
404 | | "Update this array when adding or removing quant tables."); |
405 | | |
406 | | static constexpr auto required_size_y = |
407 | | to_array<int>({1, 1, 1, 1, 2, 4, 2, 4, 4, 1, 1, 8, 8, 16, 16, 32, 32}); |
408 | | static_assert(kNumQuantTables == required_size_y.size(), |
409 | | "Update this array when adding or removing quant tables."); |
410 | | |
411 | | // MUST be equal `sum(dot(required_size_x, required_size_y))`. |
412 | | static constexpr size_t kSumRequiredXy = 2056; |
413 | | |
414 | | Status EnsureComputed(JxlMemoryManager* memory_manager, uint32_t acs_mask); |
415 | | |
416 | | private: |
417 | | static constexpr size_t kTotalTableSize = kSumRequiredXy * kDCTBlockSize * 3; |
418 | | |
419 | | uint32_t computed_mask_ = 0; |
420 | | // kTotalTableSize entries followed by kTotalTableSize for inv_table |
421 | | AlignedMemory table_storage_; |
422 | | const float* table_; |
423 | | const float* inv_table_; |
424 | | float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]}; |
425 | | float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]}; |
426 | | size_t table_offsets_[AcStrategy::kNumValidStrategies * 3]; |
427 | | std::vector<QuantEncoding> encodings_; |
428 | | }; |
429 | | |
430 | | } // namespace jxl |
431 | | |
432 | | #endif // LIB_JXL_QUANT_WEIGHTS_H_ |