Coverage Report

Created: 2025-08-12 07:37

/src/libjxl/lib/jxl/modular/encoding/enc_ma.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/modular/encoding/enc_ma.h"
7
8
#include <algorithm>
9
#include <cstdint>
10
#include <cstdlib>
11
#include <cstring>
12
#include <limits>
13
#include <numeric>
14
#include <queue>
15
#include <vector>
16
17
#include "lib/jxl/ans_params.h"
18
#include "lib/jxl/base/bits.h"
19
#include "lib/jxl/base/common.h"
20
#include "lib/jxl/base/compiler_specific.h"
21
#include "lib/jxl/base/status.h"
22
#include "lib/jxl/dec_ans.h"
23
#include "lib/jxl/modular/encoding/dec_ma.h"
24
#include "lib/jxl/modular/encoding/ma_common.h"
25
#include "lib/jxl/modular/modular_image.h"
26
27
#undef HWY_TARGET_INCLUDE
28
#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc"
29
#include <hwy/foreach_target.h>
30
#include <hwy/highway.h>
31
32
#include "lib/jxl/base/fast_math-inl.h"
33
#include "lib/jxl/base/random.h"
34
#include "lib/jxl/enc_ans.h"
35
#include "lib/jxl/modular/encoding/context_predict.h"
36
#include "lib/jxl/modular/options.h"
37
#include "lib/jxl/pack_signed.h"
38
HWY_BEFORE_NAMESPACE();
39
namespace jxl {
40
namespace HWY_NAMESPACE {
41
42
// These templates are not found via ADL.
43
using hwy::HWY_NAMESPACE::Eq;
44
using hwy::HWY_NAMESPACE::IfThenElse;
45
using hwy::HWY_NAMESPACE::Lt;
46
using hwy::HWY_NAMESPACE::Max;
47
48
const HWY_FULL(float) df;
49
const HWY_FULL(int32_t) di;
50
5.05k
size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
Unexecuted instantiation: jxl::N_SSE4::Padded(unsigned long)
jxl::N_AVX2::Padded(unsigned long)
Line
Count
Source
50
5.05k
size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
Unexecuted instantiation: jxl::N_SSE2::Padded(unsigned long)
51
52
// Compute entropy of the histogram, taking into account the minimum probability
53
// for symbols with non-zero counts.
54
376k
float EstimateBits(const int32_t *counts, size_t num_symbols) {
55
376k
  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
56
376k
  const auto zero = Zero(df);
57
376k
  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
58
376k
  const auto inv_total = Set(df, 1.0f / total);
59
376k
  auto bits_lanes = Zero(df);
60
376k
  auto total_v = Set(di, total);
61
2.61M
  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
62
2.23M
    const auto counts_iv = LoadU(di, &counts[i]);
63
2.23M
    const auto counts_fv = ConvertTo(df, counts_iv);
64
2.23M
    const auto probs = Mul(counts_fv, inv_total);
65
2.23M
    const auto mprobs = Max(probs, minprob);
66
2.23M
    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
67
2.23M
                                 BitCast(di, FastLog2f(df, mprobs)));
68
2.23M
    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
69
2.23M
  }
70
376k
  return GetLane(SumOfLanes(df, bits_lanes));
71
376k
}
Unexecuted instantiation: jxl::N_SSE4::EstimateBits(int const*, unsigned long)
jxl::N_AVX2::EstimateBits(int const*, unsigned long)
Line
Count
Source
54
376k
float EstimateBits(const int32_t *counts, size_t num_symbols) {
55
376k
  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
56
376k
  const auto zero = Zero(df);
57
376k
  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
58
376k
  const auto inv_total = Set(df, 1.0f / total);
59
376k
  auto bits_lanes = Zero(df);
60
376k
  auto total_v = Set(di, total);
61
2.61M
  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
62
2.23M
    const auto counts_iv = LoadU(di, &counts[i]);
63
2.23M
    const auto counts_fv = ConvertTo(df, counts_iv);
64
2.23M
    const auto probs = Mul(counts_fv, inv_total);
65
2.23M
    const auto mprobs = Max(probs, minprob);
66
2.23M
    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
67
2.23M
                                 BitCast(di, FastLog2f(df, mprobs)));
68
2.23M
    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
69
2.23M
  }
70
376k
  return GetLane(SumOfLanes(df, bits_lanes));
71
376k
}
Unexecuted instantiation: jxl::N_SSE2::EstimateBits(int const*, unsigned long)
72
73
void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred,
74
2.49k
                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
75
  // Note that the tree splits on *strictly greater*.
76
2.49k
  (*tree)[pos].lchild = tree->size();
77
2.49k
  (*tree)[pos].rchild = tree->size() + 1;
78
2.49k
  (*tree)[pos].splitval = splitval;
79
2.49k
  (*tree)[pos].property = property;
80
2.49k
  tree->emplace_back();
81
2.49k
  tree->back().property = -1;
82
2.49k
  tree->back().predictor = rpred;
83
2.49k
  tree->back().predictor_offset = roff;
84
2.49k
  tree->back().multiplier = 1;
85
2.49k
  tree->emplace_back();
86
2.49k
  tree->back().property = -1;
87
2.49k
  tree->back().predictor = lpred;
88
2.49k
  tree->back().predictor_offset = loff;
89
2.49k
  tree->back().multiplier = 1;
90
2.49k
}
Unexecuted instantiation: jxl::N_SSE4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
jxl::N_AVX2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Line
Count
Source
74
2.49k
                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
75
  // Note that the tree splits on *strictly greater*.
76
2.49k
  (*tree)[pos].lchild = tree->size();
77
2.49k
  (*tree)[pos].rchild = tree->size() + 1;
78
2.49k
  (*tree)[pos].splitval = splitval;
79
2.49k
  (*tree)[pos].property = property;
80
2.49k
  tree->emplace_back();
81
2.49k
  tree->back().property = -1;
82
2.49k
  tree->back().predictor = rpred;
83
2.49k
  tree->back().predictor_offset = roff;
84
2.49k
  tree->back().multiplier = 1;
85
2.49k
  tree->emplace_back();
86
2.49k
  tree->back().property = -1;
87
2.49k
  tree->back().predictor = lpred;
88
2.49k
  tree->back().predictor_offset = loff;
89
2.49k
  tree->back().multiplier = 1;
90
2.49k
}
Unexecuted instantiation: jxl::N_SSE2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
91
92
enum class IntersectionType { kNone, kPartial, kInside };
93
IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack,
94
5.03k
                               uint32_t &partial_axis, uint32_t &partial_val) {
95
5.03k
  bool partial = false;
96
15.0k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
97
10.0k
    if (haystack[i][0] >= needle[i][1]) {
98
0
      return IntersectionType::kNone;
99
0
    }
100
10.0k
    if (haystack[i][1] <= needle[i][0]) {
101
0
      return IntersectionType::kNone;
102
0
    }
103
10.0k
    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
104
10.0k
      continue;
105
10.0k
    }
106
0
    partial = true;
107
0
    partial_axis = i;
108
0
    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
109
0
      partial_val = haystack[i][0] - 1;
110
0
    } else {
111
0
      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
112
0
                  haystack[i][1] < needle[i][1]);
113
0
      partial_val = haystack[i][1] - 1;
114
0
    }
115
0
  }
116
5.03k
  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
117
5.03k
}
Unexecuted instantiation: jxl::N_SSE4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
jxl::N_AVX2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
Line
Count
Source
94
5.03k
                               uint32_t &partial_axis, uint32_t &partial_val) {
95
5.03k
  bool partial = false;
96
15.0k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
97
10.0k
    if (haystack[i][0] >= needle[i][1]) {
98
0
      return IntersectionType::kNone;
99
0
    }
100
10.0k
    if (haystack[i][1] <= needle[i][0]) {
101
0
      return IntersectionType::kNone;
102
0
    }
103
10.0k
    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
104
10.0k
      continue;
105
10.0k
    }
106
0
    partial = true;
107
0
    partial_axis = i;
108
0
    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
109
0
      partial_val = haystack[i][0] - 1;
110
0
    } else {
111
0
      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
112
0
                  haystack[i][1] < needle[i][1]);
113
0
      partial_val = haystack[i][1] - 1;
114
0
    }
115
0
  }
116
5.03k
  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
117
5.03k
}
Unexecuted instantiation: jxl::N_SSE2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
118
119
template<bool S>
120
void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
121
2.49k
                      size_t end, size_t prop, uint32_t val) {
122
2.49k
  size_t begin_pos = begin;
123
2.49k
  size_t end_pos = pos;
124
193k
  do {
125
476k
    while (begin_pos < pos &&
126
476k
           tree_samples.Property<S>(prop, begin_pos) <= val) {
127
283k
      ++begin_pos;
128
283k
    }
129
476k
    while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) {
130
283k
      ++end_pos;
131
283k
    }
132
193k
    if (begin_pos < pos && end_pos < end) {
133
192k
      tree_samples.Swap(begin_pos, end_pos);
134
192k
    }
135
193k
    ++begin_pos;
136
193k
    ++end_pos;
137
193k
  } while (begin_pos < pos && end_pos < end);
138
2.49k
}
Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
void jxl::N_AVX2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Line
Count
Source
121
318
                      size_t end, size_t prop, uint32_t val) {
122
318
  size_t begin_pos = begin;
123
318
  size_t end_pos = pos;
124
32.2k
  do {
125
99.3k
    while (begin_pos < pos &&
126
99.3k
           tree_samples.Property<S>(prop, begin_pos) <= val) {
127
67.0k
      ++begin_pos;
128
67.0k
    }
129
116k
    while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) {
130
84.1k
      ++end_pos;
131
84.1k
    }
132
32.2k
    if (begin_pos < pos && end_pos < end) {
133
32.1k
      tree_samples.Swap(begin_pos, end_pos);
134
32.1k
    }
135
32.2k
    ++begin_pos;
136
32.2k
    ++end_pos;
137
32.2k
  } while (begin_pos < pos && end_pos < end);
138
318
}
void jxl::N_AVX2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Line
Count
Source
121
2.17k
                      size_t end, size_t prop, uint32_t val) {
122
2.17k
  size_t begin_pos = begin;
123
2.17k
  size_t end_pos = pos;
124
161k
  do {
125
377k
    while (begin_pos < pos &&
126
377k
           tree_samples.Property<S>(prop, begin_pos) <= val) {
127
216k
      ++begin_pos;
128
216k
    }
129
360k
    while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) {
130
199k
      ++end_pos;
131
199k
    }
132
161k
    if (begin_pos < pos && end_pos < end) {
133
160k
      tree_samples.Swap(begin_pos, end_pos);
134
160k
    }
135
161k
    ++begin_pos;
136
161k
    ++end_pos;
137
161k
  } while (begin_pos < pos && end_pos < end);
138
2.17k
}
Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
139
140
template <bool S>
141
void CollectExtraBitsIncrease(TreeSamples &tree_samples,
142
                              const std::vector<ResidualToken> &rtokens,
143
                              std::vector<int> &count_increase,
144
                              std::vector<size_t> &extra_bits_increase,
145
                              size_t begin, size_t end, size_t prop_idx,
146
35.0k
                              size_t max_symbols) {
147
7.84M
  for (size_t i2 = begin; i2 < end; i2++) {
148
7.81M
    const ResidualToken &rt = rtokens[i2];
149
7.81M
    size_t cnt = tree_samples.Count(i2);
150
7.81M
    size_t p = tree_samples.Property<S>(prop_idx, i2);
151
7.81M
    size_t sym = rt.tok;
152
7.81M
    size_t ebi = rt.nbits * cnt;
153
7.81M
    count_increase[p * max_symbols + sym] += cnt;
154
7.81M
    extra_bits_increase[p] += ebi;
155
7.81M
  }
156
35.0k
}
Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
void jxl::N_AVX2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
Line
Count
Source
146
10.0k
                              size_t max_symbols) {
147
2.24M
  for (size_t i2 = begin; i2 < end; i2++) {
148
2.23M
    const ResidualToken &rt = rtokens[i2];
149
2.23M
    size_t cnt = tree_samples.Count(i2);
150
2.23M
    size_t p = tree_samples.Property<S>(prop_idx, i2);
151
2.23M
    size_t sym = rt.tok;
152
2.23M
    size_t ebi = rt.nbits * cnt;
153
2.23M
    count_increase[p * max_symbols + sym] += cnt;
154
2.23M
    extra_bits_increase[p] += ebi;
155
2.23M
  }
156
10.0k
}
void jxl::N_AVX2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
Line
Count
Source
146
25.0k
                              size_t max_symbols) {
147
5.60M
  for (size_t i2 = begin; i2 < end; i2++) {
148
5.58M
    const ResidualToken &rt = rtokens[i2];
149
5.58M
    size_t cnt = tree_samples.Count(i2);
150
5.58M
    size_t p = tree_samples.Property<S>(prop_idx, i2);
151
5.58M
    size_t sym = rt.tok;
152
5.58M
    size_t ebi = rt.nbits * cnt;
153
5.58M
    count_increase[p * max_symbols + sym] += cnt;
154
5.58M
    extra_bits_increase[p] += ebi;
155
5.58M
  }
156
25.0k
}
Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long)
157
158
void FindBestSplit(TreeSamples &tree_samples, float threshold,
159
                   const std::vector<ModularMultiplierInfo> &mul_info,
160
                   StaticPropRange initial_static_prop_range,
161
73
                   float fast_decode_multiplier, Tree *tree) {
162
73
  struct NodeInfo {
163
73
    size_t pos;
164
73
    size_t begin;
165
73
    size_t end;
166
73
    uint64_t used_properties;
167
73
    StaticPropRange static_prop_range;
168
73
  };
169
73
  std::vector<NodeInfo> nodes;
170
73
  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
171
73
                           initial_static_prop_range});
172
173
73
  size_t num_predictors = tree_samples.NumPredictors();
174
73
  size_t num_properties = tree_samples.NumProperties();
175
176
  // TODO(veluca): consider parallelizing the search (processing multiple nodes
177
  // at a time).
178
5.12k
  while (!nodes.empty()) {
179
5.05k
    size_t pos = nodes.back().pos;
180
5.05k
    size_t begin = nodes.back().begin;
181
5.05k
    size_t end = nodes.back().end;
182
5.05k
    uint64_t used_properties = nodes.back().used_properties;
183
5.05k
    StaticPropRange static_prop_range = nodes.back().static_prop_range;
184
5.05k
    nodes.pop_back();
185
5.05k
    if (begin == end) continue;
186
187
5.05k
    struct SplitInfo {
188
5.05k
      size_t prop = 0;
189
5.05k
      uint32_t val = 0;
190
5.05k
      size_t pos = 0;
191
5.05k
      float lcost = std::numeric_limits<float>::max();
192
5.05k
      float rcost = std::numeric_limits<float>::max();
193
5.05k
      Predictor lpred = Predictor::Zero;
194
5.05k
      Predictor rpred = Predictor::Zero;
195
211k
      float Cost() const { return lcost + rcost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
Line
Count
Source
195
211k
      float Cost() const { return lcost + rcost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
196
5.05k
    };
197
198
5.05k
    SplitInfo best_split_static_constant;
199
5.05k
    SplitInfo best_split_static;
200
5.05k
    SplitInfo best_split_nonstatic;
201
5.05k
    SplitInfo best_split_nowp;
202
203
5.05k
    JXL_DASSERT(begin <= end);
204
5.05k
    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
205
206
    // Compute the maximum token in the range.
207
5.05k
    size_t max_symbols = 0;
208
10.1k
    for (size_t pred = 0; pred < num_predictors; pred++) {
209
1.12M
      for (size_t i = begin; i < end; i++) {
210
1.11M
        uint32_t tok = tree_samples.Token(pred, i);
211
1.11M
        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
212
1.11M
      }
213
5.05k
    }
214
5.05k
    max_symbols = Padded(max_symbols);
215
5.05k
    std::vector<int32_t> counts(max_symbols * num_predictors);
216
5.05k
    std::vector<uint32_t> tot_extra_bits(num_predictors);
217
10.1k
    for (size_t pred = 0; pred < num_predictors; pred++) {
218
5.05k
      size_t extra_bits = 0;
219
5.05k
      const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred);
220
1.12M
      for (size_t i = begin; i < end; i++) {
221
1.11M
        const ResidualToken& rt = rtokens[i];
222
1.11M
        size_t count = tree_samples.Count(i);
223
1.11M
        size_t eb = rt.nbits * count;
224
1.11M
        counts[pred * max_symbols + rt.tok] += count;
225
1.11M
        extra_bits += eb;
226
1.11M
      }
227
5.05k
      tot_extra_bits[pred] = extra_bits;
228
5.05k
    }
229
230
5.05k
    float base_bits;
231
5.05k
    {
232
5.05k
      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
233
5.05k
      base_bits =
234
5.05k
          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
235
5.05k
          tot_extra_bits[pred];
236
5.05k
    }
237
238
5.05k
    SplitInfo *best = &best_split_nonstatic;
239
240
5.05k
    SplitInfo forced_split;
241
    // The multiplier ranges cut halfway through the current ranges of static
242
    // properties. We do this even if the current node is not a leaf, to
243
    // minimize the number of nodes in the resulting tree.
244
5.05k
    for (const auto &mmi : mul_info) {
245
5.03k
      uint32_t axis;
246
5.03k
      uint32_t val;
247
5.03k
      IntersectionType t =
248
5.03k
          BoxIntersects(static_prop_range, mmi.range, axis, val);
249
5.03k
      if (t == IntersectionType::kNone) continue;
250
5.03k
      if (t == IntersectionType::kInside) {
251
5.03k
        (*tree)[pos].multiplier = mmi.multiplier;
252
5.03k
        break;
253
5.03k
      }
254
0
      if (t == IntersectionType::kPartial) {
255
0
        JXL_DASSERT(axis < kNumStaticProperties);
256
0
        forced_split.val = tree_samples.QuantizeStaticProperty(axis, val);
257
0
        forced_split.prop = axis;
258
0
        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
259
0
        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
260
0
        best = &forced_split;
261
0
        best->pos = begin;
262
0
        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
263
0
        if (best->prop < tree_samples.NumStaticProps()) {
264
0
        for (size_t x = begin; x < end; x++) {
265
0
          if (tree_samples.Property<true>(best->prop, x) <= best->val) {
266
0
            best->pos++;
267
0
          }
268
0
        }
269
0
      } else {
270
0
        size_t prop = best->prop - tree_samples.NumStaticProps();
271
0
        for (size_t x = begin; x < end; x++) {
272
0
          if (tree_samples.Property<false>(prop, x) <= best->val) {
273
0
            best->pos++;
274
0
          }
275
0
        }
276
0
      }
277
0
        break;
278
0
      }
279
0
    }
280
281
5.05k
    if (best != &forced_split) {
282
5.05k
      std::vector<int> prop_value_used_count;
283
5.05k
      std::vector<int> count_increase;
284
5.05k
      std::vector<size_t> extra_bits_increase;
285
      // For each property, compute which of its values are used, and what
286
      // tokens correspond to those usages. Then, iterate through the values,
287
      // and compute the entropy of each side of the split (of the form `prop >
288
      // threshold`). Finally, find the split that minimizes the cost.
289
5.05k
      struct CostInfo {
290
5.05k
        float cost = std::numeric_limits<float>::max();
291
5.05k
        float extra_cost = 0;
292
371k
        float Cost() const { return cost + extra_cost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
Line
Count
Source
292
371k
        float Cost() const { return cost + extra_cost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
293
5.05k
        Predictor pred;  // will be uninitialized in some cases, but never used.
294
5.05k
      };
295
5.05k
      std::vector<CostInfo> costs_l;
296
5.05k
      std::vector<CostInfo> costs_r;
297
298
5.05k
      std::vector<int32_t> counts_above(max_symbols);
299
5.05k
      std::vector<int32_t> counts_below(max_symbols);
300
301
      // The lower the threshold, the higher the expected noisiness of the
302
      // estimate. Thus, discourage changing predictors.
303
5.05k
      float change_pred_penalty = 800.0f / (100.0f + threshold);
304
40.1k
      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
305
35.0k
           prop++) {
306
35.0k
        costs_l.clear();
307
35.0k
        costs_r.clear();
308
35.0k
        size_t prop_size = tree_samples.NumPropertyValues(prop);
309
35.0k
        if (extra_bits_increase.size() < prop_size) {
310
10.9k
          count_increase.resize(prop_size * max_symbols);
311
10.9k
          extra_bits_increase.resize(prop_size);
312
10.9k
        }
313
        // Clear prop_value_used_count (which cannot be cleared "on the go")
314
35.0k
        prop_value_used_count.clear();
315
35.0k
        prop_value_used_count.resize(prop_size);
316
317
35.0k
        size_t first_used = prop_size;
318
35.0k
        size_t last_used = 0;
319
320
        // TODO(veluca): consider finding multiple splits along a single
321
        // property at the same time, possibly with a bottom-up approach.
322
35.0k
        if (prop < tree_samples.NumStaticProps()) {
323
2.24M
          for (size_t i = begin; i < end; i++) {
324
2.23M
            size_t p = tree_samples.Property<true>(prop, i);
325
2.23M
            prop_value_used_count[p]++;
326
2.23M
            last_used = std::max(last_used, p);
327
2.23M
            first_used = std::min(first_used, p);
328
2.23M
          }
329
25.0k
        } else {
330
25.0k
          size_t prop_idx = prop - tree_samples.NumStaticProps();
331
5.60M
          for (size_t i = begin; i < end; i++) {
332
5.58M
            size_t p = tree_samples.Property<false>(prop_idx, i);
333
5.58M
            prop_value_used_count[p]++;
334
5.58M
            last_used = std::max(last_used, p);
335
5.58M
            first_used = std::min(first_used, p);
336
5.58M
          }
337
25.0k
        }
338
35.0k
        costs_l.resize(last_used - first_used);
339
35.0k
        costs_r.resize(last_used - first_used);
340
        // For all predictors, compute the right and left costs of each split.
341
70.1k
        for (size_t pred = 0; pred < num_predictors; pred++) {
342
          // Compute cost and histogram increments for each property value.
343
35.0k
          const std::vector<ResidualToken> &rtokens =
344
35.0k
              tree_samples.RTokens(pred);
345
35.0k
          if (prop < tree_samples.NumStaticProps()) {
346
10.0k
            CollectExtraBitsIncrease<true>(tree_samples, rtokens,
347
10.0k
                                           count_increase, extra_bits_increase,
348
10.0k
                                           begin, end, prop, max_symbols);
349
25.0k
          } else {
350
25.0k
            CollectExtraBitsIncrease<false>(
351
25.0k
                tree_samples, rtokens, count_increase, extra_bits_increase,
352
25.0k
                begin, end, prop - tree_samples.NumStaticProps(), max_symbols);
353
25.0k
          }
354
35.0k
          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
355
35.0k
                 max_symbols * sizeof counts_above[0]);
356
35.0k
          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
357
35.0k
          size_t extra_bits_below = 0;
358
          // Exclude last used: this ensures neither counts_above nor
359
          // counts_below is empty.
360
329k
          for (size_t i = first_used; i < last_used; i++) {
361
294k
            if (!prop_value_used_count[i]) continue;
362
185k
            extra_bits_below += extra_bits_increase[i];
363
            // The increase for this property value has been used, and will not
364
            // be used again: clear it. Also below.
365
185k
            extra_bits_increase[i] = 0;
366
9.00M
            for (size_t sym = 0; sym < max_symbols; sym++) {
367
8.81M
              counts_above[sym] -= count_increase[i * max_symbols + sym];
368
8.81M
              counts_below[sym] += count_increase[i * max_symbols + sym];
369
8.81M
              count_increase[i * max_symbols + sym] = 0;
370
8.81M
            }
371
185k
            float rcost = EstimateBits(counts_above.data(), max_symbols) +
372
185k
                          tot_extra_bits[pred] - extra_bits_below;
373
185k
            float lcost = EstimateBits(counts_below.data(), max_symbols) +
374
185k
                          extra_bits_below;
375
185k
            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
376
185k
            float penalty = 0;
377
            // Never discourage moving away from the Weighted predictor.
378
185k
            if (tree_samples.PredictorFromIndex(pred) !=
379
185k
                    (*tree)[pos].predictor &&
380
185k
                (*tree)[pos].predictor != Predictor::Weighted) {
381
0
              penalty = change_pred_penalty;
382
0
            }
383
            // If everything else is equal, disfavour Weighted (slower) and
384
            // favour Zero (faster if it's the only predictor used in a
385
            // group+channel combination)
386
185k
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
387
0
              penalty += 1e-8;
388
0
            }
389
185k
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
390
0
              penalty -= 1e-8;
391
0
            }
392
185k
            if (rcost + penalty < costs_r[i - first_used].Cost()) {
393
185k
              costs_r[i - first_used].cost = rcost;
394
185k
              costs_r[i - first_used].extra_cost = penalty;
395
185k
              costs_r[i - first_used].pred =
396
185k
                  tree_samples.PredictorFromIndex(pred);
397
185k
            }
398
185k
            if (lcost + penalty < costs_l[i - first_used].Cost()) {
399
185k
              costs_l[i - first_used].cost = lcost;
400
185k
              costs_l[i - first_used].extra_cost = penalty;
401
185k
              costs_l[i - first_used].pred =
402
185k
                  tree_samples.PredictorFromIndex(pred);
403
185k
            }
404
185k
          }
405
35.0k
        }
406
        // Iterate through the possible splits and find the one with minimum sum
407
        // of costs of the two sides.
408
35.0k
        size_t split = begin;
409
329k
        for (size_t i = first_used; i < last_used; i++) {
410
294k
          if (!prop_value_used_count[i]) continue;
411
185k
          split += prop_value_used_count[i];
412
185k
          float rcost = costs_r[i - first_used].cost;
413
185k
          float lcost = costs_l[i - first_used].cost;
414
          // WP was not used + we would use the WP property or predictor
415
185k
          bool adds_wp =
416
185k
              (tree_samples.PropertyFromIndex(prop) == kWPProp &&
417
185k
               (used_properties & (1LU << prop)) == 0) ||
418
185k
              ((costs_l[i - first_used].pred == Predictor::Weighted ||
419
137k
                costs_r[i - first_used].pred == Predictor::Weighted) &&
420
137k
               (*tree)[pos].predictor != Predictor::Weighted);
421
185k
          bool zero_entropy_side = rcost == 0 || lcost == 0;
422
423
185k
          SplitInfo &best_ref =
424
185k
              tree_samples.PropertyFromIndex(prop) < kNumStaticProperties
425
185k
                  ? (zero_entropy_side ? best_split_static_constant
426
1.07k
                                       : best_split_static)
427
185k
                  : (adds_wp ? best_split_nonstatic : best_split_nowp);
428
185k
          if (lcost + rcost < best_ref.Cost()) {
429
37.3k
            best_ref.prop = prop;
430
37.3k
            best_ref.val = i;
431
37.3k
            best_ref.pos = split;
432
37.3k
            best_ref.lcost = lcost;
433
37.3k
            best_ref.lpred = costs_l[i - first_used].pred;
434
37.3k
            best_ref.rcost = rcost;
435
37.3k
            best_ref.rpred = costs_r[i - first_used].pred;
436
37.3k
          }
437
185k
        }
438
        // Clear extra_bits_increase and cost_increase for last_used.
439
35.0k
        extra_bits_increase[last_used] = 0;
440
1.72M
        for (size_t sym = 0; sym < max_symbols; sym++) {
441
1.69M
          count_increase[last_used * max_symbols + sym] = 0;
442
1.69M
        }
443
35.0k
      }
444
445
      // Try to avoid introducing WP.
446
5.05k
      if (best_split_nowp.Cost() + threshold < base_bits &&
447
5.05k
          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
448
2.31k
        best = &best_split_nowp;
449
2.31k
      }
450
      // Split along static props if possible and not significantly more
451
      // expensive.
452
5.05k
      if (best_split_static.Cost() + threshold < base_bits &&
453
5.05k
          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
454
318
        best = &best_split_static;
455
318
      }
456
      // Split along static props to create constant nodes if possible.
457
5.05k
      if (best_split_static_constant.Cost() + threshold < base_bits) {
458
0
        best = &best_split_static_constant;
459
0
      }
460
5.05k
    }
461
462
5.05k
    if (best->Cost() + threshold < base_bits) {
463
2.49k
      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
464
2.49k
      pixel_type dequant =
465
2.49k
          tree_samples.UnquantizeProperty(best->prop, best->val);
466
      // Split node and try to split children.
467
2.49k
      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
468
      // "Sort" according to winning property
469
2.49k
      if (best->prop < tree_samples.NumStaticProps()) {
470
318
        SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop,
471
318
                               best->val);
472
2.17k
      } else {
473
2.17k
        SplitTreeSamples<false>(tree_samples, begin, best->pos, end,
474
2.17k
                                best->prop - tree_samples.NumStaticProps(),
475
2.17k
                                best->val);
476
2.17k
      }
477
2.49k
      if (p >= kNumStaticProperties) {
478
2.17k
        used_properties |= 1 << best->prop;
479
2.17k
      }
480
2.49k
      auto new_sp_range = static_prop_range;
481
2.49k
      if (p < kNumStaticProperties) {
482
318
        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
483
318
        new_sp_range[p][1] = dequant + 1;
484
318
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
485
318
      }
486
2.49k
      nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
487
2.49k
                               used_properties, new_sp_range});
488
2.49k
      new_sp_range = static_prop_range;
489
2.49k
      if (p < kNumStaticProperties) {
490
318
        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
491
318
        new_sp_range[p][0] = dequant + 1;
492
318
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
493
318
      }
494
2.49k
      nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
495
2.49k
                               used_properties, new_sp_range});
496
2.49k
    }
497
5.05k
  }
498
73
}
Unexecuted instantiation: jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Line
Count
Source
161
73
                   float fast_decode_multiplier, Tree *tree) {
162
73
  struct NodeInfo {
163
73
    size_t pos;
164
73
    size_t begin;
165
73
    size_t end;
166
73
    uint64_t used_properties;
167
73
    StaticPropRange static_prop_range;
168
73
  };
169
73
  std::vector<NodeInfo> nodes;
170
73
  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
171
73
                           initial_static_prop_range});
172
173
73
  size_t num_predictors = tree_samples.NumPredictors();
174
73
  size_t num_properties = tree_samples.NumProperties();
175
176
  // TODO(veluca): consider parallelizing the search (processing multiple nodes
177
  // at a time).
178
5.12k
  while (!nodes.empty()) {
179
5.05k
    size_t pos = nodes.back().pos;
180
5.05k
    size_t begin = nodes.back().begin;
181
5.05k
    size_t end = nodes.back().end;
182
5.05k
    uint64_t used_properties = nodes.back().used_properties;
183
5.05k
    StaticPropRange static_prop_range = nodes.back().static_prop_range;
184
5.05k
    nodes.pop_back();
185
5.05k
    if (begin == end) continue;
186
187
5.05k
    struct SplitInfo {
188
5.05k
      size_t prop = 0;
189
5.05k
      uint32_t val = 0;
190
5.05k
      size_t pos = 0;
191
5.05k
      float lcost = std::numeric_limits<float>::max();
192
5.05k
      float rcost = std::numeric_limits<float>::max();
193
5.05k
      Predictor lpred = Predictor::Zero;
194
5.05k
      Predictor rpred = Predictor::Zero;
195
5.05k
      float Cost() const { return lcost + rcost; }
196
5.05k
    };
197
198
5.05k
    SplitInfo best_split_static_constant;
199
5.05k
    SplitInfo best_split_static;
200
5.05k
    SplitInfo best_split_nonstatic;
201
5.05k
    SplitInfo best_split_nowp;
202
203
5.05k
    JXL_DASSERT(begin <= end);
204
5.05k
    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
205
206
    // Compute the maximum token in the range.
207
5.05k
    size_t max_symbols = 0;
208
10.1k
    for (size_t pred = 0; pred < num_predictors; pred++) {
209
1.12M
      for (size_t i = begin; i < end; i++) {
210
1.11M
        uint32_t tok = tree_samples.Token(pred, i);
211
1.11M
        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
212
1.11M
      }
213
5.05k
    }
214
5.05k
    max_symbols = Padded(max_symbols);
215
5.05k
    std::vector<int32_t> counts(max_symbols * num_predictors);
216
5.05k
    std::vector<uint32_t> tot_extra_bits(num_predictors);
217
10.1k
    for (size_t pred = 0; pred < num_predictors; pred++) {
218
5.05k
      size_t extra_bits = 0;
219
5.05k
      const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred);
220
1.12M
      for (size_t i = begin; i < end; i++) {
221
1.11M
        const ResidualToken& rt = rtokens[i];
222
1.11M
        size_t count = tree_samples.Count(i);
223
1.11M
        size_t eb = rt.nbits * count;
224
1.11M
        counts[pred * max_symbols + rt.tok] += count;
225
1.11M
        extra_bits += eb;
226
1.11M
      }
227
5.05k
      tot_extra_bits[pred] = extra_bits;
228
5.05k
    }
229
230
5.05k
    float base_bits;
231
5.05k
    {
232
5.05k
      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
233
5.05k
      base_bits =
234
5.05k
          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
235
5.05k
          tot_extra_bits[pred];
236
5.05k
    }
237
238
5.05k
    SplitInfo *best = &best_split_nonstatic;
239
240
5.05k
    SplitInfo forced_split;
241
    // The multiplier ranges cut halfway through the current ranges of static
242
    // properties. We do this even if the current node is not a leaf, to
243
    // minimize the number of nodes in the resulting tree.
244
5.05k
    for (const auto &mmi : mul_info) {
245
5.03k
      uint32_t axis;
246
5.03k
      uint32_t val;
247
5.03k
      IntersectionType t =
248
5.03k
          BoxIntersects(static_prop_range, mmi.range, axis, val);
249
5.03k
      if (t == IntersectionType::kNone) continue;
250
5.03k
      if (t == IntersectionType::kInside) {
251
5.03k
        (*tree)[pos].multiplier = mmi.multiplier;
252
5.03k
        break;
253
5.03k
      }
254
0
      if (t == IntersectionType::kPartial) {
255
0
        JXL_DASSERT(axis < kNumStaticProperties);
256
0
        forced_split.val = tree_samples.QuantizeStaticProperty(axis, val);
257
0
        forced_split.prop = axis;
258
0
        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
259
0
        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
260
0
        best = &forced_split;
261
0
        best->pos = begin;
262
0
        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
263
0
        if (best->prop < tree_samples.NumStaticProps()) {
264
0
        for (size_t x = begin; x < end; x++) {
265
0
          if (tree_samples.Property<true>(best->prop, x) <= best->val) {
266
0
            best->pos++;
267
0
          }
268
0
        }
269
0
      } else {
270
0
        size_t prop = best->prop - tree_samples.NumStaticProps();
271
0
        for (size_t x = begin; x < end; x++) {
272
0
          if (tree_samples.Property<false>(prop, x) <= best->val) {
273
0
            best->pos++;
274
0
          }
275
0
        }
276
0
      }
277
0
        break;
278
0
      }
279
0
    }
280
281
5.05k
    if (best != &forced_split) {
282
5.05k
      std::vector<int> prop_value_used_count;
283
5.05k
      std::vector<int> count_increase;
284
5.05k
      std::vector<size_t> extra_bits_increase;
285
      // For each property, compute which of its values are used, and what
286
      // tokens correspond to those usages. Then, iterate through the values,
287
      // and compute the entropy of each side of the split (of the form `prop >
288
      // threshold`). Finally, find the split that minimizes the cost.
289
5.05k
      struct CostInfo {
290
5.05k
        float cost = std::numeric_limits<float>::max();
291
5.05k
        float extra_cost = 0;
292
5.05k
        float Cost() const { return cost + extra_cost; }
293
5.05k
        Predictor pred;  // will be uninitialized in some cases, but never used.
294
5.05k
      };
295
5.05k
      std::vector<CostInfo> costs_l;
296
5.05k
      std::vector<CostInfo> costs_r;
297
298
5.05k
      std::vector<int32_t> counts_above(max_symbols);
299
5.05k
      std::vector<int32_t> counts_below(max_symbols);
300
301
      // The lower the threshold, the higher the expected noisiness of the
302
      // estimate. Thus, discourage changing predictors.
303
5.05k
      float change_pred_penalty = 800.0f / (100.0f + threshold);
304
40.1k
      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
305
35.0k
           prop++) {
306
35.0k
        costs_l.clear();
307
35.0k
        costs_r.clear();
308
35.0k
        size_t prop_size = tree_samples.NumPropertyValues(prop);
309
35.0k
        if (extra_bits_increase.size() < prop_size) {
310
10.9k
          count_increase.resize(prop_size * max_symbols);
311
10.9k
          extra_bits_increase.resize(prop_size);
312
10.9k
        }
313
        // Clear prop_value_used_count (which cannot be cleared "on the go")
314
35.0k
        prop_value_used_count.clear();
315
35.0k
        prop_value_used_count.resize(prop_size);
316
317
35.0k
        size_t first_used = prop_size;
318
35.0k
        size_t last_used = 0;
319
320
        // TODO(veluca): consider finding multiple splits along a single
321
        // property at the same time, possibly with a bottom-up approach.
322
35.0k
        if (prop < tree_samples.NumStaticProps()) {
323
2.24M
          for (size_t i = begin; i < end; i++) {
324
2.23M
            size_t p = tree_samples.Property<true>(prop, i);
325
2.23M
            prop_value_used_count[p]++;
326
2.23M
            last_used = std::max(last_used, p);
327
2.23M
            first_used = std::min(first_used, p);
328
2.23M
          }
329
25.0k
        } else {
330
25.0k
          size_t prop_idx = prop - tree_samples.NumStaticProps();
331
5.60M
          for (size_t i = begin; i < end; i++) {
332
5.58M
            size_t p = tree_samples.Property<false>(prop_idx, i);
333
5.58M
            prop_value_used_count[p]++;
334
5.58M
            last_used = std::max(last_used, p);
335
5.58M
            first_used = std::min(first_used, p);
336
5.58M
          }
337
25.0k
        }
338
35.0k
        costs_l.resize(last_used - first_used);
339
35.0k
        costs_r.resize(last_used - first_used);
340
        // For all predictors, compute the right and left costs of each split.
341
70.1k
        for (size_t pred = 0; pred < num_predictors; pred++) {
342
          // Compute cost and histogram increments for each property value.
343
35.0k
          const std::vector<ResidualToken> &rtokens =
344
35.0k
              tree_samples.RTokens(pred);
345
35.0k
          if (prop < tree_samples.NumStaticProps()) {
346
10.0k
            CollectExtraBitsIncrease<true>(tree_samples, rtokens,
347
10.0k
                                           count_increase, extra_bits_increase,
348
10.0k
                                           begin, end, prop, max_symbols);
349
25.0k
          } else {
350
25.0k
            CollectExtraBitsIncrease<false>(
351
25.0k
                tree_samples, rtokens, count_increase, extra_bits_increase,
352
25.0k
                begin, end, prop - tree_samples.NumStaticProps(), max_symbols);
353
25.0k
          }
354
35.0k
          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
355
35.0k
                 max_symbols * sizeof counts_above[0]);
356
35.0k
          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
357
35.0k
          size_t extra_bits_below = 0;
358
          // Exclude last used: this ensures neither counts_above nor
359
          // counts_below is empty.
360
329k
          for (size_t i = first_used; i < last_used; i++) {
361
294k
            if (!prop_value_used_count[i]) continue;
362
185k
            extra_bits_below += extra_bits_increase[i];
363
            // The increase for this property value has been used, and will not
364
            // be used again: clear it. Also below.
365
185k
            extra_bits_increase[i] = 0;
366
9.00M
            for (size_t sym = 0; sym < max_symbols; sym++) {
367
8.81M
              counts_above[sym] -= count_increase[i * max_symbols + sym];
368
8.81M
              counts_below[sym] += count_increase[i * max_symbols + sym];
369
8.81M
              count_increase[i * max_symbols + sym] = 0;
370
8.81M
            }
371
185k
            float rcost = EstimateBits(counts_above.data(), max_symbols) +
372
185k
                          tot_extra_bits[pred] - extra_bits_below;
373
185k
            float lcost = EstimateBits(counts_below.data(), max_symbols) +
374
185k
                          extra_bits_below;
375
185k
            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
376
185k
            float penalty = 0;
377
            // Never discourage moving away from the Weighted predictor.
378
185k
            if (tree_samples.PredictorFromIndex(pred) !=
379
185k
                    (*tree)[pos].predictor &&
380
185k
                (*tree)[pos].predictor != Predictor::Weighted) {
381
0
              penalty = change_pred_penalty;
382
0
            }
383
            // If everything else is equal, disfavour Weighted (slower) and
384
            // favour Zero (faster if it's the only predictor used in a
385
            // group+channel combination)
386
185k
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
387
0
              penalty += 1e-8;
388
0
            }
389
185k
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
390
0
              penalty -= 1e-8;
391
0
            }
392
185k
            if (rcost + penalty < costs_r[i - first_used].Cost()) {
393
185k
              costs_r[i - first_used].cost = rcost;
394
185k
              costs_r[i - first_used].extra_cost = penalty;
395
185k
              costs_r[i - first_used].pred =
396
185k
                  tree_samples.PredictorFromIndex(pred);
397
185k
            }
398
185k
            if (lcost + penalty < costs_l[i - first_used].Cost()) {
399
185k
              costs_l[i - first_used].cost = lcost;
400
185k
              costs_l[i - first_used].extra_cost = penalty;
401
185k
              costs_l[i - first_used].pred =
402
185k
                  tree_samples.PredictorFromIndex(pred);
403
185k
            }
404
185k
          }
405
35.0k
        }
406
        // Iterate through the possible splits and find the one with minimum sum
407
        // of costs of the two sides.
408
35.0k
        size_t split = begin;
409
329k
        for (size_t i = first_used; i < last_used; i++) {
410
294k
          if (!prop_value_used_count[i]) continue;
411
185k
          split += prop_value_used_count[i];
412
185k
          float rcost = costs_r[i - first_used].cost;
413
185k
          float lcost = costs_l[i - first_used].cost;
414
          // WP was not used + we would use the WP property or predictor
415
185k
          bool adds_wp =
416
185k
              (tree_samples.PropertyFromIndex(prop) == kWPProp &&
417
185k
               (used_properties & (1LU << prop)) == 0) ||
418
185k
              ((costs_l[i - first_used].pred == Predictor::Weighted ||
419
137k
                costs_r[i - first_used].pred == Predictor::Weighted) &&
420
137k
               (*tree)[pos].predictor != Predictor::Weighted);
421
185k
          bool zero_entropy_side = rcost == 0 || lcost == 0;
422
423
185k
          SplitInfo &best_ref =
424
185k
              tree_samples.PropertyFromIndex(prop) < kNumStaticProperties
425
185k
                  ? (zero_entropy_side ? best_split_static_constant
426
1.07k
                                       : best_split_static)
427
185k
                  : (adds_wp ? best_split_nonstatic : best_split_nowp);
428
185k
          if (lcost + rcost < best_ref.Cost()) {
429
37.3k
            best_ref.prop = prop;
430
37.3k
            best_ref.val = i;
431
37.3k
            best_ref.pos = split;
432
37.3k
            best_ref.lcost = lcost;
433
37.3k
            best_ref.lpred = costs_l[i - first_used].pred;
434
37.3k
            best_ref.rcost = rcost;
435
37.3k
            best_ref.rpred = costs_r[i - first_used].pred;
436
37.3k
          }
437
185k
        }
438
        // Clear extra_bits_increase and cost_increase for last_used.
439
35.0k
        extra_bits_increase[last_used] = 0;
440
1.72M
        for (size_t sym = 0; sym < max_symbols; sym++) {
441
1.69M
          count_increase[last_used * max_symbols + sym] = 0;
442
1.69M
        }
443
35.0k
      }
444
445
      // Try to avoid introducing WP.
446
5.05k
      if (best_split_nowp.Cost() + threshold < base_bits &&
447
5.05k
          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
448
2.31k
        best = &best_split_nowp;
449
2.31k
      }
450
      // Split along static props if possible and not significantly more
451
      // expensive.
452
5.05k
      if (best_split_static.Cost() + threshold < base_bits &&
453
5.05k
          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
454
318
        best = &best_split_static;
455
318
      }
456
      // Split along static props to create constant nodes if possible.
457
5.05k
      if (best_split_static_constant.Cost() + threshold < base_bits) {
458
0
        best = &best_split_static_constant;
459
0
      }
460
5.05k
    }
461
462
5.05k
    if (best->Cost() + threshold < base_bits) {
463
2.49k
      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
464
2.49k
      pixel_type dequant =
465
2.49k
          tree_samples.UnquantizeProperty(best->prop, best->val);
466
      // Split node and try to split children.
467
2.49k
      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
468
      // "Sort" according to winning property
469
2.49k
      if (best->prop < tree_samples.NumStaticProps()) {
470
318
        SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop,
471
318
                               best->val);
472
2.17k
      } else {
473
2.17k
        SplitTreeSamples<false>(tree_samples, begin, best->pos, end,
474
2.17k
                                best->prop - tree_samples.NumStaticProps(),
475
2.17k
                                best->val);
476
2.17k
      }
477
2.49k
      if (p >= kNumStaticProperties) {
478
2.17k
        used_properties |= 1 << best->prop;
479
2.17k
      }
480
2.49k
      auto new_sp_range = static_prop_range;
481
2.49k
      if (p < kNumStaticProperties) {
482
318
        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
483
318
        new_sp_range[p][1] = dequant + 1;
484
318
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
485
318
      }
486
2.49k
      nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
487
2.49k
                               used_properties, new_sp_range});
488
2.49k
      new_sp_range = static_prop_range;
489
2.49k
      if (p < kNumStaticProperties) {
490
318
        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
491
318
        new_sp_range[p][0] = dequant + 1;
492
318
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
493
318
      }
494
2.49k
      nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
495
2.49k
                               used_properties, new_sp_range});
496
2.49k
    }
497
5.05k
  }
498
73
}
Unexecuted instantiation: jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
499
500
// NOLINTNEXTLINE(google-readability-namespace-comments)
501
}  // namespace HWY_NAMESPACE
502
}  // namespace jxl
503
HWY_AFTER_NAMESPACE();
504
505
#if HWY_ONCE
506
namespace jxl {
507
508
HWY_EXPORT(FindBestSplit);  // Local function.
509
510
Status ComputeBestTree(TreeSamples &tree_samples, float threshold,
511
                       const std::vector<ModularMultiplierInfo> &mul_info,
512
                       StaticPropRange static_prop_range,
513
73
                       float fast_decode_multiplier, Tree *tree) {
514
  // TODO(veluca): take into account that different contexts can have different
515
  // uint configs.
516
  //
517
  // Initialize tree.
518
73
  tree->emplace_back();
519
73
  tree->back().property = -1;
520
73
  tree->back().predictor = tree_samples.PredictorFromIndex(0);
521
73
  tree->back().predictor_offset = 0;
522
73
  tree->back().multiplier = 1;
523
73
  JXL_ENSURE(tree_samples.NumProperties() < 64);
524
525
73
  JXL_ENSURE(tree_samples.NumDistinctSamples() <=
526
73
             std::numeric_limits<uint32_t>::max());
527
73
  HWY_DYNAMIC_DISPATCH(FindBestSplit)
528
73
  (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier,
529
73
   tree);
530
73
  return true;
531
73
}
532
533
#if JXL_CXX_LANG < JXL_CXX_17
534
constexpr int32_t TreeSamples::kPropertyRange;
535
constexpr uint32_t TreeSamples::kDedupEntryUnused;
536
#endif
537
538
Status TreeSamples::SetPredictor(Predictor predictor,
539
74
                                 ModularOptions::TreeMode wp_tree_mode) {
540
74
  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
541
0
    predictors = {Predictor::Weighted};
542
0
    residuals.resize(1);
543
0
    return true;
544
0
  }
545
74
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP &&
546
74
      predictor == Predictor::Weighted) {
547
0
    return JXL_FAILURE("Invalid predictor settings");
548
0
  }
549
74
  if (predictor == Predictor::Variable) {
550
0
    for (size_t i = 0; i < kNumModularPredictors; i++) {
551
0
      predictors.push_back(static_cast<Predictor>(i));
552
0
    }
553
0
    std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]);
554
0
    std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]);
555
74
  } else if (predictor == Predictor::Best) {
556
0
    predictors = {Predictor::Weighted, Predictor::Gradient};
557
74
  } else {
558
74
    predictors = {predictor};
559
74
  }
560
74
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
561
0
    predictors.erase(
562
0
        std::remove(predictors.begin(), predictors.end(), Predictor::Weighted),
563
0
        predictors.end());
564
0
  }
565
74
  residuals.resize(predictors.size());
566
74
  return true;
567
74
}
568
569
Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties,
570
74
                                  ModularOptions::TreeMode wp_tree_mode) {
571
74
  props_to_use = properties;
572
74
  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
573
0
    props_to_use = {static_cast<uint32_t>(kWPProp)};
574
0
  }
575
74
  if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) {
576
0
    props_to_use = {static_cast<uint32_t>(kGradientProp)};
577
0
  }
578
74
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
579
0
    props_to_use.erase(
580
0
        std::remove(props_to_use.begin(), props_to_use.end(), kWPProp),
581
0
        props_to_use.end());
582
0
  }
583
74
  if (props_to_use.empty()) {
584
0
    return JXL_FAILURE("Invalid property set configuration");
585
0
  }
586
74
  num_static_props = 0;
587
  // Check that if static properties present, then those are at the beginning.
588
592
  for (size_t i = 0; i < props_to_use.size(); ++i) {
589
518
    uint32_t prop = props_to_use[i];
590
518
    if (prop < kNumStaticProperties) {
591
136
      JXL_DASSERT(i == prop);
592
136
      num_static_props++;
593
136
    }
594
518
  }
595
74
  props.resize(props_to_use.size() - num_static_props);
596
74
  return true;
597
74
}
598
599
201
void TreeSamples::InitTable(size_t log_size) {
600
201
  size_t size = 1ULL << log_size;
601
201
  if (dedup_table_.size() == size) return;
602
112
  dedup_table_.resize(size, kDedupEntryUnused);
603
51.7k
  for (size_t i = 0; i < NumDistinctSamples(); i++) {
604
51.6k
    if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) {
605
51.6k
      AddToTable(i);
606
51.6k
    }
607
51.6k
  }
608
112
}
609
610
593k
bool TreeSamples::AddToTableAndMerge(size_t a) {
611
593k
  size_t pos1 = Hash1(a);
612
593k
  size_t pos2 = Hash2(a);
613
593k
  if (dedup_table_[pos1] != kDedupEntryUnused &&
614
593k
      IsSameSample(a, dedup_table_[pos1])) {
615
360k
    JXL_DASSERT(sample_counts[a] == 1);
616
360k
    sample_counts[dedup_table_[pos1]]++;
617
    // Remove from hash table samples that are saturated.
618
360k
    if (sample_counts[dedup_table_[pos1]] ==
619
360k
        std::numeric_limits<uint16_t>::max()) {
620
0
      dedup_table_[pos1] = kDedupEntryUnused;
621
0
    }
622
360k
    return true;
623
360k
  }
624
233k
  if (dedup_table_[pos2] != kDedupEntryUnused &&
625
233k
      IsSameSample(a, dedup_table_[pos2])) {
626
81.6k
    JXL_DASSERT(sample_counts[a] == 1);
627
81.6k
    sample_counts[dedup_table_[pos2]]++;
628
    // Remove from hash table samples that are saturated.
629
81.6k
    if (sample_counts[dedup_table_[pos2]] ==
630
81.6k
        std::numeric_limits<uint16_t>::max()) {
631
0
      dedup_table_[pos2] = kDedupEntryUnused;
632
0
    }
633
81.6k
    return true;
634
81.6k
  }
635
152k
  AddToTable(a);
636
152k
  return false;
637
233k
}
638
639
203k
void TreeSamples::AddToTable(size_t a) {
640
203k
  size_t pos1 = Hash1(a);
641
203k
  size_t pos2 = Hash2(a);
642
203k
  if (dedup_table_[pos1] == kDedupEntryUnused) {
643
107k
    dedup_table_[pos1] = a;
644
107k
  } else if (dedup_table_[pos2] == kDedupEntryUnused) {
645
60.3k
    dedup_table_[pos2] = a;
646
60.3k
  }
647
203k
}
648
649
201
void TreeSamples::PrepareForSamples(size_t extra_num_samples) {
650
201
  for (auto &res : residuals) {
651
201
    res.reserve(res.size() + extra_num_samples);
652
201
  }
653
588
  for (size_t i = 0; i < num_static_props; ++i) {
654
387
    static_props[i].reserve(static_props[i].size() + extra_num_samples);
655
387
  }
656
1.02k
  for (auto &p : props) {
657
1.02k
    p.reserve(p.size() + extra_num_samples);
658
1.02k
  }
659
201
  size_t total_num_samples = extra_num_samples + sample_counts.size();
660
201
  size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2);
661
201
  InitTable(next_size);
662
201
}
663
664
797k
size_t TreeSamples::Hash1(size_t a) const {
665
797k
  constexpr uint64_t constant = 0x1e35a7bd;
666
797k
  uint64_t h = constant;
667
797k
  for (const auto &r : residuals) {
668
797k
    h = h * constant + r[a].tok;
669
797k
    h = h * constant + r[a].nbits;
670
797k
  }
671
2.34M
  for (size_t i = 0; i < num_static_props; ++i) {
672
1.54M
    h = h * constant + static_props[i][a];
673
1.54M
  }
674
4.03M
  for (const auto &p : props) {
675
4.03M
    h = h * constant + p[a];
676
4.03M
  }
677
797k
  return (h >> 16) & (dedup_table_.size() - 1);
678
797k
}
679
797k
size_t TreeSamples::Hash2(size_t a) const {
680
797k
  constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
681
797k
  uint64_t h = constant;
682
2.34M
  for (size_t i = 0; i < num_static_props; ++i) {
683
1.54M
    h = h * constant ^ static_props[i][a];
684
1.54M
  }
685
4.03M
  for (const auto &p : props) {
686
4.03M
    h = h * constant ^ p[a];
687
4.03M
  }
688
797k
  for (const auto &r : residuals) {
689
797k
    h = h * constant ^ r[a].tok;
690
797k
    h = h * constant ^ r[a].nbits;
691
797k
  }
692
797k
  return (h >> 16) & (dedup_table_.size() - 1);
693
797k
}
694
695
625k
bool TreeSamples::IsSameSample(size_t a, size_t b) const {
696
625k
  bool ret = true;
697
625k
  for (const auto &r : residuals) {
698
625k
    if (r[a].tok != r[b].tok) {
699
82.5k
      ret = false;
700
82.5k
    }
701
625k
    if (r[a].nbits != r[b].nbits) {
702
62.4k
      ret = false;
703
62.4k
    }
704
625k
  }
705
1.82M
  for (size_t i = 0; i < num_static_props; ++i) {
706
1.20M
    if (static_props[i][a] != static_props[i][b]) {
707
72.4k
      ret = false;
708
72.4k
    }
709
1.20M
  }
710
3.17M
  for (const auto &p : props) {
711
3.17M
    if (p[a] != p[b]) {
712
490k
      ret = false;
713
490k
    }
714
3.17M
  }
715
625k
  return ret;
716
625k
}
717
718
void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties,
719
593k
                            const pixel_type_w *predictions) {
720
1.18M
  for (size_t i = 0; i < predictors.size(); i++) {
721
593k
    pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
722
593k
    uint32_t tok, nbits, bits;
723
593k
    HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
724
593k
    JXL_DASSERT(tok < 256);
725
593k
    JXL_DASSERT(nbits < 256);
726
593k
    residuals[i].emplace_back(
727
593k
        ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)});
728
593k
  }
729
1.73M
  for (size_t i = 0; i < num_static_props; ++i) {
730
1.13M
    static_props[i].push_back(QuantizeStaticProperty(i, properties[i]));
731
1.13M
  }
732
3.61M
  for (size_t i = num_static_props; i < props_to_use.size(); i++) {
733
3.01M
    props[i - num_static_props].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
734
3.01M
  }
735
593k
  sample_counts.push_back(1);
736
593k
  num_samples++;
737
593k
  if (AddToTableAndMerge(sample_counts.size() - 1)) {
738
441k
    for (auto &r : residuals) r.pop_back();
739
1.27M
    for (size_t i = 0; i < num_static_props; ++i) static_props[i].pop_back();
740
2.25M
    for (auto &p : props) p.pop_back();
741
441k
    sample_counts.pop_back();
742
441k
  }
743
593k
}
744
745
192k
void TreeSamples::Swap(size_t a, size_t b) {
746
192k
  if (a == b) return;
747
192k
  for (auto &r : residuals) {
748
192k
    std::swap(r[a], r[b]);
749
192k
  }
750
578k
  for (size_t i = 0; i < num_static_props; ++i) {
751
385k
    std::swap(static_props[i][a], static_props[i][b]);
752
385k
  }
753
964k
  for (auto &p : props) {
754
964k
    std::swap(p[a], p[b]);
755
964k
  }
756
192k
  std::swap(sample_counts[a], sample_counts[b]);
757
192k
}
758
759
namespace {
760
std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram,
761
240
                                       size_t num_chunks) {
762
240
  if (histogram.empty() || num_chunks == 0) return {};
763
240
  uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
764
240
  if (sum == 0) return {};
765
  // TODO(veluca): selecting distinct quantiles is likely not the best
766
  // way to go about this.
767
238
  std::vector<int32_t> thresholds;
768
238
  uint64_t cumsum = 0;
769
238
  uint64_t threshold = 1;
770
107k
  for (size_t i = 0; i < histogram.size(); i++) {
771
106k
    cumsum += histogram[i];
772
106k
    if (cumsum * num_chunks >= threshold * sum) {
773
1.22k
      thresholds.push_back(i);
774
12.6k
      while (cumsum * num_chunks >= threshold * sum) threshold++;
775
1.22k
    }
776
106k
  }
777
238
  JXL_DASSERT(thresholds.size() <= num_chunks);
778
  // last value collects all histogram and is not really a threshold
779
238
  thresholds.pop_back();
780
238
  return thresholds;
781
240
}
782
783
std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples,
784
114
                                     size_t num_chunks) {
785
114
  if (samples.empty()) return {};
786
104
  int min = *std::min_element(samples.begin(), samples.end());
787
104
  constexpr int kRange = 512;
788
104
  min = jxl::Clamp1(min, -kRange, kRange);
789
104
  std::vector<uint32_t> counts(2 * kRange + 1);
790
74.6k
  for (int s : samples) {
791
74.6k
    uint32_t sample_offset = jxl::Clamp1(s, -kRange, kRange) - min;
792
74.6k
    counts[sample_offset]++;
793
74.6k
  }
794
104
  std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks);
795
859
  for (auto &v : thresholds) v += min;
796
104
  return thresholds;
797
114
}
798
799
// `to[i]` is assigned value `v` conforming `from[v] <= i && from[v-1] > i`.
800
// This is because the decision node in the tree splits on (property) > i,
801
// hence everything that is not > of a threshold should be clustered
802
// together.
803
template <typename T>
804
void QuantMap(const std::vector<int32_t> &from, std::vector<T> &to,
805
518
              size_t num_pegs, int bias) {
806
518
  to.resize(num_pegs);
807
518
  size_t mapped = 0;
808
530k
  for (size_t i = 0; i < num_pegs; i++) {
809
535k
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
810
5.72k
      mapped++;
811
5.72k
    }
812
529k
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
813
529k
    to[i] = mapped;
814
529k
  }
815
518
}
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned short>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned short, std::__1::allocator<unsigned short> >&, unsigned long, int)
Line
Count
Source
805
136
              size_t num_pegs, int bias) {
806
136
  to.resize(num_pegs);
807
136
  size_t mapped = 0;
808
139k
  for (size_t i = 0; i < num_pegs; i++) {
809
139k
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
810
125
      mapped++;
811
125
    }
812
139k
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
813
139k
    to[i] = mapped;
814
139k
  }
815
136
}
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned char>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> >&, unsigned long, int)
Line
Count
Source
805
382
              size_t num_pegs, int bias) {
806
382
  to.resize(num_pegs);
807
382
  size_t mapped = 0;
808
391k
  for (size_t i = 0; i < num_pegs; i++) {
809
396k
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
810
5.59k
      mapped++;
811
5.59k
    }
812
390k
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
813
390k
    to[i] = mapped;
814
390k
  }
815
382
}
816
}  // namespace
817
818
void TreeSamples::PreQuantizeProperties(
819
    const StaticPropRange &range,
820
    const std::vector<ModularMultiplierInfo> &multiplier_info,
821
    const std::vector<uint32_t> &group_pixel_count,
822
    const std::vector<uint32_t> &channel_pixel_count,
823
    std::vector<pixel_type> &pixel_samples,
824
74
    std::vector<pixel_type> &diff_samples, size_t max_property_values) {
825
  // If we have forced splits because of multipliers, choose channel and group
826
  // thresholds accordingly.
827
74
  std::vector<int32_t> group_multiplier_thresholds;
828
74
  std::vector<int32_t> channel_multiplier_thresholds;
829
74
  for (const auto &v : multiplier_info) {
830
62
    if (v.range[0][0] != range[0][0]) {
831
0
      channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
832
0
    }
833
62
    if (v.range[0][1] != range[0][1]) {
834
0
      channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
835
0
    }
836
62
    if (v.range[1][0] != range[1][0]) {
837
0
      group_multiplier_thresholds.push_back(v.range[1][0] - 1);
838
0
    }
839
62
    if (v.range[1][1] != range[1][1]) {
840
0
      group_multiplier_thresholds.push_back(v.range[1][1] - 1);
841
0
    }
842
62
  }
843
74
  std::sort(channel_multiplier_thresholds.begin(),
844
74
            channel_multiplier_thresholds.end());
845
74
  channel_multiplier_thresholds.resize(
846
74
      std::unique(channel_multiplier_thresholds.begin(),
847
74
                  channel_multiplier_thresholds.end()) -
848
74
      channel_multiplier_thresholds.begin());
849
74
  std::sort(group_multiplier_thresholds.begin(),
850
74
            group_multiplier_thresholds.end());
851
74
  group_multiplier_thresholds.resize(
852
74
      std::unique(group_multiplier_thresholds.begin(),
853
74
                  group_multiplier_thresholds.end()) -
854
74
      group_multiplier_thresholds.begin());
855
856
74
  compact_properties.resize(props_to_use.size());
857
74
  auto quantize_channel = [&]() {
858
74
    if (!channel_multiplier_thresholds.empty()) {
859
0
      return channel_multiplier_thresholds;
860
0
    }
861
74
    return QuantizeHistogram(channel_pixel_count, max_property_values);
862
74
  };
863
74
  auto quantize_group_id = [&]() {
864
62
    if (!group_multiplier_thresholds.empty()) {
865
0
      return group_multiplier_thresholds;
866
0
    }
867
62
    return QuantizeHistogram(group_pixel_count, max_property_values);
868
62
  };
869
74
  auto quantize_coordinate = [&]() {
870
0
    std::vector<int32_t> quantized;
871
0
    quantized.reserve(max_property_values - 1);
872
0
    for (size_t i = 0; i + 1 < max_property_values; i++) {
873
0
      quantized.push_back((i + 1) * 256 / max_property_values - 1);
874
0
    }
875
0
    return quantized;
876
0
  };
877
74
  std::vector<int32_t> abs_pixel_thresholds;
878
74
  std::vector<int32_t> pixel_thresholds;
879
74
  auto quantize_pixel_property = [&]() {
880
0
    if (pixel_thresholds.empty()) {
881
0
      pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values);
882
0
    }
883
0
    return pixel_thresholds;
884
0
  };
885
74
  auto quantize_abs_pixel_property = [&]() {
886
0
    if (abs_pixel_thresholds.empty()) {
887
0
      quantize_pixel_property();  // Compute the non-abs thresholds.
888
0
      for (auto &v : pixel_samples) v = std::abs(v);
889
0
      abs_pixel_thresholds =
890
0
          QuantizeSamples(pixel_samples, max_property_values);
891
0
    }
892
0
    return abs_pixel_thresholds;
893
0
  };
894
74
  std::vector<int32_t> abs_diff_thresholds;
895
74
  std::vector<int32_t> diff_thresholds;
896
308
  auto quantize_diff_property = [&]() {
897
308
    if (diff_thresholds.empty()) {
898
114
      diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
899
114
    }
900
308
    return diff_thresholds;
901
308
  };
902
74
  auto quantize_abs_diff_property = [&]() {
903
0
    if (abs_diff_thresholds.empty()) {
904
0
      quantize_diff_property();  // Compute the non-abs thresholds.
905
0
      for (auto &v : diff_samples) v = std::abs(v);
906
0
      abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
907
0
    }
908
0
    return abs_diff_thresholds;
909
0
  };
910
74
  auto quantize_wp = [&]() {
911
74
    if (max_property_values < 32) {
912
0
      return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0,
913
0
                                  1,    3,   7,   15,  31, 63, 127};
914
0
    }
915
74
    if (max_property_values < 64) {
916
74
      return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23,
917
74
                                  -15,  -11,  -7,   -5,  -3,  -1,  0,   1,
918
74
                                  3,    5,    7,    11,  15,  23,  31,  47,
919
74
                                  63,   95,   127,  191, 255};
920
74
    }
921
0
    return std::vector<int32_t>{
922
0
        -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47,
923
0
        -39,  -31,  -27,  -23,  -19,  -15,  -13, -11, -9,  -7,  -6,
924
0
        -5,   -4,   -3,   -2,   -1,   0,    1,   2,   3,   4,   5,
925
0
        6,    7,    9,    11,   13,   15,   19,  23,  27,  31,  39,
926
0
        47,   55,   63,   79,   95,   111,  127, 159, 191, 223, 255};
927
74
  };
928
929
74
  property_mapping.resize(props_to_use.size() - num_static_props);
930
592
  for (size_t i = 0; i < props_to_use.size(); i++) {
931
518
    if (props_to_use[i] == 0) {
932
74
      compact_properties[i] = quantize_channel();
933
444
    } else if (props_to_use[i] == 1) {
934
62
      compact_properties[i] = quantize_group_id();
935
382
    } else if (props_to_use[i] == 2 || props_to_use[i] == 3) {
936
0
      compact_properties[i] = quantize_coordinate();
937
382
    } else if (props_to_use[i] == 6 || props_to_use[i] == 7 ||
938
382
               props_to_use[i] == 8 ||
939
382
               (props_to_use[i] >= kNumNonrefProperties &&
940
382
                (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) {
941
0
      compact_properties[i] = quantize_pixel_property();
942
382
    } else if (props_to_use[i] == 4 || props_to_use[i] == 5 ||
943
382
               (props_to_use[i] >= kNumNonrefProperties &&
944
382
                (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) {
945
0
      compact_properties[i] = quantize_abs_pixel_property();
946
382
    } else if (props_to_use[i] >= kNumNonrefProperties &&
947
382
               (props_to_use[i] - kNumNonrefProperties) % 4 == 2) {
948
0
      compact_properties[i] = quantize_abs_diff_property();
949
382
    } else if (props_to_use[i] == kWPProp) {
950
74
      compact_properties[i] = quantize_wp();
951
308
    } else {
952
308
      compact_properties[i] = quantize_diff_property();
953
308
    }
954
518
    if (i < num_static_props) {
955
136
      QuantMap(compact_properties[i], static_property_mapping[i],
956
136
               kPropertyRange * 2 + 1, kPropertyRange);
957
382
    } else {
958
382
      QuantMap(compact_properties[i], property_mapping[i - num_static_props],
959
382
               kPropertyRange * 2 + 1, kPropertyRange);
960
382
    }
961
518
  }
962
74
}
963
964
void CollectPixelSamples(const Image &image, const ModularOptions &options,
965
                         uint32_t group_id,
966
                         std::vector<uint32_t> &group_pixel_count,
967
                         std::vector<uint32_t> &channel_pixel_count,
968
                         std::vector<pixel_type> &pixel_samples,
969
77
                         std::vector<pixel_type> &diff_samples) {
970
77
  if (options.nb_repeats == 0) return;
971
77
  if (group_pixel_count.size() <= group_id) {
972
77
    group_pixel_count.resize(group_id + 1);
973
77
  }
974
77
  if (channel_pixel_count.size() < image.channel.size()) {
975
74
    channel_pixel_count.resize(image.channel.size());
976
74
  }
977
77
  Rng rng(group_id);
978
  // Sample 10% of the final number of samples for property quantization.
979
77
  float fraction = std::min(options.nb_repeats * 0.1, 0.99);
980
77
  Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction);
981
77
  size_t total_pixels = 0;
982
77
  std::vector<size_t> channel_ids;
983
278
  for (size_t i = 0; i < image.channel.size(); i++) {
984
202
    if (i >= image.nb_meta_channels &&
985
202
        (image.channel[i].w > options.max_chan_size ||
986
201
         image.channel[i].h > options.max_chan_size)) {
987
1
      break;
988
1
    }
989
201
    if (image.channel[i].w <= 1 || image.channel[i].h == 0) {
990
1
      continue;  // skip empty or width-1 channels.
991
1
    }
992
200
    channel_ids.push_back(i);
993
200
    group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h;
994
200
    channel_pixel_count[i] += image.channel[i].w * image.channel[i].h;
995
200
    total_pixels += image.channel[i].w * image.channel[i].h;
996
200
  }
997
77
  if (channel_ids.empty()) return;
998
75
  pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels);
999
75
  diff_samples.reserve(diff_samples.size() + fraction * total_pixels);
1000
75
  size_t i = 0;
1001
75
  size_t y = 0;
1002
75
  size_t x = 0;
1003
56.4k
  auto advance = [&](size_t amount) {
1004
56.4k
    x += amount;
1005
    // Detect row overflow (rare).
1006
69.4k
    while (x >= image.channel[channel_ids[i]].w) {
1007
13.0k
      x -= image.channel[channel_ids[i]].w;
1008
13.0k
      y++;
1009
      // Detect end-of-channel (even rarer).
1010
13.0k
      if (y == image.channel[channel_ids[i]].h) {
1011
200
        i++;
1012
200
        y = 0;
1013
200
        if (i >= channel_ids.size()) {
1014
75
          return;
1015
75
        }
1016
200
      }
1017
13.0k
    }
1018
56.4k
  };
1019
75
  advance(rng.Geometric(dist));
1020
56.4k
  for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) {
1021
56.3k
    const pixel_type *row = image.channel[channel_ids[i]].Row(y);
1022
56.3k
    pixel_samples.push_back(row[x]);
1023
56.3k
    size_t xp = x == 0 ? 1 : x - 1;
1024
56.3k
    diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]);
1025
56.3k
  }
1026
75
}
1027
1028
// TODO(veluca): very simple encoding scheme. This should be improved.
1029
Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
1030
224
                    Tree *decoder_tree) {
1031
224
  JXL_ENSURE(tree.size() <= kMaxTreeSize);
1032
224
  std::queue<int> q;
1033
224
  q.push(0);
1034
224
  size_t leaf_id = 0;
1035
224
  decoder_tree->clear();
1036
12.2k
  while (!q.empty()) {
1037
12.0k
    int cur = q.front();
1038
12.0k
    q.pop();
1039
12.0k
    JXL_ENSURE(tree[cur].property >= -1);
1040
12.0k
    tokens->emplace_back(kPropertyContext, tree[cur].property + 1);
1041
12.0k
    if (tree[cur].property == -1) {
1042
6.12k
      tokens->emplace_back(kPredictorContext,
1043
6.12k
                           static_cast<int>(tree[cur].predictor));
1044
6.12k
      tokens->emplace_back(kOffsetContext,
1045
6.12k
                           PackSigned(tree[cur].predictor_offset));
1046
6.12k
      uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier);
1047
6.12k
      uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1;
1048
6.12k
      tokens->emplace_back(kMultiplierLogContext, mul_log);
1049
6.12k
      tokens->emplace_back(kMultiplierBitsContext, mul_bits);
1050
6.12k
      JXL_ENSURE(tree[cur].predictor < Predictor::Best);
1051
6.12k
      decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor,
1052
6.12k
                                 tree[cur].predictor_offset,
1053
6.12k
                                 tree[cur].multiplier);
1054
6.12k
      continue;
1055
6.12k
    }
1056
5.90k
    decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval,
1057
5.90k
                               decoder_tree->size() + q.size() + 1,
1058
5.90k
                               decoder_tree->size() + q.size() + 2,
1059
5.90k
                               Predictor::Zero, 0, 1);
1060
5.90k
    q.push(tree[cur].lchild);
1061
5.90k
    q.push(tree[cur].rchild);
1062
5.90k
    tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval));
1063
5.90k
  }
1064
224
  return true;
1065
224
}
1066
1067
}  // namespace jxl
1068
#endif  // HWY_ONCE