/src/libjxl/lib/jxl/quant_weights.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #ifndef LIB_JXL_QUANT_WEIGHTS_H_ |
7 | | #define LIB_JXL_QUANT_WEIGHTS_H_ |
8 | | |
9 | | #include <jxl/memory_manager.h> |
10 | | |
11 | | #include <array> |
12 | | #include <cstdint> |
13 | | #include <cstring> |
14 | | #include <vector> |
15 | | |
16 | | #include "lib/jxl/ac_strategy.h" |
17 | | #include "lib/jxl/base/common.h" |
18 | | #include "lib/jxl/base/compiler_specific.h" |
19 | | #include "lib/jxl/base/status.h" |
20 | | #include "lib/jxl/dec_bit_reader.h" |
21 | | #include "lib/jxl/frame_dimensions.h" |
22 | | #include "lib/jxl/memory_manager_internal.h" |
23 | | |
24 | | namespace jxl { |
25 | | |
26 | | static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea; |
27 | | static constexpr size_t kNumPredefinedTables = 1; |
28 | | static constexpr size_t kCeilLog2NumPredefinedTables = 0; |
29 | | static constexpr size_t kLog2NumQuantModes = 3; |
30 | | |
31 | | struct DctQuantWeightParams { |
32 | | static constexpr size_t kLog2MaxDistanceBands = 4; |
33 | | static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands); |
34 | | typedef std::array<std::array<float, kMaxDistanceBands>, 3> |
35 | | DistanceBandsArray; |
36 | | |
37 | | size_t num_distance_bands = 0; |
38 | | DistanceBandsArray distance_bands = {}; |
39 | | |
40 | 307k | constexpr DctQuantWeightParams() : num_distance_bands(0) {} |
41 | | |
42 | | constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands, |
43 | | size_t num_dist_bands) |
44 | 16 | : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {} |
45 | | |
46 | | template <size_t num_dist_bands> |
47 | 0 | explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) { |
48 | 0 | num_distance_bands = num_dist_bands; |
49 | 0 | for (size_t c = 0; c < 3; c++) { |
50 | 0 | memcpy(distance_bands[c].data(), dist_bands[c], |
51 | 0 | sizeof(float) * num_dist_bands); |
52 | 0 | } |
53 | 0 | } |
54 | | }; |
55 | | |
56 | | // NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) |
57 | | struct QuantEncodingInternal { |
58 | | enum Mode { |
59 | | kQuantModeLibrary, |
60 | | kQuantModeID, |
61 | | kQuantModeDCT2, |
62 | | kQuantModeDCT4, |
63 | | kQuantModeDCT4X8, |
64 | | kQuantModeAFV, |
65 | | kQuantModeDCT, |
66 | | kQuantModeRAW, |
67 | | }; |
68 | | |
69 | | template <Mode mode> |
70 | | struct Tag {}; |
71 | | |
72 | | typedef std::array<std::array<float, 3>, 3> IdWeights; |
73 | | typedef std::array<std::array<float, 6>, 3> DCT2Weights; |
74 | | typedef std::array<std::array<float, 2>, 3> DCT4Multipliers; |
75 | | typedef std::array<std::array<float, 9>, 3> AFVWeights; |
76 | | typedef std::array<float, 3> DCT4x8Multipliers; |
77 | | |
78 | | template <size_t A> |
79 | 153k | static constexpr QuantEncodingInternal Library() { |
80 | 153k | static_assert(A < kNumPredefinedTables); |
81 | 153k | return QuantEncodingInternal(Tag<kQuantModeLibrary>(), A); |
82 | 153k | } |
83 | | constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */, |
84 | | uint8_t predefined) |
85 | 153k | : mode(kQuantModeLibrary), predefined(predefined) {} |
86 | | |
87 | | // Identity |
88 | | // xybweights is an array of {xweights, yweights, bweights}. |
89 | 1 | static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) { |
90 | 1 | return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights); |
91 | 1 | } |
92 | | constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */, |
93 | | const IdWeights& xybweights) |
94 | 1 | : mode(kQuantModeID), idweights(xybweights) {} |
95 | | |
96 | | // DCT2 |
97 | 1 | static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) { |
98 | 1 | return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights); |
99 | 1 | } |
100 | | constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */, |
101 | | const DCT2Weights& xybweights) |
102 | 1 | : mode(kQuantModeDCT2), dct2weights(xybweights) {} |
103 | | |
104 | | // DCT4 |
105 | | static constexpr QuantEncodingInternal DCT4( |
106 | 2 | const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) { |
107 | 2 | return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul); |
108 | 2 | } |
109 | | constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */, |
110 | | const DctQuantWeightParams& params, |
111 | | const DCT4Multipliers& xybmul) |
112 | 2 | : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {} |
113 | | |
114 | | // DCT4x8 |
115 | | static constexpr QuantEncodingInternal DCT4X8( |
116 | 2 | const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) { |
117 | 2 | return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul); |
118 | 2 | } |
119 | | constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */, |
120 | | const DctQuantWeightParams& params, |
121 | | const DCT4x8Multipliers& xybmul) |
122 | 2 | : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {} |
123 | | |
124 | | // DCT |
125 | | static constexpr QuantEncodingInternal DCT( |
126 | 12 | const DctQuantWeightParams& params) { |
127 | 12 | return QuantEncodingInternal(Tag<kQuantModeDCT>(), params); |
128 | 12 | } |
129 | | constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */, |
130 | | const DctQuantWeightParams& params) |
131 | 12 | : mode(kQuantModeDCT), dct_params(params) {} |
132 | | |
133 | | // AFV |
134 | | static constexpr QuantEncodingInternal AFV( |
135 | | const DctQuantWeightParams& params4x8, |
136 | 1 | const DctQuantWeightParams& params4x4, const AFVWeights& weights) { |
137 | 1 | return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4, |
138 | 1 | weights); |
139 | 1 | } |
140 | | constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */, |
141 | | const DctQuantWeightParams& params4x8, |
142 | | const DctQuantWeightParams& params4x4, |
143 | | const AFVWeights& weights) |
144 | 1 | : mode(kQuantModeAFV), |
145 | 1 | dct_params(params4x8), |
146 | 1 | afv_weights(weights), |
147 | 1 | dct_params_afv_4x4(params4x4) {} |
148 | | |
149 | | // This constructor is not constexpr so it can't be used in any of the |
150 | | // constexpr cases above. |
151 | 0 | explicit QuantEncodingInternal(Mode mode) : mode(mode) {} |
152 | | |
153 | | Mode mode; |
154 | | |
155 | | // Weights for DCT4+ tables. |
156 | | DctQuantWeightParams dct_params; |
157 | | |
158 | | union { |
159 | | // Weights for identity. |
160 | | IdWeights idweights; |
161 | | |
162 | | // Weights for DCT2. |
163 | | DCT2Weights dct2weights; |
164 | | |
165 | | // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV. |
166 | | DCT4Multipliers dct4multipliers; |
167 | | |
168 | | // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1, |
169 | | // 0); {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) + |
170 | | // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated |
171 | | // as in GetQuantWeights for DC and are used for other coefficients. |
172 | | AFVWeights afv_weights = {}; |
173 | | |
174 | | // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4. |
175 | | DCT4x8Multipliers dct4x8multipliers; |
176 | | |
177 | | // Only used in kQuantModeRAW mode. |
178 | | struct { |
179 | | // explicit quantization table (like in JPEG) |
180 | | std::vector<int>* qtable = nullptr; |
181 | | float qtable_den = 1.f / (8 * 255); |
182 | | } qraw; |
183 | | }; |
184 | | |
185 | | // Weights for 4x4 sub-block in AFV. |
186 | | DctQuantWeightParams dct_params_afv_4x4; |
187 | | |
188 | | union { |
189 | | // Which predefined table to use. Only used if mode is kQuantModeLibrary. |
190 | | uint8_t predefined = 0; |
191 | | |
192 | | // Which other quant table to copy; must copy from a table that comes before |
193 | | // the current one. Only used if mode is kQuantModeCopy. |
194 | | uint8_t source; |
195 | | }; |
196 | | }; |
197 | | |
198 | | class QuantEncoding final : public QuantEncodingInternal { |
199 | | public: |
200 | | QuantEncoding(const QuantEncoding& other) |
201 | 2.61M | : QuantEncodingInternal( |
202 | 2.61M | static_cast<const QuantEncodingInternal&>(other)) { |
203 | 2.61M | if (mode == kQuantModeRAW && qraw.qtable) { |
204 | | // Need to make a copy of the passed *qtable. |
205 | 0 | qraw.qtable = new std::vector<int>(*other.qraw.qtable); |
206 | 0 | } |
207 | 2.61M | } |
208 | | QuantEncoding(QuantEncoding&& other) noexcept |
209 | 0 | : QuantEncodingInternal( |
210 | 0 | static_cast<const QuantEncodingInternal&>(other)) { |
211 | | // Steal the qtable from the other object if any. |
212 | 0 | if (mode == kQuantModeRAW) { |
213 | 0 | other.qraw.qtable = nullptr; |
214 | 0 | } |
215 | 0 | } |
216 | 0 | QuantEncoding& operator=(const QuantEncoding& other) { |
217 | 0 | if (mode == kQuantModeRAW && qraw.qtable) { |
218 | 0 | delete qraw.qtable; |
219 | 0 | } |
220 | 0 | *static_cast<QuantEncodingInternal*>(this) = |
221 | 0 | QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other)); |
222 | 0 | if (mode == kQuantModeRAW && qraw.qtable) { |
223 | | // Need to make a copy of the passed *qtable. |
224 | 0 | qraw.qtable = new std::vector<int>(*other.qraw.qtable); |
225 | 0 | } |
226 | 0 | return *this; |
227 | 0 | } |
228 | | |
229 | 2.76M | ~QuantEncoding() { |
230 | 2.76M | if (mode == kQuantModeRAW && qraw.qtable) { |
231 | 58 | delete qraw.qtable; |
232 | 58 | } |
233 | 2.76M | } |
234 | | |
235 | | // Wrappers of the QuantEncodingInternal:: static functions that return a |
236 | | // QuantEncoding instead. This is using the explicit and private cast from |
237 | | // QuantEncodingInternal to QuantEncoding, which would be inlined anyway. |
238 | | // In general, you should use this wrappers. The only reason to directly |
239 | | // create a QuantEncodingInternal instance is if you need a constexpr version |
240 | | // of this class. Note that RAW() is not supported in that case since it uses |
241 | | // a std::vector. |
242 | | template <size_t A> |
243 | 153k | static QuantEncoding Library() { |
244 | 153k | return QuantEncoding(QuantEncodingInternal::Library<A>()); |
245 | 153k | } |
246 | 0 | static QuantEncoding Identity(const IdWeights& xybweights) { |
247 | 0 | return QuantEncoding(QuantEncodingInternal::Identity(xybweights)); |
248 | 0 | } |
249 | 0 | static QuantEncoding DCT2(const DCT2Weights& xybweights) { |
250 | 0 | return QuantEncoding(QuantEncodingInternal::DCT2(xybweights)); |
251 | 0 | } |
252 | | static QuantEncoding DCT4(const DctQuantWeightParams& params, |
253 | 0 | const DCT4Multipliers& xybmul) { |
254 | 0 | return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul)); |
255 | 0 | } |
256 | | static QuantEncoding DCT4X8(const DctQuantWeightParams& params, |
257 | 0 | const DCT4x8Multipliers& xybmul) { |
258 | 0 | return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul)); |
259 | 0 | } |
260 | 0 | static QuantEncoding DCT(const DctQuantWeightParams& params) { |
261 | 0 | return QuantEncoding(QuantEncodingInternal::DCT(params)); |
262 | 0 | } |
263 | | static QuantEncoding AFV(const DctQuantWeightParams& params4x8, |
264 | | const DctQuantWeightParams& params4x4, |
265 | 0 | const AFVWeights& weights) { |
266 | 0 | return QuantEncoding( |
267 | 0 | QuantEncodingInternal::AFV(params4x8, params4x4, weights)); |
268 | 0 | } |
269 | | |
270 | | // RAW, note that this one is not a constexpr one. |
271 | 0 | static QuantEncoding RAW(std::vector<int>&& qtable, int shift = 0) { |
272 | 0 | QuantEncoding encoding(kQuantModeRAW); |
273 | 0 | encoding.qraw.qtable = new std::vector<int>(); |
274 | 0 | *encoding.qraw.qtable = qtable; |
275 | 0 | encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255)); |
276 | 0 | return encoding; |
277 | 0 | } |
278 | | |
279 | | private: |
280 | | explicit QuantEncoding(const QuantEncodingInternal& other) |
281 | 153k | : QuantEncodingInternal(other) {} |
282 | | |
283 | | explicit QuantEncoding(QuantEncodingInternal::Mode mode_arg) |
284 | 0 | : QuantEncodingInternal(mode_arg) {} |
285 | | }; |
286 | | |
287 | | // A constexpr QuantEncodingInternal instance is often downcasted to the |
288 | | // QuantEncoding subclass even if the instance wasn't an instance of the |
289 | | // subclass. This is safe because user will upcast to QuantEncodingInternal to |
290 | | // access any of its members. |
291 | | static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal), |
292 | | "Don't add any members to QuantEncoding"); |
293 | | |
294 | | // Let's try to keep these 2**N for possible future simplicity. |
295 | | const float kInvDCQuant[3] = { |
296 | | 4096.0f, |
297 | | 512.0f, |
298 | | 256.0f, |
299 | | }; |
300 | | |
301 | | const float kDCQuant[3] = { |
302 | | 1.0f / kInvDCQuant[0], |
303 | | 1.0f / kInvDCQuant[1], |
304 | | 1.0f / kInvDCQuant[2], |
305 | | }; |
306 | | |
307 | | class ModularFrameEncoder; |
308 | | class ModularFrameDecoder; |
309 | | |
310 | | enum class QuantTable : size_t { |
311 | | DCT = 0, |
312 | | IDENTITY, |
313 | | DCT2X2, |
314 | | DCT4X4, |
315 | | DCT16X16, |
316 | | DCT32X32, |
317 | | // DCT16X8 |
318 | | DCT8X16, |
319 | | // DCT32X8 |
320 | | DCT8X32, |
321 | | // DCT32X16 |
322 | | DCT16X32, |
323 | | DCT4X8, |
324 | | // DCT8X4 |
325 | | AFV0, |
326 | | // AFV1 |
327 | | // AFV2 |
328 | | // AFV3 |
329 | | DCT64X64, |
330 | | // DCT64X32, |
331 | | DCT32X64, |
332 | | DCT128X128, |
333 | | // DCT128X64, |
334 | | DCT64X128, |
335 | | DCT256X256, |
336 | | // DCT256X128, |
337 | | DCT128X256 |
338 | | }; |
339 | | |
340 | | static constexpr uint8_t kNumQuantTables = |
341 | | static_cast<uint8_t>(QuantTable::DCT128X256) + 1; |
342 | | |
343 | | static const std::array<QuantTable, AcStrategy::kNumValidStrategies> |
344 | | kAcStrategyToQuantTableMap = { |
345 | | QuantTable::DCT, QuantTable::IDENTITY, QuantTable::DCT2X2, |
346 | | QuantTable::DCT4X4, QuantTable::DCT16X16, QuantTable::DCT32X32, |
347 | | QuantTable::DCT8X16, QuantTable::DCT8X16, QuantTable::DCT8X32, |
348 | | QuantTable::DCT8X32, QuantTable::DCT16X32, QuantTable::DCT16X32, |
349 | | QuantTable::DCT4X8, QuantTable::DCT4X8, QuantTable::AFV0, |
350 | | QuantTable::AFV0, QuantTable::AFV0, QuantTable::AFV0, |
351 | | QuantTable::DCT64X64, QuantTable::DCT32X64, QuantTable::DCT32X64, |
352 | | QuantTable::DCT128X128, QuantTable::DCT64X128, QuantTable::DCT64X128, |
353 | | QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256, |
354 | | }; |
355 | | |
356 | | class DequantMatrices { |
357 | | public: |
358 | | DequantMatrices(); |
359 | | |
360 | | static const QuantEncoding* Library(); |
361 | | |
362 | | typedef std::array<QuantEncodingInternal, |
363 | | kNumPredefinedTables * kNumQuantTables> |
364 | | DequantLibraryInternal; |
365 | | // Return the array of library kNumPredefinedTables QuantEncoding entries as |
366 | | // a constexpr array. Use Library() to obtain a pointer to the copy in the |
367 | | // .cc file. |
368 | | static DequantLibraryInternal LibraryInit(); |
369 | | |
370 | | // Returns aligned memory. |
371 | 203k | JXL_INLINE const float* Matrix(AcStrategyType quant_kind, size_t c) const { |
372 | 203k | JXL_DASSERT((1 << static_cast<uint32_t>(quant_kind)) & computed_mask_); |
373 | 203k | return &table_[table_offsets_[static_cast<size_t>(quant_kind) * 3 + c]]; |
374 | 203k | } |
375 | | |
376 | 0 | JXL_INLINE const float* InvMatrix(AcStrategyType quant_kind, size_t c) const { |
377 | 0 | size_t quant_table_idx = static_cast<uint32_t>(quant_kind); |
378 | 0 | JXL_DASSERT((1 << quant_table_idx) & computed_mask_); |
379 | 0 | return &inv_table_[table_offsets_[quant_table_idx * 3 + c]]; |
380 | 0 | } |
381 | | |
382 | | // DC quants are used in modular mode for XYB multipliers. |
383 | 137k | JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; } |
384 | 46.5k | JXL_INLINE const float* DCQuants() const { return dc_quant_; } |
385 | | |
386 | 137k | JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; } |
387 | | |
388 | | // For encoder. |
389 | 0 | void SetEncodings(const std::vector<QuantEncoding>& encodings) { |
390 | 0 | encodings_ = encodings; |
391 | 0 | computed_mask_ = 0; |
392 | 0 | } |
393 | | |
394 | | // For encoder. |
395 | 0 | void SetDCQuant(const float dc[3]) { |
396 | 0 | for (size_t c = 0; c < 3; c++) { |
397 | 0 | dc_quant_[c] = 1.0f / dc[c]; |
398 | 0 | inv_dc_quant_[c] = dc[c]; |
399 | 0 | } |
400 | 0 | } |
401 | | |
402 | | Status Decode(JxlMemoryManager* memory_manager, BitReader* br, |
403 | | ModularFrameDecoder* modular_frame_decoder = nullptr); |
404 | | Status DecodeDC(BitReader* br); |
405 | | |
406 | 0 | const std::vector<QuantEncoding>& encodings() const { return encodings_; } |
407 | | |
408 | | static constexpr auto required_size_x = |
409 | | to_array<int>({1, 1, 1, 1, 2, 4, 1, 1, 2, 1, 1, 8, 4, 16, 8, 32, 16}); |
410 | | static_assert(kNumQuantTables == required_size_x.size(), |
411 | | "Update this array when adding or removing quant tables."); |
412 | | |
413 | | static constexpr auto required_size_y = |
414 | | to_array<int>({1, 1, 1, 1, 2, 4, 2, 4, 4, 1, 1, 8, 8, 16, 16, 32, 32}); |
415 | | static_assert(kNumQuantTables == required_size_y.size(), |
416 | | "Update this array when adding or removing quant tables."); |
417 | | |
418 | | // MUST be equal `sum(dot(required_size_x, required_size_y))`. |
419 | | static constexpr size_t kSumRequiredXy = 2056; |
420 | | |
421 | | Status EnsureComputed(JxlMemoryManager* memory_manager, uint32_t acs_mask); |
422 | | |
423 | | private: |
424 | | static constexpr size_t kTotalTableSize = kSumRequiredXy * kDCTBlockSize * 3; |
425 | | |
426 | | uint32_t computed_mask_ = 0; |
427 | | // kTotalTableSize entries followed by kTotalTableSize for inv_table |
428 | | AlignedMemory table_storage_; |
429 | | const float* table_; |
430 | | const float* inv_table_; |
431 | | float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]}; |
432 | | float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]}; |
433 | | size_t table_offsets_[AcStrategy::kNumValidStrategies * 3]; |
434 | | std::vector<QuantEncoding> encodings_; |
435 | | }; |
436 | | |
437 | | } // namespace jxl |
438 | | |
439 | | #endif // LIB_JXL_QUANT_WEIGHTS_H_ |