Coverage Report

Created: 2025-07-16 07:53

/src/libjxl/lib/jxl/enc_adaptive_quantization.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_adaptive_quantization.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <algorithm>
11
#include <atomic>
12
#include <cmath>
13
#include <cstddef>
14
#include <cstdlib>
15
#include <string>
16
#include <vector>
17
18
#include "lib/jxl/memory_manager_internal.h"
19
20
#undef HWY_TARGET_INCLUDE
21
#define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc"
22
#include <hwy/foreach_target.h>
23
#include <hwy/highway.h>
24
25
#include "lib/jxl/ac_strategy.h"
26
#include "lib/jxl/base/common.h"
27
#include "lib/jxl/base/compiler_specific.h"
28
#include "lib/jxl/base/data_parallel.h"
29
#include "lib/jxl/base/fast_math-inl.h"
30
#include "lib/jxl/base/rect.h"
31
#include "lib/jxl/base/status.h"
32
#include "lib/jxl/butteraugli/butteraugli.h"
33
#include "lib/jxl/convolve.h"
34
#include "lib/jxl/dec_cache.h"
35
#include "lib/jxl/dec_group.h"
36
#include "lib/jxl/enc_aux_out.h"
37
#include "lib/jxl/enc_butteraugli_comparator.h"
38
#include "lib/jxl/enc_cache.h"
39
#include "lib/jxl/enc_debug_image.h"
40
#include "lib/jxl/enc_group.h"
41
#include "lib/jxl/enc_modular.h"
42
#include "lib/jxl/enc_params.h"
43
#include "lib/jxl/enc_transforms-inl.h"
44
#include "lib/jxl/epf.h"
45
#include "lib/jxl/frame_dimensions.h"
46
#include "lib/jxl/image.h"
47
#include "lib/jxl/image_bundle.h"
48
#include "lib/jxl/image_ops.h"
49
#include "lib/jxl/quant_weights.h"
50
51
// Set JXL_DEBUG_ADAPTIVE_QUANTIZATION to 1 to enable debugging.
52
#ifndef JXL_DEBUG_ADAPTIVE_QUANTIZATION
53
0
#define JXL_DEBUG_ADAPTIVE_QUANTIZATION 0
54
#endif
55
56
HWY_BEFORE_NAMESPACE();
57
namespace jxl {
58
namespace HWY_NAMESPACE {
59
namespace {
60
61
// These templates are not found via ADL.
62
using hwy::HWY_NAMESPACE::AbsDiff;
63
using hwy::HWY_NAMESPACE::Add;
64
using hwy::HWY_NAMESPACE::And;
65
using hwy::HWY_NAMESPACE::Gt;
66
using hwy::HWY_NAMESPACE::IfThenElseZero;
67
using hwy::HWY_NAMESPACE::Max;
68
using hwy::HWY_NAMESPACE::Min;
69
using hwy::HWY_NAMESPACE::Rebind;
70
using hwy::HWY_NAMESPACE::Sqrt;
71
using hwy::HWY_NAMESPACE::ZeroIfNegative;
72
73
// The following functions modulate an exponent (out_val) and return the updated
74
// value. Their descriptor is limited to 8 lanes for 8x8 blocks.
75
76
// Hack for mask estimation. Eventually replace this code with butteraugli's
77
// masking.
78
0
float ComputeMaskForAcStrategyUse(const float out_val) {
79
0
  const float kMul = 1.0f;
80
0
  const float kOffset = 0.001f;
81
0
  return kMul / (out_val + kOffset);
82
0
}
83
84
template <class D, class V>
85
0
V ComputeMask(const D d, const V out_val) {
86
0
  const auto kBase = Set(d, -0.7647f);
87
0
  const auto kMul4 = Set(d, 9.4708735624378946f);
88
0
  const auto kMul2 = Set(d, 17.35036561631863f);
89
0
  const auto kOffset2 = Set(d, 302.59587815579727f);
90
0
  const auto kMul3 = Set(d, 6.7943250517376494f);
91
0
  const auto kOffset3 = Set(d, 3.7179635626140772f);
92
0
  const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3);
93
0
  const auto kMul0 = Set(d, 0.80061762862741759f);
94
0
  const auto k1 = Set(d, 1.0f);
95
96
  // Avoid division by zero.
97
0
  const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f));
98
0
  const auto v2 = Div(k1, Add(v1, kOffset2));
99
0
  const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3));
100
0
  const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4));
101
  // TODO(jyrki):
102
  // A log or two here could make sense. In butteraugli we have effectively
103
  // log(log(x + C)) for this kind of use, as a single log is used in
104
  // saturating visual masking and here the modulation values are exponential,
105
  // another log would counter that.
106
0
  return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3))));
107
0
}
108
109
// mul and mul2 represent a scaling difference between jxl and butteraugli.
110
const float kSGmul = 226.77216153508914f;
111
const float kSGmul2 = 1.0f / 73.377132366608819f;
112
const float kLog2 = 0.693147181f;
113
// Includes correction factor for std::log -> log2.
114
const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2;
115
const float kSGVOffset = 7.7825991679894591f;
116
117
template <bool invert, typename D, typename V>
118
0
V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) {
119
  // The opsin space in jxl is the cubic root of photons, i.e., v * v * v
120
  // is related to the number of photons.
121
  //
122
  // SimpleGamma(v * v * v) is the psychovisual space in butteraugli.
123
  // This ratio allows quantization to move from jxl's opsin space to
124
  // butteraugli's log-gamma space.
125
0
  float kEpsilon = 1e-2;
126
0
  v = ZeroIfNegative(v);
127
0
  const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul);
128
0
  const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon);
129
0
  const auto kDenMul = Set(d, kLog2 * kSGmul);
130
131
0
  const auto v2 = Mul(v, v);
132
133
0
  const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon));
134
0
  const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset);
135
0
  return invert ? Div(num, den) : Div(den, num);
136
0
}
Unexecuted instantiation: enc_adaptive_quantization.cc:hwy::N_SCALAR::Vec1<float> jxl::N_SCALAR::(anonymous namespace)::RatioOfDerivativesOfCubicRootToSimpleGamma<false, hwy::N_SCALAR::Simd<float, 1ul, 0>, hwy::N_SCALAR::Vec1<float> >(hwy::N_SCALAR::Simd<float, 1ul, 0>, hwy::N_SCALAR::Vec1<float>)
Unexecuted instantiation: enc_adaptive_quantization.cc:hwy::N_SCALAR::Vec1<float> jxl::N_SCALAR::(anonymous namespace)::RatioOfDerivativesOfCubicRootToSimpleGamma<true, hwy::N_SCALAR::Simd<float, 1ul, 0>, hwy::N_SCALAR::Vec1<float> >(hwy::N_SCALAR::Simd<float, 1ul, 0>, hwy::N_SCALAR::Vec1<float>)
137
138
template <bool invert = false>
139
0
float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) {
140
0
  using DScalar = HWY_CAPPED(float, 1);
141
0
  auto vscalar = Load(DScalar(), &v);
142
0
  return GetLane(
143
0
      RatioOfDerivativesOfCubicRootToSimpleGamma<invert>(DScalar(), vscalar));
144
0
}
145
146
// TODO(veluca): this function computes an approximation of the derivative of
147
// SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or
148
// exact derivatives. For reference, SimpleGamma was:
149
/*
150
template <typename D, typename V>
151
V SimpleGamma(const D d, V v) {
152
  // A simple HDR compatible gamma function.
153
  const auto mul = Set(d, kSGmul);
154
  const auto kRetMul = Set(d, kSGRetMul);
155
  const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f);
156
  const auto kVOffset = Set(d, kSGVOffset);
157
158
  v *= mul;
159
160
  // This should happen rarely, but may lead to a NaN, which is rather
161
  // undesirable. Since negative photons don't exist we solve the NaNs by
162
  // clamping here.
163
  // TODO(veluca): with FastLog2f, this no longer leads to NaNs.
164
  v = ZeroIfNegative(v);
165
  return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd;
166
}
167
*/
168
169
template <class D, class V>
170
V GammaModulation(const D d, const size_t x, const size_t y,
171
                  const ImageF& xyb_x, const ImageF& xyb_y, const Rect& rect,
172
0
                  const V out_val) {
173
0
  const float kBias = 0.16f;
174
0
  JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[0]);
175
0
  JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[1]);
176
0
  JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[2]);
177
0
  auto overall_ratio = Zero(d);
178
0
  auto bias = Set(d, kBias);
179
0
  for (size_t dy = 0; dy < 8; ++dy) {
180
0
    const float* const JXL_RESTRICT row_in_x = rect.ConstRow(xyb_x, y + dy);
181
0
    const float* const JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy);
182
0
    for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
183
0
      const auto iny = Add(Load(d, row_in_y + x + dx), bias);
184
0
      const auto inx = Load(d, row_in_x + x + dx);
185
186
0
      const auto r = Sub(iny, inx);
187
0
      const auto ratio_r =
188
0
          RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, r);
189
0
      overall_ratio = Add(overall_ratio, ratio_r);
190
191
0
      const auto g = Add(iny, inx);
192
0
      const auto ratio_g =
193
0
          RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, g);
194
0
      overall_ratio = Add(overall_ratio, ratio_g);
195
0
    }
196
0
  }
197
0
  overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 0.5f / 64));
198
  // ideally -1.0, but likely optimal correction adds some entropy, so slightly
199
  // less than that.
200
0
  const auto kGamma = Set(d, 0.1005613337192697f);
201
0
  return MulAdd(kGamma, FastLog2f(d, overall_ratio), out_val);
202
0
}
203
204
// Change precision in 8x8 blocks that have significant amounts of blue
205
// content (but are not close to solid blue).
206
// This is based on the idea that M and L cone activations saturate the
207
// S (blue) receptors, and the S reception becomes more important when
208
// both M and L levels are low. In that case M and L receptors may be
209
// observing S-spectra instead and viewing them with higher spatial
210
// accuracy, justifying spending more bits here.
211
template <class D, class V>
212
V BlueModulation(const D d, const size_t x, const size_t y,
213
                 const ImageF& planex, const ImageF& planey,
214
0
                 const ImageF& planeb, const Rect& rect, const V out_val) {
215
0
  auto sum = Zero(d);
216
0
  static const float kLimit = 0.027121074570634722;
217
0
  static const float kOffset = 0.084381641171960495;
218
0
  for (size_t dy = 0; dy < 8; ++dy) {
219
0
    const float* JXL_RESTRICT row_in_x = rect.ConstRow(planex, y + dy) + x;
220
0
    const float* JXL_RESTRICT row_in_y = rect.ConstRow(planey, y + dy) + x;
221
0
    const float* JXL_RESTRICT row_in_b = rect.ConstRow(planeb, y + dy) + x;
222
0
    for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
223
0
      const auto p_x = Load(d, row_in_x + dx);
224
0
      const auto p_b = Load(d, row_in_b + dx);
225
0
      const auto p_y_raw = Add(Load(d, row_in_y + dx), Set(d, kOffset));
226
0
      const auto p_y_effective = Add(p_y_raw, Abs(p_x));
227
0
      sum = Add(sum,
228
0
                IfThenElseZero(Gt(p_b, p_y_effective),
229
0
                               Min(Sub(p_b, p_y_effective), Set(d, kLimit))));
230
0
    }
231
0
  }
232
0
  static const float kMul = 0.14207000358439159;
233
0
  sum = SumOfLanes(d, sum);
234
0
  float scalar_sum = GetLane(sum);
235
  // If it is all blue, don't boost the quantization.
236
  // All blue likely means low frequency blue. Let's not make the most
237
  // perfect sky ever.
238
0
  if (scalar_sum >= 32 * kLimit) {
239
0
    scalar_sum = 64 * kLimit - scalar_sum;
240
0
  }
241
0
  static const float kMaxLimit = 15.398788439047934f;
242
0
  if (scalar_sum >= kMaxLimit * kLimit) {
243
0
    scalar_sum = kMaxLimit * kLimit;
244
0
  }
245
0
  scalar_sum *= kMul;
246
0
  return Add(Set(d, scalar_sum), out_val);
247
0
}
248
249
// Change precision in 8x8 blocks that have high frequency content.
250
template <class D, class V>
251
V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb_y,
252
0
               const Rect& rect, const V out_val) {
253
  // Zero out the invalid differences for the rightmost value per row.
254
0
  const Rebind<uint32_t, D> du;
255
0
  HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u,
256
0
                                                        ~0u, ~0u, ~0u, 0};
257
258
  // Sums of deltas of y and x components between (approximate)
259
  // 4-connected pixels.
260
0
  auto sum_y = Zero(d);
261
0
  static const float valmin_y = 0.0206;
262
0
  auto valminv_y = Set(d, valmin_y);
263
0
  for (size_t dy = 0; dy < 8; ++dy) {
264
0
    const float* JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy) + x;
265
0
    const float* JXL_RESTRICT row_in_y_next =
266
0
        dy == 7 ? row_in_y : rect.ConstRow(xyb_y, y + dy + 1) + x;
267
268
    // In SCALAR, there is no guarantee of having extra row padding.
269
    // Hence, we need to ensure we don't access pixels outside the row itself.
270
    // In SIMD modes, however, rows are padded, so it's safe to access one
271
    // garbage value after the row. The vector then gets masked with kMaskRight
272
    // to remove the influence of that value.
273
#if HWY_TARGET != HWY_SCALAR
274
    for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
275
#else
276
0
    for (size_t dx = 0; dx < 7; dx += Lanes(d)) {
277
0
#endif
278
0
      const auto mask = BitCast(d, Load(du, kMaskRight + dx));
279
0
      {
280
0
        const auto p_y = Load(d, row_in_y + dx);
281
0
        const auto pr_y = LoadU(d, row_in_y + dx + 1);
282
0
        sum_y = Add(sum_y, And(mask, Min(valminv_y, AbsDiff(p_y, pr_y))));
283
0
        const auto pd_y = Load(d, row_in_y_next + dx);
284
0
        sum_y = Add(sum_y, Min(valminv_y, AbsDiff(p_y, pd_y)));
285
0
      }
286
0
    }
287
0
#if HWY_TARGET == HWY_SCALAR
288
0
    const auto p_y = Load(d, row_in_y + 7);
289
0
    const auto pd_y = Load(d, row_in_y_next + 7);
290
0
    sum_y = Add(sum_y, Min(valminv_y, AbsDiff(p_y, pd_y)));
291
0
#endif
292
0
  }
293
0
  static const float kMul_y = -0.38;
294
0
  sum_y = SumOfLanes(d, sum_y);
295
296
0
  float scalar_sum_y = GetLane(sum_y);
297
0
  scalar_sum_y *= kMul_y;
298
299
  // higher value -> more bpp
300
0
  float kOffset = 0.42;
301
0
  scalar_sum_y += kOffset;
302
303
0
  return Add(Set(d, scalar_sum_y), out_val);
304
0
}
305
306
void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x,
307
                         const ImageF& xyb_y, const ImageF& xyb_b,
308
                         const Rect& rect_in, const float scale,
309
0
                         const Rect& rect_out, ImageF* out) {
310
0
  float base_level = 0.48f * scale;
311
0
  float kDampenRampStart = 2.0f;
312
0
  float kDampenRampEnd = 14.0f;
313
0
  float dampen = 1.0f;
314
0
  if (butteraugli_target >= kDampenRampStart) {
315
0
    dampen = 1.0f - ((butteraugli_target - kDampenRampStart) /
316
0
                     (kDampenRampEnd - kDampenRampStart));
317
0
    if (dampen < 0) {
318
0
      dampen = 0;
319
0
    }
320
0
  }
321
0
  const float mul = scale * dampen;
322
0
  const float add = (1.0f - dampen) * base_level;
323
0
  for (size_t iy = rect_out.y0(); iy < rect_out.y1(); iy++) {
324
0
    const size_t y = iy * 8;
325
0
    float* const JXL_RESTRICT row_out = out->Row(iy);
326
0
    const HWY_CAPPED(float, kBlockDim) df;
327
0
    for (size_t ix = rect_out.x0(); ix < rect_out.x1(); ix++) {
328
0
      size_t x = ix * 8;
329
0
      auto out_val = Set(df, row_out[ix]);
330
0
      out_val = ComputeMask(df, out_val);
331
0
      out_val = HfModulation(df, x, y, xyb_y, rect_in, out_val);
332
0
      out_val = GammaModulation(df, x, y, xyb_x, xyb_y, rect_in, out_val);
333
0
      out_val = BlueModulation(df, x, y, xyb_x, xyb_y, xyb_b, rect_in, out_val);
334
      // We want multiplicative quantization field, so everything
335
      // until this point has been modulating the exponent.
336
0
      row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add;
337
0
    }
338
0
  }
339
0
}
340
341
template <typename D, typename V>
342
0
V MaskingSqrt(const D d, V v) {
343
0
  static const float kLogOffset = 27.505837037000106f;
344
0
  static const float kMul = 211.66567973503678f;
345
0
  const auto mul_v = Set(d, kMul * 1e8);
346
0
  const auto offset_v = Set(d, kLogOffset);
347
0
  return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v)));
348
0
}
349
350
0
float MaskingSqrt(const float v) {
351
0
  using DScalar = HWY_CAPPED(float, 1);
352
0
  auto vscalar = Load(DScalar(), &v);
353
0
  return GetLane(MaskingSqrt(DScalar(), vscalar));
354
0
}
355
356
void StoreMin4(const float v, float& min0, float& min1, float& min2,
357
0
               float& min3) {
358
0
  if (v < min3) {
359
0
    if (v < min0) {
360
0
      min3 = min2;
361
0
      min2 = min1;
362
0
      min1 = min0;
363
0
      min0 = v;
364
0
    } else if (v < min1) {
365
0
      min3 = min2;
366
0
      min2 = min1;
367
0
      min1 = v;
368
0
    } else if (v < min2) {
369
0
      min3 = min2;
370
0
      min2 = v;
371
0
    } else {
372
0
      min3 = v;
373
0
    }
374
0
  }
375
0
}
376
377
// Look for smooth areas near the area of degradation.
378
// If the areas are generally smooth, don't do masking.
379
// Output is downsampled 2x.
380
Status FuzzyErosion(const float butteraugli_target, const Rect& from_rect,
381
0
                    const ImageF& from, const Rect& to_rect, ImageF* to) {
382
0
  const size_t xsize = from.xsize();
383
0
  const size_t ysize = from.ysize();
384
0
  constexpr int kStep = 1;
385
0
  static_assert(kStep == 1, "Step must be 1");
386
0
  JXL_ENSURE(to_rect.xsize() * 2 == from_rect.xsize());
387
0
  JXL_ENSURE(to_rect.ysize() * 2 == from_rect.ysize());
388
0
  static const float kMulBase0 = 0.125;
389
0
  static const float kMulBase1 = 0.10;
390
0
  static const float kMulBase2 = 0.09;
391
0
  static const float kMulBase3 = 0.06;
392
0
  static const float kMulAdd0 = 0.0;
393
0
  static const float kMulAdd1 = -0.10;
394
0
  static const float kMulAdd2 = -0.09;
395
0
  static const float kMulAdd3 = -0.06;
396
397
0
  float mul = 0.0;
398
0
  if (butteraugli_target < 2.0f) {
399
0
    mul = (2.0f - butteraugli_target) * (1.0f / 2.0f);
400
0
  }
401
0
  float kMul0 = kMulBase0 + mul * kMulAdd0;
402
0
  float kMul1 = kMulBase1 + mul * kMulAdd1;
403
0
  float kMul2 = kMulBase2 + mul * kMulAdd2;
404
0
  float kMul3 = kMulBase3 + mul * kMulAdd3;
405
0
  static const float kTotal = 0.29959705784054957;
406
0
  float norm = kTotal / (kMul0 + kMul1 + kMul2 + kMul3);
407
0
  kMul0 *= norm;
408
0
  kMul1 *= norm;
409
0
  kMul2 *= norm;
410
0
  kMul3 *= norm;
411
412
0
  for (size_t fy = 0; fy < from_rect.ysize(); ++fy) {
413
0
    size_t y = fy + from_rect.y0();
414
0
    size_t ym1 = y >= kStep ? y - kStep : y;
415
0
    size_t yp1 = y + kStep < ysize ? y + kStep : y;
416
0
    const float* rowt = from.Row(ym1);
417
0
    const float* row = from.Row(y);
418
0
    const float* rowb = from.Row(yp1);
419
0
    float* row_out = to_rect.Row(to, fy / 2);
420
0
    for (size_t fx = 0; fx < from_rect.xsize(); ++fx) {
421
0
      size_t x = fx + from_rect.x0();
422
0
      size_t xm1 = x >= kStep ? x - kStep : x;
423
0
      size_t xp1 = x + kStep < xsize ? x + kStep : x;
424
0
      float min0 = row[x];
425
0
      float min1 = row[xm1];
426
0
      float min2 = row[xp1];
427
0
      float min3 = rowt[xm1];
428
      // Sort the first four values.
429
0
      if (min0 > min1) std::swap(min0, min1);
430
0
      if (min0 > min2) std::swap(min0, min2);
431
0
      if (min0 > min3) std::swap(min0, min3);
432
0
      if (min1 > min2) std::swap(min1, min2);
433
0
      if (min1 > min3) std::swap(min1, min3);
434
0
      if (min2 > min3) std::swap(min2, min3);
435
      // The remaining five values of a 3x3 neighbourhood.
436
0
      StoreMin4(rowt[x], min0, min1, min2, min3);
437
0
      StoreMin4(rowt[xp1], min0, min1, min2, min3);
438
0
      StoreMin4(rowb[xm1], min0, min1, min2, min3);
439
0
      StoreMin4(rowb[x], min0, min1, min2, min3);
440
0
      StoreMin4(rowb[xp1], min0, min1, min2, min3);
441
442
0
      float v = kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + kMul3 * min3;
443
0
      if (fx % 2 == 0 && fy % 2 == 0) {
444
0
        row_out[fx / 2] = v;
445
0
      } else {
446
0
        row_out[fx / 2] += v;
447
0
      }
448
0
    }
449
0
  }
450
0
  return true;
451
0
}
452
453
struct AdaptiveQuantizationImpl {
454
0
  Status PrepareBuffers(JxlMemoryManager* memory_manager, size_t num_threads) {
455
0
    JXL_ASSIGN_OR_RETURN(
456
0
        diff_buffer,
457
0
        ImageF::Create(memory_manager, kEncTileDim + 8, num_threads));
458
0
    for (size_t i = pre_erosion.size(); i < num_threads; i++) {
459
0
      JXL_ASSIGN_OR_RETURN(
460
0
          ImageF tmp,
461
0
          ImageF::Create(memory_manager, kEncTileDimInBlocks * 2 + 2,
462
0
                         kEncTileDimInBlocks * 2 + 2));
463
0
      pre_erosion.emplace_back(std::move(tmp));
464
0
    }
465
0
    return true;
466
0
  }
467
468
  Status ComputeTile(float butteraugli_target, float scale, const Image3F& xyb,
469
                     const Rect& rect_in, const Rect& rect_out,
470
0
                     const int thread, ImageF* mask, ImageF* mask1x1) {
471
0
    JXL_ENSURE(rect_in.x0() % kBlockDim == 0);
472
0
    JXL_ENSURE(rect_in.y0() % kBlockDim == 0);
473
0
    const size_t xsize = xyb.xsize();
474
0
    const size_t ysize = xyb.ysize();
475
476
    // The XYB gamma is 3.0 to be able to decode faster with two muls.
477
    // Butteraugli's gamma is matching the gamma of human eye, around 2.6.
478
    // We approximate the gamma difference by adding one cubic root into
479
    // the adaptive quantization. This gives us a total gamma of 2.6666
480
    // for quantization uses.
481
0
    const float match_gamma_offset = 0.019;
482
483
0
    const HWY_FULL(float) df;
484
485
0
    size_t y_start_1x1 = rect_in.y0() + rect_out.y0() * 8;
486
0
    size_t y_end_1x1 = y_start_1x1 + rect_out.ysize() * 8;
487
488
0
    size_t x_start_1x1 = rect_in.x0() + rect_out.x0() * 8;
489
0
    size_t x_end_1x1 = x_start_1x1 + rect_out.xsize() * 8;
490
491
0
    if (rect_in.x0() != 0 && rect_out.x0() == 0) x_start_1x1 -= 2;
492
0
    if (rect_in.x1() < xsize && rect_out.x1() * 8 == rect_in.xsize()) {
493
0
      x_end_1x1 += 2;
494
0
    }
495
0
    if (rect_in.y0() != 0 && rect_out.y0() == 0) y_start_1x1 -= 2;
496
0
    if (rect_in.y1() < ysize && rect_out.y1() * 8 == rect_in.ysize()) {
497
0
      y_end_1x1 += 2;
498
0
    }
499
500
    // Computes image (padded to multiple of 8x8) of local pixel differences.
501
    // Subsample both directions by 4.
502
    // 1x1 Laplacian of intensity.
503
0
    for (size_t y = y_start_1x1; y < y_end_1x1; ++y) {
504
0
      const size_t y2 = y + 1 < ysize ? y + 1 : y;
505
0
      const size_t y1 = y > 0 ? y - 1 : y;
506
0
      const float* row_in = xyb.ConstPlaneRow(1, y);
507
0
      const float* row_in1 = xyb.ConstPlaneRow(1, y1);
508
0
      const float* row_in2 = xyb.ConstPlaneRow(1, y2);
509
0
      float* mask1x1_out = mask1x1->Row(y);
510
0
      auto scalar_pixel1x1 = [&](size_t x) {
511
0
        const size_t x2 = x + 1 < xsize ? x + 1 : x;
512
0
        const size_t x1 = x > 0 ? x - 1 : x;
513
0
        const float base =
514
0
            0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]);
515
0
        const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma(
516
0
            row_in[x] + match_gamma_offset);
517
0
        float diff = fabs(gammac * (row_in[x] - base));
518
0
        static const double kScaler = 1.0;
519
0
        diff *= kScaler;
520
0
        diff = log1p(diff);
521
0
        static const float kMul = 1.0;
522
0
        static const float kOffset = 0.01;
523
0
        mask1x1_out[x] = kMul / (diff + kOffset);
524
0
      };
525
0
      for (size_t x = x_start_1x1; x < x_end_1x1; ++x) {
526
0
        scalar_pixel1x1(x);
527
0
      }
528
0
    }
529
530
0
    size_t y_start = rect_in.y0() + rect_out.y0() * 8;
531
0
    size_t y_end = y_start + rect_out.ysize() * 8;
532
533
0
    size_t x_start = rect_in.x0() + rect_out.x0() * 8;
534
0
    size_t x_end = x_start + rect_out.xsize() * 8;
535
536
0
    if (x_start != 0) x_start -= 4;
537
0
    if (x_end != xsize) x_end += 4;
538
0
    if (y_start != 0) y_start -= 4;
539
0
    if (y_end != ysize) y_end += 4;
540
0
    JXL_RETURN_IF_ERROR(pre_erosion[thread].ShrinkTo((x_end - x_start) / 4,
541
0
                                                     (y_end - y_start) / 4));
542
543
0
    static const float limit = 0.2f;
544
0
    for (size_t y = y_start; y < y_end; ++y) {
545
0
      size_t y2 = y + 1 < ysize ? y + 1 : y;
546
0
      size_t y1 = y > 0 ? y - 1 : y;
547
548
0
      const float* row_in = xyb.ConstPlaneRow(1, y);
549
0
      const float* row_in1 = xyb.ConstPlaneRow(1, y1);
550
0
      const float* row_in2 = xyb.ConstPlaneRow(1, y2);
551
0
      float* JXL_RESTRICT row_out = diff_buffer.Row(thread);
552
553
0
      auto scalar_pixel = [&](size_t x) {
554
0
        const size_t x2 = x + 1 < xsize ? x + 1 : x;
555
0
        const size_t x1 = x > 0 ? x - 1 : x;
556
0
        const float base =
557
0
            0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]);
558
0
        const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma(
559
0
            row_in[x] + match_gamma_offset);
560
0
        float diff = gammac * (row_in[x] - base);
561
0
        diff *= diff;
562
0
        if (diff >= limit) {
563
0
          diff = limit;
564
0
        }
565
0
        diff = MaskingSqrt(diff);
566
0
        if ((y % 4) != 0) {
567
0
          row_out[x - x_start] += diff;
568
0
        } else {
569
0
          row_out[x - x_start] = diff;
570
0
        }
571
0
      };
572
573
0
      size_t x = x_start;
574
      // First pixel of the row.
575
0
      if (x_start == 0) {
576
0
        scalar_pixel(x_start);
577
0
        ++x;
578
0
      }
579
      // SIMD
580
0
      const auto match_gamma_offset_v = Set(df, match_gamma_offset);
581
0
      const auto quarter = Set(df, 0.25f);
582
0
      for (; x + 1 + Lanes(df) < x_end; x += Lanes(df)) {
583
0
        const auto in = LoadU(df, row_in + x);
584
0
        const auto in_r = LoadU(df, row_in + x + 1);
585
0
        const auto in_l = LoadU(df, row_in + x - 1);
586
0
        const auto in_t = LoadU(df, row_in2 + x);
587
0
        const auto in_b = LoadU(df, row_in1 + x);
588
0
        auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b)));
589
0
        auto gammacv =
590
0
            RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/false>(
591
0
                df, Add(in, match_gamma_offset_v));
592
0
        auto diff = Mul(gammacv, Sub(in, base));
593
0
        diff = Mul(diff, diff);
594
0
        diff = Min(diff, Set(df, limit));
595
0
        diff = MaskingSqrt(df, diff);
596
0
        if ((y & 3) != 0) {
597
0
          diff = Add(diff, LoadU(df, row_out + x - x_start));
598
0
        }
599
0
        StoreU(diff, df, row_out + x - x_start);
600
0
      }
601
      // Scalar
602
0
      for (; x < x_end; ++x) {
603
0
        scalar_pixel(x);
604
0
      }
605
0
      if (y % 4 == 3) {
606
0
        float* row_d_out = pre_erosion[thread].Row((y - y_start) / 4);
607
0
        for (size_t x = 0; x < (x_end - x_start) / 4; x++) {
608
0
          row_d_out[x] = (row_out[x * 4] + row_out[x * 4 + 1] +
609
0
                          row_out[x * 4 + 2] + row_out[x * 4 + 3]) *
610
0
                         0.25f;
611
0
        }
612
0
      }
613
0
    }
614
0
    JXL_ENSURE(x_start % (kBlockDim / 2) == 0);
615
0
    JXL_ENSURE(y_start % (kBlockDim / 2) == 0);
616
0
    Rect from_rect(x_start % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1,
617
0
                   rect_out.xsize() * 2, rect_out.ysize() * 2);
618
0
    JXL_RETURN_IF_ERROR(FuzzyErosion(butteraugli_target, from_rect,
619
0
                                     pre_erosion[thread], rect_out, &aq_map));
620
0
    for (size_t y = 0; y < rect_out.ysize(); ++y) {
621
0
      const float* aq_map_row = rect_out.ConstRow(aq_map, y);
622
0
      float* mask_row = rect_out.Row(mask, y);
623
0
      for (size_t x = 0; x < rect_out.xsize(); ++x) {
624
0
        mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]);
625
0
      }
626
0
    }
627
0
    PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1),
628
0
                        xyb.Plane(2), rect_in, scale, rect_out, &aq_map);
629
0
    return true;
630
0
  }
631
  std::vector<ImageF> pre_erosion;
632
  ImageF aq_map;
633
  ImageF diff_buffer;
634
};
635
636
Status Blur1x1Masking(JxlMemoryManager* memory_manager, ThreadPool* pool,
637
0
                      ImageF* mask1x1, const Rect& rect) {
638
  // Blur the mask1x1 to obtain the masking image.
639
  // Before blurring it contains an image of absolute value of the
640
  // Laplacian of the intensity channel.
641
0
  static const float kFilterMask1x1[5] = {
642
0
      static_cast<float>(0.25647067633737227),
643
0
      static_cast<float>(0.2050056912354399075),
644
0
      static_cast<float>(0.154082048668497307),
645
0
      static_cast<float>(0.08149576591362004441),
646
0
      static_cast<float>(0.0512750104812308467),
647
0
  };
648
0
  double sum =
649
0
      1.0 + 4 * (kFilterMask1x1[0] + kFilterMask1x1[1] + kFilterMask1x1[2] +
650
0
                 kFilterMask1x1[4] + 2 * kFilterMask1x1[3]);
651
0
  if (sum < 1e-5) {
652
0
    sum = 1e-5;
653
0
  }
654
0
  const float normalize = static_cast<float>(1.0 / sum);
655
0
  const float normalize_mul = normalize;
656
0
  WeightsSymmetric5 weights =
657
0
      WeightsSymmetric5{{HWY_REP4(normalize)},
658
0
                        {HWY_REP4(normalize_mul * kFilterMask1x1[0])},
659
0
                        {HWY_REP4(normalize_mul * kFilterMask1x1[2])},
660
0
                        {HWY_REP4(normalize_mul * kFilterMask1x1[1])},
661
0
                        {HWY_REP4(normalize_mul * kFilterMask1x1[4])},
662
0
                        {HWY_REP4(normalize_mul * kFilterMask1x1[3])}};
663
0
  JXL_ASSIGN_OR_RETURN(
664
0
      ImageF temp, ImageF::Create(memory_manager, rect.xsize(), rect.ysize()));
665
0
  JXL_RETURN_IF_ERROR(Symmetric5(*mask1x1, rect, weights, pool, &temp));
666
0
  *mask1x1 = std::move(temp);
667
0
  return true;
668
0
}
669
670
StatusOr<ImageF> AdaptiveQuantizationMap(const float butteraugli_target,
671
                                         const Image3F& xyb, const Rect& rect,
672
                                         float scale, ThreadPool* pool,
673
0
                                         ImageF* mask, ImageF* mask1x1) {
674
0
  JXL_ENSURE(rect.xsize() % kBlockDim == 0);
675
0
  JXL_ENSURE(rect.ysize() % kBlockDim == 0);
676
0
  AdaptiveQuantizationImpl impl;
677
0
  const size_t xsize_blocks = rect.xsize() / kBlockDim;
678
0
  const size_t ysize_blocks = rect.ysize() / kBlockDim;
679
0
  JxlMemoryManager* memory_manager = xyb.memory_manager();
680
0
  JXL_ASSIGN_OR_RETURN(
681
0
      impl.aq_map, ImageF::Create(memory_manager, xsize_blocks, ysize_blocks));
682
0
  JXL_ASSIGN_OR_RETURN(
683
0
      *mask, ImageF::Create(memory_manager, xsize_blocks, ysize_blocks));
684
0
  JXL_ASSIGN_OR_RETURN(
685
0
      *mask1x1, ImageF::Create(memory_manager, xyb.xsize(), xyb.ysize()));
686
0
  const auto prepare = [&](const size_t num_threads) -> Status {
687
0
    JXL_RETURN_IF_ERROR(impl.PrepareBuffers(memory_manager, num_threads));
688
0
    return true;
689
0
  };
690
0
  const auto process_tile = [&](const uint32_t tid,
691
0
                                const size_t thread) -> Status {
692
0
    size_t n_enc_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks);
693
0
    size_t tx = tid % n_enc_tiles;
694
0
    size_t ty = tid / n_enc_tiles;
695
0
    size_t by0 = ty * kEncTileDimInBlocks;
696
0
    size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, ysize_blocks);
697
0
    size_t bx0 = tx * kEncTileDimInBlocks;
698
0
    size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, xsize_blocks);
699
0
    Rect rect_out(bx0, by0, bx1 - bx0, by1 - by0);
700
0
    JXL_RETURN_IF_ERROR(impl.ComputeTile(butteraugli_target, scale, xyb, rect,
701
0
                                         rect_out, thread, mask, mask1x1));
702
0
    return true;
703
0
  };
704
0
  size_t num_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks) *
705
0
                     DivCeil(ysize_blocks, kEncTileDimInBlocks);
706
0
  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_tiles, prepare, process_tile,
707
0
                                "AQ DiffPrecompute"));
708
709
0
  JXL_RETURN_IF_ERROR(Blur1x1Masking(memory_manager, pool, mask1x1, rect));
710
0
  return std::move(impl).aq_map;
711
0
}
712
713
}  // namespace
714
715
// NOLINTNEXTLINE(google-readability-namespace-comments)
716
}  // namespace HWY_NAMESPACE
717
}  // namespace jxl
718
HWY_AFTER_NAMESPACE();
719
720
#if HWY_ONCE
721
namespace jxl {
722
HWY_EXPORT(AdaptiveQuantizationMap);
723
724
namespace {
725
726
// If true, prints the quantization maps at each iteration.
727
constexpr bool FLAGS_dump_quant_state = false;
728
729
Status DumpHeatmap(const CompressParams& cparams, const AuxOut* aux_out,
730
                   const std::string& label, const ImageF& image,
731
0
                   float good_threshold, float bad_threshold) {
732
0
  if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
733
0
    JXL_ASSIGN_OR_RETURN(
734
0
        Image3F heatmap,
735
0
        CreateHeatMapImage(image, good_threshold, bad_threshold));
736
0
    char filename[200];
737
0
    snprintf(filename, sizeof(filename), "%s%05d", label.c_str(),
738
0
             aux_out->num_butteraugli_iters);
739
0
    JXL_RETURN_IF_ERROR(DumpImage(cparams, filename, heatmap));
740
0
  }
741
0
  return true;
742
0
}
743
744
Status DumpHeatmaps(const CompressParams& cparams, const AuxOut* aux_out,
745
                    float butteraugli_target, const ImageF& quant_field,
746
0
                    const ImageF& tile_heatmap, const ImageF& bt_diffmap) {
747
0
  if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
748
0
    JxlMemoryManager* memory_manager = quant_field.memory_manager();
749
0
    if (!WantDebugOutput(cparams)) return true;
750
0
    JXL_ASSIGN_OR_RETURN(ImageF inv_qmap,
751
0
                         ImageF::Create(memory_manager, quant_field.xsize(),
752
0
                                        quant_field.ysize()));
753
0
    for (size_t y = 0; y < quant_field.ysize(); ++y) {
754
0
      const float* JXL_RESTRICT row_q = quant_field.ConstRow(y);
755
0
      float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y);
756
0
      for (size_t x = 0; x < quant_field.xsize(); ++x) {
757
0
        row_inv_q[x] = 1.0f / row_q[x];  // never zero
758
0
      }
759
0
    }
760
0
    JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "quant_heatmap", inv_qmap,
761
0
                                    4.0f * butteraugli_target,
762
0
                                    6.0f * butteraugli_target));
763
0
    JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "tile_heatmap",
764
0
                                    tile_heatmap, butteraugli_target,
765
0
                                    1.5f * butteraugli_target));
766
0
    // matches heat maps produced by the command line tool.
767
0
    JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "bt_diffmap", bt_diffmap,
768
0
                                    ButteraugliFuzzyInverse(1.5),
769
0
                                    ButteraugliFuzzyInverse(0.5)));
770
0
  }
771
0
  return true;
772
0
}
773
774
StatusOr<ImageF> TileDistMap(const ImageF& distmap, int tile_size, int margin,
775
0
                             const AcStrategyImage& ac_strategy) {
776
0
  const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size;
777
0
  const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size;
778
0
  JxlMemoryManager* memory_manager = distmap.memory_manager();
779
0
  JXL_ASSIGN_OR_RETURN(ImageF tile_distmap,
780
0
                       ImageF::Create(memory_manager, tile_xsize, tile_ysize));
781
0
  size_t distmap_stride = tile_distmap.PixelsPerRow();
782
0
  for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) {
783
0
    AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y);
784
0
    float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y);
785
0
    for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) {
786
0
      AcStrategy acs = ac_strategy_row[tile_x];
787
0
      if (!acs.IsFirstBlock()) continue;
788
0
      int this_tile_xsize = acs.covered_blocks_x() * tile_size;
789
0
      int this_tile_ysize = acs.covered_blocks_y() * tile_size;
790
0
      int y_begin = std::max<int>(0, tile_size * tile_y - margin);
791
0
      int y_end = std::min<int>(distmap.ysize(),
792
0
                                tile_size * tile_y + this_tile_ysize + margin);
793
0
      int x_begin = std::max<int>(0, tile_size * tile_x - margin);
794
0
      int x_end = std::min<int>(distmap.xsize(),
795
0
                                tile_size * tile_x + this_tile_xsize + margin);
796
0
      float dist_norm = 0.0;
797
0
      double pixels = 0;
798
0
      for (int y = y_begin; y < y_end; ++y) {
799
0
        float ymul = 1.0;
800
0
        constexpr float kBorderMul = 0.98f;
801
0
        constexpr float kCornerMul = 0.7f;
802
0
        if (margin != 0 && (y == y_begin || y == y_end - 1)) {
803
0
          ymul = kBorderMul;
804
0
        }
805
0
        const float* const JXL_RESTRICT row = distmap.Row(y);
806
0
        for (int x = x_begin; x < x_end; ++x) {
807
0
          float xmul = ymul;
808
0
          if (margin != 0 && (x == x_begin || x == x_end - 1)) {
809
0
            if (xmul == 1.0) {
810
0
              xmul = kBorderMul;
811
0
            } else {
812
0
              xmul = kCornerMul;
813
0
            }
814
0
          }
815
0
          float v = row[x];
816
0
          v *= v;
817
0
          v *= v;
818
0
          v *= v;
819
0
          v *= v;
820
0
          dist_norm += xmul * v;
821
0
          pixels += xmul;
822
0
        }
823
0
      }
824
0
      if (pixels == 0) pixels = 1;
825
      // 16th norm is less than the max norm, we reduce the difference
826
      // with this normalization factor.
827
0
      constexpr float kTileNorm = 1.2f;
828
0
      const float tile_dist =
829
0
          kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f);
830
0
      dist_row[tile_x] = tile_dist;
831
0
      for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
832
0
        for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
833
0
          dist_row[tile_x + distmap_stride * iy + ix] = tile_dist;
834
0
        }
835
0
      }
836
0
    }
837
0
  }
838
0
  return tile_distmap;
839
0
}
840
841
const float kDcQuantPow = 0.83f;
842
const float kDcQuant = 1.095924047623553f;
843
const float kAcQuant = 0.725f;
844
845
// Computes the decoded image for a given set of compression parameters.
846
StatusOr<ImageBundle> RoundtripImage(const FrameHeader& frame_header,
847
                                     const Image3F& opsin,
848
                                     PassesEncoderState* enc_state,
849
                                     const JxlCmsInterface& cms,
850
0
                                     ThreadPool* pool) {
851
0
  JxlMemoryManager* memory_manager = enc_state->memory_manager();
852
0
  std::unique_ptr<PassesDecoderState> dec_state =
853
0
      jxl::make_unique<PassesDecoderState>(memory_manager);
854
0
  JXL_RETURN_IF_ERROR(dec_state->output_encoding_info.SetFromMetadata(
855
0
      *enc_state->shared.metadata));
856
0
  dec_state->shared = &enc_state->shared;
857
0
  JXL_ENSURE(opsin.ysize() % kBlockDim == 0);
858
859
0
  const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim);
860
0
  const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim);
861
0
  const size_t num_groups = xsize_groups * ysize_groups;
862
863
0
  size_t num_special_frames = enc_state->special_frames.size();
864
0
  size_t num_passes = enc_state->progressive_splitter.GetNumPasses();
865
0
  JXL_ASSIGN_OR_RETURN(ModularFrameEncoder modular_frame_encoder,
866
0
                       ModularFrameEncoder::Create(memory_manager, frame_header,
867
0
                                                   enc_state->cparams, false));
868
0
  JXL_RETURN_IF_ERROR(InitializePassesEncoder(frame_header, opsin, Rect(opsin),
869
0
                                              cms, pool, enc_state,
870
0
                                              &modular_frame_encoder, nullptr));
871
0
  JXL_RETURN_IF_ERROR(dec_state->Init(frame_header));
872
0
  JXL_RETURN_IF_ERROR(dec_state->InitForAC(num_passes, pool));
873
874
0
  ImageBundle decoded(memory_manager, &enc_state->shared.metadata->m);
875
0
  decoded.origin = frame_header.frame_origin;
876
0
  JXL_ASSIGN_OR_RETURN(
877
0
      Image3F tmp,
878
0
      Image3F::Create(memory_manager, opsin.xsize(), opsin.ysize()));
879
0
  JXL_RETURN_IF_ERROR(decoded.SetFromImage(
880
0
      std::move(tmp), dec_state->output_encoding_info.color_encoding));
881
882
0
  PassesDecoderState::PipelineOptions options;
883
0
  options.use_slow_render_pipeline = false;
884
0
  options.coalescing = false;
885
0
  options.render_spotcolors = false;
886
0
  options.render_noise = false;
887
888
  // Same as frame_header.nonserialized_metadata->m
889
0
  const ImageMetadata& metadata = *decoded.metadata();
890
891
0
  JXL_RETURN_IF_ERROR(dec_state->PreparePipeline(
892
0
      frame_header, &enc_state->shared.metadata->m, &decoded, options));
893
894
0
  AlignedArray<GroupDecCache> group_dec_caches;
895
0
  const auto allocate_storage = [&](const size_t num_threads) -> Status {
896
0
    JXL_RETURN_IF_ERROR(
897
0
        dec_state->render_pipeline->PrepareForThreads(num_threads,
898
0
                                                      /*use_group_ids=*/false));
899
0
    JXL_ASSIGN_OR_RETURN(group_dec_caches, AlignedArray<GroupDecCache>::Create(
900
0
                                               memory_manager, num_threads));
901
0
    return true;
902
0
  };
903
0
  const auto process_group = [&](const uint32_t group_index,
904
0
                                 const size_t thread) -> Status {
905
0
    if (frame_header.loop_filter.epf_iters > 0) {
906
0
      JXL_RETURN_IF_ERROR(
907
0
          ComputeSigma(frame_header.loop_filter,
908
0
                       dec_state->shared->frame_dim.BlockGroupRect(group_index),
909
0
                       dec_state.get()));
910
0
    }
911
0
    RenderPipelineInput input =
912
0
        dec_state->render_pipeline->GetInputBuffers(group_index, thread);
913
0
    JXL_RETURN_IF_ERROR(DecodeGroupForRoundtrip(
914
0
        frame_header, enc_state->coeffs, group_index, dec_state.get(),
915
0
        &group_dec_caches[thread], thread, input, nullptr, nullptr));
916
0
    for (size_t c = 0; c < metadata.num_extra_channels; c++) {
917
0
      std::pair<ImageF*, Rect> ri = input.GetBuffer(3 + c);
918
0
      FillPlane(0.0f, ri.first, ri.second);
919
0
    }
920
0
    JXL_RETURN_IF_ERROR(input.Done());
921
0
    return true;
922
0
  };
923
0
  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_groups, allocate_storage,
924
0
                                process_group, "AQ loop"));
925
926
  // Ensure we don't create any new special frames.
927
0
  enc_state->special_frames.resize(num_special_frames);
928
929
0
  return decoded;
930
0
}
931
932
constexpr int kMaxButteraugliIters = 4;
933
934
Status FindBestQuantization(const FrameHeader& frame_header,
935
                            const Image3F& linear, const Image3F& opsin,
936
                            ImageF& quant_field, PassesEncoderState* enc_state,
937
                            const JxlCmsInterface& cms, ThreadPool* pool,
938
0
                            AuxOut* aux_out) {
939
0
  const CompressParams& cparams = enc_state->cparams;
940
0
  if (cparams.resampling > 1 &&
941
0
      cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) {
942
    // For downsampled opsin image, the butteraugli based adaptive quantization
943
    // loop would only make the size bigger without improving the distance much,
944
    // so in this case we enable it only for very high butteraugli targets.
945
0
    return true;
946
0
  }
947
0
  JxlMemoryManager* memory_manager = enc_state->memory_manager();
948
0
  Quantizer& quantizer = enc_state->shared.quantizer;
949
0
  ImageI& raw_quant_field = enc_state->shared.raw_quant_field;
950
951
0
  const float butteraugli_target = cparams.butteraugli_distance;
952
0
  const float original_butteraugli = cparams.original_butteraugli_distance;
953
0
  ButteraugliParams params;
954
0
  const auto& tf = frame_header.nonserialized_metadata->m.color_encoding.Tf();
955
0
  params.intensity_target =
956
0
      tf.IsPQ() || tf.IsHLG()
957
0
          ? frame_header.nonserialized_metadata->m.IntensityTarget()
958
0
          : 80.f;
959
0
  JxlButteraugliComparator comparator(params, cms);
960
0
  JXL_RETURN_IF_ERROR(comparator.SetLinearReferenceImage(linear));
961
0
  bool lower_is_better =
962
0
      (comparator.GoodQualityScore() < comparator.BadQualityScore());
963
0
  const float initial_quant_dc = InitialQuantDC(butteraugli_target);
964
0
  JXL_RETURN_IF_ERROR(AdjustQuantField(enc_state->shared.ac_strategy,
965
0
                                       Rect(quant_field), original_butteraugli,
966
0
                                       &quant_field));
967
0
  ImageF tile_distmap;
968
0
  JXL_ASSIGN_OR_RETURN(
969
0
      ImageF initial_quant_field,
970
0
      ImageF::Create(memory_manager, quant_field.xsize(), quant_field.ysize()));
971
0
  JXL_RETURN_IF_ERROR(CopyImageTo(quant_field, &initial_quant_field));
972
973
0
  float initial_qf_min;
974
0
  float initial_qf_max;
975
0
  ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max);
976
0
  float initial_qf_ratio = initial_qf_max / initial_qf_min;
977
0
  float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio);
978
0
  float asymmetry = 2;
979
0
  if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low;
980
0
  float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low);
981
0
  float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry);
982
983
0
  JXL_ENSURE(qf_higher / qf_lower < 253);
984
985
0
  constexpr int kOriginalComparisonRound = 1;
986
0
  int iters = kMaxButteraugliIters;
987
0
  if (cparams.speed_tier != SpeedTier::kTortoise) {
988
0
    iters = 2;
989
0
  }
990
0
  for (int i = 0; i < iters + 1; ++i) {
991
0
    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
992
0
      printf("\nQuantization field:\n");
993
0
      for (size_t y = 0; y < quant_field.ysize(); ++y) {
994
0
        for (size_t x = 0; x < quant_field.xsize(); ++x) {
995
0
          printf(" %.5f", quant_field.Row(y)[x]);
996
0
        }
997
0
        printf("\n");
998
0
      }
999
0
    }
1000
0
    JXL_RETURN_IF_ERROR(quantizer.SetQuantField(initial_quant_dc, quant_field,
1001
0
                                                &raw_quant_field));
1002
0
    JXL_ASSIGN_OR_RETURN(
1003
0
        ImageBundle dec_linear,
1004
0
        RoundtripImage(frame_header, opsin, enc_state, cms, pool));
1005
0
    float score;
1006
0
    ImageF diffmap;
1007
0
    JXL_RETURN_IF_ERROR(comparator.CompareWith(dec_linear, &diffmap, &score));
1008
0
    if (!lower_is_better) {
1009
0
      score = -score;
1010
0
      ScaleImage(-1.0f, &diffmap);
1011
0
    }
1012
0
    JXL_ASSIGN_OR_RETURN(tile_distmap,
1013
0
                         TileDistMap(diffmap, 8 * cparams.resampling, 0,
1014
0
                                     enc_state->shared.ac_strategy));
1015
0
    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && WantDebugOutput(cparams)) {
1016
0
      JXL_RETURN_IF_ERROR(DumpImage(cparams, ("dec" + ToString(i)).c_str(),
1017
0
                                    *dec_linear.color()));
1018
0
      JXL_RETURN_IF_ERROR(DumpHeatmaps(cparams, aux_out, butteraugli_target,
1019
0
                                       quant_field, tile_distmap, diffmap));
1020
0
    }
1021
0
    if (aux_out != nullptr) ++aux_out->num_butteraugli_iters;
1022
0
    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
1023
0
      float minval;
1024
0
      float maxval;
1025
0
      ImageMinMax(quant_field, &minval, &maxval);
1026
0
      printf("\nButteraugli iter: %d/%d\n", i, kMaxButteraugliIters);
1027
0
      printf("Butteraugli distance: %f  (target = %f)\n", score,
1028
0
             original_butteraugli);
1029
0
      printf("quant range: %f ... %f  DC quant: %f\n", minval, maxval,
1030
0
             initial_quant_dc);
1031
0
      if (FLAGS_dump_quant_state) {
1032
0
        quantizer.DumpQuantizationMap(raw_quant_field);
1033
0
      }
1034
0
    }
1035
1036
0
    if (i == iters) break;
1037
1038
0
    double kPow[8] = {
1039
0
        0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1040
0
    };
1041
0
    double kPowMod[8] = {
1042
0
        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1043
0
    };
1044
0
    if (i == kOriginalComparisonRound) {
1045
      // Don't allow optimization to make the quant field a lot worse than
1046
      // what the initial guess was. This allows the AC field to have enough
1047
      // precision to reduce the oscillations due to the dc reconstruction.
1048
0
      double kInitMul = 0.6;
1049
0
      const double kOneMinusInitMul = 1.0 - kInitMul;
1050
0
      for (size_t y = 0; y < quant_field.ysize(); ++y) {
1051
0
        float* const JXL_RESTRICT row_q = quant_field.Row(y);
1052
0
        const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y);
1053
0
        for (size_t x = 0; x < quant_field.xsize(); ++x) {
1054
0
          double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x];
1055
0
          if (row_q[x] < clamp) {
1056
0
            row_q[x] = clamp;
1057
0
            if (row_q[x] > qf_higher) row_q[x] = qf_higher;
1058
0
            if (row_q[x] < qf_lower) row_q[x] = qf_lower;
1059
0
          }
1060
0
        }
1061
0
      }
1062
0
    }
1063
1064
0
    double cur_pow = 0.0;
1065
0
    if (i < 7) {
1066
0
      cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i];
1067
0
      if (cur_pow < 0) {
1068
0
        cur_pow = 0;
1069
0
      }
1070
0
    }
1071
0
    if (cur_pow == 0.0) {
1072
0
      for (size_t y = 0; y < quant_field.ysize(); ++y) {
1073
0
        const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y);
1074
0
        float* const JXL_RESTRICT row_q = quant_field.Row(y);
1075
0
        for (size_t x = 0; x < quant_field.xsize(); ++x) {
1076
0
          const float diff = row_dist[x] / original_butteraugli;
1077
0
          if (diff > 1.0f) {
1078
0
            float old = row_q[x];
1079
0
            row_q[x] *= diff;
1080
0
            int qf_old =
1081
0
                static_cast<int>(std::lround(old * quantizer.InvGlobalScale()));
1082
0
            int qf_new = static_cast<int>(
1083
0
                std::lround(row_q[x] * quantizer.InvGlobalScale()));
1084
0
            if (qf_old == qf_new) {
1085
0
              row_q[x] = old + quantizer.Scale();
1086
0
            }
1087
0
          }
1088
0
          if (row_q[x] > qf_higher) row_q[x] = qf_higher;
1089
0
          if (row_q[x] < qf_lower) row_q[x] = qf_lower;
1090
0
        }
1091
0
      }
1092
0
    } else {
1093
0
      for (size_t y = 0; y < quant_field.ysize(); ++y) {
1094
0
        const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y);
1095
0
        float* const JXL_RESTRICT row_q = quant_field.Row(y);
1096
0
        for (size_t x = 0; x < quant_field.xsize(); ++x) {
1097
0
          const float diff = row_dist[x] / original_butteraugli;
1098
0
          if (diff <= 1.0f) {
1099
0
            row_q[x] *= std::pow(diff, cur_pow);
1100
0
          } else {
1101
0
            float old = row_q[x];
1102
0
            row_q[x] *= diff;
1103
0
            int qf_old =
1104
0
                static_cast<int>(std::lround(old * quantizer.InvGlobalScale()));
1105
0
            int qf_new = static_cast<int>(
1106
0
                std::lround(row_q[x] * quantizer.InvGlobalScale()));
1107
0
            if (qf_old == qf_new) {
1108
0
              row_q[x] = old + quantizer.Scale();
1109
0
            }
1110
0
          }
1111
0
          if (row_q[x] > qf_higher) row_q[x] = qf_higher;
1112
0
          if (row_q[x] < qf_lower) row_q[x] = qf_lower;
1113
0
        }
1114
0
      }
1115
0
    }
1116
0
  }
1117
0
  JXL_RETURN_IF_ERROR(
1118
0
      quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field));
1119
0
  return true;
1120
0
}
1121
1122
Status FindBestQuantizationMaxError(const FrameHeader& frame_header,
1123
                                    const Image3F& opsin, ImageF& quant_field,
1124
                                    PassesEncoderState* enc_state,
1125
                                    const JxlCmsInterface& cms,
1126
0
                                    ThreadPool* pool, AuxOut* aux_out) {
1127
  // TODO(szabadka): Make this work for non-opsin color spaces.
1128
0
  const CompressParams& cparams = enc_state->cparams;
1129
0
  Quantizer& quantizer = enc_state->shared.quantizer;
1130
0
  ImageI& raw_quant_field = enc_state->shared.raw_quant_field;
1131
1132
  // TODO(veluca): better choice of this value.
1133
0
  const float initial_quant_dc =
1134
0
      16 * std::sqrt(0.1f / cparams.butteraugli_distance);
1135
0
  JXL_RETURN_IF_ERROR(
1136
0
      AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field),
1137
0
                       cparams.original_butteraugli_distance, &quant_field));
1138
1139
0
  const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0],
1140
0
                                1.0f / enc_state->cparams.max_error[1],
1141
0
                                1.0f / enc_state->cparams.max_error[2]};
1142
1143
0
  for (int i = 0; i < kMaxButteraugliIters + 1; ++i) {
1144
0
    JXL_RETURN_IF_ERROR(quantizer.SetQuantField(initial_quant_dc, quant_field,
1145
0
                                                &raw_quant_field));
1146
0
    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) {
1147
0
      JXL_RETURN_IF_ERROR(
1148
0
          DumpXybImage(cparams, ("ops" + ToString(i)).c_str(), opsin));
1149
0
    }
1150
0
    JXL_ASSIGN_OR_RETURN(
1151
0
        ImageBundle decoded,
1152
0
        RoundtripImage(frame_header, opsin, enc_state, cms, pool));
1153
0
    if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) {
1154
0
      JXL_RETURN_IF_ERROR(DumpXybImage(cparams, ("dec" + ToString(i)).c_str(),
1155
0
                                       *decoded.color()));
1156
0
    }
1157
0
    for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) {
1158
0
      AcStrategyRow ac_strategy_row =
1159
0
          enc_state->shared.ac_strategy.ConstRow(by);
1160
0
      for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) {
1161
0
        AcStrategy acs = ac_strategy_row[bx];
1162
0
        if (!acs.IsFirstBlock()) continue;
1163
0
        float max_error = 0;
1164
0
        for (size_t c = 0; c < 3; c++) {
1165
0
          for (size_t y = by * kBlockDim;
1166
0
               y < (by + acs.covered_blocks_y()) * kBlockDim; y++) {
1167
0
            if (y >= decoded.ysize()) continue;
1168
0
            const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y);
1169
0
            const float* JXL_RESTRICT dec_row =
1170
0
                decoded.color()->ConstPlaneRow(c, y);
1171
0
            for (size_t x = bx * kBlockDim;
1172
0
                 x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) {
1173
0
              if (x >= decoded.xsize()) continue;
1174
0
              max_error = std::max(
1175
0
                  std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error);
1176
0
            }
1177
0
          }
1178
0
        }
1179
        // Target an error between max_error/2 and max_error.
1180
        // If the error in the varblock is above the target, increase the qf to
1181
        // compensate. If the error is below the target, decrease the qf.
1182
        // However, to avoid an excessive increase of the qf, only do so if the
1183
        // error is less than half the maximum allowed error.
1184
0
        const float qf_mul = (max_error < 0.5f)   ? max_error * 2.0f
1185
0
                             : (max_error > 1.0f) ? max_error
1186
0
                                                  : 1.0f;
1187
0
        for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) {
1188
0
          float* JXL_RESTRICT quant_field_row = quant_field.Row(qy);
1189
0
          for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) {
1190
0
            quant_field_row[qx] *= qf_mul;
1191
0
          }
1192
0
        }
1193
0
      }
1194
0
    }
1195
0
  }
1196
0
  JXL_RETURN_IF_ERROR(
1197
0
      quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field));
1198
0
  return true;
1199
0
}
1200
1201
}  // namespace
1202
1203
Status AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect,
1204
0
                        float butteraugli_target, ImageF* quant_field) {
1205
  // Replace the whole quant_field in non-8x8 blocks with the maximum of each
1206
  // 8x8 block.
1207
0
  size_t stride = quant_field->PixelsPerRow();
1208
1209
  // At low distances it is great to use max, but mean works better
1210
  // at high distances. We interpolate between them for a distance
1211
  // range.
1212
0
  float mean_max_mixer = 1.0f;
1213
0
  {
1214
0
    static const float kLimit = 1.54138f;
1215
0
    static const float kMul = 0.56391f;
1216
0
    static const float kMin = 0.0f;
1217
0
    if (butteraugli_target > kLimit) {
1218
0
      mean_max_mixer -= (butteraugli_target - kLimit) * kMul;
1219
0
      if (mean_max_mixer < kMin) {
1220
0
        mean_max_mixer = kMin;
1221
0
      }
1222
0
    }
1223
0
  }
1224
0
  for (size_t y = 0; y < rect.ysize(); ++y) {
1225
0
    AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y);
1226
0
    float* JXL_RESTRICT quant_row = rect.Row(quant_field, y);
1227
0
    for (size_t x = 0; x < rect.xsize(); ++x) {
1228
0
      AcStrategy acs = ac_strategy_row[x];
1229
0
      if (!acs.IsFirstBlock()) continue;
1230
0
      JXL_ENSURE(x + acs.covered_blocks_x() <= quant_field->xsize());
1231
0
      JXL_ENSURE(y + acs.covered_blocks_y() <= quant_field->ysize());
1232
0
      float max = quant_row[x];
1233
0
      float mean = 0.0;
1234
0
      for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
1235
0
        for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
1236
0
          mean += quant_row[x + ix + iy * stride];
1237
0
          max = std::max(quant_row[x + ix + iy * stride], max);
1238
0
        }
1239
0
      }
1240
0
      mean /= acs.covered_blocks_y() * acs.covered_blocks_x();
1241
0
      if (acs.covered_blocks_y() * acs.covered_blocks_x() >= 4) {
1242
0
        max *= mean_max_mixer;
1243
0
        max += (1.0f - mean_max_mixer) * mean;
1244
0
      }
1245
0
      for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
1246
0
        for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
1247
0
          quant_row[x + ix + iy * stride] = max;
1248
0
        }
1249
0
      }
1250
0
    }
1251
0
  }
1252
0
  return true;
1253
0
}
1254
1255
0
float InitialQuantDC(float butteraugli_target) {
1256
0
  const float kDcMul = 0.3;  // Butteraugli target where non-linearity kicks in.
1257
0
  const float butteraugli_target_dc = std::max<float>(
1258
0
      0.5f * butteraugli_target,
1259
0
      std::min<float>(butteraugli_target,
1260
0
                      kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target,
1261
0
                                        kDcQuantPow)));
1262
  // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc.
1263
  // The maximum DC value might not be in the kXybRange because of inverse
1264
  // gaborish, so we add some slack to the maximum theoretical quant obtained
1265
  // this way (64).
1266
0
  return std::min(kDcQuant / butteraugli_target_dc, 50.f);
1267
0
}
1268
1269
StatusOr<ImageF> InitialQuantField(const float butteraugli_target,
1270
                                   const Image3F& opsin, const Rect& rect,
1271
                                   ThreadPool* pool, float rescale,
1272
0
                                   ImageF* mask, ImageF* mask1x1) {
1273
0
  const float quant_ac = kAcQuant / butteraugli_target;
1274
0
  return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)(
1275
0
      butteraugli_target, opsin, rect, quant_ac * rescale, pool, mask, mask1x1);
1276
0
}
1277
1278
Status FindBestQuantizer(const FrameHeader& frame_header, const Image3F* linear,
1279
                         const Image3F& opsin, ImageF& quant_field,
1280
                         PassesEncoderState* enc_state,
1281
                         const JxlCmsInterface& cms, ThreadPool* pool,
1282
0
                         AuxOut* aux_out, double rescale) {
1283
0
  const CompressParams& cparams = enc_state->cparams;
1284
0
  if (cparams.max_error_mode) {
1285
0
    JXL_RETURN_IF_ERROR(FindBestQuantizationMaxError(
1286
0
        frame_header, opsin, quant_field, enc_state, cms, pool, aux_out));
1287
0
  } else if (linear && cparams.speed_tier <= SpeedTier::kKitten) {
1288
    // Normal encoding to a butteraugli score.
1289
0
    JXL_RETURN_IF_ERROR(FindBestQuantization(frame_header, *linear, opsin,
1290
0
                                             quant_field, enc_state, cms, pool,
1291
0
                                             aux_out));
1292
0
  }
1293
0
  return true;
1294
0
}
1295
1296
}  // namespace jxl
1297
#endif  // HWY_ONCE