Coverage Report

Created: 2026-02-14 07:42

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/quant_weights.cc
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
#include "lib/jxl/quant_weights.h"
6
7
#include <jxl/memory_manager.h>
8
9
#include <cmath>
10
#include <cstdint>
11
#include <cstdio>
12
#include <cstdlib>
13
#include <vector>
14
15
#include "lib/jxl/ac_strategy.h"
16
#include "lib/jxl/base/compiler_specific.h"
17
#include "lib/jxl/base/status.h"
18
#include "lib/jxl/coeff_order_fwd.h"
19
#include "lib/jxl/dct_scales.h"
20
#include "lib/jxl/dec_bit_reader.h"
21
#include "lib/jxl/dec_modular.h"
22
#include "lib/jxl/fields.h"
23
#include "lib/jxl/frame_dimensions.h"
24
#include "lib/jxl/memory_manager_internal.h"
25
26
#undef HWY_TARGET_INCLUDE
27
#define HWY_TARGET_INCLUDE "lib/jxl/quant_weights.cc"
28
#include <hwy/foreach_target.h>
29
#include <hwy/highway.h>
30
31
#include "lib/jxl/base/fast_math-inl.h"
32
33
HWY_BEFORE_NAMESPACE();
34
namespace jxl {
35
namespace HWY_NAMESPACE {
36
37
// These templates are not found via ADL.
38
using hwy::HWY_NAMESPACE::Lt;
39
using hwy::HWY_NAMESPACE::MulAdd;
40
using hwy::HWY_NAMESPACE::Sqrt;
41
42
// kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y)
43
// coefficient in component c. Higher weights correspond to finer quantization
44
// intervals and more bits spent in encoding.
45
46
static constexpr const float kAlmostZero = 1e-8f;
47
48
void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights,
49
1.72k
                         float* weights) {
50
6.89k
  for (size_t c = 0; c < 3; c++) {
51
5.16k
    size_t start = c * 64;
52
5.16k
    weights[start] = 0xBAD;
53
5.16k
    weights[start + 1] = weights[start + 8] = dct2weights[c][0];
54
5.16k
    weights[start + 9] = dct2weights[c][1];
55
15.5k
    for (size_t y = 0; y < 2; y++) {
56
31.0k
      for (size_t x = 0; x < 2; x++) {
57
20.6k
        weights[start + y * 8 + x + 2] = dct2weights[c][2];
58
20.6k
        weights[start + (y + 2) * 8 + x] = dct2weights[c][2];
59
20.6k
      }
60
10.3k
    }
61
15.5k
    for (size_t y = 0; y < 2; y++) {
62
31.0k
      for (size_t x = 0; x < 2; x++) {
63
20.6k
        weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3];
64
20.6k
      }
65
10.3k
    }
66
25.8k
    for (size_t y = 0; y < 4; y++) {
67
103k
      for (size_t x = 0; x < 4; x++) {
68
82.7k
        weights[start + y * 8 + x + 4] = dct2weights[c][4];
69
82.7k
        weights[start + (y + 4) * 8 + x] = dct2weights[c][4];
70
82.7k
      }
71
20.6k
    }
72
25.8k
    for (size_t y = 0; y < 4; y++) {
73
103k
      for (size_t x = 0; x < 4; x++) {
74
82.7k
        weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5];
75
82.7k
      }
76
20.6k
    }
77
5.16k
  }
78
1.72k
}
jxl::N_SSE4::GetQuantWeightsDCT2(std::__1::array<std::__1::array<float, 6ul>, 3ul> const&, float*)
Line
Count
Source
49
90
                         float* weights) {
50
360
  for (size_t c = 0; c < 3; c++) {
51
270
    size_t start = c * 64;
52
270
    weights[start] = 0xBAD;
53
270
    weights[start + 1] = weights[start + 8] = dct2weights[c][0];
54
270
    weights[start + 9] = dct2weights[c][1];
55
810
    for (size_t y = 0; y < 2; y++) {
56
1.62k
      for (size_t x = 0; x < 2; x++) {
57
1.08k
        weights[start + y * 8 + x + 2] = dct2weights[c][2];
58
1.08k
        weights[start + (y + 2) * 8 + x] = dct2weights[c][2];
59
1.08k
      }
60
540
    }
61
810
    for (size_t y = 0; y < 2; y++) {
62
1.62k
      for (size_t x = 0; x < 2; x++) {
63
1.08k
        weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3];
64
1.08k
      }
65
540
    }
66
1.35k
    for (size_t y = 0; y < 4; y++) {
67
5.40k
      for (size_t x = 0; x < 4; x++) {
68
4.32k
        weights[start + y * 8 + x + 4] = dct2weights[c][4];
69
4.32k
        weights[start + (y + 4) * 8 + x] = dct2weights[c][4];
70
4.32k
      }
71
1.08k
    }
72
1.35k
    for (size_t y = 0; y < 4; y++) {
73
5.40k
      for (size_t x = 0; x < 4; x++) {
74
4.32k
        weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5];
75
4.32k
      }
76
1.08k
    }
77
270
  }
78
90
}
jxl::N_AVX2::GetQuantWeightsDCT2(std::__1::array<std::__1::array<float, 6ul>, 3ul> const&, float*)
Line
Count
Source
49
1.56k
                         float* weights) {
50
6.26k
  for (size_t c = 0; c < 3; c++) {
51
4.70k
    size_t start = c * 64;
52
4.70k
    weights[start] = 0xBAD;
53
4.70k
    weights[start + 1] = weights[start + 8] = dct2weights[c][0];
54
4.70k
    weights[start + 9] = dct2weights[c][1];
55
14.1k
    for (size_t y = 0; y < 2; y++) {
56
28.2k
      for (size_t x = 0; x < 2; x++) {
57
18.8k
        weights[start + y * 8 + x + 2] = dct2weights[c][2];
58
18.8k
        weights[start + (y + 2) * 8 + x] = dct2weights[c][2];
59
18.8k
      }
60
9.40k
    }
61
14.1k
    for (size_t y = 0; y < 2; y++) {
62
28.2k
      for (size_t x = 0; x < 2; x++) {
63
18.8k
        weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3];
64
18.8k
      }
65
9.40k
    }
66
23.5k
    for (size_t y = 0; y < 4; y++) {
67
94.0k
      for (size_t x = 0; x < 4; x++) {
68
75.2k
        weights[start + y * 8 + x + 4] = dct2weights[c][4];
69
75.2k
        weights[start + (y + 4) * 8 + x] = dct2weights[c][4];
70
75.2k
      }
71
18.8k
    }
72
23.5k
    for (size_t y = 0; y < 4; y++) {
73
94.0k
      for (size_t x = 0; x < 4; x++) {
74
75.2k
        weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5];
75
75.2k
      }
76
18.8k
    }
77
4.70k
  }
78
1.56k
}
jxl::N_SSE2::GetQuantWeightsDCT2(std::__1::array<std::__1::array<float, 6ul>, 3ul> const&, float*)
Line
Count
Source
49
66
                         float* weights) {
50
264
  for (size_t c = 0; c < 3; c++) {
51
198
    size_t start = c * 64;
52
198
    weights[start] = 0xBAD;
53
198
    weights[start + 1] = weights[start + 8] = dct2weights[c][0];
54
198
    weights[start + 9] = dct2weights[c][1];
55
594
    for (size_t y = 0; y < 2; y++) {
56
1.18k
      for (size_t x = 0; x < 2; x++) {
57
792
        weights[start + y * 8 + x + 2] = dct2weights[c][2];
58
792
        weights[start + (y + 2) * 8 + x] = dct2weights[c][2];
59
792
      }
60
396
    }
61
594
    for (size_t y = 0; y < 2; y++) {
62
1.18k
      for (size_t x = 0; x < 2; x++) {
63
792
        weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3];
64
792
      }
65
396
    }
66
990
    for (size_t y = 0; y < 4; y++) {
67
3.96k
      for (size_t x = 0; x < 4; x++) {
68
3.16k
        weights[start + y * 8 + x + 4] = dct2weights[c][4];
69
3.16k
        weights[start + (y + 4) * 8 + x] = dct2weights[c][4];
70
3.16k
      }
71
792
    }
72
990
    for (size_t y = 0; y < 4; y++) {
73
3.96k
      for (size_t x = 0; x < 4; x++) {
74
3.16k
        weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5];
75
3.16k
      }
76
792
    }
77
198
  }
78
66
}
79
80
void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights,
81
1.77k
                             float* weights) {
82
7.08k
  for (size_t c = 0; c < 3; c++) {
83
345k
    for (int i = 0; i < 64; i++) {
84
340k
      weights[64 * c + i] = idweights[c][0];
85
340k
    }
86
5.31k
    weights[64 * c + 1] = idweights[c][1];
87
5.31k
    weights[64 * c + 8] = idweights[c][1];
88
5.31k
    weights[64 * c + 9] = idweights[c][2];
89
5.31k
  }
90
1.77k
}
jxl::N_SSE4::GetQuantWeightsIdentity(std::__1::array<std::__1::array<float, 3ul>, 3ul> const&, float*)
Line
Count
Source
81
68
                             float* weights) {
82
272
  for (size_t c = 0; c < 3; c++) {
83
13.2k
    for (int i = 0; i < 64; i++) {
84
13.0k
      weights[64 * c + i] = idweights[c][0];
85
13.0k
    }
86
204
    weights[64 * c + 1] = idweights[c][1];
87
204
    weights[64 * c + 8] = idweights[c][1];
88
204
    weights[64 * c + 9] = idweights[c][2];
89
204
  }
90
68
}
jxl::N_AVX2::GetQuantWeightsIdentity(std::__1::array<std::__1::array<float, 3ul>, 3ul> const&, float*)
Line
Count
Source
81
1.65k
                             float* weights) {
82
6.61k
  for (size_t c = 0; c < 3; c++) {
83
322k
    for (int i = 0; i < 64; i++) {
84
317k
      weights[64 * c + i] = idweights[c][0];
85
317k
    }
86
4.96k
    weights[64 * c + 1] = idweights[c][1];
87
4.96k
    weights[64 * c + 8] = idweights[c][1];
88
4.96k
    weights[64 * c + 9] = idweights[c][2];
89
4.96k
  }
90
1.65k
}
jxl::N_SSE2::GetQuantWeightsIdentity(std::__1::array<std::__1::array<float, 3ul>, 3ul> const&, float*)
Line
Count
Source
81
50
                             float* weights) {
82
200
  for (size_t c = 0; c < 3; c++) {
83
9.75k
    for (int i = 0; i < 64; i++) {
84
9.60k
      weights[64 * c + i] = idweights[c][0];
85
9.60k
    }
86
150
    weights[64 * c + 1] = idweights[c][1];
87
150
    weights[64 * c + 8] = idweights[c][1];
88
150
    weights[64 * c + 9] = idweights[c][2];
89
150
  }
90
50
}
91
92
StatusOr<float> Interpolate(float pos, float max, const float* array,
93
82.1k
                            size_t len) {
94
82.1k
  float scaled_pos = pos * (len - 1) / max;
95
82.1k
  size_t idx = scaled_pos;
96
82.1k
  JXL_ENSURE(idx + 1 < len);
97
82.1k
  float a = array[idx];
98
82.1k
  float b = array[idx + 1];
99
82.1k
  return a * FastPowf(b / a, scaled_pos - idx);
100
82.1k
}
jxl::N_SSE4::Interpolate(float, float, float const*, unsigned long)
Line
Count
Source
93
1.71k
                            size_t len) {
94
1.71k
  float scaled_pos = pos * (len - 1) / max;
95
1.71k
  size_t idx = scaled_pos;
96
1.71k
  JXL_ENSURE(idx + 1 < len);
97
1.71k
  float a = array[idx];
98
1.71k
  float b = array[idx + 1];
99
1.71k
  return a * FastPowf(b / a, scaled_pos - idx);
100
1.71k
}
jxl::N_AVX2::Interpolate(float, float, float const*, unsigned long)
Line
Count
Source
93
78.7k
                            size_t len) {
94
78.7k
  float scaled_pos = pos * (len - 1) / max;
95
78.7k
  size_t idx = scaled_pos;
96
78.7k
  JXL_ENSURE(idx + 1 < len);
97
78.7k
  float a = array[idx];
98
78.7k
  float b = array[idx + 1];
99
78.7k
  return a * FastPowf(b / a, scaled_pos - idx);
100
78.7k
}
jxl::N_SSE2::Interpolate(float, float, float const*, unsigned long)
Line
Count
Source
93
1.68k
                            size_t len) {
94
1.68k
  float scaled_pos = pos * (len - 1) / max;
95
1.68k
  size_t idx = scaled_pos;
96
1.68k
  JXL_ENSURE(idx + 1 < len);
97
1.68k
  float a = array[idx];
98
1.68k
  float b = array[idx + 1];
99
1.68k
  return a * FastPowf(b / a, scaled_pos - idx);
100
1.68k
}
101
102
518k
float Mult(float v) {
103
518k
  if (v > 0.0f) return 1.0f + v;
104
517k
  return 1.0f / (1.0f - v);
105
518k
}
jxl::N_SSE4::Mult(float)
Line
Count
Source
102
61.4k
float Mult(float v) {
103
61.4k
  if (v > 0.0f) return 1.0f + v;
104
61.3k
  return 1.0f / (1.0f - v);
105
61.4k
}
jxl::N_AVX2::Mult(float)
Line
Count
Source
102
403k
float Mult(float v) {
103
403k
  if (v > 0.0f) return 1.0f + v;
104
402k
  return 1.0f / (1.0f - v);
105
403k
}
jxl::N_SSE2::Mult(float)
Line
Count
Source
102
53.9k
float Mult(float v) {
103
53.9k
  if (v > 0.0f) return 1.0f + v;
104
53.7k
  return 1.0f / (1.0f - v);
105
53.9k
}
106
107
using DF4 = HWY_CAPPED(float, 4);
108
109
hwy::HWY_NAMESPACE::Vec<DF4> InterpolateVec(
110
22.5M
    hwy::HWY_NAMESPACE::Vec<DF4> scaled_pos, const float* array) {
111
22.5M
  HWY_CAPPED(int32_t, 4) di;
112
113
22.5M
  auto idx = ConvertTo(di, scaled_pos);
114
115
22.5M
  auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx));
116
117
  // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but
118
  // it's probably slower.
119
22.5M
  auto a = GatherIndex(DF4(), array, idx);
120
22.5M
  auto b = GatherIndex(DF4(), array + 1, idx);
121
122
22.5M
  return Mul(a, FastPowf(DF4(), Div(b, a), frac));
123
22.5M
}
jxl::N_SSE4::InterpolateVec(hwy::N_SSE4::Vec128<float, 4ul>, float const*)
Line
Count
Source
110
3.18M
    hwy::HWY_NAMESPACE::Vec<DF4> scaled_pos, const float* array) {
111
3.18M
  HWY_CAPPED(int32_t, 4) di;
112
113
3.18M
  auto idx = ConvertTo(di, scaled_pos);
114
115
3.18M
  auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx));
116
117
  // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but
118
  // it's probably slower.
119
3.18M
  auto a = GatherIndex(DF4(), array, idx);
120
3.18M
  auto b = GatherIndex(DF4(), array + 1, idx);
121
122
3.18M
  return Mul(a, FastPowf(DF4(), Div(b, a), frac));
123
3.18M
}
jxl::N_AVX2::InterpolateVec(hwy::N_AVX2::Vec128<float, 4ul>, float const*)
Line
Count
Source
110
12.3M
    hwy::HWY_NAMESPACE::Vec<DF4> scaled_pos, const float* array) {
111
12.3M
  HWY_CAPPED(int32_t, 4) di;
112
113
12.3M
  auto idx = ConvertTo(di, scaled_pos);
114
115
12.3M
  auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx));
116
117
  // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but
118
  // it's probably slower.
119
12.3M
  auto a = GatherIndex(DF4(), array, idx);
120
12.3M
  auto b = GatherIndex(DF4(), array + 1, idx);
121
122
12.3M
  return Mul(a, FastPowf(DF4(), Div(b, a), frac));
123
12.3M
}
jxl::N_SSE2::InterpolateVec(hwy::N_SSE2::Vec128<float, 4ul>, float const*)
Line
Count
Source
110
6.97M
    hwy::HWY_NAMESPACE::Vec<DF4> scaled_pos, const float* array) {
111
6.97M
  HWY_CAPPED(int32_t, 4) di;
112
113
6.97M
  auto idx = ConvertTo(di, scaled_pos);
114
115
6.97M
  auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx));
116
117
  // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but
118
  // it's probably slower.
119
6.97M
  auto a = GatherIndex(DF4(), array, idx);
120
6.97M
  auto b = GatherIndex(DF4(), array + 1, idx);
121
122
6.97M
  return Mul(a, FastPowf(DF4(), Div(b, a), frac));
123
6.97M
}
124
125
// Computes quant weights for a COLS*ROWS-sized transform, using num_bands
126
// eccentricity bands and num_ebands eccentricity bands. If print_mode is 1,
127
// prints the resulting matrix; if print_mode is 2, prints the matrix in a
128
// format suitable for a 3d plot with gnuplot.
129
Status GetQuantWeights(
130
    size_t ROWS, size_t COLS,
131
    const DctQuantWeightParams::DistanceBandsArray& distance_bands,
132
32.7k
    size_t num_bands, float* out) {
133
131k
  for (size_t c = 0; c < 3; c++) {
134
98.3k
    float bands[DctQuantWeightParams::kMaxDistanceBands] = {
135
98.3k
        distance_bands[c][0]};
136
98.3k
    if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
137
596k
    for (size_t i = 1; i < num_bands; i++) {
138
497k
      bands[i] = bands[i - 1] * Mult(distance_bands[c][i]);
139
497k
      if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
140
497k
    }
141
98.2k
    float scale = (num_bands - 1) / (kSqrt2 + 1e-6f);
142
98.2k
    float rcpcol = scale / (COLS - 1);
143
98.2k
    float rcprow = scale / (ROWS - 1);
144
98.2k
    JXL_ENSURE(COLS >= Lanes(DF4()));
145
98.2k
    HWY_ALIGN float l0123[4] = {0, 1, 2, 3};
146
1.55M
    for (uint32_t y = 0; y < ROWS; y++) {
147
1.45M
      float dy = y * rcprow;
148
1.45M
      float dy2 = dy * dy;
149
24.0M
      for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) {
150
22.6M
        auto dx =
151
22.6M
            Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol));
152
22.6M
        auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2)));
153
22.6M
        auto weight = num_bands == 1 ? Set(DF4(), bands[0])
154
22.6M
                                     : InterpolateVec(scaled_distance, bands);
155
22.6M
        StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x);
156
22.6M
      }
157
1.45M
    }
158
98.2k
  }
159
32.7k
  return true;
160
32.7k
}
jxl::N_SSE4::GetQuantWeights(unsigned long, unsigned long, std::__1::array<std::__1::array<float, 17ul>, 3ul> const&, unsigned long, float*)
Line
Count
Source
132
3.94k
    size_t num_bands, float* out) {
133
15.7k
  for (size_t c = 0; c < 3; c++) {
134
11.8k
    float bands[DctQuantWeightParams::kMaxDistanceBands] = {
135
11.8k
        distance_bands[c][0]};
136
11.8k
    if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
137
72.8k
    for (size_t i = 1; i < num_bands; i++) {
138
61.0k
      bands[i] = bands[i - 1] * Mult(distance_bands[c][i]);
139
61.0k
      if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
140
61.0k
    }
141
11.8k
    float scale = (num_bands - 1) / (kSqrt2 + 1e-6f);
142
11.8k
    float rcpcol = scale / (COLS - 1);
143
11.8k
    float rcprow = scale / (ROWS - 1);
144
11.8k
    JXL_ENSURE(COLS >= Lanes(DF4()));
145
11.8k
    HWY_ALIGN float l0123[4] = {0, 1, 2, 3};
146
194k
    for (uint32_t y = 0; y < ROWS; y++) {
147
183k
      float dy = y * rcprow;
148
183k
      float dy2 = dy * dy;
149
3.39M
      for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) {
150
3.21M
        auto dx =
151
3.21M
            Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol));
152
3.21M
        auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2)));
153
3.21M
        auto weight = num_bands == 1 ? Set(DF4(), bands[0])
154
3.21M
                                     : InterpolateVec(scaled_distance, bands);
155
3.21M
        StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x);
156
3.21M
      }
157
183k
    }
158
11.8k
  }
159
3.93k
  return true;
160
3.94k
}
jxl::N_AVX2::GetQuantWeights(unsigned long, unsigned long, std::__1::array<std::__1::array<float, 17ul>, 3ul> const&, unsigned long, float*)
Line
Count
Source
132
25.5k
    size_t num_bands, float* out) {
133
102k
  for (size_t c = 0; c < 3; c++) {
134
76.7k
    float bands[DctQuantWeightParams::kMaxDistanceBands] = {
135
76.7k
        distance_bands[c][0]};
136
76.7k
    if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
137
460k
    for (size_t i = 1; i < num_bands; i++) {
138
383k
      bands[i] = bands[i - 1] * Mult(distance_bands[c][i]);
139
383k
      if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
140
383k
    }
141
76.7k
    float scale = (num_bands - 1) / (kSqrt2 + 1e-6f);
142
76.7k
    float rcpcol = scale / (COLS - 1);
143
76.7k
    float rcprow = scale / (ROWS - 1);
144
76.7k
    JXL_ENSURE(COLS >= Lanes(DF4()));
145
76.7k
    HWY_ALIGN float l0123[4] = {0, 1, 2, 3};
146
1.06M
    for (uint32_t y = 0; y < ROWS; y++) {
147
990k
      float dy = y * rcprow;
148
990k
      float dy2 = dy * dy;
149
13.3M
      for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) {
150
12.3M
        auto dx =
151
12.3M
            Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol));
152
12.3M
        auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2)));
153
12.3M
        auto weight = num_bands == 1 ? Set(DF4(), bands[0])
154
12.3M
                                     : InterpolateVec(scaled_distance, bands);
155
12.3M
        StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x);
156
12.3M
      }
157
990k
    }
158
76.7k
  }
159
25.5k
  return true;
160
25.5k
}
jxl::N_SSE2::GetQuantWeights(unsigned long, unsigned long, std::__1::array<std::__1::array<float, 17ul>, 3ul> const&, unsigned long, float*)
Line
Count
Source
132
3.25k
    size_t num_bands, float* out) {
133
12.9k
  for (size_t c = 0; c < 3; c++) {
134
9.74k
    float bands[DctQuantWeightParams::kMaxDistanceBands] = {
135
9.74k
        distance_bands[c][0]};
136
9.74k
    if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
137
63.2k
    for (size_t i = 1; i < num_bands; i++) {
138
53.5k
      bands[i] = bands[i - 1] * Mult(distance_bands[c][i]);
139
53.5k
      if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
140
53.5k
    }
141
9.74k
    float scale = (num_bands - 1) / (kSqrt2 + 1e-6f);
142
9.74k
    float rcpcol = scale / (COLS - 1);
143
9.74k
    float rcprow = scale / (ROWS - 1);
144
9.74k
    JXL_ENSURE(COLS >= Lanes(DF4()));
145
9.74k
    HWY_ALIGN float l0123[4] = {0, 1, 2, 3};
146
288k
    for (uint32_t y = 0; y < ROWS; y++) {
147
279k
      float dy = y * rcprow;
148
279k
      float dy2 = dy * dy;
149
7.28M
      for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) {
150
7.00M
        auto dx =
151
7.00M
            Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol));
152
7.00M
        auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2)));
153
7.00M
        auto weight = num_bands == 1 ? Set(DF4(), bands[0])
154
7.00M
                                     : InterpolateVec(scaled_distance, bands);
155
7.00M
        StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x);
156
7.00M
      }
157
279k
    }
158
9.74k
  }
159
3.24k
  return true;
160
3.25k
}
161
162
// TODO(veluca): SIMD-fy. With 256x256, this is actually slow.
163
Status ComputeQuantTable(const QuantEncoding& encoding,
164
                         float* JXL_RESTRICT table,
165
                         float* JXL_RESTRICT inv_table, size_t table_num,
166
34.3k
                         QuantTable kind, size_t* pos) {
167
34.3k
  constexpr size_t N = kBlockDim;
168
34.3k
  size_t quant_table_idx = static_cast<size_t>(kind);
169
34.3k
  size_t wrows = 8 * DequantMatrices::required_size_x[quant_table_idx];
170
34.3k
  size_t wcols = 8 * DequantMatrices::required_size_y[quant_table_idx];
171
34.3k
  size_t num = wrows * wcols;
172
173
34.3k
  std::vector<float> weights(3 * num);
174
175
34.3k
  switch (encoding.mode) {
176
0
    case QuantEncoding::kQuantModeLibrary: {
177
      // Library and copy quant encoding should get replaced by the actual
178
      // parameters by the caller.
179
0
      JXL_ENSURE(false);
180
0
      break;
181
0
    }
182
1.77k
    case QuantEncoding::kQuantModeID: {
183
1.77k
      JXL_ENSURE(num == kDCTBlockSize);
184
1.77k
      GetQuantWeightsIdentity(encoding.idweights, weights.data());
185
1.77k
      break;
186
1.77k
    }
187
1.72k
    case QuantEncoding::kQuantModeDCT2: {
188
1.72k
      JXL_ENSURE(num == kDCTBlockSize);
189
1.72k
      GetQuantWeightsDCT2(encoding.dct2weights, weights.data());
190
1.72k
      break;
191
1.72k
    }
192
1.25k
    case QuantEncoding::kQuantModeDCT4: {
193
1.25k
      JXL_ENSURE(num == kDCTBlockSize);
194
1.25k
      float weights4x4[3 * 4 * 4];
195
      // Always use 4x4 GetQuantWeights for DCT4 quantization tables.
196
1.25k
      JXL_RETURN_IF_ERROR(
197
1.25k
          GetQuantWeights(4, 4, encoding.dct_params.distance_bands,
198
1.25k
                          encoding.dct_params.num_distance_bands, weights4x4));
199
5.00k
      for (size_t c = 0; c < 3; c++) {
200
33.7k
        for (size_t y = 0; y < kBlockDim; y++) {
201
270k
          for (size_t x = 0; x < kBlockDim; x++) {
202
240k
            weights[c * num + y * kBlockDim + x] =
203
240k
                weights4x4[c * 16 + (y / 2) * 4 + (x / 2)];
204
240k
          }
205
30.0k
        }
206
3.75k
        weights[c * num + 1] /= encoding.dct4multipliers[c][0];
207
3.75k
        weights[c * num + N] /= encoding.dct4multipliers[c][0];
208
3.75k
        weights[c * num + N + 1] /= encoding.dct4multipliers[c][1];
209
3.75k
      }
210
1.25k
      break;
211
1.25k
    }
212
1.32k
    case QuantEncoding::kQuantModeDCT4X8: {
213
1.32k
      JXL_ENSURE(num == kDCTBlockSize);
214
1.32k
      float weights4x8[3 * 4 * 8];
215
      // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables.
216
1.32k
      JXL_RETURN_IF_ERROR(
217
1.32k
          GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
218
1.32k
                          encoding.dct_params.num_distance_bands, weights4x8));
219
5.28k
      for (size_t c = 0; c < 3; c++) {
220
35.6k
        for (size_t y = 0; y < kBlockDim; y++) {
221
285k
          for (size_t x = 0; x < kBlockDim; x++) {
222
253k
            weights[c * num + y * kBlockDim + x] =
223
253k
                weights4x8[c * 32 + (y / 2) * 8 + x];
224
253k
          }
225
31.7k
        }
226
3.96k
        weights[c * num + N] /= encoding.dct4x8multipliers[c];
227
3.96k
      }
228
1.32k
      break;
229
1.32k
    }
230
25.6k
    case QuantEncoding::kQuantModeDCT: {
231
25.6k
      JXL_RETURN_IF_ERROR(GetQuantWeights(
232
25.6k
          wrows, wcols, encoding.dct_params.distance_bands,
233
25.6k
          encoding.dct_params.num_distance_bands, weights.data()));
234
25.5k
      break;
235
25.6k
    }
236
25.5k
    case QuantEncoding::kQuantModeRAW: {
237
374
      if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) {
238
0
        return JXL_FAILURE("Invalid table encoding");
239
0
      }
240
374
      int* qtable = encoding.qraw.qtable->data();
241
73.3k
      for (size_t i = 0; i < 3 * num; i++) {
242
72.9k
        weights[i] = 1.f / (encoding.qraw.qtable_den * qtable[i]);
243
72.9k
      }
244
374
      break;
245
374
    }
246
2.29k
    case QuantEncoding::kQuantModeAFV: {
247
2.29k
      constexpr float kFreqs[] = {
248
2.29k
          0xBAD,
249
2.29k
          0xBAD,
250
2.29k
          0.8517778890324296,
251
2.29k
          5.37778436506804,
252
2.29k
          0xBAD,
253
2.29k
          0xBAD,
254
2.29k
          4.734747904497923,
255
2.29k
          5.449245381693219,
256
2.29k
          1.6598270267479331,
257
2.29k
          4,
258
2.29k
          7.275749096817861,
259
2.29k
          10.423227632456525,
260
2.29k
          2.662932286148962,
261
2.29k
          7.630657783650829,
262
2.29k
          8.962388608184032,
263
2.29k
          12.97166202570235,
264
2.29k
      };
265
266
2.29k
      float weights4x8[3 * 4 * 8];
267
2.29k
      JXL_RETURN_IF_ERROR((
268
2.29k
          GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
269
2.29k
                          encoding.dct_params.num_distance_bands, weights4x8)));
270
2.29k
      float weights4x4[3 * 4 * 4];
271
2.29k
      JXL_RETURN_IF_ERROR((GetQuantWeights(
272
2.29k
          4, 4, encoding.dct_params_afv_4x4.distance_bands,
273
2.29k
          encoding.dct_params_afv_4x4.num_distance_bands, weights4x4)));
274
275
2.29k
      constexpr float lo = 0.8517778890324296;
276
2.29k
      constexpr float hi = 12.97166202570235f - lo + 1e-6f;
277
9.13k
      for (size_t c = 0; c < 3; c++) {
278
6.86k
        float bands[4];
279
6.86k
        bands[0] = encoding.afv_weights[c][5];
280
6.86k
        if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
281
27.3k
        for (size_t i = 1; i < 4; i++) {
282
20.5k
          bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]);
283
20.5k
          if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
284
20.5k
        }
285
6.84k
        size_t start = c * 64;
286
116k
        auto set_weight = [&start, &weights](size_t x, size_t y, float val) {
287
116k
          weights[start + y * 8 + x] = val;
288
116k
        };
quant_weights.cc:jxl::N_SSE4::ComputeQuantTable(jxl::QuantEncoding const&, float*, float*, unsigned long, jxl::QuantTable, unsigned long*)::$_0::operator()(unsigned long, unsigned long, float) const
Line
Count
Source
286
2.43k
        auto set_weight = [&start, &weights](size_t x, size_t y, float val) {
287
2.43k
          weights[start + y * 8 + x] = val;
288
2.43k
        };
quant_weights.cc:jxl::N_AVX2::ComputeQuantTable(jxl::QuantEncoding const&, float*, float*, unsigned long, jxl::QuantTable, unsigned long*)::$_0::operator()(unsigned long, unsigned long, float) const
Line
Count
Source
286
111k
        auto set_weight = [&start, &weights](size_t x, size_t y, float val) {
287
111k
          weights[start + y * 8 + x] = val;
288
111k
        };
quant_weights.cc:jxl::N_SSE2::ComputeQuantTable(jxl::QuantEncoding const&, float*, float*, unsigned long, jxl::QuantTable, unsigned long*)::$_0::operator()(unsigned long, unsigned long, float) const
Line
Count
Source
286
2.38k
        auto set_weight = [&start, &weights](size_t x, size_t y, float val) {
287
2.38k
          weights[start + y * 8 + x] = val;
288
2.38k
        };
289
6.84k
        weights[start] = 1;  // Not used, but causes MSAN error otherwise.
290
        // Weights for (0, 1) and (1, 0).
291
6.84k
        set_weight(0, 1, encoding.afv_weights[c][0]);
292
6.84k
        set_weight(1, 0, encoding.afv_weights[c][1]);
293
        // AFV special weights for 3-pixel corner.
294
6.84k
        set_weight(0, 2, encoding.afv_weights[c][2]);
295
6.84k
        set_weight(2, 0, encoding.afv_weights[c][3]);
296
6.84k
        set_weight(2, 2, encoding.afv_weights[c][4]);
297
298
        // All other AFV weights.
299
34.2k
        for (size_t y = 0; y < 4; y++) {
300
136k
          for (size_t x = 0; x < 4; x++) {
301
109k
            if (x < 2 && y < 2) continue;
302
164k
            JXL_ASSIGN_OR_RETURN(
303
164k
                float val, Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4));
304
164k
            set_weight(2 * x, 2 * y, val);
305
164k
          }
306
27.3k
        }
307
308
        // Put 4x8 weights in odd rows, except (1, 0).
309
34.2k
        for (size_t y = 0; y < kBlockDim / 2; y++) {
310
246k
          for (size_t x = 0; x < kBlockDim; x++) {
311
218k
            if (x == 0 && y == 0) continue;
312
212k
            weights[c * num + (2 * y + 1) * kBlockDim + x] =
313
212k
                weights4x8[c * 32 + y * 8 + x];
314
212k
          }
315
27.3k
        }
316
        // Put 4x4 weights in even rows / odd columns, except (0, 1).
317
34.2k
        for (size_t y = 0; y < kBlockDim / 2; y++) {
318
136k
          for (size_t x = 0; x < kBlockDim / 2; x++) {
319
109k
            if (x == 0 && y == 0) continue;
320
102k
            weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] =
321
102k
                weights4x4[c * 16 + y * 4 + x];
322
102k
          }
323
27.3k
        }
324
6.84k
      }
325
2.27k
      break;
326
2.29k
    }
327
34.3k
  }
328
34.3k
  size_t prev_pos = *pos;
329
34.3k
  HWY_CAPPED(float, 64) d;
330
16.6M
  for (size_t i = 0; i < num * 3; i += Lanes(d)) {
331
16.5M
    auto inv_val = LoadU(d, weights.data() + i);
332
16.5M
    if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) ||
333
16.5M
                     !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) {
334
47
      return JXL_FAILURE("Invalid quantization table");
335
47
    }
336
16.5M
    auto val = Div(Set(d, 1.0f), inv_val);
337
16.5M
    StoreU(val, d, table + *pos + i);
338
16.5M
    StoreU(inv_val, d, inv_table + *pos + i);
339
16.5M
  }
340
34.2k
  (*pos) += 3 * num;
341
342
  // Ensure that the lowest frequencies have a 0 inverse table.
343
  // This does not affect en/decoding, but allows AC strategy selection to be
344
  // slightly simpler.
345
34.2k
  size_t xs = DequantMatrices::required_size_x[quant_table_idx];
346
34.2k
  size_t ys = DequantMatrices::required_size_y[quant_table_idx];
347
34.2k
  CoefficientLayout(&ys, &xs);
348
137k
  for (size_t c = 0; c < 3; c++) {
349
299k
    for (size_t y = 0; y < ys; y++) {
350
1.62M
      for (size_t x = 0; x < xs; x++) {
351
1.43M
        inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs +
352
1.43M
                  x] = 0;
353
1.43M
      }
354
196k
    }
355
102k
  }
356
34.2k
  return true;
357
34.3k
}
jxl::N_SSE4::ComputeQuantTable(jxl::QuantEncoding const&, float*, float*, unsigned long, jxl::QuantTable, unsigned long*)
Line
Count
Source
166
4.13k
                         QuantTable kind, size_t* pos) {
167
4.13k
  constexpr size_t N = kBlockDim;
168
4.13k
  size_t quant_table_idx = static_cast<size_t>(kind);
169
4.13k
  size_t wrows = 8 * DequantMatrices::required_size_x[quant_table_idx];
170
4.13k
  size_t wcols = 8 * DequantMatrices::required_size_y[quant_table_idx];
171
4.13k
  size_t num = wrows * wcols;
172
173
4.13k
  std::vector<float> weights(3 * num);
174
175
4.13k
  switch (encoding.mode) {
176
0
    case QuantEncoding::kQuantModeLibrary: {
177
      // Library and copy quant encoding should get replaced by the actual
178
      // parameters by the caller.
179
0
      JXL_ENSURE(false);
180
0
      break;
181
0
    }
182
68
    case QuantEncoding::kQuantModeID: {
183
68
      JXL_ENSURE(num == kDCTBlockSize);
184
68
      GetQuantWeightsIdentity(encoding.idweights, weights.data());
185
68
      break;
186
68
    }
187
90
    case QuantEncoding::kQuantModeDCT2: {
188
90
      JXL_ENSURE(num == kDCTBlockSize);
189
90
      GetQuantWeightsDCT2(encoding.dct2weights, weights.data());
190
90
      break;
191
90
    }
192
35
    case QuantEncoding::kQuantModeDCT4: {
193
35
      JXL_ENSURE(num == kDCTBlockSize);
194
35
      float weights4x4[3 * 4 * 4];
195
      // Always use 4x4 GetQuantWeights for DCT4 quantization tables.
196
35
      JXL_RETURN_IF_ERROR(
197
35
          GetQuantWeights(4, 4, encoding.dct_params.distance_bands,
198
35
                          encoding.dct_params.num_distance_bands, weights4x4));
199
136
      for (size_t c = 0; c < 3; c++) {
200
918
        for (size_t y = 0; y < kBlockDim; y++) {
201
7.34k
          for (size_t x = 0; x < kBlockDim; x++) {
202
6.52k
            weights[c * num + y * kBlockDim + x] =
203
6.52k
                weights4x4[c * 16 + (y / 2) * 4 + (x / 2)];
204
6.52k
          }
205
816
        }
206
102
        weights[c * num + 1] /= encoding.dct4multipliers[c][0];
207
102
        weights[c * num + N] /= encoding.dct4multipliers[c][0];
208
102
        weights[c * num + N + 1] /= encoding.dct4multipliers[c][1];
209
102
      }
210
34
      break;
211
35
    }
212
58
    case QuantEncoding::kQuantModeDCT4X8: {
213
58
      JXL_ENSURE(num == kDCTBlockSize);
214
58
      float weights4x8[3 * 4 * 8];
215
      // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables.
216
58
      JXL_RETURN_IF_ERROR(
217
58
          GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
218
58
                          encoding.dct_params.num_distance_bands, weights4x8));
219
228
      for (size_t c = 0; c < 3; c++) {
220
1.53k
        for (size_t y = 0; y < kBlockDim; y++) {
221
12.3k
          for (size_t x = 0; x < kBlockDim; x++) {
222
10.9k
            weights[c * num + y * kBlockDim + x] =
223
10.9k
                weights4x8[c * 32 + (y / 2) * 8 + x];
224
10.9k
          }
225
1.36k
        }
226
171
        weights[c * num + N] /= encoding.dct4x8multipliers[c];
227
171
      }
228
57
      break;
229
58
    }
230
3.74k
    case QuantEncoding::kQuantModeDCT: {
231
3.74k
      JXL_RETURN_IF_ERROR(GetQuantWeights(
232
3.74k
          wrows, wcols, encoding.dct_params.distance_bands,
233
3.74k
          encoding.dct_params.num_distance_bands, weights.data()));
234
3.74k
      break;
235
3.74k
    }
236
3.74k
    case QuantEncoding::kQuantModeRAW: {
237
85
      if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) {
238
0
        return JXL_FAILURE("Invalid table encoding");
239
0
      }
240
85
      int* qtable = encoding.qraw.qtable->data();
241
16.7k
      for (size_t i = 0; i < 3 * num; i++) {
242
16.7k
        weights[i] = 1.f / (encoding.qraw.qtable_den * qtable[i]);
243
16.7k
      }
244
85
      break;
245
85
    }
246
52
    case QuantEncoding::kQuantModeAFV: {
247
52
      constexpr float kFreqs[] = {
248
52
          0xBAD,
249
52
          0xBAD,
250
52
          0.8517778890324296,
251
52
          5.37778436506804,
252
52
          0xBAD,
253
52
          0xBAD,
254
52
          4.734747904497923,
255
52
          5.449245381693219,
256
52
          1.6598270267479331,
257
52
          4,
258
52
          7.275749096817861,
259
52
          10.423227632456525,
260
52
          2.662932286148962,
261
52
          7.630657783650829,
262
52
          8.962388608184032,
263
52
          12.97166202570235,
264
52
      };
265
266
52
      float weights4x8[3 * 4 * 8];
267
52
      JXL_RETURN_IF_ERROR((
268
52
          GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
269
52
                          encoding.dct_params.num_distance_bands, weights4x8)));
270
51
      float weights4x4[3 * 4 * 4];
271
51
      JXL_RETURN_IF_ERROR((GetQuantWeights(
272
51
          4, 4, encoding.dct_params_afv_4x4.distance_bands,
273
51
          encoding.dct_params_afv_4x4.num_distance_bands, weights4x4)));
274
275
50
      constexpr float lo = 0.8517778890324296;
276
50
      constexpr float hi = 12.97166202570235f - lo + 1e-6f;
277
193
      for (size_t c = 0; c < 3; c++) {
278
147
        float bands[4];
279
147
        bands[0] = encoding.afv_weights[c][5];
280
147
        if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
281
578
        for (size_t i = 1; i < 4; i++) {
282
435
          bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]);
283
435
          if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
284
435
        }
285
143
        size_t start = c * 64;
286
143
        auto set_weight = [&start, &weights](size_t x, size_t y, float val) {
287
143
          weights[start + y * 8 + x] = val;
288
143
        };
289
143
        weights[start] = 1;  // Not used, but causes MSAN error otherwise.
290
        // Weights for (0, 1) and (1, 0).
291
143
        set_weight(0, 1, encoding.afv_weights[c][0]);
292
143
        set_weight(1, 0, encoding.afv_weights[c][1]);
293
        // AFV special weights for 3-pixel corner.
294
143
        set_weight(0, 2, encoding.afv_weights[c][2]);
295
143
        set_weight(2, 0, encoding.afv_weights[c][3]);
296
143
        set_weight(2, 2, encoding.afv_weights[c][4]);
297
298
        // All other AFV weights.
299
715
        for (size_t y = 0; y < 4; y++) {
300
2.86k
          for (size_t x = 0; x < 4; x++) {
301
2.28k
            if (x < 2 && y < 2) continue;
302
3.43k
            JXL_ASSIGN_OR_RETURN(
303
3.43k
                float val, Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4));
304
3.43k
            set_weight(2 * x, 2 * y, val);
305
3.43k
          }
306
572
        }
307
308
        // Put 4x8 weights in odd rows, except (1, 0).
309
715
        for (size_t y = 0; y < kBlockDim / 2; y++) {
310
5.14k
          for (size_t x = 0; x < kBlockDim; x++) {
311
4.57k
            if (x == 0 && y == 0) continue;
312
4.43k
            weights[c * num + (2 * y + 1) * kBlockDim + x] =
313
4.43k
                weights4x8[c * 32 + y * 8 + x];
314
4.43k
          }
315
572
        }
316
        // Put 4x4 weights in even rows / odd columns, except (0, 1).
317
715
        for (size_t y = 0; y < kBlockDim / 2; y++) {
318
2.86k
          for (size_t x = 0; x < kBlockDim / 2; x++) {
319
2.28k
            if (x == 0 && y == 0) continue;
320
2.14k
            weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] =
321
2.14k
                weights4x4[c * 16 + y * 4 + x];
322
2.14k
          }
323
572
        }
324
143
      }
325
46
      break;
326
50
    }
327
4.13k
  }
328
4.12k
  size_t prev_pos = *pos;
329
4.12k
  HWY_CAPPED(float, 64) d;
330
3.23M
  for (size_t i = 0; i < num * 3; i += Lanes(d)) {
331
3.22M
    auto inv_val = LoadU(d, weights.data() + i);
332
3.22M
    if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) ||
333
3.22M
                     !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) {
334
9
      return JXL_FAILURE("Invalid quantization table");
335
9
    }
336
3.22M
    auto val = Div(Set(d, 1.0f), inv_val);
337
3.22M
    StoreU(val, d, table + *pos + i);
338
3.22M
    StoreU(inv_val, d, inv_table + *pos + i);
339
3.22M
  }
340
4.11k
  (*pos) += 3 * num;
341
342
  // Ensure that the lowest frequencies have a 0 inverse table.
343
  // This does not affect en/decoding, but allows AC strategy selection to be
344
  // slightly simpler.
345
4.11k
  size_t xs = DequantMatrices::required_size_x[quant_table_idx];
346
4.11k
  size_t ys = DequantMatrices::required_size_y[quant_table_idx];
347
4.11k
  CoefficientLayout(&ys, &xs);
348
16.4k
  for (size_t c = 0; c < 3; c++) {
349
36.0k
    for (size_t y = 0; y < ys; y++) {
350
225k
      for (size_t x = 0; x < xs; x++) {
351
201k
        inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs +
352
201k
                  x] = 0;
353
201k
      }
354
23.7k
    }
355
12.3k
  }
356
4.11k
  return true;
357
4.12k
}
jxl::N_AVX2::ComputeQuantTable(jxl::QuantEncoding const&, float*, float*, unsigned long, jxl::QuantTable, unsigned long*)
Line
Count
Source
166
26.8k
                         QuantTable kind, size_t* pos) {
167
26.8k
  constexpr size_t N = kBlockDim;
168
26.8k
  size_t quant_table_idx = static_cast<size_t>(kind);
169
26.8k
  size_t wrows = 8 * DequantMatrices::required_size_x[quant_table_idx];
170
26.8k
  size_t wcols = 8 * DequantMatrices::required_size_y[quant_table_idx];
171
26.8k
  size_t num = wrows * wcols;
172
173
26.8k
  std::vector<float> weights(3 * num);
174
175
26.8k
  switch (encoding.mode) {
176
0
    case QuantEncoding::kQuantModeLibrary: {
177
      // Library and copy quant encoding should get replaced by the actual
178
      // parameters by the caller.
179
0
      JXL_ENSURE(false);
180
0
      break;
181
0
    }
182
1.65k
    case QuantEncoding::kQuantModeID: {
183
1.65k
      JXL_ENSURE(num == kDCTBlockSize);
184
1.65k
      GetQuantWeightsIdentity(encoding.idweights, weights.data());
185
1.65k
      break;
186
1.65k
    }
187
1.56k
    case QuantEncoding::kQuantModeDCT2: {
188
1.56k
      JXL_ENSURE(num == kDCTBlockSize);
189
1.56k
      GetQuantWeightsDCT2(encoding.dct2weights, weights.data());
190
1.56k
      break;
191
1.56k
    }
192
1.19k
    case QuantEncoding::kQuantModeDCT4: {
193
1.19k
      JXL_ENSURE(num == kDCTBlockSize);
194
1.19k
      float weights4x4[3 * 4 * 4];
195
      // Always use 4x4 GetQuantWeights for DCT4 quantization tables.
196
1.19k
      JXL_RETURN_IF_ERROR(
197
1.19k
          GetQuantWeights(4, 4, encoding.dct_params.distance_bands,
198
1.19k
                          encoding.dct_params.num_distance_bands, weights4x4));
199
4.75k
      for (size_t c = 0; c < 3; c++) {
200
32.0k
        for (size_t y = 0; y < kBlockDim; y++) {
201
256k
          for (size_t x = 0; x < kBlockDim; x++) {
202
228k
            weights[c * num + y * kBlockDim + x] =
203
228k
                weights4x4[c * 16 + (y / 2) * 4 + (x / 2)];
204
228k
          }
205
28.5k
        }
206
3.56k
        weights[c * num + 1] /= encoding.dct4multipliers[c][0];
207
3.56k
        weights[c * num + N] /= encoding.dct4multipliers[c][0];
208
3.56k
        weights[c * num + N + 1] /= encoding.dct4multipliers[c][1];
209
3.56k
      }
210
1.18k
      break;
211
1.19k
    }
212
1.21k
    case QuantEncoding::kQuantModeDCT4X8: {
213
1.21k
      JXL_ENSURE(num == kDCTBlockSize);
214
1.21k
      float weights4x8[3 * 4 * 8];
215
      // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables.
216
1.21k
      JXL_RETURN_IF_ERROR(
217
1.21k
          GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
218
1.21k
                          encoding.dct_params.num_distance_bands, weights4x8));
219
4.83k
      for (size_t c = 0; c < 3; c++) {
220
32.6k
        for (size_t y = 0; y < kBlockDim; y++) {
221
261k
          for (size_t x = 0; x < kBlockDim; x++) {
222
232k
            weights[c * num + y * kBlockDim + x] =
223
232k
                weights4x8[c * 32 + (y / 2) * 8 + x];
224
232k
          }
225
29.0k
        }
226
3.62k
        weights[c * num + N] /= encoding.dct4x8multipliers[c];
227
3.62k
      }
228
1.20k
      break;
229
1.21k
    }
230
18.7k
    case QuantEncoding::kQuantModeDCT: {
231
18.7k
      JXL_RETURN_IF_ERROR(GetQuantWeights(
232
18.7k
          wrows, wcols, encoding.dct_params.distance_bands,
233
18.7k
          encoding.dct_params.num_distance_bands, weights.data()));
234
18.7k
      break;
235
18.7k
    }
236
18.7k
    case QuantEncoding::kQuantModeRAW: {
237
254
      if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) {
238
0
        return JXL_FAILURE("Invalid table encoding");
239
0
      }
240
254
      int* qtable = encoding.qraw.qtable->data();
241
49.4k
      for (size_t i = 0; i < 3 * num; i++) {
242
49.1k
        weights[i] = 1.f / (encoding.qraw.qtable_den * qtable[i]);
243
49.1k
      }
244
254
      break;
245
254
    }
246
2.19k
    case QuantEncoding::kQuantModeAFV: {
247
2.19k
      constexpr float kFreqs[] = {
248
2.19k
          0xBAD,
249
2.19k
          0xBAD,
250
2.19k
          0.8517778890324296,
251
2.19k
          5.37778436506804,
252
2.19k
          0xBAD,
253
2.19k
          0xBAD,
254
2.19k
          4.734747904497923,
255
2.19k
          5.449245381693219,
256
2.19k
          1.6598270267479331,
257
2.19k
          4,
258
2.19k
          7.275749096817861,
259
2.19k
          10.423227632456525,
260
2.19k
          2.662932286148962,
261
2.19k
          7.630657783650829,
262
2.19k
          8.962388608184032,
263
2.19k
          12.97166202570235,
264
2.19k
      };
265
266
2.19k
      float weights4x8[3 * 4 * 8];
267
2.19k
      JXL_RETURN_IF_ERROR((
268
2.19k
          GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
269
2.19k
                          encoding.dct_params.num_distance_bands, weights4x8)));
270
2.19k
      float weights4x4[3 * 4 * 4];
271
2.19k
      JXL_RETURN_IF_ERROR((GetQuantWeights(
272
2.19k
          4, 4, encoding.dct_params_afv_4x4.distance_bands,
273
2.19k
          encoding.dct_params_afv_4x4.num_distance_bands, weights4x4)));
274
275
2.19k
      constexpr float lo = 0.8517778890324296;
276
2.19k
      constexpr float hi = 12.97166202570235f - lo + 1e-6f;
277
8.75k
      for (size_t c = 0; c < 3; c++) {
278
6.57k
        float bands[4];
279
6.57k
        bands[0] = encoding.afv_weights[c][5];
280
6.57k
        if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
281
26.2k
        for (size_t i = 1; i < 4; i++) {
282
19.6k
          bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]);
283
19.6k
          if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
284
19.6k
        }
285
6.56k
        size_t start = c * 64;
286
6.56k
        auto set_weight = [&start, &weights](size_t x, size_t y, float val) {
287
6.56k
          weights[start + y * 8 + x] = val;
288
6.56k
        };
289
6.56k
        weights[start] = 1;  // Not used, but causes MSAN error otherwise.
290
        // Weights for (0, 1) and (1, 0).
291
6.56k
        set_weight(0, 1, encoding.afv_weights[c][0]);
292
6.56k
        set_weight(1, 0, encoding.afv_weights[c][1]);
293
        // AFV special weights for 3-pixel corner.
294
6.56k
        set_weight(0, 2, encoding.afv_weights[c][2]);
295
6.56k
        set_weight(2, 0, encoding.afv_weights[c][3]);
296
6.56k
        set_weight(2, 2, encoding.afv_weights[c][4]);
297
298
        // All other AFV weights.
299
32.8k
        for (size_t y = 0; y < 4; y++) {
300
131k
          for (size_t x = 0; x < 4; x++) {
301
104k
            if (x < 2 && y < 2) continue;
302
157k
            JXL_ASSIGN_OR_RETURN(
303
157k
                float val, Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4));
304
157k
            set_weight(2 * x, 2 * y, val);
305
157k
          }
306
26.2k
        }
307
308
        // Put 4x8 weights in odd rows, except (1, 0).
309
32.8k
        for (size_t y = 0; y < kBlockDim / 2; y++) {
310
236k
          for (size_t x = 0; x < kBlockDim; x++) {
311
209k
            if (x == 0 && y == 0) continue;
312
203k
            weights[c * num + (2 * y + 1) * kBlockDim + x] =
313
203k
                weights4x8[c * 32 + y * 8 + x];
314
203k
          }
315
26.2k
        }
316
        // Put 4x4 weights in even rows / odd columns, except (0, 1).
317
32.8k
        for (size_t y = 0; y < kBlockDim / 2; y++) {
318
131k
          for (size_t x = 0; x < kBlockDim / 2; x++) {
319
104k
            if (x == 0 && y == 0) continue;
320
98.4k
            weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] =
321
98.4k
                weights4x4[c * 16 + y * 4 + x];
322
98.4k
          }
323
26.2k
        }
324
6.56k
      }
325
2.18k
      break;
326
2.19k
    }
327
26.8k
  }
328
26.8k
  size_t prev_pos = *pos;
329
26.8k
  HWY_CAPPED(float, 64) d;
330
6.34M
  for (size_t i = 0; i < num * 3; i += Lanes(d)) {
331
6.32M
    auto inv_val = LoadU(d, weights.data() + i);
332
6.32M
    if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) ||
333
6.32M
                     !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) {
334
26
      return JXL_FAILURE("Invalid quantization table");
335
26
    }
336
6.32M
    auto val = Div(Set(d, 1.0f), inv_val);
337
6.32M
    StoreU(val, d, table + *pos + i);
338
6.32M
    StoreU(inv_val, d, inv_table + *pos + i);
339
6.32M
  }
340
26.8k
  (*pos) += 3 * num;
341
342
  // Ensure that the lowest frequencies have a 0 inverse table.
343
  // This does not affect en/decoding, but allows AC strategy selection to be
344
  // slightly simpler.
345
26.8k
  size_t xs = DequantMatrices::required_size_x[quant_table_idx];
346
26.8k
  size_t ys = DequantMatrices::required_size_y[quant_table_idx];
347
26.8k
  CoefficientLayout(&ys, &xs);
348
107k
  for (size_t c = 0; c < 3; c++) {
349
218k
    for (size_t y = 0; y < ys; y++) {
350
927k
      for (size_t x = 0; x < xs; x++) {
351
790k
        inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs +
352
790k
                  x] = 0;
353
790k
      }
354
137k
    }
355
80.4k
  }
356
26.8k
  return true;
357
26.8k
}
jxl::N_SSE2::ComputeQuantTable(jxl::QuantEncoding const&, float*, float*, unsigned long, jxl::QuantTable, unsigned long*)
Line
Count
Source
166
3.35k
                         QuantTable kind, size_t* pos) {
167
3.35k
  constexpr size_t N = kBlockDim;
168
3.35k
  size_t quant_table_idx = static_cast<size_t>(kind);
169
3.35k
  size_t wrows = 8 * DequantMatrices::required_size_x[quant_table_idx];
170
3.35k
  size_t wcols = 8 * DequantMatrices::required_size_y[quant_table_idx];
171
3.35k
  size_t num = wrows * wcols;
172
173
3.35k
  std::vector<float> weights(3 * num);
174
175
3.35k
  switch (encoding.mode) {
176
0
    case QuantEncoding::kQuantModeLibrary: {
177
      // Library and copy quant encoding should get replaced by the actual
178
      // parameters by the caller.
179
0
      JXL_ENSURE(false);
180
0
      break;
181
0
    }
182
50
    case QuantEncoding::kQuantModeID: {
183
50
      JXL_ENSURE(num == kDCTBlockSize);
184
50
      GetQuantWeightsIdentity(encoding.idweights, weights.data());
185
50
      break;
186
50
    }
187
66
    case QuantEncoding::kQuantModeDCT2: {
188
66
      JXL_ENSURE(num == kDCTBlockSize);
189
66
      GetQuantWeightsDCT2(encoding.dct2weights, weights.data());
190
66
      break;
191
66
    }
192
31
    case QuantEncoding::kQuantModeDCT4: {
193
31
      JXL_ENSURE(num == kDCTBlockSize);
194
31
      float weights4x4[3 * 4 * 4];
195
      // Always use 4x4 GetQuantWeights for DCT4 quantization tables.
196
31
      JXL_RETURN_IF_ERROR(
197
31
          GetQuantWeights(4, 4, encoding.dct_params.distance_bands,
198
31
                          encoding.dct_params.num_distance_bands, weights4x4));
199
116
      for (size_t c = 0; c < 3; c++) {
200
783
        for (size_t y = 0; y < kBlockDim; y++) {
201
6.26k
          for (size_t x = 0; x < kBlockDim; x++) {
202
5.56k
            weights[c * num + y * kBlockDim + x] =
203
5.56k
                weights4x4[c * 16 + (y / 2) * 4 + (x / 2)];
204
5.56k
          }
205
696
        }
206
87
        weights[c * num + 1] /= encoding.dct4multipliers[c][0];
207
87
        weights[c * num + N] /= encoding.dct4multipliers[c][0];
208
87
        weights[c * num + N + 1] /= encoding.dct4multipliers[c][1];
209
87
      }
210
29
      break;
211
31
    }
212
56
    case QuantEncoding::kQuantModeDCT4X8: {
213
56
      JXL_ENSURE(num == kDCTBlockSize);
214
56
      float weights4x8[3 * 4 * 8];
215
      // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables.
216
56
      JXL_RETURN_IF_ERROR(
217
56
          GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
218
56
                          encoding.dct_params.num_distance_bands, weights4x8));
219
220
      for (size_t c = 0; c < 3; c++) {
220
1.48k
        for (size_t y = 0; y < kBlockDim; y++) {
221
11.8k
          for (size_t x = 0; x < kBlockDim; x++) {
222
10.5k
            weights[c * num + y * kBlockDim + x] =
223
10.5k
                weights4x8[c * 32 + (y / 2) * 8 + x];
224
10.5k
          }
225
1.32k
        }
226
165
        weights[c * num + N] /= encoding.dct4x8multipliers[c];
227
165
      }
228
55
      break;
229
56
    }
230
3.06k
    case QuantEncoding::kQuantModeDCT: {
231
3.06k
      JXL_RETURN_IF_ERROR(GetQuantWeights(
232
3.06k
          wrows, wcols, encoding.dct_params.distance_bands,
233
3.06k
          encoding.dct_params.num_distance_bands, weights.data()));
234
3.06k
      break;
235
3.06k
    }
236
3.06k
    case QuantEncoding::kQuantModeRAW: {
237
35
      if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) {
238
0
        return JXL_FAILURE("Invalid table encoding");
239
0
      }
240
35
      int* qtable = encoding.qraw.qtable->data();
241
7.13k
      for (size_t i = 0; i < 3 * num; i++) {
242
7.10k
        weights[i] = 1.f / (encoding.qraw.qtable_den * qtable[i]);
243
7.10k
      }
244
35
      break;
245
35
    }
246
50
    case QuantEncoding::kQuantModeAFV: {
247
50
      constexpr float kFreqs[] = {
248
50
          0xBAD,
249
50
          0xBAD,
250
50
          0.8517778890324296,
251
50
          5.37778436506804,
252
50
          0xBAD,
253
50
          0xBAD,
254
50
          4.734747904497923,
255
50
          5.449245381693219,
256
50
          1.6598270267479331,
257
50
          4,
258
50
          7.275749096817861,
259
50
          10.423227632456525,
260
50
          2.662932286148962,
261
50
          7.630657783650829,
262
50
          8.962388608184032,
263
50
          12.97166202570235,
264
50
      };
265
266
50
      float weights4x8[3 * 4 * 8];
267
50
      JXL_RETURN_IF_ERROR((
268
50
          GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
269
50
                          encoding.dct_params.num_distance_bands, weights4x8)));
270
50
      float weights4x4[3 * 4 * 4];
271
50
      JXL_RETURN_IF_ERROR((GetQuantWeights(
272
50
          4, 4, encoding.dct_params_afv_4x4.distance_bands,
273
50
          encoding.dct_params_afv_4x4.num_distance_bands, weights4x4)));
274
275
49
      constexpr float lo = 0.8517778890324296;
276
49
      constexpr float hi = 12.97166202570235f - lo + 1e-6f;
277
189
      for (size_t c = 0; c < 3; c++) {
278
144
        float bands[4];
279
144
        bands[0] = encoding.afv_weights[c][5];
280
144
        if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
281
566
        for (size_t i = 1; i < 4; i++) {
282
426
          bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]);
283
426
          if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
284
426
        }
285
140
        size_t start = c * 64;
286
140
        auto set_weight = [&start, &weights](size_t x, size_t y, float val) {
287
140
          weights[start + y * 8 + x] = val;
288
140
        };
289
140
        weights[start] = 1;  // Not used, but causes MSAN error otherwise.
290
        // Weights for (0, 1) and (1, 0).
291
140
        set_weight(0, 1, encoding.afv_weights[c][0]);
292
140
        set_weight(1, 0, encoding.afv_weights[c][1]);
293
        // AFV special weights for 3-pixel corner.
294
140
        set_weight(0, 2, encoding.afv_weights[c][2]);
295
140
        set_weight(2, 0, encoding.afv_weights[c][3]);
296
140
        set_weight(2, 2, encoding.afv_weights[c][4]);
297
298
        // All other AFV weights.
299
700
        for (size_t y = 0; y < 4; y++) {
300
2.80k
          for (size_t x = 0; x < 4; x++) {
301
2.24k
            if (x < 2 && y < 2) continue;
302
3.36k
            JXL_ASSIGN_OR_RETURN(
303
3.36k
                float val, Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4));
304
3.36k
            set_weight(2 * x, 2 * y, val);
305
3.36k
          }
306
560
        }
307
308
        // Put 4x8 weights in odd rows, except (1, 0).
309
700
        for (size_t y = 0; y < kBlockDim / 2; y++) {
310
5.04k
          for (size_t x = 0; x < kBlockDim; x++) {
311
4.48k
            if (x == 0 && y == 0) continue;
312
4.34k
            weights[c * num + (2 * y + 1) * kBlockDim + x] =
313
4.34k
                weights4x8[c * 32 + y * 8 + x];
314
4.34k
          }
315
560
        }
316
        // Put 4x4 weights in even rows / odd columns, except (0, 1).
317
700
        for (size_t y = 0; y < kBlockDim / 2; y++) {
318
2.80k
          for (size_t x = 0; x < kBlockDim / 2; x++) {
319
2.24k
            if (x == 0 && y == 0) continue;
320
2.10k
            weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] =
321
2.10k
                weights4x4[c * 16 + y * 4 + x];
322
2.10k
          }
323
560
        }
324
140
      }
325
45
      break;
326
49
    }
327
3.35k
  }
328
3.34k
  size_t prev_pos = *pos;
329
3.34k
  HWY_CAPPED(float, 64) d;
330
7.01M
  for (size_t i = 0; i < num * 3; i += Lanes(d)) {
331
7.01M
    auto inv_val = LoadU(d, weights.data() + i);
332
7.01M
    if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) ||
333
7.01M
                     !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) {
334
12
      return JXL_FAILURE("Invalid quantization table");
335
12
    }
336
7.01M
    auto val = Div(Set(d, 1.0f), inv_val);
337
7.01M
    StoreU(val, d, table + *pos + i);
338
7.01M
    StoreU(inv_val, d, inv_table + *pos + i);
339
7.01M
  }
340
3.33k
  (*pos) += 3 * num;
341
342
  // Ensure that the lowest frequencies have a 0 inverse table.
343
  // This does not affect en/decoding, but allows AC strategy selection to be
344
  // slightly simpler.
345
3.33k
  size_t xs = DequantMatrices::required_size_x[quant_table_idx];
346
3.33k
  size_t ys = DequantMatrices::required_size_y[quant_table_idx];
347
3.33k
  CoefficientLayout(&ys, &xs);
348
13.3k
  for (size_t c = 0; c < 3; c++) {
349
45.4k
    for (size_t y = 0; y < ys; y++) {
350
473k
      for (size_t x = 0; x < xs; x++) {
351
438k
        inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs +
352
438k
                  x] = 0;
353
438k
      }
354
35.4k
    }
355
9.99k
  }
356
3.33k
  return true;
357
3.34k
}
358
359
// NOLINTNEXTLINE(google-readability-namespace-comments)
360
}  // namespace HWY_NAMESPACE
361
}  // namespace jxl
362
HWY_AFTER_NAMESPACE();
363
364
#if HWY_ONCE
365
366
namespace jxl {
367
namespace {
368
369
HWY_EXPORT(ComputeQuantTable);
370
371
constexpr const float kAlmostZero = 1e-8f;
372
373
272
Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) {
374
272
  params->num_distance_bands =
375
272
      br->ReadFixedBits<DctQuantWeightParams::kLog2MaxDistanceBands>() + 1;
376
967
  for (size_t c = 0; c < 3; c++) {
377
3.30k
    for (size_t i = 0; i < params->num_distance_bands; i++) {
378
2.55k
      JXL_RETURN_IF_ERROR(F16Coder::Read(br, &params->distance_bands[c][i]));
379
2.55k
    }
380
743
    if (params->distance_bands[c][0] < kAlmostZero) {
381
48
      return JXL_FAILURE("Distance band seed is too small");
382
48
    }
383
695
    params->distance_bands[c][0] *= 64.0f;
384
695
  }
385
216
  return true;
386
272
}
387
388
Status Decode(JxlMemoryManager* memory_manager, BitReader* br,
389
              QuantEncoding* encoding, size_t required_size_x,
390
              size_t required_size_y, size_t idx,
391
14.6k
              ModularFrameDecoder* modular_frame_decoder) {
392
14.6k
  size_t required_size = required_size_x * required_size_y;
393
14.6k
  required_size_x *= kBlockDim;
394
14.6k
  required_size_y *= kBlockDim;
395
14.6k
  int mode = br->ReadFixedBits<kLog2NumQuantModes>();
396
14.6k
  switch (mode) {
397
13.5k
    case QuantEncoding::kQuantModeLibrary: {
398
13.5k
      encoding->predefined = br->ReadFixedBits<kCeilLog2NumPredefinedTables>();
399
13.5k
      if (encoding->predefined >= kNumPredefinedTables) {
400
0
        return JXL_FAILURE("Invalid predefined table");
401
0
      }
402
13.5k
      break;
403
13.5k
    }
404
13.5k
    case QuantEncoding::kQuantModeID: {
405
82
      if (required_size != 1) return JXL_FAILURE("Invalid mode");
406
204
      for (size_t c = 0; c < 3; c++) {
407
599
        for (size_t i = 0; i < 3; i++) {
408
471
          JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i]));
409
451
          if (std::abs(encoding->idweights[c][i]) < kAlmostZero) {
410
26
            return JXL_FAILURE("ID Quantizer is too small");
411
26
          }
412
425
          encoding->idweights[c][i] *= 64;
413
425
        }
414
174
      }
415
30
      break;
416
76
    }
417
115
    case QuantEncoding::kQuantModeDCT2: {
418
115
      if (required_size != 1) return JXL_FAILURE("Invalid mode");
419
237
      for (size_t c = 0; c < 3; c++) {
420
1.19k
        for (size_t i = 0; i < 6; i++) {
421
1.06k
          JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i]));
422
1.02k
          if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) {
423
48
            return JXL_FAILURE("Quantizer is too small");
424
48
          }
425
981
          encoding->dct2weights[c][i] *= 64;
426
981
        }
427
216
      }
428
21
      break;
429
107
    }
430
85
    case QuantEncoding::kQuantModeDCT4X8: {
431
85
      if (required_size != 1) return JXL_FAILURE("Invalid mode");
432
259
      for (size_t c = 0; c < 3; c++) {
433
204
        JXL_RETURN_IF_ERROR(
434
204
            F16Coder::Read(br, &encoding->dct4x8multipliers[c]));
435
196
        if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) {
436
12
          return JXL_FAILURE("DCT4X8 multiplier is too small");
437
12
        }
438
196
      }
439
55
      JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params));
440
44
      break;
441
55
    }
442
56
    case QuantEncoding::kQuantModeDCT4: {
443
56
      if (required_size != 1) return JXL_FAILURE("Invalid mode");
444
153
      for (size_t c = 0; c < 3; c++) {
445
344
        for (size_t i = 0; i < 2; i++) {
446
242
          JXL_RETURN_IF_ERROR(
447
242
              F16Coder::Read(br, &encoding->dct4multipliers[c][i]));
448
233
          if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) {
449
15
            return JXL_FAILURE("DCT4 multiplier is too small");
450
15
          }
451
233
        }
452
126
      }
453
27
      JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params));
454
25
      break;
455
27
    }
456
97
    case QuantEncoding::kQuantModeAFV: {
457
97
      if (required_size != 1) return JXL_FAILURE("Invalid mode");
458
274
      for (size_t c = 0; c < 3; c++) {
459
2.05k
        for (size_t i = 0; i < 9; i++) {
460
1.87k
          JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i]));
461
1.87k
        }
462
1.26k
        for (size_t i = 0; i < 6; i++) {
463
1.08k
          encoding->afv_weights[c][i] *= 64;
464
1.08k
        }
465
180
      }
466
42
      JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params));
467
35
      JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4));
468
33
      break;
469
35
    }
470
113
    case QuantEncoding::kQuantModeDCT: {
471
113
      JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params));
472
79
      break;
473
113
    }
474
497
    case QuantEncoding::kQuantModeRAW: {
475
      // Set mode early, to avoid mem-leak.
476
497
      encoding->mode = QuantEncoding::kQuantModeRAW;
477
497
      JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable(
478
497
          memory_manager, required_size_x, required_size_y, br, encoding, idx,
479
497
          modular_frame_decoder));
480
424
      break;
481
497
    }
482
424
    default:
483
0
      return JXL_FAILURE("Invalid quantization table encoding");
484
14.6k
  }
485
14.2k
  encoding->mode = static_cast<QuantEncoding::Mode>(mode);
486
14.2k
  return true;
487
14.6k
}
488
489
}  // namespace
490
491
#if JXL_CXX_LANG < JXL_CXX_17
492
constexpr const std::array<int, 17> DequantMatrices::required_size_x;
493
constexpr const std::array<int, 17> DequantMatrices::required_size_y;
494
constexpr const size_t DequantMatrices::kSumRequiredXy;
495
#endif
496
497
Status DequantMatrices::Decode(JxlMemoryManager* memory_manager, BitReader* br,
498
17.8k
                               ModularFrameDecoder* modular_frame_decoder) {
499
17.8k
  size_t all_default = br->ReadBits(1);
500
17.8k
  size_t num_tables = all_default ? 0 : static_cast<size_t>(kNumQuantTables);
501
17.8k
  encodings_.clear();
502
17.8k
  encodings_.resize(kNumQuantTables, QuantEncoding::Library<0>());
503
32.0k
  for (size_t i = 0; i < num_tables; i++) {
504
14.6k
    JXL_RETURN_IF_ERROR(jxl::Decode(memory_manager, br, &encodings_[i],
505
14.6k
                                    required_size_x[i % kNumQuantTables],
506
14.6k
                                    required_size_y[i % kNumQuantTables], i,
507
14.6k
                                    modular_frame_decoder));
508
14.6k
  }
509
17.4k
  computed_mask_ = 0;
510
17.4k
  return true;
511
17.8k
}
512
513
97.7k
Status DequantMatrices::DecodeDC(BitReader* br) {
514
97.7k
  bool all_default = static_cast<bool>(br->ReadBits(1));
515
97.7k
  if (!br->AllReadsWithinBounds()) return JXL_FAILURE("EOS during DecodeDC");
516
96.4k
  if (!all_default) {
517
11.4k
    for (size_t c = 0; c < 3; c++) {
518
8.64k
      JXL_RETURN_IF_ERROR(F16Coder::Read(br, &dc_quant_[c]));
519
8.63k
      dc_quant_[c] *= 1.0f / 128.0f;
520
      // Negative values and nearly zero are invalid values.
521
8.63k
      if (dc_quant_[c] < kAlmostZero) {
522
94
        return JXL_FAILURE("Invalid dc_quant: coefficient is too small.");
523
94
      }
524
8.53k
      inv_dc_quant_[c] = 1.0f / dc_quant_[c];
525
8.53k
    }
526
2.93k
  }
527
96.3k
  return true;
528
96.4k
}
529
530
0
constexpr float V(float v) { return static_cast<float>(v); }
531
532
namespace {
533
struct DequantMatricesLibraryDef {
534
  // DCT8
535
4
  static constexpr QuantEncodingInternal DCT() {
536
4
    return QuantEncodingInternal::DCT(DctQuantWeightParams({{{{
537
4
                                                                 V(3150.0),
538
4
                                                                 V(0.0),
539
4
                                                                 V(-0.4),
540
4
                                                                 V(-0.4),
541
4
                                                                 V(-0.4),
542
4
                                                                 V(-2.0),
543
4
                                                             }},
544
4
                                                             {{
545
4
                                                                 V(560.0),
546
4
                                                                 V(0.0),
547
4
                                                                 V(-0.3),
548
4
                                                                 V(-0.3),
549
4
                                                                 V(-0.3),
550
4
                                                                 V(-0.3),
551
4
                                                             }},
552
4
                                                             {{
553
4
                                                                 V(512.0),
554
4
                                                                 V(-2.0),
555
4
                                                                 V(-1.0),
556
4
                                                                 V(0.0),
557
4
                                                                 V(-1.0),
558
4
                                                                 V(-2.0),
559
4
                                                             }}}},
560
4
                                                           6));
561
4
  }
562
563
  // Identity
564
4
  static constexpr QuantEncodingInternal IDENTITY() {
565
4
    return QuantEncodingInternal::Identity({{{{
566
4
                                                 V(280.0),
567
4
                                                 V(3160.0),
568
4
                                                 V(3160.0),
569
4
                                             }},
570
4
                                             {{
571
4
                                                 V(60.0),
572
4
                                                 V(864.0),
573
4
                                                 V(864.0),
574
4
                                             }},
575
4
                                             {{
576
4
                                                 V(18.0),
577
4
                                                 V(200.0),
578
4
                                                 V(200.0),
579
4
                                             }}}});
580
4
  }
581
582
  // DCT2
583
4
  static constexpr QuantEncodingInternal DCT2X2() {
584
4
    return QuantEncodingInternal::DCT2({{{{
585
4
                                             V(3840.0),
586
4
                                             V(2560.0),
587
4
                                             V(1280.0),
588
4
                                             V(640.0),
589
4
                                             V(480.0),
590
4
                                             V(300.0),
591
4
                                         }},
592
4
                                         {{
593
4
                                             V(960.0),
594
4
                                             V(640.0),
595
4
                                             V(320.0),
596
4
                                             V(180.0),
597
4
                                             V(140.0),
598
4
                                             V(120.0),
599
4
                                         }},
600
4
                                         {{
601
4
                                             V(640.0),
602
4
                                             V(320.0),
603
4
                                             V(128.0),
604
4
                                             V(64.0),
605
4
                                             V(32.0),
606
4
                                             V(16.0),
607
4
                                         }}}});
608
4
  }
609
610
  // DCT4 (quant_kind 3)
611
8
  static constexpr QuantEncodingInternal DCT4X4() {
612
8
    return QuantEncodingInternal::DCT4(DctQuantWeightParams({{{{
613
8
                                                                  V(2200.0),
614
8
                                                                  V(0.0),
615
8
                                                                  V(0.0),
616
8
                                                                  V(0.0),
617
8
                                                              }},
618
8
                                                              {{
619
8
                                                                  V(392.0),
620
8
                                                                  V(0.0),
621
8
                                                                  V(0.0),
622
8
                                                                  V(0.0),
623
8
                                                              }},
624
8
                                                              {{
625
8
                                                                  V(112.0),
626
8
                                                                  V(-0.25),
627
8
                                                                  V(-0.25),
628
8
                                                                  V(-0.5),
629
8
                                                              }}}},
630
8
                                                            4),
631
                                       /* kMul */
632
8
                                       {{{{
633
8
                                             V(1.0),
634
8
                                             V(1.0),
635
8
                                         }},
636
8
                                         {{
637
8
                                             V(1.0),
638
8
                                             V(1.0),
639
8
                                         }},
640
8
                                         {{
641
8
                                             V(1.0),
642
8
                                             V(1.0),
643
8
                                         }}}});
644
8
  }
645
646
  // DCT16
647
4
  static constexpr QuantEncodingInternal DCT16X16() {
648
4
    return QuantEncodingInternal::DCT(
649
4
        DctQuantWeightParams({{{{
650
4
                                   V(8996.8725711814115328),
651
4
                                   V(-1.3000777393353804),
652
4
                                   V(-0.49424529824571225),
653
4
                                   V(-0.439093774457103443),
654
4
                                   V(-0.6350101832695744),
655
4
                                   V(-0.90177264050827612),
656
4
                                   V(-1.6162099239887414),
657
4
                               }},
658
4
                               {{
659
4
                                   V(3191.48366296844234752),
660
4
                                   V(-0.67424582104194355),
661
4
                                   V(-0.80745813428471001),
662
4
                                   V(-0.44925837484843441),
663
4
                                   V(-0.35865440981033403),
664
4
                                   V(-0.31322389111877305),
665
4
                                   V(-0.37615025315725483),
666
4
                               }},
667
4
                               {{
668
4
                                   V(1157.50408145487200256),
669
4
                                   V(-2.0531423165804414),
670
4
                                   V(-1.4),
671
4
                                   V(-0.50687130033378396),
672
4
                                   V(-0.42708730624733904),
673
4
                                   V(-1.4856834539296244),
674
4
                                   V(-4.9209142884401604),
675
4
                               }}}},
676
4
                             7));
677
4
  }
678
679
  // DCT32
680
4
  static constexpr QuantEncodingInternal DCT32X32() {
681
4
    return QuantEncodingInternal::DCT(
682
4
        DctQuantWeightParams({{{{
683
4
                                   V(15718.40830982518931456),
684
4
                                   V(-1.025),
685
4
                                   V(-0.98),
686
4
                                   V(-0.9012),
687
4
                                   V(-0.4),
688
4
                                   V(-0.48819395464),
689
4
                                   V(-0.421064),
690
4
                                   V(-0.27),
691
4
                               }},
692
4
                               {{
693
4
                                   V(7305.7636810695983104),
694
4
                                   V(-0.8041958212306401),
695
4
                                   V(-0.7633036457487539),
696
4
                                   V(-0.55660379990111464),
697
4
                                   V(-0.49785304658857626),
698
4
                                   V(-0.43699592683512467),
699
4
                                   V(-0.40180866526242109),
700
4
                                   V(-0.27321683125358037),
701
4
                               }},
702
4
                               {{
703
4
                                   V(3803.53173721215041536),
704
4
                                   V(-3.060733579805728),
705
4
                                   V(-2.0413270132490346),
706
4
                                   V(-2.0235650159727417),
707
4
                                   V(-0.5495389509954993),
708
4
                                   V(-0.4),
709
4
                                   V(-0.4),
710
4
                                   V(-0.3),
711
4
                               }}}},
712
4
                             8));
713
4
  }
714
715
  // DCT16X8
716
4
  static constexpr QuantEncodingInternal DCT8X16() {
717
4
    return QuantEncodingInternal::DCT(
718
4
        DctQuantWeightParams({{{{
719
4
                                   V(7240.7734393502),
720
4
                                   V(-0.7),
721
4
                                   V(-0.7),
722
4
                                   V(-0.2),
723
4
                                   V(-0.2),
724
4
                                   V(-0.2),
725
4
                                   V(-0.5),
726
4
                               }},
727
4
                               {{
728
4
                                   V(1448.15468787004),
729
4
                                   V(-0.5),
730
4
                                   V(-0.5),
731
4
                                   V(-0.5),
732
4
                                   V(-0.2),
733
4
                                   V(-0.2),
734
4
                                   V(-0.2),
735
4
                               }},
736
4
                               {{
737
4
                                   V(506.854140754517),
738
4
                                   V(-1.4),
739
4
                                   V(-0.2),
740
4
                                   V(-0.5),
741
4
                                   V(-0.5),
742
4
                                   V(-1.5),
743
4
                                   V(-3.6),
744
4
                               }}}},
745
4
                             7));
746
4
  }
747
748
  // DCT32X8
749
4
  static constexpr QuantEncodingInternal DCT8X32() {
750
4
    return QuantEncodingInternal::DCT(
751
4
        DctQuantWeightParams({{{{
752
4
                                   V(16283.2494710648897),
753
4
                                   V(-1.7812845336559429),
754
4
                                   V(-1.6309059012653515),
755
4
                                   V(-1.0382179034313539),
756
4
                                   V(-0.85),
757
4
                                   V(-0.7),
758
4
                                   V(-0.9),
759
4
                                   V(-1.2360638576849587),
760
4
                               }},
761
4
                               {{
762
4
                                   V(5089.15750884921511936),
763
4
                                   V(-0.320049391452786891),
764
4
                                   V(-0.35362849922161446),
765
4
                                   V(-0.30340000000000003),
766
4
                                   V(-0.61),
767
4
                                   V(-0.5),
768
4
                                   V(-0.5),
769
4
                                   V(-0.6),
770
4
                               }},
771
4
                               {{
772
4
                                   V(3397.77603275308720128),
773
4
                                   V(-0.321327362693153371),
774
4
                                   V(-0.34507619223117997),
775
4
                                   V(-0.70340000000000003),
776
4
                                   V(-0.9),
777
4
                                   V(-1.0),
778
4
                                   V(-1.0),
779
4
                                   V(-1.1754605576265209),
780
4
                               }}}},
781
4
                             8));
782
4
  }
783
784
  // DCT32X16
785
4
  static constexpr QuantEncodingInternal DCT16X32() {
786
4
    return QuantEncodingInternal::DCT(
787
4
        DctQuantWeightParams({{{{
788
4
                                   V(13844.97076442300573),
789
4
                                   V(-0.97113799999999995),
790
4
                                   V(-0.658),
791
4
                                   V(-0.42026),
792
4
                                   V(-0.22712),
793
4
                                   V(-0.2206),
794
4
                                   V(-0.226),
795
4
                                   V(-0.6),
796
4
                               }},
797
4
                               {{
798
4
                                   V(4798.964084220744293),
799
4
                                   V(-0.61125308982767057),
800
4
                                   V(-0.83770786552491361),
801
4
                                   V(-0.79014862079498627),
802
4
                                   V(-0.2692727459704829),
803
4
                                   V(-0.38272769465388551),
804
4
                                   V(-0.22924222653091453),
805
4
                                   V(-0.20719098826199578),
806
4
                               }},
807
4
                               {{
808
4
                                   V(1807.236946760964614),
809
4
                                   V(-1.2),
810
4
                                   V(-1.2),
811
4
                                   V(-0.7),
812
4
                                   V(-0.7),
813
4
                                   V(-0.7),
814
4
                                   V(-0.4),
815
4
                                   V(-0.5),
816
4
                               }}}},
817
4
                             8));
818
4
  }
819
820
  // DCT4X8 and 8x4
821
8
  static constexpr QuantEncodingInternal DCT4X8() {
822
8
    return QuantEncodingInternal::DCT4X8(
823
8
        DctQuantWeightParams({{
824
8
                                 {{
825
8
                                     V(2198.050556016380522),
826
8
                                     V(-0.96269623020744692),
827
8
                                     V(-0.76194253026666783),
828
8
                                     V(-0.6551140670773547),
829
8
                                 }},
830
8
                                 {{
831
8
                                     V(764.3655248643528689),
832
8
                                     V(-0.92630200888366945),
833
8
                                     V(-0.9675229603596517),
834
8
                                     V(-0.27845290869168118),
835
8
                                 }},
836
8
                                 {{
837
8
                                     V(527.107573587542228),
838
8
                                     V(-1.4594385811273854),
839
8
                                     V(-1.450082094097871593),
840
8
                                     V(-1.5843722511996204),
841
8
                                 }},
842
8
                             }},
843
8
                             4),
844
        /* kMuls */
845
8
        {{
846
8
            V(1.0),
847
8
            V(1.0),
848
8
            V(1.0),
849
8
        }});
850
8
  }
851
  // AFV
852
4
  static QuantEncodingInternal AFV0() {
853
4
    return QuantEncodingInternal::AFV(DCT4X8().dct_params, DCT4X4().dct_params,
854
4
                                      {{{{
855
                                            // 4x4/4x8 DC tendency.
856
4
                                            V(3072.0),
857
4
                                            V(3072.0),
858
                                            // AFV corner.
859
4
                                            V(256.0),
860
4
                                            V(256.0),
861
4
                                            V(256.0),
862
                                            // AFV high freqs.
863
4
                                            V(414.0),
864
4
                                            V(0.0),
865
4
                                            V(0.0),
866
4
                                            V(0.0),
867
4
                                        }},
868
4
                                        {{
869
                                            // 4x4/4x8 DC tendency.
870
4
                                            V(1024.0),
871
4
                                            V(1024.0),
872
                                            // AFV corner.
873
4
                                            V(50),
874
4
                                            V(50),
875
4
                                            V(50),
876
                                            // AFV high freqs.
877
4
                                            V(58.0),
878
4
                                            V(0.0),
879
4
                                            V(0.0),
880
4
                                            V(0.0),
881
4
                                        }},
882
4
                                        {{
883
                                            // 4x4/4x8 DC tendency.
884
4
                                            V(384.0),
885
4
                                            V(384.0),
886
                                            // AFV corner.
887
4
                                            V(12.0),
888
4
                                            V(12.0),
889
4
                                            V(12.0),
890
                                            // AFV high freqs.
891
4
                                            V(22.0),
892
4
                                            V(-0.25),
893
4
                                            V(-0.25),
894
4
                                            V(-0.25),
895
4
                                        }}}});
896
4
  }
897
898
  // DCT64
899
4
  static QuantEncodingInternal DCT64X64() {
900
4
    return QuantEncodingInternal::DCT(
901
4
        DctQuantWeightParams({{{{
902
4
                                   V(0.9 * 26629.073922049845),
903
4
                                   V(-1.025),
904
4
                                   V(-0.78),
905
4
                                   V(-0.65012),
906
4
                                   V(-0.19041574084286472),
907
4
                                   V(-0.20819395464),
908
4
                                   V(-0.421064),
909
4
                                   V(-0.32733845535848671),
910
4
                               }},
911
4
                               {{
912
4
                                   V(0.9 * 9311.3238710010046),
913
4
                                   V(-0.3041958212306401),
914
4
                                   V(-0.3633036457487539),
915
4
                                   V(-0.35660379990111464),
916
4
                                   V(-0.3443074455424403),
917
4
                                   V(-0.33699592683512467),
918
4
                                   V(-0.30180866526242109),
919
4
                                   V(-0.27321683125358037),
920
4
                               }},
921
4
                               {{
922
4
                                   V(0.9 * 4992.2486445538634),
923
4
                                   V(-1.2),
924
4
                                   V(-1.2),
925
4
                                   V(-0.8),
926
4
                                   V(-0.7),
927
4
                                   V(-0.7),
928
4
                                   V(-0.4),
929
4
                                   V(-0.5),
930
4
                               }}}},
931
4
                             8));
932
4
  }
933
934
  // DCT64X32
935
4
  static QuantEncodingInternal DCT32X64() {
936
4
    return QuantEncodingInternal::DCT(
937
4
        DctQuantWeightParams({{{{
938
4
                                   V(0.65 * 23629.073922049845),
939
4
                                   V(-1.025),
940
4
                                   V(-0.78),
941
4
                                   V(-0.65012),
942
4
                                   V(-0.19041574084286472),
943
4
                                   V(-0.20819395464),
944
4
                                   V(-0.421064),
945
4
                                   V(-0.32733845535848671),
946
4
                               }},
947
4
                               {{
948
4
                                   V(0.65 * 8611.3238710010046),
949
4
                                   V(-0.3041958212306401),
950
4
                                   V(-0.3633036457487539),
951
4
                                   V(-0.35660379990111464),
952
4
                                   V(-0.3443074455424403),
953
4
                                   V(-0.33699592683512467),
954
4
                                   V(-0.30180866526242109),
955
4
                                   V(-0.27321683125358037),
956
4
                               }},
957
4
                               {{
958
4
                                   V(0.65 * 4492.2486445538634),
959
4
                                   V(-1.2),
960
4
                                   V(-1.2),
961
4
                                   V(-0.8),
962
4
                                   V(-0.7),
963
4
                                   V(-0.7),
964
4
                                   V(-0.4),
965
4
                                   V(-0.5),
966
4
                               }}}},
967
4
                             8));
968
4
  }
969
  // DCT128X128
970
4
  static QuantEncodingInternal DCT128X128() {
971
4
    return QuantEncodingInternal::DCT(
972
4
        DctQuantWeightParams({{{{
973
4
                                   V(1.8 * 26629.073922049845),
974
4
                                   V(-1.025),
975
4
                                   V(-0.78),
976
4
                                   V(-0.65012),
977
4
                                   V(-0.19041574084286472),
978
4
                                   V(-0.20819395464),
979
4
                                   V(-0.421064),
980
4
                                   V(-0.32733845535848671),
981
4
                               }},
982
4
                               {{
983
4
                                   V(1.8 * 9311.3238710010046),
984
4
                                   V(-0.3041958212306401),
985
4
                                   V(-0.3633036457487539),
986
4
                                   V(-0.35660379990111464),
987
4
                                   V(-0.3443074455424403),
988
4
                                   V(-0.33699592683512467),
989
4
                                   V(-0.30180866526242109),
990
4
                                   V(-0.27321683125358037),
991
4
                               }},
992
4
                               {{
993
4
                                   V(1.8 * 4992.2486445538634),
994
4
                                   V(-1.2),
995
4
                                   V(-1.2),
996
4
                                   V(-0.8),
997
4
                                   V(-0.7),
998
4
                                   V(-0.7),
999
4
                                   V(-0.4),
1000
4
                                   V(-0.5),
1001
4
                               }}}},
1002
4
                             8));
1003
4
  }
1004
1005
  // DCT128X64
1006
4
  static QuantEncodingInternal DCT64X128() {
1007
4
    return QuantEncodingInternal::DCT(
1008
4
        DctQuantWeightParams({{{{
1009
4
                                   V(1.3 * 23629.073922049845),
1010
4
                                   V(-1.025),
1011
4
                                   V(-0.78),
1012
4
                                   V(-0.65012),
1013
4
                                   V(-0.19041574084286472),
1014
4
                                   V(-0.20819395464),
1015
4
                                   V(-0.421064),
1016
4
                                   V(-0.32733845535848671),
1017
4
                               }},
1018
4
                               {{
1019
4
                                   V(1.3 * 8611.3238710010046),
1020
4
                                   V(-0.3041958212306401),
1021
4
                                   V(-0.3633036457487539),
1022
4
                                   V(-0.35660379990111464),
1023
4
                                   V(-0.3443074455424403),
1024
4
                                   V(-0.33699592683512467),
1025
4
                                   V(-0.30180866526242109),
1026
4
                                   V(-0.27321683125358037),
1027
4
                               }},
1028
4
                               {{
1029
4
                                   V(1.3 * 4492.2486445538634),
1030
4
                                   V(-1.2),
1031
4
                                   V(-1.2),
1032
4
                                   V(-0.8),
1033
4
                                   V(-0.7),
1034
4
                                   V(-0.7),
1035
4
                                   V(-0.4),
1036
4
                                   V(-0.5),
1037
4
                               }}}},
1038
4
                             8));
1039
4
  }
1040
  // DCT256X256
1041
4
  static QuantEncodingInternal DCT256X256() {
1042
4
    return QuantEncodingInternal::DCT(
1043
4
        DctQuantWeightParams({{{{
1044
4
                                   V(3.6 * 26629.073922049845),
1045
4
                                   V(-1.025),
1046
4
                                   V(-0.78),
1047
4
                                   V(-0.65012),
1048
4
                                   V(-0.19041574084286472),
1049
4
                                   V(-0.20819395464),
1050
4
                                   V(-0.421064),
1051
4
                                   V(-0.32733845535848671),
1052
4
                               }},
1053
4
                               {{
1054
4
                                   V(3.6 * 9311.3238710010046),
1055
4
                                   V(-0.3041958212306401),
1056
4
                                   V(-0.3633036457487539),
1057
4
                                   V(-0.35660379990111464),
1058
4
                                   V(-0.3443074455424403),
1059
4
                                   V(-0.33699592683512467),
1060
4
                                   V(-0.30180866526242109),
1061
4
                                   V(-0.27321683125358037),
1062
4
                               }},
1063
4
                               {{
1064
4
                                   V(3.6 * 4992.2486445538634),
1065
4
                                   V(-1.2),
1066
4
                                   V(-1.2),
1067
4
                                   V(-0.8),
1068
4
                                   V(-0.7),
1069
4
                                   V(-0.7),
1070
4
                                   V(-0.4),
1071
4
                                   V(-0.5),
1072
4
                               }}}},
1073
4
                             8));
1074
4
  }
1075
1076
  // DCT256X128
1077
4
  static QuantEncodingInternal DCT128X256() {
1078
4
    return QuantEncodingInternal::DCT(
1079
4
        DctQuantWeightParams({{{{
1080
4
                                   V(2.6 * 23629.073922049845),
1081
4
                                   V(-1.025),
1082
4
                                   V(-0.78),
1083
4
                                   V(-0.65012),
1084
4
                                   V(-0.19041574084286472),
1085
4
                                   V(-0.20819395464),
1086
4
                                   V(-0.421064),
1087
4
                                   V(-0.32733845535848671),
1088
4
                               }},
1089
4
                               {{
1090
4
                                   V(2.6 * 8611.3238710010046),
1091
4
                                   V(-0.3041958212306401),
1092
4
                                   V(-0.3633036457487539),
1093
4
                                   V(-0.35660379990111464),
1094
4
                                   V(-0.3443074455424403),
1095
4
                                   V(-0.33699592683512467),
1096
4
                                   V(-0.30180866526242109),
1097
4
                                   V(-0.27321683125358037),
1098
4
                               }},
1099
4
                               {{
1100
4
                                   V(2.6 * 4492.2486445538634),
1101
4
                                   V(-1.2),
1102
4
                                   V(-1.2),
1103
4
                                   V(-0.8),
1104
4
                                   V(-0.7),
1105
4
                                   V(-0.7),
1106
4
                                   V(-0.4),
1107
4
                                   V(-0.5),
1108
4
                               }}}},
1109
4
                             8));
1110
4
  }
1111
};
1112
}  // namespace
1113
1114
4
DequantMatrices::DequantLibraryInternal DequantMatrices::LibraryInit() {
1115
4
  static_assert(kNumQuantTables == 17,
1116
4
                "Update this function when adding new quantization kinds.");
1117
4
  static_assert(kNumPredefinedTables == 1,
1118
4
                "Update this function when adding new quantization matrices to "
1119
4
                "the library.");
1120
1121
  // The library and the indices need to be kept in sync manually.
1122
4
  static_assert(0 == static_cast<uint8_t>(QuantTable::DCT),
1123
4
                "Update the DequantLibrary array below.");
1124
4
  static_assert(1 == static_cast<uint8_t>(QuantTable::IDENTITY),
1125
4
                "Update the DequantLibrary array below.");
1126
4
  static_assert(2 == static_cast<uint8_t>(QuantTable::DCT2X2),
1127
4
                "Update the DequantLibrary array below.");
1128
4
  static_assert(3 == static_cast<uint8_t>(QuantTable::DCT4X4),
1129
4
                "Update the DequantLibrary array below.");
1130
4
  static_assert(4 == static_cast<uint8_t>(QuantTable::DCT16X16),
1131
4
                "Update the DequantLibrary array below.");
1132
4
  static_assert(5 == static_cast<uint8_t>(QuantTable::DCT32X32),
1133
4
                "Update the DequantLibrary array below.");
1134
4
  static_assert(6 == static_cast<uint8_t>(QuantTable::DCT8X16),
1135
4
                "Update the DequantLibrary array below.");
1136
4
  static_assert(7 == static_cast<uint8_t>(QuantTable::DCT8X32),
1137
4
                "Update the DequantLibrary array below.");
1138
4
  static_assert(8 == static_cast<uint8_t>(QuantTable::DCT16X32),
1139
4
                "Update the DequantLibrary array below.");
1140
4
  static_assert(9 == static_cast<uint8_t>(QuantTable::DCT4X8),
1141
4
                "Update the DequantLibrary array below.");
1142
4
  static_assert(10 == static_cast<uint8_t>(QuantTable::AFV0),
1143
4
                "Update the DequantLibrary array below.");
1144
4
  static_assert(11 == static_cast<uint8_t>(QuantTable::DCT64X64),
1145
4
                "Update the DequantLibrary array below.");
1146
4
  static_assert(12 == static_cast<uint8_t>(QuantTable::DCT32X64),
1147
4
                "Update the DequantLibrary array below.");
1148
4
  static_assert(13 == static_cast<uint8_t>(QuantTable::DCT128X128),
1149
4
                "Update the DequantLibrary array below.");
1150
4
  static_assert(14 == static_cast<uint8_t>(QuantTable::DCT64X128),
1151
4
                "Update the DequantLibrary array below.");
1152
4
  static_assert(15 == static_cast<uint8_t>(QuantTable::DCT256X256),
1153
4
                "Update the DequantLibrary array below.");
1154
4
  static_assert(16 == static_cast<uint8_t>(QuantTable::DCT128X256),
1155
4
                "Update the DequantLibrary array below.");
1156
4
  return DequantMatrices::DequantLibraryInternal{{
1157
4
      DequantMatricesLibraryDef::DCT(),
1158
4
      DequantMatricesLibraryDef::IDENTITY(),
1159
4
      DequantMatricesLibraryDef::DCT2X2(),
1160
4
      DequantMatricesLibraryDef::DCT4X4(),
1161
4
      DequantMatricesLibraryDef::DCT16X16(),
1162
4
      DequantMatricesLibraryDef::DCT32X32(),
1163
4
      DequantMatricesLibraryDef::DCT8X16(),
1164
4
      DequantMatricesLibraryDef::DCT8X32(),
1165
4
      DequantMatricesLibraryDef::DCT16X32(),
1166
4
      DequantMatricesLibraryDef::DCT4X8(),
1167
4
      DequantMatricesLibraryDef::AFV0(),
1168
4
      DequantMatricesLibraryDef::DCT64X64(),
1169
4
      DequantMatricesLibraryDef::DCT32X64(),
1170
      // Same default for large transforms (128+) as for 64x* transforms.
1171
4
      DequantMatricesLibraryDef::DCT128X128(),
1172
4
      DequantMatricesLibraryDef::DCT64X128(),
1173
4
      DequantMatricesLibraryDef::DCT256X256(),
1174
4
      DequantMatricesLibraryDef::DCT128X256(),
1175
4
  }};
1176
4
}
1177
1178
18.8k
const QuantEncoding* DequantMatrices::Library() {
1179
18.8k
  static const DequantMatrices::DequantLibraryInternal kDequantLibrary =
1180
18.8k
      DequantMatrices::LibraryInit();
1181
  // Downcast the result to a const QuantEncoding* from QuantEncodingInternal*
1182
  // since the subclass (QuantEncoding) doesn't add any new members and users
1183
  // will need to upcast to QuantEncodingInternal to access the members of that
1184
  // class. This allows to have kDequantLibrary as a constexpr value while still
1185
  // allowing to create QuantEncoding::RAW() instances that use std::vector in
1186
  // C++11.
1187
18.8k
  return reinterpret_cast<const QuantEncoding*>(kDequantLibrary.data());
1188
18.8k
}
1189
1190
256k
DequantMatrices::DequantMatrices() {
1191
256k
  encodings_.resize(kNumQuantTables, QuantEncoding::Library<0>());
1192
256k
  size_t pos = 0;
1193
256k
  size_t offsets[kNumQuantTables * 3];
1194
4.61M
  for (size_t i = 0; i < static_cast<size_t>(kNumQuantTables); i++) {
1195
4.36M
    size_t num_blocks =
1196
4.36M
        static_cast<size_t>(required_size_x[i]) * required_size_y[i];
1197
4.36M
    size_t num = num_blocks * kDCTBlockSize;
1198
17.4M
    for (size_t c = 0; c < 3; c++) {
1199
13.0M
      offsets[3 * i + c] = pos + c * num;
1200
13.0M
    }
1201
4.36M
    pos += 3 * num;
1202
4.36M
  }
1203
7.18M
  for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
1204
27.7M
    for (size_t c = 0; c < 3; c++) {
1205
20.7M
      table_offsets_[i * 3 + c] =
1206
20.7M
          offsets[static_cast<size_t>(kAcStrategyToQuantTableMap[i]) * 3 + c];
1207
20.7M
    }
1208
6.92M
  }
1209
256k
}
1210
1211
Status DequantMatrices::EnsureComputed(JxlMemoryManager* memory_manager,
1212
18.8k
                                       uint32_t acs_mask) {
1213
18.8k
  const QuantEncoding* library = Library();
1214
1215
18.8k
  if (!table_storage_) {
1216
18.4k
    size_t table_storage_bytes = 2 * kTotalTableSize * sizeof(float);
1217
18.4k
    JXL_ASSIGN_OR_RETURN(
1218
18.4k
        table_storage_,
1219
18.4k
        AlignedMemory::Create(memory_manager, table_storage_bytes));
1220
18.4k
    table_ = table_storage_.address<float>();
1221
18.4k
    inv_table_ = table_ + kTotalTableSize;
1222
18.4k
  }
1223
1224
18.8k
  size_t offsets[kNumQuantTables * 3 + 1];
1225
18.8k
  size_t pos = 0;
1226
338k
  for (size_t i = 0; i < kNumQuantTables; i++) {
1227
320k
    size_t num_blocks =
1228
320k
        static_cast<size_t>(required_size_x[i]) * required_size_y[i];
1229
320k
    size_t num = num_blocks * kDCTBlockSize;
1230
1.28M
    for (size_t c = 0; c < 3; c++) {
1231
960k
      offsets[3 * i + c] = pos + c * num;
1232
960k
    }
1233
320k
    pos += 3 * num;
1234
320k
  }
1235
18.8k
  offsets[kNumQuantTables * 3] = pos;
1236
18.8k
  JXL_ENSURE(pos == kTotalTableSize);
1237
1238
18.8k
  uint32_t kind_mask = 0;
1239
527k
  for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
1240
508k
    if (acs_mask & (1u << i)) {
1241
42.3k
      kind_mask |= 1u << static_cast<uint32_t>(kAcStrategyToQuantTableMap[i]);
1242
42.3k
    }
1243
508k
  }
1244
18.8k
  uint32_t computed_kind_mask = 0;
1245
527k
  for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
1246
508k
    if (computed_mask_ & (1u << i)) {
1247
1.88k
      computed_kind_mask |=
1248
1.88k
          1u << static_cast<uint32_t>(kAcStrategyToQuantTableMap[i]);
1249
1.88k
    }
1250
508k
  }
1251
337k
  for (size_t table = 0; table < kNumQuantTables; table++) {
1252
318k
    if ((1 << table) & computed_kind_mask) continue;
1253
317k
    if ((1 << table) & ~kind_mask) continue;
1254
34.3k
    size_t offset = offsets[table * 3];
1255
34.3k
    float* mutable_table = table_storage_.address<float>();
1256
34.3k
    if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) {
1257
33.8k
      JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)(
1258
33.8k
          library[table], mutable_table, mutable_table + kTotalTableSize, table,
1259
33.8k
          QuantTable(table), &offset));
1260
33.8k
    } else {
1261
499
      JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)(
1262
499
          encodings_[table], mutable_table, mutable_table + kTotalTableSize,
1263
499
          table, QuantTable(table), &offset));
1264
499
    }
1265
34.2k
    JXL_ENSURE(offset == offsets[table * 3 + 3]);
1266
34.2k
  }
1267
18.7k
  computed_mask_ |= acs_mask;
1268
1269
18.7k
  return true;
1270
18.8k
}
1271
1272
}  // namespace jxl
1273
#endif