Coverage Report

Created: 2025-06-13 07:37

/src/libjxl/lib/jxl/dec_group.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_group.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <algorithm>
11
#include <array>
12
#include <cstdint>
13
#include <cstdio>
14
#include <cstdlib>
15
#include <cstring>
16
#include <memory>
17
#include <utility>
18
#include <vector>
19
20
#include "lib/jxl/base/compiler_specific.h"
21
#include "lib/jxl/chroma_from_luma.h"
22
#include "lib/jxl/coeff_order_fwd.h"
23
#include "lib/jxl/dct_util.h"
24
#include "lib/jxl/dec_ans.h"
25
#include "lib/jxl/frame_dimensions.h"
26
#include "lib/jxl/frame_header.h"
27
#include "lib/jxl/image.h"
28
#include "lib/jxl/image_ops.h"
29
#include "lib/jxl/jpeg/jpeg_data.h"
30
#include "lib/jxl/render_pipeline/render_pipeline.h"
31
#include "lib/jxl/render_pipeline/render_pipeline_stage.h"
32
33
#undef HWY_TARGET_INCLUDE
34
#define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc"
35
#include <hwy/foreach_target.h>
36
#include <hwy/highway.h>
37
38
#include "lib/jxl/ac_context.h"
39
#include "lib/jxl/ac_strategy.h"
40
#include "lib/jxl/base/bits.h"
41
#include "lib/jxl/base/common.h"
42
#include "lib/jxl/base/printf_macros.h"
43
#include "lib/jxl/base/rect.h"
44
#include "lib/jxl/base/status.h"
45
#include "lib/jxl/coeff_order.h"
46
#include "lib/jxl/common.h"  // kMaxNumPasses
47
#include "lib/jxl/dec_cache.h"
48
#include "lib/jxl/dec_transforms-inl.h"
49
#include "lib/jxl/dec_xyb.h"
50
#include "lib/jxl/entropy_coder.h"
51
#include "lib/jxl/quant_weights.h"
52
#include "lib/jxl/quantizer-inl.h"
53
#include "lib/jxl/quantizer.h"
54
55
#ifndef LIB_JXL_DEC_GROUP_CC
56
#define LIB_JXL_DEC_GROUP_CC
57
namespace jxl {
58
59
struct AuxOut;
60
61
// Interface for reading groups for DecodeGroupImpl.
62
class GetBlock {
63
 public:
64
  virtual void StartRow(size_t by) = 0;
65
  virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs,
66
                           size_t size, size_t log2_covered_blocks,
67
                           ACPtr block[3], ACType ac_type) = 0;
68
27.1k
  virtual ~GetBlock() {}
69
};
70
71
// Controls whether DecodeGroupImpl renders to pixels or not.
72
enum DrawMode {
73
  // Render to pixels.
74
  kDraw = 0,
75
  // Don't render to pixels.
76
  kDontDraw = 1,
77
};
78
79
}  // namespace jxl
80
#endif  // LIB_JXL_DEC_GROUP_CC
81
82
HWY_BEFORE_NAMESPACE();
83
namespace jxl {
84
namespace HWY_NAMESPACE {
85
86
// These templates are not found via ADL.
87
using hwy::HWY_NAMESPACE::AllFalse;
88
using hwy::HWY_NAMESPACE::Gt;
89
using hwy::HWY_NAMESPACE::Le;
90
using hwy::HWY_NAMESPACE::MaskFromVec;
91
using hwy::HWY_NAMESPACE::Or;
92
using hwy::HWY_NAMESPACE::Rebind;
93
using hwy::HWY_NAMESPACE::ShiftRight;
94
95
using D = HWY_FULL(float);
96
using DU = HWY_FULL(uint32_t);
97
using DI = HWY_FULL(int32_t);
98
using DI16 = Rebind<int16_t, DI>;
99
using DI16_FULL = HWY_CAPPED(int16_t, kDCTBlockSize);
100
constexpr D d;
101
constexpr DI di;
102
constexpr DI16 di16;
103
constexpr DI16_FULL di16_full;
104
105
// TODO(veluca): consider SIMDfying.
106
376
void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
107
3.38k
  for (size_t x = 0; x < 8; x++) {
108
13.5k
    for (size_t y = x + 1; y < 8; y++) {
109
10.5k
      std::swap(block[y * 8 + x], block[x * 8 + y]);
110
10.5k
    }
111
3.00k
  }
112
376
}
jxl::N_SSE4::Transpose8x8InPlace(int*)
Line
Count
Source
106
95
void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
107
855
  for (size_t x = 0; x < 8; x++) {
108
3.42k
    for (size_t y = x + 1; y < 8; y++) {
109
2.66k
      std::swap(block[y * 8 + x], block[x * 8 + y]);
110
2.66k
    }
111
760
  }
112
95
}
jxl::N_AVX2::Transpose8x8InPlace(int*)
Line
Count
Source
106
263
void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
107
2.36k
  for (size_t x = 0; x < 8; x++) {
108
9.46k
    for (size_t y = x + 1; y < 8; y++) {
109
7.36k
      std::swap(block[y * 8 + x], block[x * 8 + y]);
110
7.36k
    }
111
2.10k
  }
112
263
}
jxl::N_SSE2::Transpose8x8InPlace(int*)
Line
Count
Source
106
18
void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
107
162
  for (size_t x = 0; x < 8; x++) {
108
648
    for (size_t y = x + 1; y < 8; y++) {
109
504
      std::swap(block[y * 8 + x], block[x * 8 + y]);
110
504
    }
111
144
  }
112
18
}
113
114
template <ACType ac_type>
115
void DequantLane(Vec<D> scaled_dequant_x, Vec<D> scaled_dequant_y,
116
                 Vec<D> scaled_dequant_b,
117
                 const float* JXL_RESTRICT dequant_matrices, size_t size,
118
                 size_t k, Vec<D> x_cc_mul, Vec<D> b_cc_mul,
119
                 const float* JXL_RESTRICT biases, ACPtr qblock[3],
120
23.0M
                 float* JXL_RESTRICT block) {
121
23.0M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
122
23.0M
  const auto y_mul =
123
23.0M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
124
23.0M
  const auto b_mul =
125
23.0M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
126
127
23.0M
  Vec<DI> quantized_x_int;
128
23.0M
  Vec<DI> quantized_y_int;
129
23.0M
  Vec<DI> quantized_b_int;
130
23.0M
  if (ac_type == ACType::k16) {
131
12.2M
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
132
12.2M
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
133
12.2M
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
134
12.2M
  } else {
135
10.7M
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
136
10.7M
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
137
10.7M
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
138
10.7M
  }
139
140
23.0M
  const auto dequant_x_cc =
141
23.0M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
142
23.0M
  const auto dequant_y =
143
23.0M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
144
23.0M
  const auto dequant_b_cc =
145
23.0M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
146
147
23.0M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
148
23.0M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
149
23.0M
  Store(dequant_x, d, block + k);
150
23.0M
  Store(dequant_y, d, block + size + k);
151
23.0M
  Store(dequant_b, d, block + 2 * size + k);
152
23.0M
}
void jxl::N_SSE4::DequantLane<(jxl::ACType)0>(hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, float const*, unsigned long, unsigned long, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
120
3.40M
                 float* JXL_RESTRICT block) {
121
3.40M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
122
3.40M
  const auto y_mul =
123
3.40M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
124
3.40M
  const auto b_mul =
125
3.40M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
126
127
3.40M
  Vec<DI> quantized_x_int;
128
3.40M
  Vec<DI> quantized_y_int;
129
3.40M
  Vec<DI> quantized_b_int;
130
3.41M
  if (ac_type == ACType::k16) {
131
3.41M
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
132
3.41M
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
133
3.41M
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
134
18.4E
  } else {
135
18.4E
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
136
18.4E
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
137
18.4E
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
138
18.4E
  }
139
140
3.40M
  const auto dequant_x_cc =
141
3.40M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
142
3.40M
  const auto dequant_y =
143
3.40M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
144
3.40M
  const auto dequant_b_cc =
145
3.40M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
146
147
3.40M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
148
3.40M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
149
3.40M
  Store(dequant_x, d, block + k);
150
3.40M
  Store(dequant_y, d, block + size + k);
151
3.40M
  Store(dequant_b, d, block + 2 * size + k);
152
3.40M
}
void jxl::N_SSE4::DequantLane<(jxl::ACType)1>(hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, float const*, unsigned long, unsigned long, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
120
5.43M
                 float* JXL_RESTRICT block) {
121
5.43M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
122
5.43M
  const auto y_mul =
123
5.43M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
124
5.43M
  const auto b_mul =
125
5.43M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
126
127
5.43M
  Vec<DI> quantized_x_int;
128
5.43M
  Vec<DI> quantized_y_int;
129
5.43M
  Vec<DI> quantized_b_int;
130
5.43M
  if (ac_type == ACType::k16) {
131
0
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
132
0
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
133
0
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
134
5.43M
  } else {
135
5.43M
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
136
5.43M
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
137
5.43M
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
138
5.43M
  }
139
140
5.43M
  const auto dequant_x_cc =
141
5.43M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
142
5.43M
  const auto dequant_y =
143
5.43M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
144
5.43M
  const auto dequant_b_cc =
145
5.43M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
146
147
5.43M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
148
5.43M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
149
5.43M
  Store(dequant_x, d, block + k);
150
5.43M
  Store(dequant_y, d, block + size + k);
151
5.43M
  Store(dequant_b, d, block + 2 * size + k);
152
5.43M
}
void jxl::N_AVX2::DequantLane<(jxl::ACType)0>(hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, float const*, unsigned long, unsigned long, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
120
5.74M
                 float* JXL_RESTRICT block) {
121
5.74M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
122
5.74M
  const auto y_mul =
123
5.74M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
124
5.74M
  const auto b_mul =
125
5.74M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
126
127
5.74M
  Vec<DI> quantized_x_int;
128
5.74M
  Vec<DI> quantized_y_int;
129
5.74M
  Vec<DI> quantized_b_int;
130
5.74M
  if (ac_type == ACType::k16) {
131
5.74M
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
132
5.74M
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
133
5.74M
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
134
18.4E
  } else {
135
18.4E
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
136
18.4E
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
137
18.4E
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
138
18.4E
  }
139
140
5.74M
  const auto dequant_x_cc =
141
5.74M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
142
5.74M
  const auto dequant_y =
143
5.74M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
144
5.74M
  const auto dequant_b_cc =
145
5.74M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
146
147
5.74M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
148
5.74M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
149
5.74M
  Store(dequant_x, d, block + k);
150
5.74M
  Store(dequant_y, d, block + size + k);
151
5.74M
  Store(dequant_b, d, block + 2 * size + k);
152
5.74M
}
void jxl::N_AVX2::DequantLane<(jxl::ACType)1>(hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, float const*, unsigned long, unsigned long, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
120
2.50M
                 float* JXL_RESTRICT block) {
121
2.50M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
122
2.50M
  const auto y_mul =
123
2.50M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
124
2.50M
  const auto b_mul =
125
2.50M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
126
127
2.50M
  Vec<DI> quantized_x_int;
128
2.50M
  Vec<DI> quantized_y_int;
129
2.50M
  Vec<DI> quantized_b_int;
130
2.50M
  if (ac_type == ACType::k16) {
131
0
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
132
0
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
133
0
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
134
2.50M
  } else {
135
2.50M
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
136
2.50M
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
137
2.50M
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
138
2.50M
  }
139
140
2.50M
  const auto dequant_x_cc =
141
2.50M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
142
2.50M
  const auto dequant_y =
143
2.50M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
144
2.50M
  const auto dequant_b_cc =
145
2.50M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
146
147
2.50M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
148
2.50M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
149
2.50M
  Store(dequant_x, d, block + k);
150
2.50M
  Store(dequant_y, d, block + size + k);
151
2.50M
  Store(dequant_b, d, block + 2 * size + k);
152
2.50M
}
void jxl::N_SSE2::DequantLane<(jxl::ACType)0>(hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, float const*, unsigned long, unsigned long, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
120
3.07M
                 float* JXL_RESTRICT block) {
121
3.07M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
122
3.07M
  const auto y_mul =
123
3.07M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
124
3.07M
  const auto b_mul =
125
3.07M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
126
127
3.07M
  Vec<DI> quantized_x_int;
128
3.07M
  Vec<DI> quantized_y_int;
129
3.07M
  Vec<DI> quantized_b_int;
130
3.07M
  if (ac_type == ACType::k16) {
131
3.07M
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
132
3.07M
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
133
3.07M
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
134
18.4E
  } else {
135
18.4E
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
136
18.4E
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
137
18.4E
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
138
18.4E
  }
139
140
3.07M
  const auto dequant_x_cc =
141
3.07M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
142
3.07M
  const auto dequant_y =
143
3.07M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
144
3.07M
  const auto dequant_b_cc =
145
3.07M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
146
147
3.07M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
148
3.07M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
149
3.07M
  Store(dequant_x, d, block + k);
150
3.07M
  Store(dequant_y, d, block + size + k);
151
3.07M
  Store(dequant_b, d, block + 2 * size + k);
152
3.07M
}
void jxl::N_SSE2::DequantLane<(jxl::ACType)1>(hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, float const*, unsigned long, unsigned long, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
120
2.85M
                 float* JXL_RESTRICT block) {
121
2.85M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
122
2.85M
  const auto y_mul =
123
2.85M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
124
2.85M
  const auto b_mul =
125
2.85M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
126
127
2.85M
  Vec<DI> quantized_x_int;
128
2.85M
  Vec<DI> quantized_y_int;
129
2.85M
  Vec<DI> quantized_b_int;
130
2.85M
  if (ac_type == ACType::k16) {
131
0
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
132
0
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
133
0
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
134
2.85M
  } else {
135
2.85M
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
136
2.85M
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
137
2.85M
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
138
2.85M
  }
139
140
2.85M
  const auto dequant_x_cc =
141
2.85M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
142
2.85M
  const auto dequant_y =
143
2.85M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
144
2.85M
  const auto dequant_b_cc =
145
2.85M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
146
147
2.85M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
148
2.85M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
149
2.85M
  Store(dequant_x, d, block + k);
150
2.85M
  Store(dequant_y, d, block + size + k);
151
2.85M
  Store(dequant_b, d, block + 2 * size + k);
152
2.85M
}
153
154
template <ACType ac_type>
155
void DequantBlock(float inv_global_scale, int quant, float x_dm_multiplier,
156
                  float b_dm_multiplier, Vec<D> x_cc_mul, Vec<D> b_cc_mul,
157
                  AcStrategyType kind, size_t size, const Quantizer& quantizer,
158
                  size_t covered_blocks, const size_t* sbx,
159
                  const float* JXL_RESTRICT* JXL_RESTRICT dc_row,
160
                  size_t dc_stride, const float* JXL_RESTRICT biases,
161
                  ACPtr qblock[3], float* JXL_RESTRICT block,
162
1.51M
                  float* JXL_RESTRICT scratch) {
163
1.51M
  const auto scaled_dequant_s = inv_global_scale / quant;
164
165
1.51M
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
166
1.51M
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
167
1.51M
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
168
169
1.51M
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
170
171
24.5M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
172
23.0M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
173
23.0M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
174
23.0M
                         qblock, block);
175
23.0M
  }
176
6.06M
  for (size_t c = 0; c < 3; c++) {
177
4.54M
    LowestFrequenciesFromDC(kind, dc_row[c] + sbx[c], dc_stride,
178
4.54M
                            block + c * size, scratch);
179
4.54M
  }
180
1.51M
}
void jxl::N_SSE4::DequantBlock<(jxl::ACType)0>(float, int, float, float, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, jxl::AcStrategyType, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
162
119k
                  float* JXL_RESTRICT scratch) {
163
119k
  const auto scaled_dequant_s = inv_global_scale / quant;
164
165
119k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
166
119k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
167
119k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
168
169
119k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
170
171
3.52M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
172
3.40M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
173
3.40M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
174
3.40M
                         qblock, block);
175
3.40M
  }
176
478k
  for (size_t c = 0; c < 3; c++) {
177
358k
    LowestFrequenciesFromDC(kind, dc_row[c] + sbx[c], dc_stride,
178
358k
                            block + c * size, scratch);
179
358k
  }
180
119k
}
void jxl::N_SSE4::DequantBlock<(jxl::ACType)1>(float, int, float, float, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, jxl::AcStrategyType, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
162
313k
                  float* JXL_RESTRICT scratch) {
163
313k
  const auto scaled_dequant_s = inv_global_scale / quant;
164
165
313k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
166
313k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
167
313k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
168
169
313k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
170
171
5.75M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
172
5.43M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
173
5.43M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
174
5.43M
                         qblock, block);
175
5.43M
  }
176
1.25M
  for (size_t c = 0; c < 3; c++) {
177
939k
    LowestFrequenciesFromDC(kind, dc_row[c] + sbx[c], dc_stride,
178
939k
                            block + c * size, scratch);
179
939k
  }
180
313k
}
void jxl::N_AVX2::DequantBlock<(jxl::ACType)0>(float, int, float, float, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, jxl::AcStrategyType, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
162
566k
                  float* JXL_RESTRICT scratch) {
163
566k
  const auto scaled_dequant_s = inv_global_scale / quant;
164
165
566k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
166
566k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
167
566k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
168
169
566k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
170
171
6.30M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
172
5.74M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
173
5.74M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
174
5.74M
                         qblock, block);
175
5.74M
  }
176
2.26M
  for (size_t c = 0; c < 3; c++) {
177
1.69M
    LowestFrequenciesFromDC(kind, dc_row[c] + sbx[c], dc_stride,
178
1.69M
                            block + c * size, scratch);
179
1.69M
  }
180
566k
}
void jxl::N_AVX2::DequantBlock<(jxl::ACType)1>(float, int, float, float, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, jxl::AcStrategyType, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
162
252k
                  float* JXL_RESTRICT scratch) {
163
252k
  const auto scaled_dequant_s = inv_global_scale / quant;
164
165
252k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
166
252k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
167
252k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
168
169
252k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
170
171
2.75M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
172
2.50M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
173
2.50M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
174
2.50M
                         qblock, block);
175
2.50M
  }
176
1.00M
  for (size_t c = 0; c < 3; c++) {
177
756k
    LowestFrequenciesFromDC(kind, dc_row[c] + sbx[c], dc_stride,
178
756k
                            block + c * size, scratch);
179
756k
  }
180
252k
}
void jxl::N_SSE2::DequantBlock<(jxl::ACType)0>(float, int, float, float, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, jxl::AcStrategyType, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
162
109k
                  float* JXL_RESTRICT scratch) {
163
109k
  const auto scaled_dequant_s = inv_global_scale / quant;
164
165
109k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
166
109k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
167
109k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
168
169
109k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
170
171
3.18M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
172
3.07M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
173
3.07M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
174
3.07M
                         qblock, block);
175
3.07M
  }
176
436k
  for (size_t c = 0; c < 3; c++) {
177
327k
    LowestFrequenciesFromDC(kind, dc_row[c] + sbx[c], dc_stride,
178
327k
                            block + c * size, scratch);
179
327k
  }
180
109k
}
void jxl::N_SSE2::DequantBlock<(jxl::ACType)1>(float, int, float, float, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, jxl::AcStrategyType, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
162
156k
                  float* JXL_RESTRICT scratch) {
163
156k
  const auto scaled_dequant_s = inv_global_scale / quant;
164
165
156k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
166
156k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
167
156k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
168
169
156k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
170
171
3.01M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
172
2.85M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
173
2.85M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
174
2.85M
                         qblock, block);
175
2.85M
  }
176
624k
  for (size_t c = 0; c < 3; c++) {
177
468k
    LowestFrequenciesFromDC(kind, dc_row[c] + sbx[c], dc_stride,
178
468k
                            block + c * size, scratch);
179
468k
  }
180
156k
}
181
182
Status DecodeGroupImpl(const FrameHeader& frame_header,
183
                       GetBlock* JXL_RESTRICT get_block,
184
                       GroupDecCache* JXL_RESTRICT group_dec_cache,
185
                       PassesDecoderState* JXL_RESTRICT dec_state,
186
                       size_t thread, size_t group_idx,
187
                       RenderPipelineInput& render_pipeline_input,
188
25.5k
                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
189
  // TODO(veluca): investigate cache usage in this function.
190
25.5k
  const Rect block_rect =
191
25.5k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
192
25.5k
  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
193
194
25.5k
  const size_t xsize_blocks = block_rect.xsize();
195
25.5k
  const size_t ysize_blocks = block_rect.ysize();
196
197
25.5k
  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
198
199
25.5k
  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
200
201
25.5k
  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
202
203
25.5k
  const auto kJpegDctMin = Set(di16_full, -4095);
204
25.5k
  const auto kJpegDctMax = Set(di16_full, 4095);
205
206
25.5k
  size_t idct_stride[3];
207
102k
  for (size_t c = 0; c < 3; c++) {
208
76.5k
    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
209
76.5k
  }
210
211
25.5k
  HWY_ALIGN int32_t scaled_qtable[64 * 3];
212
213
25.5k
  ACType ac_type = dec_state->coefficients->Type();
214
25.5k
  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
215
25.5k
                                              : DequantBlock<ACType::k32>;
216
  // Whether or not coefficients should be stored for future usage, and/or read
217
  // from past usage.
218
25.5k
  bool accumulate = !dec_state->coefficients->IsEmpty();
219
  // Offset of the current block in the group.
220
25.5k
  size_t offset = 0;
221
222
25.5k
  std::array<int, 3> jpeg_c_map;
223
25.5k
  bool jpeg_is_gray = false;
224
25.5k
  std::array<int, 3> dcoff = {};
225
226
  // TODO(veluca): all of this should be done only once per image.
227
25.5k
  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
228
25.5k
  if (jpeg_data) {
229
421
    if (!color_correlation.IsJPEGCompatible()) {
230
22
      return JXL_FAILURE("The CfL map is not JPEG-compatible");
231
22
    }
232
399
    jpeg_is_gray = (jpeg_data->components.size() == 1);
233
399
    JXL_ENSURE(frame_header.color_transform != ColorTransform::kXYB);
234
399
    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
235
399
    const std::vector<QuantEncoding>& qe =
236
399
        dec_state->shared->matrices.encodings();
237
399
    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
238
399
        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
239
0
      return JXL_FAILURE(
240
0
          "Quantization table is not a JPEG quantization table.");
241
0
    }
242
399
    JXL_ENSURE(qe[0].qraw.qtable->size() == 3 * 8 * 8);
243
399
    int* qtable = qe[0].qraw.qtable->data();
244
1.55k
    for (size_t c = 0; c < 3; c++) {
245
1.17k
      if (frame_header.color_transform == ColorTransform::kNone) {
246
51
        dcoff[c] = 1024 / qtable[64 * c];
247
51
      }
248
75.3k
      for (size_t i = 0; i < 64; i++) {
249
        // Transpose the matrix, as it will be used on the transposed block.
250
74.1k
        int num = qtable[64 + i];
251
74.1k
        int den = qtable[64 * c + i];
252
74.1k
        if (num <= 0 || den <= 0 || num >= 65536 || den >= 65536) {
253
20
          return JXL_FAILURE("Invalid JPEG quantization table");
254
20
        }
255
74.1k
        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
256
74.1k
            (1 << kCFLFixedPointPrecision) * num / den;
257
74.1k
      }
258
1.17k
    }
259
399
  }
260
261
25.4k
  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
262
25.4k
  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
263
25.4k
  Rect r[3];
264
101k
  for (size_t i = 0; i < 3; i++) {
265
76.4k
    r[i] =
266
76.4k
        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
267
76.4k
             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
268
76.4k
    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
269
76.4k
                        dec_state->shared->dc->Plane(i).ysize()})) {
270
0
      return JXL_FAILURE("Frame dimensions are too big for the image.");
271
0
    }
272
76.4k
  }
273
274
230k
  for (size_t by = 0; by < ysize_blocks; ++by) {
275
206k
    get_block->StartRow(by);
276
206k
    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
277
278
206k
    const int32_t* JXL_RESTRICT row_quant =
279
206k
        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
280
281
206k
    const float* JXL_RESTRICT dc_rows[3] = {
282
206k
        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
283
206k
        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
284
206k
        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
285
206k
    };
286
287
206k
    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
288
206k
    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
289
290
206k
    const int8_t* JXL_RESTRICT row_cmap[3] = {
291
206k
        dec_state->shared->cmap.ytox_map.ConstRow(ty),
292
206k
        nullptr,
293
206k
        dec_state->shared->cmap.ytob_map.ConstRow(ty),
294
206k
    };
295
296
206k
    float* JXL_RESTRICT idct_row[3];
297
206k
    int16_t* JXL_RESTRICT jpeg_row[3];
298
824k
    for (size_t c = 0; c < 3; c++) {
299
617k
      const auto& buffer = render_pipeline_input.GetBuffer(c);
300
617k
      idct_row[c] = buffer.second.Row(buffer.first, sby[c] * kBlockDim);
301
617k
      if (jpeg_data) {
302
1.13k
        auto& component = jpeg_data->components[jpeg_c_map[c]];
303
1.13k
        jpeg_row[c] =
304
1.13k
            component.coeffs.data() +
305
1.13k
            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
306
1.13k
                kDCTBlockSize;
307
1.13k
      }
308
617k
    }
309
310
206k
    size_t bx = 0;
311
560k
    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
312
355k
         tx++) {
313
355k
      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
314
355k
      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
315
355k
      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
316
      // Increment bx by llf_x because those iterations would otherwise
317
      // immediately continue (!IsFirstBlock). Reduces mispredictions.
318
1.99M
      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
319
1.64M
        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
320
1.64M
        AcStrategy acs = acs_row[bx];
321
1.64M
        const size_t llf_x = acs.covered_blocks_x();
322
323
        // Can only happen in the second or lower rows of a varblock.
324
1.64M
        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
325
124k
          bx += llf_x;
326
124k
          continue;
327
124k
        }
328
1.51M
        const size_t log2_covered_blocks = acs.log2_covered_blocks();
329
330
1.51M
        const size_t covered_blocks = 1 << log2_covered_blocks;
331
1.51M
        const size_t size = covered_blocks * kDCTBlockSize;
332
333
1.51M
        ACPtr qblock[3];
334
1.51M
        if (accumulate) {
335
5.40k
          for (size_t c = 0; c < 3; c++) {
336
4.05k
            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
337
4.05k
          }
338
1.51M
        } else {
339
          // No point in reading from bitstream without accumulating and not
340
          // drawing.
341
1.51M
          JXL_ENSURE(draw == kDraw);
342
1.51M
          if (ac_type == ACType::k16) {
343
795k
            memset(group_dec_cache->dec_group_qblock16, 0,
344
795k
                   size * 3 * sizeof(int16_t));
345
3.18M
            for (size_t c = 0; c < 3; c++) {
346
2.38M
              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
347
2.38M
            }
348
795k
          } else {
349
721k
            memset(group_dec_cache->dec_group_qblock, 0,
350
721k
                   size * 3 * sizeof(int32_t));
351
2.88M
            for (size_t c = 0; c < 3; c++) {
352
2.16M
              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
353
2.16M
            }
354
721k
          }
355
1.51M
        }
356
1.51M
        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
357
1.51M
            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
358
1.51M
        offset += size;
359
1.51M
        if (draw == kDontDraw) {
360
1.14k
          bx += llf_x;
361
1.14k
          continue;
362
1.14k
        }
363
364
1.51M
        if (JXL_UNLIKELY(jpeg_data)) {
365
379
          if (acs.Strategy() != AcStrategyType::DCT) {
366
3
            return JXL_FAILURE(
367
3
                "Can only decode to JPEG if only DCT-8 is used.");
368
3
          }
369
370
376
          HWY_ALIGN int32_t transposed_dct_y[64];
371
1.12k
          for (size_t c : {1, 0, 2}) {
372
            // Propagate only Y for grayscale.
373
1.12k
            if (jpeg_is_gray && c != 1) {
374
746
              continue;
375
746
            }
376
376
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
377
0
              continue;
378
0
            }
379
376
            int16_t* JXL_RESTRICT jpeg_pos =
380
376
                jpeg_row[c] + sbx[c] * kDCTBlockSize;
381
            // JPEG XL is transposed, JPEG is not.
382
376
            auto* transposed_dct = qblock[c].ptr32;
383
376
            Transpose8x8InPlace(transposed_dct);
384
            // No CfL - no need to store the y block converted to integers.
385
376
            if (!cs.Is444() ||
386
376
                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
387
3.61k
              for (size_t i = 0; i < 64; i += Lanes(d)) {
388
3.29k
                const auto ini = Load(di, transposed_dct + i);
389
3.29k
                const auto ini16 = DemoteTo(di16, ini);
390
3.29k
                StoreU(ini16, di16, jpeg_pos + i);
391
3.29k
              }
392
314
            } else if (c == 1) {
393
              // Y channel: save for restoring X/B, but nothing else to do.
394
678
              for (size_t i = 0; i < 64; i += Lanes(d)) {
395
616
                const auto ini = Load(di, transposed_dct + i);
396
616
                Store(ini, di, transposed_dct_y + i);
397
616
                const auto ini16 = DemoteTo(di16, ini);
398
616
                StoreU(ini16, di16, jpeg_pos + i);
399
616
              }
400
62
            } else {
401
              // transposed_dct_y contains the y channel block, transposed.
402
0
              const auto scale =
403
0
                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
404
0
              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
405
0
              for (int i = 0; i < 64; i += Lanes(d)) {
406
0
                auto in = Load(di, transposed_dct + i);
407
0
                auto in_y = Load(di, transposed_dct_y + i);
408
0
                auto qt = Load(di, scaled_qtable + c * size + i);
409
0
                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
410
0
                    Add(Mul(qt, scale), round));
411
0
                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
412
0
                    Add(Mul(in_y, coeff_scale), round));
413
0
                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
414
0
              }
415
0
            }
416
376
            jpeg_pos[0] =
417
376
                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
418
376
            auto overflow = MaskFromVec(Set(di16_full, 0));
419
376
            auto underflow = MaskFromVec(Set(di16_full, 0));
420
2.33k
            for (int i = 0; i < 64; i += Lanes(di16_full)) {
421
1.95k
              auto in = LoadU(di16_full, jpeg_pos + i);
422
1.95k
              overflow = Or(overflow, Gt(in, kJpegDctMax));
423
1.95k
              underflow = Or(underflow, Lt(in, kJpegDctMin));
424
1.95k
            }
425
376
            if (!AllFalse(di16_full, Or(overflow, underflow))) {
426
3
              return JXL_FAILURE("JPEG DCT coefficients out of range");
427
3
            }
428
376
          }
429
1.51M
        } else {
430
1.51M
          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
431
          // Dequantize and add predictions.
432
1.51M
          dequant_block(
433
1.51M
              inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
434
1.51M
              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.Strategy(),
435
1.51M
              size, dec_state->shared->quantizer,
436
1.51M
              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
437
1.51M
              dc_stride,
438
1.51M
              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
439
1.51M
              block, group_dec_cache->scratch_space);
440
441
4.54M
          for (size_t c : {1, 0, 2}) {
442
4.54M
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
443
1.10M
              continue;
444
1.10M
            }
445
            // IDCT
446
3.44M
            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
447
3.44M
            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
448
3.44M
                              idct_stride[c], group_dec_cache->scratch_space);
449
3.44M
          }
450
1.51M
        }
451
1.51M
        bx += llf_x;
452
1.51M
      }
453
355k
    }
454
206k
  }
455
24.6k
  return true;
456
25.4k
}
jxl::N_SSE4::DecodeGroupImpl(jxl::FrameHeader const&, jxl::GetBlock*, jxl::GroupDecCache*, jxl::PassesDecoderState*, unsigned long, unsigned long, jxl::RenderPipelineInput&, jxl::jpeg::JPEGData*, jxl::DrawMode)
Line
Count
Source
188
5.05k
                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
189
  // TODO(veluca): investigate cache usage in this function.
190
5.05k
  const Rect block_rect =
191
5.05k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
192
5.05k
  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
193
194
5.05k
  const size_t xsize_blocks = block_rect.xsize();
195
5.05k
  const size_t ysize_blocks = block_rect.ysize();
196
197
5.05k
  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
198
199
5.05k
  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
200
201
5.05k
  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
202
203
5.05k
  const auto kJpegDctMin = Set(di16_full, -4095);
204
5.05k
  const auto kJpegDctMax = Set(di16_full, 4095);
205
206
5.05k
  size_t idct_stride[3];
207
20.2k
  for (size_t c = 0; c < 3; c++) {
208
15.1k
    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
209
15.1k
  }
210
211
5.05k
  HWY_ALIGN int32_t scaled_qtable[64 * 3];
212
213
5.05k
  ACType ac_type = dec_state->coefficients->Type();
214
5.05k
  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
215
5.05k
                                              : DequantBlock<ACType::k32>;
216
  // Whether or not coefficients should be stored for future usage, and/or read
217
  // from past usage.
218
5.05k
  bool accumulate = !dec_state->coefficients->IsEmpty();
219
  // Offset of the current block in the group.
220
5.05k
  size_t offset = 0;
221
222
5.05k
  std::array<int, 3> jpeg_c_map;
223
5.05k
  bool jpeg_is_gray = false;
224
5.05k
  std::array<int, 3> dcoff = {};
225
226
  // TODO(veluca): all of this should be done only once per image.
227
5.05k
  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
228
5.05k
  if (jpeg_data) {
229
108
    if (!color_correlation.IsJPEGCompatible()) {
230
6
      return JXL_FAILURE("The CfL map is not JPEG-compatible");
231
6
    }
232
102
    jpeg_is_gray = (jpeg_data->components.size() == 1);
233
102
    JXL_ENSURE(frame_header.color_transform != ColorTransform::kXYB);
234
102
    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
235
102
    const std::vector<QuantEncoding>& qe =
236
102
        dec_state->shared->matrices.encodings();
237
102
    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
238
102
        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
239
0
      return JXL_FAILURE(
240
0
          "Quantization table is not a JPEG quantization table.");
241
0
    }
242
102
    JXL_ENSURE(qe[0].qraw.qtable->size() == 3 * 8 * 8);
243
102
    int* qtable = qe[0].qraw.qtable->data();
244
396
    for (size_t c = 0; c < 3; c++) {
245
300
      if (frame_header.color_transform == ColorTransform::kNone) {
246
21
        dcoff[c] = 1024 / qtable[64 * c];
247
21
      }
248
19.1k
      for (size_t i = 0; i < 64; i++) {
249
        // Transpose the matrix, as it will be used on the transposed block.
250
18.8k
        int num = qtable[64 + i];
251
18.8k
        int den = qtable[64 * c + i];
252
18.8k
        if (num <= 0 || den <= 0 || num >= 65536 || den >= 65536) {
253
6
          return JXL_FAILURE("Invalid JPEG quantization table");
254
6
        }
255
18.8k
        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
256
18.8k
            (1 << kCFLFixedPointPrecision) * num / den;
257
18.8k
      }
258
300
    }
259
102
  }
260
261
5.04k
  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
262
5.04k
  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
263
5.04k
  Rect r[3];
264
20.1k
  for (size_t i = 0; i < 3; i++) {
265
15.1k
    r[i] =
266
15.1k
        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
267
15.1k
             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
268
15.1k
    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
269
15.1k
                        dec_state->shared->dc->Plane(i).ysize()})) {
270
0
      return JXL_FAILURE("Frame dimensions are too big for the image.");
271
0
    }
272
15.1k
  }
273
274
68.2k
  for (size_t by = 0; by < ysize_blocks; ++by) {
275
63.3k
    get_block->StartRow(by);
276
63.3k
    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
277
278
63.3k
    const int32_t* JXL_RESTRICT row_quant =
279
63.3k
        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
280
281
63.3k
    const float* JXL_RESTRICT dc_rows[3] = {
282
63.3k
        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
283
63.3k
        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
284
63.3k
        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
285
63.3k
    };
286
287
63.3k
    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
288
63.3k
    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
289
290
63.3k
    const int8_t* JXL_RESTRICT row_cmap[3] = {
291
63.3k
        dec_state->shared->cmap.ytox_map.ConstRow(ty),
292
63.3k
        nullptr,
293
63.3k
        dec_state->shared->cmap.ytob_map.ConstRow(ty),
294
63.3k
    };
295
296
63.3k
    float* JXL_RESTRICT idct_row[3];
297
63.3k
    int16_t* JXL_RESTRICT jpeg_row[3];
298
253k
    for (size_t c = 0; c < 3; c++) {
299
190k
      const auto& buffer = render_pipeline_input.GetBuffer(c);
300
190k
      idct_row[c] = buffer.second.Row(buffer.first, sby[c] * kBlockDim);
301
190k
      if (jpeg_data) {
302
288
        auto& component = jpeg_data->components[jpeg_c_map[c]];
303
288
        jpeg_row[c] =
304
288
            component.coeffs.data() +
305
288
            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
306
288
                kDCTBlockSize;
307
288
      }
308
190k
    }
309
310
63.3k
    size_t bx = 0;
311
167k
    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
312
104k
         tx++) {
313
104k
      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
314
104k
      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
315
104k
      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
316
      // Increment bx by llf_x because those iterations would otherwise
317
      // immediately continue (!IsFirstBlock). Reduces mispredictions.
318
580k
      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
319
476k
        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
320
476k
        AcStrategy acs = acs_row[bx];
321
476k
        const size_t llf_x = acs.covered_blocks_x();
322
323
        // Can only happen in the second or lower rows of a varblock.
324
476k
        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
325
43.1k
          bx += llf_x;
326
43.1k
          continue;
327
43.1k
        }
328
433k
        const size_t log2_covered_blocks = acs.log2_covered_blocks();
329
330
433k
        const size_t covered_blocks = 1 << log2_covered_blocks;
331
433k
        const size_t size = covered_blocks * kDCTBlockSize;
332
333
433k
        ACPtr qblock[3];
334
433k
        if (accumulate) {
335
112
          for (size_t c = 0; c < 3; c++) {
336
84
            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
337
84
          }
338
433k
        } else {
339
          // No point in reading from bitstream without accumulating and not
340
          // drawing.
341
433k
          JXL_ENSURE(draw == kDraw);
342
433k
          if (ac_type == ACType::k16) {
343
119k
            memset(group_dec_cache->dec_group_qblock16, 0,
344
119k
                   size * 3 * sizeof(int16_t));
345
478k
            for (size_t c = 0; c < 3; c++) {
346
359k
              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
347
359k
            }
348
313k
          } else {
349
313k
            memset(group_dec_cache->dec_group_qblock, 0,
350
313k
                   size * 3 * sizeof(int32_t));
351
1.25M
            for (size_t c = 0; c < 3; c++) {
352
940k
              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
353
940k
            }
354
313k
          }
355
433k
        }
356
433k
        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
357
433k
            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
358
432k
        offset += size;
359
432k
        if (draw == kDontDraw) {
360
27
          bx += llf_x;
361
27
          continue;
362
27
        }
363
364
432k
        if (JXL_UNLIKELY(jpeg_data)) {
365
96
          if (acs.Strategy() != AcStrategyType::DCT) {
366
1
            return JXL_FAILURE(
367
1
                "Can only decode to JPEG if only DCT-8 is used.");
368
1
          }
369
370
95
          HWY_ALIGN int32_t transposed_dct_y[64];
371
283
          for (size_t c : {1, 0, 2}) {
372
            // Propagate only Y for grayscale.
373
283
            if (jpeg_is_gray && c != 1) {
374
188
              continue;
375
188
            }
376
95
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
377
0
              continue;
378
0
            }
379
95
            int16_t* JXL_RESTRICT jpeg_pos =
380
95
                jpeg_row[c] + sbx[c] * kDCTBlockSize;
381
            // JPEG XL is transposed, JPEG is not.
382
95
            auto* transposed_dct = qblock[c].ptr32;
383
95
            Transpose8x8InPlace(transposed_dct);
384
            // No CfL - no need to store the y block converted to integers.
385
95
            if (!cs.Is444() ||
386
95
                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
387
1.44k
              for (size_t i = 0; i < 64; i += Lanes(d)) {
388
1.36k
                const auto ini = Load(di, transposed_dct + i);
389
1.36k
                const auto ini16 = DemoteTo(di16, ini);
390
1.36k
                StoreU(ini16, di16, jpeg_pos + i);
391
1.36k
              }
392
85
            } else if (c == 1) {
393
              // Y channel: save for restoring X/B, but nothing else to do.
394
170
              for (size_t i = 0; i < 64; i += Lanes(d)) {
395
160
                const auto ini = Load(di, transposed_dct + i);
396
160
                Store(ini, di, transposed_dct_y + i);
397
160
                const auto ini16 = DemoteTo(di16, ini);
398
160
                StoreU(ini16, di16, jpeg_pos + i);
399
160
              }
400
10
            } else {
401
              // transposed_dct_y contains the y channel block, transposed.
402
0
              const auto scale =
403
0
                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
404
0
              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
405
0
              for (int i = 0; i < 64; i += Lanes(d)) {
406
0
                auto in = Load(di, transposed_dct + i);
407
0
                auto in_y = Load(di, transposed_dct_y + i);
408
0
                auto qt = Load(di, scaled_qtable + c * size + i);
409
0
                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
410
0
                    Add(Mul(qt, scale), round));
411
0
                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
412
0
                    Add(Mul(in_y, coeff_scale), round));
413
0
                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
414
0
              }
415
0
            }
416
95
            jpeg_pos[0] =
417
95
                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
418
95
            auto overflow = MaskFromVec(Set(di16_full, 0));
419
95
            auto underflow = MaskFromVec(Set(di16_full, 0));
420
855
            for (int i = 0; i < 64; i += Lanes(di16_full)) {
421
760
              auto in = LoadU(di16_full, jpeg_pos + i);
422
760
              overflow = Or(overflow, Gt(in, kJpegDctMax));
423
760
              underflow = Or(underflow, Lt(in, kJpegDctMin));
424
760
            }
425
95
            if (!AllFalse(di16_full, Or(overflow, underflow))) {
426
1
              return JXL_FAILURE("JPEG DCT coefficients out of range");
427
1
            }
428
95
          }
429
432k
        } else {
430
432k
          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
431
          // Dequantize and add predictions.
432
432k
          dequant_block(
433
432k
              inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
434
432k
              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.Strategy(),
435
432k
              size, dec_state->shared->quantizer,
436
432k
              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
437
432k
              dc_stride,
438
432k
              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
439
432k
              block, group_dec_cache->scratch_space);
440
441
1.29M
          for (size_t c : {1, 0, 2}) {
442
1.29M
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
443
443k
              continue;
444
443k
            }
445
            // IDCT
446
853k
            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
447
853k
            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
448
853k
                              idct_stride[c], group_dec_cache->scratch_space);
449
853k
          }
450
432k
        }
451
432k
        bx += llf_x;
452
432k
      }
453
104k
    }
454
63.3k
  }
455
4.85k
  return true;
456
5.04k
}
jxl::N_AVX2::DecodeGroupImpl(jxl::FrameHeader const&, jxl::GetBlock*, jxl::GroupDecCache*, jxl::PassesDecoderState*, unsigned long, unsigned long, jxl::RenderPipelineInput&, jxl::jpeg::JPEGData*, jxl::DrawMode)
Line
Count
Source
188
16.4k
                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
189
  // TODO(veluca): investigate cache usage in this function.
190
16.4k
  const Rect block_rect =
191
16.4k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
192
16.4k
  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
193
194
16.4k
  const size_t xsize_blocks = block_rect.xsize();
195
16.4k
  const size_t ysize_blocks = block_rect.ysize();
196
197
16.4k
  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
198
199
16.4k
  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
200
201
16.4k
  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
202
203
16.4k
  const auto kJpegDctMin = Set(di16_full, -4095);
204
16.4k
  const auto kJpegDctMax = Set(di16_full, 4095);
205
206
16.4k
  size_t idct_stride[3];
207
65.7k
  for (size_t c = 0; c < 3; c++) {
208
49.2k
    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
209
49.2k
  }
210
211
16.4k
  HWY_ALIGN int32_t scaled_qtable[64 * 3];
212
213
16.4k
  ACType ac_type = dec_state->coefficients->Type();
214
16.4k
  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
215
16.4k
                                              : DequantBlock<ACType::k32>;
216
  // Whether or not coefficients should be stored for future usage, and/or read
217
  // from past usage.
218
16.4k
  bool accumulate = !dec_state->coefficients->IsEmpty();
219
  // Offset of the current block in the group.
220
16.4k
  size_t offset = 0;
221
222
16.4k
  std::array<int, 3> jpeg_c_map;
223
16.4k
  bool jpeg_is_gray = false;
224
16.4k
  std::array<int, 3> dcoff = {};
225
226
  // TODO(veluca): all of this should be done only once per image.
227
16.4k
  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
228
16.4k
  if (jpeg_data) {
229
282
    if (!color_correlation.IsJPEGCompatible()) {
230
11
      return JXL_FAILURE("The CfL map is not JPEG-compatible");
231
11
    }
232
271
    jpeg_is_gray = (jpeg_data->components.size() == 1);
233
271
    JXL_ENSURE(frame_header.color_transform != ColorTransform::kXYB);
234
271
    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
235
271
    const std::vector<QuantEncoding>& qe =
236
271
        dec_state->shared->matrices.encodings();
237
271
    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
238
271
        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
239
0
      return JXL_FAILURE(
240
0
          "Quantization table is not a JPEG quantization table.");
241
0
    }
242
271
    JXL_ENSURE(qe[0].qraw.qtable->size() == 3 * 8 * 8);
243
271
    int* qtable = qe[0].qraw.qtable->data();
244
1.06k
    for (size_t c = 0; c < 3; c++) {
245
805
      if (frame_header.color_transform == ColorTransform::kNone) {
246
15
        dcoff[c] = 1024 / qtable[64 * c];
247
15
      }
248
51.9k
      for (size_t i = 0; i < 64; i++) {
249
        // Transpose the matrix, as it will be used on the transposed block.
250
51.1k
        int num = qtable[64 + i];
251
51.1k
        int den = qtable[64 * c + i];
252
51.1k
        if (num <= 0 || den <= 0 || num >= 65536 || den >= 65536) {
253
7
          return JXL_FAILURE("Invalid JPEG quantization table");
254
7
        }
255
51.1k
        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
256
51.1k
            (1 << kCFLFixedPointPrecision) * num / den;
257
51.1k
      }
258
805
    }
259
271
  }
260
261
16.4k
  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
262
16.4k
  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
263
16.4k
  Rect r[3];
264
65.6k
  for (size_t i = 0; i < 3; i++) {
265
49.2k
    r[i] =
266
49.2k
        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
267
49.2k
             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
268
49.2k
    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
269
49.2k
                        dec_state->shared->dc->Plane(i).ysize()})) {
270
0
      return JXL_FAILURE("Frame dimensions are too big for the image.");
271
0
    }
272
49.2k
  }
273
274
108k
  for (size_t by = 0; by < ysize_blocks; ++by) {
275
92.5k
    get_block->StartRow(by);
276
92.5k
    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
277
278
92.5k
    const int32_t* JXL_RESTRICT row_quant =
279
92.5k
        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
280
281
92.5k
    const float* JXL_RESTRICT dc_rows[3] = {
282
92.5k
        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
283
92.5k
        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
284
92.5k
        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
285
92.5k
    };
286
287
92.5k
    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
288
92.5k
    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
289
290
92.5k
    const int8_t* JXL_RESTRICT row_cmap[3] = {
291
92.5k
        dec_state->shared->cmap.ytox_map.ConstRow(ty),
292
92.5k
        nullptr,
293
92.5k
        dec_state->shared->cmap.ytob_map.ConstRow(ty),
294
92.5k
    };
295
296
92.5k
    float* JXL_RESTRICT idct_row[3];
297
92.5k
    int16_t* JXL_RESTRICT jpeg_row[3];
298
369k
    for (size_t c = 0; c < 3; c++) {
299
277k
      const auto& buffer = render_pipeline_input.GetBuffer(c);
300
277k
      idct_row[c] = buffer.second.Row(buffer.first, sby[c] * kBlockDim);
301
277k
      if (jpeg_data) {
302
792
        auto& component = jpeg_data->components[jpeg_c_map[c]];
303
792
        jpeg_row[c] =
304
792
            component.coeffs.data() +
305
792
            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
306
792
                kDCTBlockSize;
307
792
      }
308
277k
    }
309
310
92.5k
    size_t bx = 0;
311
266k
    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
312
174k
         tx++) {
313
174k
      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
314
174k
      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
315
174k
      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
316
      // Increment bx by llf_x because those iterations would otherwise
317
      // immediately continue (!IsFirstBlock). Reduces mispredictions.
318
1.03M
      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
319
866k
        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
320
866k
        AcStrategy acs = acs_row[bx];
321
866k
        const size_t llf_x = acs.covered_blocks_x();
322
323
        // Can only happen in the second or lower rows of a varblock.
324
866k
        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
325
45.9k
          bx += llf_x;
326
45.9k
          continue;
327
45.9k
        }
328
820k
        const size_t log2_covered_blocks = acs.log2_covered_blocks();
329
330
820k
        const size_t covered_blocks = 1 << log2_covered_blocks;
331
820k
        const size_t size = covered_blocks * kDCTBlockSize;
332
333
820k
        ACPtr qblock[3];
334
820k
        if (accumulate) {
335
5.17k
          for (size_t c = 0; c < 3; c++) {
336
3.87k
            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
337
3.87k
          }
338
818k
        } else {
339
          // No point in reading from bitstream without accumulating and not
340
          // drawing.
341
818k
          JXL_ENSURE(draw == kDraw);
342
818k
          if (ac_type == ACType::k16) {
343
566k
            memset(group_dec_cache->dec_group_qblock16, 0,
344
566k
                   size * 3 * sizeof(int16_t));
345
2.26M
            for (size_t c = 0; c < 3; c++) {
346
1.69M
              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
347
1.69M
            }
348
566k
          } else {
349
252k
            memset(group_dec_cache->dec_group_qblock, 0,
350
252k
                   size * 3 * sizeof(int32_t));
351
1.01M
            for (size_t c = 0; c < 3; c++) {
352
758k
              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
353
758k
            }
354
252k
          }
355
818k
        }
356
820k
        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
357
820k
            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
358
819k
        offset += size;
359
819k
        if (draw == kDontDraw) {
360
1.09k
          bx += llf_x;
361
1.09k
          continue;
362
1.09k
        }
363
364
818k
        if (JXL_UNLIKELY(jpeg_data)) {
365
264
          if (acs.Strategy() != AcStrategyType::DCT) {
366
1
            return JXL_FAILURE(
367
1
                "Can only decode to JPEG if only DCT-8 is used.");
368
1
          }
369
370
263
          HWY_ALIGN int32_t transposed_dct_y[64];
371
787
          for (size_t c : {1, 0, 2}) {
372
            // Propagate only Y for grayscale.
373
787
            if (jpeg_is_gray && c != 1) {
374
524
              continue;
375
524
            }
376
263
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
377
0
              continue;
378
0
            }
379
263
            int16_t* JXL_RESTRICT jpeg_pos =
380
263
                jpeg_row[c] + sbx[c] * kDCTBlockSize;
381
            // JPEG XL is transposed, JPEG is not.
382
263
            auto* transposed_dct = qblock[c].ptr32;
383
263
            Transpose8x8InPlace(transposed_dct);
384
            // No CfL - no need to store the y block converted to integers.
385
263
            if (!cs.Is444() ||
386
263
                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
387
1.94k
              for (size_t i = 0; i < 64; i += Lanes(d)) {
388
1.72k
                const auto ini = Load(di, transposed_dct + i);
389
1.72k
                const auto ini16 = DemoteTo(di16, ini);
390
1.72k
                StoreU(ini16, di16, jpeg_pos + i);
391
1.72k
              }
392
216
            } else if (c == 1) {
393
              // Y channel: save for restoring X/B, but nothing else to do.
394
423
              for (size_t i = 0; i < 64; i += Lanes(d)) {
395
376
                const auto ini = Load(di, transposed_dct + i);
396
376
                Store(ini, di, transposed_dct_y + i);
397
376
                const auto ini16 = DemoteTo(di16, ini);
398
376
                StoreU(ini16, di16, jpeg_pos + i);
399
376
              }
400
47
            } else {
401
              // transposed_dct_y contains the y channel block, transposed.
402
0
              const auto scale =
403
0
                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
404
0
              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
405
0
              for (int i = 0; i < 64; i += Lanes(d)) {
406
0
                auto in = Load(di, transposed_dct + i);
407
0
                auto in_y = Load(di, transposed_dct_y + i);
408
0
                auto qt = Load(di, scaled_qtable + c * size + i);
409
0
                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
410
0
                    Add(Mul(qt, scale), round));
411
0
                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
412
0
                    Add(Mul(in_y, coeff_scale), round));
413
0
                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
414
0
              }
415
0
            }
416
263
            jpeg_pos[0] =
417
263
                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
418
263
            auto overflow = MaskFromVec(Set(di16_full, 0));
419
263
            auto underflow = MaskFromVec(Set(di16_full, 0));
420
1.31k
            for (int i = 0; i < 64; i += Lanes(di16_full)) {
421
1.05k
              auto in = LoadU(di16_full, jpeg_pos + i);
422
1.05k
              overflow = Or(overflow, Gt(in, kJpegDctMax));
423
1.05k
              underflow = Or(underflow, Lt(in, kJpegDctMin));
424
1.05k
            }
425
263
            if (!AllFalse(di16_full, Or(overflow, underflow))) {
426
1
              return JXL_FAILURE("JPEG DCT coefficients out of range");
427
1
            }
428
263
          }
429
818k
        } else {
430
818k
          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
431
          // Dequantize and add predictions.
432
818k
          dequant_block(
433
818k
              inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
434
818k
              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.Strategy(),
435
818k
              size, dec_state->shared->quantizer,
436
818k
              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
437
818k
              dc_stride,
438
818k
              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
439
818k
              block, group_dec_cache->scratch_space);
440
441
2.45M
          for (size_t c : {1, 0, 2}) {
442
2.45M
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
443
445k
              continue;
444
445k
            }
445
            // IDCT
446
2.00M
            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
447
2.00M
            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
448
2.00M
                              idct_stride[c], group_dec_cache->scratch_space);
449
2.00M
          }
450
818k
        }
451
818k
        bx += llf_x;
452
818k
      }
453
174k
    }
454
92.5k
  }
455
15.8k
  return true;
456
16.4k
}
jxl::N_SSE2::DecodeGroupImpl(jxl::FrameHeader const&, jxl::GetBlock*, jxl::GroupDecCache*, jxl::PassesDecoderState*, unsigned long, unsigned long, jxl::RenderPipelineInput&, jxl::jpeg::JPEGData*, jxl::DrawMode)
Line
Count
Source
188
4.03k
                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
189
  // TODO(veluca): investigate cache usage in this function.
190
4.03k
  const Rect block_rect =
191
4.03k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
192
4.03k
  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
193
194
4.03k
  const size_t xsize_blocks = block_rect.xsize();
195
4.03k
  const size_t ysize_blocks = block_rect.ysize();
196
197
4.03k
  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
198
199
4.03k
  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
200
201
4.03k
  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
202
203
4.03k
  const auto kJpegDctMin = Set(di16_full, -4095);
204
4.03k
  const auto kJpegDctMax = Set(di16_full, 4095);
205
206
4.03k
  size_t idct_stride[3];
207
16.1k
  for (size_t c = 0; c < 3; c++) {
208
12.1k
    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
209
12.1k
  }
210
211
4.03k
  HWY_ALIGN int32_t scaled_qtable[64 * 3];
212
213
4.03k
  ACType ac_type = dec_state->coefficients->Type();
214
4.03k
  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
215
4.03k
                                              : DequantBlock<ACType::k32>;
216
  // Whether or not coefficients should be stored for future usage, and/or read
217
  // from past usage.
218
4.03k
  bool accumulate = !dec_state->coefficients->IsEmpty();
219
  // Offset of the current block in the group.
220
4.03k
  size_t offset = 0;
221
222
4.03k
  std::array<int, 3> jpeg_c_map;
223
4.03k
  bool jpeg_is_gray = false;
224
4.03k
  std::array<int, 3> dcoff = {};
225
226
  // TODO(veluca): all of this should be done only once per image.
227
4.03k
  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
228
4.03k
  if (jpeg_data) {
229
31
    if (!color_correlation.IsJPEGCompatible()) {
230
5
      return JXL_FAILURE("The CfL map is not JPEG-compatible");
231
5
    }
232
26
    jpeg_is_gray = (jpeg_data->components.size() == 1);
233
26
    JXL_ENSURE(frame_header.color_transform != ColorTransform::kXYB);
234
26
    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
235
26
    const std::vector<QuantEncoding>& qe =
236
26
        dec_state->shared->matrices.encodings();
237
26
    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
238
26
        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
239
0
      return JXL_FAILURE(
240
0
          "Quantization table is not a JPEG quantization table.");
241
0
    }
242
26
    JXL_ENSURE(qe[0].qraw.qtable->size() == 3 * 8 * 8);
243
26
    int* qtable = qe[0].qraw.qtable->data();
244
89
    for (size_t c = 0; c < 3; c++) {
245
70
      if (frame_header.color_transform == ColorTransform::kNone) {
246
15
        dcoff[c] = 1024 / qtable[64 * c];
247
15
      }
248
4.19k
      for (size_t i = 0; i < 64; i++) {
249
        // Transpose the matrix, as it will be used on the transposed block.
250
4.13k
        int num = qtable[64 + i];
251
4.13k
        int den = qtable[64 * c + i];
252
4.13k
        if (num <= 0 || den <= 0 || num >= 65536 || den >= 65536) {
253
7
          return JXL_FAILURE("Invalid JPEG quantization table");
254
7
        }
255
4.12k
        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
256
4.12k
            (1 << kCFLFixedPointPrecision) * num / den;
257
4.12k
      }
258
70
    }
259
26
  }
260
261
4.02k
  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
262
4.02k
  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
263
4.02k
  Rect r[3];
264
16.0k
  for (size_t i = 0; i < 3; i++) {
265
12.0k
    r[i] =
266
12.0k
        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
267
12.0k
             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
268
12.0k
    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
269
12.0k
                        dec_state->shared->dc->Plane(i).ysize()})) {
270
0
      return JXL_FAILURE("Frame dimensions are too big for the image.");
271
0
    }
272
12.0k
  }
273
274
54.0k
  for (size_t by = 0; by < ysize_blocks; ++by) {
275
50.1k
    get_block->StartRow(by);
276
50.1k
    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
277
278
50.1k
    const int32_t* JXL_RESTRICT row_quant =
279
50.1k
        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
280
281
50.1k
    const float* JXL_RESTRICT dc_rows[3] = {
282
50.1k
        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
283
50.1k
        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
284
50.1k
        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
285
50.1k
    };
286
287
50.1k
    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
288
50.1k
    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
289
290
50.1k
    const int8_t* JXL_RESTRICT row_cmap[3] = {
291
50.1k
        dec_state->shared->cmap.ytox_map.ConstRow(ty),
292
50.1k
        nullptr,
293
50.1k
        dec_state->shared->cmap.ytob_map.ConstRow(ty),
294
50.1k
    };
295
296
50.1k
    float* JXL_RESTRICT idct_row[3];
297
50.1k
    int16_t* JXL_RESTRICT jpeg_row[3];
298
200k
    for (size_t c = 0; c < 3; c++) {
299
150k
      const auto& buffer = render_pipeline_input.GetBuffer(c);
300
150k
      idct_row[c] = buffer.second.Row(buffer.first, sby[c] * kBlockDim);
301
150k
      if (jpeg_data) {
302
57
        auto& component = jpeg_data->components[jpeg_c_map[c]];
303
57
        jpeg_row[c] =
304
57
            component.coeffs.data() +
305
57
            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
306
57
                kDCTBlockSize;
307
57
      }
308
150k
    }
309
310
50.1k
    size_t bx = 0;
311
126k
    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
312
76.8k
         tx++) {
313
76.8k
      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
314
76.8k
      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
315
76.8k
      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
316
      // Increment bx by llf_x because those iterations would otherwise
317
      // immediately continue (!IsFirstBlock). Reduces mispredictions.
318
377k
      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
319
300k
        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
320
300k
        AcStrategy acs = acs_row[bx];
321
300k
        const size_t llf_x = acs.covered_blocks_x();
322
323
        // Can only happen in the second or lower rows of a varblock.
324
300k
        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
325
35.5k
          bx += llf_x;
326
35.5k
          continue;
327
35.5k
        }
328
265k
        const size_t log2_covered_blocks = acs.log2_covered_blocks();
329
330
265k
        const size_t covered_blocks = 1 << log2_covered_blocks;
331
265k
        const size_t size = covered_blocks * kDCTBlockSize;
332
333
265k
        ACPtr qblock[3];
334
265k
        if (accumulate) {
335
116
          for (size_t c = 0; c < 3; c++) {
336
87
            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
337
87
          }
338
265k
        } else {
339
          // No point in reading from bitstream without accumulating and not
340
          // drawing.
341
265k
          JXL_ENSURE(draw == kDraw);
342
265k
          if (ac_type == ACType::k16) {
343
109k
            memset(group_dec_cache->dec_group_qblock16, 0,
344
109k
                   size * 3 * sizeof(int16_t));
345
436k
            for (size_t c = 0; c < 3; c++) {
346
327k
              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
347
327k
            }
348
156k
          } else {
349
156k
            memset(group_dec_cache->dec_group_qblock, 0,
350
156k
                   size * 3 * sizeof(int32_t));
351
624k
            for (size_t c = 0; c < 3; c++) {
352
468k
              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
353
468k
            }
354
156k
          }
355
265k
        }
356
265k
        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
357
265k
            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
358
265k
        offset += size;
359
265k
        if (draw == kDontDraw) {
360
28
          bx += llf_x;
361
28
          continue;
362
28
        }
363
364
265k
        if (JXL_UNLIKELY(jpeg_data)) {
365
19
          if (acs.Strategy() != AcStrategyType::DCT) {
366
1
            return JXL_FAILURE(
367
1
                "Can only decode to JPEG if only DCT-8 is used.");
368
1
          }
369
370
18
          HWY_ALIGN int32_t transposed_dct_y[64];
371
52
          for (size_t c : {1, 0, 2}) {
372
            // Propagate only Y for grayscale.
373
52
            if (jpeg_is_gray && c != 1) {
374
34
              continue;
375
34
            }
376
18
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
377
0
              continue;
378
0
            }
379
18
            int16_t* JXL_RESTRICT jpeg_pos =
380
18
                jpeg_row[c] + sbx[c] * kDCTBlockSize;
381
            // JPEG XL is transposed, JPEG is not.
382
18
            auto* transposed_dct = qblock[c].ptr32;
383
18
            Transpose8x8InPlace(transposed_dct);
384
            // No CfL - no need to store the y block converted to integers.
385
18
            if (!cs.Is444() ||
386
18
                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
387
221
              for (size_t i = 0; i < 64; i += Lanes(d)) {
388
208
                const auto ini = Load(di, transposed_dct + i);
389
208
                const auto ini16 = DemoteTo(di16, ini);
390
208
                StoreU(ini16, di16, jpeg_pos + i);
391
208
              }
392
13
            } else if (c == 1) {
393
              // Y channel: save for restoring X/B, but nothing else to do.
394
85
              for (size_t i = 0; i < 64; i += Lanes(d)) {
395
80
                const auto ini = Load(di, transposed_dct + i);
396
80
                Store(ini, di, transposed_dct_y + i);
397
80
                const auto ini16 = DemoteTo(di16, ini);
398
80
                StoreU(ini16, di16, jpeg_pos + i);
399
80
              }
400
5
            } else {
401
              // transposed_dct_y contains the y channel block, transposed.
402
0
              const auto scale =
403
0
                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
404
0
              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
405
0
              for (int i = 0; i < 64; i += Lanes(d)) {
406
0
                auto in = Load(di, transposed_dct + i);
407
0
                auto in_y = Load(di, transposed_dct_y + i);
408
0
                auto qt = Load(di, scaled_qtable + c * size + i);
409
0
                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
410
0
                    Add(Mul(qt, scale), round));
411
0
                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
412
0
                    Add(Mul(in_y, coeff_scale), round));
413
0
                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
414
0
              }
415
0
            }
416
18
            jpeg_pos[0] =
417
18
                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
418
18
            auto overflow = MaskFromVec(Set(di16_full, 0));
419
18
            auto underflow = MaskFromVec(Set(di16_full, 0));
420
162
            for (int i = 0; i < 64; i += Lanes(di16_full)) {
421
144
              auto in = LoadU(di16_full, jpeg_pos + i);
422
144
              overflow = Or(overflow, Gt(in, kJpegDctMax));
423
144
              underflow = Or(underflow, Lt(in, kJpegDctMin));
424
144
            }
425
18
            if (!AllFalse(di16_full, Or(overflow, underflow))) {
426
1
              return JXL_FAILURE("JPEG DCT coefficients out of range");
427
1
            }
428
18
          }
429
265k
        } else {
430
265k
          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
431
          // Dequantize and add predictions.
432
265k
          dequant_block(
433
265k
              inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
434
265k
              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.Strategy(),
435
265k
              size, dec_state->shared->quantizer,
436
265k
              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
437
265k
              dc_stride,
438
265k
              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
439
265k
              block, group_dec_cache->scratch_space);
440
441
795k
          for (size_t c : {1, 0, 2}) {
442
795k
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
443
215k
              continue;
444
215k
            }
445
            // IDCT
446
579k
            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
447
579k
            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
448
579k
                              idct_stride[c], group_dec_cache->scratch_space);
449
579k
          }
450
265k
        }
451
265k
        bx += llf_x;
452
265k
      }
453
76.8k
    }
454
50.1k
  }
455
3.85k
  return true;
456
4.02k
}
457
458
// NOLINTNEXTLINE(google-readability-namespace-comments)
459
}  // namespace HWY_NAMESPACE
460
}  // namespace jxl
461
HWY_AFTER_NAMESPACE();
462
463
#if HWY_ONCE
464
namespace jxl {
465
namespace {
466
// Decode quantized AC coefficients of DCT blocks.
467
// LLF components in the output block will not be modified.
468
template <ACType ac_type, bool uses_lz77>
469
Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks,
470
                        int32_t* JXL_RESTRICT row_nzeros,
471
                        const int32_t* JXL_RESTRICT row_nzeros_top,
472
                        size_t nzeros_stride, size_t c, size_t bx, size_t by,
473
                        size_t lbx, AcStrategy acs,
474
                        const coeff_order_t* JXL_RESTRICT coeff_order,
475
                        BitReader* JXL_RESTRICT br,
476
                        ANSSymbolReader* JXL_RESTRICT decoder,
477
                        const std::vector<uint8_t>& context_map,
478
                        const uint8_t* qdc_row, const int32_t* qf_row,
479
                        const BlockCtxMap& block_ctx_map, ACPtr block,
480
3.44M
                        size_t shift = 0) {
481
  // Equal to number of LLF coefficients.
482
3.44M
  const size_t covered_blocks = 1 << log2_covered_blocks;
483
3.44M
  const size_t size = covered_blocks * kDCTBlockSize;
484
3.44M
  int32_t predicted_nzeros =
485
3.44M
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
486
487
3.44M
  size_t ord = kStrategyOrder[acs.RawStrategy()];
488
3.44M
  const coeff_order_t* JXL_RESTRICT order =
489
3.44M
      &coeff_order[CoeffOrderOffset(ord, c)];
490
491
3.44M
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
492
3.44M
  const int32_t nzero_ctx =
493
3.44M
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
494
495
3.44M
  size_t nzeros =
496
3.44M
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
497
3.44M
  if (nzeros > size - covered_blocks) {
498
343
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
499
343
                       " 8x8 blocks",
500
343
                       nzeros, covered_blocks);
501
343
  }
502
7.28M
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
503
8.62M
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
504
4.79M
      row_nzeros[bx + x + y * nzeros_stride] =
505
4.79M
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
506
4.79M
    }
507
3.83M
  }
508
509
3.44M
  const size_t histo_offset =
510
3.44M
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
511
512
3.44M
  size_t prev = (nzeros > size / 16 ? 0 : 1);
513
23.3M
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
514
19.9M
    const size_t ctx =
515
19.9M
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
516
19.9M
                                          log2_covered_blocks, prev);
517
19.9M
    const size_t u_coeff =
518
19.9M
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
519
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
520
    // signed integer to avoid undefined behavior of shifting negative numbers.
521
19.9M
    const size_t magnitude = u_coeff >> 1;
522
19.9M
    const size_t neg_sign = (~u_coeff) & 1;
523
19.9M
    const intptr_t coeff =
524
19.9M
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
525
19.9M
    if (ac_type == ACType::k16) {
526
6.49M
      block.ptr16[order[k]] += coeff;
527
13.4M
    } else {
528
13.4M
      block.ptr32[order[k]] += coeff;
529
13.4M
    }
530
19.9M
    prev = static_cast<size_t>(u_coeff != 0);
531
19.9M
    nzeros -= prev;
532
19.9M
  }
533
3.44M
  if (JXL_UNLIKELY(nzeros != 0)) {
534
533
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
535
533
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
536
533
                       "), channel %" PRIuS,
537
533
                       nzeros, bx, by, c);
538
533
  }
539
540
3.44M
  return true;
541
3.44M
}
dec_group.cc:jxl::Status jxl::(anonymous namespace)::DecodeACVarBlock<(jxl::ACType)0, true>(unsigned long, unsigned long, int*, int const*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, jxl::AcStrategy, unsigned int const*, jxl::BitReader*, jxl::ANSSymbolReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&, unsigned char const*, int const*, jxl::BlockCtxMap const&, jxl::ACPtr, unsigned long)
Line
Count
Source
480
410k
                        size_t shift = 0) {
481
  // Equal to number of LLF coefficients.
482
410k
  const size_t covered_blocks = 1 << log2_covered_blocks;
483
410k
  const size_t size = covered_blocks * kDCTBlockSize;
484
410k
  int32_t predicted_nzeros =
485
410k
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
486
487
410k
  size_t ord = kStrategyOrder[acs.RawStrategy()];
488
410k
  const coeff_order_t* JXL_RESTRICT order =
489
410k
      &coeff_order[CoeffOrderOffset(ord, c)];
490
491
410k
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
492
410k
  const int32_t nzero_ctx =
493
410k
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
494
495
410k
  size_t nzeros =
496
410k
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
497
410k
  if (nzeros > size - covered_blocks) {
498
55
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
499
55
                       " 8x8 blocks",
500
55
                       nzeros, covered_blocks);
501
55
  }
502
873k
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
503
1.06M
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
504
596k
      row_nzeros[bx + x + y * nzeros_stride] =
505
596k
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
506
596k
    }
507
463k
  }
508
509
410k
  const size_t histo_offset =
510
410k
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
511
512
410k
  size_t prev = (nzeros > size / 16 ? 0 : 1);
513
1.10M
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
514
692k
    const size_t ctx =
515
692k
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
516
692k
                                          log2_covered_blocks, prev);
517
692k
    const size_t u_coeff =
518
692k
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
519
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
520
    // signed integer to avoid undefined behavior of shifting negative numbers.
521
692k
    const size_t magnitude = u_coeff >> 1;
522
692k
    const size_t neg_sign = (~u_coeff) & 1;
523
692k
    const intptr_t coeff =
524
692k
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
525
692k
    if (ac_type == ACType::k16) {
526
692k
      block.ptr16[order[k]] += coeff;
527
18.4E
    } else {
528
18.4E
      block.ptr32[order[k]] += coeff;
529
18.4E
    }
530
692k
    prev = static_cast<size_t>(u_coeff != 0);
531
692k
    nzeros -= prev;
532
692k
  }
533
410k
  if (JXL_UNLIKELY(nzeros != 0)) {
534
116
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
535
116
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
536
116
                       "), channel %" PRIuS,
537
116
                       nzeros, bx, by, c);
538
116
  }
539
540
410k
  return true;
541
410k
}
dec_group.cc:jxl::Status jxl::(anonymous namespace)::DecodeACVarBlock<(jxl::ACType)1, true>(unsigned long, unsigned long, int*, int const*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, jxl::AcStrategy, unsigned int const*, jxl::BitReader*, jxl::ANSSymbolReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&, unsigned char const*, int const*, jxl::BlockCtxMap const&, jxl::ACPtr, unsigned long)
Line
Count
Source
480
71.1k
                        size_t shift = 0) {
481
  // Equal to number of LLF coefficients.
482
71.1k
  const size_t covered_blocks = 1 << log2_covered_blocks;
483
71.1k
  const size_t size = covered_blocks * kDCTBlockSize;
484
71.1k
  int32_t predicted_nzeros =
485
71.1k
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
486
487
71.1k
  size_t ord = kStrategyOrder[acs.RawStrategy()];
488
71.1k
  const coeff_order_t* JXL_RESTRICT order =
489
71.1k
      &coeff_order[CoeffOrderOffset(ord, c)];
490
491
71.1k
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
492
71.1k
  const int32_t nzero_ctx =
493
71.1k
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
494
495
71.1k
  size_t nzeros =
496
71.1k
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
497
71.1k
  if (nzeros > size - covered_blocks) {
498
85
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
499
85
                       " 8x8 blocks",
500
85
                       nzeros, covered_blocks);
501
85
  }
502
153k
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
503
199k
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
504
117k
      row_nzeros[bx + x + y * nzeros_stride] =
505
117k
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
506
117k
    }
507
82.0k
  }
508
509
71.0k
  const size_t histo_offset =
510
71.0k
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
511
512
71.0k
  size_t prev = (nzeros > size / 16 ? 0 : 1);
513
914k
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
514
843k
    const size_t ctx =
515
843k
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
516
843k
                                          log2_covered_blocks, prev);
517
843k
    const size_t u_coeff =
518
843k
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
519
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
520
    // signed integer to avoid undefined behavior of shifting negative numbers.
521
843k
    const size_t magnitude = u_coeff >> 1;
522
843k
    const size_t neg_sign = (~u_coeff) & 1;
523
843k
    const intptr_t coeff =
524
843k
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
525
843k
    if (ac_type == ACType::k16) {
526
0
      block.ptr16[order[k]] += coeff;
527
843k
    } else {
528
843k
      block.ptr32[order[k]] += coeff;
529
843k
    }
530
843k
    prev = static_cast<size_t>(u_coeff != 0);
531
843k
    nzeros -= prev;
532
843k
  }
533
71.0k
  if (JXL_UNLIKELY(nzeros != 0)) {
534
65
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
535
65
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
536
65
                       "), channel %" PRIuS,
537
65
                       nzeros, bx, by, c);
538
65
  }
539
540
70.9k
  return true;
541
71.0k
}
dec_group.cc:jxl::Status jxl::(anonymous namespace)::DecodeACVarBlock<(jxl::ACType)0, false>(unsigned long, unsigned long, int*, int const*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, jxl::AcStrategy, unsigned int const*, jxl::BitReader*, jxl::ANSSymbolReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&, unsigned char const*, int const*, jxl::BlockCtxMap const&, jxl::ACPtr, unsigned long)
Line
Count
Source
480
1.65M
                        size_t shift = 0) {
481
  // Equal to number of LLF coefficients.
482
1.65M
  const size_t covered_blocks = 1 << log2_covered_blocks;
483
1.65M
  const size_t size = covered_blocks * kDCTBlockSize;
484
1.65M
  int32_t predicted_nzeros =
485
1.65M
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
486
487
1.65M
  size_t ord = kStrategyOrder[acs.RawStrategy()];
488
1.65M
  const coeff_order_t* JXL_RESTRICT order =
489
1.65M
      &coeff_order[CoeffOrderOffset(ord, c)];
490
491
1.65M
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
492
1.65M
  const int32_t nzero_ctx =
493
1.65M
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
494
495
1.65M
  size_t nzeros =
496
1.65M
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
497
1.65M
  if (nzeros > size - covered_blocks) {
498
48
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
499
48
                       " 8x8 blocks",
500
48
                       nzeros, covered_blocks);
501
48
  }
502
3.50M
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
503
4.32M
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
504
2.47M
      row_nzeros[bx + x + y * nzeros_stride] =
505
2.47M
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
506
2.47M
    }
507
1.84M
  }
508
509
1.65M
  const size_t histo_offset =
510
1.65M
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
511
512
1.65M
  size_t prev = (nzeros > size / 16 ? 0 : 1);
513
7.45M
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
514
5.80M
    const size_t ctx =
515
5.80M
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
516
5.80M
                                          log2_covered_blocks, prev);
517
5.80M
    const size_t u_coeff =
518
5.80M
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
519
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
520
    // signed integer to avoid undefined behavior of shifting negative numbers.
521
5.80M
    const size_t magnitude = u_coeff >> 1;
522
5.80M
    const size_t neg_sign = (~u_coeff) & 1;
523
5.80M
    const intptr_t coeff =
524
5.80M
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
525
5.80M
    if (ac_type == ACType::k16) {
526
5.79M
      block.ptr16[order[k]] += coeff;
527
5.79M
    } else {
528
4.31k
      block.ptr32[order[k]] += coeff;
529
4.31k
    }
530
5.80M
    prev = static_cast<size_t>(u_coeff != 0);
531
5.80M
    nzeros -= prev;
532
5.80M
  }
533
1.65M
  if (JXL_UNLIKELY(nzeros != 0)) {
534
315
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
535
315
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
536
315
                       "), channel %" PRIuS,
537
315
                       nzeros, bx, by, c);
538
315
  }
539
540
1.65M
  return true;
541
1.65M
}
dec_group.cc:jxl::Status jxl::(anonymous namespace)::DecodeACVarBlock<(jxl::ACType)1, false>(unsigned long, unsigned long, int*, int const*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, jxl::AcStrategy, unsigned int const*, jxl::BitReader*, jxl::ANSSymbolReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&, unsigned char const*, int const*, jxl::BlockCtxMap const&, jxl::ACPtr, unsigned long)
Line
Count
Source
480
1.31M
                        size_t shift = 0) {
481
  // Equal to number of LLF coefficients.
482
1.31M
  const size_t covered_blocks = 1 << log2_covered_blocks;
483
1.31M
  const size_t size = covered_blocks * kDCTBlockSize;
484
1.31M
  int32_t predicted_nzeros =
485
1.31M
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
486
487
1.31M
  size_t ord = kStrategyOrder[acs.RawStrategy()];
488
1.31M
  const coeff_order_t* JXL_RESTRICT order =
489
1.31M
      &coeff_order[CoeffOrderOffset(ord, c)];
490
491
1.31M
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
492
1.31M
  const int32_t nzero_ctx =
493
1.31M
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
494
495
1.31M
  size_t nzeros =
496
1.31M
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
497
1.31M
  if (nzeros > size - covered_blocks) {
498
155
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
499
155
                       " 8x8 blocks",
500
155
                       nzeros, covered_blocks);
501
155
  }
502
2.74M
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
503
3.04M
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
504
1.60M
      row_nzeros[bx + x + y * nzeros_stride] =
505
1.60M
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
506
1.60M
    }
507
1.43M
  }
508
509
1.31M
  const size_t histo_offset =
510
1.31M
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
511
512
1.31M
  size_t prev = (nzeros > size / 16 ? 0 : 1);
513
13.9M
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
514
12.6M
    const size_t ctx =
515
12.6M
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
516
12.6M
                                          log2_covered_blocks, prev);
517
12.6M
    const size_t u_coeff =
518
12.6M
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
519
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
520
    // signed integer to avoid undefined behavior of shifting negative numbers.
521
12.6M
    const size_t magnitude = u_coeff >> 1;
522
12.6M
    const size_t neg_sign = (~u_coeff) & 1;
523
12.6M
    const intptr_t coeff =
524
12.6M
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
525
12.6M
    if (ac_type == ACType::k16) {
526
0
      block.ptr16[order[k]] += coeff;
527
12.6M
    } else {
528
12.6M
      block.ptr32[order[k]] += coeff;
529
12.6M
    }
530
12.6M
    prev = static_cast<size_t>(u_coeff != 0);
531
12.6M
    nzeros -= prev;
532
12.6M
  }
533
1.31M
  if (JXL_UNLIKELY(nzeros != 0)) {
534
37
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
535
37
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
536
37
                       "), channel %" PRIuS,
537
37
                       nzeros, bx, by, c);
538
37
  }
539
540
1.31M
  return true;
541
1.31M
}
542
543
// Structs used by DecodeGroupImpl to get a quantized block.
544
// GetBlockFromBitstream uses ANS decoding (and thus keeps track of row
545
// pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient
546
// image provided by the encoder.
547
548
struct GetBlockFromBitstream : public GetBlock {
549
205k
  void StartRow(size_t by) override {
550
205k
    qf_row = rect.ConstRow(*qf, by);
551
820k
    for (size_t c = 0; c < 3; c++) {
552
615k
      size_t sby = by >> vshift[c];
553
615k
      quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0();
554
1.23M
      for (size_t i = 0; i < num_passes; i++) {
555
616k
        row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby);
556
616k
        row_nzeros_top[i][c] =
557
616k
            sby == 0
558
616k
                ? nullptr
559
616k
                : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1);
560
616k
      }
561
615k
    }
562
205k
  }
563
564
  Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size,
565
                   size_t log2_covered_blocks, ACPtr block[3],
566
1.51M
                   ACType ac_type) override {
567
1.51M
    ;
568
4.55M
    for (size_t c : {1, 0, 2}) {
569
4.55M
      size_t sbx = bx >> hshift[c];
570
4.55M
      size_t sby = by >> vshift[c];
571
4.55M
      if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) {
572
1.10M
        continue;
573
1.10M
      }
574
575
6.89M
      for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) {
576
3.44M
        auto decode_ac_varblock =
577
3.44M
            decoders[pass].UsesLZ77()
578
3.44M
                ? (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 1>
579
481k
                                          : DecodeACVarBlock<ACType::k32, 1>)
580
3.44M
                : (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 0>
581
2.96M
                                          : DecodeACVarBlock<ACType::k32, 0>);
582
3.44M
        JXL_RETURN_IF_ERROR(decode_ac_varblock(
583
3.44M
            ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c],
584
3.44M
            row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs,
585
3.44M
            &coeff_orders[pass * coeff_order_size], readers[pass],
586
3.44M
            &decoders[pass], context_map[pass], quant_dc_row, qf_row,
587
3.44M
            *block_ctx_map, block[c], shift_for_pass[pass]));
588
3.44M
      }
589
3.44M
    }
590
1.51M
    return true;
591
1.51M
  }
592
593
  Status Init(const FrameHeader& frame_header,
594
              BitReader* JXL_RESTRICT* JXL_RESTRICT readers_,
595
              size_t num_passes_, size_t group_idx, size_t histo_selector_bits,
596
              const Rect& rect_, GroupDecCache* JXL_RESTRICT group_dec_cache_,
597
24.7k
              PassesDecoderState* dec_state, size_t first_pass) {
598
98.9k
    for (size_t i = 0; i < 3; i++) {
599
74.2k
      hshift[i] = frame_header.chroma_subsampling.HShift(i);
600
74.2k
      vshift[i] = frame_header.chroma_subsampling.VShift(i);
601
74.2k
    }
602
24.7k
    coeff_order_size = dec_state->shared->coeff_order_size;
603
24.7k
    coeff_orders =
604
24.7k
        dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size;
605
24.7k
    context_map = dec_state->context_map.data() + first_pass;
606
24.7k
    readers = readers_;
607
24.7k
    num_passes = num_passes_;
608
24.7k
    shift_for_pass = frame_header.passes.shift + first_pass;
609
24.7k
    group_dec_cache = group_dec_cache_;
610
24.7k
    rect = rect_;
611
24.7k
    block_ctx_map = &dec_state->shared->block_ctx_map;
612
24.7k
    qf = &dec_state->shared->raw_quant_field;
613
24.7k
    quant_dc = &dec_state->shared->quant_dc;
614
615
49.5k
    for (size_t pass = 0; pass < num_passes; pass++) {
616
      // Select which histogram set to use among those of the current pass.
617
24.7k
      size_t cur_histogram = 0;
618
24.7k
      if (histo_selector_bits != 0) {
619
7.49k
        cur_histogram = readers[pass]->ReadBits(histo_selector_bits);
620
7.49k
      }
621
24.7k
      if (cur_histogram >= dec_state->shared->num_histograms) {
622
19
        return JXL_FAILURE("Invalid histogram selector");
623
19
      }
624
24.7k
      ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts();
625
626
24.7k
      JXL_ASSIGN_OR_RETURN(
627
24.7k
          decoders[pass],
628
24.7k
          ANSSymbolReader::Create(&dec_state->code[pass + first_pass],
629
24.7k
                                  readers[pass]));
630
24.7k
    }
631
24.7k
    nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow();
632
49.4k
    for (size_t i = 0; i < num_passes; i++) {
633
24.7k
      JXL_ENSURE(
634
24.7k
          nzeros_stride ==
635
24.7k
          static_cast<size_t>(group_dec_cache->num_nzeroes[i].PixelsPerRow()));
636
24.7k
    }
637
24.7k
    return true;
638
24.7k
  }
639
640
  const uint32_t* shift_for_pass = nullptr;  // not owned
641
  const coeff_order_t* JXL_RESTRICT coeff_orders;
642
  size_t coeff_order_size;
643
  const std::vector<uint8_t>* JXL_RESTRICT context_map;
644
  ANSSymbolReader decoders[kMaxNumPasses];
645
  BitReader* JXL_RESTRICT* JXL_RESTRICT readers;
646
  size_t num_passes;
647
  size_t ctx_offset[kMaxNumPasses];
648
  size_t nzeros_stride;
649
  int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3];
650
  const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3];
651
  GroupDecCache* JXL_RESTRICT group_dec_cache;
652
  const BlockCtxMap* block_ctx_map;
653
  const ImageI* qf;
654
  const ImageB* quant_dc;
655
  const int32_t* qf_row;
656
  const uint8_t* quant_dc_row;
657
  Rect rect;
658
  size_t hshift[3], vshift[3];
659
};
660
661
struct GetBlockFromEncoder : public GetBlock {
662
804
  void StartRow(size_t by) override {}
663
664
  Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size,
665
                   size_t log2_covered_blocks, ACPtr block[3],
666
804
                   ACType ac_type) override {
667
804
    JXL_ENSURE(ac_type == ACType::k32);
668
3.21k
    for (size_t c = 0; c < 3; c++) {
669
      // for each pass
670
4.82k
      for (size_t i = 0; i < quantized_ac->size(); i++) {
671
156k
        for (size_t k = 0; k < size; k++) {
672
          // TODO(veluca): SIMD.
673
154k
          block[c].ptr32[k] +=
674
154k
              rows[i][c][offset + k] * (1 << shift_for_pass[i]);
675
154k
        }
676
2.41k
      }
677
2.41k
    }
678
804
    offset += size;
679
804
    return true;
680
804
  }
681
682
  static StatusOr<GetBlockFromEncoder> Create(
683
      const std::vector<std::unique_ptr<ACImage>>& ac, size_t group_idx,
684
804
      const uint32_t* shift_for_pass) {
685
804
    GetBlockFromEncoder result(ac, group_idx, shift_for_pass);
686
    // TODO(veluca): not supported with chroma subsampling.
687
1.60k
    for (size_t i = 0; i < ac.size(); i++) {
688
804
      JXL_ENSURE(ac[i]->Type() == ACType::k32);
689
3.21k
      for (size_t c = 0; c < 3; c++) {
690
2.41k
        result.rows[i][c] = ac[i]->PlaneRow(c, group_idx, 0).ptr32;
691
2.41k
      }
692
804
    }
693
804
    return result;
694
804
  }
695
696
  const std::vector<std::unique_ptr<ACImage>>* JXL_RESTRICT quantized_ac;
697
  size_t offset = 0;
698
  const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3];
699
  const uint32_t* shift_for_pass = nullptr;  // not owned
700
701
 private:
702
  GetBlockFromEncoder(const std::vector<std::unique_ptr<ACImage>>& ac,
703
                      size_t group_idx, const uint32_t* shift_for_pass)
704
804
      : quantized_ac(&ac), shift_for_pass(shift_for_pass) {}
705
};
706
707
HWY_EXPORT(DecodeGroupImpl);
708
709
}  // namespace
710
711
Status DecodeGroup(const FrameHeader& frame_header,
712
                   BitReader* JXL_RESTRICT* JXL_RESTRICT readers,
713
                   size_t num_passes, size_t group_idx,
714
                   PassesDecoderState* JXL_RESTRICT dec_state,
715
                   GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread,
716
                   RenderPipelineInput& render_pipeline_input,
717
                   jpeg::JPEGData* JXL_RESTRICT jpeg_data, size_t first_pass,
718
24.7k
                   bool force_draw, bool dc_only, bool* should_run_pipeline) {
719
24.7k
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
720
24.7k
  DrawMode draw =
721
24.7k
      (num_passes + first_pass == frame_header.passes.num_passes) || force_draw
722
24.7k
          ? kDraw
723
24.7k
          : kDontDraw;
724
725
24.7k
  if (should_run_pipeline) {
726
24.7k
    *should_run_pipeline = draw != kDontDraw;
727
24.7k
  }
728
729
24.7k
  if (draw == kDraw && num_passes == 0 && first_pass == 0) {
730
0
    JXL_RETURN_IF_ERROR(group_dec_cache->InitDCBufferOnce(memory_manager));
731
0
    const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
732
0
    for (size_t c : {0, 1, 2}) {
733
0
      size_t hs = cs.HShift(c);
734
0
      size_t vs = cs.VShift(c);
735
      // We reuse filter_input_storage here as it is not currently in use.
736
0
      const Rect src_rect_precs =
737
0
          dec_state->shared->frame_dim.BlockGroupRect(group_idx);
738
0
      const Rect src_rect =
739
0
          Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs,
740
0
               src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs);
741
0
      const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(),
742
0
                           src_rect.ysize());
743
0
      JXL_RETURN_IF_ERROR(
744
0
          CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2,
745
0
                                 copy_rect, &group_dec_cache->dc_buffer));
746
      // Mirrorpad. Interleaving left and right padding ensures that padding
747
      // works out correctly even for images with DC size of 1.
748
0
      for (size_t y = 0; y < src_rect.ysize() + 4; y++) {
749
0
        size_t xend = kRenderPipelineXOffset +
750
0
                      (dec_state->shared->dc->Plane(c).xsize() >> hs) -
751
0
                      src_rect.x0();
752
0
        for (size_t ix = 0; ix < 2; ix++) {
753
0
          if (src_rect.x0() == 0) {
754
0
            group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] =
755
0
                group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix];
756
0
          }
757
0
          if (src_rect.x0() + src_rect.xsize() + 2 >=
758
0
              (dec_state->shared->dc->xsize() >> hs)) {
759
0
            group_dec_cache->dc_buffer.Row(y)[xend + ix] =
760
0
                group_dec_cache->dc_buffer.Row(y)[xend - ix - 1];
761
0
          }
762
0
        }
763
0
      }
764
0
      const auto& buffer = render_pipeline_input.GetBuffer(c);
765
0
      Rect dst_rect = buffer.second;
766
0
      ImageF* upsampling_dst = buffer.first;
767
0
      JXL_ENSURE(dst_rect.IsInside(*upsampling_dst));
768
769
0
      RenderPipelineStage::RowInfo input_rows(1, std::vector<float*>(5));
770
0
      RenderPipelineStage::RowInfo output_rows(1, std::vector<float*>(8));
771
0
      for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize();
772
0
           y++) {
773
0
        for (ssize_t iy = 0; iy < 5; iy++) {
774
0
          input_rows[0][iy] = group_dec_cache->dc_buffer.Row(
775
0
              Mirror(static_cast<ssize_t>(y) + iy - 2,
776
0
                     dec_state->shared->dc->Plane(c).ysize() >> vs) +
777
0
              2 - src_rect.y0());
778
0
        }
779
0
        for (size_t iy = 0; iy < 8; iy++) {
780
0
          output_rows[0][iy] =
781
0
              dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) -
782
0
              kRenderPipelineXOffset;
783
0
        }
784
        // Arguments set to 0/nullptr are not used.
785
0
        JXL_RETURN_IF_ERROR(dec_state->upsampler8x->ProcessRow(
786
0
            input_rows, output_rows,
787
0
            /*xextra=*/0, src_rect.xsize(), 0, 0, thread));
788
0
      }
789
0
    }
790
0
    return true;
791
0
  }
792
793
24.7k
  size_t histo_selector_bits = 0;
794
24.7k
  if (dc_only) {
795
0
    JXL_ENSURE(num_passes == 0);
796
24.7k
  } else {
797
24.7k
    JXL_ENSURE(dec_state->shared->num_histograms > 0);
798
24.7k
    histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms);
799
24.7k
  }
800
801
24.7k
  auto get_block = jxl::make_unique<GetBlockFromBitstream>();
802
24.7k
  JXL_RETURN_IF_ERROR(get_block->Init(
803
24.7k
      frame_header, readers, num_passes, group_idx, histo_selector_bits,
804
24.7k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx), group_dec_cache,
805
24.7k
      dec_state, first_pass));
806
807
24.7k
  JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(
808
24.7k
      frame_header, get_block.get(), group_dec_cache, dec_state, thread,
809
24.7k
      group_idx, render_pipeline_input, jpeg_data, draw));
810
811
47.6k
  for (size_t pass = 0; pass < num_passes; pass++) {
812
23.8k
    if (!get_block->decoders[pass].CheckANSFinalState()) {
813
0
      return JXL_FAILURE("ANS checksum failure.");
814
0
    }
815
23.8k
  }
816
23.7k
  return true;
817
23.7k
}
818
819
Status DecodeGroupForRoundtrip(const FrameHeader& frame_header,
820
                               const std::vector<std::unique_ptr<ACImage>>& ac,
821
                               size_t group_idx,
822
                               PassesDecoderState* JXL_RESTRICT dec_state,
823
                               GroupDecCache* JXL_RESTRICT group_dec_cache,
824
                               size_t thread,
825
                               RenderPipelineInput& render_pipeline_input,
826
                               jpeg::JPEGData* JXL_RESTRICT jpeg_data,
827
804
                               AuxOut* aux_out) {
828
804
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
829
804
  JXL_ASSIGN_OR_RETURN(
830
804
      GetBlockFromEncoder get_block,
831
804
      GetBlockFromEncoder::Create(ac, group_idx, frame_header.passes.shift));
832
804
  JXL_RETURN_IF_ERROR(group_dec_cache->InitOnce(
833
804
      memory_manager,
834
804
      /*num_passes=*/0,
835
804
      /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1));
836
837
804
  return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(
838
804
      frame_header, &get_block, group_dec_cache, dec_state, thread, group_idx,
839
804
      render_pipeline_input, jpeg_data, kDraw);
840
804
}
841
842
}  // namespace jxl
843
#endif  // HWY_ONCE