Coverage Report

Created: 2024-05-21 06:41

/src/libjxl/lib/jxl/dec_group.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_group.h"
7
8
#include <algorithm>
9
#include <cstdint>
10
#include <cstring>
11
#include <memory>
12
#include <utility>
13
14
#include "lib/jxl/chroma_from_luma.h"
15
#include "lib/jxl/frame_header.h"
16
17
#undef HWY_TARGET_INCLUDE
18
#define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc"
19
#include <hwy/foreach_target.h>
20
#include <hwy/highway.h>
21
22
#include "lib/jxl/ac_context.h"
23
#include "lib/jxl/ac_strategy.h"
24
#include "lib/jxl/base/bits.h"
25
#include "lib/jxl/base/common.h"
26
#include "lib/jxl/base/printf_macros.h"
27
#include "lib/jxl/base/rect.h"
28
#include "lib/jxl/base/status.h"
29
#include "lib/jxl/coeff_order.h"
30
#include "lib/jxl/common.h"  // kMaxNumPasses
31
#include "lib/jxl/dec_cache.h"
32
#include "lib/jxl/dec_transforms-inl.h"
33
#include "lib/jxl/dec_xyb.h"
34
#include "lib/jxl/entropy_coder.h"
35
#include "lib/jxl/quant_weights.h"
36
#include "lib/jxl/quantizer-inl.h"
37
#include "lib/jxl/quantizer.h"
38
39
#ifndef LIB_JXL_DEC_GROUP_CC
40
#define LIB_JXL_DEC_GROUP_CC
41
namespace jxl {
42
43
struct AuxOut;
44
45
// Interface for reading groups for DecodeGroupImpl.
46
class GetBlock {
47
 public:
48
  virtual void StartRow(size_t by) = 0;
49
  virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs,
50
                           size_t size, size_t log2_covered_blocks,
51
                           ACPtr block[3], ACType ac_type) = 0;
52
11.0k
  virtual ~GetBlock() {}
53
};
54
55
// Controls whether DecodeGroupImpl renders to pixels or not.
56
enum DrawMode {
57
  // Render to pixels.
58
  kDraw = 0,
59
  // Don't render to pixels.
60
  kDontDraw = 1,
61
};
62
63
}  // namespace jxl
64
#endif  // LIB_JXL_DEC_GROUP_CC
65
66
HWY_BEFORE_NAMESPACE();
67
namespace jxl {
68
namespace HWY_NAMESPACE {
69
70
// These templates are not found via ADL.
71
using hwy::HWY_NAMESPACE::AllFalse;
72
using hwy::HWY_NAMESPACE::Gt;
73
using hwy::HWY_NAMESPACE::Le;
74
using hwy::HWY_NAMESPACE::MaskFromVec;
75
using hwy::HWY_NAMESPACE::Or;
76
using hwy::HWY_NAMESPACE::Rebind;
77
using hwy::HWY_NAMESPACE::ShiftRight;
78
79
using D = HWY_FULL(float);
80
using DU = HWY_FULL(uint32_t);
81
using DI = HWY_FULL(int32_t);
82
using DI16 = Rebind<int16_t, DI>;
83
using DI16_FULL = HWY_CAPPED(int16_t, kDCTBlockSize);
84
constexpr D d;
85
constexpr DI di;
86
constexpr DI16 di16;
87
constexpr DI16_FULL di16_full;
88
89
// TODO(veluca): consider SIMDfying.
90
271
void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
91
2.43k
  for (size_t x = 0; x < 8; x++) {
92
9.75k
    for (size_t y = x + 1; y < 8; y++) {
93
7.58k
      std::swap(block[y * 8 + x], block[x * 8 + y]);
94
7.58k
    }
95
2.16k
  }
96
271
}
jxl::N_SSE4::Transpose8x8InPlace(int*)
Line
Count
Source
90
56
void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
91
504
  for (size_t x = 0; x < 8; x++) {
92
2.01k
    for (size_t y = x + 1; y < 8; y++) {
93
1.56k
      std::swap(block[y * 8 + x], block[x * 8 + y]);
94
1.56k
    }
95
448
  }
96
56
}
jxl::N_AVX2::Transpose8x8InPlace(int*)
Line
Count
Source
90
198
void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
91
1.78k
  for (size_t x = 0; x < 8; x++) {
92
7.12k
    for (size_t y = x + 1; y < 8; y++) {
93
5.54k
      std::swap(block[y * 8 + x], block[x * 8 + y]);
94
5.54k
    }
95
1.58k
  }
96
198
}
jxl::N_SSE2::Transpose8x8InPlace(int*)
Line
Count
Source
90
17
void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
91
153
  for (size_t x = 0; x < 8; x++) {
92
612
    for (size_t y = x + 1; y < 8; y++) {
93
476
      std::swap(block[y * 8 + x], block[x * 8 + y]);
94
476
    }
95
136
  }
96
17
}
97
98
template <ACType ac_type>
99
void DequantLane(Vec<D> scaled_dequant_x, Vec<D> scaled_dequant_y,
100
                 Vec<D> scaled_dequant_b,
101
                 const float* JXL_RESTRICT dequant_matrices, size_t size,
102
                 size_t k, Vec<D> x_cc_mul, Vec<D> b_cc_mul,
103
                 const float* JXL_RESTRICT biases, ACPtr qblock[3],
104
10.8M
                 float* JXL_RESTRICT block) {
105
10.8M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
106
10.8M
  const auto y_mul =
107
10.8M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
108
10.8M
  const auto b_mul =
109
10.8M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
110
111
10.8M
  Vec<DI> quantized_x_int;
112
10.8M
  Vec<DI> quantized_y_int;
113
10.8M
  Vec<DI> quantized_b_int;
114
10.8M
  if (ac_type == ACType::k16) {
115
6.83M
    Rebind<int16_t, DI> di16;
116
6.83M
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
117
6.83M
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
118
6.83M
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
119
6.83M
  } else {
120
4.01M
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
121
4.01M
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
122
4.01M
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
123
4.01M
  }
124
125
10.8M
  const auto dequant_x_cc =
126
10.8M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
127
10.8M
  const auto dequant_y =
128
10.8M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
129
10.8M
  const auto dequant_b_cc =
130
10.8M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
131
132
10.8M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
133
10.8M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
134
10.8M
  Store(dequant_x, d, block + k);
135
10.8M
  Store(dequant_y, d, block + size + k);
136
10.8M
  Store(dequant_b, d, block + 2 * size + k);
137
10.8M
}
void jxl::N_SSE4::DequantLane<(jxl::ACType)0>(hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, float const*, unsigned long, unsigned long, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
104
2.79M
                 float* JXL_RESTRICT block) {
105
2.79M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
106
2.79M
  const auto y_mul =
107
2.79M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
108
2.79M
  const auto b_mul =
109
2.79M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
110
111
2.79M
  Vec<DI> quantized_x_int;
112
2.79M
  Vec<DI> quantized_y_int;
113
2.79M
  Vec<DI> quantized_b_int;
114
2.80M
  if (ac_type == ACType::k16) {
115
2.80M
    Rebind<int16_t, DI> di16;
116
2.80M
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
117
2.80M
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
118
2.80M
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
119
18.4E
  } else {
120
18.4E
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
121
18.4E
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
122
18.4E
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
123
18.4E
  }
124
125
2.79M
  const auto dequant_x_cc =
126
2.79M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
127
2.79M
  const auto dequant_y =
128
2.79M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
129
2.79M
  const auto dequant_b_cc =
130
2.79M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
131
132
2.79M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
133
2.79M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
134
2.79M
  Store(dequant_x, d, block + k);
135
2.79M
  Store(dequant_y, d, block + size + k);
136
2.79M
  Store(dequant_b, d, block + 2 * size + k);
137
2.79M
}
void jxl::N_SSE4::DequantLane<(jxl::ACType)1>(hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, float const*, unsigned long, unsigned long, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
104
1.24M
                 float* JXL_RESTRICT block) {
105
1.24M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
106
1.24M
  const auto y_mul =
107
1.24M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
108
1.24M
  const auto b_mul =
109
1.24M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
110
111
1.24M
  Vec<DI> quantized_x_int;
112
1.24M
  Vec<DI> quantized_y_int;
113
1.24M
  Vec<DI> quantized_b_int;
114
1.24M
  if (ac_type == ACType::k16) {
115
0
    Rebind<int16_t, DI> di16;
116
0
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
117
0
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
118
0
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
119
1.24M
  } else {
120
1.24M
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
121
1.24M
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
122
1.24M
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
123
1.24M
  }
124
125
1.24M
  const auto dequant_x_cc =
126
1.24M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
127
1.24M
  const auto dequant_y =
128
1.24M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
129
1.24M
  const auto dequant_b_cc =
130
1.24M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
131
132
1.24M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
133
1.24M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
134
1.24M
  Store(dequant_x, d, block + k);
135
1.24M
  Store(dequant_y, d, block + size + k);
136
1.24M
  Store(dequant_b, d, block + 2 * size + k);
137
1.24M
}
void jxl::N_AVX2::DequantLane<(jxl::ACType)0>(hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, float const*, unsigned long, unsigned long, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
104
1.66M
                 float* JXL_RESTRICT block) {
105
1.66M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
106
1.66M
  const auto y_mul =
107
1.66M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
108
1.66M
  const auto b_mul =
109
1.66M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
110
111
1.66M
  Vec<DI> quantized_x_int;
112
1.66M
  Vec<DI> quantized_y_int;
113
1.66M
  Vec<DI> quantized_b_int;
114
1.67M
  if (ac_type == ACType::k16) {
115
1.67M
    Rebind<int16_t, DI> di16;
116
1.67M
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
117
1.67M
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
118
1.67M
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
119
18.4E
  } else {
120
18.4E
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
121
18.4E
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
122
18.4E
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
123
18.4E
  }
124
125
1.66M
  const auto dequant_x_cc =
126
1.66M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
127
1.66M
  const auto dequant_y =
128
1.66M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
129
1.66M
  const auto dequant_b_cc =
130
1.66M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
131
132
1.66M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
133
1.66M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
134
1.66M
  Store(dequant_x, d, block + k);
135
1.66M
  Store(dequant_y, d, block + size + k);
136
1.66M
  Store(dequant_b, d, block + 2 * size + k);
137
1.66M
}
void jxl::N_AVX2::DequantLane<(jxl::ACType)1>(hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, float const*, unsigned long, unsigned long, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
104
884k
                 float* JXL_RESTRICT block) {
105
884k
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
106
884k
  const auto y_mul =
107
884k
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
108
884k
  const auto b_mul =
109
884k
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
110
111
884k
  Vec<DI> quantized_x_int;
112
884k
  Vec<DI> quantized_y_int;
113
884k
  Vec<DI> quantized_b_int;
114
884k
  if (ac_type == ACType::k16) {
115
0
    Rebind<int16_t, DI> di16;
116
0
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
117
0
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
118
0
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
119
884k
  } else {
120
884k
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
121
884k
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
122
884k
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
123
884k
  }
124
125
884k
  const auto dequant_x_cc =
126
884k
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
127
884k
  const auto dequant_y =
128
884k
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
129
884k
  const auto dequant_b_cc =
130
884k
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
131
132
884k
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
133
884k
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
134
884k
  Store(dequant_x, d, block + k);
135
884k
  Store(dequant_y, d, block + size + k);
136
884k
  Store(dequant_b, d, block + 2 * size + k);
137
884k
}
void jxl::N_SSE2::DequantLane<(jxl::ACType)0>(hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, float const*, unsigned long, unsigned long, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
104
2.36M
                 float* JXL_RESTRICT block) {
105
2.36M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
106
2.36M
  const auto y_mul =
107
2.36M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
108
2.36M
  const auto b_mul =
109
2.36M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
110
111
2.36M
  Vec<DI> quantized_x_int;
112
2.36M
  Vec<DI> quantized_y_int;
113
2.36M
  Vec<DI> quantized_b_int;
114
2.36M
  if (ac_type == ACType::k16) {
115
2.36M
    Rebind<int16_t, DI> di16;
116
2.36M
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
117
2.36M
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
118
2.36M
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
119
18.4E
  } else {
120
18.4E
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
121
18.4E
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
122
18.4E
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
123
18.4E
  }
124
125
2.36M
  const auto dequant_x_cc =
126
2.36M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
127
2.36M
  const auto dequant_y =
128
2.36M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
129
2.36M
  const auto dequant_b_cc =
130
2.36M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
131
132
2.36M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
133
2.36M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
134
2.36M
  Store(dequant_x, d, block + k);
135
2.36M
  Store(dequant_y, d, block + size + k);
136
2.36M
  Store(dequant_b, d, block + 2 * size + k);
137
2.36M
}
void jxl::N_SSE2::DequantLane<(jxl::ACType)1>(hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, float const*, unsigned long, unsigned long, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, float const*, jxl::ACPtr*, float*)
Line
Count
Source
104
1.89M
                 float* JXL_RESTRICT block) {
105
1.89M
  const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
106
1.89M
  const auto y_mul =
107
1.89M
      Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
108
1.89M
  const auto b_mul =
109
1.89M
      Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
110
111
1.89M
  Vec<DI> quantized_x_int;
112
1.89M
  Vec<DI> quantized_y_int;
113
1.89M
  Vec<DI> quantized_b_int;
114
1.89M
  if (ac_type == ACType::k16) {
115
0
    Rebind<int16_t, DI> di16;
116
0
    quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
117
0
    quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
118
0
    quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
119
1.89M
  } else {
120
1.89M
    quantized_x_int = Load(di, qblock[0].ptr32 + k);
121
1.89M
    quantized_y_int = Load(di, qblock[1].ptr32 + k);
122
1.89M
    quantized_b_int = Load(di, qblock[2].ptr32 + k);
123
1.89M
  }
124
125
1.89M
  const auto dequant_x_cc =
126
1.89M
      Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
127
1.89M
  const auto dequant_y =
128
1.89M
      Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
129
1.89M
  const auto dequant_b_cc =
130
1.89M
      Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
131
132
1.89M
  const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
133
1.89M
  const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
134
1.89M
  Store(dequant_x, d, block + k);
135
1.89M
  Store(dequant_y, d, block + size + k);
136
1.89M
  Store(dequant_b, d, block + 2 * size + k);
137
1.89M
}
138
139
template <ACType ac_type>
140
void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant,
141
                  float x_dm_multiplier, float b_dm_multiplier, Vec<D> x_cc_mul,
142
                  Vec<D> b_cc_mul, size_t kind, size_t size,
143
                  const Quantizer& quantizer, size_t covered_blocks,
144
                  const size_t* sbx,
145
                  const float* JXL_RESTRICT* JXL_RESTRICT dc_row,
146
                  size_t dc_stride, const float* JXL_RESTRICT biases,
147
                  ACPtr qblock[3], float* JXL_RESTRICT block,
148
522k
                  float* JXL_RESTRICT scratch) {
149
522k
  const auto scaled_dequant_s = inv_global_scale / quant;
150
151
522k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
152
522k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
153
522k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
154
155
522k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
156
157
11.3M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
158
10.8M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
159
10.8M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
160
10.8M
                         qblock, block);
161
10.8M
  }
162
2.08M
  for (size_t c = 0; c < 3; c++) {
163
1.56M
    LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
164
1.56M
                            block + c * size, scratch);
165
1.56M
  }
166
522k
}
void jxl::N_SSE4::DequantBlock<(jxl::ACType)0>(jxl::AcStrategy const&, float, int, float, float, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, unsigned long, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
148
94.3k
                  float* JXL_RESTRICT scratch) {
149
94.3k
  const auto scaled_dequant_s = inv_global_scale / quant;
150
151
94.3k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
152
94.3k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
153
94.3k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
154
155
94.3k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
156
157
2.88M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
158
2.79M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
159
2.79M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
160
2.79M
                         qblock, block);
161
2.79M
  }
162
377k
  for (size_t c = 0; c < 3; c++) {
163
282k
    LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
164
282k
                            block + c * size, scratch);
165
282k
  }
166
94.3k
}
void jxl::N_SSE4::DequantBlock<(jxl::ACType)1>(jxl::AcStrategy const&, float, int, float, float, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, unsigned long, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
148
50.7k
                  float* JXL_RESTRICT scratch) {
149
50.7k
  const auto scaled_dequant_s = inv_global_scale / quant;
150
151
50.7k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
152
50.7k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
153
50.7k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
154
155
50.7k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
156
157
1.29M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
158
1.24M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
159
1.24M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
160
1.24M
                         qblock, block);
161
1.24M
  }
162
202k
  for (size_t c = 0; c < 3; c++) {
163
151k
    LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
164
151k
                            block + c * size, scratch);
165
151k
  }
166
50.7k
}
void jxl::N_AVX2::DequantBlock<(jxl::ACType)0>(jxl::AcStrategy const&, float, int, float, float, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, unsigned long, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
148
122k
                  float* JXL_RESTRICT scratch) {
149
122k
  const auto scaled_dequant_s = inv_global_scale / quant;
150
151
122k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
152
122k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
153
122k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
154
155
122k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
156
157
1.79M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
158
1.67M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
159
1.67M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
160
1.67M
                         qblock, block);
161
1.67M
  }
162
490k
  for (size_t c = 0; c < 3; c++) {
163
367k
    LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
164
367k
                            block + c * size, scratch);
165
367k
  }
166
122k
}
void jxl::N_AVX2::DequantBlock<(jxl::ACType)1>(jxl::AcStrategy const&, float, int, float, float, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, unsigned long, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
148
77.8k
                  float* JXL_RESTRICT scratch) {
149
77.8k
  const auto scaled_dequant_s = inv_global_scale / quant;
150
151
77.8k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
152
77.8k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
153
77.8k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
154
155
77.8k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
156
157
962k
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
158
885k
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
159
885k
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
160
885k
                         qblock, block);
161
885k
  }
162
310k
  for (size_t c = 0; c < 3; c++) {
163
233k
    LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
164
233k
                            block + c * size, scratch);
165
233k
  }
166
77.8k
}
void jxl::N_SSE2::DequantBlock<(jxl::ACType)0>(jxl::AcStrategy const&, float, int, float, float, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, unsigned long, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
148
80.0k
                  float* JXL_RESTRICT scratch) {
149
80.0k
  const auto scaled_dequant_s = inv_global_scale / quant;
150
151
80.0k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
152
80.0k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
153
80.0k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
154
155
80.0k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
156
157
2.43M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
158
2.35M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
159
2.35M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
160
2.35M
                         qblock, block);
161
2.35M
  }
162
320k
  for (size_t c = 0; c < 3; c++) {
163
239k
    LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
164
239k
                            block + c * size, scratch);
165
239k
  }
166
80.0k
}
void jxl::N_SSE2::DequantBlock<(jxl::ACType)1>(jxl::AcStrategy const&, float, int, float, float, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, unsigned long, unsigned long, jxl::Quantizer const&, unsigned long, unsigned long const*, float const* restrict*, unsigned long, float const*, jxl::ACPtr*, float*, float*)
Line
Count
Source
148
97.0k
                  float* JXL_RESTRICT scratch) {
149
97.0k
  const auto scaled_dequant_s = inv_global_scale / quant;
150
151
97.0k
  const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
152
97.0k
  const auto scaled_dequant_y = Set(d, scaled_dequant_s);
153
97.0k
  const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
154
155
97.0k
  const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
156
157
1.99M
  for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
158
1.89M
    DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
159
1.89M
                         dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
160
1.89M
                         qblock, block);
161
1.89M
  }
162
387k
  for (size_t c = 0; c < 3; c++) {
163
290k
    LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
164
290k
                            block + c * size, scratch);
165
290k
  }
166
97.0k
}
167
168
Status DecodeGroupImpl(const FrameHeader& frame_header,
169
                       GetBlock* JXL_RESTRICT get_block,
170
                       GroupDecCache* JXL_RESTRICT group_dec_cache,
171
                       PassesDecoderState* JXL_RESTRICT dec_state,
172
                       size_t thread, size_t group_idx,
173
                       RenderPipelineInput& render_pipeline_input,
174
10.7k
                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
175
  // TODO(veluca): investigate cache usage in this function.
176
10.7k
  const Rect block_rect =
177
10.7k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
178
10.7k
  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
179
180
10.7k
  const size_t xsize_blocks = block_rect.xsize();
181
10.7k
  const size_t ysize_blocks = block_rect.ysize();
182
183
10.7k
  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
184
185
10.7k
  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
186
187
10.7k
  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
188
189
10.7k
  const auto kJpegDctMin = Set(di16_full, -4095);
190
10.7k
  const auto kJpegDctMax = Set(di16_full, 4095);
191
192
10.7k
  size_t idct_stride[3];
193
43.0k
  for (size_t c = 0; c < 3; c++) {
194
32.2k
    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
195
32.2k
  }
196
197
10.7k
  HWY_ALIGN int32_t scaled_qtable[64 * 3];
198
199
10.7k
  ACType ac_type = dec_state->coefficients->Type();
200
10.7k
  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
201
10.7k
                                              : DequantBlock<ACType::k32>;
202
  // Whether or not coefficients should be stored for future usage, and/or read
203
  // from past usage.
204
10.7k
  bool accumulate = !dec_state->coefficients->IsEmpty();
205
  // Offset of the current block in the group.
206
10.7k
  size_t offset = 0;
207
208
10.7k
  std::array<int, 3> jpeg_c_map;
209
10.7k
  bool jpeg_is_gray = false;
210
10.7k
  std::array<int, 3> dcoff = {};
211
212
  // TODO(veluca): all of this should be done only once per image.
213
10.7k
  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
214
10.7k
  if (jpeg_data) {
215
312
    if (!color_correlation.IsJPEGCompatible()) {
216
17
      return JXL_FAILURE("The CfL map is not JPEG-compatible");
217
17
    }
218
295
    jpeg_is_gray = (jpeg_data->components.size() == 1);
219
295
    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
220
295
    const std::vector<QuantEncoding>& qe =
221
295
        dec_state->shared->matrices.encodings();
222
295
    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
223
295
        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
224
0
      return JXL_FAILURE(
225
0
          "Quantization table is not a JPEG quantization table.");
226
0
    }
227
1.14k
    for (size_t c = 0; c < 3; c++) {
228
867
      if (frame_header.color_transform == ColorTransform::kNone) {
229
24
        dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c];
230
24
      }
231
55.3k
      for (size_t i = 0; i < 64; i++) {
232
        // Transpose the matrix, as it will be used on the transposed block.
233
54.4k
        int n = qe[0].qraw.qtable->at(64 + i);
234
54.4k
        int d = qe[0].qraw.qtable->at(64 * c + i);
235
54.4k
        if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) {
236
20
          return JXL_FAILURE("Invalid JPEG quantization table");
237
20
        }
238
54.4k
        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
239
54.4k
            (1 << kCFLFixedPointPrecision) * n / d;
240
54.4k
      }
241
867
    }
242
295
  }
243
244
10.7k
  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
245
10.7k
  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
246
10.7k
  Rect r[3];
247
42.8k
  for (size_t i = 0; i < 3; i++) {
248
32.0k
    r[i] =
249
32.0k
        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
250
32.0k
             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
251
32.0k
    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
252
32.0k
                        dec_state->shared->dc->Plane(i).ysize()})) {
253
0
      return JXL_FAILURE("Frame dimensions are too big for the image.");
254
0
    }
255
32.0k
  }
256
257
172k
  for (size_t by = 0; by < ysize_blocks; ++by) {
258
163k
    get_block->StartRow(by);
259
163k
    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
260
261
163k
    const int32_t* JXL_RESTRICT row_quant =
262
163k
        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
263
264
163k
    const float* JXL_RESTRICT dc_rows[3] = {
265
163k
        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
266
163k
        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
267
163k
        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
268
163k
    };
269
270
163k
    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
271
163k
    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
272
273
163k
    const int8_t* JXL_RESTRICT row_cmap[3] = {
274
163k
        dec_state->shared->cmap.ytox_map.ConstRow(ty),
275
163k
        nullptr,
276
163k
        dec_state->shared->cmap.ytob_map.ConstRow(ty),
277
163k
    };
278
279
163k
    float* JXL_RESTRICT idct_row[3];
280
163k
    int16_t* JXL_RESTRICT jpeg_row[3];
281
652k
    for (size_t c = 0; c < 3; c++) {
282
488k
      idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row(
283
488k
          render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim);
284
488k
      if (jpeg_data) {
285
825
        auto& component = jpeg_data->components[jpeg_c_map[c]];
286
825
        jpeg_row[c] =
287
825
            component.coeffs.data() +
288
825
            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
289
825
                kDCTBlockSize;
290
825
      }
291
488k
    }
292
293
163k
    size_t bx = 0;
294
386k
    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
295
224k
         tx++) {
296
224k
      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
297
224k
      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
298
224k
      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
299
      // Increment bx by llf_x because those iterations would otherwise
300
      // immediately continue (!IsFirstBlock). Reduces mispredictions.
301
820k
      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
302
597k
        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
303
597k
        AcStrategy acs = acs_row[bx];
304
597k
        const size_t llf_x = acs.covered_blocks_x();
305
306
        // Can only happen in the second or lower rows of a varblock.
307
597k
        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
308
74.1k
          bx += llf_x;
309
74.1k
          continue;
310
74.1k
        }
311
523k
        const size_t log2_covered_blocks = acs.log2_covered_blocks();
312
313
523k
        const size_t covered_blocks = 1 << log2_covered_blocks;
314
523k
        const size_t size = covered_blocks * kDCTBlockSize;
315
316
523k
        ACPtr qblock[3];
317
523k
        if (accumulate) {
318
528
          for (size_t c = 0; c < 3; c++) {
319
396
            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
320
396
          }
321
522k
        } else {
322
          // No point in reading from bitstream without accumulating and not
323
          // drawing.
324
522k
          JXL_ASSERT(draw == kDraw);
325
522k
          if (ac_type == ACType::k16) {
326
297k
            memset(group_dec_cache->dec_group_qblock16, 0,
327
297k
                   size * 3 * sizeof(int16_t));
328
1.18M
            for (size_t c = 0; c < 3; c++) {
329
891k
              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
330
891k
            }
331
297k
          } else {
332
225k
            memset(group_dec_cache->dec_group_qblock, 0,
333
225k
                   size * 3 * sizeof(int32_t));
334
903k
            for (size_t c = 0; c < 3; c++) {
335
678k
              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
336
678k
            }
337
225k
          }
338
522k
        }
339
523k
        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
340
523k
            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
341
521k
        offset += size;
342
521k
        if (draw == kDontDraw) {
343
117
          bx += llf_x;
344
117
          continue;
345
117
        }
346
347
521k
        if (JXL_UNLIKELY(jpeg_data)) {
348
275
          if (acs.Strategy() != AcStrategy::Type::DCT) {
349
4
            return JXL_FAILURE(
350
4
                "Can only decode to JPEG if only DCT-8 is used.");
351
4
          }
352
353
271
          HWY_ALIGN int32_t transposed_dct_y[64];
354
807
          for (size_t c : {1, 0, 2}) {
355
            // Propagate only Y for grayscale.
356
807
            if (jpeg_is_gray && c != 1) {
357
536
              continue;
358
536
            }
359
271
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
360
0
              continue;
361
0
            }
362
271
            int16_t* JXL_RESTRICT jpeg_pos =
363
271
                jpeg_row[c] + sbx[c] * kDCTBlockSize;
364
            // JPEG XL is transposed, JPEG is not.
365
271
            auto* transposed_dct = qblock[c].ptr32;
366
271
            Transpose8x8InPlace(transposed_dct);
367
            // No CfL - no need to store the y block converted to integers.
368
271
            if (!cs.Is444() ||
369
271
                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
370
2.80k
              for (size_t i = 0; i < 64; i += Lanes(d)) {
371
2.55k
                const auto ini = Load(di, transposed_dct + i);
372
2.55k
                const auto ini16 = DemoteTo(di16, ini);
373
2.55k
                StoreU(ini16, di16, jpeg_pos + i);
374
2.55k
              }
375
252
            } else if (c == 1) {
376
              // Y channel: save for restoring X/B, but nothing else to do.
377
219
              for (size_t i = 0; i < 64; i += Lanes(d)) {
378
200
                const auto ini = Load(di, transposed_dct + i);
379
200
                Store(ini, di, transposed_dct_y + i);
380
200
                const auto ini16 = DemoteTo(di16, ini);
381
200
                StoreU(ini16, di16, jpeg_pos + i);
382
200
              }
383
19
            } else {
384
              // transposed_dct_y contains the y channel block, transposed.
385
0
              const auto scale =
386
0
                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
387
0
              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
388
0
              for (int i = 0; i < 64; i += Lanes(d)) {
389
0
                auto in = Load(di, transposed_dct + i);
390
0
                auto in_y = Load(di, transposed_dct_y + i);
391
0
                auto qt = Load(di, scaled_qtable + c * size + i);
392
0
                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
393
0
                    Add(Mul(qt, scale), round));
394
0
                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
395
0
                    Add(Mul(in_y, coeff_scale), round));
396
0
                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
397
0
              }
398
0
            }
399
271
            jpeg_pos[0] =
400
271
                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
401
271
            auto overflow = MaskFromVec(Set(di16_full, 0));
402
271
            auto underflow = MaskFromVec(Set(di16_full, 0));
403
1.64k
            for (int i = 0; i < 64; i += Lanes(di16_full)) {
404
1.37k
              auto in = LoadU(di16_full, jpeg_pos + i);
405
1.37k
              overflow = Or(overflow, Gt(in, kJpegDctMax));
406
1.37k
              underflow = Or(underflow, Lt(in, kJpegDctMin));
407
1.37k
            }
408
271
            if (!AllFalse(di16_full, Or(overflow, underflow))) {
409
3
              return JXL_FAILURE("JPEG DCT coefficients out of range");
410
3
            }
411
271
          }
412
521k
        } else {
413
521k
          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
414
          // Dequantize and add predictions.
415
521k
          dequant_block(
416
521k
              acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
417
521k
              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(),
418
521k
              size, dec_state->shared->quantizer,
419
521k
              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
420
521k
              dc_stride,
421
521k
              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
422
521k
              block, group_dec_cache->scratch_space);
423
424
1.56M
          for (size_t c : {1, 0, 2}) {
425
1.56M
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
426
178k
              continue;
427
178k
            }
428
            // IDCT
429
1.38M
            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
430
1.38M
            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
431
1.38M
                              idct_stride[c], group_dec_cache->scratch_space);
432
1.38M
          }
433
521k
        }
434
521k
        bx += llf_x;
435
521k
      }
436
224k
    }
437
163k
  }
438
9.38k
  return true;
439
10.7k
}
jxl::N_SSE4::DecodeGroupImpl(jxl::FrameHeader const&, jxl::GetBlock*, jxl::GroupDecCache*, jxl::PassesDecoderState*, unsigned long, unsigned long, jxl::RenderPipelineInput&, jxl::jpeg::JPEGData*, jxl::DrawMode)
Line
Count
Source
174
3.08k
                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
175
  // TODO(veluca): investigate cache usage in this function.
176
3.08k
  const Rect block_rect =
177
3.08k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
178
3.08k
  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
179
180
3.08k
  const size_t xsize_blocks = block_rect.xsize();
181
3.08k
  const size_t ysize_blocks = block_rect.ysize();
182
183
3.08k
  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
184
185
3.08k
  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
186
187
3.08k
  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
188
189
3.08k
  const auto kJpegDctMin = Set(di16_full, -4095);
190
3.08k
  const auto kJpegDctMax = Set(di16_full, 4095);
191
192
3.08k
  size_t idct_stride[3];
193
12.3k
  for (size_t c = 0; c < 3; c++) {
194
9.24k
    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
195
9.24k
  }
196
197
3.08k
  HWY_ALIGN int32_t scaled_qtable[64 * 3];
198
199
3.08k
  ACType ac_type = dec_state->coefficients->Type();
200
3.08k
  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
201
3.08k
                                              : DequantBlock<ACType::k32>;
202
  // Whether or not coefficients should be stored for future usage, and/or read
203
  // from past usage.
204
3.08k
  bool accumulate = !dec_state->coefficients->IsEmpty();
205
  // Offset of the current block in the group.
206
3.08k
  size_t offset = 0;
207
208
3.08k
  std::array<int, 3> jpeg_c_map;
209
3.08k
  bool jpeg_is_gray = false;
210
3.08k
  std::array<int, 3> dcoff = {};
211
212
  // TODO(veluca): all of this should be done only once per image.
213
3.08k
  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
214
3.08k
  if (jpeg_data) {
215
68
    if (!color_correlation.IsJPEGCompatible()) {
216
5
      return JXL_FAILURE("The CfL map is not JPEG-compatible");
217
5
    }
218
63
    jpeg_is_gray = (jpeg_data->components.size() == 1);
219
63
    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
220
63
    const std::vector<QuantEncoding>& qe =
221
63
        dec_state->shared->matrices.encodings();
222
63
    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
223
63
        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
224
0
      return JXL_FAILURE(
225
0
          "Quantization table is not a JPEG quantization table.");
226
0
    }
227
240
    for (size_t c = 0; c < 3; c++) {
228
183
      if (frame_header.color_transform == ColorTransform::kNone) {
229
6
        dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c];
230
6
      }
231
11.5k
      for (size_t i = 0; i < 64; i++) {
232
        // Transpose the matrix, as it will be used on the transposed block.
233
11.3k
        int n = qe[0].qraw.qtable->at(64 + i);
234
11.3k
        int d = qe[0].qraw.qtable->at(64 * c + i);
235
11.3k
        if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) {
236
6
          return JXL_FAILURE("Invalid JPEG quantization table");
237
6
        }
238
11.3k
        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
239
11.3k
            (1 << kCFLFixedPointPrecision) * n / d;
240
11.3k
      }
241
183
    }
242
63
  }
243
244
3.07k
  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
245
3.07k
  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
246
3.07k
  Rect r[3];
247
12.2k
  for (size_t i = 0; i < 3; i++) {
248
9.20k
    r[i] =
249
9.20k
        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
250
9.20k
             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
251
9.20k
    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
252
9.20k
                        dec_state->shared->dc->Plane(i).ysize()})) {
253
0
      return JXL_FAILURE("Frame dimensions are too big for the image.");
254
0
    }
255
9.20k
  }
256
257
43.7k
  for (size_t by = 0; by < ysize_blocks; ++by) {
258
41.0k
    get_block->StartRow(by);
259
41.0k
    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
260
261
41.0k
    const int32_t* JXL_RESTRICT row_quant =
262
41.0k
        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
263
264
41.0k
    const float* JXL_RESTRICT dc_rows[3] = {
265
41.0k
        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
266
41.0k
        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
267
41.0k
        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
268
41.0k
    };
269
270
41.0k
    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
271
41.0k
    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
272
273
41.0k
    const int8_t* JXL_RESTRICT row_cmap[3] = {
274
41.0k
        dec_state->shared->cmap.ytox_map.ConstRow(ty),
275
41.0k
        nullptr,
276
41.0k
        dec_state->shared->cmap.ytob_map.ConstRow(ty),
277
41.0k
    };
278
279
41.0k
    float* JXL_RESTRICT idct_row[3];
280
41.0k
    int16_t* JXL_RESTRICT jpeg_row[3];
281
163k
    for (size_t c = 0; c < 3; c++) {
282
122k
      idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row(
283
122k
          render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim);
284
122k
      if (jpeg_data) {
285
171
        auto& component = jpeg_data->components[jpeg_c_map[c]];
286
171
        jpeg_row[c] =
287
171
            component.coeffs.data() +
288
171
            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
289
171
                kDCTBlockSize;
290
171
      }
291
122k
    }
292
293
41.0k
    size_t bx = 0;
294
101k
    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
295
60.6k
         tx++) {
296
60.6k
      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
297
60.6k
      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
298
60.6k
      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
299
      // Increment bx by llf_x because those iterations would otherwise
300
      // immediately continue (!IsFirstBlock). Reduces mispredictions.
301
223k
      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
302
163k
        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
303
163k
        AcStrategy acs = acs_row[bx];
304
163k
        const size_t llf_x = acs.covered_blocks_x();
305
306
        // Can only happen in the second or lower rows of a varblock.
307
163k
        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
308
18.5k
          bx += llf_x;
309
18.5k
          continue;
310
18.5k
        }
311
145k
        const size_t log2_covered_blocks = acs.log2_covered_blocks();
312
313
145k
        const size_t covered_blocks = 1 << log2_covered_blocks;
314
145k
        const size_t size = covered_blocks * kDCTBlockSize;
315
316
145k
        ACPtr qblock[3];
317
145k
        if (accumulate) {
318
240
          for (size_t c = 0; c < 3; c++) {
319
180
            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
320
180
          }
321
145k
        } else {
322
          // No point in reading from bitstream without accumulating and not
323
          // drawing.
324
145k
          JXL_ASSERT(draw == kDraw);
325
145k
          if (ac_type == ACType::k16) {
326
94.3k
            memset(group_dec_cache->dec_group_qblock16, 0,
327
94.3k
                   size * 3 * sizeof(int16_t));
328
377k
            for (size_t c = 0; c < 3; c++) {
329
283k
              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
330
283k
            }
331
94.3k
          } else {
332
50.6k
            memset(group_dec_cache->dec_group_qblock, 0,
333
50.6k
                   size * 3 * sizeof(int32_t));
334
203k
            for (size_t c = 0; c < 3; c++) {
335
152k
              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
336
152k
            }
337
50.6k
          }
338
145k
        }
339
145k
        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
340
145k
            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
341
144k
        offset += size;
342
144k
        if (draw == kDontDraw) {
343
54
          bx += llf_x;
344
54
          continue;
345
54
        }
346
347
144k
        if (JXL_UNLIKELY(jpeg_data)) {
348
57
          if (acs.Strategy() != AcStrategy::Type::DCT) {
349
1
            return JXL_FAILURE(
350
1
                "Can only decode to JPEG if only DCT-8 is used.");
351
1
          }
352
353
56
          HWY_ALIGN int32_t transposed_dct_y[64];
354
166
          for (size_t c : {1, 0, 2}) {
355
            // Propagate only Y for grayscale.
356
166
            if (jpeg_is_gray && c != 1) {
357
110
              continue;
358
110
            }
359
56
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
360
0
              continue;
361
0
            }
362
56
            int16_t* JXL_RESTRICT jpeg_pos =
363
56
                jpeg_row[c] + sbx[c] * kDCTBlockSize;
364
            // JPEG XL is transposed, JPEG is not.
365
56
            auto* transposed_dct = qblock[c].ptr32;
366
56
            Transpose8x8InPlace(transposed_dct);
367
            // No CfL - no need to store the y block converted to integers.
368
56
            if (!cs.Is444() ||
369
56
                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
370
918
              for (size_t i = 0; i < 64; i += Lanes(d)) {
371
864
                const auto ini = Load(di, transposed_dct + i);
372
864
                const auto ini16 = DemoteTo(di16, ini);
373
864
                StoreU(ini16, di16, jpeg_pos + i);
374
864
              }
375
54
            } else if (c == 1) {
376
              // Y channel: save for restoring X/B, but nothing else to do.
377
34
              for (size_t i = 0; i < 64; i += Lanes(d)) {
378
32
                const auto ini = Load(di, transposed_dct + i);
379
32
                Store(ini, di, transposed_dct_y + i);
380
32
                const auto ini16 = DemoteTo(di16, ini);
381
32
                StoreU(ini16, di16, jpeg_pos + i);
382
32
              }
383
2
            } else {
384
              // transposed_dct_y contains the y channel block, transposed.
385
0
              const auto scale =
386
0
                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
387
0
              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
388
0
              for (int i = 0; i < 64; i += Lanes(d)) {
389
0
                auto in = Load(di, transposed_dct + i);
390
0
                auto in_y = Load(di, transposed_dct_y + i);
391
0
                auto qt = Load(di, scaled_qtable + c * size + i);
392
0
                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
393
0
                    Add(Mul(qt, scale), round));
394
0
                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
395
0
                    Add(Mul(in_y, coeff_scale), round));
396
0
                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
397
0
              }
398
0
            }
399
56
            jpeg_pos[0] =
400
56
                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
401
56
            auto overflow = MaskFromVec(Set(di16_full, 0));
402
56
            auto underflow = MaskFromVec(Set(di16_full, 0));
403
504
            for (int i = 0; i < 64; i += Lanes(di16_full)) {
404
448
              auto in = LoadU(di16_full, jpeg_pos + i);
405
448
              overflow = Or(overflow, Gt(in, kJpegDctMax));
406
448
              underflow = Or(underflow, Lt(in, kJpegDctMin));
407
448
            }
408
56
            if (!AllFalse(di16_full, Or(overflow, underflow))) {
409
1
              return JXL_FAILURE("JPEG DCT coefficients out of range");
410
1
            }
411
56
          }
412
144k
        } else {
413
144k
          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
414
          // Dequantize and add predictions.
415
144k
          dequant_block(
416
144k
              acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
417
144k
              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(),
418
144k
              size, dec_state->shared->quantizer,
419
144k
              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
420
144k
              dc_stride,
421
144k
              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
422
144k
              block, group_dec_cache->scratch_space);
423
424
432k
          for (size_t c : {1, 0, 2}) {
425
432k
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
426
24.8k
              continue;
427
24.8k
            }
428
            // IDCT
429
407k
            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
430
407k
            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
431
407k
                              idct_stride[c], group_dec_cache->scratch_space);
432
407k
          }
433
144k
        }
434
144k
        bx += llf_x;
435
144k
      }
436
60.6k
    }
437
41.0k
  }
438
2.66k
  return true;
439
3.07k
}
jxl::N_AVX2::DecodeGroupImpl(jxl::FrameHeader const&, jxl::GetBlock*, jxl::GroupDecCache*, jxl::PassesDecoderState*, unsigned long, unsigned long, jxl::RenderPipelineInput&, jxl::jpeg::JPEGData*, jxl::DrawMode)
Line
Count
Source
174
4.23k
                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
175
  // TODO(veluca): investigate cache usage in this function.
176
4.23k
  const Rect block_rect =
177
4.23k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
178
4.23k
  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
179
180
4.23k
  const size_t xsize_blocks = block_rect.xsize();
181
4.23k
  const size_t ysize_blocks = block_rect.ysize();
182
183
4.23k
  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
184
185
4.23k
  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
186
187
4.23k
  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
188
189
4.23k
  const auto kJpegDctMin = Set(di16_full, -4095);
190
4.23k
  const auto kJpegDctMax = Set(di16_full, 4095);
191
192
4.23k
  size_t idct_stride[3];
193
16.9k
  for (size_t c = 0; c < 3; c++) {
194
12.6k
    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
195
12.6k
  }
196
197
4.23k
  HWY_ALIGN int32_t scaled_qtable[64 * 3];
198
199
4.23k
  ACType ac_type = dec_state->coefficients->Type();
200
4.23k
  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
201
4.23k
                                              : DequantBlock<ACType::k32>;
202
  // Whether or not coefficients should be stored for future usage, and/or read
203
  // from past usage.
204
4.23k
  bool accumulate = !dec_state->coefficients->IsEmpty();
205
  // Offset of the current block in the group.
206
4.23k
  size_t offset = 0;
207
208
4.23k
  std::array<int, 3> jpeg_c_map;
209
4.23k
  bool jpeg_is_gray = false;
210
4.23k
  std::array<int, 3> dcoff = {};
211
212
  // TODO(veluca): all of this should be done only once per image.
213
4.23k
  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
214
4.23k
  if (jpeg_data) {
215
215
    if (!color_correlation.IsJPEGCompatible()) {
216
7
      return JXL_FAILURE("The CfL map is not JPEG-compatible");
217
7
    }
218
208
    jpeg_is_gray = (jpeg_data->components.size() == 1);
219
208
    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
220
208
    const std::vector<QuantEncoding>& qe =
221
208
        dec_state->shared->matrices.encodings();
222
208
    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
223
208
        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
224
0
      return JXL_FAILURE(
225
0
          "Quantization table is not a JPEG quantization table.");
226
0
    }
227
815
    for (size_t c = 0; c < 3; c++) {
228
616
      if (frame_header.color_transform == ColorTransform::kNone) {
229
12
        dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c];
230
12
      }
231
39.6k
      for (size_t i = 0; i < 64; i++) {
232
        // Transpose the matrix, as it will be used on the transposed block.
233
39.0k
        int n = qe[0].qraw.qtable->at(64 + i);
234
39.0k
        int d = qe[0].qraw.qtable->at(64 * c + i);
235
39.0k
        if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) {
236
9
          return JXL_FAILURE("Invalid JPEG quantization table");
237
9
        }
238
38.9k
        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
239
38.9k
            (1 << kCFLFixedPointPrecision) * n / d;
240
38.9k
      }
241
616
    }
242
208
  }
243
244
4.21k
  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
245
4.21k
  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
246
4.21k
  Rect r[3];
247
16.8k
  for (size_t i = 0; i < 3; i++) {
248
12.5k
    r[i] =
249
12.5k
        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
250
12.5k
             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
251
12.5k
    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
252
12.5k
                        dec_state->shared->dc->Plane(i).ysize()})) {
253
0
      return JXL_FAILURE("Frame dimensions are too big for the image.");
254
0
    }
255
12.5k
  }
256
257
63.5k
  for (size_t by = 0; by < ysize_blocks; ++by) {
258
59.8k
    get_block->StartRow(by);
259
59.8k
    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
260
261
59.8k
    const int32_t* JXL_RESTRICT row_quant =
262
59.8k
        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
263
264
59.8k
    const float* JXL_RESTRICT dc_rows[3] = {
265
59.8k
        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
266
59.8k
        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
267
59.8k
        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
268
59.8k
    };
269
270
59.8k
    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
271
59.8k
    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
272
273
59.8k
    const int8_t* JXL_RESTRICT row_cmap[3] = {
274
59.8k
        dec_state->shared->cmap.ytox_map.ConstRow(ty),
275
59.8k
        nullptr,
276
59.8k
        dec_state->shared->cmap.ytob_map.ConstRow(ty),
277
59.8k
    };
278
279
59.8k
    float* JXL_RESTRICT idct_row[3];
280
59.8k
    int16_t* JXL_RESTRICT jpeg_row[3];
281
238k
    for (size_t c = 0; c < 3; c++) {
282
179k
      idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row(
283
179k
          render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim);
284
179k
      if (jpeg_data) {
285
597
        auto& component = jpeg_data->components[jpeg_c_map[c]];
286
597
        jpeg_row[c] =
287
597
            component.coeffs.data() +
288
597
            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
289
597
                kDCTBlockSize;
290
597
      }
291
179k
    }
292
293
59.8k
    size_t bx = 0;
294
142k
    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
295
83.6k
         tx++) {
296
83.6k
      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
297
83.6k
      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
298
83.6k
      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
299
      // Increment bx by llf_x because those iterations would otherwise
300
      // immediately continue (!IsFirstBlock). Reduces mispredictions.
301
311k
      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
302
227k
        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
303
227k
        AcStrategy acs = acs_row[bx];
304
227k
        const size_t llf_x = acs.covered_blocks_x();
305
306
        // Can only happen in the second or lower rows of a varblock.
307
227k
        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
308
27.1k
          bx += llf_x;
309
27.1k
          continue;
310
27.1k
        }
311
200k
        const size_t log2_covered_blocks = acs.log2_covered_blocks();
312
313
200k
        const size_t covered_blocks = 1 << log2_covered_blocks;
314
200k
        const size_t size = covered_blocks * kDCTBlockSize;
315
316
200k
        ACPtr qblock[3];
317
200k
        if (accumulate) {
318
200
          for (size_t c = 0; c < 3; c++) {
319
150
            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
320
150
          }
321
200k
        } else {
322
          // No point in reading from bitstream without accumulating and not
323
          // drawing.
324
200k
          JXL_ASSERT(draw == kDraw);
325
200k
          if (ac_type == ACType::k16) {
326
122k
            memset(group_dec_cache->dec_group_qblock16, 0,
327
122k
                   size * 3 * sizeof(int16_t));
328
490k
            for (size_t c = 0; c < 3; c++) {
329
368k
              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
330
368k
            }
331
122k
          } else {
332
78.0k
            memset(group_dec_cache->dec_group_qblock, 0,
333
78.0k
                   size * 3 * sizeof(int32_t));
334
312k
            for (size_t c = 0; c < 3; c++) {
335
234k
              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
336
234k
            }
337
78.0k
          }
338
200k
        }
339
200k
        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
340
200k
            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
341
200k
        offset += size;
342
200k
        if (draw == kDontDraw) {
343
43
          bx += llf_x;
344
43
          continue;
345
43
        }
346
347
200k
        if (JXL_UNLIKELY(jpeg_data)) {
348
199
          if (acs.Strategy() != AcStrategy::Type::DCT) {
349
1
            return JXL_FAILURE(
350
1
                "Can only decode to JPEG if only DCT-8 is used.");
351
1
          }
352
353
198
          HWY_ALIGN int32_t transposed_dct_y[64];
354
592
          for (size_t c : {1, 0, 2}) {
355
            // Propagate only Y for grayscale.
356
592
            if (jpeg_is_gray && c != 1) {
357
394
              continue;
358
394
            }
359
198
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
360
0
              continue;
361
0
            }
362
198
            int16_t* JXL_RESTRICT jpeg_pos =
363
198
                jpeg_row[c] + sbx[c] * kDCTBlockSize;
364
            // JPEG XL is transposed, JPEG is not.
365
198
            auto* transposed_dct = qblock[c].ptr32;
366
198
            Transpose8x8InPlace(transposed_dct);
367
            // No CfL - no need to store the y block converted to integers.
368
198
            if (!cs.Is444() ||
369
198
                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
370
1.66k
              for (size_t i = 0; i < 64; i += Lanes(d)) {
371
1.48k
                const auto ini = Load(di, transposed_dct + i);
372
1.48k
                const auto ini16 = DemoteTo(di16, ini);
373
1.48k
                StoreU(ini16, di16, jpeg_pos + i);
374
1.48k
              }
375
185
            } else if (c == 1) {
376
              // Y channel: save for restoring X/B, but nothing else to do.
377
117
              for (size_t i = 0; i < 64; i += Lanes(d)) {
378
104
                const auto ini = Load(di, transposed_dct + i);
379
104
                Store(ini, di, transposed_dct_y + i);
380
104
                const auto ini16 = DemoteTo(di16, ini);
381
104
                StoreU(ini16, di16, jpeg_pos + i);
382
104
              }
383
13
            } else {
384
              // transposed_dct_y contains the y channel block, transposed.
385
0
              const auto scale =
386
0
                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
387
0
              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
388
0
              for (int i = 0; i < 64; i += Lanes(d)) {
389
0
                auto in = Load(di, transposed_dct + i);
390
0
                auto in_y = Load(di, transposed_dct_y + i);
391
0
                auto qt = Load(di, scaled_qtable + c * size + i);
392
0
                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
393
0
                    Add(Mul(qt, scale), round));
394
0
                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
395
0
                    Add(Mul(in_y, coeff_scale), round));
396
0
                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
397
0
              }
398
0
            }
399
198
            jpeg_pos[0] =
400
198
                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
401
198
            auto overflow = MaskFromVec(Set(di16_full, 0));
402
198
            auto underflow = MaskFromVec(Set(di16_full, 0));
403
990
            for (int i = 0; i < 64; i += Lanes(di16_full)) {
404
792
              auto in = LoadU(di16_full, jpeg_pos + i);
405
792
              overflow = Or(overflow, Gt(in, kJpegDctMax));
406
792
              underflow = Or(underflow, Lt(in, kJpegDctMin));
407
792
            }
408
198
            if (!AllFalse(di16_full, Or(overflow, underflow))) {
409
1
              return JXL_FAILURE("JPEG DCT coefficients out of range");
410
1
            }
411
198
          }
412
200k
        } else {
413
200k
          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
414
          // Dequantize and add predictions.
415
200k
          dequant_block(
416
200k
              acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
417
200k
              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(),
418
200k
              size, dec_state->shared->quantizer,
419
200k
              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
420
200k
              dc_stride,
421
200k
              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
422
200k
              block, group_dec_cache->scratch_space);
423
424
599k
          for (size_t c : {1, 0, 2}) {
425
599k
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
426
36.2k
              continue;
427
36.2k
            }
428
            // IDCT
429
562k
            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
430
562k
            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
431
562k
                              idct_stride[c], group_dec_cache->scratch_space);
432
562k
          }
433
200k
        }
434
200k
        bx += llf_x;
435
200k
      }
436
83.6k
    }
437
59.8k
  }
438
3.70k
  return true;
439
4.21k
}
jxl::N_SSE2::DecodeGroupImpl(jxl::FrameHeader const&, jxl::GetBlock*, jxl::GroupDecCache*, jxl::PassesDecoderState*, unsigned long, unsigned long, jxl::RenderPipelineInput&, jxl::jpeg::JPEGData*, jxl::DrawMode)
Line
Count
Source
174
3.44k
                       jpeg::JPEGData* jpeg_data, DrawMode draw) {
175
  // TODO(veluca): investigate cache usage in this function.
176
3.44k
  const Rect block_rect =
177
3.44k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx);
178
3.44k
  const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
179
180
3.44k
  const size_t xsize_blocks = block_rect.xsize();
181
3.44k
  const size_t ysize_blocks = block_rect.ysize();
182
183
3.44k
  const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
184
185
3.44k
  const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
186
187
3.44k
  const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
188
189
3.44k
  const auto kJpegDctMin = Set(di16_full, -4095);
190
3.44k
  const auto kJpegDctMax = Set(di16_full, 4095);
191
192
3.44k
  size_t idct_stride[3];
193
13.7k
  for (size_t c = 0; c < 3; c++) {
194
10.3k
    idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
195
10.3k
  }
196
197
3.44k
  HWY_ALIGN int32_t scaled_qtable[64 * 3];
198
199
3.44k
  ACType ac_type = dec_state->coefficients->Type();
200
3.44k
  auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
201
3.44k
                                              : DequantBlock<ACType::k32>;
202
  // Whether or not coefficients should be stored for future usage, and/or read
203
  // from past usage.
204
3.44k
  bool accumulate = !dec_state->coefficients->IsEmpty();
205
  // Offset of the current block in the group.
206
3.44k
  size_t offset = 0;
207
208
3.44k
  std::array<int, 3> jpeg_c_map;
209
3.44k
  bool jpeg_is_gray = false;
210
3.44k
  std::array<int, 3> dcoff = {};
211
212
  // TODO(veluca): all of this should be done only once per image.
213
3.44k
  const ColorCorrelation& color_correlation = dec_state->shared->cmap.base();
214
3.44k
  if (jpeg_data) {
215
29
    if (!color_correlation.IsJPEGCompatible()) {
216
5
      return JXL_FAILURE("The CfL map is not JPEG-compatible");
217
5
    }
218
24
    jpeg_is_gray = (jpeg_data->components.size() == 1);
219
24
    jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
220
24
    const std::vector<QuantEncoding>& qe =
221
24
        dec_state->shared->matrices.encodings();
222
24
    if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
223
24
        std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
224
0
      return JXL_FAILURE(
225
0
          "Quantization table is not a JPEG quantization table.");
226
0
    }
227
87
    for (size_t c = 0; c < 3; c++) {
228
68
      if (frame_header.color_transform == ColorTransform::kNone) {
229
6
        dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c];
230
6
      }
231
4.15k
      for (size_t i = 0; i < 64; i++) {
232
        // Transpose the matrix, as it will be used on the transposed block.
233
4.09k
        int n = qe[0].qraw.qtable->at(64 + i);
234
4.09k
        int d = qe[0].qraw.qtable->at(64 * c + i);
235
4.09k
        if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) {
236
5
          return JXL_FAILURE("Invalid JPEG quantization table");
237
5
        }
238
4.09k
        scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
239
4.09k
            (1 << kCFLFixedPointPrecision) * n / d;
240
4.09k
      }
241
68
    }
242
24
  }
243
244
3.43k
  size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
245
3.43k
  size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
246
3.43k
  Rect r[3];
247
13.7k
  for (size_t i = 0; i < 3; i++) {
248
10.2k
    r[i] =
249
10.2k
        Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
250
10.2k
             block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
251
10.2k
    if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
252
10.2k
                        dec_state->shared->dc->Plane(i).ysize()})) {
253
0
      return JXL_FAILURE("Frame dimensions are too big for the image.");
254
0
    }
255
10.2k
  }
256
257
65.4k
  for (size_t by = 0; by < ysize_blocks; ++by) {
258
62.3k
    get_block->StartRow(by);
259
62.3k
    size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
260
261
62.3k
    const int32_t* JXL_RESTRICT row_quant =
262
62.3k
        block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
263
264
62.3k
    const float* JXL_RESTRICT dc_rows[3] = {
265
62.3k
        r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
266
62.3k
        r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
267
62.3k
        r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
268
62.3k
    };
269
270
62.3k
    const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
271
62.3k
    AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
272
273
62.3k
    const int8_t* JXL_RESTRICT row_cmap[3] = {
274
62.3k
        dec_state->shared->cmap.ytox_map.ConstRow(ty),
275
62.3k
        nullptr,
276
62.3k
        dec_state->shared->cmap.ytob_map.ConstRow(ty),
277
62.3k
    };
278
279
62.3k
    float* JXL_RESTRICT idct_row[3];
280
62.3k
    int16_t* JXL_RESTRICT jpeg_row[3];
281
249k
    for (size_t c = 0; c < 3; c++) {
282
186k
      idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row(
283
186k
          render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim);
284
186k
      if (jpeg_data) {
285
57
        auto& component = jpeg_data->components[jpeg_c_map[c]];
286
57
        jpeg_row[c] =
287
57
            component.coeffs.data() +
288
57
            (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
289
57
                kDCTBlockSize;
290
57
      }
291
186k
    }
292
293
62.3k
    size_t bx = 0;
294
142k
    for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
295
80.0k
         tx++) {
296
80.0k
      size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
297
80.0k
      auto x_cc_mul = Set(d, color_correlation.YtoXRatio(row_cmap[0][abs_tx]));
298
80.0k
      auto b_cc_mul = Set(d, color_correlation.YtoBRatio(row_cmap[2][abs_tx]));
299
      // Increment bx by llf_x because those iterations would otherwise
300
      // immediately continue (!IsFirstBlock). Reduces mispredictions.
301
285k
      for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
302
205k
        size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
303
205k
        AcStrategy acs = acs_row[bx];
304
205k
        const size_t llf_x = acs.covered_blocks_x();
305
306
        // Can only happen in the second or lower rows of a varblock.
307
205k
        if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
308
28.4k
          bx += llf_x;
309
28.4k
          continue;
310
28.4k
        }
311
177k
        const size_t log2_covered_blocks = acs.log2_covered_blocks();
312
313
177k
        const size_t covered_blocks = 1 << log2_covered_blocks;
314
177k
        const size_t size = covered_blocks * kDCTBlockSize;
315
316
177k
        ACPtr qblock[3];
317
177k
        if (accumulate) {
318
88
          for (size_t c = 0; c < 3; c++) {
319
66
            qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
320
66
          }
321
177k
        } else {
322
          // No point in reading from bitstream without accumulating and not
323
          // drawing.
324
177k
          JXL_ASSERT(draw == kDraw);
325
177k
          if (ac_type == ACType::k16) {
326
80.0k
            memset(group_dec_cache->dec_group_qblock16, 0,
327
80.0k
                   size * 3 * sizeof(int16_t));
328
320k
            for (size_t c = 0; c < 3; c++) {
329
240k
              qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
330
240k
            }
331
96.9k
          } else {
332
96.9k
            memset(group_dec_cache->dec_group_qblock, 0,
333
96.9k
                   size * 3 * sizeof(int32_t));
334
388k
            for (size_t c = 0; c < 3; c++) {
335
291k
              qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
336
291k
            }
337
96.9k
          }
338
177k
        }
339
177k
        JXL_RETURN_IF_ERROR(get_block->LoadBlock(
340
177k
            bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
341
176k
        offset += size;
342
176k
        if (draw == kDontDraw) {
343
20
          bx += llf_x;
344
20
          continue;
345
20
        }
346
347
176k
        if (JXL_UNLIKELY(jpeg_data)) {
348
19
          if (acs.Strategy() != AcStrategy::Type::DCT) {
349
2
            return JXL_FAILURE(
350
2
                "Can only decode to JPEG if only DCT-8 is used.");
351
2
          }
352
353
17
          HWY_ALIGN int32_t transposed_dct_y[64];
354
49
          for (size_t c : {1, 0, 2}) {
355
            // Propagate only Y for grayscale.
356
49
            if (jpeg_is_gray && c != 1) {
357
32
              continue;
358
32
            }
359
17
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
360
0
              continue;
361
0
            }
362
17
            int16_t* JXL_RESTRICT jpeg_pos =
363
17
                jpeg_row[c] + sbx[c] * kDCTBlockSize;
364
            // JPEG XL is transposed, JPEG is not.
365
17
            auto* transposed_dct = qblock[c].ptr32;
366
17
            Transpose8x8InPlace(transposed_dct);
367
            // No CfL - no need to store the y block converted to integers.
368
17
            if (!cs.Is444() ||
369
17
                (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
370
221
              for (size_t i = 0; i < 64; i += Lanes(d)) {
371
208
                const auto ini = Load(di, transposed_dct + i);
372
208
                const auto ini16 = DemoteTo(di16, ini);
373
208
                StoreU(ini16, di16, jpeg_pos + i);
374
208
              }
375
13
            } else if (c == 1) {
376
              // Y channel: save for restoring X/B, but nothing else to do.
377
68
              for (size_t i = 0; i < 64; i += Lanes(d)) {
378
64
                const auto ini = Load(di, transposed_dct + i);
379
64
                Store(ini, di, transposed_dct_y + i);
380
64
                const auto ini16 = DemoteTo(di16, ini);
381
64
                StoreU(ini16, di16, jpeg_pos + i);
382
64
              }
383
4
            } else {
384
              // transposed_dct_y contains the y channel block, transposed.
385
0
              const auto scale =
386
0
                  Set(di, ColorCorrelation::RatioJPEG(row_cmap[c][abs_tx]));
387
0
              const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
388
0
              for (int i = 0; i < 64; i += Lanes(d)) {
389
0
                auto in = Load(di, transposed_dct + i);
390
0
                auto in_y = Load(di, transposed_dct_y + i);
391
0
                auto qt = Load(di, scaled_qtable + c * size + i);
392
0
                auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
393
0
                    Add(Mul(qt, scale), round));
394
0
                auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
395
0
                    Add(Mul(in_y, coeff_scale), round));
396
0
                StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
397
0
              }
398
0
            }
399
17
            jpeg_pos[0] =
400
17
                Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
401
17
            auto overflow = MaskFromVec(Set(di16_full, 0));
402
17
            auto underflow = MaskFromVec(Set(di16_full, 0));
403
153
            for (int i = 0; i < 64; i += Lanes(di16_full)) {
404
136
              auto in = LoadU(di16_full, jpeg_pos + i);
405
136
              overflow = Or(overflow, Gt(in, kJpegDctMax));
406
136
              underflow = Or(underflow, Lt(in, kJpegDctMin));
407
136
            }
408
17
            if (!AllFalse(di16_full, Or(overflow, underflow))) {
409
1
              return JXL_FAILURE("JPEG DCT coefficients out of range");
410
1
            }
411
17
          }
412
176k
        } else {
413
176k
          HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
414
          // Dequantize and add predictions.
415
176k
          dequant_block(
416
176k
              acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
417
176k
              dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(),
418
176k
              size, dec_state->shared->quantizer,
419
176k
              acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
420
176k
              dc_stride,
421
176k
              dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
422
176k
              block, group_dec_cache->scratch_space);
423
424
529k
          for (size_t c : {1, 0, 2}) {
425
529k
            if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
426
117k
              continue;
427
117k
            }
428
            // IDCT
429
412k
            float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
430
412k
            TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
431
412k
                              idct_stride[c], group_dec_cache->scratch_space);
432
412k
          }
433
176k
        }
434
176k
        bx += llf_x;
435
176k
      }
436
80.0k
    }
437
62.3k
  }
438
3.02k
  return true;
439
3.43k
}
440
441
// NOLINTNEXTLINE(google-readability-namespace-comments)
442
}  // namespace HWY_NAMESPACE
443
}  // namespace jxl
444
HWY_AFTER_NAMESPACE();
445
446
#if HWY_ONCE
447
namespace jxl {
448
namespace {
449
// Decode quantized AC coefficients of DCT blocks.
450
// LLF components in the output block will not be modified.
451
template <ACType ac_type, bool uses_lz77>
452
Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks,
453
                        int32_t* JXL_RESTRICT row_nzeros,
454
                        const int32_t* JXL_RESTRICT row_nzeros_top,
455
                        size_t nzeros_stride, size_t c, size_t bx, size_t by,
456
                        size_t lbx, AcStrategy acs,
457
                        const coeff_order_t* JXL_RESTRICT coeff_order,
458
                        BitReader* JXL_RESTRICT br,
459
                        ANSSymbolReader* JXL_RESTRICT decoder,
460
                        const std::vector<uint8_t>& context_map,
461
                        const uint8_t* qdc_row, const int32_t* qf_row,
462
                        const BlockCtxMap& block_ctx_map, ACPtr block,
463
1.39M
                        size_t shift = 0) {
464
  // Equal to number of LLF coefficients.
465
1.39M
  const size_t covered_blocks = 1 << log2_covered_blocks;
466
1.39M
  const size_t size = covered_blocks * kDCTBlockSize;
467
1.39M
  int32_t predicted_nzeros =
468
1.39M
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
469
470
1.39M
  size_t ord = kStrategyOrder[acs.RawStrategy()];
471
1.39M
  const coeff_order_t* JXL_RESTRICT order =
472
1.39M
      &coeff_order[CoeffOrderOffset(ord, c)];
473
474
1.39M
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
475
1.39M
  const int32_t nzero_ctx =
476
1.39M
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
477
478
1.39M
  size_t nzeros =
479
1.39M
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
480
1.39M
  if (nzeros > size - covered_blocks) {
481
694
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
482
694
                       " 8x8 blocks",
483
694
                       nzeros, covered_blocks);
484
694
  }
485
3.00M
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
486
3.98M
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
487
2.37M
      row_nzeros[bx + x + y * nzeros_stride] =
488
2.37M
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
489
2.37M
    }
490
1.60M
  }
491
492
1.39M
  const size_t histo_offset =
493
1.39M
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
494
495
1.39M
  size_t prev = (nzeros > size / 16 ? 0 : 1);
496
23.1M
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
497
21.7M
    const size_t ctx =
498
21.7M
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
499
21.7M
                                          log2_covered_blocks, prev);
500
21.7M
    const size_t u_coeff =
501
21.7M
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
502
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
503
    // signed integer to avoid undefined behavior of shifting negative numbers.
504
21.7M
    const size_t magnitude = u_coeff >> 1;
505
21.7M
    const size_t neg_sign = (~u_coeff) & 1;
506
21.7M
    const intptr_t coeff =
507
21.7M
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
508
21.7M
    if (ac_type == ACType::k16) {
509
13.5M
      block.ptr16[order[k]] += coeff;
510
13.5M
    } else {
511
8.16M
      block.ptr32[order[k]] += coeff;
512
8.16M
    }
513
21.7M
    prev = static_cast<size_t>(u_coeff != 0);
514
21.7M
    nzeros -= prev;
515
21.7M
  }
516
1.39M
  if (JXL_UNLIKELY(nzeros != 0)) {
517
643
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
518
643
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
519
643
                       "), channel %" PRIuS,
520
643
                       nzeros, bx, by, c);
521
643
  }
522
523
1.39M
  return true;
524
1.39M
}
dec_group.cc:jxl::Status jxl::(anonymous namespace)::DecodeACVarBlock<(jxl::ACType)0, true>(unsigned long, unsigned long, int*, int const*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, jxl::AcStrategy, unsigned int const*, jxl::BitReader*, jxl::ANSSymbolReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&, unsigned char const*, int const*, jxl::BlockCtxMap const&, jxl::ACPtr, unsigned long)
Line
Count
Source
463
210k
                        size_t shift = 0) {
464
  // Equal to number of LLF coefficients.
465
210k
  const size_t covered_blocks = 1 << log2_covered_blocks;
466
210k
  const size_t size = covered_blocks * kDCTBlockSize;
467
210k
  int32_t predicted_nzeros =
468
210k
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
469
470
210k
  size_t ord = kStrategyOrder[acs.RawStrategy()];
471
210k
  const coeff_order_t* JXL_RESTRICT order =
472
210k
      &coeff_order[CoeffOrderOffset(ord, c)];
473
474
210k
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
475
210k
  const int32_t nzero_ctx =
476
210k
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
477
478
210k
  size_t nzeros =
479
210k
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
480
210k
  if (nzeros > size - covered_blocks) {
481
138
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
482
138
                       " 8x8 blocks",
483
138
                       nzeros, covered_blocks);
484
138
  }
485
425k
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
486
449k
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
487
234k
      row_nzeros[bx + x + y * nzeros_stride] =
488
234k
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
489
234k
    }
490
215k
  }
491
492
210k
  const size_t histo_offset =
493
210k
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
494
495
210k
  size_t prev = (nzeros > size / 16 ? 0 : 1);
496
555k
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
497
345k
    const size_t ctx =
498
345k
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
499
345k
                                          log2_covered_blocks, prev);
500
345k
    const size_t u_coeff =
501
345k
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
502
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
503
    // signed integer to avoid undefined behavior of shifting negative numbers.
504
345k
    const size_t magnitude = u_coeff >> 1;
505
345k
    const size_t neg_sign = (~u_coeff) & 1;
506
345k
    const intptr_t coeff =
507
345k
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
508
345k
    if (ac_type == ACType::k16) {
509
345k
      block.ptr16[order[k]] += coeff;
510
18.4E
    } else {
511
18.4E
      block.ptr32[order[k]] += coeff;
512
18.4E
    }
513
345k
    prev = static_cast<size_t>(u_coeff != 0);
514
345k
    nzeros -= prev;
515
345k
  }
516
210k
  if (JXL_UNLIKELY(nzeros != 0)) {
517
74
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
518
74
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
519
74
                       "), channel %" PRIuS,
520
74
                       nzeros, bx, by, c);
521
74
  }
522
523
210k
  return true;
524
210k
}
dec_group.cc:jxl::Status jxl::(anonymous namespace)::DecodeACVarBlock<(jxl::ACType)1, true>(unsigned long, unsigned long, int*, int const*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, jxl::AcStrategy, unsigned int const*, jxl::BitReader*, jxl::ANSSymbolReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&, unsigned char const*, int const*, jxl::BlockCtxMap const&, jxl::ACPtr, unsigned long)
Line
Count
Source
463
58.0k
                        size_t shift = 0) {
464
  // Equal to number of LLF coefficients.
465
58.0k
  const size_t covered_blocks = 1 << log2_covered_blocks;
466
58.0k
  const size_t size = covered_blocks * kDCTBlockSize;
467
58.0k
  int32_t predicted_nzeros =
468
58.0k
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
469
470
58.0k
  size_t ord = kStrategyOrder[acs.RawStrategy()];
471
58.0k
  const coeff_order_t* JXL_RESTRICT order =
472
58.0k
      &coeff_order[CoeffOrderOffset(ord, c)];
473
474
58.0k
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
475
58.0k
  const int32_t nzero_ctx =
476
58.0k
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
477
478
58.0k
  size_t nzeros =
479
58.0k
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
480
58.0k
  if (nzeros > size - covered_blocks) {
481
199
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
482
199
                       " 8x8 blocks",
483
199
                       nzeros, covered_blocks);
484
199
  }
485
115k
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
486
133k
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
487
75.5k
      row_nzeros[bx + x + y * nzeros_stride] =
488
75.5k
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
489
75.5k
    }
490
57.5k
  }
491
492
57.8k
  const size_t histo_offset =
493
57.8k
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
494
495
57.8k
  size_t prev = (nzeros > size / 16 ? 0 : 1);
496
1.12M
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
497
1.06M
    const size_t ctx =
498
1.06M
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
499
1.06M
                                          log2_covered_blocks, prev);
500
1.06M
    const size_t u_coeff =
501
1.06M
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
502
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
503
    // signed integer to avoid undefined behavior of shifting negative numbers.
504
1.06M
    const size_t magnitude = u_coeff >> 1;
505
1.06M
    const size_t neg_sign = (~u_coeff) & 1;
506
1.06M
    const intptr_t coeff =
507
1.06M
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
508
1.06M
    if (ac_type == ACType::k16) {
509
0
      block.ptr16[order[k]] += coeff;
510
1.06M
    } else {
511
1.06M
      block.ptr32[order[k]] += coeff;
512
1.06M
    }
513
1.06M
    prev = static_cast<size_t>(u_coeff != 0);
514
1.06M
    nzeros -= prev;
515
1.06M
  }
516
57.8k
  if (JXL_UNLIKELY(nzeros != 0)) {
517
38
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
518
38
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
519
38
                       "), channel %" PRIuS,
520
38
                       nzeros, bx, by, c);
521
38
  }
522
523
57.8k
  return true;
524
57.8k
}
dec_group.cc:jxl::Status jxl::(anonymous namespace)::DecodeACVarBlock<(jxl::ACType)0, false>(unsigned long, unsigned long, int*, int const*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, jxl::AcStrategy, unsigned int const*, jxl::BitReader*, jxl::ANSSymbolReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&, unsigned char const*, int const*, jxl::BlockCtxMap const&, jxl::ACPtr, unsigned long)
Line
Count
Source
463
599k
                        size_t shift = 0) {
464
  // Equal to number of LLF coefficients.
465
599k
  const size_t covered_blocks = 1 << log2_covered_blocks;
466
599k
  const size_t size = covered_blocks * kDCTBlockSize;
467
599k
  int32_t predicted_nzeros =
468
599k
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
469
470
599k
  size_t ord = kStrategyOrder[acs.RawStrategy()];
471
599k
  const coeff_order_t* JXL_RESTRICT order =
472
599k
      &coeff_order[CoeffOrderOffset(ord, c)];
473
474
599k
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
475
599k
  const int32_t nzero_ctx =
476
599k
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
477
478
599k
  size_t nzeros =
479
599k
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
480
599k
  if (nzeros > size - covered_blocks) {
481
36
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
482
36
                       " 8x8 blocks",
483
36
                       nzeros, covered_blocks);
484
36
  }
485
1.33M
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
486
2.04M
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
487
1.31M
      row_nzeros[bx + x + y * nzeros_stride] =
488
1.31M
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
489
1.31M
    }
490
734k
  }
491
492
599k
  const size_t histo_offset =
493
599k
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
494
495
599k
  size_t prev = (nzeros > size / 16 ? 0 : 1);
496
13.8M
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
497
13.2M
    const size_t ctx =
498
13.2M
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
499
13.2M
                                          log2_covered_blocks, prev);
500
13.2M
    const size_t u_coeff =
501
13.2M
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
502
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
503
    // signed integer to avoid undefined behavior of shifting negative numbers.
504
13.2M
    const size_t magnitude = u_coeff >> 1;
505
13.2M
    const size_t neg_sign = (~u_coeff) & 1;
506
13.2M
    const intptr_t coeff =
507
13.2M
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
508
13.2M
    if (ac_type == ACType::k16) {
509
13.2M
      block.ptr16[order[k]] += coeff;
510
13.2M
    } else {
511
17.6k
      block.ptr32[order[k]] += coeff;
512
17.6k
    }
513
13.2M
    prev = static_cast<size_t>(u_coeff != 0);
514
13.2M
    nzeros -= prev;
515
13.2M
  }
516
599k
  if (JXL_UNLIKELY(nzeros != 0)) {
517
344
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
518
344
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
519
344
                       "), channel %" PRIuS,
520
344
                       nzeros, bx, by, c);
521
344
  }
522
523
599k
  return true;
524
599k
}
dec_group.cc:jxl::Status jxl::(anonymous namespace)::DecodeACVarBlock<(jxl::ACType)1, false>(unsigned long, unsigned long, int*, int const*, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, jxl::AcStrategy, unsigned int const*, jxl::BitReader*, jxl::ANSSymbolReader*, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> > const&, unsigned char const*, int const*, jxl::BlockCtxMap const&, jxl::ACPtr, unsigned long)
Line
Count
Source
463
523k
                        size_t shift = 0) {
464
  // Equal to number of LLF coefficients.
465
523k
  const size_t covered_blocks = 1 << log2_covered_blocks;
466
523k
  const size_t size = covered_blocks * kDCTBlockSize;
467
523k
  int32_t predicted_nzeros =
468
523k
      PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
469
470
523k
  size_t ord = kStrategyOrder[acs.RawStrategy()];
471
523k
  const coeff_order_t* JXL_RESTRICT order =
472
523k
      &coeff_order[CoeffOrderOffset(ord, c)];
473
474
523k
  size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
475
523k
  const int32_t nzero_ctx =
476
523k
      block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
477
478
523k
  size_t nzeros =
479
523k
      decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
480
523k
  if (nzeros > size - covered_blocks) {
481
321
    return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
482
321
                       " 8x8 blocks",
483
321
                       nzeros, covered_blocks);
484
321
  }
485
1.12M
  for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
486
1.35M
    for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
487
753k
      row_nzeros[bx + x + y * nzeros_stride] =
488
753k
          (nzeros + covered_blocks - 1) >> log2_covered_blocks;
489
753k
    }
490
602k
  }
491
492
522k
  const size_t histo_offset =
493
522k
      ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
494
495
522k
  size_t prev = (nzeros > size / 16 ? 0 : 1);
496
7.60M
  for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
497
7.08M
    const size_t ctx =
498
7.08M
        histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
499
7.08M
                                          log2_covered_blocks, prev);
500
7.08M
    const size_t u_coeff =
501
7.08M
        decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
502
    // Hand-rolled version of UnpackSigned, shifting before the conversion to
503
    // signed integer to avoid undefined behavior of shifting negative numbers.
504
7.08M
    const size_t magnitude = u_coeff >> 1;
505
7.08M
    const size_t neg_sign = (~u_coeff) & 1;
506
7.08M
    const intptr_t coeff =
507
7.08M
        static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
508
7.08M
    if (ac_type == ACType::k16) {
509
0
      block.ptr16[order[k]] += coeff;
510
7.08M
    } else {
511
7.08M
      block.ptr32[order[k]] += coeff;
512
7.08M
    }
513
7.08M
    prev = static_cast<size_t>(u_coeff != 0);
514
7.08M
    nzeros -= prev;
515
7.08M
  }
516
522k
  if (JXL_UNLIKELY(nzeros != 0)) {
517
187
    return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
518
187
                       ", should be 0. Block (%" PRIuS ", %" PRIuS
519
187
                       "), channel %" PRIuS,
520
187
                       nzeros, bx, by, c);
521
187
  }
522
523
522k
  return true;
524
522k
}
525
526
// Structs used by DecodeGroupImpl to get a quantized block.
527
// GetBlockFromBitstream uses ANS decoding (and thus keeps track of row
528
// pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient
529
// image provided by the encoder.
530
531
struct GetBlockFromBitstream : public GetBlock {
532
163k
  void StartRow(size_t by) override {
533
163k
    qf_row = rect.ConstRow(*qf, by);
534
652k
    for (size_t c = 0; c < 3; c++) {
535
489k
      size_t sby = by >> vshift[c];
536
489k
      quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0();
537
978k
      for (size_t i = 0; i < num_passes; i++) {
538
489k
        row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby);
539
489k
        row_nzeros_top[i][c] =
540
489k
            sby == 0
541
489k
                ? nullptr
542
489k
                : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1);
543
489k
      }
544
489k
    }
545
163k
  }
546
547
  Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size,
548
                   size_t log2_covered_blocks, ACPtr block[3],
549
523k
                   ACType ac_type) override {
550
523k
    ;
551
1.56M
    for (size_t c : {1, 0, 2}) {
552
1.56M
      size_t sbx = bx >> hshift[c];
553
1.56M
      size_t sby = by >> vshift[c];
554
1.56M
      if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) {
555
178k
        continue;
556
178k
      }
557
558
2.77M
      for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) {
559
1.39M
        auto decode_ac_varblock =
560
1.39M
            decoders[pass].UsesLZ77()
561
1.39M
                ? (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 1>
562
268k
                                          : DecodeACVarBlock<ACType::k32, 1>)
563
1.39M
                : (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 0>
564
1.12M
                                          : DecodeACVarBlock<ACType::k32, 0>);
565
1.39M
        JXL_RETURN_IF_ERROR(decode_ac_varblock(
566
1.39M
            ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c],
567
1.39M
            row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs,
568
1.39M
            &coeff_orders[pass * coeff_order_size], readers[pass],
569
1.39M
            &decoders[pass], context_map[pass], quant_dc_row, qf_row,
570
1.39M
            *block_ctx_map, block[c], shift_for_pass[pass]));
571
1.39M
      }
572
1.39M
    }
573
522k
    return true;
574
523k
  }
575
576
  Status Init(const FrameHeader& frame_header,
577
              BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes,
578
              size_t group_idx, size_t histo_selector_bits, const Rect& rect,
579
              GroupDecCache* JXL_RESTRICT group_dec_cache,
580
10.9k
              PassesDecoderState* dec_state, size_t first_pass) {
581
43.9k
    for (size_t i = 0; i < 3; i++) {
582
32.9k
      hshift[i] = frame_header.chroma_subsampling.HShift(i);
583
32.9k
      vshift[i] = frame_header.chroma_subsampling.VShift(i);
584
32.9k
    }
585
10.9k
    this->coeff_order_size = dec_state->shared->coeff_order_size;
586
10.9k
    this->coeff_orders =
587
10.9k
        dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size;
588
10.9k
    this->context_map = dec_state->context_map.data() + first_pass;
589
10.9k
    this->readers = readers;
590
10.9k
    this->num_passes = num_passes;
591
10.9k
    this->shift_for_pass = frame_header.passes.shift + first_pass;
592
10.9k
    this->group_dec_cache = group_dec_cache;
593
10.9k
    this->rect = rect;
594
10.9k
    block_ctx_map = &dec_state->shared->block_ctx_map;
595
10.9k
    qf = &dec_state->shared->raw_quant_field;
596
10.9k
    quant_dc = &dec_state->shared->quant_dc;
597
598
21.8k
    for (size_t pass = 0; pass < num_passes; pass++) {
599
      // Select which histogram set to use among those of the current pass.
600
11.0k
      size_t cur_histogram = 0;
601
11.0k
      if (histo_selector_bits != 0) {
602
6.82k
        cur_histogram = readers[pass]->ReadBits(histo_selector_bits);
603
6.82k
      }
604
11.0k
      if (cur_histogram >= dec_state->shared->num_histograms) {
605
223
        return JXL_FAILURE("Invalid histogram selector");
606
223
      }
607
10.8k
      ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts();
608
609
10.8k
      JXL_ASSIGN_OR_RETURN(
610
10.8k
          decoders[pass],
611
10.8k
          ANSSymbolReader::Create(&dec_state->code[pass + first_pass],
612
10.8k
                                  readers[pass]));
613
10.8k
    }
614
10.7k
    nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow();
615
21.5k
    for (size_t i = 0; i < num_passes; i++) {
616
10.8k
      JXL_ASSERT(
617
10.8k
          nzeros_stride ==
618
10.8k
          static_cast<size_t>(group_dec_cache->num_nzeroes[i].PixelsPerRow()));
619
10.8k
    }
620
10.7k
    return true;
621
10.7k
  }
622
623
  const uint32_t* shift_for_pass = nullptr;  // not owned
624
  const coeff_order_t* JXL_RESTRICT coeff_orders;
625
  size_t coeff_order_size;
626
  const std::vector<uint8_t>* JXL_RESTRICT context_map;
627
  ANSSymbolReader decoders[kMaxNumPasses];
628
  BitReader* JXL_RESTRICT* JXL_RESTRICT readers;
629
  size_t num_passes;
630
  size_t ctx_offset[kMaxNumPasses];
631
  size_t nzeros_stride;
632
  int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3];
633
  const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3];
634
  GroupDecCache* JXL_RESTRICT group_dec_cache;
635
  const BlockCtxMap* block_ctx_map;
636
  const ImageI* qf;
637
  const ImageB* quant_dc;
638
  const int32_t* qf_row;
639
  const uint8_t* quant_dc_row;
640
  Rect rect;
641
  size_t hshift[3], vshift[3];
642
};
643
644
struct GetBlockFromEncoder : public GetBlock {
645
0
  void StartRow(size_t by) override {}
646
647
  Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size,
648
                   size_t log2_covered_blocks, ACPtr block[3],
649
0
                   ACType ac_type) override {
650
0
    JXL_DASSERT(ac_type == ACType::k32);
651
0
    for (size_t c = 0; c < 3; c++) {
652
      // for each pass
653
0
      for (size_t i = 0; i < quantized_ac->size(); i++) {
654
0
        for (size_t k = 0; k < size; k++) {
655
          // TODO(veluca): SIMD.
656
0
          block[c].ptr32[k] +=
657
0
              rows[i][c][offset + k] * (1 << shift_for_pass[i]);
658
0
        }
659
0
      }
660
0
    }
661
0
    offset += size;
662
0
    return true;
663
0
  }
664
665
  GetBlockFromEncoder(const std::vector<std::unique_ptr<ACImage>>& ac,
666
                      size_t group_idx, const uint32_t* shift_for_pass)
667
0
      : quantized_ac(&ac), shift_for_pass(shift_for_pass) {
668
    // TODO(veluca): not supported with chroma subsampling.
669
0
    for (size_t i = 0; i < quantized_ac->size(); i++) {
670
0
      JXL_CHECK((*quantized_ac)[i]->Type() == ACType::k32);
671
0
      for (size_t c = 0; c < 3; c++) {
672
0
        rows[i][c] = (*quantized_ac)[i]->PlaneRow(c, group_idx, 0).ptr32;
673
0
      }
674
0
    }
675
0
  }
676
677
  const std::vector<std::unique_ptr<ACImage>>* JXL_RESTRICT quantized_ac;
678
  size_t offset = 0;
679
  const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3];
680
  const uint32_t* shift_for_pass = nullptr;  // not owned
681
};
682
683
HWY_EXPORT(DecodeGroupImpl);
684
685
}  // namespace
686
687
Status DecodeGroup(const FrameHeader& frame_header,
688
                   BitReader* JXL_RESTRICT* JXL_RESTRICT readers,
689
                   size_t num_passes, size_t group_idx,
690
                   PassesDecoderState* JXL_RESTRICT dec_state,
691
                   GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread,
692
                   RenderPipelineInput& render_pipeline_input,
693
                   jpeg::JPEGData* JXL_RESTRICT jpeg_data, size_t first_pass,
694
10.9k
                   bool force_draw, bool dc_only, bool* should_run_pipeline) {
695
10.9k
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
696
10.9k
  DrawMode draw =
697
10.9k
      (num_passes + first_pass == frame_header.passes.num_passes) || force_draw
698
10.9k
          ? kDraw
699
10.9k
          : kDontDraw;
700
701
10.9k
  if (should_run_pipeline) {
702
10.9k
    *should_run_pipeline = draw != kDontDraw;
703
10.9k
  }
704
705
10.9k
  if (draw == kDraw && num_passes == 0 && first_pass == 0) {
706
0
    JXL_RETURN_IF_ERROR(group_dec_cache->InitDCBufferOnce(memory_manager));
707
0
    const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
708
0
    for (size_t c : {0, 1, 2}) {
709
0
      size_t hs = cs.HShift(c);
710
0
      size_t vs = cs.VShift(c);
711
      // We reuse filter_input_storage here as it is not currently in use.
712
0
      const Rect src_rect_precs =
713
0
          dec_state->shared->frame_dim.BlockGroupRect(group_idx);
714
0
      const Rect src_rect =
715
0
          Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs,
716
0
               src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs);
717
0
      const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(),
718
0
                           src_rect.ysize());
719
0
      CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2,
720
0
                             copy_rect, &group_dec_cache->dc_buffer);
721
      // Mirrorpad. Interleaving left and right padding ensures that padding
722
      // works out correctly even for images with DC size of 1.
723
0
      for (size_t y = 0; y < src_rect.ysize() + 4; y++) {
724
0
        size_t xend = kRenderPipelineXOffset +
725
0
                      (dec_state->shared->dc->Plane(c).xsize() >> hs) -
726
0
                      src_rect.x0();
727
0
        for (size_t ix = 0; ix < 2; ix++) {
728
0
          if (src_rect.x0() == 0) {
729
0
            group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] =
730
0
                group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix];
731
0
          }
732
0
          if (src_rect.x0() + src_rect.xsize() + 2 >=
733
0
              (dec_state->shared->dc->xsize() >> hs)) {
734
0
            group_dec_cache->dc_buffer.Row(y)[xend + ix] =
735
0
                group_dec_cache->dc_buffer.Row(y)[xend - ix - 1];
736
0
          }
737
0
        }
738
0
      }
739
0
      Rect dst_rect = render_pipeline_input.GetBuffer(c).second;
740
0
      ImageF* upsampling_dst = render_pipeline_input.GetBuffer(c).first;
741
0
      JXL_ASSERT(dst_rect.IsInside(*upsampling_dst));
742
743
0
      RenderPipelineStage::RowInfo input_rows(1, std::vector<float*>(5));
744
0
      RenderPipelineStage::RowInfo output_rows(1, std::vector<float*>(8));
745
0
      for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize();
746
0
           y++) {
747
0
        for (ssize_t iy = 0; iy < 5; iy++) {
748
0
          input_rows[0][iy] = group_dec_cache->dc_buffer.Row(
749
0
              Mirror(static_cast<ssize_t>(y) + iy - 2,
750
0
                     dec_state->shared->dc->Plane(c).ysize() >> vs) +
751
0
              2 - src_rect.y0());
752
0
        }
753
0
        for (size_t iy = 0; iy < 8; iy++) {
754
0
          output_rows[0][iy] =
755
0
              dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) -
756
0
              kRenderPipelineXOffset;
757
0
        }
758
        // Arguments set to 0/nullptr are not used.
759
0
        JXL_RETURN_IF_ERROR(dec_state->upsampler8x->ProcessRow(
760
0
            input_rows, output_rows,
761
0
            /*xextra=*/0, src_rect.xsize(), 0, 0, thread));
762
0
      }
763
0
    }
764
0
    return true;
765
0
  }
766
767
10.9k
  size_t histo_selector_bits = 0;
768
10.9k
  if (dc_only) {
769
0
    JXL_ASSERT(num_passes == 0);
770
10.9k
  } else {
771
10.9k
    JXL_ASSERT(dec_state->shared->num_histograms > 0);
772
10.9k
    histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms);
773
10.9k
  }
774
775
10.9k
  auto get_block = jxl::make_unique<GetBlockFromBitstream>();
776
10.9k
  JXL_RETURN_IF_ERROR(get_block->Init(
777
10.9k
      frame_header, readers, num_passes, group_idx, histo_selector_bits,
778
10.9k
      dec_state->shared->frame_dim.BlockGroupRect(group_idx), group_dec_cache,
779
10.9k
      dec_state, first_pass));
780
781
10.7k
  JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(
782
10.7k
      frame_header, get_block.get(), group_dec_cache, dec_state, thread,
783
10.7k
      group_idx, render_pipeline_input, jpeg_data, draw));
784
785
18.8k
  for (size_t pass = 0; pass < num_passes; pass++) {
786
9.45k
    if (!get_block->decoders[pass].CheckANSFinalState()) {
787
0
      return JXL_FAILURE("ANS checksum failure.");
788
0
    }
789
9.45k
  }
790
9.36k
  return true;
791
9.36k
}
792
793
Status DecodeGroupForRoundtrip(const FrameHeader& frame_header,
794
                               const std::vector<std::unique_ptr<ACImage>>& ac,
795
                               size_t group_idx,
796
                               PassesDecoderState* JXL_RESTRICT dec_state,
797
                               GroupDecCache* JXL_RESTRICT group_dec_cache,
798
                               size_t thread,
799
                               RenderPipelineInput& render_pipeline_input,
800
                               jpeg::JPEGData* JXL_RESTRICT jpeg_data,
801
0
                               AuxOut* aux_out) {
802
0
  JxlMemoryManager* memory_manager = dec_state->memory_manager();
803
0
  GetBlockFromEncoder get_block(ac, group_idx, frame_header.passes.shift);
804
0
  JXL_RETURN_IF_ERROR(group_dec_cache->InitOnce(
805
0
      memory_manager,
806
0
      /*num_passes=*/0,
807
0
      /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1));
808
809
0
  return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(
810
0
      frame_header, &get_block, group_dec_cache, dec_state, thread, group_idx,
811
0
      render_pipeline_input, jpeg_data, kDraw);
812
0
}
813
814
}  // namespace jxl
815
#endif  // HWY_ONCE