Coverage Report

Created: 2026-03-31 06:56

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/modular/encoding/enc_ma.cc
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/modular/encoding/enc_ma.h"
7
8
#include <algorithm>
9
#include <cstdint>
10
#include <cstdlib>
11
#include <cstring>
12
#include <limits>
13
#include <numeric>
14
#include <queue>
15
#include <vector>
16
17
#include "lib/jxl/ans_params.h"
18
#include "lib/jxl/base/bits.h"
19
#include "lib/jxl/base/common.h"
20
#include "lib/jxl/base/compiler_specific.h"
21
#include "lib/jxl/base/status.h"
22
#include "lib/jxl/dec_ans.h"
23
#include "lib/jxl/modular/encoding/dec_ma.h"
24
#include "lib/jxl/modular/encoding/ma_common.h"
25
#include "lib/jxl/modular/modular_image.h"
26
27
#undef HWY_TARGET_INCLUDE
28
#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc"
29
#include <hwy/foreach_target.h>
30
#include <hwy/highway.h>
31
32
#include "lib/jxl/base/fast_math-inl.h"
33
#include "lib/jxl/base/random.h"
34
#include "lib/jxl/enc_ans.h"
35
#include "lib/jxl/modular/encoding/context_predict.h"
36
#include "lib/jxl/modular/options.h"
37
#include "lib/jxl/pack_signed.h"
38
HWY_BEFORE_NAMESPACE();
39
namespace jxl {
40
namespace HWY_NAMESPACE {
41
42
// These templates are not found via ADL.
43
using hwy::HWY_NAMESPACE::Eq;
44
using hwy::HWY_NAMESPACE::IfThenElse;
45
using hwy::HWY_NAMESPACE::Lt;
46
using hwy::HWY_NAMESPACE::Max;
47
48
const HWY_FULL(float) df;
49
const HWY_FULL(int32_t) di;
50
346k
size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
Unexecuted instantiation: jxl::N_SSE4::Padded(unsigned long)
jxl::N_AVX2::Padded(unsigned long)
Line
Count
Source
50
346k
size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
Unexecuted instantiation: jxl::N_SSE2::Padded(unsigned long)
51
52
// Compute entropy of the histogram, taking into account the minimum probability
53
// for symbols with non-zero counts.
54
26.8M
float EstimateBits(const int32_t *counts, size_t num_symbols) {
55
26.8M
  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
56
26.8M
  const auto zero = Zero(df);
57
26.8M
  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
58
26.8M
  const auto inv_total = Set(df, 1.0f / total);
59
26.8M
  auto bits_lanes = Zero(df);
60
26.8M
  auto total_v = Set(di, total);
61
194M
  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
62
167M
    const auto counts_iv = LoadU(di, &counts[i]);
63
167M
    const auto counts_fv = ConvertTo(df, counts_iv);
64
167M
    const auto probs = Mul(counts_fv, inv_total);
65
167M
    const auto mprobs = Max(probs, minprob);
66
167M
    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
67
167M
                                 BitCast(di, FastLog2f(df, mprobs)));
68
167M
    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
69
167M
  }
70
26.8M
  return GetLane(SumOfLanes(df, bits_lanes));
71
26.8M
}
Unexecuted instantiation: jxl::N_SSE4::EstimateBits(int const*, unsigned long)
jxl::N_AVX2::EstimateBits(int const*, unsigned long)
Line
Count
Source
54
26.8M
float EstimateBits(const int32_t *counts, size_t num_symbols) {
55
26.8M
  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
56
26.8M
  const auto zero = Zero(df);
57
26.8M
  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
58
26.8M
  const auto inv_total = Set(df, 1.0f / total);
59
26.8M
  auto bits_lanes = Zero(df);
60
26.8M
  auto total_v = Set(di, total);
61
194M
  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
62
167M
    const auto counts_iv = LoadU(di, &counts[i]);
63
167M
    const auto counts_fv = ConvertTo(df, counts_iv);
64
167M
    const auto probs = Mul(counts_fv, inv_total);
65
167M
    const auto mprobs = Max(probs, minprob);
66
167M
    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
67
167M
                                 BitCast(di, FastLog2f(df, mprobs)));
68
167M
    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
69
167M
  }
70
26.8M
  return GetLane(SumOfLanes(df, bits_lanes));
71
26.8M
}
Unexecuted instantiation: jxl::N_SSE2::EstimateBits(int const*, unsigned long)
72
73
void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred,
74
172k
                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
75
  // Note that the tree splits on *strictly greater*.
76
172k
  (*tree)[pos].lchild = tree->size();
77
172k
  (*tree)[pos].rchild = tree->size() + 1;
78
172k
  (*tree)[pos].splitval = splitval;
79
172k
  (*tree)[pos].property = property;
80
172k
  tree->emplace_back();
81
172k
  tree->back().property = -1;
82
172k
  tree->back().predictor = rpred;
83
172k
  tree->back().predictor_offset = roff;
84
172k
  tree->back().multiplier = 1;
85
172k
  tree->emplace_back();
86
172k
  tree->back().property = -1;
87
172k
  tree->back().predictor = lpred;
88
172k
  tree->back().predictor_offset = loff;
89
172k
  tree->back().multiplier = 1;
90
172k
}
Unexecuted instantiation: jxl::N_SSE4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
jxl::N_AVX2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Line
Count
Source
74
172k
                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
75
  // Note that the tree splits on *strictly greater*.
76
172k
  (*tree)[pos].lchild = tree->size();
77
172k
  (*tree)[pos].rchild = tree->size() + 1;
78
172k
  (*tree)[pos].splitval = splitval;
79
172k
  (*tree)[pos].property = property;
80
172k
  tree->emplace_back();
81
172k
  tree->back().property = -1;
82
172k
  tree->back().predictor = rpred;
83
172k
  tree->back().predictor_offset = roff;
84
172k
  tree->back().multiplier = 1;
85
172k
  tree->emplace_back();
86
172k
  tree->back().property = -1;
87
172k
  tree->back().predictor = lpred;
88
172k
  tree->back().predictor_offset = loff;
89
172k
  tree->back().multiplier = 1;
90
172k
}
Unexecuted instantiation: jxl::N_SSE2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
91
92
enum class IntersectionType { kNone, kPartial, kInside };
93
IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack,
94
84.8k
                               uint32_t &partial_axis, uint32_t &partial_val) {
95
84.8k
  bool partial = false;
96
254k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
97
169k
    if (haystack[i][0] >= needle[i][1]) {
98
0
      return IntersectionType::kNone;
99
0
    }
100
169k
    if (haystack[i][1] <= needle[i][0]) {
101
0
      return IntersectionType::kNone;
102
0
    }
103
169k
    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
104
169k
      continue;
105
169k
    }
106
0
    partial = true;
107
0
    partial_axis = i;
108
0
    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
109
0
      partial_val = haystack[i][0] - 1;
110
0
    } else {
111
0
      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
112
0
                  haystack[i][1] < needle[i][1]);
113
0
      partial_val = haystack[i][1] - 1;
114
0
    }
115
0
  }
116
84.8k
  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
117
84.8k
}
Unexecuted instantiation: jxl::N_SSE4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
jxl::N_AVX2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
Line
Count
Source
94
84.8k
                               uint32_t &partial_axis, uint32_t &partial_val) {
95
84.8k
  bool partial = false;
96
254k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
97
169k
    if (haystack[i][0] >= needle[i][1]) {
98
0
      return IntersectionType::kNone;
99
0
    }
100
169k
    if (haystack[i][1] <= needle[i][0]) {
101
0
      return IntersectionType::kNone;
102
0
    }
103
169k
    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
104
169k
      continue;
105
169k
    }
106
0
    partial = true;
107
0
    partial_axis = i;
108
0
    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
109
0
      partial_val = haystack[i][0] - 1;
110
0
    } else {
111
0
      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
112
0
                  haystack[i][1] < needle[i][1]);
113
0
      partial_val = haystack[i][1] - 1;
114
0
    }
115
0
  }
116
84.8k
  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
117
84.8k
}
Unexecuted instantiation: jxl::N_SSE2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
118
119
template<bool S>
120
void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
121
172k
                      size_t end, size_t prop, uint32_t val) {
122
172k
  size_t begin_pos = begin;
123
172k
  size_t end_pos = pos;
124
9.76M
  do {
125
24.2M
    while (begin_pos < pos &&
126
24.2M
           tree_samples.Property<S>(prop, begin_pos) <= val) {
127
14.5M
      ++begin_pos;
128
14.5M
    }
129
26.6M
    while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) {
130
16.8M
      ++end_pos;
131
16.8M
    }
132
9.76M
    if (begin_pos < pos && end_pos < end) {
133
9.72M
      tree_samples.Swap(begin_pos, end_pos);
134
9.72M
    }
135
9.76M
    ++begin_pos;
136
9.76M
    ++end_pos;
137
9.76M
  } while (begin_pos < pos && end_pos < end);
138
172k
}
Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
void jxl::N_AVX2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Line
Count
Source
121
6.75k
                      size_t end, size_t prop, uint32_t val) {
122
6.75k
  size_t begin_pos = begin;
123
6.75k
  size_t end_pos = pos;
124
546k
  do {
125
2.87M
    while (begin_pos < pos &&
126
2.87M
           tree_samples.Property<S>(prop, begin_pos) <= val) {
127
2.33M
      ++begin_pos;
128
2.33M
    }
129
1.87M
    while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) {
130
1.33M
      ++end_pos;
131
1.33M
    }
132
546k
    if (begin_pos < pos && end_pos < end) {
133
543k
      tree_samples.Swap(begin_pos, end_pos);
134
543k
    }
135
546k
    ++begin_pos;
136
546k
    ++end_pos;
137
546k
  } while (begin_pos < pos && end_pos < end);
138
6.75k
}
void jxl::N_AVX2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Line
Count
Source
121
165k
                      size_t end, size_t prop, uint32_t val) {
122
165k
  size_t begin_pos = begin;
123
165k
  size_t end_pos = pos;
124
9.21M
  do {
125
21.4M
    while (begin_pos < pos &&
126
21.3M
           tree_samples.Property<S>(prop, begin_pos) <= val) {
127
12.1M
      ++begin_pos;
128
12.1M
    }
129
24.7M
    while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) {
130
15.5M
      ++end_pos;
131
15.5M
    }
132
9.21M
    if (begin_pos < pos && end_pos < end) {
133
9.18M
      tree_samples.Swap(begin_pos, end_pos);
134
9.18M
    }
135
9.21M
    ++begin_pos;
136
9.21M
    ++end_pos;
137
9.21M
  } while (begin_pos < pos && end_pos < end);
138
165k
}
Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
139
140
template <bool S>
141
void CollectExtraBitsIncrease(TreeSamples &tree_samples,
142
                              const std::vector<ResidualToken> &rtokens,
143
                              std::vector<int> &count_increase,
144
                              std::vector<size_t> &extra_bits_increase,
145
                              size_t begin, size_t end, size_t prop_idx,
146
2.25M
                              size_t max_symbols) {
147
402M
  for (size_t i2 = begin; i2 < end; i2++) {
148
399M
    const ResidualToken &rt = rtokens[i2];
149
399M
    size_t cnt = tree_samples.Count(i2);
150
399M
    size_t p = tree_samples.Property<S>(prop_idx, i2);
151
399M
    size_t sym = rt.tok;
152
399M
    size_t ebi = rt.nbits * cnt;
153
399M
    count_increase[p * max_symbols + sym] += cnt;
154
399M
    extra_bits_increase[p] += ebi;
155
399M
  }
156
2.25M
}
Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
void jxl::N_AVX2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
Line
Count
Source
146
418k
                              size_t max_symbols) {
147
79.4M
  for (size_t i2 = begin; i2 < end; i2++) {
148
79.0M
    const ResidualToken &rt = rtokens[i2];
149
79.0M
    size_t cnt = tree_samples.Count(i2);
150
79.0M
    size_t p = tree_samples.Property<S>(prop_idx, i2);
151
79.0M
    size_t sym = rt.tok;
152
79.0M
    size_t ebi = rt.nbits * cnt;
153
79.0M
    count_increase[p * max_symbols + sym] += cnt;
154
79.0M
    extra_bits_increase[p] += ebi;
155
79.0M
  }
156
418k
}
void jxl::N_AVX2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
Line
Count
Source
146
1.83M
                              size_t max_symbols) {
147
322M
  for (size_t i2 = begin; i2 < end; i2++) {
148
320M
    const ResidualToken &rt = rtokens[i2];
149
320M
    size_t cnt = tree_samples.Count(i2);
150
320M
    size_t p = tree_samples.Property<S>(prop_idx, i2);
151
320M
    size_t sym = rt.tok;
152
320M
    size_t ebi = rt.nbits * cnt;
153
320M
    count_increase[p * max_symbols + sym] += cnt;
154
320M
    extra_bits_increase[p] += ebi;
155
320M
  }
156
1.83M
}
Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
157
158
void FindBestSplit(TreeSamples &tree_samples, float threshold,
159
                   const std::vector<ModularMultiplierInfo> &mul_info,
160
                   StaticPropRange initial_static_prop_range,
161
2.28k
                   float fast_decode_multiplier, Tree *tree) {
162
2.28k
  struct NodeInfo {
163
2.28k
    size_t pos;
164
2.28k
    size_t begin;
165
2.28k
    size_t end;
166
2.28k
    StaticPropRange static_prop_range;
167
2.28k
  };
168
2.28k
  std::vector<NodeInfo> nodes;
169
2.28k
  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(),
170
2.28k
                           initial_static_prop_range});
171
172
2.28k
  size_t num_predictors = tree_samples.NumPredictors();
173
2.28k
  size_t num_properties = tree_samples.NumProperties();
174
175
  // TODO(veluca): consider parallelizing the search (processing multiple nodes
176
  // at a time).
177
348k
  while (!nodes.empty()) {
178
346k
    size_t pos = nodes.back().pos;
179
346k
    size_t begin = nodes.back().begin;
180
346k
    size_t end = nodes.back().end;
181
182
346k
    StaticPropRange static_prop_range = nodes.back().static_prop_range;
183
346k
    nodes.pop_back();
184
346k
    if (begin == end) continue;
185
186
346k
    struct SplitInfo {
187
346k
      size_t prop = 0;
188
346k
      uint32_t val = 0;
189
346k
      size_t pos = 0;
190
346k
      float lcost = std::numeric_limits<float>::max();
191
346k
      float rcost = std::numeric_limits<float>::max();
192
346k
      Predictor lpred = Predictor::Zero;
193
346k
      Predictor rpred = Predictor::Zero;
194
15.0M
      float Cost() const { return lcost + rcost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
Line
Count
Source
194
15.0M
      float Cost() const { return lcost + rcost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
195
346k
    };
196
197
346k
    SplitInfo best_split_static_constant;
198
346k
    SplitInfo best_split_static;
199
346k
    SplitInfo best_split_nonstatic;
200
346k
    SplitInfo best_split_nowp;
201
202
346k
    JXL_DASSERT(begin <= end);
203
346k
    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
204
205
    // Compute the maximum token in the range.
206
346k
    size_t max_symbols = 0;
207
693k
    for (size_t pred = 0; pred < num_predictors; pred++) {
208
57.7M
      for (size_t i = begin; i < end; i++) {
209
57.4M
        uint32_t tok = tree_samples.Token(pred, i);
210
57.4M
        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
211
57.4M
      }
212
346k
    }
213
346k
    max_symbols = Padded(max_symbols);
214
346k
    std::vector<int32_t> counts(max_symbols * num_predictors);
215
346k
    std::vector<uint32_t> tot_extra_bits(num_predictors);
216
693k
    for (size_t pred = 0; pred < num_predictors; pred++) {
217
346k
      size_t extra_bits = 0;
218
346k
      const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred);
219
57.7M
      for (size_t i = begin; i < end; i++) {
220
57.4M
        const ResidualToken& rt = rtokens[i];
221
57.4M
        size_t count = tree_samples.Count(i);
222
57.4M
        size_t eb = rt.nbits * count;
223
57.4M
        counts[pred * max_symbols + rt.tok] += count;
224
57.4M
        extra_bits += eb;
225
57.4M
      }
226
346k
      tot_extra_bits[pred] = extra_bits;
227
346k
    }
228
229
346k
    float base_bits;
230
346k
    {
231
346k
      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
232
346k
      base_bits =
233
346k
          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
234
346k
          tot_extra_bits[pred];
235
346k
    }
236
237
346k
    SplitInfo *best = &best_split_nonstatic;
238
239
346k
    SplitInfo forced_split;
240
    // The multiplier ranges cut halfway through the current ranges of static
241
    // properties. We do this even if the current node is not a leaf, to
242
    // minimize the number of nodes in the resulting tree.
243
346k
    for (const auto &mmi : mul_info) {
244
84.8k
      uint32_t axis;
245
84.8k
      uint32_t val;
246
84.8k
      IntersectionType t =
247
84.8k
          BoxIntersects(static_prop_range, mmi.range, axis, val);
248
84.8k
      if (t == IntersectionType::kNone) continue;
249
84.8k
      if (t == IntersectionType::kInside) {
250
84.8k
        (*tree)[pos].multiplier = mmi.multiplier;
251
84.8k
        break;
252
84.8k
      }
253
0
      if (t == IntersectionType::kPartial) {
254
0
        JXL_DASSERT(axis < kNumStaticProperties);
255
0
        forced_split.val = tree_samples.QuantizeStaticProperty(axis, val);
256
0
        forced_split.prop = axis;
257
0
        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
258
0
        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
259
0
        best = &forced_split;
260
0
        best->pos = begin;
261
0
        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
262
0
        if (best->prop < tree_samples.NumStaticProps()) {
263
0
        for (size_t x = begin; x < end; x++) {
264
0
          if (tree_samples.Property<true>(best->prop, x) <= best->val) {
265
0
            best->pos++;
266
0
          }
267
0
        }
268
0
      } else {
269
0
        size_t prop = best->prop - tree_samples.NumStaticProps();
270
0
        for (size_t x = begin; x < end; x++) {
271
0
          if (tree_samples.Property<false>(prop, x) <= best->val) {
272
0
            best->pos++;
273
0
          }
274
0
        }
275
0
      }
276
0
        break;
277
0
      }
278
0
    }
279
280
346k
    if (best != &forced_split) {
281
346k
      std::vector<int> prop_value_used_count;
282
346k
      std::vector<int> count_increase;
283
346k
      std::vector<size_t> extra_bits_increase;
284
      // For each property, compute which of its values are used, and what
285
      // tokens correspond to those usages. Then, iterate through the values,
286
      // and compute the entropy of each side of the split (of the form `prop >
287
      // threshold`). Finally, find the split that minimizes the cost.
288
346k
      struct CostInfo {
289
346k
        float cost = std::numeric_limits<float>::max();
290
346k
        float extra_cost = 0;
291
26.4M
        float Cost() const { return cost + extra_cost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
Line
Count
Source
291
26.4M
        float Cost() const { return cost + extra_cost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
292
346k
        Predictor pred;  // will be uninitialized in some cases, but never used.
293
346k
      };
294
346k
      std::vector<CostInfo> costs_l;
295
346k
      std::vector<CostInfo> costs_r;
296
297
346k
      std::vector<int32_t> counts_above(max_symbols);
298
346k
      std::vector<int32_t> counts_below(max_symbols);
299
300
      // The lower the threshold, the higher the expected noisiness of the
301
      // estimate. Thus, discourage changing predictors.
302
346k
      float change_pred_penalty = 800.0f / (100.0f + threshold);
303
2.60M
      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
304
2.25M
           prop++) {
305
2.25M
        costs_l.clear();
306
2.25M
        costs_r.clear();
307
2.25M
        size_t prop_size = tree_samples.NumPropertyValues(prop);
308
2.25M
        if (extra_bits_increase.size() < prop_size) {
309
813k
          count_increase.resize(prop_size * max_symbols);
310
813k
          extra_bits_increase.resize(prop_size);
311
813k
        }
312
        // Clear prop_value_used_count (which cannot be cleared "on the go")
313
2.25M
        prop_value_used_count.clear();
314
2.25M
        prop_value_used_count.resize(prop_size);
315
316
2.25M
        size_t first_used = prop_size;
317
2.25M
        size_t last_used = 0;
318
319
        // TODO(veluca): consider finding multiple splits along a single
320
        // property at the same time, possibly with a bottom-up approach.
321
2.25M
        if (prop < tree_samples.NumStaticProps()) {
322
79.4M
          for (size_t i = begin; i < end; i++) {
323
79.0M
            size_t p = tree_samples.Property<true>(prop, i);
324
79.0M
            prop_value_used_count[p]++;
325
79.0M
            last_used = std::max(last_used, p);
326
79.0M
            first_used = std::min(first_used, p);
327
79.0M
          }
328
1.83M
        } else {
329
1.83M
          size_t prop_idx = prop - tree_samples.NumStaticProps();
330
322M
          for (size_t i = begin; i < end; i++) {
331
320M
            size_t p = tree_samples.Property<false>(prop_idx, i);
332
320M
            prop_value_used_count[p]++;
333
320M
            last_used = std::max(last_used, p);
334
320M
            first_used = std::min(first_used, p);
335
320M
          }
336
1.83M
        }
337
2.25M
        costs_l.resize(last_used - first_used);
338
2.25M
        costs_r.resize(last_used - first_used);
339
        // For all predictors, compute the right and left costs of each split.
340
4.51M
        for (size_t pred = 0; pred < num_predictors; pred++) {
341
          // Compute cost and histogram increments for each property value.
342
2.25M
          const std::vector<ResidualToken> &rtokens =
343
2.25M
              tree_samples.RTokens(pred);
344
2.25M
          if (prop < tree_samples.NumStaticProps()) {
345
418k
            CollectExtraBitsIncrease<true>(tree_samples, rtokens,
346
418k
                                           count_increase, extra_bits_increase,
347
418k
                                           begin, end, prop, max_symbols);
348
1.83M
          } else {
349
1.83M
            CollectExtraBitsIncrease<false>(
350
1.83M
                tree_samples, rtokens, count_increase, extra_bits_increase,
351
1.83M
                begin, end, prop - tree_samples.NumStaticProps(), max_symbols);
352
1.83M
          }
353
2.25M
          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
354
2.25M
                 max_symbols * sizeof counts_above[0]);
355
2.25M
          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
356
2.25M
          size_t extra_bits_below = 0;
357
          // Exclude last used: this ensures neither counts_above nor
358
          // counts_below is empty.
359
32.7M
          for (size_t i = first_used; i < last_used; i++) {
360
30.4M
            if (!prop_value_used_count[i]) continue;
361
13.2M
            extra_bits_below += extra_bits_increase[i];
362
            // The increase for this property value has been used, and will not
363
            // be used again: clear it. Also below.
364
13.2M
            extra_bits_increase[i] = 0;
365
674M
            for (size_t sym = 0; sym < max_symbols; sym++) {
366
660M
              counts_above[sym] -= count_increase[i * max_symbols + sym];
367
660M
              counts_below[sym] += count_increase[i * max_symbols + sym];
368
660M
              count_increase[i * max_symbols + sym] = 0;
369
660M
            }
370
13.2M
            float rcost = EstimateBits(counts_above.data(), max_symbols) +
371
13.2M
                          tot_extra_bits[pred] - extra_bits_below;
372
13.2M
            float lcost = EstimateBits(counts_below.data(), max_symbols) +
373
13.2M
                          extra_bits_below;
374
13.2M
            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
375
13.2M
            float penalty = 0;
376
            // Never discourage moving away from the Weighted predictor.
377
13.2M
            if (tree_samples.PredictorFromIndex(pred) !=
378
13.2M
                    (*tree)[pos].predictor &&
379
0
                (*tree)[pos].predictor != Predictor::Weighted) {
380
0
              penalty = change_pred_penalty;
381
0
            }
382
            // If everything else is equal, disfavour Weighted (slower) and
383
            // favour Zero (faster if it's the only predictor used in a
384
            // group+channel combination)
385
13.2M
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
386
0
              penalty += 1e-8;
387
0
            }
388
13.2M
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
389
0
              penalty -= 1e-8;
390
0
            }
391
13.2M
            if (rcost + penalty < costs_r[i - first_used].Cost()) {
392
13.2M
              costs_r[i - first_used].cost = rcost;
393
13.2M
              costs_r[i - first_used].extra_cost = penalty;
394
13.2M
              costs_r[i - first_used].pred =
395
13.2M
                  tree_samples.PredictorFromIndex(pred);
396
13.2M
            }
397
13.2M
            if (lcost + penalty < costs_l[i - first_used].Cost()) {
398
13.2M
              costs_l[i - first_used].cost = lcost;
399
13.2M
              costs_l[i - first_used].extra_cost = penalty;
400
13.2M
              costs_l[i - first_used].pred =
401
13.2M
                  tree_samples.PredictorFromIndex(pred);
402
13.2M
            }
403
13.2M
          }
404
2.25M
        }
405
        // Iterate through the possible splits and find the one with minimum sum
406
        // of costs of the two sides.
407
2.25M
        size_t split = begin;
408
32.7M
        for (size_t i = first_used; i < last_used; i++) {
409
30.4M
          if (!prop_value_used_count[i]) continue;
410
13.2M
          split += prop_value_used_count[i];
411
13.2M
          float rcost = costs_r[i - first_used].cost;
412
13.2M
          float lcost = costs_l[i - first_used].cost;
413
414
13.2M
          bool uses_wp = tree_samples.PropertyFromIndex(prop) == kWPProp ||
415
11.0M
                         costs_l[i - first_used].pred == Predictor::Weighted ||
416
11.0M
                         costs_r[i - first_used].pred == Predictor::Weighted;
417
13.2M
          bool zero_entropy_side = rcost == 0 || lcost == 0;
418
419
13.2M
          SplitInfo &best_ref =
420
13.2M
              tree_samples.PropertyFromIndex(prop) < kNumStaticProperties
421
13.2M
                  ? (zero_entropy_side ? best_split_static_constant
422
90.6k
                                       : best_split_static)
423
13.2M
                  : (uses_wp ? best_split_nonstatic : best_split_nowp);
424
13.2M
          if (lcost + rcost < best_ref.Cost()) {
425
2.40M
            best_ref.prop = prop;
426
2.40M
            best_ref.val = i;
427
2.40M
            best_ref.pos = split;
428
2.40M
            best_ref.lcost = lcost;
429
2.40M
            best_ref.lpred = costs_l[i - first_used].pred;
430
2.40M
            best_ref.rcost = rcost;
431
2.40M
            best_ref.rpred = costs_r[i - first_used].pred;
432
2.40M
          }
433
13.2M
        }
434
        // Clear extra_bits_increase and cost_increase for last_used.
435
2.25M
        extra_bits_increase[last_used] = 0;
436
115M
        for (size_t sym = 0; sym < max_symbols; sym++) {
437
113M
          count_increase[last_used * max_symbols + sym] = 0;
438
113M
        }
439
2.25M
      }
440
441
      // Try to avoid introducing WP.
442
346k
      if (best_split_nowp.Cost() + threshold < base_bits &&
443
166k
          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
444
159k
        best = &best_split_nowp;
445
159k
      }
446
      // Split along static props if possible and not significantly more
447
      // expensive.
448
346k
      if (best_split_static.Cost() + threshold < base_bits &&
449
30.1k
          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
450
6.40k
        best = &best_split_static;
451
6.40k
      }
452
      // Split along static props to create constant nodes if possible.
453
346k
      if (best_split_static_constant.Cost() + threshold < base_bits) {
454
489
        best = &best_split_static_constant;
455
489
      }
456
346k
    }
457
458
346k
    if (best->Cost() + threshold < base_bits) {
459
172k
      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
460
172k
      pixel_type dequant =
461
172k
          tree_samples.UnquantizeProperty(best->prop, best->val);
462
      // Split node and try to split children.
463
172k
      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
464
      // "Sort" according to winning property
465
172k
      if (best->prop < tree_samples.NumStaticProps()) {
466
6.75k
        SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop,
467
6.75k
                               best->val);
468
165k
      } else {
469
165k
        SplitTreeSamples<false>(tree_samples, begin, best->pos, end,
470
165k
                                best->prop - tree_samples.NumStaticProps(),
471
165k
                                best->val);
472
165k
      }
473
172k
      auto new_sp_range = static_prop_range;
474
172k
      if (p < kNumStaticProperties) {
475
6.75k
        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
476
6.75k
        new_sp_range[p][1] = dequant + 1;
477
6.75k
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
478
6.75k
      }
479
172k
      nodes.push_back(
480
172k
          NodeInfo{(*tree)[pos].rchild, begin, best->pos, new_sp_range});
481
172k
      new_sp_range = static_prop_range;
482
172k
      if (p < kNumStaticProperties) {
483
6.75k
        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
484
6.75k
        new_sp_range[p][0] = dequant + 1;
485
6.75k
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
486
6.75k
      }
487
172k
      nodes.push_back(
488
172k
          NodeInfo{(*tree)[pos].lchild, best->pos, end, new_sp_range});
489
172k
    }
490
346k
  }
491
2.28k
}
Unexecuted instantiation: jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Line
Count
Source
161
2.28k
                   float fast_decode_multiplier, Tree *tree) {
162
2.28k
  struct NodeInfo {
163
2.28k
    size_t pos;
164
2.28k
    size_t begin;
165
2.28k
    size_t end;
166
2.28k
    StaticPropRange static_prop_range;
167
2.28k
  };
168
2.28k
  std::vector<NodeInfo> nodes;
169
2.28k
  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(),
170
2.28k
                           initial_static_prop_range});
171
172
2.28k
  size_t num_predictors = tree_samples.NumPredictors();
173
2.28k
  size_t num_properties = tree_samples.NumProperties();
174
175
  // TODO(veluca): consider parallelizing the search (processing multiple nodes
176
  // at a time).
177
348k
  while (!nodes.empty()) {
178
346k
    size_t pos = nodes.back().pos;
179
346k
    size_t begin = nodes.back().begin;
180
346k
    size_t end = nodes.back().end;
181
182
346k
    StaticPropRange static_prop_range = nodes.back().static_prop_range;
183
346k
    nodes.pop_back();
184
346k
    if (begin == end) continue;
185
186
346k
    struct SplitInfo {
187
346k
      size_t prop = 0;
188
346k
      uint32_t val = 0;
189
346k
      size_t pos = 0;
190
346k
      float lcost = std::numeric_limits<float>::max();
191
346k
      float rcost = std::numeric_limits<float>::max();
192
346k
      Predictor lpred = Predictor::Zero;
193
346k
      Predictor rpred = Predictor::Zero;
194
346k
      float Cost() const { return lcost + rcost; }
195
346k
    };
196
197
346k
    SplitInfo best_split_static_constant;
198
346k
    SplitInfo best_split_static;
199
346k
    SplitInfo best_split_nonstatic;
200
346k
    SplitInfo best_split_nowp;
201
202
346k
    JXL_DASSERT(begin <= end);
203
346k
    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
204
205
    // Compute the maximum token in the range.
206
346k
    size_t max_symbols = 0;
207
693k
    for (size_t pred = 0; pred < num_predictors; pred++) {
208
57.7M
      for (size_t i = begin; i < end; i++) {
209
57.4M
        uint32_t tok = tree_samples.Token(pred, i);
210
57.4M
        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
211
57.4M
      }
212
346k
    }
213
346k
    max_symbols = Padded(max_symbols);
214
346k
    std::vector<int32_t> counts(max_symbols * num_predictors);
215
346k
    std::vector<uint32_t> tot_extra_bits(num_predictors);
216
693k
    for (size_t pred = 0; pred < num_predictors; pred++) {
217
346k
      size_t extra_bits = 0;
218
346k
      const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred);
219
57.7M
      for (size_t i = begin; i < end; i++) {
220
57.4M
        const ResidualToken& rt = rtokens[i];
221
57.4M
        size_t count = tree_samples.Count(i);
222
57.4M
        size_t eb = rt.nbits * count;
223
57.4M
        counts[pred * max_symbols + rt.tok] += count;
224
57.4M
        extra_bits += eb;
225
57.4M
      }
226
346k
      tot_extra_bits[pred] = extra_bits;
227
346k
    }
228
229
346k
    float base_bits;
230
346k
    {
231
346k
      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
232
346k
      base_bits =
233
346k
          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
234
346k
          tot_extra_bits[pred];
235
346k
    }
236
237
346k
    SplitInfo *best = &best_split_nonstatic;
238
239
346k
    SplitInfo forced_split;
240
    // The multiplier ranges cut halfway through the current ranges of static
241
    // properties. We do this even if the current node is not a leaf, to
242
    // minimize the number of nodes in the resulting tree.
243
346k
    for (const auto &mmi : mul_info) {
244
84.8k
      uint32_t axis;
245
84.8k
      uint32_t val;
246
84.8k
      IntersectionType t =
247
84.8k
          BoxIntersects(static_prop_range, mmi.range, axis, val);
248
84.8k
      if (t == IntersectionType::kNone) continue;
249
84.8k
      if (t == IntersectionType::kInside) {
250
84.8k
        (*tree)[pos].multiplier = mmi.multiplier;
251
84.8k
        break;
252
84.8k
      }
253
0
      if (t == IntersectionType::kPartial) {
254
0
        JXL_DASSERT(axis < kNumStaticProperties);
255
0
        forced_split.val = tree_samples.QuantizeStaticProperty(axis, val);
256
0
        forced_split.prop = axis;
257
0
        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
258
0
        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
259
0
        best = &forced_split;
260
0
        best->pos = begin;
261
0
        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
262
0
        if (best->prop < tree_samples.NumStaticProps()) {
263
0
        for (size_t x = begin; x < end; x++) {
264
0
          if (tree_samples.Property<true>(best->prop, x) <= best->val) {
265
0
            best->pos++;
266
0
          }
267
0
        }
268
0
      } else {
269
0
        size_t prop = best->prop - tree_samples.NumStaticProps();
270
0
        for (size_t x = begin; x < end; x++) {
271
0
          if (tree_samples.Property<false>(prop, x) <= best->val) {
272
0
            best->pos++;
273
0
          }
274
0
        }
275
0
      }
276
0
        break;
277
0
      }
278
0
    }
279
280
346k
    if (best != &forced_split) {
281
346k
      std::vector<int> prop_value_used_count;
282
346k
      std::vector<int> count_increase;
283
346k
      std::vector<size_t> extra_bits_increase;
284
      // For each property, compute which of its values are used, and what
285
      // tokens correspond to those usages. Then, iterate through the values,
286
      // and compute the entropy of each side of the split (of the form `prop >
287
      // threshold`). Finally, find the split that minimizes the cost.
288
346k
      struct CostInfo {
289
346k
        float cost = std::numeric_limits<float>::max();
290
346k
        float extra_cost = 0;
291
346k
        float Cost() const { return cost + extra_cost; }
292
346k
        Predictor pred;  // will be uninitialized in some cases, but never used.
293
346k
      };
294
346k
      std::vector<CostInfo> costs_l;
295
346k
      std::vector<CostInfo> costs_r;
296
297
346k
      std::vector<int32_t> counts_above(max_symbols);
298
346k
      std::vector<int32_t> counts_below(max_symbols);
299
300
      // The lower the threshold, the higher the expected noisiness of the
301
      // estimate. Thus, discourage changing predictors.
302
346k
      float change_pred_penalty = 800.0f / (100.0f + threshold);
303
2.60M
      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
304
2.25M
           prop++) {
305
2.25M
        costs_l.clear();
306
2.25M
        costs_r.clear();
307
2.25M
        size_t prop_size = tree_samples.NumPropertyValues(prop);
308
2.25M
        if (extra_bits_increase.size() < prop_size) {
309
813k
          count_increase.resize(prop_size * max_symbols);
310
813k
          extra_bits_increase.resize(prop_size);
311
813k
        }
312
        // Clear prop_value_used_count (which cannot be cleared "on the go")
313
2.25M
        prop_value_used_count.clear();
314
2.25M
        prop_value_used_count.resize(prop_size);
315
316
2.25M
        size_t first_used = prop_size;
317
2.25M
        size_t last_used = 0;
318
319
        // TODO(veluca): consider finding multiple splits along a single
320
        // property at the same time, possibly with a bottom-up approach.
321
2.25M
        if (prop < tree_samples.NumStaticProps()) {
322
79.4M
          for (size_t i = begin; i < end; i++) {
323
79.0M
            size_t p = tree_samples.Property<true>(prop, i);
324
79.0M
            prop_value_used_count[p]++;
325
79.0M
            last_used = std::max(last_used, p);
326
79.0M
            first_used = std::min(first_used, p);
327
79.0M
          }
328
1.83M
        } else {
329
1.83M
          size_t prop_idx = prop - tree_samples.NumStaticProps();
330
322M
          for (size_t i = begin; i < end; i++) {
331
320M
            size_t p = tree_samples.Property<false>(prop_idx, i);
332
320M
            prop_value_used_count[p]++;
333
320M
            last_used = std::max(last_used, p);
334
320M
            first_used = std::min(first_used, p);
335
320M
          }
336
1.83M
        }
337
2.25M
        costs_l.resize(last_used - first_used);
338
2.25M
        costs_r.resize(last_used - first_used);
339
        // For all predictors, compute the right and left costs of each split.
340
4.51M
        for (size_t pred = 0; pred < num_predictors; pred++) {
341
          // Compute cost and histogram increments for each property value.
342
2.25M
          const std::vector<ResidualToken> &rtokens =
343
2.25M
              tree_samples.RTokens(pred);
344
2.25M
          if (prop < tree_samples.NumStaticProps()) {
345
418k
            CollectExtraBitsIncrease<true>(tree_samples, rtokens,
346
418k
                                           count_increase, extra_bits_increase,
347
418k
                                           begin, end, prop, max_symbols);
348
1.83M
          } else {
349
1.83M
            CollectExtraBitsIncrease<false>(
350
1.83M
                tree_samples, rtokens, count_increase, extra_bits_increase,
351
1.83M
                begin, end, prop - tree_samples.NumStaticProps(), max_symbols);
352
1.83M
          }
353
2.25M
          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
354
2.25M
                 max_symbols * sizeof counts_above[0]);
355
2.25M
          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
356
2.25M
          size_t extra_bits_below = 0;
357
          // Exclude last used: this ensures neither counts_above nor
358
          // counts_below is empty.
359
32.7M
          for (size_t i = first_used; i < last_used; i++) {
360
30.4M
            if (!prop_value_used_count[i]) continue;
361
13.2M
            extra_bits_below += extra_bits_increase[i];
362
            // The increase for this property value has been used, and will not
363
            // be used again: clear it. Also below.
364
13.2M
            extra_bits_increase[i] = 0;
365
674M
            for (size_t sym = 0; sym < max_symbols; sym++) {
366
660M
              counts_above[sym] -= count_increase[i * max_symbols + sym];
367
660M
              counts_below[sym] += count_increase[i * max_symbols + sym];
368
660M
              count_increase[i * max_symbols + sym] = 0;
369
660M
            }
370
13.2M
            float rcost = EstimateBits(counts_above.data(), max_symbols) +
371
13.2M
                          tot_extra_bits[pred] - extra_bits_below;
372
13.2M
            float lcost = EstimateBits(counts_below.data(), max_symbols) +
373
13.2M
                          extra_bits_below;
374
13.2M
            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
375
13.2M
            float penalty = 0;
376
            // Never discourage moving away from the Weighted predictor.
377
13.2M
            if (tree_samples.PredictorFromIndex(pred) !=
378
13.2M
                    (*tree)[pos].predictor &&
379
0
                (*tree)[pos].predictor != Predictor::Weighted) {
380
0
              penalty = change_pred_penalty;
381
0
            }
382
            // If everything else is equal, disfavour Weighted (slower) and
383
            // favour Zero (faster if it's the only predictor used in a
384
            // group+channel combination)
385
13.2M
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
386
0
              penalty += 1e-8;
387
0
            }
388
13.2M
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
389
0
              penalty -= 1e-8;
390
0
            }
391
13.2M
            if (rcost + penalty < costs_r[i - first_used].Cost()) {
392
13.2M
              costs_r[i - first_used].cost = rcost;
393
13.2M
              costs_r[i - first_used].extra_cost = penalty;
394
13.2M
              costs_r[i - first_used].pred =
395
13.2M
                  tree_samples.PredictorFromIndex(pred);
396
13.2M
            }
397
13.2M
            if (lcost + penalty < costs_l[i - first_used].Cost()) {
398
13.2M
              costs_l[i - first_used].cost = lcost;
399
13.2M
              costs_l[i - first_used].extra_cost = penalty;
400
13.2M
              costs_l[i - first_used].pred =
401
13.2M
                  tree_samples.PredictorFromIndex(pred);
402
13.2M
            }
403
13.2M
          }
404
2.25M
        }
405
        // Iterate through the possible splits and find the one with minimum sum
406
        // of costs of the two sides.
407
2.25M
        size_t split = begin;
408
32.7M
        for (size_t i = first_used; i < last_used; i++) {
409
30.4M
          if (!prop_value_used_count[i]) continue;
410
13.2M
          split += prop_value_used_count[i];
411
13.2M
          float rcost = costs_r[i - first_used].cost;
412
13.2M
          float lcost = costs_l[i - first_used].cost;
413
414
13.2M
          bool uses_wp = tree_samples.PropertyFromIndex(prop) == kWPProp ||
415
11.0M
                         costs_l[i - first_used].pred == Predictor::Weighted ||
416
11.0M
                         costs_r[i - first_used].pred == Predictor::Weighted;
417
13.2M
          bool zero_entropy_side = rcost == 0 || lcost == 0;
418
419
13.2M
          SplitInfo &best_ref =
420
13.2M
              tree_samples.PropertyFromIndex(prop) < kNumStaticProperties
421
13.2M
                  ? (zero_entropy_side ? best_split_static_constant
422
90.6k
                                       : best_split_static)
423
13.2M
                  : (uses_wp ? best_split_nonstatic : best_split_nowp);
424
13.2M
          if (lcost + rcost < best_ref.Cost()) {
425
2.40M
            best_ref.prop = prop;
426
2.40M
            best_ref.val = i;
427
2.40M
            best_ref.pos = split;
428
2.40M
            best_ref.lcost = lcost;
429
2.40M
            best_ref.lpred = costs_l[i - first_used].pred;
430
2.40M
            best_ref.rcost = rcost;
431
2.40M
            best_ref.rpred = costs_r[i - first_used].pred;
432
2.40M
          }
433
13.2M
        }
434
        // Clear extra_bits_increase and cost_increase for last_used.
435
2.25M
        extra_bits_increase[last_used] = 0;
436
115M
        for (size_t sym = 0; sym < max_symbols; sym++) {
437
113M
          count_increase[last_used * max_symbols + sym] = 0;
438
113M
        }
439
2.25M
      }
440
441
      // Try to avoid introducing WP.
442
346k
      if (best_split_nowp.Cost() + threshold < base_bits &&
443
166k
          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
444
159k
        best = &best_split_nowp;
445
159k
      }
446
      // Split along static props if possible and not significantly more
447
      // expensive.
448
346k
      if (best_split_static.Cost() + threshold < base_bits &&
449
30.1k
          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
450
6.40k
        best = &best_split_static;
451
6.40k
      }
452
      // Split along static props to create constant nodes if possible.
453
346k
      if (best_split_static_constant.Cost() + threshold < base_bits) {
454
489
        best = &best_split_static_constant;
455
489
      }
456
346k
    }
457
458
346k
    if (best->Cost() + threshold < base_bits) {
459
172k
      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
460
172k
      pixel_type dequant =
461
172k
          tree_samples.UnquantizeProperty(best->prop, best->val);
462
      // Split node and try to split children.
463
172k
      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
464
      // "Sort" according to winning property
465
172k
      if (best->prop < tree_samples.NumStaticProps()) {
466
6.75k
        SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop,
467
6.75k
                               best->val);
468
165k
      } else {
469
165k
        SplitTreeSamples<false>(tree_samples, begin, best->pos, end,
470
165k
                                best->prop - tree_samples.NumStaticProps(),
471
165k
                                best->val);
472
165k
      }
473
172k
      auto new_sp_range = static_prop_range;
474
172k
      if (p < kNumStaticProperties) {
475
6.75k
        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
476
6.75k
        new_sp_range[p][1] = dequant + 1;
477
6.75k
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
478
6.75k
      }
479
172k
      nodes.push_back(
480
172k
          NodeInfo{(*tree)[pos].rchild, begin, best->pos, new_sp_range});
481
172k
      new_sp_range = static_prop_range;
482
172k
      if (p < kNumStaticProperties) {
483
6.75k
        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
484
6.75k
        new_sp_range[p][0] = dequant + 1;
485
6.75k
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
486
6.75k
      }
487
172k
      nodes.push_back(
488
172k
          NodeInfo{(*tree)[pos].lchild, best->pos, end, new_sp_range});
489
172k
    }
490
346k
  }
491
2.28k
}
Unexecuted instantiation: jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
492
493
// NOLINTNEXTLINE(google-readability-namespace-comments)
494
}  // namespace HWY_NAMESPACE
495
}  // namespace jxl
496
HWY_AFTER_NAMESPACE();
497
498
#if HWY_ONCE
499
namespace jxl {
500
501
HWY_EXPORT(FindBestSplit);  // Local function.
502
503
Status ComputeBestTree(TreeSamples &tree_samples, float threshold,
504
                       const std::vector<ModularMultiplierInfo> &mul_info,
505
                       StaticPropRange static_prop_range,
506
2.28k
                       float fast_decode_multiplier, Tree *tree) {
507
  // TODO(veluca): take into account that different contexts can have different
508
  // uint configs.
509
  //
510
  // Initialize tree.
511
2.28k
  tree->emplace_back();
512
2.28k
  tree->back().property = -1;
513
2.28k
  tree->back().predictor = tree_samples.PredictorFromIndex(0);
514
2.28k
  tree->back().predictor_offset = 0;
515
2.28k
  tree->back().multiplier = 1;
516
2.28k
  JXL_ENSURE(tree_samples.NumProperties() < 64);
517
518
2.28k
  JXL_ENSURE(tree_samples.NumDistinctSamples() <=
519
2.28k
             std::numeric_limits<uint32_t>::max());
520
2.28k
  HWY_DYNAMIC_DISPATCH(FindBestSplit)
521
2.28k
  (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier,
522
2.28k
   tree);
523
2.28k
  return true;
524
2.28k
}
525
526
#if JXL_CXX_LANG < JXL_CXX_17
527
constexpr int32_t TreeSamples::kPropertyRange;
528
constexpr uint32_t TreeSamples::kDedupEntryUnused;
529
#endif
530
531
Status TreeSamples::SetPredictor(Predictor predictor,
532
2.44k
                                 ModularOptions::TreeMode wp_tree_mode) {
533
2.44k
  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
534
0
    predictors = {Predictor::Weighted};
535
0
    residuals.resize(1);
536
0
    return true;
537
0
  }
538
2.44k
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP &&
539
0
      predictor == Predictor::Weighted) {
540
0
    return JXL_FAILURE("Invalid predictor settings");
541
0
  }
542
2.44k
  if (predictor == Predictor::Variable) {
543
0
    for (size_t i = 0; i < kNumModularPredictors; i++) {
544
0
      predictors.push_back(static_cast<Predictor>(i));
545
0
    }
546
0
    std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]);
547
0
    std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]);
548
2.44k
  } else if (predictor == Predictor::Best) {
549
0
    predictors = {Predictor::Weighted, Predictor::Gradient};
550
2.44k
  } else {
551
2.44k
    predictors = {predictor};
552
2.44k
  }
553
2.44k
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
554
0
    predictors.erase(
555
0
        std::remove(predictors.begin(), predictors.end(), Predictor::Weighted),
556
0
        predictors.end());
557
0
  }
558
2.44k
  residuals.resize(predictors.size());
559
2.44k
  return true;
560
2.44k
}
561
562
Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties,
563
2.44k
                                  ModularOptions::TreeMode wp_tree_mode) {
564
2.44k
  props_to_use = properties;
565
2.44k
  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
566
0
    props_to_use = {static_cast<uint32_t>(kWPProp)};
567
0
  }
568
2.44k
  if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) {
569
0
    props_to_use = {static_cast<uint32_t>(kGradientProp)};
570
0
  }
571
2.44k
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
572
0
    props_to_use.erase(
573
0
        std::remove(props_to_use.begin(), props_to_use.end(), kWPProp),
574
0
        props_to_use.end());
575
0
  }
576
2.44k
  if (props_to_use.empty()) {
577
0
    return JXL_FAILURE("Invalid property set configuration");
578
0
  }
579
2.44k
  num_static_props = 0;
580
  // Check that if static properties present, then those are at the beginning.
581
19.5k
  for (size_t i = 0; i < props_to_use.size(); ++i) {
582
17.1k
    uint32_t prop = props_to_use[i];
583
17.1k
    if (prop < kNumStaticProperties) {
584
3.53k
      JXL_DASSERT(i == prop);
585
3.53k
      num_static_props++;
586
3.53k
    }
587
17.1k
  }
588
2.44k
  props.resize(props_to_use.size() - num_static_props);
589
2.44k
  return true;
590
2.44k
}
591
592
5.66k
void TreeSamples::InitTable(size_t log_size) {
593
5.66k
  size_t size = 1ULL << log_size;
594
5.66k
  if (dedup_table_.size() == size) return;
595
4.08k
  dedup_table_.resize(size, kDedupEntryUnused);
596
5.17M
  for (size_t i = 0; i < NumDistinctSamples(); i++) {
597
5.16M
    if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) {
598
5.16M
      AddToTable(i);
599
5.16M
    }
600
5.16M
  }
601
4.08k
}
602
603
44.6M
bool TreeSamples::AddToTableAndMerge(size_t a) {
604
44.6M
  size_t pos1 = Hash1(a);
605
44.6M
  size_t pos2 = Hash2(a);
606
44.6M
  if (dedup_table_[pos1] != kDedupEntryUnused &&
607
40.8M
      IsSameSample(a, dedup_table_[pos1])) {
608
33.9M
    JXL_DASSERT(sample_counts[a] == 1);
609
33.9M
    sample_counts[dedup_table_[pos1]]++;
610
    // Remove from hash table samples that are saturated.
611
33.9M
    if (sample_counts[dedup_table_[pos1]] ==
612
33.9M
        std::numeric_limits<uint16_t>::max()) {
613
18
      dedup_table_[pos1] = kDedupEntryUnused;
614
18
    }
615
33.9M
    return true;
616
33.9M
  }
617
10.6M
  if (dedup_table_[pos2] != kDedupEntryUnused &&
618
6.18M
      IsSameSample(a, dedup_table_[pos2])) {
619
4.51M
    JXL_DASSERT(sample_counts[a] == 1);
620
4.51M
    sample_counts[dedup_table_[pos2]]++;
621
    // Remove from hash table samples that are saturated.
622
4.51M
    if (sample_counts[dedup_table_[pos2]] ==
623
4.51M
        std::numeric_limits<uint16_t>::max()) {
624
0
      dedup_table_[pos2] = kDedupEntryUnused;
625
0
    }
626
4.51M
    return true;
627
4.51M
  }
628
6.11M
  AddToTable(a);
629
6.11M
  return false;
630
10.6M
}
631
632
11.2M
void TreeSamples::AddToTable(size_t a) {
633
11.2M
  size_t pos1 = Hash1(a);
634
11.2M
  size_t pos2 = Hash2(a);
635
11.2M
  if (dedup_table_[pos1] == kDedupEntryUnused) {
636
5.46M
    dedup_table_[pos1] = a;
637
5.81M
  } else if (dedup_table_[pos2] == kDedupEntryUnused) {
638
2.99M
    dedup_table_[pos2] = a;
639
2.99M
  }
640
11.2M
}
641
642
5.66k
void TreeSamples::PrepareForSamples(size_t extra_num_samples) {
643
5.66k
  for (auto &res : residuals) {
644
5.66k
    res.reserve(res.size() + extra_num_samples);
645
5.66k
  }
646
14.3k
  for (size_t i = 0; i < num_static_props; ++i) {
647
8.66k
    static_props[i].reserve(static_props[i].size() + extra_num_samples);
648
8.66k
  }
649
30.9k
  for (auto &p : props) {
650
30.9k
    p.reserve(p.size() + extra_num_samples);
651
30.9k
  }
652
5.66k
  size_t total_num_samples = extra_num_samples + sample_counts.size();
653
5.66k
  size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2);
654
5.66k
  InitTable(next_size);
655
5.66k
}
656
657
55.8M
size_t TreeSamples::Hash1(size_t a) const {
658
55.8M
  constexpr uint64_t constant = 0x1e35a7bd;
659
55.8M
  uint64_t h = constant;
660
55.8M
  for (const auto &r : residuals) {
661
55.8M
    h = h * constant + r[a].tok;
662
55.8M
    h = h * constant + r[a].nbits;
663
55.8M
  }
664
130M
  for (size_t i = 0; i < num_static_props; ++i) {
665
74.4M
    h = h * constant + static_props[i][a];
666
74.4M
  }
667
316M
  for (const auto &p : props) {
668
316M
    h = h * constant + p[a];
669
316M
  }
670
55.8M
  return (h >> 16) & (dedup_table_.size() - 1);
671
55.8M
}
672
55.8M
size_t TreeSamples::Hash2(size_t a) const {
673
55.8M
  constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
674
55.8M
  uint64_t h = constant;
675
130M
  for (size_t i = 0; i < num_static_props; ++i) {
676
74.4M
    h = h * constant ^ static_props[i][a];
677
74.4M
  }
678
316M
  for (const auto &p : props) {
679
316M
    h = h * constant ^ p[a];
680
316M
  }
681
55.8M
  for (const auto &r : residuals) {
682
55.8M
    h = h * constant ^ r[a].tok;
683
55.8M
    h = h * constant ^ r[a].nbits;
684
55.8M
  }
685
55.8M
  return (h >> 16) & (dedup_table_.size() - 1);
686
55.8M
}
687
688
47.0M
bool TreeSamples::IsSameSample(size_t a, size_t b) const {
689
47.0M
  bool ret = true;
690
47.0M
  for (const auto &r : residuals) {
691
47.0M
    if (r[a].tok != r[b].tok) {
692
3.93M
      ret = false;
693
3.93M
    }
694
47.0M
    if (r[a].nbits != r[b].nbits) {
695
2.94M
      ret = false;
696
2.94M
    }
697
47.0M
  }
698
109M
  for (size_t i = 0; i < num_static_props; ++i) {
699
61.9M
    if (static_props[i][a] != static_props[i][b]) {
700
2.47M
      ret = false;
701
2.47M
    }
702
61.9M
  }
703
267M
  for (const auto &p : props) {
704
267M
    if (p[a] != p[b]) {
705
23.7M
      ret = false;
706
23.7M
    }
707
267M
  }
708
47.0M
  return ret;
709
47.0M
}
710
711
void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties,
712
44.6M
                            const pixel_type_w *predictions) {
713
89.2M
  for (size_t i = 0; i < predictors.size(); i++) {
714
44.6M
    pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
715
44.6M
    uint32_t tok, nbits, bits;
716
44.6M
    HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
717
44.6M
    JXL_DASSERT(tok < 256);
718
44.6M
    JXL_DASSERT(nbits < 256);
719
44.6M
    ResidualToken token = {static_cast<uint8_t>(tok),
720
44.6M
                           static_cast<uint8_t>(nbits)};
721
44.6M
    residuals[i].push_back(token);
722
44.6M
  }
723
103M
  for (size_t i = 0; i < num_static_props; ++i) {
724
59.0M
    static_props[i].push_back(QuantizeStaticProperty(i, properties[i]));
725
59.0M
  }
726
297M
  for (size_t i = num_static_props; i < props_to_use.size(); i++) {
727
253M
    props[i - num_static_props].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
728
253M
  }
729
44.6M
  sample_counts.push_back(1);
730
44.6M
  num_samples++;
731
44.6M
  if (AddToTableAndMerge(sample_counts.size() - 1)) {
732
38.4M
    for (auto &r : residuals) r.pop_back();
733
88.7M
    for (size_t i = 0; i < num_static_props; ++i) static_props[i].pop_back();
734
219M
    for (auto &p : props) p.pop_back();
735
38.4M
    sample_counts.pop_back();
736
38.4M
  }
737
44.6M
}
738
739
9.72M
void TreeSamples::Swap(size_t a, size_t b) {
740
9.72M
  if (a == b) return;
741
9.72M
  for (auto &r : residuals) {
742
9.72M
    std::swap(r[a], r[b]);
743
9.72M
  }
744
22.9M
  for (size_t i = 0; i < num_static_props; ++i) {
745
13.2M
    std::swap(static_props[i][a], static_props[i][b]);
746
13.2M
  }
747
54.8M
  for (auto &p : props) {
748
54.8M
    std::swap(p[a], p[b]);
749
54.8M
  }
750
9.72M
  std::swap(sample_counts[a], sample_counts[b]);
751
9.72M
}
752
753
namespace {
754
std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram,
755
7.43k
                                       size_t num_chunks) {
756
7.43k
  if (histogram.empty() || num_chunks == 0) return {};
757
7.43k
  uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
758
7.43k
  if (sum == 0) return {};
759
  // TODO(veluca): selecting distinct quantiles is likely not the best
760
  // way to go about this.
761
7.20k
  std::vector<int32_t> thresholds;
762
7.20k
  uint64_t cumsum = 0;
763
7.20k
  uint64_t threshold = 1;
764
4.01M
  for (size_t i = 0; i < histogram.size(); i++) {
765
4.00M
    cumsum += histogram[i];
766
4.00M
    if (cumsum * num_chunks >= threshold * sum) {
767
42.7k
      thresholds.push_back(i);
768
734k
      while (cumsum * num_chunks >= threshold * sum) threshold++;
769
42.7k
    }
770
4.00M
  }
771
7.20k
  JXL_DASSERT(thresholds.size() <= num_chunks);
772
  // last value collects all histogram and is not really a threshold
773
7.20k
  thresholds.pop_back();
774
7.20k
  return thresholds;
775
7.43k
}
776
777
std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples,
778
5.03k
                                     size_t num_chunks) {
779
5.03k
  if (samples.empty()) return {};
780
3.90k
  int min = *std::min_element(samples.begin(), samples.end());
781
3.90k
  constexpr int kRange = 512;
782
3.90k
  min = jxl::Clamp1(min, -kRange, kRange);
783
3.90k
  std::vector<uint32_t> counts(2 * kRange + 1);
784
6.63M
  for (int s : samples) {
785
6.63M
    uint32_t sample_offset = jxl::Clamp1(s, -kRange, kRange) - min;
786
6.63M
    counts[sample_offset]++;
787
6.63M
  }
788
3.90k
  std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks);
789
33.3k
  for (auto &v : thresholds) v += min;
790
3.90k
  return thresholds;
791
5.03k
}
792
793
// `to[i]` is assigned value `v` conforming `from[v] <= i && from[v-1] > i`.
794
// This is because the decision node in the tree splits on (property) > i,
795
// hence everything that is not > of a threshold should be clustered
796
// together.
797
template <typename T>
798
void QuantMap(const std::vector<int32_t> &from, std::vector<T> &to,
799
17.1k
              size_t num_pegs, int bias) {
800
17.1k
  to.resize(num_pegs);
801
17.1k
  size_t mapped = 0;
802
17.5M
  for (size_t i = 0; i < num_pegs; i++) {
803
17.8M
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
804
285k
      mapped++;
805
285k
    }
806
17.5M
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
807
17.5M
    to[i] = mapped;
808
17.5M
  }
809
17.1k
}
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned short>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned short, std::__1::allocator<unsigned short> >&, unsigned long, int)
Line
Count
Source
799
3.53k
              size_t num_pegs, int bias) {
800
3.53k
  to.resize(num_pegs);
801
3.53k
  size_t mapped = 0;
802
3.61M
  for (size_t i = 0; i < num_pegs; i++) {
803
3.61M
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
804
2.14k
      mapped++;
805
2.14k
    }
806
3.61M
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
807
3.61M
    to[i] = mapped;
808
3.61M
  }
809
3.53k
}
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned char>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> >&, unsigned long, int)
Line
Count
Source
799
13.6k
              size_t num_pegs, int bias) {
800
13.6k
  to.resize(num_pegs);
801
13.6k
  size_t mapped = 0;
802
13.9M
  for (size_t i = 0; i < num_pegs; i++) {
803
14.2M
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
804
283k
      mapped++;
805
283k
    }
806
13.9M
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
807
13.9M
    to[i] = mapped;
808
13.9M
  }
809
13.6k
}
810
}  // namespace
811
812
void TreeSamples::PreQuantizeProperties(
813
    const StaticPropRange &range,
814
    const std::vector<ModularMultiplierInfo> &multiplier_info,
815
    const std::vector<uint32_t> &group_pixel_count,
816
    const std::vector<uint32_t> &channel_pixel_count,
817
    std::vector<pixel_type> &pixel_samples,
818
2.44k
    std::vector<pixel_type> &diff_samples, size_t max_property_values) {
819
  // If we have forced splits because of multipliers, choose channel and group
820
  // thresholds accordingly.
821
2.44k
  std::vector<int32_t> group_multiplier_thresholds;
822
2.44k
  std::vector<int32_t> channel_multiplier_thresholds;
823
2.44k
  for (const auto &v : multiplier_info) {
824
785
    if (v.range[0][0] != range[0][0]) {
825
0
      channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
826
0
    }
827
785
    if (v.range[0][1] != range[0][1]) {
828
0
      channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
829
0
    }
830
785
    if (v.range[1][0] != range[1][0]) {
831
0
      group_multiplier_thresholds.push_back(v.range[1][0] - 1);
832
0
    }
833
785
    if (v.range[1][1] != range[1][1]) {
834
0
      group_multiplier_thresholds.push_back(v.range[1][1] - 1);
835
0
    }
836
785
  }
837
2.44k
  std::sort(channel_multiplier_thresholds.begin(),
838
2.44k
            channel_multiplier_thresholds.end());
839
2.44k
  channel_multiplier_thresholds.resize(
840
2.44k
      std::unique(channel_multiplier_thresholds.begin(),
841
2.44k
                  channel_multiplier_thresholds.end()) -
842
2.44k
      channel_multiplier_thresholds.begin());
843
2.44k
  std::sort(group_multiplier_thresholds.begin(),
844
2.44k
            group_multiplier_thresholds.end());
845
2.44k
  group_multiplier_thresholds.resize(
846
2.44k
      std::unique(group_multiplier_thresholds.begin(),
847
2.44k
                  group_multiplier_thresholds.end()) -
848
2.44k
      group_multiplier_thresholds.begin());
849
850
2.44k
  compact_properties.resize(props_to_use.size());
851
2.44k
  auto quantize_channel = [&]() {
852
2.44k
    if (!channel_multiplier_thresholds.empty()) {
853
0
      return channel_multiplier_thresholds;
854
0
    }
855
2.44k
    return QuantizeHistogram(channel_pixel_count, max_property_values);
856
2.44k
  };
857
2.44k
  auto quantize_group_id = [&]() {
858
1.08k
    if (!group_multiplier_thresholds.empty()) {
859
0
      return group_multiplier_thresholds;
860
0
    }
861
1.08k
    return QuantizeHistogram(group_pixel_count, max_property_values);
862
1.08k
  };
863
2.44k
  auto quantize_coordinate = [&]() {
864
0
    std::vector<int32_t> quantized;
865
0
    quantized.reserve(max_property_values - 1);
866
0
    for (size_t i = 0; i + 1 < max_property_values; i++) {
867
0
      quantized.push_back((i + 1) * 256 / max_property_values - 1);
868
0
    }
869
0
    return quantized;
870
0
  };
871
2.44k
  std::vector<int32_t> abs_pixel_thresholds;
872
2.44k
  std::vector<int32_t> pixel_thresholds;
873
2.44k
  auto quantize_pixel_property = [&]() {
874
0
    if (pixel_thresholds.empty()) {
875
0
      pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values);
876
0
    }
877
0
    return pixel_thresholds;
878
0
  };
879
2.44k
  auto quantize_abs_pixel_property = [&]() {
880
0
    if (abs_pixel_thresholds.empty()) {
881
0
      quantize_pixel_property();  // Compute the non-abs thresholds.
882
0
      for (auto &v : pixel_samples) v = std::abs(v);
883
0
      abs_pixel_thresholds =
884
0
          QuantizeSamples(pixel_samples, max_property_values);
885
0
    }
886
0
    return abs_pixel_thresholds;
887
0
  };
888
2.44k
  std::vector<int32_t> abs_diff_thresholds;
889
2.44k
  std::vector<int32_t> diff_thresholds;
890
11.1k
  auto quantize_diff_property = [&]() {
891
11.1k
    if (diff_thresholds.empty()) {
892
5.03k
      diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
893
5.03k
    }
894
11.1k
    return diff_thresholds;
895
11.1k
  };
896
2.44k
  auto quantize_abs_diff_property = [&]() {
897
0
    if (abs_diff_thresholds.empty()) {
898
0
      quantize_diff_property();  // Compute the non-abs thresholds.
899
0
      for (auto &v : diff_samples) v = std::abs(v);
900
0
      abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
901
0
    }
902
0
    return abs_diff_thresholds;
903
0
  };
904
2.44k
  auto quantize_wp = [&]() {
905
2.44k
    if (max_property_values < 32) {
906
0
      return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0,
907
0
                                  1,    3,   7,   15,  31, 63, 127};
908
0
    }
909
2.44k
    if (max_property_values < 64) {
910
0
      return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23,
911
0
                                  -15,  -11,  -7,   -5,  -3,  -1,  0,   1,
912
0
                                  3,    5,    7,    11,  15,  23,  31,  47,
913
0
                                  63,   95,   127,  191, 255};
914
0
    }
915
2.44k
    return std::vector<int32_t>{
916
2.44k
        -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47,
917
2.44k
        -39,  -31,  -27,  -23,  -19,  -15,  -13, -11, -9,  -7,  -6,
918
2.44k
        -5,   -4,   -3,   -2,   -1,   0,    1,   2,   3,   4,   5,
919
2.44k
        6,    7,    9,    11,   13,   15,   19,  23,  27,  31,  39,
920
2.44k
        47,   55,   63,   79,   95,   111,  127, 159, 191, 223, 255};
921
2.44k
  };
922
923
2.44k
  property_mapping.resize(props_to_use.size() - num_static_props);
924
19.5k
  for (size_t i = 0; i < props_to_use.size(); i++) {
925
17.1k
    if (props_to_use[i] == 0) {
926
2.44k
      compact_properties[i] = quantize_channel();
927
14.6k
    } else if (props_to_use[i] == 1) {
928
1.08k
      compact_properties[i] = quantize_group_id();
929
13.6k
    } else if (props_to_use[i] == 2 || props_to_use[i] == 3) {
930
0
      compact_properties[i] = quantize_coordinate();
931
13.6k
    } else if (props_to_use[i] == 6 || props_to_use[i] == 7 ||
932
13.6k
               props_to_use[i] == 8 ||
933
13.6k
               (props_to_use[i] >= kNumNonrefProperties &&
934
0
                (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) {
935
0
      compact_properties[i] = quantize_pixel_property();
936
13.6k
    } else if (props_to_use[i] == 4 || props_to_use[i] == 5 ||
937
13.6k
               (props_to_use[i] >= kNumNonrefProperties &&
938
0
                (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) {
939
0
      compact_properties[i] = quantize_abs_pixel_property();
940
13.6k
    } else if (props_to_use[i] >= kNumNonrefProperties &&
941
0
               (props_to_use[i] - kNumNonrefProperties) % 4 == 2) {
942
0
      compact_properties[i] = quantize_abs_diff_property();
943
13.6k
    } else if (props_to_use[i] == kWPProp) {
944
2.44k
      compact_properties[i] = quantize_wp();
945
11.1k
    } else {
946
11.1k
      compact_properties[i] = quantize_diff_property();
947
11.1k
    }
948
17.1k
    if (i < num_static_props) {
949
3.53k
      QuantMap(compact_properties[i], static_property_mapping[i],
950
3.53k
               kPropertyRange * 2 + 1, kPropertyRange);
951
13.6k
    } else {
952
13.6k
      QuantMap(compact_properties[i], property_mapping[i - num_static_props],
953
13.6k
               kPropertyRange * 2 + 1, kPropertyRange);
954
13.6k
    }
955
17.1k
  }
956
2.44k
}
957
958
void CollectPixelSamples(const Image &image, const ModularOptions &options,
959
                         uint32_t group_id,
960
                         std::vector<uint32_t> &group_pixel_count,
961
                         std::vector<uint32_t> &channel_pixel_count,
962
                         std::vector<pixel_type> &pixel_samples,
963
3.57k
                         std::vector<pixel_type> &diff_samples) {
964
3.57k
  if (options.nb_repeats == 0) return;
965
3.57k
  if (group_pixel_count.size() <= group_id) {
966
3.57k
    group_pixel_count.resize(group_id + 1);
967
3.57k
  }
968
3.57k
  if (channel_pixel_count.size() < image.channel.size()) {
969
2.49k
    channel_pixel_count.resize(image.channel.size());
970
2.49k
  }
971
3.57k
  Rng rng(group_id);
972
  // Sample 10% of the final number of samples for property quantization.
973
3.57k
  float fraction = std::min(options.nb_repeats * 0.1, 0.99);
974
3.57k
  Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction);
975
3.57k
  size_t total_pixels = 0;
976
3.57k
  std::vector<size_t> channel_ids;
977
9.23k
  for (size_t i = 0; i < image.channel.size(); i++) {
978
6.10k
    if (i >= image.nb_meta_channels &&
979
5.48k
        (image.channel[i].w > options.max_chan_size ||
980
5.07k
         image.channel[i].h > options.max_chan_size)) {
981
443
      break;
982
443
    }
983
5.66k
    if (image.channel[i].w <= 1 || image.channel[i].h == 0) {
984
60
      continue;  // skip empty or width-1 channels.
985
60
    }
986
5.60k
    channel_ids.push_back(i);
987
5.60k
    group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h;
988
5.60k
    channel_pixel_count[i] += image.channel[i].w * image.channel[i].h;
989
5.60k
    total_pixels += image.channel[i].w * image.channel[i].h;
990
5.60k
  }
991
3.57k
  if (channel_ids.empty()) return;
992
3.34k
  pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels);
993
3.34k
  diff_samples.reserve(diff_samples.size() + fraction * total_pixels);
994
3.34k
  size_t i = 0;
995
3.34k
  size_t y = 0;
996
3.34k
  size_t x = 0;
997
4.39M
  auto advance = [&](size_t amount) {
998
4.39M
    x += amount;
999
    // Detect row overflow (rare).
1000
4.92M
    while (x >= image.channel[channel_ids[i]].w) {
1001
533k
      x -= image.channel[channel_ids[i]].w;
1002
533k
      y++;
1003
      // Detect end-of-channel (even rarer).
1004
533k
      if (y == image.channel[channel_ids[i]].h) {
1005
5.60k
        i++;
1006
5.60k
        y = 0;
1007
5.60k
        if (i >= channel_ids.size()) {
1008
3.34k
          return;
1009
3.34k
        }
1010
5.60k
      }
1011
533k
    }
1012
4.39M
  };
1013
3.34k
  advance(rng.Geometric(dist));
1014
4.39M
  for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) {
1015
4.39M
    const pixel_type *row = image.channel[channel_ids[i]].Row(y);
1016
4.39M
    pixel_samples.push_back(row[x]);
1017
4.39M
    size_t xp = x == 0 ? 1 : x - 1;
1018
4.39M
    diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]);
1019
4.39M
  }
1020
3.34k
}
1021
1022
// TODO(veluca): very simple encoding scheme. This should be improved.
1023
Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
1024
3.88k
                    Tree *decoder_tree) {
1025
3.88k
  JXL_ENSURE(tree.size() <= kMaxTreeSize);
1026
3.88k
  std::queue<int> q;
1027
3.88k
  q.push(0);
1028
3.88k
  size_t leaf_id = 0;
1029
3.88k
  decoder_tree->clear();
1030
422k
  while (!q.empty()) {
1031
418k
    int cur = q.front();
1032
418k
    q.pop();
1033
418k
    JXL_ENSURE(tree[cur].property >= -1);
1034
418k
    tokens->emplace_back(kPropertyContext, tree[cur].property + 1);
1035
418k
    if (tree[cur].property == -1) {
1036
211k
      tokens->emplace_back(kPredictorContext,
1037
211k
                           static_cast<int>(tree[cur].predictor));
1038
211k
      tokens->emplace_back(kOffsetContext,
1039
211k
                           PackSigned(tree[cur].predictor_offset));
1040
211k
      uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier);
1041
211k
      uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1;
1042
211k
      tokens->emplace_back(kMultiplierLogContext, mul_log);
1043
211k
      tokens->emplace_back(kMultiplierBitsContext, mul_bits);
1044
211k
      JXL_ENSURE(tree[cur].predictor < Predictor::Best);
1045
211k
      decoder_tree->emplace_back(
1046
211k
          -1, 0, static_cast<int>(leaf_id), 0, tree[cur].predictor,
1047
211k
          tree[cur].predictor_offset, tree[cur].multiplier);
1048
211k
      leaf_id++;
1049
211k
      continue;
1050
211k
    }
1051
207k
    decoder_tree->emplace_back(
1052
207k
        tree[cur].property, tree[cur].splitval,
1053
207k
        static_cast<int>(decoder_tree->size() + q.size() + 1),
1054
207k
        static_cast<int>(decoder_tree->size() + q.size() + 2), Predictor::Zero,
1055
207k
        0, 1);
1056
207k
    q.push(tree[cur].lchild);
1057
207k
    q.push(tree[cur].rchild);
1058
207k
    tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval));
1059
207k
  }
1060
3.88k
  return true;
1061
3.88k
}
1062
1063
}  // namespace jxl
1064
#endif  // HWY_ONCE