Coverage Report

Created: 2024-05-21 06:24

/src/libjxl/lib/jxl/dec_modular.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_modular.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <atomic>
11
#include <cstdint>
12
#include <vector>
13
14
#include "lib/jxl/frame_header.h"
15
16
#undef HWY_TARGET_INCLUDE
17
#define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc"
18
#include <hwy/foreach_target.h>
19
#include <hwy/highway.h>
20
21
#include "lib/jxl/base/compiler_specific.h"
22
#include "lib/jxl/base/printf_macros.h"
23
#include "lib/jxl/base/rect.h"
24
#include "lib/jxl/base/status.h"
25
#include "lib/jxl/compressed_dc.h"
26
#include "lib/jxl/epf.h"
27
#include "lib/jxl/modular/encoding/encoding.h"
28
#include "lib/jxl/modular/modular_image.h"
29
#include "lib/jxl/modular/transform/transform.h"
30
31
HWY_BEFORE_NAMESPACE();
32
namespace jxl {
33
namespace HWY_NAMESPACE {
34
35
// These templates are not found via ADL.
36
using hwy::HWY_NAMESPACE::Add;
37
using hwy::HWY_NAMESPACE::Mul;
38
using hwy::HWY_NAMESPACE::Rebind;
39
40
void MultiplySum(const size_t xsize,
41
                 const pixel_type* const JXL_RESTRICT row_in,
42
                 const pixel_type* const JXL_RESTRICT row_in_Y,
43
0
                 const float factor, float* const JXL_RESTRICT row_out) {
44
0
  const HWY_FULL(float) df;
45
0
  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
46
0
  const auto factor_v = Set(df, factor);
47
0
  for (size_t x = 0; x < xsize; x += Lanes(di)) {
48
0
    const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x));
49
0
    const auto out = Mul(ConvertTo(df, in), factor_v);
50
0
    Store(out, df, row_out + x);
51
0
  }
52
0
}
Unexecuted instantiation: jxl::N_SSE4::MultiplySum(unsigned long, int const*, int const*, float, float*)
Unexecuted instantiation: jxl::N_AVX2::MultiplySum(unsigned long, int const*, int const*, float, float*)
Unexecuted instantiation: jxl::N_SSE2::MultiplySum(unsigned long, int const*, int const*, float, float*)
53
54
void RgbFromSingle(const size_t xsize,
55
                   const pixel_type* const JXL_RESTRICT row_in,
56
                   const float factor, float* out_r, float* out_g,
57
0
                   float* out_b) {
58
0
  const HWY_FULL(float) df;
59
0
  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
60
61
0
  const auto factor_v = Set(df, factor);
62
0
  for (size_t x = 0; x < xsize; x += Lanes(di)) {
63
0
    const auto in = Load(di, row_in + x);
64
0
    const auto out = Mul(ConvertTo(df, in), factor_v);
65
0
    Store(out, df, out_r + x);
66
0
    Store(out, df, out_g + x);
67
0
    Store(out, df, out_b + x);
68
0
  }
69
0
}
Unexecuted instantiation: jxl::N_SSE4::RgbFromSingle(unsigned long, int const*, float, float*, float*, float*)
Unexecuted instantiation: jxl::N_AVX2::RgbFromSingle(unsigned long, int const*, float, float*, float*, float*)
Unexecuted instantiation: jxl::N_SSE2::RgbFromSingle(unsigned long, int const*, float, float*, float*, float*)
70
71
void SingleFromSingle(const size_t xsize,
72
                      const pixel_type* const JXL_RESTRICT row_in,
73
0
                      const float factor, float* row_out) {
74
0
  const HWY_FULL(float) df;
75
0
  const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
76
77
0
  const auto factor_v = Set(df, factor);
78
0
  for (size_t x = 0; x < xsize; x += Lanes(di)) {
79
0
    const auto in = Load(di, row_in + x);
80
0
    const auto out = Mul(ConvertTo(df, in), factor_v);
81
0
    Store(out, df, row_out + x);
82
0
  }
83
0
}
Unexecuted instantiation: jxl::N_SSE4::SingleFromSingle(unsigned long, int const*, float, float*)
Unexecuted instantiation: jxl::N_AVX2::SingleFromSingle(unsigned long, int const*, float, float*)
Unexecuted instantiation: jxl::N_SSE2::SingleFromSingle(unsigned long, int const*, float, float*)
84
// NOLINTNEXTLINE(google-readability-namespace-comments)
85
}  // namespace HWY_NAMESPACE
86
}  // namespace jxl
87
HWY_AFTER_NAMESPACE();
88
89
#if HWY_ONCE
90
namespace jxl {
91
HWY_EXPORT(MultiplySum);       // Local function
92
HWY_EXPORT(RgbFromSingle);     // Local function
93
HWY_EXPORT(SingleFromSingle);  // Local function
94
95
// Slow conversion using double precision multiplication, only
96
// needed when the bit depth is too high for single precision
97
void SingleFromSingleAccurate(const size_t xsize,
98
                              const pixel_type* const JXL_RESTRICT row_in,
99
0
                              const double factor, float* row_out) {
100
0
  for (size_t x = 0; x < xsize; x++) {
101
0
    row_out[x] = row_in[x] * factor;
102
0
  }
103
0
}
104
105
// convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int
106
// back to binary32 float
107
void int_to_float(const pixel_type* const JXL_RESTRICT row_in,
108
                  float* const JXL_RESTRICT row_out, const size_t xsize,
109
0
                  const int bits, const int exp_bits) {
110
0
  if (bits == 32) {
111
0
    JXL_ASSERT(sizeof(pixel_type) == sizeof(float));
112
0
    JXL_ASSERT(exp_bits == 8);
113
0
    memcpy(row_out, row_in, xsize * sizeof(float));
114
0
    return;
115
0
  }
116
0
  int exp_bias = (1 << (exp_bits - 1)) - 1;
117
0
  int sign_shift = bits - 1;
118
0
  int mant_bits = bits - exp_bits - 1;
119
0
  int mant_shift = 23 - mant_bits;
120
0
  for (size_t x = 0; x < xsize; ++x) {
121
0
    uint32_t f;
122
0
    memcpy(&f, &row_in[x], 4);
123
0
    int signbit = (f >> sign_shift);
124
0
    f &= (1 << sign_shift) - 1;
125
0
    if (f == 0) {
126
0
      row_out[x] = (signbit ? -0.f : 0.f);
127
0
      continue;
128
0
    }
129
0
    int exp = (f >> mant_bits);
130
0
    int mantissa = (f & ((1 << mant_bits) - 1));
131
0
    mantissa <<= mant_shift;
132
    // Try to normalize only if there is space for maneuver.
133
0
    if (exp == 0 && exp_bits < 8) {
134
      // subnormal number
135
0
      while ((mantissa & 0x800000) == 0) {
136
0
        mantissa <<= 1;
137
0
        exp--;
138
0
      }
139
0
      exp++;
140
      // remove leading 1 because it is implicit now
141
0
      mantissa &= 0x7fffff;
142
0
    }
143
0
    exp -= exp_bias;
144
    // broke up the arbitrary float into its parts, now reassemble into
145
    // binary32
146
0
    exp += 127;
147
0
    JXL_ASSERT(exp >= 0);
148
0
    f = (signbit ? 0x80000000 : 0);
149
0
    f |= (exp << 23);
150
0
    f |= mantissa;
151
0
    memcpy(&row_out[x], &f, 4);
152
0
  }
153
0
}
154
155
#if JXL_DEBUG_V_LEVEL >= 1
156
std::string ModularStreamId::DebugString() const {
157
  std::ostringstream os;
158
  os << (kind == kGlobalData   ? "ModularGlobal"
159
         : kind == kVarDCTDC   ? "VarDCTDC"
160
         : kind == kModularDC  ? "ModularDC"
161
         : kind == kACMetadata ? "ACMeta"
162
         : kind == kQuantTable ? "QuantTable"
163
         : kind == kModularAC  ? "ModularAC"
164
                               : "");
165
  if (kind == kVarDCTDC || kind == kModularDC || kind == kACMetadata ||
166
      kind == kModularAC) {
167
    os << " group " << group_id;
168
  }
169
  if (kind == kModularAC) {
170
    os << " pass " << pass_id;
171
  }
172
  if (kind == kQuantTable) {
173
    os << " " << quant_table_id;
174
  }
175
  return os.str();
176
}
177
#endif
178
179
Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
180
                                             const FrameHeader& frame_header,
181
0
                                             bool allow_truncated_group) {
182
0
  JxlMemoryManager* memory_manager = this->memory_manager();
183
0
  bool decode_color = frame_header.encoding == FrameEncoding::kModular;
184
0
  const auto& metadata = frame_header.nonserialized_metadata->m;
185
0
  bool is_gray = metadata.color_encoding.IsGray();
186
0
  size_t nb_chans = 3;
187
0
  if (is_gray && frame_header.color_transform == ColorTransform::kNone) {
188
0
    nb_chans = 1;
189
0
  }
190
0
  do_color = decode_color;
191
0
  size_t nb_extra = metadata.extra_channel_info.size();
192
0
  bool has_tree = static_cast<bool>(reader->ReadBits(1));
193
0
  if (!allow_truncated_group ||
194
0
      reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) {
195
0
    if (has_tree) {
196
0
      size_t tree_size_limit =
197
0
          std::min(static_cast<size_t>(1 << 22),
198
0
                   1024 + frame_dim.xsize * frame_dim.ysize *
199
0
                              (nb_chans + nb_extra) / 16);
200
0
      JXL_RETURN_IF_ERROR(
201
0
          DecodeTree(memory_manager, reader, &tree, tree_size_limit));
202
0
      JXL_RETURN_IF_ERROR(DecodeHistograms(
203
0
          memory_manager, reader, (tree.size() + 1) / 2, &code, &context_map));
204
0
    }
205
0
  }
206
0
  if (!do_color) nb_chans = 0;
207
208
0
  bool fp = metadata.bit_depth.floating_point_sample;
209
210
  // bits_per_sample is just metadata for XYB images.
211
0
  if (metadata.bit_depth.bits_per_sample >= 32 && do_color &&
212
0
      frame_header.color_transform != ColorTransform::kXYB) {
213
0
    if (metadata.bit_depth.bits_per_sample == 32 && fp == false) {
214
0
      return JXL_FAILURE("uint32_t not supported in dec_modular");
215
0
    } else if (metadata.bit_depth.bits_per_sample > 32) {
216
0
      return JXL_FAILURE("bits_per_sample > 32 not supported");
217
0
    }
218
0
  }
219
220
0
  JXL_ASSIGN_OR_RETURN(
221
0
      Image gi,
222
0
      Image::Create(memory_manager, frame_dim.xsize, frame_dim.ysize,
223
0
                    metadata.bit_depth.bits_per_sample, nb_chans + nb_extra));
224
225
0
  all_same_shift = true;
226
0
  if (frame_header.color_transform == ColorTransform::kYCbCr) {
227
0
    for (size_t c = 0; c < nb_chans; c++) {
228
0
      gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c);
229
0
      gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c);
230
0
      size_t xsize_shifted =
231
0
          DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift);
232
0
      size_t ysize_shifted =
233
0
          DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift);
234
0
      JXL_RETURN_IF_ERROR(gi.channel[c].shrink(xsize_shifted, ysize_shifted));
235
0
      if (gi.channel[c].hshift != gi.channel[0].hshift ||
236
0
          gi.channel[c].vshift != gi.channel[0].vshift)
237
0
        all_same_shift = false;
238
0
    }
239
0
  }
240
241
0
  for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) {
242
0
    size_t ecups = frame_header.extra_channel_upsampling[ec];
243
0
    JXL_RETURN_IF_ERROR(
244
0
        gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups),
245
0
                             DivCeil(frame_dim.ysize_upsampled, ecups)));
246
0
    gi.channel[c].hshift = gi.channel[c].vshift =
247
0
        CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling);
248
0
    if (gi.channel[c].hshift != gi.channel[0].hshift ||
249
0
        gi.channel[c].vshift != gi.channel[0].vshift)
250
0
      all_same_shift = false;
251
0
  }
252
253
0
  JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s",
254
0
              gi.DebugString().c_str());
255
0
  ModularOptions options;
256
0
  options.max_chan_size = frame_dim.group_dim;
257
0
  options.group_dim = frame_dim.group_dim;
258
0
  Status dec_status = ModularGenericDecompress(
259
0
      reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim),
260
0
      &options,
261
0
      /*undo_transforms=*/false, &tree, &code, &context_map,
262
0
      allow_truncated_group);
263
0
  if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
264
0
  if (dec_status.IsFatalError()) {
265
0
    return JXL_FAILURE("Failed to decode global modular info");
266
0
  }
267
268
  // TODO(eustas): are we sure this can be done after partial decode?
269
0
  have_something = false;
270
0
  for (size_t c = 0; c < gi.channel.size(); c++) {
271
0
    Channel& gic = gi.channel[c];
272
0
    if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim &&
273
0
        gic.h <= frame_dim.group_dim)
274
0
      have_something = true;
275
0
  }
276
  // move global transforms to groups if possible
277
0
  if (!have_something && all_same_shift) {
278
0
    if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) {
279
0
      global_transform = gi.transform;
280
0
      gi.transform.clear();
281
      // TODO(jon): also move no-delta-palette out (trickier though)
282
0
    }
283
0
  }
284
0
  full_image = std::move(gi);
285
0
  JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s",
286
0
              full_image.DebugString().c_str());
287
0
  return dec_status;
288
0
}
289
290
0
void ModularFrameDecoder::MaybeDropFullImage() {
291
0
  if (full_image.transform.empty() && !have_something && all_same_shift) {
292
0
    use_full_image = false;
293
0
    JXL_DEBUG_V(6, "Dropping full image");
294
0
    for (auto& ch : full_image.channel) {
295
      // keep metadata on channels around, but dealloc their planes
296
0
      ch.plane = Plane<pixel_type>();
297
0
    }
298
0
  }
299
0
}
300
301
Status ModularFrameDecoder::DecodeGroup(
302
    const FrameHeader& frame_header, const Rect& rect, BitReader* reader,
303
    int minShift, int maxShift, const ModularStreamId& stream, bool zerofill,
304
    PassesDecoderState* dec_state, RenderPipelineInput* render_pipeline_input,
305
0
    bool allow_truncated, bool* should_run_pipeline) {
306
0
  JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s",
307
0
              stream.DebugString().c_str(), Description(rect).c_str(), minShift,
308
0
              maxShift, zerofill ? "using zerofill" : "");
309
0
  JXL_DASSERT(stream.kind == ModularStreamId::kModularDC ||
310
0
              stream.kind == ModularStreamId::kModularAC);
311
0
  const size_t xsize = rect.xsize();
312
0
  const size_t ysize = rect.ysize();
313
0
  JXL_ASSIGN_OR_RETURN(Image gi, Image::Create(memory_manager_, xsize, ysize,
314
0
                                               full_image.bitdepth, 0));
315
  // start at the first bigger-than-groupsize non-metachannel
316
0
  size_t c = full_image.nb_meta_channels;
317
0
  for (; c < full_image.channel.size(); c++) {
318
0
    Channel& fc = full_image.channel[c];
319
0
    if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break;
320
0
  }
321
0
  size_t beginc = c;
322
0
  for (; c < full_image.channel.size(); c++) {
323
0
    Channel& fc = full_image.channel[c];
324
0
    int shift = std::min(fc.hshift, fc.vshift);
325
0
    if (shift > maxShift) continue;
326
0
    if (shift < minShift) continue;
327
0
    Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
328
0
           rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
329
0
    if (r.xsize() == 0 || r.ysize() == 0) continue;
330
0
    if (zerofill && use_full_image) {
331
0
      for (size_t y = 0; y < r.ysize(); ++y) {
332
0
        pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y);
333
0
        memset(row_out, 0, r.xsize() * sizeof(*row_out));
334
0
      }
335
0
    } else {
336
0
      JXL_ASSIGN_OR_RETURN(
337
0
          Channel gc, Channel::Create(memory_manager_, r.xsize(), r.ysize()));
338
0
      if (zerofill) ZeroFillImage(&gc.plane);
339
0
      gc.hshift = fc.hshift;
340
0
      gc.vshift = fc.vshift;
341
0
      gi.channel.emplace_back(std::move(gc));
342
0
    }
343
0
  }
344
0
  if (zerofill && use_full_image) return true;
345
  // Return early if there's nothing to decode. Otherwise there might be
346
  // problems later (in ModularImageToDecodedRect).
347
0
  if (gi.channel.empty()) {
348
0
    if (dec_state && should_run_pipeline) {
349
0
      const auto* metadata = frame_header.nonserialized_metadata;
350
0
      if (do_color || metadata->m.num_extra_channels > 0) {
351
        // Signal to FrameDecoder that we do not have some of the required input
352
        // for the render pipeline.
353
0
        *should_run_pipeline = false;
354
0
      }
355
0
    }
356
0
    JXL_DEBUG_V(6, "Nothing to decode, returning early.");
357
0
    return true;
358
0
  }
359
0
  ModularOptions options;
360
0
  if (!zerofill) {
361
0
    auto status = ModularGenericDecompress(
362
0
        reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options,
363
0
        /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated);
364
0
    if (!allow_truncated) JXL_RETURN_IF_ERROR(status);
365
0
    if (status.IsFatalError()) return status;
366
0
  }
367
  // Undo global transforms that have been pushed to the group level
368
0
  if (!use_full_image) {
369
0
    JXL_ASSERT(render_pipeline_input);
370
0
    for (const auto& t : global_transform) {
371
0
      JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header));
372
0
    }
373
0
    JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(
374
0
        frame_header, gi, dec_state, nullptr, *render_pipeline_input,
375
0
        Rect(0, 0, gi.w, gi.h)));
376
0
    return true;
377
0
  }
378
0
  int gic = 0;
379
0
  for (c = beginc; c < full_image.channel.size(); c++) {
380
0
    Channel& fc = full_image.channel[c];
381
0
    int shift = std::min(fc.hshift, fc.vshift);
382
0
    if (shift > maxShift) continue;
383
0
    if (shift < minShift) continue;
384
0
    Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
385
0
           rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
386
0
    if (r.xsize() == 0 || r.ysize() == 0) continue;
387
0
    JXL_ASSERT(use_full_image);
388
0
    CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()),
389
0
                /*from=*/gi.channel[gic].plane,
390
0
                /*rect_to=*/r, /*to=*/&fc.plane);
391
0
    gic++;
392
0
  }
393
0
  return true;
394
0
}
395
396
Status ModularFrameDecoder::DecodeVarDCTDC(const FrameHeader& frame_header,
397
                                           size_t group_id, BitReader* reader,
398
0
                                           PassesDecoderState* dec_state) {
399
0
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
400
0
  const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id);
401
0
  JXL_DEBUG_V(6, "Decoding VarDCT DC with rect %s", Description(r).c_str());
402
  // TODO(eustas): investigate if we could reduce the impact of
403
  //               EvalRationalPolynomial; generally speaking, the limit is
404
  //               2**(128/(3*magic)), where 128 comes from IEEE 754 exponent,
405
  //               3 comes from XybToRgb that cubes the values, and "magic" is
406
  //               the sum of all other contributions. 2**18 is known to lead
407
  //               to NaN on input found by fuzzing (see commit message).
408
0
  JXL_ASSIGN_OR_RETURN(Image image,
409
0
                       Image::Create(memory_manager, r.xsize(), r.ysize(),
410
0
                                     full_image.bitdepth, 3));
411
0
  size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim);
412
0
  reader->Refill();
413
0
  size_t extra_precision = reader->ReadFixedBits<2>();
414
0
  float mul = 1.0f / (1 << extra_precision);
415
0
  ModularOptions options;
416
0
  for (size_t c = 0; c < 3; c++) {
417
0
    Channel& ch = image.channel[c < 2 ? c ^ 1 : c];
418
0
    ch.w >>= frame_header.chroma_subsampling.HShift(c);
419
0
    ch.h >>= frame_header.chroma_subsampling.VShift(c);
420
0
    JXL_RETURN_IF_ERROR(ch.shrink());
421
0
  }
422
0
  if (!ModularGenericDecompress(
423
0
          reader, image, /*header=*/nullptr, stream_id, &options,
424
0
          /*undo_transforms=*/true, &tree, &code, &context_map)) {
425
0
    return JXL_FAILURE("Failed to decode VarDCT DC group (DC group id %d)",
426
0
                       static_cast<int>(group_id));
427
0
  }
428
0
  DequantDC(r, &dec_state->shared_storage.dc_storage,
429
0
            &dec_state->shared_storage.quant_dc, image,
430
0
            dec_state->shared->quantizer.MulDC(), mul,
431
0
            dec_state->shared->cmap.base().DCFactors(),
432
0
            frame_header.chroma_subsampling, dec_state->shared->block_ctx_map);
433
0
  return true;
434
0
}
435
436
Status ModularFrameDecoder::DecodeAcMetadata(const FrameHeader& frame_header,
437
                                             size_t group_id, BitReader* reader,
438
0
                                             PassesDecoderState* dec_state) {
439
0
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
440
0
  const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id);
441
0
  JXL_DEBUG_V(6, "Decoding AcMetadata with rect %s", Description(r).c_str());
442
0
  size_t upper_bound = r.xsize() * r.ysize();
443
0
  reader->Refill();
444
0
  size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1;
445
0
  size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim);
446
  // YToX, YToB, ACS + QF, EPF
447
0
  JXL_ASSIGN_OR_RETURN(Image image,
448
0
                       Image::Create(memory_manager, r.xsize(), r.ysize(),
449
0
                                     full_image.bitdepth, 4));
450
0
  static_assert(kColorTileDimInBlocks == 8, "Color tile size changed");
451
0
  Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3);
452
0
  JXL_ASSIGN_OR_RETURN(
453
0
      image.channel[0],
454
0
      Channel::Create(memory_manager, cr.xsize(), cr.ysize(), 3, 3));
455
0
  JXL_ASSIGN_OR_RETURN(
456
0
      image.channel[1],
457
0
      Channel::Create(memory_manager, cr.xsize(), cr.ysize(), 3, 3));
458
0
  JXL_ASSIGN_OR_RETURN(image.channel[2],
459
0
                       Channel::Create(memory_manager, count, 2, 0, 0));
460
0
  ModularOptions options;
461
0
  if (!ModularGenericDecompress(
462
0
          reader, image, /*header=*/nullptr, stream_id, &options,
463
0
          /*undo_transforms=*/true, &tree, &code, &context_map)) {
464
0
    return JXL_FAILURE("Failed to decode AC metadata");
465
0
  }
466
0
  ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr,
467
0
                       &dec_state->shared_storage.cmap.ytox_map);
468
0
  ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, cr,
469
0
                       &dec_state->shared_storage.cmap.ytob_map);
470
0
  size_t num = 0;
471
0
  bool is444 = frame_header.chroma_subsampling.Is444();
472
0
  auto& ac_strategy = dec_state->shared_storage.ac_strategy;
473
0
  size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize());
474
0
  size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize());
475
0
  uint32_t local_used_acs = 0;
476
0
  for (size_t iy = 0; iy < r.ysize(); iy++) {
477
0
    size_t y = r.y0() + iy;
478
0
    int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy);
479
0
    uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy);
480
0
    int32_t* row_in_1 = image.channel[2].plane.Row(0);
481
0
    int32_t* row_in_2 = image.channel[2].plane.Row(1);
482
0
    int32_t* row_in_3 = image.channel[3].plane.Row(iy);
483
0
    for (size_t ix = 0; ix < r.xsize(); ix++) {
484
0
      size_t x = r.x0() + ix;
485
0
      int sharpness = row_in_3[ix];
486
0
      if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) {
487
0
        return JXL_FAILURE("Corrupted sharpness field");
488
0
      }
489
0
      row_epf[ix] = sharpness;
490
0
      if (ac_strategy.IsValid(x, y)) {
491
0
        continue;
492
0
      }
493
494
0
      if (num >= count) return JXL_FAILURE("Corrupted stream");
495
496
0
      if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) {
497
0
        return JXL_FAILURE("Invalid AC strategy");
498
0
      }
499
0
      local_used_acs |= 1u << row_in_1[num];
500
0
      AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]);
501
0
      if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) &&
502
0
          !is444) {
503
0
        return JXL_FAILURE(
504
0
            "AC strategy not compatible with chroma subsampling");
505
0
      }
506
      // Ensure that blocks do not overflow *AC* groups.
507
0
      size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
508
0
      size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
509
0
      size_t next_x_dct_block = x + acs.covered_blocks_x();
510
0
      size_t next_y_dct_block = y + acs.covered_blocks_y();
511
0
      if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) {
512
0
        return JXL_FAILURE("Invalid AC strategy, x overflow");
513
0
      }
514
0
      if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) {
515
0
        return JXL_FAILURE("Invalid AC strategy, y overflow");
516
0
      }
517
0
      JXL_RETURN_IF_ERROR(
518
0
          ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num])));
519
0
      row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1,
520
0
                                                     row_in_2[num]));
521
0
      num++;
522
0
    }
523
0
  }
524
0
  dec_state->used_acs |= local_used_acs;
525
0
  if (frame_header.loop_filter.epf_iters > 0) {
526
0
    ComputeSigma(frame_header.loop_filter, r, dec_state);
527
0
  }
528
0
  return true;
529
0
}
530
531
Status ModularFrameDecoder::ModularImageToDecodedRect(
532
    const FrameHeader& frame_header, Image& gi, PassesDecoderState* dec_state,
533
    jxl::ThreadPool* pool, RenderPipelineInput& render_pipeline_input,
534
0
    Rect modular_rect) const {
535
0
  const auto* metadata = frame_header.nonserialized_metadata;
536
0
  JXL_CHECK(gi.transform.empty());
537
538
0
  auto get_row = [&](size_t c, size_t y) {
539
0
    const auto& buffer = render_pipeline_input.GetBuffer(c);
540
0
    return buffer.second.Row(buffer.first, y);
541
0
  };
542
543
0
  size_t c = 0;
544
0
  if (do_color) {
545
0
    const bool rgb_from_gray =
546
0
        metadata->m.color_encoding.IsGray() &&
547
0
        frame_header.color_transform == ColorTransform::kNone;
548
0
    const bool fp = metadata->m.bit_depth.floating_point_sample &&
549
0
                    frame_header.color_transform != ColorTransform::kXYB;
550
0
    for (; c < 3; c++) {
551
0
      double factor = full_image.bitdepth < 32
552
0
                          ? 1.0 / ((1u << full_image.bitdepth) - 1)
553
0
                          : 0;
554
0
      size_t c_in = c;
555
0
      if (frame_header.color_transform == ColorTransform::kXYB) {
556
0
        factor = dec_state->shared->matrices.DCQuants()[c];
557
        // XYB is encoded as YX(B-Y)
558
0
        if (c < 2) c_in = 1 - c;
559
0
      } else if (rgb_from_gray) {
560
0
        c_in = 0;
561
0
      }
562
0
      JXL_ASSERT(c_in < gi.channel.size());
563
0
      Channel& ch_in = gi.channel[c_in];
564
      // TODO(eustas): could we detect it on earlier stage?
565
0
      if (ch_in.w == 0 || ch_in.h == 0) {
566
0
        return JXL_FAILURE("Empty image");
567
0
      }
568
0
      JXL_CHECK(ch_in.hshift <= 3 && ch_in.vshift <= 3);
569
0
      Rect r = render_pipeline_input.GetBuffer(c).second;
570
0
      Rect mr(modular_rect.x0() >> ch_in.hshift,
571
0
              modular_rect.y0() >> ch_in.vshift,
572
0
              DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
573
0
              DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
574
0
      mr = mr.Crop(ch_in.plane);
575
0
      size_t xsize_shifted = r.xsize();
576
0
      size_t ysize_shifted = r.ysize();
577
0
      if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
578
0
        return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
579
0
                           "x%" PRIuS
580
0
                           " modular channel into "
581
0
                           "a %" PRIuS "x%" PRIuS " rect",
582
0
                           mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
583
0
      }
584
0
      if (frame_header.color_transform == ColorTransform::kXYB && c == 2) {
585
0
        JXL_ASSERT(!fp);
586
0
        JXL_RETURN_IF_ERROR(RunOnPool(
587
0
            pool, 0, ysize_shifted, ThreadPool::NoInit,
588
0
            [&](const uint32_t task, size_t /* thread */) {
589
0
              const size_t y = task;
590
0
              const pixel_type* const JXL_RESTRICT row_in =
591
0
                  mr.Row(&ch_in.plane, y);
592
0
              const pixel_type* const JXL_RESTRICT row_in_Y =
593
0
                  mr.Row(&gi.channel[0].plane, y);
594
0
              float* const JXL_RESTRICT row_out = get_row(c, y);
595
0
              HWY_DYNAMIC_DISPATCH(MultiplySum)
596
0
              (xsize_shifted, row_in, row_in_Y, factor, row_out);
597
0
            },
598
0
            "ModularIntToFloat"));
599
0
      } else if (fp) {
600
0
        int bits = metadata->m.bit_depth.bits_per_sample;
601
0
        int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample;
602
0
        JXL_RETURN_IF_ERROR(RunOnPool(
603
0
            pool, 0, ysize_shifted, ThreadPool::NoInit,
604
0
            [&](const uint32_t task, size_t /* thread */) {
605
0
              const size_t y = task;
606
0
              const pixel_type* const JXL_RESTRICT row_in =
607
0
                  mr.Row(&ch_in.plane, y);
608
0
              if (rgb_from_gray) {
609
0
                for (size_t cc = 0; cc < 3; cc++) {
610
0
                  float* const JXL_RESTRICT row_out = get_row(cc, y);
611
0
                  int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
612
0
                }
613
0
              } else {
614
0
                float* const JXL_RESTRICT row_out = get_row(c, y);
615
0
                int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
616
0
              }
617
0
            },
618
0
            "ModularIntToFloat_losslessfloat"));
619
0
      } else {
620
0
        JXL_RETURN_IF_ERROR(RunOnPool(
621
0
            pool, 0, ysize_shifted, ThreadPool::NoInit,
622
0
            [&](const uint32_t task, size_t /* thread */) {
623
0
              const size_t y = task;
624
0
              const pixel_type* const JXL_RESTRICT row_in =
625
0
                  mr.Row(&ch_in.plane, y);
626
0
              if (rgb_from_gray) {
627
0
                if (full_image.bitdepth < 23) {
628
0
                  HWY_DYNAMIC_DISPATCH(RgbFromSingle)
629
0
                  (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y),
630
0
                   get_row(2, y));
631
0
                } else {
632
0
                  SingleFromSingleAccurate(xsize_shifted, row_in, factor,
633
0
                                           get_row(0, y));
634
0
                  SingleFromSingleAccurate(xsize_shifted, row_in, factor,
635
0
                                           get_row(1, y));
636
0
                  SingleFromSingleAccurate(xsize_shifted, row_in, factor,
637
0
                                           get_row(2, y));
638
0
                }
639
0
              } else {
640
0
                float* const JXL_RESTRICT row_out = get_row(c, y);
641
0
                if (full_image.bitdepth < 23) {
642
0
                  HWY_DYNAMIC_DISPATCH(SingleFromSingle)
643
0
                  (xsize_shifted, row_in, factor, row_out);
644
0
                } else {
645
0
                  SingleFromSingleAccurate(xsize_shifted, row_in, factor,
646
0
                                           row_out);
647
0
                }
648
0
              }
649
0
            },
650
0
            "ModularIntToFloat"));
651
0
      }
652
0
      if (rgb_from_gray) {
653
0
        break;
654
0
      }
655
0
    }
656
0
    if (rgb_from_gray) {
657
0
      c = 1;
658
0
    }
659
0
  }
660
0
  size_t num_extra_channels = metadata->m.num_extra_channels;
661
0
  for (size_t ec = 0; ec < num_extra_channels; ec++, c++) {
662
0
    const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec];
663
0
    int bits = eci.bit_depth.bits_per_sample;
664
0
    int exp_bits = eci.bit_depth.exponent_bits_per_sample;
665
0
    bool fp = eci.bit_depth.floating_point_sample;
666
0
    JXL_ASSERT(fp || bits < 32);
667
0
    const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1));
668
0
    JXL_ASSERT(c < gi.channel.size());
669
0
    Channel& ch_in = gi.channel[c];
670
0
    Rect r = render_pipeline_input.GetBuffer(3 + ec).second;
671
0
    Rect mr(modular_rect.x0() >> ch_in.hshift,
672
0
            modular_rect.y0() >> ch_in.vshift,
673
0
            DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
674
0
            DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
675
0
    mr = mr.Crop(ch_in.plane);
676
0
    if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
677
0
      return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
678
0
                         "x%" PRIuS
679
0
                         " modular channel into "
680
0
                         "a %" PRIuS "x%" PRIuS " rect",
681
0
                         mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
682
0
    }
683
0
    for (size_t y = 0; y < r.ysize(); ++y) {
684
0
      float* const JXL_RESTRICT row_out =
685
0
          r.Row(render_pipeline_input.GetBuffer(3 + ec).first, y);
686
0
      const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
687
0
      if (fp) {
688
0
        int_to_float(row_in, row_out, r.xsize(), bits, exp_bits);
689
0
      } else {
690
0
        if (full_image.bitdepth < 23) {
691
0
          HWY_DYNAMIC_DISPATCH(SingleFromSingle)
692
0
          (r.xsize(), row_in, factor, row_out);
693
0
        } else {
694
0
          SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out);
695
0
        }
696
0
      }
697
0
    }
698
0
  }
699
0
  return true;
700
0
}
701
702
Status ModularFrameDecoder::FinalizeDecoding(const FrameHeader& frame_header,
703
                                             PassesDecoderState* dec_state,
704
                                             jxl::ThreadPool* pool,
705
0
                                             bool inplace) {
706
0
  if (!use_full_image) return true;
707
0
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
708
0
  Image gi{memory_manager};
709
0
  if (inplace) {
710
0
    gi = std::move(full_image);
711
0
  } else {
712
0
    JXL_ASSIGN_OR_RETURN(gi, Image::Clone(full_image));
713
0
  }
714
0
  size_t xsize = gi.w;
715
0
  size_t ysize = gi.h;
716
717
0
  JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s",
718
0
              gi.DebugString().c_str());
719
720
  // Don't use threads if total image size is smaller than a group
721
0
  if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr;
722
723
  // Undo the global transforms
724
0
  gi.undo_transforms(global_header.wp_header, pool);
725
0
  JXL_DASSERT(global_transform.empty());
726
0
  if (gi.error) return JXL_FAILURE("Undoing transforms failed");
727
728
0
  for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) {
729
0
    dec_state->render_pipeline->ClearDone(i);
730
0
  }
731
0
  std::atomic<bool> has_error{false};
732
0
  JXL_RETURN_IF_ERROR(RunOnPool(
733
0
      pool, 0, dec_state->shared->frame_dim.num_groups,
734
0
      [&](size_t num_threads) {
735
0
        bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT ||
736
0
                              (frame_header.flags & FrameHeader::kNoise));
737
0
        return dec_state->render_pipeline->PrepareForThreads(num_threads,
738
0
                                                             use_group_ids);
739
0
      },
740
0
      [&](const uint32_t group, size_t thread_id) {
741
0
        if (has_error) return;
742
0
        RenderPipelineInput input =
743
0
            dec_state->render_pipeline->GetInputBuffers(group, thread_id);
744
0
        if (!ModularImageToDecodedRect(
745
0
                frame_header, gi, dec_state, nullptr, input,
746
0
                dec_state->shared->frame_dim.GroupRect(group))) {
747
0
          has_error = true;
748
0
          return;
749
0
        }
750
0
        if (!input.Done()) {
751
0
          has_error = true;
752
0
          return;
753
0
        }
754
0
      },
755
0
      "ModularToRect"));
756
0
  if (has_error) return JXL_FAILURE("Error producing input to render pipeline");
757
0
  return true;
758
0
}
759
760
static constexpr const float kAlmostZero = 1e-8f;
761
762
Status ModularFrameDecoder::DecodeQuantTable(
763
    JxlMemoryManager* memory_manager, size_t required_size_x,
764
    size_t required_size_y, BitReader* br, QuantEncoding* encoding, size_t idx,
765
0
    ModularFrameDecoder* modular_frame_decoder) {
766
0
  JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den));
767
0
  if (encoding->qraw.qtable_den < kAlmostZero) {
768
    // qtable[] values are already checked for <= 0 so the denominator may not
769
    // be negative.
770
0
    return JXL_FAILURE("Invalid qtable_den: value too small");
771
0
  }
772
0
  JXL_ASSIGN_OR_RETURN(
773
0
      Image image,
774
0
      Image::Create(memory_manager, required_size_x, required_size_y, 8, 3));
775
0
  ModularOptions options;
776
0
  if (modular_frame_decoder) {
777
0
    JXL_RETURN_IF_ERROR(ModularGenericDecompress(
778
0
        br, image, /*header=*/nullptr,
779
0
        ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim),
780
0
        &options, /*undo_transforms=*/true, &modular_frame_decoder->tree,
781
0
        &modular_frame_decoder->code, &modular_frame_decoder->context_map));
782
0
  } else {
783
0
    JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr,
784
0
                                                 0, &options,
785
0
                                                 /*undo_transforms=*/true));
786
0
  }
787
0
  if (!encoding->qraw.qtable) {
788
0
    encoding->qraw.qtable = new std::vector<int>();
789
0
  }
790
0
  encoding->qraw.qtable->resize(required_size_x * required_size_y * 3);
791
0
  for (size_t c = 0; c < 3; c++) {
792
0
    for (size_t y = 0; y < required_size_y; y++) {
793
0
      int32_t* JXL_RESTRICT row = image.channel[c].Row(y);
794
0
      for (size_t x = 0; x < required_size_x; x++) {
795
0
        (*encoding->qraw.qtable)[c * required_size_x * required_size_y +
796
0
                                 y * required_size_x + x] = row[x];
797
0
        if (row[x] <= 0) {
798
0
          return JXL_FAILURE("Invalid raw quantization table");
799
0
        }
800
0
      }
801
0
    }
802
0
  }
803
0
  return true;
804
0
}
805
806
}  // namespace jxl
807
#endif  // HWY_ONCE