Coverage Report

Created: 2025-06-16 07:00

/src/libjxl/lib/jxl/enc_heuristics.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_heuristics.h"
7
8
#include <jxl/cms_interface.h>
9
#include <jxl/memory_manager.h>
10
11
#include <algorithm>
12
#include <array>
13
#include <cmath>
14
#include <cstddef>
15
#include <cstdint>
16
#include <cstdlib>
17
#include <limits>
18
#include <memory>
19
#include <numeric>
20
#include <string>
21
#include <utility>
22
#include <vector>
23
24
#include "lib/jxl/ac_context.h"
25
#include "lib/jxl/ac_strategy.h"
26
#include "lib/jxl/base/common.h"
27
#include "lib/jxl/base/compiler_specific.h"
28
#include "lib/jxl/base/data_parallel.h"
29
#include "lib/jxl/base/override.h"
30
#include "lib/jxl/base/rect.h"
31
#include "lib/jxl/base/status.h"
32
#include "lib/jxl/butteraugli/butteraugli.h"
33
#include "lib/jxl/chroma_from_luma.h"
34
#include "lib/jxl/coeff_order.h"
35
#include "lib/jxl/coeff_order_fwd.h"
36
#include "lib/jxl/color_encoding_internal.h"
37
#include "lib/jxl/common.h"
38
#include "lib/jxl/dct_util.h"
39
#include "lib/jxl/dec_cache.h"
40
#include "lib/jxl/dec_group.h"
41
#include "lib/jxl/dec_noise.h"
42
#include "lib/jxl/dec_xyb.h"
43
#include "lib/jxl/enc_ac_strategy.h"
44
#include "lib/jxl/enc_adaptive_quantization.h"
45
#include "lib/jxl/enc_cache.h"
46
#include "lib/jxl/enc_chroma_from_luma.h"
47
#include "lib/jxl/enc_gaborish.h"
48
#include "lib/jxl/enc_modular.h"
49
#include "lib/jxl/enc_noise.h"
50
#include "lib/jxl/enc_params.h"
51
#include "lib/jxl/enc_patch_dictionary.h"
52
#include "lib/jxl/enc_quant_weights.h"
53
#include "lib/jxl/enc_splines.h"
54
#include "lib/jxl/epf.h"
55
#include "lib/jxl/frame_dimensions.h"
56
#include "lib/jxl/frame_header.h"
57
#include "lib/jxl/image.h"
58
#include "lib/jxl/image_metadata.h"
59
#include "lib/jxl/image_ops.h"
60
#include "lib/jxl/memory_manager_internal.h"
61
#include "lib/jxl/passes_state.h"
62
#include "lib/jxl/quant_weights.h"
63
#include "lib/jxl/render_pipeline/render_pipeline.h"
64
65
namespace jxl {
66
67
struct AuxOut;
68
69
void FindBestBlockEntropyModel(const CompressParams& cparams, const ImageI& rqf,
70
                               const AcStrategyImage& ac_strategy,
71
186
                               BlockCtxMap* block_ctx_map) {
72
186
  if (cparams.decoding_speed_tier >= 1) {
73
0
    static constexpr uint8_t kSimpleCtxMap[] = {
74
        // Cluster all blocks together
75
0
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  //
76
0
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  //
77
0
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  //
78
0
    };
79
0
    static_assert(
80
0
        3 * kNumOrders == sizeof(kSimpleCtxMap) / sizeof *kSimpleCtxMap,
81
0
        "Update simple context map");
82
83
0
    auto bcm = *block_ctx_map;
84
0
    bcm.ctx_map.assign(std::begin(kSimpleCtxMap), std::end(kSimpleCtxMap));
85
0
    bcm.num_ctxs = 2;
86
0
    bcm.num_dc_ctxs = 1;
87
0
    return;
88
0
  }
89
186
  if (cparams.speed_tier >= SpeedTier::kFalcon) {
90
0
    return;
91
0
  }
92
  // No need to change context modeling for small images.
93
186
  size_t tot = rqf.xsize() * rqf.ysize();
94
186
  size_t size_for_ctx_model = (1 << 10) * cparams.butteraugli_distance;
95
186
  if (tot < size_for_ctx_model) return;
96
97
137
  struct OccCounters {
98
    // count the occurrences of each qf value and each strategy type.
99
137
    OccCounters(const ImageI& rqf, const AcStrategyImage& ac_strategy) {
100
7.23k
      for (size_t y = 0; y < rqf.ysize(); y++) {
101
7.09k
        const int32_t* qf_row = rqf.Row(y);
102
7.09k
        AcStrategyRow acs_row = ac_strategy.ConstRow(y);
103
400k
        for (size_t x = 0; x < rqf.xsize(); x++) {
104
393k
          int ord = kStrategyOrder[acs_row[x].RawStrategy()];
105
393k
          int qf = qf_row[x] - 1;
106
393k
          qf_counts[qf]++;
107
393k
          qf_ord_counts[ord][qf]++;
108
393k
          ord_counts[ord]++;
109
393k
        }
110
7.09k
      }
111
137
    }
112
113
137
    size_t qf_counts[256] = {};
114
137
    size_t qf_ord_counts[kNumOrders][256] = {};
115
137
    size_t ord_counts[kNumOrders] = {};
116
137
  };
117
  // The OccCounters struct is too big to allocate on the stack.
118
137
  std::unique_ptr<OccCounters> counters(new OccCounters(rqf, ac_strategy));
119
120
  // Splitting the context model according to the quantization field seems to
121
  // mostly benefit only large images.
122
137
  size_t size_for_qf_split = (1 << 13) * cparams.butteraugli_distance;
123
137
  size_t num_qf_segments = tot < size_for_qf_split ? 1 : 2;
124
137
  std::vector<uint32_t>& qft = block_ctx_map->qf_thresholds;
125
137
  qft.clear();
126
  // Divide the quant field in up to num_qf_segments segments.
127
137
  size_t cumsum = 0;
128
137
  size_t next = 1;
129
137
  size_t last_cut = 256;
130
137
  size_t cut = tot * next / num_qf_segments;
131
35.2k
  for (uint32_t j = 0; j < 256; j++) {
132
35.0k
    cumsum += counters->qf_counts[j];
133
35.0k
    if (cumsum > cut) {
134
1
      if (j != 0) {
135
1
        qft.push_back(j);
136
1
      }
137
1
      last_cut = j;
138
2
      while (cumsum > cut) {
139
1
        next++;
140
1
        cut = tot * next / num_qf_segments;
141
1
      }
142
35.0k
    } else if (next > qft.size() + 1) {
143
0
      if (j - 1 == last_cut && j != 0) {
144
0
        qft.push_back(j);
145
0
      }
146
0
    }
147
35.0k
  }
148
149
  // Count the occurrences of each segment.
150
137
  std::vector<size_t> counts(kNumOrders * (qft.size() + 1));
151
137
  size_t qft_pos = 0;
152
35.2k
  for (size_t j = 0; j < 256; j++) {
153
35.0k
    if (qft_pos < qft.size() && j == qft[qft_pos]) {
154
1
      qft_pos++;
155
1
    }
156
491k
    for (size_t i = 0; i < kNumOrders; i++) {
157
455k
      counts[qft_pos + i * (qft.size() + 1)] += counters->qf_ord_counts[i][j];
158
455k
    }
159
35.0k
  }
160
161
  // Repeatedly merge the lowest-count pair.
162
137
  std::vector<uint8_t> remap((qft.size() + 1) * kNumOrders);
163
137
  std::iota(remap.begin(), remap.end(), 0);
164
137
  std::vector<uint8_t> clusters(remap);
165
137
  size_t nb_clusters =
166
137
      Clamp1(static_cast<int>(tot / size_for_ctx_model / 2), 2, 9);
167
137
  size_t nb_clusters_chroma =
168
137
      Clamp1(static_cast<int>(tot / size_for_ctx_model / 3), 1, 5);
169
  // This is O(n^2 log n), but n is small.
170
1.65k
  while (clusters.size() > nb_clusters) {
171
1.51k
    std::sort(clusters.begin(), clusters.end(),
172
12.1k
              [&](int a, int b) { return counts[a] > counts[b]; });
173
1.51k
    counts[clusters[clusters.size() - 2]] += counts[clusters.back()];
174
1.51k
    counts[clusters.back()] = 0;
175
1.51k
    remap[clusters.back()] = clusters[clusters.size() - 2];
176
1.51k
    clusters.pop_back();
177
1.51k
  }
178
1.93k
  for (size_t i = 0; i < remap.size(); i++) {
179
3.05k
    while (remap[remap[i]] != remap[i]) {
180
1.26k
      remap[i] = remap[remap[i]];
181
1.26k
    }
182
1.79k
  }
183
  // Relabel starting from 0.
184
137
  std::vector<uint8_t> remap_remap(remap.size(), remap.size());
185
137
  size_t num = 0;
186
1.93k
  for (size_t i = 0; i < remap.size(); i++) {
187
1.79k
    if (remap_remap[remap[i]] == remap.size()) {
188
279
      remap_remap[remap[i]] = num++;
189
279
    }
190
1.79k
    remap[i] = remap_remap[remap[i]];
191
1.79k
  }
192
  // Write the block context map.
193
137
  auto& ctx_map = block_ctx_map->ctx_map;
194
137
  ctx_map = remap;
195
137
  ctx_map.resize(remap.size() * 3);
196
  // for chroma, only use up to nb_clusters_chroma separate block contexts
197
  // (those for the biggest clusters)
198
3.72k
  for (size_t i = remap.size(); i < remap.size() * 3; i++) {
199
3.58k
    ctx_map[i] = num + Clamp1(static_cast<int>(remap[i % remap.size()]), 0,
200
3.58k
                              static_cast<int>(nb_clusters_chroma) - 1);
201
3.58k
  }
202
137
  block_ctx_map->num_ctxs =
203
137
      *std::max_element(ctx_map.begin(), ctx_map.end()) + 1;
204
137
}
205
206
namespace {
207
208
Status FindBestDequantMatrices(JxlMemoryManager* memory_manager,
209
                               const CompressParams& cparams,
210
                               ModularFrameEncoder* modular_frame_encoder,
211
186
                               DequantMatrices* dequant_matrices) {
212
  // TODO(veluca): quant matrices for no-gaborish.
213
  // TODO(veluca): heuristics for in-bitstream quant tables.
214
186
  *dequant_matrices = DequantMatrices();
215
186
  if (cparams.max_error_mode || cparams.disable_perceptual_optimizations) {
216
0
    constexpr float kMSEWeights[3] = {0.001, 0.001, 0.001};
217
0
    const float* wp = cparams.disable_perceptual_optimizations
218
0
                          ? kMSEWeights
219
0
                          : cparams.max_error;
220
    // Set numerators of all quantization matrices to constant values.
221
0
    float weights[3][1] = {{1.0f / wp[0]}, {1.0f / wp[1]}, {1.0f / wp[2]}};
222
0
    DctQuantWeightParams dct_params(weights);
223
0
    std::vector<QuantEncoding> encodings(kNumQuantTables,
224
0
                                         QuantEncoding::DCT(dct_params));
225
0
    JXL_RETURN_IF_ERROR(DequantMatricesSetCustom(dequant_matrices, encodings,
226
0
                                                 modular_frame_encoder));
227
0
    float dc_weights[3] = {1.0f / wp[0], 1.0f / wp[1], 1.0f / wp[2]};
228
0
    JXL_RETURN_IF_ERROR(DequantMatricesSetCustomDC(
229
0
        memory_manager, dequant_matrices, dc_weights));
230
0
  }
231
186
  return true;
232
186
}
233
234
0
void StoreMin2(const float v, float& min1, float& min2) {
235
0
  if (v < min2) {
236
0
    if (v < min1) {
237
0
      min2 = min1;
238
0
      min1 = v;
239
0
    } else {
240
0
      min2 = v;
241
0
    }
242
0
  }
243
0
}
244
245
0
void CreateMask(const ImageF& image, ImageF& mask) {
246
0
  for (size_t y = 0; y < image.ysize(); y++) {
247
0
    const auto* row_n = y > 0 ? image.Row(y - 1) : image.Row(y);
248
0
    const auto* row_in = image.Row(y);
249
0
    const auto* row_s = y + 1 < image.ysize() ? image.Row(y + 1) : image.Row(y);
250
0
    auto* row_out = mask.Row(y);
251
0
    for (size_t x = 0; x < image.xsize(); x++) {
252
      // Center, west, east, north, south values and their absolute difference
253
0
      float c = row_in[x];
254
0
      float w = x > 0 ? row_in[x - 1] : row_in[x];
255
0
      float e = x + 1 < image.xsize() ? row_in[x + 1] : row_in[x];
256
0
      float n = row_n[x];
257
0
      float s = row_s[x];
258
0
      float dw = std::abs(c - w);
259
0
      float de = std::abs(c - e);
260
0
      float dn = std::abs(c - n);
261
0
      float ds = std::abs(c - s);
262
0
      float min = std::numeric_limits<float>::max();
263
0
      float min2 = std::numeric_limits<float>::max();
264
0
      StoreMin2(dw, min, min2);
265
0
      StoreMin2(de, min, min2);
266
0
      StoreMin2(dn, min, min2);
267
0
      StoreMin2(ds, min, min2);
268
0
      row_out[x] = min2;
269
0
    }
270
0
  }
271
0
}
272
273
// Downsamples the image by a factor of 2 with a kernel that's sharper than
274
// the standard 2x2 box kernel used by DownsampleImage.
275
// The kernel is optimized against the result of the 2x2 upsampling kernel used
276
// by the decoder. Ringing is slightly reduced by clamping the values of the
277
// resulting pixels within certain bounds of a small region in the original
278
// image.
279
0
Status DownsampleImage2_Sharper(const ImageF& input, ImageF* output) {
280
0
  const int64_t kernelx = 12;
281
0
  const int64_t kernely = 12;
282
0
  JxlMemoryManager* memory_manager = input.memory_manager();
283
284
0
  static const float kernel[144] = {
285
0
      -0.000314256996835, -0.000314256996835, -0.000897597057705,
286
0
      -0.000562751488849, -0.000176807273646, 0.001864627368902,
287
0
      0.001864627368902,  -0.000176807273646, -0.000562751488849,
288
0
      -0.000897597057705, -0.000314256996835, -0.000314256996835,
289
0
      -0.000314256996835, -0.001527942804748, -0.000121760530512,
290
0
      0.000191123989093,  0.010193185932466,  0.058637519197110,
291
0
      0.058637519197110,  0.010193185932466,  0.000191123989093,
292
0
      -0.000121760530512, -0.001527942804748, -0.000314256996835,
293
0
      -0.000897597057705, -0.000121760530512, 0.000946363683751,
294
0
      0.007113577630288,  0.000437956841058,  -0.000372823835211,
295
0
      -0.000372823835211, 0.000437956841058,  0.007113577630288,
296
0
      0.000946363683751,  -0.000121760530512, -0.000897597057705,
297
0
      -0.000562751488849, 0.000191123989093,  0.007113577630288,
298
0
      0.044592622228814,  0.000222278879007,  -0.162864473015945,
299
0
      -0.162864473015945, 0.000222278879007,  0.044592622228814,
300
0
      0.007113577630288,  0.000191123989093,  -0.000562751488849,
301
0
      -0.000176807273646, 0.010193185932466,  0.000437956841058,
302
0
      0.000222278879007,  -0.000913092543974, -0.017071696107902,
303
0
      -0.017071696107902, -0.000913092543974, 0.000222278879007,
304
0
      0.000437956841058,  0.010193185932466,  -0.000176807273646,
305
0
      0.001864627368902,  0.058637519197110,  -0.000372823835211,
306
0
      -0.162864473015945, -0.017071696107902, 0.414660099370354,
307
0
      0.414660099370354,  -0.017071696107902, -0.162864473015945,
308
0
      -0.000372823835211, 0.058637519197110,  0.001864627368902,
309
0
      0.001864627368902,  0.058637519197110,  -0.000372823835211,
310
0
      -0.162864473015945, -0.017071696107902, 0.414660099370354,
311
0
      0.414660099370354,  -0.017071696107902, -0.162864473015945,
312
0
      -0.000372823835211, 0.058637519197110,  0.001864627368902,
313
0
      -0.000176807273646, 0.010193185932466,  0.000437956841058,
314
0
      0.000222278879007,  -0.000913092543974, -0.017071696107902,
315
0
      -0.017071696107902, -0.000913092543974, 0.000222278879007,
316
0
      0.000437956841058,  0.010193185932466,  -0.000176807273646,
317
0
      -0.000562751488849, 0.000191123989093,  0.007113577630288,
318
0
      0.044592622228814,  0.000222278879007,  -0.162864473015945,
319
0
      -0.162864473015945, 0.000222278879007,  0.044592622228814,
320
0
      0.007113577630288,  0.000191123989093,  -0.000562751488849,
321
0
      -0.000897597057705, -0.000121760530512, 0.000946363683751,
322
0
      0.007113577630288,  0.000437956841058,  -0.000372823835211,
323
0
      -0.000372823835211, 0.000437956841058,  0.007113577630288,
324
0
      0.000946363683751,  -0.000121760530512, -0.000897597057705,
325
0
      -0.000314256996835, -0.001527942804748, -0.000121760530512,
326
0
      0.000191123989093,  0.010193185932466,  0.058637519197110,
327
0
      0.058637519197110,  0.010193185932466,  0.000191123989093,
328
0
      -0.000121760530512, -0.001527942804748, -0.000314256996835,
329
0
      -0.000314256996835, -0.000314256996835, -0.000897597057705,
330
0
      -0.000562751488849, -0.000176807273646, 0.001864627368902,
331
0
      0.001864627368902,  -0.000176807273646, -0.000562751488849,
332
0
      -0.000897597057705, -0.000314256996835, -0.000314256996835};
333
334
0
  int64_t xsize = input.xsize();
335
0
  int64_t ysize = input.ysize();
336
337
0
  JXL_ASSIGN_OR_RETURN(ImageF box_downsample,
338
0
                       ImageF::Create(memory_manager, xsize, ysize));
339
0
  JXL_RETURN_IF_ERROR(CopyImageTo(input, &box_downsample));
340
0
  JXL_ASSIGN_OR_RETURN(box_downsample, DownsampleImage(box_downsample, 2));
341
342
0
  JXL_ASSIGN_OR_RETURN(ImageF mask,
343
0
                       ImageF::Create(memory_manager, box_downsample.xsize(),
344
0
                                      box_downsample.ysize()));
345
0
  CreateMask(box_downsample, mask);
346
347
0
  for (size_t y = 0; y < output->ysize(); y++) {
348
0
    float* row_out = output->Row(y);
349
0
    const float* row_in[kernely];
350
0
    const float* row_mask = mask.Row(y);
351
    // get the rows in the support
352
0
    for (size_t ky = 0; ky < kernely; ky++) {
353
0
      int64_t iy = y * 2 + ky - (kernely - 1) / 2;
354
0
      if (iy < 0) iy = 0;
355
0
      if (iy >= ysize) iy = ysize - 1;
356
0
      row_in[ky] = input.Row(iy);
357
0
    }
358
359
0
    for (size_t x = 0; x < output->xsize(); x++) {
360
      // get min and max values of the original image in the support
361
0
      float min = std::numeric_limits<float>::max();
362
0
      float max = std::numeric_limits<float>::min();
363
      // kernelx - R and kernely - R are the radius of a rectangular region in
364
      // which the values of a pixel are bounded to reduce ringing.
365
0
      static constexpr int64_t R = 5;
366
0
      for (int64_t ky = R; ky + R < kernely; ky++) {
367
0
        for (int64_t kx = R; kx + R < kernelx; kx++) {
368
0
          int64_t ix = x * 2 + kx - (kernelx - 1) / 2;
369
0
          if (ix < 0) ix = 0;
370
0
          if (ix >= xsize) ix = xsize - 1;
371
0
          min = std::min<float>(min, row_in[ky][ix]);
372
0
          max = std::max<float>(max, row_in[ky][ix]);
373
0
        }
374
0
      }
375
376
0
      float sum = 0;
377
0
      for (int64_t ky = 0; ky < kernely; ky++) {
378
0
        for (int64_t kx = 0; kx < kernelx; kx++) {
379
0
          int64_t ix = x * 2 + kx - (kernelx - 1) / 2;
380
0
          if (ix < 0) ix = 0;
381
0
          if (ix >= xsize) ix = xsize - 1;
382
0
          sum += row_in[ky][ix] * kernel[ky * kernelx + kx];
383
0
        }
384
0
      }
385
386
0
      row_out[x] = sum;
387
388
      // Clamp the pixel within the value  of a small area to prevent ringning.
389
      // The mask determines how much to clamp, clamp more to reduce more
390
      // ringing in smooth areas, clamp less in noisy areas to get more
391
      // sharpness. Higher mask_multiplier gives less clamping, so less
392
      // ringing reduction.
393
0
      const constexpr float mask_multiplier = 1;
394
0
      float a = row_mask[x] * mask_multiplier;
395
0
      float clip_min = min - a;
396
0
      float clip_max = max + a;
397
0
      if (row_out[x] < clip_min) {
398
0
        row_out[x] = clip_min;
399
0
      } else if (row_out[x] > clip_max) {
400
0
        row_out[x] = clip_max;
401
0
      }
402
0
    }
403
0
  }
404
0
  return true;
405
0
}
406
407
}  // namespace
408
409
0
Status DownsampleImage2_Sharper(Image3F* opsin) {
410
  // Allocate extra space to avoid a reallocation when padding.
411
0
  JxlMemoryManager* memory_manager = opsin->memory_manager();
412
0
  JXL_ASSIGN_OR_RETURN(
413
0
      Image3F downsampled,
414
0
      Image3F::Create(memory_manager, DivCeil(opsin->xsize(), 2) + kBlockDim,
415
0
                      DivCeil(opsin->ysize(), 2) + kBlockDim));
416
0
  JXL_RETURN_IF_ERROR(downsampled.ShrinkTo(downsampled.xsize() - kBlockDim,
417
0
                                           downsampled.ysize() - kBlockDim));
418
419
0
  for (size_t c = 0; c < 3; c++) {
420
0
    JXL_RETURN_IF_ERROR(
421
0
        DownsampleImage2_Sharper(opsin->Plane(c), &downsampled.Plane(c)));
422
0
  }
423
0
  *opsin = std::move(downsampled);
424
0
  return true;
425
0
}
426
427
namespace {
428
429
// The default upsampling kernels used by Upsampler in the decoder.
430
const constexpr int64_t kSize = 5;
431
432
const float kernel00[25] = {
433
    -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f,
434
    -0.03452303f, 0.14111091f,  0.28896755f,  0.00278718f,  -0.01610267f,
435
    -0.04022174f, 0.28896755f,  0.56661550f,  0.03777607f,  -0.01986694f,
436
    -0.02921014f, 0.00278718f,  0.03777607f,  -0.03144731f, -0.01185068f,
437
    -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f,
438
};
439
const float kernel01[25] = {
440
    -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f,
441
    -0.02921014f, 0.00278718f,  0.03777607f,  -0.03144731f, -0.01185068f,
442
    -0.04022174f, 0.28896755f,  0.56661550f,  0.03777607f,  -0.01986694f,
443
    -0.03452303f, 0.14111091f,  0.28896755f,  0.00278718f,  -0.01610267f,
444
    -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f,
445
};
446
const float kernel10[25] = {
447
    -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f,
448
    -0.01610267f, 0.00278718f,  0.28896755f,  0.14111091f,  -0.03452303f,
449
    -0.01986694f, 0.03777607f,  0.56661550f,  0.28896755f,  -0.04022174f,
450
    -0.01185068f, -0.03144731f, 0.03777607f,  0.00278718f,  -0.02921014f,
451
    -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f,
452
};
453
const float kernel11[25] = {
454
    -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f,
455
    -0.01185068f, -0.03144731f, 0.03777607f,  0.00278718f,  -0.02921014f,
456
    -0.01986694f, 0.03777607f,  0.56661550f,  0.28896755f,  -0.04022174f,
457
    -0.01610267f, 0.00278718f,  0.28896755f,  0.14111091f,  -0.03452303f,
458
    -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f,
459
};
460
461
// Does exactly the same as the Upsampler in dec_upsampler for 2x2 pixels, with
462
// default CustomTransformData.
463
// TODO(lode): use Upsampler instead. However, it requires pre-initialization
464
// and padding on the left side of the image which requires refactoring the
465
// other code using this.
466
0
void UpsampleImage(const ImageF& input, ImageF* output) {
467
0
  int64_t xsize = input.xsize();
468
0
  int64_t ysize = input.ysize();
469
0
  int64_t xsize2 = output->xsize();
470
0
  int64_t ysize2 = output->ysize();
471
0
  for (int64_t y = 0; y < ysize2; y++) {
472
0
    for (int64_t x = 0; x < xsize2; x++) {
473
0
      const auto* kernel = kernel00;
474
0
      if ((x & 1) && (y & 1)) {
475
0
        kernel = kernel11;
476
0
      } else if (x & 1) {
477
0
        kernel = kernel10;
478
0
      } else if (y & 1) {
479
0
        kernel = kernel01;
480
0
      }
481
0
      float sum = 0;
482
0
      int64_t x2 = x / 2;
483
0
      int64_t y2 = y / 2;
484
485
      // get min and max values of the original image in the support
486
0
      float min = std::numeric_limits<float>::max();
487
0
      float max = std::numeric_limits<float>::min();
488
489
0
      for (int64_t ky = 0; ky < kSize; ky++) {
490
0
        for (int64_t kx = 0; kx < kSize; kx++) {
491
0
          int64_t xi = x2 - kSize / 2 + kx;
492
0
          int64_t yi = y2 - kSize / 2 + ky;
493
0
          if (xi < 0) xi = 0;
494
0
          if (xi >= xsize) xi = input.xsize() - 1;
495
0
          if (yi < 0) yi = 0;
496
0
          if (yi >= ysize) yi = input.ysize() - 1;
497
0
          min = std::min<float>(min, input.Row(yi)[xi]);
498
0
          max = std::max<float>(max, input.Row(yi)[xi]);
499
0
        }
500
0
      }
501
502
0
      for (int64_t ky = 0; ky < kSize; ky++) {
503
0
        for (int64_t kx = 0; kx < kSize; kx++) {
504
0
          int64_t xi = x2 - kSize / 2 + kx;
505
0
          int64_t yi = y2 - kSize / 2 + ky;
506
0
          if (xi < 0) xi = 0;
507
0
          if (xi >= xsize) xi = input.xsize() - 1;
508
0
          if (yi < 0) yi = 0;
509
0
          if (yi >= ysize) yi = input.ysize() - 1;
510
0
          sum += input.Row(yi)[xi] * kernel[ky * kSize + kx];
511
0
        }
512
0
      }
513
0
      output->Row(y)[x] = sum;
514
0
      if (output->Row(y)[x] < min) output->Row(y)[x] = min;
515
0
      if (output->Row(y)[x] > max) output->Row(y)[x] = max;
516
0
    }
517
0
  }
518
0
}
519
520
// Returns the derivative of Upsampler, with respect to input pixel x2, y2, to
521
// output pixel x, y (ignoring the clamping).
522
0
float UpsamplerDeriv(int64_t x2, int64_t y2, int64_t x, int64_t y) {
523
0
  const auto* kernel = kernel00;
524
0
  if ((x & 1) && (y & 1)) {
525
0
    kernel = kernel11;
526
0
  } else if (x & 1) {
527
0
    kernel = kernel10;
528
0
  } else if (y & 1) {
529
0
    kernel = kernel01;
530
0
  }
531
532
0
  int64_t ix = x / 2;
533
0
  int64_t iy = y / 2;
534
0
  int64_t kx = x2 - ix + kSize / 2;
535
0
  int64_t ky = y2 - iy + kSize / 2;
536
537
  // This should not happen.
538
0
  if (kx < 0 || kx >= kSize || ky < 0 || ky >= kSize) return 0;
539
540
0
  return kernel[ky * kSize + kx];
541
0
}
542
543
// Apply the derivative of the Upsampler to the input, reversing the effect of
544
// its coefficients. The output image is 2x2 times smaller than the input.
545
0
void AntiUpsample(const ImageF& input, ImageF* d) {
546
0
  int64_t xsize = input.xsize();
547
0
  int64_t ysize = input.ysize();
548
0
  int64_t xsize2 = d->xsize();
549
0
  int64_t ysize2 = d->ysize();
550
0
  int64_t k0 = kSize - 1;
551
0
  int64_t k1 = kSize;
552
0
  for (int64_t y2 = 0; y2 < ysize2; ++y2) {
553
0
    auto* row = d->Row(y2);
554
0
    for (int64_t x2 = 0; x2 < xsize2; ++x2) {
555
0
      int64_t x0 = x2 * 2 - k0;
556
0
      if (x0 < 0) x0 = 0;
557
0
      int64_t x1 = x2 * 2 + k1 + 1;
558
0
      if (x1 > xsize) x1 = xsize;
559
0
      int64_t y0 = y2 * 2 - k0;
560
0
      if (y0 < 0) y0 = 0;
561
0
      int64_t y1 = y2 * 2 + k1 + 1;
562
0
      if (y1 > ysize) y1 = ysize;
563
564
0
      float sum = 0;
565
0
      for (int64_t y = y0; y < y1; ++y) {
566
0
        const auto* row_in = input.Row(y);
567
0
        for (int64_t x = x0; x < x1; ++x) {
568
0
          double deriv = UpsamplerDeriv(x2, y2, x, y);
569
0
          sum += deriv * row_in[x];
570
0
        }
571
0
      }
572
0
      row[x2] = sum;
573
0
    }
574
0
  }
575
0
}
576
577
// Element-wise multiplies two images.
578
template <typename T>
579
Status ElwiseMul(const Plane<T>& image1, const Plane<T>& image2,
580
0
                 Plane<T>* out) {
581
0
  const size_t xsize = image1.xsize();
582
0
  const size_t ysize = image1.ysize();
583
0
  JXL_ENSURE(xsize == image2.xsize());
584
0
  JXL_ENSURE(ysize == image2.ysize());
585
0
  JXL_ENSURE(xsize == out->xsize());
586
0
  JXL_ENSURE(ysize == out->ysize());
587
0
  for (size_t y = 0; y < ysize; ++y) {
588
0
    const T* const JXL_RESTRICT row1 = image1.Row(y);
589
0
    const T* const JXL_RESTRICT row2 = image2.Row(y);
590
0
    T* const JXL_RESTRICT row_out = out->Row(y);
591
0
    for (size_t x = 0; x < xsize; ++x) {
592
0
      row_out[x] = row1[x] * row2[x];
593
0
    }
594
0
  }
595
0
  return true;
596
0
}
597
598
// Element-wise divides two images.
599
template <typename T>
600
Status ElwiseDiv(const Plane<T>& image1, const Plane<T>& image2,
601
0
                 Plane<T>* out) {
602
0
  const size_t xsize = image1.xsize();
603
0
  const size_t ysize = image1.ysize();
604
0
  JXL_ENSURE(xsize == image2.xsize());
605
0
  JXL_ENSURE(ysize == image2.ysize());
606
0
  JXL_ENSURE(xsize == out->xsize());
607
0
  JXL_ENSURE(ysize == out->ysize());
608
0
  for (size_t y = 0; y < ysize; ++y) {
609
0
    const T* const JXL_RESTRICT row1 = image1.Row(y);
610
0
    const T* const JXL_RESTRICT row2 = image2.Row(y);
611
0
    T* const JXL_RESTRICT row_out = out->Row(y);
612
0
    for (size_t x = 0; x < xsize; ++x) {
613
0
      row_out[x] = row1[x] / row2[x];
614
0
    }
615
0
  }
616
0
  return true;
617
0
}
618
619
0
void ReduceRinging(const ImageF& initial, const ImageF& mask, ImageF& down) {
620
0
  int64_t xsize2 = down.xsize();
621
0
  int64_t ysize2 = down.ysize();
622
623
0
  for (size_t y = 0; y < down.ysize(); y++) {
624
0
    const float* row_mask = mask.Row(y);
625
0
    float* row_out = down.Row(y);
626
0
    for (size_t x = 0; x < down.xsize(); x++) {
627
0
      float v = down.Row(y)[x];
628
0
      float min = initial.Row(y)[x];
629
0
      float max = initial.Row(y)[x];
630
0
      for (int64_t yi = -1; yi < 2; yi++) {
631
0
        for (int64_t xi = -1; xi < 2; xi++) {
632
0
          int64_t x2 = static_cast<int64_t>(x) + xi;
633
0
          int64_t y2 = static_cast<int64_t>(y) + yi;
634
0
          if (x2 < 0 || y2 < 0 || x2 >= xsize2 || y2 >= ysize2) continue;
635
0
          min = std::min<float>(min, initial.Row(y2)[x2]);
636
0
          max = std::max<float>(max, initial.Row(y2)[x2]);
637
0
        }
638
0
      }
639
640
0
      row_out[x] = v;
641
642
      // Clamp the pixel within the value  of a small area to prevent ringning.
643
      // The mask determines how much to clamp, clamp more to reduce more
644
      // ringing in smooth areas, clamp less in noisy areas to get more
645
      // sharpness. Higher mask_multiplier gives less clamping, so less
646
      // ringing reduction.
647
0
      const constexpr float mask_multiplier = 2;
648
0
      float a = row_mask[x] * mask_multiplier;
649
0
      float clip_min = min - a;
650
0
      float clip_max = max + a;
651
0
      if (row_out[x] < clip_min) row_out[x] = clip_min;
652
0
      if (row_out[x] > clip_max) row_out[x] = clip_max;
653
0
    }
654
0
  }
655
0
}
656
657
// TODO(lode): move this to a separate file enc_downsample.cc
658
0
Status DownsampleImage2_Iterative(const ImageF& orig, ImageF* output) {
659
0
  int64_t xsize = orig.xsize();
660
0
  int64_t ysize = orig.ysize();
661
0
  int64_t xsize2 = DivCeil(orig.xsize(), 2);
662
0
  int64_t ysize2 = DivCeil(orig.ysize(), 2);
663
0
  JxlMemoryManager* memory_manager = orig.memory_manager();
664
665
0
  JXL_ASSIGN_OR_RETURN(ImageF box_downsample,
666
0
                       ImageF::Create(memory_manager, xsize, ysize));
667
0
  JXL_RETURN_IF_ERROR(CopyImageTo(orig, &box_downsample));
668
0
  JXL_ASSIGN_OR_RETURN(box_downsample, DownsampleImage(box_downsample, 2));
669
0
  JXL_ASSIGN_OR_RETURN(ImageF mask,
670
0
                       ImageF::Create(memory_manager, box_downsample.xsize(),
671
0
                                      box_downsample.ysize()));
672
0
  CreateMask(box_downsample, mask);
673
674
0
  JXL_RETURN_IF_ERROR(output->ShrinkTo(xsize2, ysize2));
675
676
  // Initial result image using the sharper downsampling.
677
  // Allocate extra space to avoid a reallocation when padding.
678
0
  JXL_ASSIGN_OR_RETURN(
679
0
      ImageF initial,
680
0
      ImageF::Create(memory_manager, DivCeil(orig.xsize(), 2) + kBlockDim,
681
0
                     DivCeil(orig.ysize(), 2) + kBlockDim));
682
0
  JXL_RETURN_IF_ERROR(initial.ShrinkTo(initial.xsize() - kBlockDim,
683
0
                                       initial.ysize() - kBlockDim));
684
0
  JXL_RETURN_IF_ERROR(DownsampleImage2_Sharper(orig, &initial));
685
686
0
  JXL_ASSIGN_OR_RETURN(
687
0
      ImageF down,
688
0
      ImageF::Create(memory_manager, initial.xsize(), initial.ysize()));
689
0
  JXL_RETURN_IF_ERROR(CopyImageTo(initial, &down));
690
0
  JXL_ASSIGN_OR_RETURN(ImageF up, ImageF::Create(memory_manager, xsize, ysize));
691
0
  JXL_ASSIGN_OR_RETURN(ImageF corr,
692
0
                       ImageF::Create(memory_manager, xsize, ysize));
693
0
  JXL_ASSIGN_OR_RETURN(ImageF corr2,
694
0
                       ImageF::Create(memory_manager, xsize2, ysize2));
695
696
  // In the weights map, relatively higher values will allow less ringing but
697
  // also less sharpness. With all constant values, it optimizes equally
698
  // everywhere. Even in this case, the weights2 computed from
699
  // this is still used and differs at the borders of the image.
700
  // TODO(lode): Make use of the weights field for anti-ringing and clamping,
701
  // the values are all set to 1 for now, but it is intended to be used for
702
  // reducing ringing based on the mask, and taking clamping into account.
703
0
  JXL_ASSIGN_OR_RETURN(ImageF weights,
704
0
                       ImageF::Create(memory_manager, xsize, ysize));
705
0
  for (size_t y = 0; y < weights.ysize(); y++) {
706
0
    auto* row = weights.Row(y);
707
0
    for (size_t x = 0; x < weights.xsize(); x++) {
708
0
      row[x] = 1;
709
0
    }
710
0
  }
711
0
  JXL_ASSIGN_OR_RETURN(ImageF weights2,
712
0
                       ImageF::Create(memory_manager, xsize2, ysize2));
713
0
  AntiUpsample(weights, &weights2);
714
715
0
  const size_t num_it = 3;
716
0
  for (size_t it = 0; it < num_it; ++it) {
717
0
    UpsampleImage(down, &up);
718
0
    JXL_ASSIGN_OR_RETURN(corr, LinComb<float>(1, orig, -1, up));
719
0
    JXL_RETURN_IF_ERROR(ElwiseMul(corr, weights, &corr));
720
0
    AntiUpsample(corr, &corr2);
721
0
    JXL_RETURN_IF_ERROR(ElwiseDiv(corr2, weights2, &corr2));
722
723
0
    JXL_ASSIGN_OR_RETURN(down, LinComb<float>(1, down, 1, corr2));
724
0
  }
725
726
0
  ReduceRinging(initial, mask, down);
727
728
  // can't just use CopyImage, because the output image was prepared with
729
  // padding.
730
0
  for (size_t y = 0; y < down.ysize(); y++) {
731
0
    for (size_t x = 0; x < down.xsize(); x++) {
732
0
      float v = down.Row(y)[x];
733
0
      output->Row(y)[x] = v;
734
0
    }
735
0
  }
736
0
  return true;
737
0
}
738
739
}  // namespace
740
741
0
Status DownsampleImage2_Iterative(Image3F* opsin) {
742
0
  JxlMemoryManager* memory_manager = opsin->memory_manager();
743
  // Allocate extra space to avoid a reallocation when padding.
744
0
  JXL_ASSIGN_OR_RETURN(
745
0
      Image3F downsampled,
746
0
      Image3F::Create(memory_manager, DivCeil(opsin->xsize(), 2) + kBlockDim,
747
0
                      DivCeil(opsin->ysize(), 2) + kBlockDim));
748
0
  JXL_RETURN_IF_ERROR(downsampled.ShrinkTo(downsampled.xsize() - kBlockDim,
749
0
                                           downsampled.ysize() - kBlockDim));
750
751
0
  JXL_ASSIGN_OR_RETURN(
752
0
      Image3F rgb,
753
0
      Image3F::Create(memory_manager, opsin->xsize(), opsin->ysize()));
754
0
  OpsinParams opsin_params;  // TODO(user): use the ones that are actually used
755
0
  opsin_params.Init(kDefaultIntensityTarget);
756
0
  JXL_RETURN_IF_ERROR(
757
0
      OpsinToLinear(*opsin, Rect(rgb), nullptr, &rgb, opsin_params));
758
759
0
  JXL_ASSIGN_OR_RETURN(
760
0
      ImageF mask,
761
0
      ImageF::Create(memory_manager, opsin->xsize(), opsin->ysize()));
762
0
  ButteraugliParams butter_params;
763
0
  JXL_ASSIGN_OR_RETURN(std::unique_ptr<ButteraugliComparator> butter,
764
0
                       ButteraugliComparator::Make(rgb, butter_params));
765
0
  JXL_RETURN_IF_ERROR(butter->Mask(&mask));
766
0
  JXL_ASSIGN_OR_RETURN(
767
0
      ImageF mask_fuzzy,
768
0
      ImageF::Create(memory_manager, opsin->xsize(), opsin->ysize()));
769
770
0
  for (size_t c = 0; c < 3; c++) {
771
0
    JXL_RETURN_IF_ERROR(
772
0
        DownsampleImage2_Iterative(opsin->Plane(c), &downsampled.Plane(c)));
773
0
  }
774
0
  *opsin = std::move(downsampled);
775
0
  return true;
776
0
}
777
778
StatusOr<Image3F> ReconstructImage(
779
    const FrameHeader& orig_frame_header, const PassesSharedState& shared,
780
558
    const std::vector<std::unique_ptr<ACImage>>& coeffs, ThreadPool* pool) {
781
558
  const FrameDimensions& frame_dim = shared.frame_dim;
782
558
  JxlMemoryManager* memory_manager = shared.memory_manager;
783
784
558
  FrameHeader frame_header = orig_frame_header;
785
558
  frame_header.UpdateFlag(shared.image_features.patches.HasAny(),
786
558
                          FrameHeader::kPatches);
787
558
  frame_header.UpdateFlag(shared.image_features.splines.HasAny(),
788
558
                          FrameHeader::kSplines);
789
558
  frame_header.color_transform = ColorTransform::kNone;
790
791
558
  auto metadata = jxl::make_unique<CodecMetadata>();
792
558
  *metadata = *frame_header.nonserialized_metadata;
793
558
  metadata->m.extra_channel_info.clear();
794
558
  metadata->m.num_extra_channels = metadata->m.extra_channel_info.size();
795
558
  frame_header.nonserialized_metadata = metadata.get();
796
558
  frame_header.extra_channel_upsampling.clear();
797
798
558
  const bool is_gray = shared.metadata->m.color_encoding.IsGray();
799
558
  auto dec_state = jxl::make_unique<PassesDecoderState>(memory_manager);
800
558
  JXL_RETURN_IF_ERROR(
801
558
      dec_state->output_encoding_info.SetFromMetadata(*shared.metadata));
802
558
  JXL_RETURN_IF_ERROR(dec_state->output_encoding_info.MaybeSetColorEncoding(
803
558
      ColorEncoding::LinearSRGB(is_gray)));
804
558
  dec_state->shared = &shared;
805
558
  JXL_RETURN_IF_ERROR(dec_state->Init(frame_header));
806
807
558
  ImageBundle decoded(memory_manager, &shared.metadata->m);
808
558
  decoded.origin = frame_header.frame_origin;
809
558
  JXL_ASSIGN_OR_RETURN(
810
558
      Image3F tmp,
811
558
      Image3F::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
812
558
  JXL_RETURN_IF_ERROR(decoded.SetFromImage(
813
558
      std::move(tmp), dec_state->output_encoding_info.color_encoding));
814
815
558
  PassesDecoderState::PipelineOptions options;
816
558
  options.use_slow_render_pipeline = false;
817
558
  options.coalescing = false;
818
558
  options.render_spotcolors = false;
819
558
  options.render_noise = true;
820
821
558
  JXL_RETURN_IF_ERROR(dec_state->PreparePipeline(
822
558
      frame_header, &shared.metadata->m, &decoded, options));
823
824
558
  AlignedArray<GroupDecCache> group_dec_caches;
825
558
  const auto allocate_storage = [&](const size_t num_threads) -> Status {
826
558
    JXL_RETURN_IF_ERROR(
827
558
        dec_state->render_pipeline->PrepareForThreads(num_threads,
828
558
                                                      /*use_group_ids=*/false));
829
558
    JXL_ASSIGN_OR_RETURN(group_dec_caches, AlignedArray<GroupDecCache>::Create(
830
558
                                               memory_manager, num_threads));
831
558
    return true;
832
558
  };
833
558
  const auto process_group = [&](const uint32_t group_index,
834
1.78k
                                 const size_t thread) -> Status {
835
1.78k
    if (frame_header.loop_filter.epf_iters > 0) {
836
1.78k
      JXL_RETURN_IF_ERROR(ComputeSigma(frame_header.loop_filter,
837
1.78k
                                       frame_dim.BlockGroupRect(group_index),
838
1.78k
                                       dec_state.get()));
839
1.78k
    }
840
1.78k
    RenderPipelineInput input =
841
1.78k
        dec_state->render_pipeline->GetInputBuffers(group_index, thread);
842
1.78k
    JXL_RETURN_IF_ERROR(DecodeGroupForRoundtrip(
843
1.78k
        frame_header, coeffs, group_index, dec_state.get(),
844
1.78k
        &group_dec_caches[thread], thread, input, nullptr, nullptr));
845
1.78k
    if ((frame_header.flags & FrameHeader::kNoise) != 0) {
846
0
      PrepareNoiseInput(*dec_state, shared.frame_dim, frame_header, group_index,
847
0
                        thread);
848
0
    }
849
1.78k
    JXL_RETURN_IF_ERROR(input.Done());
850
1.78k
    return true;
851
1.78k
  };
852
558
  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.num_groups, allocate_storage,
853
558
                                process_group, "ReconstructImage"));
854
558
  return std::move(*decoded.color());
855
558
}
856
857
  
858
float ComputeBlockL2Distance(const Image3F& a, const Image3F& b,
859
1.19M
                             const ImageF& mask1x1, size_t by, size_t bx) {
860
1.19M
  Rect rect(bx * kBlockDim, by * kBlockDim, kBlockDim, kBlockDim, a.xsize(),
861
1.19M
            a.ysize());
862
1.19M
  float err2[3] = {0.0f};
863
10.7M
  for (size_t y = 0; y < rect.ysize(); ++y) {
864
9.50M
    const float* row_a[3] = {
865
9.50M
        rect.ConstPlaneRow(a, 0, y),
866
9.50M
        rect.ConstPlaneRow(a, 1, y),
867
9.50M
        rect.ConstPlaneRow(a, 2, y),
868
9.50M
    };
869
9.50M
    const float* row_b[3] = {
870
9.50M
        rect.ConstPlaneRow(b, 0, y),
871
9.50M
        rect.ConstPlaneRow(b, 1, y),
872
9.50M
        rect.ConstPlaneRow(b, 2, y),
873
9.50M
    };
874
9.50M
    const float* row_mask = rect.ConstRow(mask1x1, y);
875
84.9M
    for (size_t x = 0; x < rect.xsize(); ++x) {
876
75.4M
      float mask = row_mask[x];
877
75.4M
      float mask2 = mask * mask;
878
301M
      for (int i = 0; i < 3; ++i) {
879
226M
  float diff = row_a[i][x] - row_b[i][x];
880
226M
  err2[i] += mask2 * diff * diff;
881
226M
      }
882
75.4M
    }
883
9.50M
  }
884
1.19M
  static const double kW[] = {
885
1.19M
      12.339445295782363,
886
1.19M
      1.0,
887
1.19M
      0.2,
888
1.19M
  };
889
1.19M
  float retval = kW[0] * err2[0] + kW[1] * err2[1] + kW[2] * err2[2];
890
1.19M
  return retval;
891
1.19M
}
892
893
Status ComputeARHeuristics(const FrameHeader& frame_header,
894
                           PassesEncoderState* enc_state,
895
                           const Image3F& orig_opsin, const Rect& rect,
896
186
                           ThreadPool* pool) {
897
186
  const CompressParams& cparams = enc_state->cparams;
898
186
  PassesSharedState& shared = enc_state->shared;
899
186
  const FrameDimensions& frame_dim = shared.frame_dim;
900
186
  const ImageF& initial_quant_masking1x1 = enc_state->initial_quant_masking1x1;
901
186
  ImageB& epf_sharpness = shared.epf_sharpness;
902
186
  JxlMemoryManager* memory_manager = enc_state->memory_manager();
903
904
186
  float clamped_butteraugli = std::min(5.0f, cparams.butteraugli_distance);
905
186
  if (cparams.butteraugli_distance < kMinButteraugliForDynamicAR ||
906
186
      cparams.speed_tier > SpeedTier::kWombat ||
907
186
      frame_header.loop_filter.epf_iters == 0) {
908
0
    FillPlane(static_cast<uint8_t>(4), &epf_sharpness, Rect(epf_sharpness));
909
0
    return true;
910
0
  }
911
912
186
  std::vector<uint8_t> epf_steps;
913
186
  if (cparams.butteraugli_distance > 4.5f) {
914
0
    epf_steps.push_back(0);
915
0
    epf_steps.push_back(4);
916
186
  } else {
917
186
    epf_steps.push_back(0);
918
186
    epf_steps.push_back(2);
919
186
    epf_steps.push_back(7);
920
186
  }
921
186
  static const int kNumEPFVals = 8;
922
186
  size_t epf_steps_lut[kNumEPFVals] = {0};
923
186
  {
924
744
    for (size_t i = 0; i < epf_steps.size(); ++i) {
925
558
      epf_steps_lut[epf_steps[i]] = i;
926
558
    }
927
186
  }
928
186
  std::array<ImageF, kNumEPFVals> error_images;
929
558
  for (uint8_t val : epf_steps) {
930
558
    FillPlane(val, &epf_sharpness, Rect(epf_sharpness));
931
558
    JXL_ASSIGN_OR_RETURN(
932
558
        Image3F decoded,
933
558
        ReconstructImage(frame_header, shared, enc_state->coeffs, pool));
934
558
    JXL_ASSIGN_OR_RETURN(error_images[val],
935
558
                         ImageF::Create(memory_manager, frame_dim.xsize_blocks,
936
558
                                        frame_dim.ysize_blocks));
937
22.5k
    for (size_t by = 0; by < frame_dim.ysize_blocks; by++) {
938
21.9k
      float* error_row = error_images[val].Row(by);
939
1.21M
      for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) {
940
1.19M
        error_row[bx] = ComputeBlockL2Distance(
941
1.19M
      orig_opsin, decoded, initial_quant_masking1x1, by, bx);
942
1.19M
      }
943
21.9k
    }
944
558
  }
945
186
  std::vector<std::vector<size_t>> histo(9, std::vector<size_t>(kNumEPFVals));
946
186
  std::vector<size_t> totals(9, 1);
947
186
  static const float kFavorNoSmoothing = 0.99;
948
7.51k
  for (size_t by = 0; by < frame_dim.ysize_blocks; by++) {
949
7.32k
    uint8_t* JXL_RESTRICT out_row = epf_sharpness.Row(by);
950
7.32k
    uint8_t* JXL_RESTRICT prev_row = epf_sharpness.Row(by > 0 ? by - 1 : 0);
951
405k
    for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) {
952
398k
      uint8_t best_val = 0;
953
398k
      float best_error = std::numeric_limits<float>::max();
954
398k
      uint8_t top_val = by > 0 ? prev_row[bx] : 0;
955
398k
      uint8_t left_val = bx > 0 ? out_row[bx - 1] : 0;
956
398k
      float top_error = error_images[top_val].Row(by)[bx];
957
398k
      float left_error = error_images[left_val].Row(by)[bx];
958
1.19M
      for (uint8_t val : epf_steps) {
959
1.19M
        float error = error_images[val].Row(by)[bx];
960
1.19M
        if (val == 0) {
961
398k
    error *= kFavorNoSmoothing;
962
398k
        }
963
1.19M
        if (error < best_error) {
964
632k
          best_val = val;
965
632k
          best_error = error;
966
632k
        }
967
1.19M
      }
968
398k
      if (best_error < std::min(top_error, left_error)) {
969
216k
        out_row[bx] = best_val;
970
216k
      } else if (top_error < left_error) {
971
28.1k
        out_row[bx] = top_val;
972
153k
      } else {
973
153k
        out_row[bx] = left_val;
974
153k
      }
975
398k
      int context = epf_steps_lut[top_val] * 3 + epf_steps_lut[left_val];
976
398k
      ++histo[context][out_row[bx]];
977
398k
      ++totals[context];
978
398k
    }
979
7.32k
  }
980
186
  const float c3base = 0.98017198824148288;
981
186
  const float c3clamp = 0.85970338919928291;
982
186
  const float c3 = std::max(c3clamp, std::pow(c3base, clamped_butteraugli));
983
186
  static const float c5 = 0.1087690359555803;
984
186
  float mul[3 * 3 * 3] = {0};
985
558
  for (uint8_t top_val : epf_steps) {
986
1.67k
    for (uint8_t left_val : epf_steps) {
987
5.02k
      for (uint8_t val : epf_steps) {
988
5.02k
        int context = epf_steps_lut[top_val] * 3 + epf_steps_lut[left_val];
989
5.02k
        const auto& ctx_histo = histo[context];
990
5.02k
        const int mulix = epf_steps_lut[val] + 3 * context;
991
5.02k
        mul[mulix] = 1.0 / (1.0 + c5 * std::log1p(ctx_histo[val] /
992
5.02k
              totals[context]) / clamped_butteraugli);
993
5.02k
        if (val == 0) {
994
1.67k
          mul[mulix] *= c3;
995
1.67k
        }
996
5.02k
      }
997
1.67k
    }
998
558
  }
999
7.51k
  for (size_t by = 0; by < frame_dim.ysize_blocks; by++) {
1000
7.32k
    uint8_t* JXL_RESTRICT out_row = epf_sharpness.Row(by);
1001
7.32k
    uint8_t* JXL_RESTRICT prev_row = epf_sharpness.Row(by > 0 ? by - 1 : 0);
1002
405k
    for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) {
1003
398k
      uint8_t best_val = 0;
1004
398k
      float best_error = std::numeric_limits<float>::max();
1005
398k
      uint8_t top_val = by > 0 ? prev_row[bx] : 0;
1006
398k
      uint8_t left_val = bx > 0 ? out_row[bx - 1] : 0;
1007
398k
      int context = epf_steps_lut[top_val] * 3 + epf_steps_lut[left_val];
1008
1.19M
      for (uint8_t val : epf_steps) {
1009
1.19M
        int mulix = epf_steps_lut[val] + 3 * context;
1010
1.19M
        float error = error_images[val].Row(by)[bx] * mul[mulix];
1011
1.19M
        if (error < best_error) {
1012
614k
          best_val = val;
1013
614k
          best_error = error;
1014
614k
        }
1015
1.19M
      }
1016
398k
      out_row[bx] = best_val;
1017
398k
    }
1018
7.32k
  }
1019
1020
186
  return true;
1021
186
}
1022
1023
Status LossyFrameHeuristics(const FrameHeader& frame_header,
1024
                            PassesEncoderState* enc_state,
1025
                            ModularFrameEncoder* modular_frame_encoder,
1026
                            const Image3F* linear, Image3F* opsin,
1027
                            const Rect& rect, const JxlCmsInterface& cms,
1028
186
                            ThreadPool* pool, AuxOut* aux_out) {
1029
186
  const CompressParams& cparams = enc_state->cparams;
1030
186
  const bool streaming_mode = enc_state->streaming_mode;
1031
186
  const bool initialize_global_state = enc_state->initialize_global_state;
1032
186
  PassesSharedState& shared = enc_state->shared;
1033
186
  const FrameDimensions& frame_dim = shared.frame_dim;
1034
186
  ImageFeatures& image_features = shared.image_features;
1035
186
  DequantMatrices& matrices = shared.matrices;
1036
186
  Quantizer& quantizer = shared.quantizer;
1037
186
  ImageF& initial_quant_masking1x1 = enc_state->initial_quant_masking1x1;
1038
186
  ImageI& raw_quant_field = shared.raw_quant_field;
1039
186
  ColorCorrelationMap& cmap = shared.cmap;
1040
186
  AcStrategyImage& ac_strategy = shared.ac_strategy;
1041
186
  BlockCtxMap& block_ctx_map = shared.block_ctx_map;
1042
186
  JxlMemoryManager* memory_manager = enc_state->memory_manager();
1043
1044
  // Find and subtract splines.
1045
186
  if (cparams.custom_splines.HasAny()) {
1046
0
    image_features.splines = cparams.custom_splines;
1047
0
  }
1048
186
  if (!streaming_mode && cparams.speed_tier <= SpeedTier::kSquirrel) {
1049
186
    if (!cparams.custom_splines.HasAny()) {
1050
186
      image_features.splines = FindSplines(*opsin);
1051
186
    }
1052
186
    JXL_RETURN_IF_ERROR(image_features.splines.InitializeDrawCache(
1053
186
        opsin->xsize(), opsin->ysize(), cmap.base()));
1054
186
    image_features.splines.SubtractFrom(opsin);
1055
186
  }
1056
1057
  // Find and subtract patches/dots.
1058
186
  if (!streaming_mode &&
1059
186
      ApplyOverride(cparams.patches,
1060
186
                    cparams.speed_tier <= SpeedTier::kSquirrel)) {
1061
186
    JXL_RETURN_IF_ERROR(
1062
186
        FindBestPatchDictionary(*opsin, enc_state, cms, pool, aux_out));
1063
186
    JXL_RETURN_IF_ERROR(
1064
186
        PatchDictionaryEncoder::SubtractFrom(image_features.patches, opsin));
1065
186
  }
1066
1067
186
  const float quant_dc = InitialQuantDC(cparams.butteraugli_distance);
1068
1069
  // TODO(veluca): we can now run all the code from here to FindBestQuantizer
1070
  // (excluded) one rect at a time. Do that.
1071
1072
  // Dependency graph:
1073
  //
1074
  // input: either XYB or input image
1075
  //
1076
  // input image -> XYB [optional]
1077
  // XYB -> initial quant field
1078
  // XYB -> Gaborished XYB
1079
  // Gaborished XYB -> CfL1
1080
  // initial quant field, Gaborished XYB, CfL1 -> ACS
1081
  // initial quant field, ACS, Gaborished XYB -> EPF control field
1082
  // initial quant field -> adjusted initial quant field
1083
  // adjusted initial quant field, ACS -> raw quant field
1084
  // raw quant field, ACS, Gaborished XYB -> CfL2
1085
  //
1086
  // output: Gaborished XYB, CfL, ACS, raw quant field, EPF control field.
1087
1088
186
  AcStrategyHeuristics acs_heuristics(memory_manager, cparams);
1089
186
  CfLHeuristics cfl_heuristics(memory_manager);
1090
186
  ImageF initial_quant_field;
1091
186
  ImageF initial_quant_masking;
1092
1093
  // Compute an initial estimate of the quantization field.
1094
  // Call InitialQuantField only in Hare mode or slower. Otherwise, rely
1095
  // on simple heuristics in FindBestAcStrategy, or set a constant for Falcon
1096
  // mode.
1097
186
  if (cparams.speed_tier > SpeedTier::kHare ||
1098
186
      cparams.disable_perceptual_optimizations) {
1099
0
    JXL_ASSIGN_OR_RETURN(initial_quant_field,
1100
0
                         ImageF::Create(memory_manager, frame_dim.xsize_blocks,
1101
0
                                        frame_dim.ysize_blocks));
1102
0
    JXL_ASSIGN_OR_RETURN(initial_quant_masking,
1103
0
                         ImageF::Create(memory_manager, frame_dim.xsize_blocks,
1104
0
                                        frame_dim.ysize_blocks));
1105
0
    float q = 0.79 / cparams.butteraugli_distance;
1106
0
    FillImage(q, &initial_quant_field);
1107
0
    float masking = 1.0f / (q + 0.001f);
1108
0
    FillImage(masking, &initial_quant_masking);
1109
0
    if (cparams.disable_perceptual_optimizations) {
1110
0
      JXL_ASSIGN_OR_RETURN(
1111
0
          initial_quant_masking1x1,
1112
0
          ImageF::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
1113
0
      FillImage(masking, &initial_quant_masking1x1);
1114
0
    }
1115
0
    quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0);
1116
186
  } else {
1117
    // Call this here, as it relies on pre-gaborish values.
1118
186
    float butteraugli_distance_for_iqf = cparams.butteraugli_distance;
1119
186
    if (!frame_header.loop_filter.gab) {
1120
0
      butteraugli_distance_for_iqf *= 0.62f;
1121
0
    }
1122
186
    JXL_ASSIGN_OR_RETURN(
1123
186
        initial_quant_field,
1124
186
        InitialQuantField(butteraugli_distance_for_iqf, *opsin, rect, pool,
1125
186
                          1.0f, &initial_quant_masking,
1126
186
                          &initial_quant_masking1x1));
1127
186
    float q = 0.39 / cparams.butteraugli_distance;
1128
186
    quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0);
1129
186
  }
1130
1131
  // TODO(veluca): do something about animations.
1132
1133
  // Apply inverse-gaborish.
1134
186
  if (frame_header.loop_filter.gab) {
1135
    // Changing the weight here to 0.99f would help to reduce ringing in
1136
    // generation loss.
1137
186
    float weight[3] = {
1138
186
        1.0f,
1139
186
        1.0f,
1140
186
        1.0f,
1141
186
    };
1142
186
    JXL_RETURN_IF_ERROR(GaborishInverse(opsin, rect, weight, pool));
1143
186
  }
1144
1145
186
  if (initialize_global_state) {
1146
186
    JXL_RETURN_IF_ERROR(FindBestDequantMatrices(
1147
186
        memory_manager, cparams, modular_frame_encoder, &matrices));
1148
186
  }
1149
1150
186
  JXL_RETURN_IF_ERROR(cfl_heuristics.Init(rect));
1151
186
  JXL_RETURN_IF_ERROR(acs_heuristics.Init(*opsin, rect, initial_quant_field,
1152
186
                                          initial_quant_masking,
1153
186
                                          initial_quant_masking1x1, &matrices));
1154
1155
7.22k
  auto process_tile = [&](const uint32_t tid, const size_t thread) -> Status {
1156
7.22k
    size_t n_enc_tiles = DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks);
1157
7.22k
    size_t tx = tid % n_enc_tiles;
1158
7.22k
    size_t ty = tid / n_enc_tiles;
1159
7.22k
    size_t by0 = ty * kEncTileDimInBlocks;
1160
7.22k
    size_t by1 =
1161
7.22k
        std::min((ty + 1) * kEncTileDimInBlocks, frame_dim.ysize_blocks);
1162
7.22k
    size_t bx0 = tx * kEncTileDimInBlocks;
1163
7.22k
    size_t bx1 =
1164
7.22k
        std::min((tx + 1) * kEncTileDimInBlocks, frame_dim.xsize_blocks);
1165
7.22k
    Rect r(bx0, by0, bx1 - bx0, by1 - by0);
1166
1167
    // For speeds up to Wombat, we only compute the color correlation map
1168
    // once we know the transform type and the quantization map.
1169
7.22k
    if (cparams.speed_tier <= SpeedTier::kSquirrel) {
1170
7.22k
      JXL_RETURN_IF_ERROR(cfl_heuristics.ComputeTile(
1171
7.22k
          r, *opsin, rect, matrices,
1172
7.22k
          /*ac_strategy=*/nullptr,
1173
7.22k
          /*raw_quant_field=*/nullptr,
1174
7.22k
          /*quantizer=*/nullptr, /*fast=*/false, thread, &cmap));
1175
7.22k
    }
1176
1177
    // Choose block sizes.
1178
7.22k
    JXL_RETURN_IF_ERROR(
1179
7.22k
        acs_heuristics.ProcessRect(r, cmap, &ac_strategy, thread));
1180
1181
    // Always set the initial quant field, so we can compute the CfL map with
1182
    // more accuracy. The initial quant field might change in slower modes, but
1183
    // adjusting the quant field with butteraugli when all the other encoding
1184
    // parameters are fixed is likely a more reliable choice anyway.
1185
7.22k
    JXL_RETURN_IF_ERROR(AdjustQuantField(
1186
7.22k
        ac_strategy, r, cparams.butteraugli_distance, &initial_quant_field));
1187
7.22k
    quantizer.SetQuantFieldRect(initial_quant_field, r, &raw_quant_field);
1188
1189
    // Compute a non-default CfL map if we are at Hare speed, or slower.
1190
7.22k
    if (cparams.speed_tier <= SpeedTier::kHare) {
1191
7.22k
      JXL_RETURN_IF_ERROR(cfl_heuristics.ComputeTile(
1192
7.22k
          r, *opsin, rect, matrices, &ac_strategy, &raw_quant_field, &quantizer,
1193
7.22k
          /*fast=*/cparams.speed_tier >= SpeedTier::kWombat, thread, &cmap));
1194
7.22k
    }
1195
7.22k
    return true;
1196
7.22k
  };
1197
186
  size_t num_tiles = DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks) *
1198
186
                     DivCeil(frame_dim.ysize_blocks, kEncTileDimInBlocks);
1199
186
  const auto prepare = [&](const size_t num_threads) -> Status {
1200
186
    JXL_RETURN_IF_ERROR(acs_heuristics.PrepareForThreads(num_threads));
1201
186
    JXL_RETURN_IF_ERROR(cfl_heuristics.PrepareForThreads(num_threads));
1202
186
    return true;
1203
186
  };
1204
186
  JXL_RETURN_IF_ERROR(
1205
186
      RunOnPool(pool, 0, num_tiles, prepare, process_tile, "Enc Heuristics"));
1206
1207
186
  JXL_RETURN_IF_ERROR(acs_heuristics.Finalize(frame_dim, ac_strategy, aux_out));
1208
1209
  // Refine quantization levels.
1210
186
  if (!streaming_mode && !cparams.disable_perceptual_optimizations) {
1211
186
    ImageB& epf_sharpness = shared.epf_sharpness;
1212
186
    FillPlane(static_cast<uint8_t>(4), &epf_sharpness, Rect(epf_sharpness));
1213
186
    JXL_RETURN_IF_ERROR(FindBestQuantizer(frame_header, linear, *opsin,
1214
186
                                          initial_quant_field, enc_state, cms,
1215
186
                                          pool, aux_out));
1216
186
  }
1217
1218
  // Choose a context model that depends on the amount of quantization for AC.
1219
186
  if (cparams.speed_tier < SpeedTier::kFalcon && initialize_global_state) {
1220
186
    FindBestBlockEntropyModel(cparams, raw_quant_field, ac_strategy,
1221
186
                              &block_ctx_map);
1222
186
  }
1223
186
  return true;
1224
186
}
1225
1226
}  // namespace jxl