Coverage Report

Created: 2025-06-22 08:04

/src/libjxl/lib/jxl/dec_modular.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_modular.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <cstdint>
11
#include <vector>
12
13
#include "lib/jxl/frame_header.h"
14
15
#undef HWY_TARGET_INCLUDE
16
#define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc"
17
#include <hwy/foreach_target.h>
18
#include <hwy/highway.h>
19
20
#include "lib/jxl/base/compiler_specific.h"
21
#include "lib/jxl/base/printf_macros.h"
22
#include "lib/jxl/base/rect.h"
23
#include "lib/jxl/base/status.h"
24
#include "lib/jxl/compressed_dc.h"
25
#include "lib/jxl/epf.h"
26
#include "lib/jxl/modular/encoding/encoding.h"
27
#include "lib/jxl/modular/modular_image.h"
28
#include "lib/jxl/modular/transform/transform.h"
29
30
HWY_BEFORE_NAMESPACE();
31
namespace jxl {
32
namespace HWY_NAMESPACE {
33
34
// These templates are not found via ADL.
35
using hwy::HWY_NAMESPACE::Add;
36
using hwy::HWY_NAMESPACE::Mul;
37
using hwy::HWY_NAMESPACE::Rebind;
38
39
void MultiplySum(const size_t xsize,
40
                 const pixel_type* const JXL_RESTRICT row_in,
41
                 const pixel_type* const JXL_RESTRICT row_in_Y,
42
1.02M
                 const float factor, float* const JXL_RESTRICT row_out) {
43
1.02M
  const HWY_FULL(float) df;
44
1.02M
  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
45
1.02M
  const auto factor_v = Set(df, factor);
46
145M
  for (size_t x = 0; x < xsize; x += Lanes(di)) {
47
144M
    const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x));
48
144M
    const auto out = Mul(ConvertTo(df, in), factor_v);
49
144M
    Store(out, df, row_out + x);
50
144M
  }
51
1.02M
}
52
53
void RgbFromSingle(const size_t xsize,
54
                   const pixel_type* const JXL_RESTRICT row_in,
55
                   const float factor, float* out_r, float* out_g,
56
13.1k
                   float* out_b) {
57
13.1k
  const HWY_FULL(float) df;
58
13.1k
  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
59
60
13.1k
  const auto factor_v = Set(df, factor);
61
1.80M
  for (size_t x = 0; x < xsize; x += Lanes(di)) {
62
1.79M
    const auto in = Load(di, row_in + x);
63
1.79M
    const auto out = Mul(ConvertTo(df, in), factor_v);
64
1.79M
    Store(out, df, out_r + x);
65
1.79M
    Store(out, df, out_g + x);
66
1.79M
    Store(out, df, out_b + x);
67
1.79M
  }
68
13.1k
}
69
70
void SingleFromSingle(const size_t xsize,
71
                      const pixel_type* const JXL_RESTRICT row_in,
72
4.06M
                      const float factor, float* row_out) {
73
4.06M
  const HWY_FULL(float) df;
74
4.06M
  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
75
76
4.06M
  const auto factor_v = Set(df, factor);
77
597M
  for (size_t x = 0; x < xsize; x += Lanes(di)) {
78
593M
    const auto in = Load(di, row_in + x);
79
593M
    const auto out = Mul(ConvertTo(df, in), factor_v);
80
593M
    Store(out, df, row_out + x);
81
593M
  }
82
4.06M
}
83
// NOLINTNEXTLINE(google-readability-namespace-comments)
84
}  // namespace HWY_NAMESPACE
85
}  // namespace jxl
86
HWY_AFTER_NAMESPACE();
87
88
#if HWY_ONCE
89
namespace jxl {
90
HWY_EXPORT(MultiplySum);       // Local function
91
HWY_EXPORT(RgbFromSingle);     // Local function
92
HWY_EXPORT(SingleFromSingle);  // Local function
93
94
// Slow conversion using double precision multiplication, only
95
// needed when the bit depth is too high for single precision
96
void SingleFromSingleAccurate(const size_t xsize,
97
                              const pixel_type* const JXL_RESTRICT row_in,
98
179k
                              const double factor, float* row_out) {
99
9.21M
  for (size_t x = 0; x < xsize; x++) {
100
9.03M
    row_out[x] = row_in[x] * factor;
101
9.03M
  }
102
179k
}
103
104
// convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int
105
// back to binary32 float
106
Status int_to_float(const pixel_type* const JXL_RESTRICT row_in,
107
                    float* const JXL_RESTRICT row_out, const size_t xsize,
108
502k
                    const int bits, const int exp_bits) {
109
502k
  static_assert(sizeof(pixel_type) == sizeof(float));
110
502k
  if (bits == 32) {
111
316k
    JXL_ENSURE(exp_bits == 8);
112
316k
    memcpy(row_out, row_in, xsize * sizeof(float));
113
316k
    return true;
114
316k
  }
115
186k
  int exp_bias = (1 << (exp_bits - 1)) - 1;
116
186k
  int sign_shift = bits - 1;
117
186k
  int mant_bits = bits - exp_bits - 1;
118
186k
  int mant_shift = 23 - mant_bits;
119
31.8M
  for (size_t x = 0; x < xsize; ++x) {
120
31.6M
    uint32_t f;
121
31.6M
    memcpy(&f, &row_in[x], 4);
122
31.6M
    int signbit = (f >> sign_shift);
123
31.6M
    f &= (1 << sign_shift) - 1;
124
31.6M
    if (f == 0) {
125
11.8M
      row_out[x] = (signbit ? -0.f : 0.f);
126
11.8M
      continue;
127
11.8M
    }
128
19.7M
    int exp = (f >> mant_bits);
129
19.7M
    int mantissa = (f & ((1 << mant_bits) - 1));
130
19.7M
    mantissa <<= mant_shift;
131
    // Try to normalize only if there is space for maneuver.
132
19.7M
    if (exp == 0 && exp_bits < 8) {
133
      // subnormal number
134
25.6M
      while ((mantissa & 0x800000) == 0) {
135
23.2M
        mantissa <<= 1;
136
23.2M
        exp--;
137
23.2M
      }
138
2.44M
      exp++;
139
      // remove leading 1 because it is implicit now
140
2.44M
      mantissa &= 0x7fffff;
141
2.44M
    }
142
19.7M
    exp -= exp_bias;
143
    // broke up the arbitrary float into its parts, now reassemble into
144
    // binary32
145
19.7M
    exp += 127;
146
19.7M
    JXL_ENSURE(exp >= 0);
147
19.7M
    f = (signbit ? 0x80000000 : 0);
148
19.7M
    f |= (exp << 23);
149
19.7M
    f |= mantissa;
150
19.7M
    memcpy(&row_out[x], &f, 4);
151
19.7M
  }
152
186k
  return true;
153
186k
}
154
155
#if JXL_DEBUG_V_LEVEL >= 1
156
std::string ModularStreamId::DebugString() const {
157
  std::ostringstream os;
158
  os << (kind == GlobalData   ? "ModularGlobal"
159
         : kind == VarDCTDC   ? "VarDCTDC"
160
         : kind == ModularDC  ? "ModularDC"
161
         : kind == ACMetadata ? "ACMeta"
162
         : kind == QuantTable ? "QuantTable"
163
         : kind == ModularAC  ? "ModularAC"
164
                              : "");
165
  if (kind == VarDCTDC || kind == ModularDC || kind == ACMetadata ||
166
      kind == ModularAC) {
167
    os << " group " << group_id;
168
  }
169
  if (kind == ModularAC) {
170
    os << " pass " << pass_id;
171
  }
172
  if (kind == QuantTable) {
173
    os << " " << quant_table_id;
174
  }
175
  return os.str();
176
}
177
#endif
178
179
Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
180
                                             const FrameHeader& frame_header,
181
41.7k
                                             bool allow_truncated_group) {
182
41.7k
  JxlMemoryManager* memory_manager = this->memory_manager();
183
41.7k
  bool decode_color = frame_header.encoding == FrameEncoding::kModular;
184
41.7k
  const auto& metadata = frame_header.nonserialized_metadata->m;
185
41.7k
  bool is_gray = metadata.color_encoding.IsGray();
186
41.7k
  size_t nb_chans = 3;
187
41.7k
  if (is_gray && frame_header.color_transform == ColorTransform::kNone) {
188
894
    nb_chans = 1;
189
894
  }
190
41.7k
  do_color = decode_color;
191
41.7k
  size_t nb_extra = metadata.extra_channel_info.size();
192
41.7k
  bool has_tree = static_cast<bool>(reader->ReadBits(1));
193
41.7k
  if (!allow_truncated_group ||
194
41.7k
      reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) {
195
41.7k
    if (has_tree) {
196
18.1k
      size_t tree_size_limit =
197
18.1k
          std::min(static_cast<size_t>(1 << 22),
198
18.1k
                   1024 + frame_dim.xsize * frame_dim.ysize *
199
18.1k
                              (nb_chans + nb_extra) / 16);
200
18.1k
      JXL_RETURN_IF_ERROR(
201
18.1k
          DecodeTree(memory_manager, reader, &tree, tree_size_limit));
202
17.7k
      JXL_RETURN_IF_ERROR(DecodeHistograms(
203
17.7k
          memory_manager, reader, (tree.size() + 1) / 2, &code, &context_map));
204
17.7k
    }
205
41.7k
  }
206
41.3k
  if (!do_color) nb_chans = 0;
207
208
41.3k
  bool fp = metadata.bit_depth.floating_point_sample;
209
210
  // bits_per_sample is just metadata for XYB images.
211
41.3k
  if (metadata.bit_depth.bits_per_sample >= 32 && do_color &&
212
41.3k
      frame_header.color_transform != ColorTransform::kXYB) {
213
6.24k
    if (metadata.bit_depth.bits_per_sample == 32 && fp == false) {
214
0
      return JXL_FAILURE("uint32_t not supported in dec_modular");
215
6.24k
    } else if (metadata.bit_depth.bits_per_sample > 32) {
216
0
      return JXL_FAILURE("bits_per_sample > 32 not supported");
217
0
    }
218
6.24k
  }
219
220
82.6k
  JXL_ASSIGN_OR_RETURN(
221
82.6k
      Image gi,
222
82.6k
      Image::Create(memory_manager, frame_dim.xsize, frame_dim.ysize,
223
82.6k
                    metadata.bit_depth.bits_per_sample, nb_chans + nb_extra));
224
225
82.6k
  all_same_shift = true;
226
82.6k
  if (frame_header.color_transform == ColorTransform::kYCbCr) {
227
33.8k
    for (size_t c = 0; c < nb_chans; c++) {
228
25.2k
      gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c);
229
25.2k
      gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c);
230
25.2k
      size_t xsize_shifted =
231
25.2k
          DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift);
232
25.2k
      size_t ysize_shifted =
233
25.2k
          DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift);
234
25.2k
      JXL_RETURN_IF_ERROR(gi.channel[c].shrink(xsize_shifted, ysize_shifted));
235
25.2k
      if (gi.channel[c].hshift != gi.channel[0].hshift ||
236
25.2k
          gi.channel[c].vshift != gi.channel[0].vshift)
237
9.83k
        all_same_shift = false;
238
25.2k
    }
239
8.58k
  }
240
241
63.7k
  for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) {
242
22.4k
    size_t ecups = frame_header.extra_channel_upsampling[ec];
243
22.4k
    JXL_RETURN_IF_ERROR(
244
22.4k
        gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups),
245
22.4k
                             DivCeil(frame_dim.ysize_upsampled, ecups)));
246
22.4k
    gi.channel[c].hshift = gi.channel[c].vshift =
247
22.4k
        CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling);
248
22.4k
    if (gi.channel[c].hshift != gi.channel[0].hshift ||
249
22.4k
        gi.channel[c].vshift != gi.channel[0].vshift)
250
10.6k
      all_same_shift = false;
251
22.4k
  }
252
253
41.3k
  JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s",
254
41.3k
              gi.DebugString().c_str());
255
41.3k
  ModularOptions options;
256
41.3k
  options.max_chan_size = frame_dim.group_dim;
257
41.3k
  options.group_dim = frame_dim.group_dim;
258
41.3k
  Status dec_status = ModularGenericDecompress(
259
41.3k
      reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim),
260
41.3k
      &options,
261
41.3k
      /*undo_transforms=*/false, &tree, &code, &context_map,
262
41.3k
      allow_truncated_group);
263
41.3k
  if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
264
39.3k
  if (dec_status.IsFatalError()) {
265
0
    return JXL_FAILURE("Failed to decode global modular info");
266
0
  }
267
268
  // TODO(eustas): are we sure this can be done after partial decode?
269
39.3k
  have_something = false;
270
275k
  for (size_t c = 0; c < gi.channel.size(); c++) {
271
236k
    Channel& gic = gi.channel[c];
272
236k
    if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim &&
273
236k
        gic.h <= frame_dim.group_dim)
274
223k
      have_something = true;
275
236k
  }
276
  // move global transforms to groups if possible
277
39.3k
  if (!have_something && all_same_shift) {
278
12.7k
    if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) {
279
90
      global_transform = gi.transform;
280
90
      gi.transform.clear();
281
      // TODO(jon): also move no-delta-palette out (trickier though)
282
90
    }
283
12.7k
  }
284
39.3k
  full_image = std::move(gi);
285
39.3k
  JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s",
286
39.3k
              full_image.DebugString().c_str());
287
39.3k
  return dec_status;
288
39.3k
}
289
290
37.8k
void ModularFrameDecoder::MaybeDropFullImage() {
291
37.8k
  if (full_image.transform.empty() && !have_something && all_same_shift) {
292
11.1k
    use_full_image = false;
293
11.1k
    JXL_DEBUG_V(6, "Dropping full image");
294
11.1k
    for (auto& ch : full_image.channel) {
295
      // keep metadata on channels around, but dealloc their planes
296
1.45k
      ch.plane = Plane<pixel_type>();
297
1.45k
    }
298
11.1k
  }
299
37.8k
}
300
301
Status ModularFrameDecoder::DecodeGroup(
302
    const FrameHeader& frame_header, const Rect& rect, BitReader* reader,
303
    int minShift, int maxShift, const ModularStreamId& stream, bool zerofill,
304
    PassesDecoderState* dec_state, RenderPipelineInput* render_pipeline_input,
305
85.8k
    bool allow_truncated, bool* should_run_pipeline) {
306
85.8k
  JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s",
307
85.8k
              stream.DebugString().c_str(), Description(rect).c_str(), minShift,
308
85.8k
              maxShift, zerofill ? "using zerofill" : "");
309
85.8k
  JXL_ENSURE(stream.kind == ModularStreamId::Kind::ModularDC ||
310
85.8k
             stream.kind == ModularStreamId::Kind::ModularAC);
311
85.8k
  const size_t xsize = rect.xsize();
312
85.8k
  const size_t ysize = rect.ysize();
313
85.8k
  JXL_ASSIGN_OR_RETURN(Image gi, Image::Create(memory_manager_, xsize, ysize,
314
85.8k
                                               full_image.bitdepth, 0));
315
  // start at the first bigger-than-groupsize non-metachannel
316
85.8k
  size_t c = full_image.nb_meta_channels;
317
496k
  for (; c < full_image.channel.size(); c++) {
318
425k
    Channel& fc = full_image.channel[c];
319
425k
    if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break;
320
425k
  }
321
85.8k
  size_t beginc = c;
322
170k
  for (; c < full_image.channel.size(); c++) {
323
85.0k
    Channel& fc = full_image.channel[c];
324
85.0k
    int shift = std::min(fc.hshift, fc.vshift);
325
85.0k
    if (shift > maxShift) continue;
326
81.4k
    if (shift < minShift) continue;
327
52.1k
    Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
328
52.1k
           rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
329
52.1k
    if (r.xsize() == 0 || r.ysize() == 0) continue;
330
51.6k
    if (zerofill && use_full_image) {
331
0
      for (size_t y = 0; y < r.ysize(); ++y) {
332
0
        pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y);
333
0
        memset(row_out, 0, r.xsize() * sizeof(*row_out));
334
0
      }
335
51.6k
    } else {
336
51.6k
      JXL_ASSIGN_OR_RETURN(
337
51.6k
          Channel gc, Channel::Create(memory_manager_, r.xsize(), r.ysize()));
338
51.6k
      if (zerofill) ZeroFillImage(&gc.plane);
339
51.6k
      gc.hshift = fc.hshift;
340
51.6k
      gc.vshift = fc.vshift;
341
51.6k
      gi.channel.emplace_back(std::move(gc));
342
51.6k
    }
343
51.6k
  }
344
85.8k
  if (zerofill && use_full_image) return true;
345
  // Return early if there's nothing to decode. Otherwise there might be
346
  // problems later (in ModularImageToDecodedRect).
347
85.8k
  if (gi.channel.empty()) {
348
74.8k
    if (dec_state && should_run_pipeline) {
349
35.9k
      const auto* metadata = frame_header.nonserialized_metadata;
350
35.9k
      if (do_color || metadata->m.num_extra_channels > 0) {
351
        // Signal to FrameDecoder that we do not have some of the required input
352
        // for the render pipeline.
353
26.0k
        *should_run_pipeline = false;
354
26.0k
      }
355
35.9k
    }
356
74.8k
    JXL_DEBUG_V(6, "Nothing to decode, returning early.");
357
74.8k
    return true;
358
74.8k
  }
359
10.9k
  ModularOptions options;
360
10.9k
  if (!zerofill) {
361
10.3k
    auto status = ModularGenericDecompress(
362
10.3k
        reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options,
363
10.3k
        /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated);
364
10.5k
    if (!allow_truncated) JXL_RETURN_IF_ERROR(status);
365
10.1k
    if (status.IsFatalError()) return status;
366
10.1k
  }
367
  // Undo global transforms that have been pushed to the group level
368
10.6k
  if (!use_full_image) {
369
4.61k
    JXL_ENSURE(render_pipeline_input);
370
4.61k
    for (const auto& t : global_transform) {
371
2.68k
      JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header));
372
2.68k
    }
373
4.61k
    JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(
374
4.61k
        frame_header, gi, dec_state, nullptr, *render_pipeline_input,
375
4.61k
        Rect(0, 0, gi.w, gi.h)));
376
4.61k
    return true;
377
4.61k
  }
378
6.06k
  int gic = 0;
379
62.7k
  for (c = beginc; c < full_image.channel.size(); c++) {
380
56.6k
    Channel& fc = full_image.channel[c];
381
56.6k
    int shift = std::min(fc.hshift, fc.vshift);
382
56.6k
    if (shift > maxShift) continue;
383
53.0k
    if (shift < minShift) continue;
384
37.6k
    Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
385
37.6k
           rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
386
37.6k
    if (r.xsize() == 0 || r.ysize() == 0) continue;
387
37.1k
    JXL_ENSURE(use_full_image);
388
37.1k
    JXL_RETURN_IF_ERROR(
389
37.1k
        CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()),
390
37.1k
                    /*from=*/gi.channel[gic].plane,
391
37.1k
                    /*rect_to=*/r, /*to=*/&fc.plane));
392
37.1k
    gic++;
393
37.1k
  }
394
6.06k
  return true;
395
6.06k
}
396
397
Status ModularFrameDecoder::DecodeVarDCTDC(const FrameHeader& frame_header,
398
                                           size_t group_id, BitReader* reader,
399
12.1k
                                           PassesDecoderState* dec_state) {
400
12.1k
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
401
12.1k
  const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id);
402
12.1k
  JXL_DEBUG_V(6, "Decoding VarDCT DC with rect %s", Description(r).c_str());
403
  // TODO(eustas): investigate if we could reduce the impact of
404
  //               EvalRationalPolynomial; generally speaking, the limit is
405
  //               2**(128/(3*magic)), where 128 comes from IEEE 754 exponent,
406
  //               3 comes from XybToRgb that cubes the values, and "magic" is
407
  //               the sum of all other contributions. 2**18 is known to lead
408
  //               to NaN on input found by fuzzing (see commit message).
409
12.1k
  JXL_ASSIGN_OR_RETURN(Image image,
410
12.1k
                       Image::Create(memory_manager, r.xsize(), r.ysize(),
411
12.1k
                                     full_image.bitdepth, 3));
412
12.1k
  size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim);
413
12.1k
  reader->Refill();
414
12.1k
  size_t extra_precision = reader->ReadFixedBits<2>();
415
12.1k
  float mul = 1.0f / (1 << extra_precision);
416
12.1k
  ModularOptions options;
417
48.6k
  for (size_t c = 0; c < 3; c++) {
418
36.4k
    Channel& ch = image.channel[c < 2 ? c ^ 1 : c];
419
36.4k
    ch.w >>= frame_header.chroma_subsampling.HShift(c);
420
36.4k
    ch.h >>= frame_header.chroma_subsampling.VShift(c);
421
36.4k
    JXL_RETURN_IF_ERROR(ch.shrink());
422
36.4k
  }
423
12.1k
  if (!ModularGenericDecompress(
424
12.1k
          reader, image, /*header=*/nullptr, stream_id, &options,
425
12.1k
          /*undo_transforms=*/true, &tree, &code, &context_map)) {
426
535
    return JXL_FAILURE("Failed to decode VarDCT DC group (DC group id %d)",
427
535
                       static_cast<int>(group_id));
428
535
  }
429
11.6k
  DequantDC(r, &dec_state->shared_storage.dc_storage,
430
11.6k
            &dec_state->shared_storage.quant_dc, image,
431
11.6k
            dec_state->shared->quantizer.MulDC(), mul,
432
11.6k
            dec_state->shared->cmap.base().DCFactors(),
433
11.6k
            frame_header.chroma_subsampling, dec_state->shared->block_ctx_map);
434
11.6k
  return true;
435
12.1k
}
436
437
Status ModularFrameDecoder::DecodeAcMetadata(const FrameHeader& frame_header,
438
                                             size_t group_id, BitReader* reader,
439
11.6k
                                             PassesDecoderState* dec_state) {
440
11.6k
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
441
11.6k
  const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id);
442
11.6k
  JXL_DEBUG_V(6, "Decoding AcMetadata with rect %s", Description(r).c_str());
443
11.6k
  size_t upper_bound = r.xsize() * r.ysize();
444
11.6k
  reader->Refill();
445
11.6k
  size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1;
446
11.6k
  size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim);
447
  // YToX, YToB, ACS + QF, EPF
448
11.6k
  JXL_ASSIGN_OR_RETURN(Image image,
449
11.6k
                       Image::Create(memory_manager, r.xsize(), r.ysize(),
450
11.6k
                                     full_image.bitdepth, 4));
451
11.6k
  static_assert(kColorTileDimInBlocks == 8, "Color tile size changed");
452
11.6k
  Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3);
453
11.6k
  JXL_ASSIGN_OR_RETURN(
454
11.6k
      image.channel[0],
455
11.6k
      Channel::Create(memory_manager, cr.xsize(), cr.ysize(), 3, 3));
456
11.6k
  JXL_ASSIGN_OR_RETURN(
457
11.6k
      image.channel[1],
458
11.6k
      Channel::Create(memory_manager, cr.xsize(), cr.ysize(), 3, 3));
459
11.6k
  JXL_ASSIGN_OR_RETURN(image.channel[2],
460
11.6k
                       Channel::Create(memory_manager, count, 2, 0, 0));
461
11.6k
  ModularOptions options;
462
11.6k
  if (!ModularGenericDecompress(
463
11.6k
          reader, image, /*header=*/nullptr, stream_id, &options,
464
11.6k
          /*undo_transforms=*/true, &tree, &code, &context_map)) {
465
719
    return JXL_FAILURE("Failed to decode AC metadata");
466
719
  }
467
10.8k
  JXL_RETURN_IF_ERROR(
468
10.8k
      ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane,
469
10.8k
                           cr, &dec_state->shared_storage.cmap.ytox_map));
470
10.8k
  JXL_RETURN_IF_ERROR(
471
10.8k
      ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane,
472
10.8k
                           cr, &dec_state->shared_storage.cmap.ytob_map));
473
10.8k
  size_t num = 0;
474
10.8k
  bool is444 = frame_header.chroma_subsampling.Is444();
475
10.8k
  auto& ac_strategy = dec_state->shared_storage.ac_strategy;
476
10.8k
  size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize());
477
10.8k
  size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize());
478
10.8k
  uint32_t local_used_acs = 0;
479
63.0k
  for (size_t iy = 0; iy < r.ysize(); iy++) {
480
52.3k
    size_t y = r.y0() + iy;
481
52.3k
    int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy);
482
52.3k
    uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy);
483
52.3k
    int32_t* row_in_1 = image.channel[2].plane.Row(0);
484
52.3k
    int32_t* row_in_2 = image.channel[2].plane.Row(1);
485
52.3k
    int32_t* row_in_3 = image.channel[3].plane.Row(iy);
486
972k
    for (size_t ix = 0; ix < r.xsize(); ix++) {
487
920k
      size_t x = r.x0() + ix;
488
920k
      int sharpness = row_in_3[ix];
489
920k
      if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) {
490
65
        return JXL_FAILURE("Corrupted sharpness field");
491
65
      }
492
920k
      row_epf[ix] = sharpness;
493
920k
      if (ac_strategy.IsValid(x, y)) {
494
172k
        continue;
495
172k
      }
496
497
748k
      if (num >= count) return JXL_FAILURE("Corrupted stream");
498
499
748k
      if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) {
500
45
        return JXL_FAILURE("Invalid AC strategy");
501
45
      }
502
748k
      local_used_acs |= 1u << row_in_1[num];
503
748k
      AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]);
504
748k
      if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) &&
505
748k
          !is444) {
506
1
        return JXL_FAILURE(
507
1
            "AC strategy not compatible with chroma subsampling");
508
1
      }
509
      // Ensure that blocks do not overflow *AC* groups.
510
748k
      size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
511
748k
      size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
512
748k
      size_t next_x_dct_block = x + acs.covered_blocks_x();
513
748k
      size_t next_y_dct_block = y + acs.covered_blocks_y();
514
748k
      if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) {
515
5
        return JXL_FAILURE("Invalid AC strategy, x overflow");
516
5
      }
517
748k
      if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) {
518
5
        return JXL_FAILURE("Invalid AC strategy, y overflow");
519
5
      }
520
748k
      JXL_RETURN_IF_ERROR(
521
748k
          ac_strategy.SetNoBoundsCheck(x, y, AcStrategyType(row_in_1[num])));
522
748k
      row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1,
523
748k
                                                     row_in_2[num]));
524
748k
      num++;
525
748k
    }
526
52.3k
  }
527
10.7k
  dec_state->used_acs |= local_used_acs;
528
10.7k
  if (frame_header.loop_filter.epf_iters > 0) {
529
9.99k
    JXL_RETURN_IF_ERROR(ComputeSigma(frame_header.loop_filter, r, dec_state));
530
9.99k
  }
531
10.7k
  return true;
532
10.7k
}
533
534
Status ModularFrameDecoder::ModularImageToDecodedRect(
535
    const FrameHeader& frame_header, Image& gi, PassesDecoderState* dec_state,
536
    jxl::ThreadPool* pool, RenderPipelineInput& render_pipeline_input,
537
33.1k
    Rect modular_rect) const {
538
33.1k
  const auto* metadata = frame_header.nonserialized_metadata;
539
33.1k
  JXL_ENSURE(gi.transform.empty());
540
541
5.15M
  auto get_row = [&](size_t c, size_t y) {
542
5.15M
    const auto& buffer = render_pipeline_input.GetBuffer(c);
543
5.15M
    return buffer.second.Row(buffer.first, y);
544
5.15M
  };
545
546
33.1k
  size_t c = 0;
547
33.1k
  if (do_color) {
548
33.1k
    const bool rgb_from_gray =
549
33.1k
        metadata->m.color_encoding.IsGray() &&
550
33.1k
        frame_header.color_transform == ColorTransform::kNone;
551
33.1k
    const bool fp = metadata->m.bit_depth.floating_point_sample &&
552
33.1k
                    frame_header.color_transform != ColorTransform::kXYB;
553
129k
    for (; c < 3; c++) {
554
97.2k
      double factor = full_image.bitdepth < 32
555
97.2k
                          ? 1.0 / ((1u << full_image.bitdepth) - 1)
556
97.2k
                          : 0;
557
97.2k
      size_t c_in = c;
558
97.2k
      if (frame_header.color_transform == ColorTransform::kXYB) {
559
46.5k
        factor = dec_state->shared->matrices.DCQuants()[c];
560
        // XYB is encoded as YX(B-Y)
561
46.5k
        if (c < 2) c_in = 1 - c;
562
50.7k
      } else if (rgb_from_gray) {
563
869
        c_in = 0;
564
869
      }
565
97.2k
      JXL_ENSURE(c_in < gi.channel.size());
566
97.2k
      Channel& ch_in = gi.channel[c_in];
567
      // TODO(eustas): could we detect it on earlier stage?
568
97.3k
      if (ch_in.w == 0 || ch_in.h == 0) {
569
0
        return JXL_FAILURE("Empty image");
570
0
      }
571
97.2k
      JXL_ENSURE(ch_in.hshift <= 3 && ch_in.vshift <= 3);
572
97.2k
      Rect r = render_pipeline_input.GetBuffer(c).second;
573
97.2k
      Rect mr(modular_rect.x0() >> ch_in.hshift,
574
97.2k
              modular_rect.y0() >> ch_in.vshift,
575
97.2k
              DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
576
97.2k
              DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
577
97.2k
      mr = mr.Crop(ch_in.plane);
578
97.2k
      size_t xsize_shifted = r.xsize();
579
97.2k
      size_t ysize_shifted = r.ysize();
580
97.7k
      if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
581
0
        return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
582
0
                           "x%" PRIuS
583
0
                           " modular channel into "
584
0
                           "a %" PRIuS "x%" PRIuS " rect",
585
0
                           mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
586
0
      }
587
97.2k
      if (frame_header.color_transform == ColorTransform::kXYB && c == 2) {
588
15.5k
        JXL_ENSURE(!fp);
589
15.5k
        const auto process_row = [&](const uint32_t task,
590
1.02M
                                     size_t /* thread */) -> Status {
591
1.02M
          const size_t y = task;
592
1.02M
          const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
593
1.02M
          const pixel_type* const JXL_RESTRICT row_in_Y =
594
1.02M
              mr.Row(&gi.channel[0].plane, y);
595
1.02M
          float* const JXL_RESTRICT row_out = get_row(c, y);
596
1.02M
          HWY_DYNAMIC_DISPATCH(MultiplySum)
597
1.02M
          (xsize_shifted, row_in, row_in_Y, factor, row_out);
598
1.02M
          return true;
599
1.02M
        };
600
15.5k
        JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted,
601
15.5k
                                      ThreadPool::NoInit, process_row,
602
15.5k
                                      "ModularIntToFloat"));
603
81.7k
      } else if (fp) {
604
26.2k
        int bits = metadata->m.bit_depth.bits_per_sample;
605
26.2k
        int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample;
606
26.2k
        const auto process_row = [&](const uint32_t task,
607
487k
                                     size_t /* thread */) -> Status {
608
487k
          const size_t y = task;
609
487k
          const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
610
487k
          if (rgb_from_gray) {
611
15.8k
            for (size_t cc = 0; cc < 3; cc++) {
612
11.8k
              float* const JXL_RESTRICT row_out = get_row(cc, y);
613
11.8k
              JXL_RETURN_IF_ERROR(
614
11.8k
                  int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits));
615
11.8k
            }
616
483k
          } else {
617
483k
            float* const JXL_RESTRICT row_out = get_row(c, y);
618
483k
            JXL_RETURN_IF_ERROR(
619
483k
                int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits));
620
483k
          }
621
487k
          return true;
622
487k
        };
623
26.2k
        JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted,
624
26.2k
                                      ThreadPool::NoInit, process_row,
625
26.2k
                                      "ModularIntToFloat_losslessfloat"));
626
55.5k
      } else {
627
55.5k
        const auto process_row = [&](const uint32_t task,
628
3.59M
                                     size_t /* thread */) -> Status {
629
3.59M
          const size_t y = task;
630
3.59M
          const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
631
3.59M
          if (rgb_from_gray) {
632
20.8k
            if (full_image.bitdepth < 23) {
633
13.1k
              HWY_DYNAMIC_DISPATCH(RgbFromSingle)
634
13.1k
              (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y),
635
13.1k
               get_row(2, y));
636
13.1k
            } else {
637
7.69k
              SingleFromSingleAccurate(xsize_shifted, row_in, factor,
638
7.69k
                                       get_row(0, y));
639
7.69k
              SingleFromSingleAccurate(xsize_shifted, row_in, factor,
640
7.69k
                                       get_row(1, y));
641
7.69k
              SingleFromSingleAccurate(xsize_shifted, row_in, factor,
642
7.69k
                                       get_row(2, y));
643
7.69k
            }
644
3.57M
          } else {
645
3.57M
            float* const JXL_RESTRICT row_out = get_row(c, y);
646
3.57M
            if (full_image.bitdepth < 23) {
647
3.48M
              HWY_DYNAMIC_DISPATCH(SingleFromSingle)
648
3.48M
              (xsize_shifted, row_in, factor, row_out);
649
3.48M
            } else {
650
83.7k
              SingleFromSingleAccurate(xsize_shifted, row_in, factor, row_out);
651
83.7k
            }
652
3.57M
          }
653
3.59M
          return true;
654
3.59M
        };
655
55.5k
        JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, ysize_shifted,
656
55.5k
                                      ThreadPool::NoInit, process_row,
657
55.5k
                                      "ModularIntToFloat"));
658
55.5k
      }
659
97.2k
      if (rgb_from_gray) {
660
869
        break;
661
869
      }
662
97.2k
    }
663
33.1k
    if (rgb_from_gray) {
664
869
      c = 1;
665
869
    }
666
33.1k
  }
667
33.1k
  size_t num_extra_channels = metadata->m.num_extra_channels;
668
45.8k
  for (size_t ec = 0; ec < num_extra_channels; ec++, c++) {
669
12.6k
    const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec];
670
12.6k
    int bits = eci.bit_depth.bits_per_sample;
671
12.6k
    int exp_bits = eci.bit_depth.exponent_bits_per_sample;
672
12.6k
    bool fp = eci.bit_depth.floating_point_sample;
673
12.6k
    JXL_ENSURE(fp || bits < 32);
674
12.6k
    const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1));
675
12.6k
    JXL_ENSURE(c < gi.channel.size());
676
12.6k
    Channel& ch_in = gi.channel[c];
677
12.6k
    const auto& buffer = render_pipeline_input.GetBuffer(3 + ec);
678
12.6k
    Rect r = buffer.second;
679
12.6k
    Rect mr(modular_rect.x0() >> ch_in.hshift,
680
12.6k
            modular_rect.y0() >> ch_in.vshift,
681
12.6k
            DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
682
12.6k
            DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
683
12.6k
    mr = mr.Crop(ch_in.plane);
684
12.6k
    if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
685
0
      return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
686
0
                         "x%" PRIuS
687
0
                         " modular channel into "
688
0
                         "a %" PRIuS "x%" PRIuS " rect",
689
0
                         mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
690
0
    }
691
668k
    for (size_t y = 0; y < r.ysize(); ++y) {
692
655k
      float* const JXL_RESTRICT row_out = r.Row(buffer.first, y);
693
655k
      const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
694
655k
      if (fp) {
695
7.06k
        JXL_RETURN_IF_ERROR(
696
7.06k
            int_to_float(row_in, row_out, r.xsize(), bits, exp_bits));
697
648k
      } else {
698
648k
        if (full_image.bitdepth < 23) {
699
578k
          HWY_DYNAMIC_DISPATCH(SingleFromSingle)
700
578k
          (r.xsize(), row_in, factor, row_out);
701
578k
        } else {
702
70.5k
          SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out);
703
70.5k
        }
704
648k
      }
705
655k
    }
706
12.6k
  }
707
33.1k
  return true;
708
33.1k
}
709
710
Status ModularFrameDecoder::FinalizeDecoding(const FrameHeader& frame_header,
711
                                             PassesDecoderState* dec_state,
712
                                             jxl::ThreadPool* pool,
713
36.2k
                                             bool inplace) {
714
36.2k
  if (!use_full_image) return true;
715
26.5k
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
716
26.5k
  Image gi{memory_manager};
717
26.5k
  if (inplace) {
718
26.5k
    gi = std::move(full_image);
719
26.5k
  } else {
720
0
    JXL_ASSIGN_OR_RETURN(gi, Image::Clone(full_image));
721
0
  }
722
26.5k
  size_t xsize = gi.w;
723
26.5k
  size_t ysize = gi.h;
724
725
26.5k
  JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s",
726
26.5k
              gi.DebugString().c_str());
727
728
  // Don't use threads if total image size is smaller than a group
729
26.5k
  if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr;
730
731
  // Undo the global transforms
732
26.5k
  gi.undo_transforms(global_header.wp_header, pool);
733
26.5k
  JXL_ENSURE(global_transform.empty());
734
26.5k
  if (gi.error) return JXL_FAILURE("Undoing transforms failed");
735
736
55.1k
  for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) {
737
28.5k
    dec_state->render_pipeline->ClearDone(i);
738
28.5k
  }
739
740
26.5k
  const auto init = [&](size_t num_threads) -> Status {
741
26.5k
    bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT ||
742
26.5k
                          (frame_header.flags & FrameHeader::kNoise));
743
26.5k
    JXL_RETURN_IF_ERROR(dec_state->render_pipeline->PrepareForThreads(
744
26.5k
        num_threads, use_group_ids));
745
26.5k
    return true;
746
26.5k
  };
747
26.5k
  const auto process_group = [&](const uint32_t group,
748
28.5k
                                 size_t thread_id) -> Status {
749
28.5k
    RenderPipelineInput input =
750
28.5k
        dec_state->render_pipeline->GetInputBuffers(group, thread_id);
751
28.5k
    JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(
752
28.5k
        frame_header, gi, dec_state, nullptr, input,
753
28.5k
        dec_state->shared->frame_dim.GroupRect(group)));
754
28.5k
    JXL_RETURN_IF_ERROR(input.Done());
755
28.5k
    return true;
756
28.5k
  };
757
26.5k
  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0,
758
26.5k
                                dec_state->shared->frame_dim.num_groups, init,
759
26.5k
                                process_group, "ModularToRect"));
760
26.5k
  return true;
761
26.5k
}
762
763
static constexpr const float kAlmostZero = 1e-8f;
764
765
Status ModularFrameDecoder::DecodeQuantTable(
766
    JxlMemoryManager* memory_manager, size_t required_size_x,
767
    size_t required_size_y, BitReader* br, QuantEncoding* encoding, size_t idx,
768
90
    ModularFrameDecoder* modular_frame_decoder) {
769
90
  JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den));
770
88
  if (encoding->qraw.qtable_den < kAlmostZero) {
771
    // qtable[] values are already checked for <= 0 so the denominator may not
772
    // be negative.
773
4
    return JXL_FAILURE("Invalid qtable_den: value too small");
774
4
  }
775
168
  JXL_ASSIGN_OR_RETURN(
776
168
      Image image,
777
168
      Image::Create(memory_manager, required_size_x, required_size_y, 8, 3));
778
168
  ModularOptions options;
779
168
  if (modular_frame_decoder) {
780
84
    JXL_ASSIGN_OR_RETURN(ModularStreamId qt, ModularStreamId::QuantTable(idx));
781
84
    JXL_RETURN_IF_ERROR(ModularGenericDecompress(
782
84
        br, image, /*header=*/nullptr, qt.ID(modular_frame_decoder->frame_dim),
783
84
        &options, /*undo_transforms=*/true, &modular_frame_decoder->tree,
784
84
        &modular_frame_decoder->code, &modular_frame_decoder->context_map));
785
84
  } else {
786
0
    JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr,
787
0
                                                 0, &options,
788
0
                                                 /*undo_transforms=*/true));
789
0
  }
790
58
  if (!encoding->qraw.qtable) {
791
58
    encoding->qraw.qtable =
792
58
        new std::vector<int>(required_size_x * required_size_y * 3);
793
58
  } else {
794
0
    JXL_ENSURE(encoding->qraw.qtable->size() ==
795
0
               required_size_x * required_size_y * 3);
796
0
  }
797
58
  int* qtable = encoding->qraw.qtable->data();
798
148
  for (size_t c = 0; c < 3; c++) {
799
2.14k
    for (size_t y = 0; y < required_size_y; y++) {
800
2.05k
      int32_t* JXL_RESTRICT row = image.channel[c].Row(y);
801
210k
      for (size_t x = 0; x < required_size_x; x++) {
802
208k
        qtable[c * required_size_x * required_size_y + y * required_size_x +
803
208k
               x] = row[x];
804
208k
        if (row[x] <= 0) {
805
35
          return JXL_FAILURE("Invalid raw quantization table");
806
35
        }
807
208k
      }
808
2.05k
    }
809
125
  }
810
23
  return true;
811
58
}
812
813
}  // namespace jxl
814
#endif  // HWY_ONCE