Coverage Report

Created: 2025-06-22 08:04

/src/libjxl/lib/jxl/compressed_dc.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/compressed_dc.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <algorithm>
11
#include <cstdint>
12
#include <cstdlib>
13
#include <cstring>
14
#include <vector>
15
16
#undef HWY_TARGET_INCLUDE
17
#define HWY_TARGET_INCLUDE "lib/jxl/compressed_dc.cc"
18
#include <hwy/foreach_target.h>
19
#include <hwy/highway.h>
20
21
#include "lib/jxl/base/compiler_specific.h"
22
#include "lib/jxl/base/data_parallel.h"
23
#include "lib/jxl/base/rect.h"
24
#include "lib/jxl/base/status.h"
25
#include "lib/jxl/image.h"
26
HWY_BEFORE_NAMESPACE();
27
namespace jxl {
28
namespace HWY_NAMESPACE {
29
30
using D = HWY_FULL(float);
31
using DScalar = HWY_CAPPED(float, 1);
32
33
// These templates are not found via ADL.
34
using hwy::HWY_NAMESPACE::Abs;
35
using hwy::HWY_NAMESPACE::Add;
36
using hwy::HWY_NAMESPACE::Div;
37
using hwy::HWY_NAMESPACE::Max;
38
using hwy::HWY_NAMESPACE::Mul;
39
using hwy::HWY_NAMESPACE::MulAdd;
40
using hwy::HWY_NAMESPACE::Rebind;
41
using hwy::HWY_NAMESPACE::Sub;
42
using hwy::HWY_NAMESPACE::Vec;
43
using hwy::HWY_NAMESPACE::ZeroIfNegative;
44
45
// TODO(veluca): optimize constants.
46
const float w1 = 0.20345139757231578f;
47
const float w2 = 0.0334829185968739f;
48
const float w0 = 1.0f - 4.0f * (w1 + w2);
49
50
template <class V>
51
461k
V MaxWorkaround(V a, V b) {
52
#if (HWY_TARGET == HWY_AVX3) && HWY_COMPILER_CLANG <= 800
53
  // Prevents "Do not know how to split the result of this operator" error
54
  return IfThenElse(a > b, a, b);
55
#else
56
461k
  return Max(a, b);
57
461k
#endif
58
461k
}
59
60
template <typename D>
61
JXL_INLINE void ComputePixelChannel(const D d, const float dc_factor,
62
                                    const float* JXL_RESTRICT row_top,
63
                                    const float* JXL_RESTRICT row,
64
                                    const float* JXL_RESTRICT row_bottom,
65
                                    Vec<D>* JXL_RESTRICT mc,
66
                                    Vec<D>* JXL_RESTRICT sm,
67
462k
                                    Vec<D>* JXL_RESTRICT gap, size_t x) {
68
462k
  const auto tl = LoadU(d, row_top + x - 1);
69
462k
  const auto tc = Load(d, row_top + x);
70
462k
  const auto tr = LoadU(d, row_top + x + 1);
71
72
462k
  const auto ml = LoadU(d, row + x - 1);
73
462k
  *mc = Load(d, row + x);
74
462k
  const auto mr = LoadU(d, row + x + 1);
75
76
462k
  const auto bl = LoadU(d, row_bottom + x - 1);
77
462k
  const auto bc = Load(d, row_bottom + x);
78
462k
  const auto br = LoadU(d, row_bottom + x + 1);
79
80
462k
  const auto w_center = Set(d, w0);
81
462k
  const auto w_side = Set(d, w1);
82
462k
  const auto w_corner = Set(d, w2);
83
84
462k
  const auto corner = Add(Add(tl, tr), Add(bl, br));
85
462k
  const auto side = Add(Add(ml, mr), Add(tc, bc));
86
462k
  *sm = MulAdd(corner, w_corner, MulAdd(side, w_side, Mul(*mc, w_center)));
87
88
462k
  const auto dc_quant = Set(d, dc_factor);
89
462k
  *gap = MaxWorkaround(*gap, Abs(Div(Sub(*mc, *sm), dc_quant)));
90
462k
}
91
92
template <typename D>
93
JXL_INLINE void ComputePixel(
94
    const float* JXL_RESTRICT dc_factors,
95
    const float* JXL_RESTRICT* JXL_RESTRICT rows_top,
96
    const float* JXL_RESTRICT* JXL_RESTRICT rows,
97
    const float* JXL_RESTRICT* JXL_RESTRICT rows_bottom,
98
289k
    float* JXL_RESTRICT* JXL_RESTRICT out_rows, size_t x) {
99
289k
  const D d;
100
289k
  auto mc_x = Undefined(d);
101
289k
  auto mc_y = Undefined(d);
102
289k
  auto mc_b = Undefined(d);
103
289k
  auto sm_x = Undefined(d);
104
289k
  auto sm_y = Undefined(d);
105
289k
  auto sm_b = Undefined(d);
106
289k
  auto gap = Set(d, 0.5f);
107
289k
  ComputePixelChannel(d, dc_factors[0], rows_top[0], rows[0], rows_bottom[0],
108
289k
                      &mc_x, &sm_x, &gap, x);
109
289k
  ComputePixelChannel(d, dc_factors[1], rows_top[1], rows[1], rows_bottom[1],
110
289k
                      &mc_y, &sm_y, &gap, x);
111
289k
  ComputePixelChannel(d, dc_factors[2], rows_top[2], rows[2], rows_bottom[2],
112
289k
                      &mc_b, &sm_b, &gap, x);
113
289k
  auto factor = MulAdd(Set(d, -4.0f), gap, Set(d, 3.0f));
114
289k
  factor = ZeroIfNegative(factor);
115
116
289k
  auto out = MulAdd(Sub(sm_x, mc_x), factor, mc_x);
117
289k
  Store(out, d, out_rows[0] + x);
118
289k
  out = MulAdd(Sub(sm_y, mc_y), factor, mc_y);
119
289k
  Store(out, d, out_rows[1] + x);
120
289k
  out = MulAdd(Sub(sm_b, mc_b), factor, mc_b);
121
289k
  Store(out, d, out_rows[2] + x);
122
289k
}
123
124
Status AdaptiveDCSmoothing(JxlMemoryManager* memory_manager,
125
                           const float* dc_factors, Image3F* dc,
126
9.97k
                           ThreadPool* pool) {
127
9.97k
  const size_t xsize = dc->xsize();
128
9.97k
  const size_t ysize = dc->ysize();
129
9.97k
  if (ysize <= 2 || xsize <= 2) return true;
130
131
  // TODO(veluca): use tile-based processing?
132
  // TODO(veluca): decide if changes to the y channel should be propagated to
133
  // the x and b channels through color correlation.
134
992
  JXL_ENSURE(w1 + w2 < 0.25f);
135
136
1.98k
  JXL_ASSIGN_OR_RETURN(Image3F smoothed,
137
1.98k
                       Image3F::Create(memory_manager, xsize, ysize));
138
  // Fill in borders that the loop below will not. First and last are unused.
139
3.96k
  for (size_t c = 0; c < 3; c++) {
140
5.95k
    for (size_t y : {static_cast<size_t>(0), ysize - 1}) {
141
5.95k
      memcpy(smoothed.PlaneRow(c, y), dc->PlaneRow(c, y),
142
5.95k
             xsize * sizeof(float));
143
5.95k
    }
144
2.97k
  }
145
28.1k
  auto process_row = [&](const uint32_t y, size_t /*thread*/) -> Status {
146
28.1k
    const float* JXL_RESTRICT rows_top[3]{
147
28.1k
        dc->ConstPlaneRow(0, y - 1),
148
28.1k
        dc->ConstPlaneRow(1, y - 1),
149
28.1k
        dc->ConstPlaneRow(2, y - 1),
150
28.1k
    };
151
28.1k
    const float* JXL_RESTRICT rows[3] = {
152
28.1k
        dc->ConstPlaneRow(0, y),
153
28.1k
        dc->ConstPlaneRow(1, y),
154
28.1k
        dc->ConstPlaneRow(2, y),
155
28.1k
    };
156
28.1k
    const float* JXL_RESTRICT rows_bottom[3] = {
157
28.1k
        dc->ConstPlaneRow(0, y + 1),
158
28.1k
        dc->ConstPlaneRow(1, y + 1),
159
28.1k
        dc->ConstPlaneRow(2, y + 1),
160
28.1k
    };
161
28.1k
    float* JXL_RESTRICT rows_out[3] = {
162
28.1k
        smoothed.PlaneRow(0, y),
163
28.1k
        smoothed.PlaneRow(1, y),
164
28.1k
        smoothed.PlaneRow(2, y),
165
28.1k
    };
166
56.1k
    for (size_t x : {static_cast<size_t>(0), xsize - 1}) {
167
224k
      for (size_t c = 0; c < 3; c++) {
168
168k
        rows_out[c][x] = rows[c][x];
169
168k
      }
170
56.1k
    }
171
172
28.1k
    size_t x = 1;
173
    // First pixels
174
28.1k
    const size_t N = Lanes(D());
175
28.1k
    for (; x < std::min(N, xsize - 1); x++) {
176
0
      ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out,
177
0
                            x);
178
0
    }
179
    // Full vectors.
180
314k
    for (; x + N <= xsize - 1; x += N) {
181
285k
      ComputePixel<D>(dc_factors, rows_top, rows, rows_bottom, rows_out, x);
182
285k
    }
183
    // Last pixels.
184
28.1k
    for (; x < xsize - 1; x++) {
185
0
      ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out,
186
0
                            x);
187
0
    }
188
28.1k
    return true;
189
28.1k
  };
190
1.98k
  JXL_RETURN_IF_ERROR(RunOnPool(pool, 1, ysize - 1, ThreadPool::NoInit,
191
1.98k
                                process_row, "DCSmoothingRow"));
192
992
  dc->Swap(smoothed);
193
992
  return true;
194
1.98k
}
195
196
// DC dequantization.
197
void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in,
198
               const float* dc_factors, float mul, const float* cfl_factors,
199
               const YCbCrChromaSubsampling& chroma_subsampling,
200
11.6k
               const BlockCtxMap& bctx) {
201
11.6k
  const HWY_FULL(float) df;
202
11.6k
  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
203
11.6k
  if (chroma_subsampling.Is444()) {
204
11.4k
    const auto fac_x = Set(df, dc_factors[0] * mul);
205
11.4k
    const auto fac_y = Set(df, dc_factors[1] * mul);
206
11.4k
    const auto fac_b = Set(df, dc_factors[2] * mul);
207
11.4k
    const auto cfl_fac_x = Set(df, cfl_factors[0]);
208
11.4k
    const auto cfl_fac_b = Set(df, cfl_factors[2]);
209
76.1k
    for (size_t y = 0; y < r.ysize(); y++) {
210
64.7k
      float* dec_row_x = r.PlaneRow(dc, 0, y);
211
64.7k
      float* dec_row_y = r.PlaneRow(dc, 1, y);
212
64.7k
      float* dec_row_b = r.PlaneRow(dc, 2, y);
213
64.7k
      const int32_t* quant_row_x = in.channel[1].plane.Row(y);
214
64.7k
      const int32_t* quant_row_y = in.channel[0].plane.Row(y);
215
64.7k
      const int32_t* quant_row_b = in.channel[2].plane.Row(y);
216
1.32M
      for (size_t x = 0; x < r.xsize(); x += Lanes(di)) {
217
1.25M
        const auto in_q_x = Load(di, quant_row_x + x);
218
1.25M
        const auto in_q_y = Load(di, quant_row_y + x);
219
1.25M
        const auto in_q_b = Load(di, quant_row_b + x);
220
1.25M
        const auto in_x = Mul(ConvertTo(df, in_q_x), fac_x);
221
1.25M
        const auto in_y = Mul(ConvertTo(df, in_q_y), fac_y);
222
1.25M
        const auto in_b = Mul(ConvertTo(df, in_q_b), fac_b);
223
1.25M
        Store(in_y, df, dec_row_y + x);
224
1.25M
        Store(MulAdd(in_y, cfl_fac_x, in_x), df, dec_row_x + x);
225
1.25M
        Store(MulAdd(in_y, cfl_fac_b, in_b), df, dec_row_b + x);
226
1.25M
      }
227
64.7k
    }
228
11.4k
  } else {
229
537
    for (size_t c : {1, 0, 2}) {
230
537
      Rect rect(r.x0() >> chroma_subsampling.HShift(c),
231
537
                r.y0() >> chroma_subsampling.VShift(c),
232
537
                r.xsize() >> chroma_subsampling.HShift(c),
233
537
                r.ysize() >> chroma_subsampling.VShift(c));
234
537
      const auto fac = Set(df, dc_factors[c] * mul);
235
537
      const Channel& ch = in.channel[c < 2 ? c ^ 1 : c];
236
5.02k
      for (size_t y = 0; y < rect.ysize(); y++) {
237
4.48k
        const int32_t* quant_row = ch.plane.Row(y);
238
4.48k
        float* row = rect.PlaneRow(dc, c, y);
239
48.1k
        for (size_t x = 0; x < rect.xsize(); x += Lanes(di)) {
240
43.6k
          const auto in_q = Load(di, quant_row + x);
241
43.6k
          const auto in = Mul(ConvertTo(df, in_q), fac);
242
43.6k
          Store(in, df, row + x);
243
43.6k
        }
244
4.48k
      }
245
537
    }
246
179
  }
247
11.6k
  if (bctx.num_dc_ctxs <= 1) {
248
64.0k
    for (size_t y = 0; y < r.ysize(); y++) {
249
53.1k
      uint8_t* qdc_row = r.Row(quant_dc, y);
250
53.1k
      memset(qdc_row, 0, sizeof(*qdc_row) * r.xsize());
251
53.1k
    }
252
10.9k
  } else {
253
14.2k
    for (size_t y = 0; y < r.ysize(); y++) {
254
13.5k
      uint8_t* qdc_row_val = r.Row(quant_dc, y);
255
13.5k
      const int32_t* quant_row_x =
256
13.5k
          in.channel[1].plane.Row(y >> chroma_subsampling.VShift(0));
257
13.5k
      const int32_t* quant_row_y =
258
13.5k
          in.channel[0].plane.Row(y >> chroma_subsampling.VShift(1));
259
13.5k
      const int32_t* quant_row_b =
260
13.5k
          in.channel[2].plane.Row(y >> chroma_subsampling.VShift(2));
261
436k
      for (size_t x = 0; x < r.xsize(); x++) {
262
422k
        int bucket_x = 0;
263
422k
        int bucket_y = 0;
264
422k
        int bucket_b = 0;
265
1.42M
        for (int t : bctx.dc_thresholds[0]) {
266
1.42M
          if (quant_row_x[x >> chroma_subsampling.HShift(0)] > t) bucket_x++;
267
1.42M
        }
268
1.12M
        for (int t : bctx.dc_thresholds[1]) {
269
1.12M
          if (quant_row_y[x >> chroma_subsampling.HShift(1)] > t) bucket_y++;
270
1.12M
        }
271
422k
        for (int t : bctx.dc_thresholds[2]) {
272
154k
          if (quant_row_b[x >> chroma_subsampling.HShift(2)] > t) bucket_b++;
273
154k
        }
274
422k
        int bucket = bucket_x;
275
422k
        bucket *= bctx.dc_thresholds[2].size() + 1;
276
422k
        bucket += bucket_b;
277
422k
        bucket *= bctx.dc_thresholds[1].size() + 1;
278
422k
        bucket += bucket_y;
279
422k
        qdc_row_val[x] = bucket;
280
422k
      }
281
13.5k
    }
282
687
  }
283
11.6k
}
284
285
// NOLINTNEXTLINE(google-readability-namespace-comments)
286
}  // namespace HWY_NAMESPACE
287
}  // namespace jxl
288
HWY_AFTER_NAMESPACE();
289
290
#if HWY_ONCE
291
namespace jxl {
292
293
HWY_EXPORT(DequantDC);
294
HWY_EXPORT(AdaptiveDCSmoothing);
295
Status AdaptiveDCSmoothing(JxlMemoryManager* memory_manager,
296
                           const float* dc_factors, Image3F* dc,
297
9.97k
                           ThreadPool* pool) {
298
9.97k
  return HWY_DYNAMIC_DISPATCH(AdaptiveDCSmoothing)(memory_manager, dc_factors,
299
9.97k
                                                   dc, pool);
300
9.97k
}
301
302
void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in,
303
               const float* dc_factors, float mul, const float* cfl_factors,
304
               const YCbCrChromaSubsampling& chroma_subsampling,
305
11.6k
               const BlockCtxMap& bctx) {
306
11.6k
  HWY_DYNAMIC_DISPATCH(DequantDC)
307
11.6k
  (r, dc, quant_dc, in, dc_factors, mul, cfl_factors, chroma_subsampling, bctx);
308
11.6k
}
309
310
}  // namespace jxl
311
#endif  // HWY_ONCE