Coverage Report

Created: 2025-06-16 07:00

/src/libjxl/lib/jxl/quant_weights.h
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#ifndef LIB_JXL_QUANT_WEIGHTS_H_
7
#define LIB_JXL_QUANT_WEIGHTS_H_
8
9
#include <jxl/memory_manager.h>
10
11
#include <array>
12
#include <cstdint>
13
#include <cstring>
14
#include <vector>
15
16
#include "lib/jxl/ac_strategy.h"
17
#include "lib/jxl/base/common.h"
18
#include "lib/jxl/base/compiler_specific.h"
19
#include "lib/jxl/base/status.h"
20
#include "lib/jxl/dec_bit_reader.h"
21
#include "lib/jxl/frame_dimensions.h"
22
#include "lib/jxl/memory_manager_internal.h"
23
24
namespace jxl {
25
26
static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea;
27
static constexpr size_t kNumPredefinedTables = 1;
28
static constexpr size_t kCeilLog2NumPredefinedTables = 0;
29
static constexpr size_t kLog2NumQuantModes = 3;
30
31
struct DctQuantWeightParams {
32
  static constexpr size_t kLog2MaxDistanceBands = 4;
33
  static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands);
34
  using DistanceBandsArray =
35
      std::array<std::array<float, kMaxDistanceBands>, 3>;
36
37
  size_t num_distance_bands = 0;
38
  DistanceBandsArray distance_bands = {};
39
40
113k
  constexpr DctQuantWeightParams() : num_distance_bands(0) {}
41
42
  constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands,
43
                                 size_t num_dist_bands)
44
32
      : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {}
45
46
  template <size_t num_dist_bands>
47
0
  explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) {
48
0
    num_distance_bands = num_dist_bands;
49
0
    for (size_t c = 0; c < 3; c++) {
50
0
      memcpy(distance_bands[c].data(), dist_bands[c],
51
0
             sizeof(float) * num_dist_bands);
52
0
    }
53
0
  }
54
};
55
56
// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding)
57
struct QuantEncodingInternal {
58
  enum Mode {
59
    kQuantModeLibrary,
60
    kQuantModeID,
61
    kQuantModeDCT2,
62
    kQuantModeDCT4,
63
    kQuantModeDCT4X8,
64
    kQuantModeAFV,
65
    kQuantModeDCT,
66
    kQuantModeRAW,
67
  };
68
69
  template <Mode mode>
70
  struct Tag {};
71
72
  using IdWeights = std::array<std::array<float, 3>, 3>;
73
  using DCT2Weights = std::array<std::array<float, 6>, 3>;
74
  using DCT4Multipliers = std::array<std::array<float, 2>, 3>;
75
  using AFVWeights = std::array<std::array<float, 9>, 3>;
76
  using DCT4x8Multipliers = std::array<float, 3>;
77
78
  template <size_t A>
79
56.8k
  static constexpr QuantEncodingInternal Library() {
80
56.8k
    static_assert(A < kNumPredefinedTables, "Library index out of bounds");
81
56.8k
    return QuantEncodingInternal(Tag<kQuantModeLibrary>(), A);
82
56.8k
  }
83
  constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */,
84
                                  uint8_t predefined)
85
56.8k
      : mode(kQuantModeLibrary), predefined(predefined) {}
86
87
  // Identity
88
  // xybweights is an array of {xweights, yweights, bweights}.
89
2
  static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) {
90
2
    return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights);
91
2
  }
92
  constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */,
93
                                  const IdWeights& xybweights)
94
2
      : mode(kQuantModeID), idweights(xybweights) {}
95
96
  // DCT2
97
2
  static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) {
98
2
    return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights);
99
2
  }
100
  constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */,
101
                                  const DCT2Weights& xybweights)
102
2
      : mode(kQuantModeDCT2), dct2weights(xybweights) {}
103
104
  // DCT4
105
  static constexpr QuantEncodingInternal DCT4(
106
4
      const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) {
107
4
    return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul);
108
4
  }
109
  constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */,
110
                                  const DctQuantWeightParams& params,
111
                                  const DCT4Multipliers& xybmul)
112
4
      : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {}
113
114
  // DCT4x8
115
  static constexpr QuantEncodingInternal DCT4X8(
116
4
      const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) {
117
4
    return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul);
118
4
  }
119
  constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */,
120
                                  const DctQuantWeightParams& params,
121
                                  const DCT4x8Multipliers& xybmul)
122
4
      : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {}
123
124
  // DCT
125
  static constexpr QuantEncodingInternal DCT(
126
24
      const DctQuantWeightParams& params) {
127
24
    return QuantEncodingInternal(Tag<kQuantModeDCT>(), params);
128
24
  }
129
  constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */,
130
                                  const DctQuantWeightParams& params)
131
24
      : mode(kQuantModeDCT), dct_params(params) {}
132
133
  // AFV
134
  static constexpr QuantEncodingInternal AFV(
135
      const DctQuantWeightParams& params4x8,
136
2
      const DctQuantWeightParams& params4x4, const AFVWeights& weights) {
137
2
    return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4,
138
2
                                 weights);
139
2
  }
140
  constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */,
141
                                  const DctQuantWeightParams& params4x8,
142
                                  const DctQuantWeightParams& params4x4,
143
                                  const AFVWeights& weights)
144
2
      : mode(kQuantModeAFV),
145
2
        dct_params(params4x8),
146
2
        afv_weights(weights),
147
2
        dct_params_afv_4x4(params4x4) {}
148
149
  // This constructor is not constexpr so it can't be used in any of the
150
  // constexpr cases above.
151
0
  explicit QuantEncodingInternal(Mode mode) : mode(mode) {}
152
153
  Mode mode;
154
155
  // Weights for DCT4+ tables.
156
  DctQuantWeightParams dct_params;
157
158
  union {
159
    // Weights for identity.
160
    IdWeights idweights;
161
162
    // Weights for DCT2.
163
    DCT2Weights dct2weights;
164
165
    // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV.
166
    DCT4Multipliers dct4multipliers;
167
168
    // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1,
169
    // 0);  {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) +
170
    // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated
171
    // as in GetQuantWeights for DC and are used for other coefficients.
172
    AFVWeights afv_weights = {};
173
174
    // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4.
175
    DCT4x8Multipliers dct4x8multipliers;
176
177
    // Only used in kQuantModeRAW mode.
178
    struct {
179
      // explicit quantization table (like in JPEG)
180
      std::vector<int>* qtable = nullptr;
181
      float qtable_den = 1.f / (8 * 255);
182
    } qraw;
183
  };
184
185
  // Weights for 4x4 sub-block in AFV.
186
  DctQuantWeightParams dct_params_afv_4x4;
187
188
  // Which predefined table to use. Only used if mode is kQuantModeLibrary.
189
  uint8_t predefined = 0;
190
};
191
192
class QuantEncoding final : public QuantEncodingInternal {
193
 public:
194
  QuantEncoding(const QuantEncoding& other)
195
966k
      : QuantEncodingInternal(
196
966k
            static_cast<const QuantEncodingInternal&>(other)) {
197
966k
    if (mode == kQuantModeRAW && qraw.qtable) {
198
      // Need to make a copy of the passed *qtable.
199
0
      qraw.qtable = new std::vector<int>(*other.qraw.qtable);
200
0
    }
201
966k
  }
202
  QuantEncoding(QuantEncoding&& other) noexcept
203
0
      : QuantEncodingInternal(
204
0
            static_cast<const QuantEncodingInternal&>(other)) {
205
    // Steal the qtable from the other object if any.
206
0
    if (mode == kQuantModeRAW) {
207
0
      other.qraw.qtable = nullptr;
208
0
    }
209
0
  }
210
0
  QuantEncoding& operator=(const QuantEncoding& other) {
211
0
    if (mode == kQuantModeRAW && qraw.qtable) {
212
0
      delete qraw.qtable;
213
0
    }
214
0
    *static_cast<QuantEncodingInternal*>(this) =
215
0
        QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other));
216
0
    if (mode == kQuantModeRAW && qraw.qtable) {
217
      // Need to make a copy of the passed *qtable.
218
0
      qraw.qtable = new std::vector<int>(*other.qraw.qtable);
219
0
    }
220
0
    return *this;
221
0
  }
222
223
1.02M
  ~QuantEncoding() {
224
1.02M
    if (mode == kQuantModeRAW && qraw.qtable) {
225
46
      delete qraw.qtable;
226
46
    }
227
1.02M
  }
228
229
  // Wrappers of the QuantEncodingInternal:: static functions that return a
230
  // QuantEncoding instead. This is using the explicit and private cast from
231
  // QuantEncodingInternal to QuantEncoding, which would be inlined anyway.
232
  // In general, you should use this wrappers. The only reason to directly
233
  // create a QuantEncodingInternal instance is if you need a constexpr version
234
  // of this class. Note that RAW() is not supported in that case since it uses
235
  // a std::vector.
236
  template <size_t A>
237
56.8k
  static QuantEncoding Library() {
238
56.8k
    return QuantEncoding(QuantEncodingInternal::Library<A>());
239
56.8k
  }
240
0
  static QuantEncoding Identity(const IdWeights& xybweights) {
241
0
    return QuantEncoding(QuantEncodingInternal::Identity(xybweights));
242
0
  }
243
0
  static QuantEncoding DCT2(const DCT2Weights& xybweights) {
244
0
    return QuantEncoding(QuantEncodingInternal::DCT2(xybweights));
245
0
  }
246
  static QuantEncoding DCT4(const DctQuantWeightParams& params,
247
0
                            const DCT4Multipliers& xybmul) {
248
0
    return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul));
249
0
  }
250
  static QuantEncoding DCT4X8(const DctQuantWeightParams& params,
251
0
                              const DCT4x8Multipliers& xybmul) {
252
0
    return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul));
253
0
  }
254
0
  static QuantEncoding DCT(const DctQuantWeightParams& params) {
255
0
    return QuantEncoding(QuantEncodingInternal::DCT(params));
256
0
  }
257
  static QuantEncoding AFV(const DctQuantWeightParams& params4x8,
258
                           const DctQuantWeightParams& params4x4,
259
0
                           const AFVWeights& weights) {
260
0
    return QuantEncoding(
261
0
        QuantEncodingInternal::AFV(params4x8, params4x4, weights));
262
0
  }
263
264
  // RAW, note that this one is not a constexpr one.
265
0
  static QuantEncoding RAW(std::vector<int>&& qtable, int shift = 0) {
266
0
    QuantEncoding encoding(kQuantModeRAW);
267
0
    encoding.qraw.qtable = new std::vector<int>();
268
0
    *encoding.qraw.qtable = qtable;
269
0
    encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255));
270
0
    return encoding;
271
0
  }
272
273
 private:
274
  explicit QuantEncoding(const QuantEncodingInternal& other)
275
56.8k
      : QuantEncodingInternal(other) {}
276
277
  explicit QuantEncoding(QuantEncodingInternal::Mode mode_arg)
278
0
      : QuantEncodingInternal(mode_arg) {}
279
};
280
281
// A constexpr QuantEncodingInternal instance is often downcasted to the
282
// QuantEncoding subclass even if the instance wasn't an instance of the
283
// subclass. This is safe because user will upcast to QuantEncodingInternal to
284
// access any of its members.
285
static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal),
286
              "Don't add any members to QuantEncoding");
287
288
// Let's try to keep these 2**N for possible future simplicity.
289
const float kInvDCQuant[3] = {
290
    4096.0f,
291
    512.0f,
292
    256.0f,
293
};
294
295
const float kDCQuant[3] = {
296
    1.0f / kInvDCQuant[0],
297
    1.0f / kInvDCQuant[1],
298
    1.0f / kInvDCQuant[2],
299
};
300
301
class ModularFrameEncoder;
302
class ModularFrameDecoder;
303
304
enum class QuantTable : size_t {
305
  DCT = 0,
306
  IDENTITY,
307
  DCT2X2,
308
  DCT4X4,
309
  DCT16X16,
310
  DCT32X32,
311
  // DCT16X8
312
  DCT8X16,
313
  // DCT32X8
314
  DCT8X32,
315
  // DCT32X16
316
  DCT16X32,
317
  DCT4X8,
318
  // DCT8X4
319
  AFV0,
320
  // AFV1
321
  // AFV2
322
  // AFV3
323
  DCT64X64,
324
  // DCT64X32,
325
  DCT32X64,
326
  DCT128X128,
327
  // DCT128X64,
328
  DCT64X128,
329
  DCT256X256,
330
  // DCT256X128,
331
  DCT128X256
332
};
333
334
static constexpr uint8_t kNumQuantTables =
335
    static_cast<uint8_t>(QuantTable::DCT128X256) + 1;
336
337
static const std::array<QuantTable, AcStrategy::kNumValidStrategies>
338
    kAcStrategyToQuantTableMap = {
339
        QuantTable::DCT,        QuantTable::IDENTITY,   QuantTable::DCT2X2,
340
        QuantTable::DCT4X4,     QuantTable::DCT16X16,   QuantTable::DCT32X32,
341
        QuantTable::DCT8X16,    QuantTable::DCT8X16,    QuantTable::DCT8X32,
342
        QuantTable::DCT8X32,    QuantTable::DCT16X32,   QuantTable::DCT16X32,
343
        QuantTable::DCT4X8,     QuantTable::DCT4X8,     QuantTable::AFV0,
344
        QuantTable::AFV0,       QuantTable::AFV0,       QuantTable::AFV0,
345
        QuantTable::DCT64X64,   QuantTable::DCT32X64,   QuantTable::DCT32X64,
346
        QuantTable::DCT128X128, QuantTable::DCT64X128,  QuantTable::DCT64X128,
347
        QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256,
348
};
349
350
class DequantMatrices {
351
 public:
352
  DequantMatrices();
353
354
  static const QuantEncoding* Library();
355
356
  using DequantLibraryInternal =
357
      std::array<QuantEncodingInternal, kNumPredefinedTables * kNumQuantTables>;
358
  // Return the array of library kNumPredefinedTables QuantEncoding entries as
359
  // a constexpr array. Use Library() to obtain a pointer to the copy in the
360
  // .cc file.
361
  static DequantLibraryInternal LibraryInit();
362
363
  // Returns aligned memory.
364
16.2M
  JXL_INLINE const float* Matrix(AcStrategyType quant_kind, size_t c) const {
365
16.2M
    JXL_DASSERT((1 << static_cast<uint32_t>(quant_kind)) & computed_mask_);
366
16.2M
    return &table_[table_offsets_[static_cast<size_t>(quant_kind) * 3 + c]];
367
16.2M
  }
368
369
17.4M
  JXL_INLINE const float* InvMatrix(AcStrategyType quant_kind, size_t c) const {
370
17.4M
    size_t quant_table_idx = static_cast<uint32_t>(quant_kind);
371
17.4M
    JXL_DASSERT((1 << quant_table_idx) & computed_mask_);
372
17.4M
    return &inv_table_[table_offsets_[quant_table_idx * 3 + c]];
373
17.4M
  }
374
375
  // DC quants are used in modular mode for XYB multipliers.
376
48.1k
  JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; }
377
47.7k
  JXL_INLINE const float* DCQuants() const { return dc_quant_; }
378
379
532k
  JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; }
380
381
  // For encoder.
382
0
  void SetEncodings(const std::vector<QuantEncoding>& encodings) {
383
0
    encodings_ = encodings;
384
0
    computed_mask_ = 0;
385
0
  }
386
387
  // For encoder.
388
283
  void SetDCQuant(const float dc[3]) {
389
1.13k
    for (size_t c = 0; c < 3; c++) {
390
849
      dc_quant_[c] = 1.0f / dc[c];
391
849
      inv_dc_quant_[c] = dc[c];
392
849
    }
393
283
  }
394
395
  Status Decode(JxlMemoryManager* memory_manager, BitReader* br,
396
                ModularFrameDecoder* modular_frame_decoder = nullptr);
397
  Status DecodeDC(BitReader* br);
398
399
186
  const std::vector<QuantEncoding>& encodings() const { return encodings_; }
400
401
  static constexpr auto required_size_x =
402
      to_array<int>({1, 1, 1, 1, 2, 4, 1, 1, 2, 1, 1, 8, 4, 16, 8, 32, 16});
403
  static_assert(kNumQuantTables == required_size_x.size(),
404
                "Update this array when adding or removing quant tables.");
405
406
  static constexpr auto required_size_y =
407
      to_array<int>({1, 1, 1, 1, 2, 4, 2, 4, 4, 1, 1, 8, 8, 16, 16, 32, 32});
408
  static_assert(kNumQuantTables == required_size_y.size(),
409
                "Update this array when adding or removing quant tables.");
410
411
  // MUST be equal `sum(dot(required_size_x, required_size_y))`.
412
  static constexpr size_t kSumRequiredXy = 2056;
413
414
  Status EnsureComputed(JxlMemoryManager* memory_manager, uint32_t acs_mask);
415
416
 private:
417
  static constexpr size_t kTotalTableSize = kSumRequiredXy * kDCTBlockSize * 3;
418
419
  uint32_t computed_mask_ = 0;
420
  // kTotalTableSize entries followed by kTotalTableSize for inv_table
421
  AlignedMemory table_storage_;
422
  const float* table_;
423
  const float* inv_table_;
424
  float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]};
425
  float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]};
426
  size_t table_offsets_[AcStrategy::kNumValidStrategies * 3];
427
  std::vector<QuantEncoding> encodings_;
428
};
429
430
}  // namespace jxl
431
432
#endif  // LIB_JXL_QUANT_WEIGHTS_H_