Coverage Report

Created: 2026-06-07 07:20

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/modular/encoding/enc_encoding.cc
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include <jxl/memory_manager.h>
7
8
#include <algorithm>
9
#include <array>
10
#include <cstddef>
11
#include <cstdint>
12
#include <cstdlib>
13
#include <limits>
14
#include <queue>
15
#include <utility>
16
#include <vector>
17
18
#include "lib/jxl/base/bits.h"
19
#include "lib/jxl/base/common.h"
20
#include "lib/jxl/base/compiler_specific.h"
21
#include "lib/jxl/base/printf_macros.h"
22
#include "lib/jxl/base/status.h"
23
#include "lib/jxl/enc_ans.h"
24
#include "lib/jxl/enc_ans_params.h"
25
#include "lib/jxl/enc_aux_out.h"
26
#include "lib/jxl/enc_bit_writer.h"
27
#include "lib/jxl/enc_fields.h"
28
#include "lib/jxl/fields.h"
29
#include "lib/jxl/image.h"
30
#include "lib/jxl/image_ops.h"
31
#include "lib/jxl/modular/encoding/context_predict.h"
32
#include "lib/jxl/modular/encoding/dec_ma.h"
33
#include "lib/jxl/modular/encoding/enc_ma.h"
34
#include "lib/jxl/modular/encoding/encoding.h"
35
#include "lib/jxl/modular/encoding/ma_common.h"
36
#include "lib/jxl/modular/modular_image.h"
37
#include "lib/jxl/modular/options.h"
38
#include "lib/jxl/pack_signed.h"
39
40
namespace jxl {
41
42
namespace {
43
// Plot tree (if enabled) and predictor usage map.
44
constexpr bool kWantDebug = true;
45
// constexpr bool kPrintTree = false;
46
47
296M
inline std::array<uint8_t, 3> PredictorColor(Predictor p) {
48
296M
  switch (p) {
49
20.1M
    case Predictor::Zero:
50
20.1M
      return {{0, 0, 0}};
51
6.27M
    case Predictor::Left:
52
6.27M
      return {{255, 0, 0}};
53
0
    case Predictor::Top:
54
0
      return {{0, 255, 0}};
55
0
    case Predictor::Average0:
56
0
      return {{0, 0, 255}};
57
0
    case Predictor::Average4:
58
0
      return {{192, 128, 128}};
59
0
    case Predictor::Select:
60
0
      return {{255, 255, 0}};
61
270M
    case Predictor::Gradient:
62
270M
      return {{255, 0, 255}};
63
24.6k
    case Predictor::Weighted:
64
24.6k
      return {{0, 255, 255}};
65
      // TODO(jon)
66
0
    default:
67
0
      return {{255, 255, 255}};
68
296M
  };
69
0
}
70
71
// `cutoffs` must be sorted.
72
Tree MakeFixedTree(int property, const std::vector<int32_t> &cutoffs,
73
2.75k
                   Predictor pred, size_t num_pixels, int bitdepth) {
74
2.75k
  size_t log_px = CeilLog2Nonzero(num_pixels);
75
2.75k
  size_t min_gap = 0;
76
  // Reduce fixed tree height when encoding small images.
77
2.75k
  if (log_px < 14) {
78
2.22k
    min_gap = 8 * (14 - log_px);
79
2.22k
  }
80
2.75k
  const int shift = bitdepth > 11 ? std::min(4, bitdepth - 11) : 0;
81
2.75k
  const int mul = 1 << shift;
82
2.75k
  Tree tree;
83
2.75k
  struct NodeInfo {
84
2.75k
    size_t begin, end, pos;
85
2.75k
  };
86
2.75k
  std::queue<NodeInfo> q;
87
  // Leaf IDs will be set by roundtrip decoding the tree.
88
2.75k
  tree.push_back(PropertyDecisionNode::Leaf(pred));
89
2.75k
  q.push(NodeInfo{0, cutoffs.size(), 0});
90
44.0k
  while (!q.empty()) {
91
41.2k
    NodeInfo info = q.front();
92
41.2k
    q.pop();
93
41.2k
    if (info.begin + min_gap >= info.end) continue;
94
19.2k
    uint32_t split = (info.begin + info.end) / 2;
95
19.2k
    int32_t cutoff = cutoffs[split] * mul;
96
19.2k
    tree[info.pos] = PropertyDecisionNode::Split(property, cutoff, tree.size());
97
19.2k
    q.push(NodeInfo{split + 1, info.end, tree.size()});
98
19.2k
    tree.push_back(PropertyDecisionNode::Leaf(pred));
99
19.2k
    q.push(NodeInfo{info.begin, split, tree.size()});
100
19.2k
    tree.push_back(PropertyDecisionNode::Leaf(pred));
101
19.2k
  }
102
2.75k
  return tree;
103
2.75k
}
104
105
Status GatherTreeData(const Image &image, pixel_type chan, size_t group_id,
106
                      const weighted::Header &wp_header,
107
                      const ModularOptions &options, TreeSamples &tree_samples,
108
6.17k
                      size_t *total_pixels) {
109
6.17k
  const Channel &channel = image.channel[chan];
110
6.17k
  JxlMemoryManager *memory_manager = channel.memory_manager();
111
112
6.17k
  JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w,
113
6.17k
              channel.h, chan);
114
115
6.17k
  std::array<pixel_type, kNumStaticProperties> static_props = {
116
6.17k
      {chan, static_cast<int>(group_id)}};
117
6.17k
  Properties properties(kNumNonrefProperties +
118
6.17k
                        kExtraPropsPerChannel * options.max_properties);
119
6.17k
  double pixel_fraction = std::min(1.0f, options.nb_repeats);
120
  // a fraction of 0 is used to disable learning entirely.
121
6.17k
  if (pixel_fraction > 0) {
122
6.17k
    pixel_fraction = std::max(pixel_fraction,
123
6.17k
                              std::min(1.0, 1024.0 / (channel.w * channel.h)));
124
6.17k
  }
125
6.17k
  uint64_t threshold =
126
6.17k
      (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction;
127
6.17k
  uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull),
128
6.17k
                   static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)};
129
  // Xorshift128+ adapted from xorshift128+-inl.h
130
111M
  auto use_sample = [&]() {
131
111M
    auto s1 = s[0];
132
111M
    const auto s0 = s[1];
133
111M
    const auto bits = s1 + s0;  // b, c
134
111M
    s[0] = s0;
135
111M
    s1 ^= s1 << 23;
136
111M
    s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5);
137
111M
    s[1] = s1;
138
111M
    return (bits >> 32) <= threshold;
139
111M
  };
140
141
6.17k
  const ptrdiff_t onerow = channel.plane.PixelsPerRow();
142
6.17k
  JXL_ASSIGN_OR_RETURN(
143
6.17k
      Channel references,
144
6.17k
      Channel::Create(memory_manager, properties.size() - kNumNonrefProperties,
145
6.17k
                      channel.w));
146
6.17k
  weighted::State wp_state(wp_header, channel.w, channel.h);
147
6.17k
  tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64);
148
6.17k
  const bool multiple_predictors = tree_samples.NumPredictors() != 1;
149
4.62M
  auto compute_sample = [&](const pixel_type *p, size_t x, size_t y) {
150
4.62M
    pixel_type_w pred[kNumModularPredictors];
151
4.62M
    if (multiple_predictors) {
152
0
      PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references,
153
0
                      &wp_state, pred);
154
4.62M
    } else {
155
4.62M
      pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] =
156
4.62M
          PredictLearn(&properties, channel.w, p + x, onerow, x, y,
157
4.62M
                       tree_samples.PredictorFromIndex(0), references,
158
4.62M
                       &wp_state)
159
4.62M
              .guess;
160
4.62M
    }
161
4.62M
    (*total_pixels)++;
162
4.62M
    if (use_sample()) {
163
2.50M
      tree_samples.AddSample(p[x], properties, pred);
164
2.50M
    }
165
4.62M
    wp_state.UpdateErrors(p[x], x, y, channel.w);
166
4.62M
  };
167
168
636k
  for (size_t y = 0; y < channel.h; y++) {
169
629k
    const pixel_type *JXL_RESTRICT p = channel.Row(y);
170
629k
    PrecomputeReferences(channel, y, image, chan, &references);
171
629k
    InitPropsRow(&properties, static_props, y);
172
173
    // TODO(veluca): avoid computing WP if we don't use its property or
174
    // predictions.
175
629k
    if (y > 1 && channel.w > 8 && references.w == 0) {
176
1.82M
      for (size_t x = 0; x < 2; x++) {
177
1.21M
        compute_sample(p, x, y);
178
1.21M
      }
179
107M
      for (size_t x = 2; x < channel.w - 2; x++) {
180
106M
        pixel_type_w pred[kNumModularPredictors];
181
106M
        if (multiple_predictors) {
182
0
          PredictLearnAllNEC(&properties, channel.w, p + x, onerow, x, y,
183
0
                             references, &wp_state, pred);
184
106M
        } else {
185
106M
          pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] =
186
106M
              PredictLearnNEC(&properties, channel.w, p + x, onerow, x, y,
187
106M
                              tree_samples.PredictorFromIndex(0), references,
188
106M
                              &wp_state)
189
106M
                  .guess;
190
106M
        }
191
106M
        (*total_pixels)++;
192
106M
        if (use_sample()) {
193
53.6M
          tree_samples.AddSample(p[x], properties, pred);
194
53.6M
        }
195
106M
        wp_state.UpdateErrors(p[x], x, y, channel.w);
196
106M
      }
197
1.82M
      for (size_t x = channel.w - 2; x < channel.w; x++) {
198
1.21M
        compute_sample(p, x, y);
199
1.21M
      }
200
607k
    } else {
201
2.22M
      for (size_t x = 0; x < channel.w; x++) {
202
2.20M
        compute_sample(p, x, y);
203
2.20M
      }
204
22.7k
    }
205
629k
  }
206
6.17k
  return true;
207
6.17k
}
208
209
StatusOr<Tree> LearnTree(
210
    TreeSamples &&tree_samples, size_t total_pixels,
211
    const ModularOptions &options,
212
    const std::vector<ModularMultiplierInfo> &multiplier_info = {},
213
2.83k
    StaticPropRange static_prop_range = {}) {
214
2.83k
  Tree tree;
215
8.50k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
216
5.66k
    if (static_prop_range[i][1] == 0) {
217
0
      static_prop_range[i][1] = std::numeric_limits<uint32_t>::max();
218
0
    }
219
5.66k
  }
220
2.83k
  if (!tree_samples.HasSamples()) {
221
129
    tree.emplace_back();
222
129
    tree.back().predictor = tree_samples.PredictorFromIndex(0);
223
129
    tree.back().property = -1;
224
129
    tree.back().predictor_offset = 0;
225
129
    tree.back().multiplier = 1;
226
129
    return tree;
227
129
  }
228
2.70k
  float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels;
229
2.70k
  float required_cost = pixel_fraction * 0.9 + 0.1;
230
2.70k
  tree_samples.AllSamplesDone();
231
2.70k
  JXL_RETURN_IF_ERROR(ComputeBestTree(
232
2.70k
      tree_samples, options.splitting_heuristics_node_threshold * required_cost,
233
2.70k
      multiplier_info, static_prop_range, options.fast_decode_multiplier,
234
2.70k
      &tree));
235
2.70k
  return tree;
236
2.70k
}
237
238
Status EncodeModularChannelMAANS(const Image &image, pixel_type chan,
239
                                 const weighted::Header &wp_header,
240
                                 const Tree &global_tree, Token **tokenpp,
241
25.3k
                                 size_t group_id, bool skip_encoder_fast_path) {
242
25.3k
  const Channel &channel = image.channel[chan];
243
25.3k
  JxlMemoryManager *memory_manager = channel.memory_manager();
244
25.3k
  Token *tokenp = *tokenpp;
245
25.3k
  JXL_ENSURE(channel.w != 0 && channel.h != 0);
246
247
25.3k
  Image3F predictor_img;
248
25.3k
  if (kWantDebug) {
249
25.3k
    JXL_ASSIGN_OR_RETURN(predictor_img,
250
25.3k
                         Image3F::Create(memory_manager, channel.w, channel.h));
251
25.3k
  }
252
253
25.3k
  JXL_DEBUG_V(6,
254
25.3k
              "Encoding %" PRIuS "x%" PRIuS
255
25.3k
              " channel %d, "
256
25.3k
              "(shift=%i,%i)",
257
25.3k
              channel.w, channel.h, chan, channel.hshift, channel.vshift);
258
259
25.3k
  std::array<pixel_type, kNumStaticProperties> static_props = {
260
25.3k
      {chan, static_cast<int>(group_id)}};
261
25.3k
  bool use_wp;
262
25.3k
  bool is_wp_only;
263
25.3k
  bool is_gradient_only;
264
25.3k
  size_t num_props;
265
25.3k
  FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp,
266
25.3k
                             &is_wp_only, &is_gradient_only);
267
25.3k
  MATreeLookup tree_lookup(tree);
268
25.3k
  JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size());
269
270
  // Check if this tree is a WP-only tree with a small enough property value
271
  // range.
272
  // Initialized to avoid clang-tidy complaining.
273
25.3k
  auto tree_lut = jxl::make_unique<TreeLut<uint16_t, false, false>>();
274
25.3k
  if (is_wp_only) {
275
8.20k
    is_wp_only = TreeToLookupTable(tree, *tree_lut);
276
8.20k
  }
277
25.3k
  if (is_gradient_only) {
278
4.22k
    is_gradient_only = TreeToLookupTable(tree, *tree_lut);
279
4.22k
  }
280
281
25.3k
  if (is_wp_only && !skip_encoder_fast_path) {
282
32.8k
    for (size_t c = 0; c < 3; c++) {
283
24.6k
      FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]),
284
24.6k
                &predictor_img.Plane(c));
285
24.6k
    }
286
8.20k
    const ptrdiff_t onerow = channel.plane.PixelsPerRow();
287
8.20k
    weighted::State wp_state(wp_header, channel.w, channel.h);
288
8.20k
    Properties properties(1);
289
8.20k
    bool unhealthy = false;
290
294k
    for (size_t y = 0; y < channel.h; y++) {
291
286k
      const pixel_type *JXL_RESTRICT r = channel.Row(y);
292
14.6M
      for (size_t x = 0; x < channel.w; x++) {
293
14.3M
        size_t offset = 0;
294
14.3M
        pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
295
14.3M
        pixel_type_w top = (y ? *(r + x - onerow) : left);
296
14.3M
        pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
297
14.3M
        pixel_type_w topright =
298
14.3M
            (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top);
299
14.3M
        pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
300
14.3M
        int32_t guess = wp_state.Predict</*compute_properties=*/true>(
301
14.3M
            x, y, channel.w, top, left, topright, topleft, toptop, &properties,
302
14.3M
            offset);
303
14.3M
        uint32_t pos =
304
14.3M
            kPropRangeFast +
305
14.3M
            jxl::Clamp1(properties[0], -kPropRangeFast, kPropRangeFast - 1);
306
14.3M
        uint32_t ctx_id = tree_lut->context_lookup[pos];
307
14.3M
        int32_t residual;
308
14.3M
        unhealthy |= SubOverflow(r[x], guess, residual);
309
14.3M
        *tokenp++ = Token(ctx_id, PackSigned(residual));
310
14.3M
        wp_state.UpdateErrors(r[x], x, y, channel.w);
311
14.3M
      }
312
286k
    }
313
8.20k
    if (unhealthy) {
314
0
      return JXL_FAILURE("Residual overflow");
315
0
    }
316
17.1k
  } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient &&
317
3.55k
             tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
318
3.55k
             !skip_encoder_fast_path) {
319
14.2k
    for (size_t c = 0; c < 3; c++) {
320
10.6k
      FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
321
10.6k
                &predictor_img.Plane(c));
322
10.6k
    }
323
3.55k
    const ptrdiff_t onerow = channel.plane.PixelsPerRow();
324
3.55k
    bool unhealthy = false;
325
128k
    for (size_t y = 0; y < channel.h; y++) {
326
125k
      const pixel_type *JXL_RESTRICT r = channel.Row(y);
327
19.8M
      for (size_t x = 0; x < channel.w; x++) {
328
19.7M
        pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
329
19.7M
        pixel_type_w top = (y ? *(r + x - onerow) : left);
330
19.7M
        pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
331
19.7M
        int32_t guess = ClampedGradient(top, left, topleft);
332
19.7M
        int32_t residual;
333
19.7M
        unhealthy |= SubOverflow(r[x], guess, residual);
334
19.7M
        *tokenp++ = Token(tree[0].childID, PackSigned(residual));
335
19.7M
      }
336
125k
    }
337
3.55k
    if (unhealthy) {
338
7
      return JXL_FAILURE("Residual overflow");
339
7
    }
340
13.5k
  } else if (is_gradient_only && !skip_encoder_fast_path) {
341
2.69k
    for (size_t c = 0; c < 3; c++) {
342
2.02k
      FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
343
2.02k
                &predictor_img.Plane(c));
344
2.02k
    }
345
674
    const ptrdiff_t onerow = channel.plane.PixelsPerRow();
346
674
    bool unhealthy = false;
347
20.0k
    for (size_t y = 0; y < channel.h; y++) {
348
19.3k
      const pixel_type *JXL_RESTRICT r = channel.Row(y);
349
1.45M
      for (size_t x = 0; x < channel.w; x++) {
350
1.43M
        pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
351
1.43M
        pixel_type_w top = (y ? *(r + x - onerow) : left);
352
1.43M
        pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
353
1.43M
        int32_t guess = ClampedGradient(top, left, topleft);
354
1.43M
        uint32_t pos =
355
1.43M
            kPropRangeFast +
356
1.43M
            std::min<pixel_type_w>(
357
1.43M
                std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
358
1.43M
                kPropRangeFast - 1);
359
1.43M
        uint32_t ctx_id = tree_lut->context_lookup[pos];
360
1.43M
        int32_t residual;
361
1.43M
        unhealthy |= SubOverflow(r[x], guess, residual);
362
1.43M
        *tokenp++ = Token(ctx_id, PackSigned(residual));
363
1.43M
      }
364
19.3k
    }
365
674
    if (unhealthy) {
366
14
      return JXL_FAILURE("Residual overflow");
367
14
    }
368
12.8k
  } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero &&
369
0
             tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
370
0
             !skip_encoder_fast_path) {
371
0
    for (size_t c = 0; c < 3; c++) {
372
0
      FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]),
373
0
                &predictor_img.Plane(c));
374
0
    }
375
0
    for (size_t y = 0; y < channel.h; y++) {
376
0
      const pixel_type *JXL_RESTRICT p = channel.Row(y);
377
0
      for (size_t x = 0; x < channel.w; x++) {
378
0
        *tokenp++ = Token(tree[0].childID, PackSigned(p[x]));
379
0
      }
380
0
    }
381
12.8k
  } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted &&
382
6.56k
             (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 &&
383
6.56k
             tree[0].predictor_offset == 0 && !skip_encoder_fast_path) {
384
    // multiplier is a power of 2.
385
26.2k
    for (size_t c = 0; c < 3; c++) {
386
19.6k
      FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]),
387
19.6k
                &predictor_img.Plane(c));
388
19.6k
    }
389
6.56k
    uint32_t mul_shift =
390
6.56k
        FloorLog2Nonzero(static_cast<uint32_t>(tree[0].multiplier));
391
6.56k
    const ptrdiff_t onerow = channel.plane.PixelsPerRow();
392
34.1k
    for (size_t y = 0; y < channel.h; y++) {
393
27.5k
      const pixel_type *JXL_RESTRICT r = channel.Row(y);
394
272k
      for (size_t x = 0; x < channel.w; x++) {
395
244k
        PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x,
396
244k
                                                  y, tree[0].predictor);
397
244k
        pixel_type_w residual = r[x] - pred.guess;
398
244k
        JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual);
399
244k
        *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift));
400
244k
      }
401
27.5k
    }
402
403
6.56k
  } else if (!use_wp && !skip_encoder_fast_path) {
404
2.90k
    const ptrdiff_t onerow = channel.plane.PixelsPerRow();
405
2.90k
    Properties properties(num_props);
406
2.90k
    JXL_ASSIGN_OR_RETURN(
407
2.90k
        Channel references,
408
2.90k
        Channel::Create(memory_manager,
409
2.90k
                        properties.size() - kNumNonrefProperties, channel.w));
410
128k
    for (size_t y = 0; y < channel.h; y++) {
411
125k
      const pixel_type *JXL_RESTRICT p = channel.Row(y);
412
125k
      PrecomputeReferences(channel, y, image, chan, &references);
413
125k
      float *pred_img_row[3];
414
125k
      if (kWantDebug) {
415
502k
        for (size_t c = 0; c < 3; c++) {
416
376k
          pred_img_row[c] = predictor_img.PlaneRow(c, y);
417
376k
        }
418
125k
      }
419
125k
      InitPropsRow(&properties, static_props, y);
420
13.1M
      for (size_t x = 0; x < channel.w; x++) {
421
12.9M
        PredictionResult res =
422
12.9M
            PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
423
12.9M
                            tree_lookup, references);
424
12.9M
        if (kWantDebug) {
425
51.9M
          for (size_t i = 0; i < 3; i++) {
426
38.9M
            pred_img_row[i][x] = PredictorColor(res.predictor)[i];
427
38.9M
          }
428
12.9M
        }
429
12.9M
        pixel_type_w residual = p[x] - res.guess;
430
12.9M
        JXL_DASSERT(residual % res.multiplier == 0);
431
12.9M
        *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
432
12.9M
      }
433
125k
    }
434
3.41k
  } else {
435
3.41k
    const ptrdiff_t onerow = channel.plane.PixelsPerRow();
436
3.41k
    Properties properties(num_props);
437
3.41k
    JXL_ASSIGN_OR_RETURN(
438
3.41k
        Channel references,
439
3.41k
        Channel::Create(memory_manager,
440
3.41k
                        properties.size() - kNumNonrefProperties, channel.w));
441
3.41k
    weighted::State wp_state(wp_header, channel.w, channel.h);
442
462k
    for (size_t y = 0; y < channel.h; y++) {
443
459k
      const pixel_type *JXL_RESTRICT p = channel.Row(y);
444
459k
      PrecomputeReferences(channel, y, image, chan, &references);
445
459k
      float *pred_img_row[3];
446
459k
      if (kWantDebug) {
447
1.83M
        for (size_t c = 0; c < 3; c++) {
448
1.37M
          pred_img_row[c] = predictor_img.PlaneRow(c, y);
449
1.37M
        }
450
459k
      }
451
459k
      InitPropsRow(&properties, static_props, y);
452
86.4M
      for (size_t x = 0; x < channel.w; x++) {
453
85.9M
        PredictionResult res =
454
85.9M
            PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
455
85.9M
                          tree_lookup, references, &wp_state);
456
85.9M
        if (kWantDebug) {
457
343M
          for (size_t i = 0; i < 3; i++) {
458
257M
            pred_img_row[i][x] = PredictorColor(res.predictor)[i];
459
257M
          }
460
85.9M
        }
461
85.9M
        pixel_type_w residual = p[x] - res.guess;
462
85.9M
        JXL_DASSERT(residual % res.multiplier == 0);
463
85.9M
        *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
464
85.9M
        wp_state.UpdateErrors(p[x], x, y, channel.w);
465
85.9M
      }
466
459k
    }
467
3.41k
  }
468
  /* TODO(szabadka): Add cparams to the call stack here.
469
  if (kWantDebug && WantDebugOutput(cparams)) {
470
    DumpImage(
471
        cparams,
472
        ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(),
473
        predictor_img);
474
  }
475
  */
476
25.2k
  *tokenpp = tokenp;
477
25.2k
  return true;
478
25.3k
}
479
480
}  // namespace
481
482
Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels,
483
5.51k
                    int bitdepth, int prevprop) {
484
5.51k
  switch (tree_kind) {
485
0
    case ModularOptions::TreeKind::kJpegTranscodeACMeta:
486
      // All the data is 0, so no need for a fancy tree.
487
0
      return {PropertyDecisionNode::Leaf(Predictor::Zero)};
488
0
    case ModularOptions::TreeKind::kTrivialTreeNoPredictor:
489
      // All the data is 0, so no need for a fancy tree.
490
0
      return {PropertyDecisionNode::Leaf(Predictor::Zero)};
491
0
    case ModularOptions::TreeKind::kFalconACMeta:
492
      // All the data is 0 except the quant field. TODO(veluca): make that 0
493
      // too.
494
0
      return {PropertyDecisionNode::Leaf(Predictor::Left)};
495
2.75k
    case ModularOptions::TreeKind::kACMeta: {
496
      // Small image.
497
2.75k
      if (total_pixels < 1024) {
498
1.66k
        return {PropertyDecisionNode::Leaf(Predictor::Left)};
499
1.66k
      }
500
1.09k
      Tree tree;
501
      // 0: c > 1
502
1.09k
      tree.push_back(PropertyDecisionNode::Split(0, 1, 1));
503
      // 1: c > 2
504
1.09k
      tree.push_back(PropertyDecisionNode::Split(0, 2, 3));
505
      // 2: c > 0
506
1.09k
      tree.push_back(PropertyDecisionNode::Split(0, 0, 5));
507
      // 3: EPF control field (all 0 or 4), top > 3
508
1.09k
      tree.push_back(PropertyDecisionNode::Split(6, 3, 21));
509
      // 4: ACS+QF, y > 0
510
1.09k
      tree.push_back(PropertyDecisionNode::Split(2, 0, 7));
511
      // 5: CfL x
512
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient));
513
      // 6: CfL b
514
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient));
515
      // 7: QF: split according to the left quant value.
516
1.09k
      tree.push_back(PropertyDecisionNode::Split(7, 5, 9));
517
      // 8: ACS: split in 4 segments (8x8 from 0 to 3, large square 4-5, large
518
      // rectangular 6-11, 8x8 12+), according to previous ACS value.
519
1.09k
      tree.push_back(PropertyDecisionNode::Split(7, 5, 15));
520
      // QF
521
1.09k
      tree.push_back(PropertyDecisionNode::Split(7, 11, 11));
522
1.09k
      tree.push_back(PropertyDecisionNode::Split(7, 3, 13));
523
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
524
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
525
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
526
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
527
      // ACS
528
1.09k
      tree.push_back(PropertyDecisionNode::Split(7, 11, 17));
529
1.09k
      tree.push_back(PropertyDecisionNode::Split(7, 3, 19));
530
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
531
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
532
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
533
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
534
      // EPF, left > 3
535
1.09k
      tree.push_back(PropertyDecisionNode::Split(7, 3, 23));
536
1.09k
      tree.push_back(PropertyDecisionNode::Split(7, 3, 25));
537
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
538
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
539
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
540
1.09k
      tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
541
1.09k
      return tree;
542
2.75k
    }
543
2.75k
    case ModularOptions::TreeKind::kWPFixedDC: {
544
2.75k
      std::vector<int32_t> cutoffs = {
545
2.75k
          -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15,
546
2.75k
          -11,  -7,   -4,   -3,   -1,   0,   1,   3,   5,   7,   11,
547
2.75k
          15,   23,   31,   47,   63,   95,  127, 191, 255, 392, 500};
548
2.75k
      return MakeFixedTree(kWPProp, cutoffs, Predictor::Weighted, total_pixels,
549
2.75k
                           bitdepth);
550
2.75k
    }
551
0
    case ModularOptions::TreeKind::kGradientFixedDC: {
552
0
      std::vector<int32_t> cutoffs = {
553
0
          -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15,
554
0
          -11,  -7,   -4,   -3,   -1,   0,   1,   3,   5,   7,   11,
555
0
          15,   23,   31,   47,   63,   95,  127, 191, 255, 392, 500};
556
0
      return MakeFixedTree(
557
0
          prevprop > 0 ? kNumNonrefProperties + 2 : kGradientProp, cutoffs,
558
0
          Predictor::Gradient, total_pixels, bitdepth);
559
2.75k
    }
560
0
    case ModularOptions::TreeKind::kLearn: {
561
0
      JXL_DEBUG_ABORT("internal: kLearn is not predefined tree");
562
0
      return {};
563
2.75k
    }
564
5.51k
  }
565
0
  JXL_DEBUG_ABORT("internal: unexpected TreeKind: %d",
566
0
                  static_cast<int>(tree_kind));
567
0
  return {};
568
5.51k
}
569
570
StatusOr<Tree> LearnTree(
571
    const Image *images, const ModularOptions *options, const uint32_t start,
572
    const uint32_t stop,
573
2.83k
    const std::vector<ModularMultiplierInfo> &multiplier_info = {}) {
574
2.83k
  TreeSamples tree_samples;
575
2.83k
  JXL_RETURN_IF_ERROR(tree_samples.SetPredictor(options[start].predictor,
576
2.83k
                                                options[start].wp_tree_mode));
577
2.83k
  JXL_RETURN_IF_ERROR(
578
2.83k
      tree_samples.SetProperties(options[start].splitting_heuristics_properties,
579
2.83k
                                 options[start].wp_tree_mode));
580
2.83k
  uint32_t max_c = 0;
581
2.83k
  std::vector<pixel_type> pixel_samples;
582
2.83k
  std::vector<pixel_type> diff_samples;
583
2.83k
  std::vector<uint32_t> group_pixel_count;
584
2.83k
  std::vector<uint32_t> channel_pixel_count;
585
6.70k
  for (uint32_t i = start; i < stop; i++) {
586
3.87k
    max_c = std::max<uint32_t>(images[i].channel.size(), max_c);
587
3.87k
    CollectPixelSamples(images[i], options[i], i, group_pixel_count,
588
3.87k
                        channel_pixel_count, pixel_samples, diff_samples);
589
3.87k
  }
590
2.83k
  StaticPropRange range;
591
2.83k
  range[0] = {{0, max_c}};
592
2.83k
  range[1] = {{start, stop}};
593
594
2.83k
  tree_samples.PreQuantizeProperties(
595
2.83k
      range, multiplier_info, group_pixel_count, channel_pixel_count,
596
2.83k
      pixel_samples, diff_samples, options[start].max_property_values);
597
598
2.83k
  size_t total_pixels = 0;
599
7.84k
  for (size_t i = 0; i < images[start].channel.size(); i++) {
600
5.42k
    if (i >= images[start].nb_meta_channels &&
601
4.86k
        (images[start].channel[i].w > options[start].max_chan_size ||
602
4.47k
         images[start].channel[i].h > options[start].max_chan_size)) {
603
414
      break;
604
414
    }
605
5.00k
    total_pixels += images[start].channel[i].w * images[start].channel[i].h;
606
5.00k
  }
607
2.83k
  total_pixels = std::max<size_t>(total_pixels, 1);
608
609
2.83k
  weighted::Header wp_header;
610
611
6.70k
  for (size_t i = start; i < stop; i++) {
612
3.87k
    size_t nb_channels = images[i].channel.size();
613
614
3.87k
    if (images[i].w == 0 || images[i].h == 0 || nb_channels < 1)
615
0
      continue;  // is there any use for a zero-channel image?
616
3.87k
    if (images[i].error) return JXL_FAILURE("Invalid image");
617
3.87k
    JXL_ENSURE(options[i].tree_kind == ModularOptions::TreeKind::kLearn);
618
619
3.87k
    JXL_DEBUG_V(
620
3.87k
        2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.",
621
3.87k
        nb_channels, images[i].bitdepth, images[i].w, images[i].h);
622
623
    // encode transforms
624
3.87k
    Bundle::Init(&wp_header);
625
3.87k
    if (PredictorHasWeighted(options[i].predictor)) {
626
0
      weighted::PredictorMode(options[i].wp_mode, &wp_header);
627
0
    }
628
629
    // Gather tree data
630
10.0k
    for (size_t c = 0; c < nb_channels; c++) {
631
6.58k
      if (c >= images[i].nb_meta_channels &&
632
5.90k
          (images[i].channel[c].w > options[i].max_chan_size ||
633
5.51k
           images[i].channel[c].h > options[i].max_chan_size)) {
634
414
        break;
635
414
      }
636
6.17k
      if (!images[i].channel[c].w || !images[i].channel[c].h) {
637
0
        continue;  // skip empty channels
638
0
      }
639
6.17k
      JXL_RETURN_IF_ERROR(GatherTreeData(images[i], c, i, wp_header, options[i],
640
6.17k
                                         tree_samples, &total_pixels));
641
6.17k
    }
642
3.87k
  }
643
644
  // TODO(veluca): parallelize more.
645
2.83k
  JXL_ASSIGN_OR_RETURN(Tree tree,
646
2.83k
                       LearnTree(std::move(tree_samples), total_pixels,
647
2.83k
                                 options[start], multiplier_info, range));
648
2.83k
  return tree;
649
2.83k
}
650
651
Status ModularCompress(const Image &image, const ModularOptions &options,
652
                       size_t group_id, const Tree &tree, GroupHeader &header,
653
79.3k
                       std::vector<Token> &tokens, size_t *width) {
654
79.3k
  size_t nb_channels = image.channel.size();
655
656
79.3k
  if (image.w == 0 || image.h == 0 || nb_channels < 1)
657
70.0k
    return true;  // is there any use for a zero-channel image?
658
9.34k
  if (image.error) return JXL_FAILURE("Invalid image");
659
660
9.34k
  JXL_DEBUG_V(
661
9.34k
      2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.",
662
9.34k
      nb_channels, image.bitdepth, image.w, image.h);
663
664
  // encode transforms
665
9.34k
  Bundle::Init(&header);
666
9.34k
  if (PredictorHasWeighted(options.predictor)) {
667
2.73k
    weighted::PredictorMode(options.wp_mode, &header.wp_header);
668
2.73k
  }
669
9.34k
  header.transforms = image.transform;
670
9.34k
  header.use_global_tree = true;
671
672
9.34k
  size_t image_width = 0;
673
9.34k
  size_t total_tokens = 0;
674
34.6k
  for (size_t i = 0; i < nb_channels; i++) {
675
25.7k
    if (i >= image.nb_meta_channels &&
676
25.0k
        (image.channel[i].w > options.max_chan_size ||
677
24.6k
         image.channel[i].h > options.max_chan_size)) {
678
414
      break;
679
414
    }
680
25.3k
    if (image.channel[i].w > image_width) image_width = image.channel[i].w;
681
25.3k
    total_tokens += image.channel[i].w * image.channel[i].h;
682
25.3k
  }
683
9.34k
  if (options.zero_tokens) {
684
0
    tokens.resize(tokens.size() + total_tokens, {0, 0});
685
9.34k
  } else {
686
    // Do one big allocation for all the tokens we'll need,
687
    // to avoid reallocs that might require copying.
688
9.34k
    size_t pos = tokens.size();
689
9.34k
    tokens.resize(pos + total_tokens);
690
9.34k
    Token *tokenp = tokens.data() + pos;
691
34.6k
    for (size_t i = 0; i < nb_channels; i++) {
692
25.7k
      if (i >= image.nb_meta_channels &&
693
25.0k
          (image.channel[i].w > options.max_chan_size ||
694
24.6k
           image.channel[i].h > options.max_chan_size)) {
695
414
        break;
696
414
      }
697
25.3k
      if (!image.channel[i].w || !image.channel[i].h) {
698
0
        continue;  // skip empty channels
699
0
      }
700
25.3k
      JXL_RETURN_IF_ERROR(
701
25.3k
          EncodeModularChannelMAANS(image, i, header.wp_header, tree, &tokenp,
702
25.3k
                                    group_id, options.skip_encoder_fast_path));
703
25.3k
    }
704
    // Make sure we actually wrote all tokens
705
9.31k
    JXL_ENSURE(tokenp == tokens.data() + tokens.size());
706
9.31k
  }
707
708
9.31k
  *width = image_width;
709
710
9.31k
  return true;
711
9.34k
}
712
713
Status ModularGenericCompress(const Image &image, const ModularOptions &opts,
714
                              BitWriter &writer, AuxOut *aux_out,
715
1.02k
                              LayerType layer, size_t group_id) {
716
1.02k
  size_t nb_channels = image.channel.size();
717
718
1.02k
  if (image.w == 0 || image.h == 0 || nb_channels < 1)
719
0
    return true;  // is there any use for a zero-channel image?
720
1.02k
  if (image.error) return JXL_FAILURE("Invalid image");
721
722
1.02k
  ModularOptions options = opts;  // Make a copy to modify it.
723
1.02k
  if (options.predictor == kUndefinedPredictor) {
724
0
    options.predictor = Predictor::Gradient;
725
0
  }
726
727
1.02k
  size_t bits = writer.BitsWritten();
728
729
1.02k
  JxlMemoryManager *memory_manager = image.memory_manager();
730
1.02k
  JXL_DEBUG_V(
731
1.02k
      2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.",
732
1.02k
      nb_channels, image.bitdepth, image.w, image.h);
733
734
  // encode transforms
735
1.02k
  GroupHeader header;
736
1.02k
  Bundle::Init(&header);
737
1.02k
  if (PredictorHasWeighted(options.predictor)) {
738
139
    weighted::PredictorMode(options.wp_mode, &header.wp_header);
739
139
  }
740
1.02k
  header.transforms = image.transform;
741
742
1.02k
  JXL_RETURN_IF_ERROR(Bundle::Write(header, &writer, layer, aux_out));
743
744
  // Compute tree.
745
1.02k
  Tree tree;
746
1.02k
  if (options.tree_kind == ModularOptions::TreeKind::kLearn) {
747
745
    JXL_ASSIGN_OR_RETURN(tree, LearnTree(&image, &options, 0, 1));
748
745
  } else {
749
278
    size_t total_pixels = 0;
750
1.25k
    for (size_t i = 0; i < nb_channels; i++) {
751
973
      if (i >= image.nb_meta_channels &&
752
973
          (image.channel[i].w > options.max_chan_size ||
753
973
           image.channel[i].h > options.max_chan_size)) {
754
0
        break;
755
0
      }
756
973
      total_pixels += image.channel[i].w * image.channel[i].h;
757
973
    }
758
278
    total_pixels = std::max<size_t>(total_pixels, 1);
759
760
278
    tree = PredefinedTree(options.tree_kind, total_pixels, image.bitdepth,
761
278
                          options.max_properties);
762
278
  }
763
764
1.02k
  Tree decoded_tree;
765
1.02k
  std::vector<std::vector<Token>> tree_tokens(1);
766
1.02k
  JXL_RETURN_IF_ERROR(TokenizeTree(tree, tree_tokens.data(), &decoded_tree));
767
1.02k
  JXL_ENSURE(tree.size() == decoded_tree.size());
768
1.02k
  tree = std::move(decoded_tree);
769
770
  /* TODO(szabadka) Add text output callback
771
  if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) {
772
    PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id));
773
  } */
774
775
  // Write tree
776
1.02k
  EntropyEncodingData code;
777
1.02k
  JXL_ASSIGN_OR_RETURN(
778
1.02k
      size_t cost,
779
1.02k
      BuildAndEncodeHistograms(memory_manager, options.histogram_params,
780
1.02k
                               kNumTreeContexts, tree_tokens, &code, &writer,
781
1.02k
                               LayerType::ModularTree, aux_out));
782
1.02k
  JXL_RETURN_IF_ERROR(WriteTokens(tree_tokens[0], code, 0, &writer,
783
1.02k
                                  LayerType::ModularTree, aux_out));
784
785
1.02k
  size_t image_width = 0;
786
1.02k
  std::vector<std::vector<Token>> tokens(1);
787
  // it puts `use_global_tree = true` in the header, but this is not used
788
  // further
789
1.02k
  JXL_RETURN_IF_ERROR(ModularCompress(image, options, group_id, tree, header,
790
1.02k
                                      tokens[0], &image_width));
791
792
  // Write data
793
1.02k
  code = {};
794
1.02k
  HistogramParams histo_params = options.histogram_params;
795
1.02k
  histo_params.image_widths.push_back(image_width);
796
1.02k
  JXL_ASSIGN_OR_RETURN(
797
1.02k
      cost, BuildAndEncodeHistograms(memory_manager, histo_params,
798
1.02k
                                     (tree.size() + 1) / 2, tokens, &code,
799
1.02k
                                     &writer, layer, aux_out));
800
1.02k
  (void)cost;
801
1.02k
  JXL_RETURN_IF_ERROR(WriteTokens(tokens[0], code, 0, &writer, layer, aux_out));
802
803
1.02k
  bits = writer.BitsWritten() - bits;
804
1.02k
  JXL_DEBUG_V(4,
805
1.02k
              "Modular-encoded a %" PRIuS "x%" PRIuS
806
1.02k
              " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes",
807
1.02k
              image.w, image.h, image.bitdepth, image.channel.size(), bits / 8);
808
1.02k
  (void)bits;
809
810
1.02k
  return true;
811
1.02k
}
812
813
}  // namespace jxl