Coverage Report

Created: 2026-02-14 07:09

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/enc_modular_simd.cc
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_modular_simd.h"
7
8
#include <cstdint>
9
10
#include "lib/jxl/base/common.h"
11
#include "lib/jxl/base/status.h"
12
#include "lib/jxl/dec_ans.h"
13
#include "lib/jxl/enc_ans_params.h"
14
#include "lib/jxl/memory_manager_internal.h"
15
#include "lib/jxl/modular/modular_image.h"
16
17
#undef HWY_TARGET_INCLUDE
18
#define HWY_TARGET_INCLUDE "lib/jxl/enc_modular_simd.cc"
19
#include <hwy/foreach_target.h>
20
#include <hwy/highway.h>
21
22
#if HWY_TARGET == HWY_SCALAR
23
#include "lib/jxl/modular/encoding/context_predict.h"
24
#include "lib/jxl/pack_signed.h"
25
#endif
26
27
HWY_BEFORE_NAMESPACE();
28
namespace jxl {
29
namespace HWY_NAMESPACE {
30
31
// These templates are not found via ADL.
32
using hwy::HWY_NAMESPACE::Add;
33
using hwy::HWY_NAMESPACE::And;
34
using hwy::HWY_NAMESPACE::Ge;
35
using hwy::HWY_NAMESPACE::GetLane;
36
using hwy::HWY_NAMESPACE::Gt;
37
using hwy::HWY_NAMESPACE::IfThenElse;
38
using hwy::HWY_NAMESPACE::IfThenElseZero;
39
using hwy::HWY_NAMESPACE::Iota;
40
using hwy::HWY_NAMESPACE::Load;
41
using hwy::HWY_NAMESPACE::LoadU;
42
using hwy::HWY_NAMESPACE::Lt;
43
using hwy::HWY_NAMESPACE::Max;
44
using hwy::HWY_NAMESPACE::Min;
45
using hwy::HWY_NAMESPACE::Mul;
46
using hwy::HWY_NAMESPACE::Not;
47
using hwy::HWY_NAMESPACE::Set;
48
using hwy::HWY_NAMESPACE::ShiftLeft;
49
using hwy::HWY_NAMESPACE::ShiftRight;
50
using hwy::HWY_NAMESPACE::Store;
51
using hwy::HWY_NAMESPACE::StoreU;
52
using hwy::HWY_NAMESPACE::Sub;
53
using hwy::HWY_NAMESPACE::Xor;
54
using hwy::HWY_NAMESPACE::Zero;
55
56
0
StatusOr<float> EstimateCost(const Image& img) {
57
0
  size_t histo_cost = 0;
58
0
  float histo_cost_frac = 0.0f;
59
0
  size_t extra_bits = 0;
60
61
0
#if HWY_TARGET == HWY_SCALAR
62
0
  HybridUintConfig config;
63
0
  uint32_t cutoffs[] = {0,  1,  3,  5,   7,   11,  15,  23, 31,
64
0
                        47, 63, 95, 127, 191, 255, 392, 500};
65
0
  constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1;
66
0
  Histogram histo[nc] = {};
67
0
  for (const Channel& ch : img.channel) {
68
0
    const ptrdiff_t onerow = ch.plane.PixelsPerRow();
69
0
    for (size_t y = 0; y < ch.h; y++) {
70
0
      const pixel_type* JXL_RESTRICT r = ch.Row(y);
71
0
      for (size_t x = 0; x < ch.w; x++) {
72
0
        pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
73
0
        pixel_type_w top = (y ? *(r + x - onerow) : left);
74
0
        pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
75
0
        size_t max_diff =
76
0
            std::max({left, top, topleft}) - std::min({left, top, topleft});
77
0
        size_t ctx = 0;
78
0
        for (uint32_t c : cutoffs) {
79
0
          ctx += (max_diff < c) ? 1 : 0;
80
0
        }
81
0
        pixel_type res = r[x] - ClampedGradient(top, left, topleft);
82
0
        uint32_t token;
83
0
        uint32_t nbits;
84
0
        uint32_t bits;
85
0
        config.Encode(PackSigned(res), &token, &nbits, &bits);
86
0
        histo[ctx].Add(token);
87
0
        extra_bits += nbits;
88
0
      }
89
0
    }
90
0
    for (auto& h : histo) {
91
0
      float f_cost = h.ShannonEntropy();
92
0
      size_t i_cost = f_cost;
93
0
      histo_cost += i_cost;
94
0
      histo_cost_frac += f_cost - i_cost;
95
0
      h.Clear();
96
0
    }
97
0
  }
98
#else
99
  JxlMemoryManager* memory_manager = img.memory_manager();
100
  const auto& ctx_map = estimate_cost_detail::ContextMap();
101
  const HWY_FULL(int32_t) di;
102
  const HWY_FULL(uint32_t) du;
103
  const HWY_FULL(float) df;
104
  const auto kOne = Set(du, 1);
105
  const auto kSplit = Set(du, 16);
106
  const auto kExpOffset2 = Set(du, 129);  // 127 + 2
107
  const auto kTokenBias = Set(du, 8);
108
  const auto kTokenMul = Set(du, 4);
109
  const auto kMsbMask = Set(du, 3);
110
  const auto kMaxDiffCap = Set(du, estimate_cost_detail::kLastThreshold - 1);
111
  const auto kLanes = Set(du, Lanes(du));
112
  const auto kIota = Iota(du, 0);
113
  const auto kLargeThreshold = Set(du, (1 << 22) - 1);
114
  constexpr size_t kLargeShiftVal = 10;
115
  const auto kLargeShift = Set(du, kLargeShiftVal);
116
117
  size_t max_w = 0;
118
  for (const Channel& ch : img.channel) {
119
    if (ch.h == 0) continue;
120
    max_w = std::max(max_w, ch.w);
121
  }
122
  max_w = RoundUpTo(max_w, Lanes(du));
123
  max_w = std::max(max_w, 2 * Lanes(du));
124
125
  JXL_ASSIGN_OR_RETURN(
126
      AlignedMemory buffer,
127
      AlignedMemory::Create(memory_manager, max_w * 2 * sizeof(uint32_t)));
128
  uint32_t* max_diff_row = buffer.address<uint32_t>();
129
  uint32_t* token_row = max_diff_row + max_w;
130
  int32_t* primer = buffer.address<int32_t>();
131
  int32_t* top_primer = primer + max_w;
132
133
  HybridUintConfig config;
134
135
  Histogram histo[estimate_cost_detail::kLastCtx + 1] = {};
136
  auto extra_bits_lanes = Zero(du);
137
  for (const Channel& ch : img.channel) {
138
    if (ch.h == 0 || ch.w == 0) continue;
139
    for (auto& h : histo) {
140
      h.EnsureCapacity(32 * 4);
141
    }
142
    const pixel_type* JXL_RESTRICT r = ch.Row(0);
143
    const pixel_type* JXL_RESTRICT last = primer;
144
    primer[0] = 0;
145
    StoreU(Load(di, r), di, primer + 1);
146
    auto pos = kIota;
147
    const auto last_pos = Set(du, ch.w);
148
    for (size_t x = 0; x < ch.w; x += Lanes(di)) {
149
      const auto left = LoadU(di, last);
150
      const auto central = Load(di, r + x);
151
      const auto ures = BitCast(du, Sub(central, left));
152
      const auto packed =
153
          Xor(ShiftLeft<1>(ures), Sub(ShiftRight<31>(Not(ures)), kOne));
154
      const auto is_large = Gt(packed, kLargeThreshold);
155
      const auto packed_shifted = ShiftRight<kLargeShiftVal>(packed);
156
      const auto not_literal = Ge(packed, kSplit);
157
      const auto packed_fixed = IfThenElse(is_large, packed_shifted, packed);
158
      const auto v = BitCast(du, ConvertTo(df, packed_fixed));
159
      const auto eb_raw = Sub(ShiftRight<23>(v), kExpOffset2);
160
      const auto eb = IfThenElse(is_large, Add(eb_raw, kLargeShift), eb_raw);
161
      const auto token = Add(Add(kTokenBias, Mul(eb, kTokenMul)),
162
                             And(ShiftRight<21>(v), kMsbMask));
163
      const auto tail_mask = Lt(pos, last_pos);
164
      const auto eb_fixed = IfThenElseZero(not_literal, eb);
165
      const auto token_fixed = IfThenElse(not_literal, token, packed);
166
      extra_bits_lanes =
167
          Add(extra_bits_lanes, IfThenElseZero(tail_mask, eb_fixed));
168
      Store(token_fixed, du, token_row + x);
169
      pos = Add(pos, kLanes);
170
      last = r + x + Lanes(di) - 1;
171
    }
172
    for (size_t x = 0; x < ch.w; x++) {
173
      histo[0].FastAdd(token_row[x]);
174
    }
175
    for (size_t y = 1; y < ch.h; y++) {
176
      r = ch.Row(y);
177
      const pixel_type* JXL_RESTRICT t = ch.Row(y - 1);
178
      last = primer;
179
      primer[0] = t[0];
180
      StoreU(Load(di, r), di, primer + 1);
181
      top_primer[0] = t[0];
182
      StoreU(Load(di, t), di, top_primer + 1);
183
      const pixel_type* JXL_RESTRICT top_last = top_primer;
184
      pos = kIota;
185
      for (size_t x = 0; x < ch.w; x += Lanes(di)) {
186
        const auto left = LoadU(di, last);
187
        const auto central = Load(di, r + x);
188
        const auto topleft = LoadU(di, top_last);
189
        const auto top = Load(di, t + x);
190
        const auto l_ge_t = Ge(left, top);
191
        const auto m = IfThenElse(l_ge_t, top, left);
192
        const auto M = IfThenElse(l_ge_t, left, top);
193
        const auto maxx = Max(topleft, M);
194
        const auto minn = Min(topleft, m);
195
        const auto max_diff = BitCast(du, Sub(maxx, minn));
196
        Store(Min(max_diff, kMaxDiffCap), du, max_diff_row + x);
197
        const auto overshoot = Lt(topleft, m);
198
        const auto undershoot = Gt(topleft, M);
199
        const auto grad =
200
            BitCast(di, Sub(Add(BitCast(du, top), BitCast(du, left)),
201
                            BitCast(du, topleft)));
202
        const auto prediction =
203
            IfThenElse(undershoot, m, IfThenElse(overshoot, M, grad));
204
        const auto ures = BitCast(du, Sub(central, prediction));
205
        const auto packed =
206
            Xor(ShiftLeft<1>(ures), Sub(ShiftRight<31>(Not(ures)), kOne));
207
        const auto is_large = Gt(packed, kLargeThreshold);
208
        const auto packed_shifted = ShiftRight<kLargeShiftVal>(packed);
209
        const auto not_literal = Ge(packed, kSplit);
210
        const auto packed_fixed = IfThenElse(is_large, packed_shifted, packed);
211
        const auto v = BitCast(du, ConvertTo(df, packed_fixed));
212
        const auto eb_raw = Sub(ShiftRight<23>(v), kExpOffset2);
213
        const auto eb = IfThenElse(is_large, Add(eb_raw, kLargeShift), eb_raw);
214
        const auto token = Add(Add(kTokenBias, Mul(eb, kTokenMul)),
215
                               And(ShiftRight<21>(v), kMsbMask));
216
        const auto tail_mask = Lt(pos, last_pos);
217
        const auto eb_fixed = IfThenElseZero(not_literal, eb);
218
        const auto token_fixed = IfThenElse(not_literal, token, packed);
219
        extra_bits_lanes =
220
            Add(extra_bits_lanes, IfThenElseZero(tail_mask, eb_fixed));
221
        Store(token_fixed, du, token_row + x);
222
        pos = Add(pos, kLanes);
223
        last = r + x + Lanes(di) - 1;
224
        top_last = t + x + Lanes(di) - 1;
225
      }
226
      for (size_t x = 0; x < ch.w; x++) {
227
        size_t ctx = ctx_map[max_diff_row[x]];
228
        histo[ctx].FastAdd(token_row[x]);
229
      }
230
    }
231
    for (auto& h : histo) {
232
      h.Condition();
233
      float f_cost = h.ShannonEntropy();
234
      size_t i_cost = f_cost;
235
      histo_cost += i_cost;
236
      histo_cost_frac += f_cost - i_cost;
237
      h.Clear();
238
    }
239
  }
240
  extra_bits = GetLane(SumOfLanes(du, extra_bits_lanes));
241
#endif
242
0
  size_t total_cost =
243
0
      extra_bits + histo_cost + static_cast<size_t>(histo_cost_frac);
244
0
  return total_cost;
245
0
}
246
247
// NOLINTNEXTLINE(google-readability-namespace-comments)
248
}  // namespace HWY_NAMESPACE
249
}  // namespace jxl
250
HWY_AFTER_NAMESPACE();
251
252
#if HWY_ONCE
253
namespace jxl {
254
255
HWY_EXPORT(EstimateCost);
256
257
0
StatusOr<float> EstimateCost(const Image& img) {
258
0
  return HWY_DYNAMIC_DISPATCH(EstimateCost)(img);
259
0
}
260
261
namespace estimate_cost_detail {
262
/*
263
cutoffs = [0, 1, 3, 5, 7, 11, 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500]
264
ctx_map = [[c for c,v in enumerate(cutoffs) if v <= i][0] for i in range(501)]
265
*/
266
0
const std::array<uint8_t, kLastThreshold>& ContextMap() {
267
0
  static const std::array<uint8_t, kLastThreshold> kCtxMap = {
268
0
      0,  1,  1,  2,  2,  3,  3,  4,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,
269
0
      6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,
270
0
      8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,
271
0
      9,  9,  9,  9,  9,  9,  9,  9,  9,  10, 10, 10, 10, 10, 10, 10, 10, 10,
272
0
      10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
273
0
      10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
274
0
      11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
275
0
      11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
276
0
      12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
277
0
      12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
278
0
      12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
279
0
      13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
280
0
      13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
281
0
      13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
282
0
      13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
283
0
      14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
284
0
      14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
285
0
      14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
286
0
      14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
287
0
      14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
288
0
      14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
289
0
      14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15,
290
0
      15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
291
0
      15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
292
0
      15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
293
0
      15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
294
0
      15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
295
0
      15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16};
296
0
  return kCtxMap;
297
0
}
298
}  // namespace estimate_cost_detail
299
300
}  // namespace jxl
301
#endif