Coverage Report

Created: 2025-06-16 07:00

/src/libjxl/lib/jxl/modular/encoding/enc_ma.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/modular/encoding/enc_ma.h"
7
8
#include <algorithm>
9
#include <cstdint>
10
#include <cstdlib>
11
#include <cstring>
12
#include <limits>
13
#include <numeric>
14
#include <queue>
15
#include <vector>
16
17
#include "lib/jxl/ans_params.h"
18
#include "lib/jxl/base/bits.h"
19
#include "lib/jxl/base/common.h"
20
#include "lib/jxl/base/compiler_specific.h"
21
#include "lib/jxl/base/status.h"
22
#include "lib/jxl/dec_ans.h"
23
#include "lib/jxl/modular/encoding/dec_ma.h"
24
#include "lib/jxl/modular/encoding/ma_common.h"
25
#include "lib/jxl/modular/modular_image.h"
26
27
#undef HWY_TARGET_INCLUDE
28
#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc"
29
#include <hwy/foreach_target.h>
30
#include <hwy/highway.h>
31
32
#include "lib/jxl/base/fast_math-inl.h"
33
#include "lib/jxl/base/random.h"
34
#include "lib/jxl/enc_ans.h"
35
#include "lib/jxl/modular/encoding/context_predict.h"
36
#include "lib/jxl/modular/options.h"
37
#include "lib/jxl/pack_signed.h"
38
HWY_BEFORE_NAMESPACE();
39
namespace jxl {
40
namespace HWY_NAMESPACE {
41
42
// These templates are not found via ADL.
43
using hwy::HWY_NAMESPACE::Eq;
44
using hwy::HWY_NAMESPACE::IfThenElse;
45
using hwy::HWY_NAMESPACE::Lt;
46
using hwy::HWY_NAMESPACE::Max;
47
48
const HWY_FULL(float) df;
49
const HWY_FULL(int32_t) di;
50
6.07k
size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
Unexecuted instantiation: jxl::N_SSE4::Padded(unsigned long)
jxl::N_AVX2::Padded(unsigned long)
Line
Count
Source
50
6.07k
size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
Unexecuted instantiation: jxl::N_SSE2::Padded(unsigned long)
51
52
// Compute entropy of the histogram, taking into account the minimum probability
53
// for symbols with non-zero counts.
54
370k
float EstimateBits(const int32_t *counts, size_t num_symbols) {
55
370k
  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
56
370k
  const auto zero = Zero(df);
57
370k
  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
58
370k
  const auto inv_total = Set(df, 1.0f / total);
59
370k
  auto bits_lanes = Zero(df);
60
370k
  auto total_v = Set(di, total);
61
2.77M
  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
62
2.40M
    const auto counts_iv = LoadU(di, &counts[i]);
63
2.40M
    const auto counts_fv = ConvertTo(df, counts_iv);
64
2.40M
    const auto probs = Mul(counts_fv, inv_total);
65
2.40M
    const auto mprobs = Max(probs, minprob);
66
2.40M
    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
67
2.40M
                                 BitCast(di, FastLog2f(df, mprobs)));
68
2.40M
    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
69
2.40M
  }
70
370k
  return GetLane(SumOfLanes(df, bits_lanes));
71
370k
}
Unexecuted instantiation: jxl::N_SSE4::EstimateBits(int const*, unsigned long)
jxl::N_AVX2::EstimateBits(int const*, unsigned long)
Line
Count
Source
54
370k
float EstimateBits(const int32_t *counts, size_t num_symbols) {
55
370k
  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
56
370k
  const auto zero = Zero(df);
57
370k
  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
58
370k
  const auto inv_total = Set(df, 1.0f / total);
59
370k
  auto bits_lanes = Zero(df);
60
370k
  auto total_v = Set(di, total);
61
2.77M
  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
62
2.40M
    const auto counts_iv = LoadU(di, &counts[i]);
63
2.40M
    const auto counts_fv = ConvertTo(df, counts_iv);
64
2.40M
    const auto probs = Mul(counts_fv, inv_total);
65
2.40M
    const auto mprobs = Max(probs, minprob);
66
2.40M
    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
67
2.40M
                                 BitCast(di, FastLog2f(df, mprobs)));
68
2.40M
    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
69
2.40M
  }
70
370k
  return GetLane(SumOfLanes(df, bits_lanes));
71
370k
}
Unexecuted instantiation: jxl::N_SSE2::EstimateBits(int const*, unsigned long)
72
73
void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred,
74
2.98k
                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
75
  // Note that the tree splits on *strictly greater*.
76
2.98k
  (*tree)[pos].lchild = tree->size();
77
2.98k
  (*tree)[pos].rchild = tree->size() + 1;
78
2.98k
  (*tree)[pos].splitval = splitval;
79
2.98k
  (*tree)[pos].property = property;
80
2.98k
  tree->emplace_back();
81
2.98k
  tree->back().property = -1;
82
2.98k
  tree->back().predictor = rpred;
83
2.98k
  tree->back().predictor_offset = roff;
84
2.98k
  tree->back().multiplier = 1;
85
2.98k
  tree->emplace_back();
86
2.98k
  tree->back().property = -1;
87
2.98k
  tree->back().predictor = lpred;
88
2.98k
  tree->back().predictor_offset = loff;
89
2.98k
  tree->back().multiplier = 1;
90
2.98k
}
Unexecuted instantiation: jxl::N_SSE4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
jxl::N_AVX2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Line
Count
Source
74
2.98k
                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
75
  // Note that the tree splits on *strictly greater*.
76
2.98k
  (*tree)[pos].lchild = tree->size();
77
2.98k
  (*tree)[pos].rchild = tree->size() + 1;
78
2.98k
  (*tree)[pos].splitval = splitval;
79
2.98k
  (*tree)[pos].property = property;
80
2.98k
  tree->emplace_back();
81
2.98k
  tree->back().property = -1;
82
2.98k
  tree->back().predictor = rpred;
83
2.98k
  tree->back().predictor_offset = roff;
84
2.98k
  tree->back().multiplier = 1;
85
2.98k
  tree->emplace_back();
86
2.98k
  tree->back().property = -1;
87
2.98k
  tree->back().predictor = lpred;
88
2.98k
  tree->back().predictor_offset = loff;
89
2.98k
  tree->back().multiplier = 1;
90
2.98k
}
Unexecuted instantiation: jxl::N_SSE2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
91
92
enum class IntersectionType { kNone, kPartial, kInside };
93
IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack,
94
5.89k
                               uint32_t &partial_axis, uint32_t &partial_val) {
95
5.89k
  bool partial = false;
96
17.6k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
97
11.7k
    if (haystack[i][0] >= needle[i][1]) {
98
0
      return IntersectionType::kNone;
99
0
    }
100
11.7k
    if (haystack[i][1] <= needle[i][0]) {
101
0
      return IntersectionType::kNone;
102
0
    }
103
11.7k
    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
104
11.7k
      continue;
105
11.7k
    }
106
0
    partial = true;
107
0
    partial_axis = i;
108
0
    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
109
0
      partial_val = haystack[i][0] - 1;
110
0
    } else {
111
0
      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
112
0
                  haystack[i][1] < needle[i][1]);
113
0
      partial_val = haystack[i][1] - 1;
114
0
    }
115
0
  }
116
5.89k
  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
117
5.89k
}
Unexecuted instantiation: jxl::N_SSE4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
jxl::N_AVX2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
Line
Count
Source
94
5.89k
                               uint32_t &partial_axis, uint32_t &partial_val) {
95
5.89k
  bool partial = false;
96
17.6k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
97
11.7k
    if (haystack[i][0] >= needle[i][1]) {
98
0
      return IntersectionType::kNone;
99
0
    }
100
11.7k
    if (haystack[i][1] <= needle[i][0]) {
101
0
      return IntersectionType::kNone;
102
0
    }
103
11.7k
    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
104
11.7k
      continue;
105
11.7k
    }
106
0
    partial = true;
107
0
    partial_axis = i;
108
0
    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
109
0
      partial_val = haystack[i][0] - 1;
110
0
    } else {
111
0
      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
112
0
                  haystack[i][1] < needle[i][1]);
113
0
      partial_val = haystack[i][1] - 1;
114
0
    }
115
0
  }
116
5.89k
  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
117
5.89k
}
Unexecuted instantiation: jxl::N_SSE2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
118
119
void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
120
2.98k
                      size_t end, size_t prop, uint32_t val) {
121
2.98k
  size_t begin_pos = begin;
122
2.98k
  size_t end_pos = pos;
123
196k
  do {
124
502k
    while (begin_pos < pos && tree_samples.Property(prop, begin_pos) <= val) {
125
306k
      ++begin_pos;
126
306k
    }
127
521k
    while (end_pos < end && tree_samples.Property(prop, end_pos) > val) {
128
324k
      ++end_pos;
129
324k
    }
130
196k
    if (begin_pos < pos && end_pos < end) {
131
195k
      tree_samples.Swap(begin_pos, end_pos);
132
195k
    }
133
196k
    ++begin_pos;
134
196k
    ++end_pos;
135
196k
  } while (begin_pos < pos && end_pos < end);
136
2.98k
}
Unexecuted instantiation: jxl::N_SSE4::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
jxl::N_AVX2::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Line
Count
Source
120
2.98k
                      size_t end, size_t prop, uint32_t val) {
121
2.98k
  size_t begin_pos = begin;
122
2.98k
  size_t end_pos = pos;
123
196k
  do {
124
502k
    while (begin_pos < pos && tree_samples.Property(prop, begin_pos) <= val) {
125
306k
      ++begin_pos;
126
306k
    }
127
521k
    while (end_pos < end && tree_samples.Property(prop, end_pos) > val) {
128
324k
      ++end_pos;
129
324k
    }
130
196k
    if (begin_pos < pos && end_pos < end) {
131
195k
      tree_samples.Swap(begin_pos, end_pos);
132
195k
    }
133
196k
    ++begin_pos;
134
196k
    ++end_pos;
135
196k
  } while (begin_pos < pos && end_pos < end);
136
2.98k
}
Unexecuted instantiation: jxl::N_SSE2::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
137
138
void FindBestSplit(TreeSamples &tree_samples, float threshold,
139
                   const std::vector<ModularMultiplierInfo> &mul_info,
140
                   StaticPropRange initial_static_prop_range,
141
111
                   float fast_decode_multiplier, Tree *tree) {
142
111
  struct NodeInfo {
143
111
    size_t pos;
144
111
    size_t begin;
145
111
    size_t end;
146
111
    uint64_t used_properties;
147
111
    StaticPropRange static_prop_range;
148
111
  };
149
111
  std::vector<NodeInfo> nodes;
150
111
  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
151
111
                           initial_static_prop_range});
152
153
111
  size_t num_predictors = tree_samples.NumPredictors();
154
111
  size_t num_properties = tree_samples.NumProperties();
155
156
  // TODO(veluca): consider parallelizing the search (processing multiple nodes
157
  // at a time).
158
6.18k
  while (!nodes.empty()) {
159
6.07k
    size_t pos = nodes.back().pos;
160
6.07k
    size_t begin = nodes.back().begin;
161
6.07k
    size_t end = nodes.back().end;
162
6.07k
    uint64_t used_properties = nodes.back().used_properties;
163
6.07k
    StaticPropRange static_prop_range = nodes.back().static_prop_range;
164
6.07k
    nodes.pop_back();
165
6.07k
    if (begin == end) continue;
166
167
6.07k
    struct SplitInfo {
168
6.07k
      size_t prop = 0;
169
6.07k
      uint32_t val = 0;
170
6.07k
      size_t pos = 0;
171
6.07k
      float lcost = std::numeric_limits<float>::max();
172
6.07k
      float rcost = std::numeric_limits<float>::max();
173
6.07k
      Predictor lpred = Predictor::Zero;
174
6.07k
      Predictor rpred = Predictor::Zero;
175
213k
      float Cost() const { return lcost + rcost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
Line
Count
Source
175
213k
      float Cost() const { return lcost + rcost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
176
6.07k
    };
177
178
6.07k
    SplitInfo best_split_static_constant;
179
6.07k
    SplitInfo best_split_static;
180
6.07k
    SplitInfo best_split_nonstatic;
181
6.07k
    SplitInfo best_split_nowp;
182
183
6.07k
    JXL_DASSERT(begin <= end);
184
6.07k
    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
185
186
    // Compute the maximum token in the range.
187
6.07k
    size_t max_symbols = 0;
188
12.1k
    for (size_t pred = 0; pred < num_predictors; pred++) {
189
1.23M
      for (size_t i = begin; i < end; i++) {
190
1.22M
        uint32_t tok = tree_samples.Token(pred, i);
191
1.22M
        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
192
1.22M
      }
193
6.07k
    }
194
6.07k
    max_symbols = Padded(max_symbols);
195
6.07k
    std::vector<int32_t> counts(max_symbols * num_predictors);
196
6.07k
    std::vector<uint32_t> tot_extra_bits(num_predictors);
197
12.1k
    for (size_t pred = 0; pred < num_predictors; pred++) {
198
1.23M
      for (size_t i = begin; i < end; i++) {
199
1.22M
        counts[pred * max_symbols + tree_samples.Token(pred, i)] +=
200
1.22M
            tree_samples.Count(i);
201
1.22M
        tot_extra_bits[pred] +=
202
1.22M
            tree_samples.NBits(pred, i) * tree_samples.Count(i);
203
1.22M
      }
204
6.07k
    }
205
206
6.07k
    float base_bits;
207
6.07k
    {
208
6.07k
      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
209
6.07k
      base_bits =
210
6.07k
          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
211
6.07k
          tot_extra_bits[pred];
212
6.07k
    }
213
214
6.07k
    SplitInfo *best = &best_split_nonstatic;
215
216
6.07k
    SplitInfo forced_split;
217
    // The multiplier ranges cut halfway through the current ranges of static
218
    // properties. We do this even if the current node is not a leaf, to
219
    // minimize the number of nodes in the resulting tree.
220
6.07k
    for (const auto &mmi : mul_info) {
221
5.89k
      uint32_t axis;
222
5.89k
      uint32_t val;
223
5.89k
      IntersectionType t =
224
5.89k
          BoxIntersects(static_prop_range, mmi.range, axis, val);
225
5.89k
      if (t == IntersectionType::kNone) continue;
226
5.89k
      if (t == IntersectionType::kInside) {
227
5.89k
        (*tree)[pos].multiplier = mmi.multiplier;
228
5.89k
        break;
229
5.89k
      }
230
0
      if (t == IntersectionType::kPartial) {
231
0
        forced_split.val = tree_samples.QuantizeProperty(axis, val);
232
0
        forced_split.prop = axis;
233
0
        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
234
0
        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
235
0
        best = &forced_split;
236
0
        best->pos = begin;
237
0
        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
238
0
        for (size_t x = begin; x < end; x++) {
239
0
          if (tree_samples.Property(best->prop, x) <= best->val) {
240
0
            best->pos++;
241
0
          }
242
0
        }
243
0
        break;
244
0
      }
245
0
    }
246
247
6.07k
    if (best != &forced_split) {
248
6.07k
      std::vector<int> prop_value_used_count;
249
6.07k
      std::vector<int> count_increase;
250
6.07k
      std::vector<size_t> extra_bits_increase;
251
      // For each property, compute which of its values are used, and what
252
      // tokens correspond to those usages. Then, iterate through the values,
253
      // and compute the entropy of each side of the split (of the form `prop >
254
      // threshold`). Finally, find the split that minimizes the cost.
255
6.07k
      struct CostInfo {
256
6.07k
        float cost = std::numeric_limits<float>::max();
257
6.07k
        float extra_cost = 0;
258
364k
        float Cost() const { return cost + extra_cost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
Line
Count
Source
258
364k
        float Cost() const { return cost + extra_cost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
259
6.07k
        Predictor pred;  // will be uninitialized in some cases, but never used.
260
6.07k
      };
261
6.07k
      std::vector<CostInfo> costs_l;
262
6.07k
      std::vector<CostInfo> costs_r;
263
264
6.07k
      std::vector<int32_t> counts_above(max_symbols);
265
6.07k
      std::vector<int32_t> counts_below(max_symbols);
266
267
      // The lower the threshold, the higher the expected noisiness of the
268
      // estimate. Thus, discourage changing predictors.
269
6.07k
      float change_pred_penalty = 800.0f / (100.0f + threshold);
270
48.1k
      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
271
42.1k
           prop++) {
272
42.1k
        costs_l.clear();
273
42.1k
        costs_r.clear();
274
42.1k
        size_t prop_size = tree_samples.NumPropertyValues(prop);
275
42.1k
        if (extra_bits_increase.size() < prop_size) {
276
12.2k
          count_increase.resize(prop_size * max_symbols);
277
12.2k
          extra_bits_increase.resize(prop_size);
278
12.2k
        }
279
        // Clear prop_value_used_count (which cannot be cleared "on the go")
280
42.1k
        prop_value_used_count.clear();
281
42.1k
        prop_value_used_count.resize(prop_size);
282
283
42.1k
        size_t first_used = prop_size;
284
42.1k
        size_t last_used = 0;
285
286
        // TODO(veluca): consider finding multiple splits along a single
287
        // property at the same time, possibly with a bottom-up approach.
288
8.62M
        for (size_t i = begin; i < end; i++) {
289
8.57M
          size_t p = tree_samples.Property(prop, i);
290
8.57M
          prop_value_used_count[p]++;
291
8.57M
          last_used = std::max(last_used, p);
292
8.57M
          first_used = std::min(first_used, p);
293
8.57M
        }
294
42.1k
        costs_l.resize(last_used - first_used);
295
42.1k
        costs_r.resize(last_used - first_used);
296
        // For all predictors, compute the right and left costs of each split.
297
84.2k
        for (size_t pred = 0; pred < num_predictors; pred++) {
298
          // Compute cost and histogram increments for each property value.
299
8.62M
          for (size_t i = begin; i < end; i++) {
300
8.57M
            size_t p = tree_samples.Property(prop, i);
301
8.57M
            size_t cnt = tree_samples.Count(i);
302
8.57M
            size_t sym = tree_samples.Token(pred, i);
303
8.57M
            count_increase[p * max_symbols + sym] += cnt;
304
8.57M
            extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt;
305
8.57M
          }
306
42.1k
          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
307
42.1k
                 max_symbols * sizeof counts_above[0]);
308
42.1k
          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
309
42.1k
          size_t extra_bits_below = 0;
310
          // Exclude last used: this ensures neither counts_above nor
311
          // counts_below is empty.
312
351k
          for (size_t i = first_used; i < last_used; i++) {
313
309k
            if (!prop_value_used_count[i]) continue;
314
182k
            extra_bits_below += extra_bits_increase[i];
315
            // The increase for this property value has been used, and will not
316
            // be used again: clear it. Also below.
317
182k
            extra_bits_increase[i] = 0;
318
9.63M
            for (size_t sym = 0; sym < max_symbols; sym++) {
319
9.44M
              counts_above[sym] -= count_increase[i * max_symbols + sym];
320
9.44M
              counts_below[sym] += count_increase[i * max_symbols + sym];
321
9.44M
              count_increase[i * max_symbols + sym] = 0;
322
9.44M
            }
323
182k
            float rcost = EstimateBits(counts_above.data(), max_symbols) +
324
182k
                          tot_extra_bits[pred] - extra_bits_below;
325
182k
            float lcost = EstimateBits(counts_below.data(), max_symbols) +
326
182k
                          extra_bits_below;
327
182k
            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
328
182k
            float penalty = 0;
329
            // Never discourage moving away from the Weighted predictor.
330
182k
            if (tree_samples.PredictorFromIndex(pred) !=
331
182k
                    (*tree)[pos].predictor &&
332
182k
                (*tree)[pos].predictor != Predictor::Weighted) {
333
0
              penalty = change_pred_penalty;
334
0
            }
335
            // If everything else is equal, disfavour Weighted (slower) and
336
            // favour Zero (faster if it's the only predictor used in a
337
            // group+channel combination)
338
182k
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
339
0
              penalty += 1e-8;
340
0
            }
341
182k
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
342
0
              penalty -= 1e-8;
343
0
            }
344
182k
            if (rcost + penalty < costs_r[i - first_used].Cost()) {
345
182k
              costs_r[i - first_used].cost = rcost;
346
182k
              costs_r[i - first_used].extra_cost = penalty;
347
182k
              costs_r[i - first_used].pred =
348
182k
                  tree_samples.PredictorFromIndex(pred);
349
182k
            }
350
182k
            if (lcost + penalty < costs_l[i - first_used].Cost()) {
351
182k
              costs_l[i - first_used].cost = lcost;
352
182k
              costs_l[i - first_used].extra_cost = penalty;
353
182k
              costs_l[i - first_used].pred =
354
182k
                  tree_samples.PredictorFromIndex(pred);
355
182k
            }
356
182k
          }
357
42.1k
        }
358
        // Iterate through the possible splits and find the one with minimum sum
359
        // of costs of the two sides.
360
42.1k
        size_t split = begin;
361
351k
        for (size_t i = first_used; i < last_used; i++) {
362
309k
          if (!prop_value_used_count[i]) continue;
363
182k
          split += prop_value_used_count[i];
364
182k
          float rcost = costs_r[i - first_used].cost;
365
182k
          float lcost = costs_l[i - first_used].cost;
366
          // WP was not used + we would use the WP property or predictor
367
182k
          bool adds_wp =
368
182k
              (tree_samples.PropertyFromIndex(prop) == kWPProp &&
369
182k
               (used_properties & (1LU << prop)) == 0) ||
370
182k
              ((costs_l[i - first_used].pred == Predictor::Weighted ||
371
128k
                costs_r[i - first_used].pred == Predictor::Weighted) &&
372
128k
               (*tree)[pos].predictor != Predictor::Weighted);
373
182k
          bool zero_entropy_side = rcost == 0 || lcost == 0;
374
375
182k
          SplitInfo &best_ref =
376
182k
              tree_samples.PropertyFromIndex(prop) < kNumStaticProperties
377
182k
                  ? (zero_entropy_side ? best_split_static_constant
378
1.27k
                                       : best_split_static)
379
182k
                  : (adds_wp ? best_split_nonstatic : best_split_nowp);
380
182k
          if (lcost + rcost < best_ref.Cost()) {
381
39.1k
            best_ref.prop = prop;
382
39.1k
            best_ref.val = i;
383
39.1k
            best_ref.pos = split;
384
39.1k
            best_ref.lcost = lcost;
385
39.1k
            best_ref.lpred = costs_l[i - first_used].pred;
386
39.1k
            best_ref.rcost = rcost;
387
39.1k
            best_ref.rpred = costs_r[i - first_used].pred;
388
39.1k
          }
389
182k
        }
390
        // Clear extra_bits_increase and cost_increase for last_used.
391
42.1k
        extra_bits_increase[last_used] = 0;
392
2.39M
        for (size_t sym = 0; sym < max_symbols; sym++) {
393
2.35M
          count_increase[last_used * max_symbols + sym] = 0;
394
2.35M
        }
395
42.1k
      }
396
397
      // Try to avoid introducing WP.
398
6.07k
      if (best_split_nowp.Cost() + threshold < base_bits &&
399
6.07k
          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
400
2.76k
        best = &best_split_nowp;
401
2.76k
      }
402
      // Split along static props if possible and not significantly more
403
      // expensive.
404
6.07k
      if (best_split_static.Cost() + threshold < base_bits &&
405
6.07k
          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
406
402
        best = &best_split_static;
407
402
      }
408
      // Split along static props to create constant nodes if possible.
409
6.07k
      if (best_split_static_constant.Cost() + threshold < base_bits) {
410
2
        best = &best_split_static_constant;
411
2
      }
412
6.07k
    }
413
414
6.07k
    if (best->Cost() + threshold < base_bits) {
415
2.98k
      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
416
2.98k
      pixel_type dequant =
417
2.98k
          tree_samples.UnquantizeProperty(best->prop, best->val);
418
      // Split node and try to split children.
419
2.98k
      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
420
      // "Sort" according to winning property
421
2.98k
      SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop,
422
2.98k
                       best->val);
423
2.98k
      if (p >= kNumStaticProperties) {
424
2.57k
        used_properties |= 1 << best->prop;
425
2.57k
      }
426
2.98k
      auto new_sp_range = static_prop_range;
427
2.98k
      if (p < kNumStaticProperties) {
428
404
        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
429
404
        new_sp_range[p][1] = dequant + 1;
430
404
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
431
404
      }
432
2.98k
      nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
433
2.98k
                               used_properties, new_sp_range});
434
2.98k
      new_sp_range = static_prop_range;
435
2.98k
      if (p < kNumStaticProperties) {
436
404
        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
437
404
        new_sp_range[p][0] = dequant + 1;
438
404
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
439
404
      }
440
2.98k
      nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
441
2.98k
                               used_properties, new_sp_range});
442
2.98k
    }
443
6.07k
  }
444
111
}
Unexecuted instantiation: jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Line
Count
Source
141
111
                   float fast_decode_multiplier, Tree *tree) {
142
111
  struct NodeInfo {
143
111
    size_t pos;
144
111
    size_t begin;
145
111
    size_t end;
146
111
    uint64_t used_properties;
147
111
    StaticPropRange static_prop_range;
148
111
  };
149
111
  std::vector<NodeInfo> nodes;
150
111
  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
151
111
                           initial_static_prop_range});
152
153
111
  size_t num_predictors = tree_samples.NumPredictors();
154
111
  size_t num_properties = tree_samples.NumProperties();
155
156
  // TODO(veluca): consider parallelizing the search (processing multiple nodes
157
  // at a time).
158
6.18k
  while (!nodes.empty()) {
159
6.07k
    size_t pos = nodes.back().pos;
160
6.07k
    size_t begin = nodes.back().begin;
161
6.07k
    size_t end = nodes.back().end;
162
6.07k
    uint64_t used_properties = nodes.back().used_properties;
163
6.07k
    StaticPropRange static_prop_range = nodes.back().static_prop_range;
164
6.07k
    nodes.pop_back();
165
6.07k
    if (begin == end) continue;
166
167
6.07k
    struct SplitInfo {
168
6.07k
      size_t prop = 0;
169
6.07k
      uint32_t val = 0;
170
6.07k
      size_t pos = 0;
171
6.07k
      float lcost = std::numeric_limits<float>::max();
172
6.07k
      float rcost = std::numeric_limits<float>::max();
173
6.07k
      Predictor lpred = Predictor::Zero;
174
6.07k
      Predictor rpred = Predictor::Zero;
175
6.07k
      float Cost() const { return lcost + rcost; }
176
6.07k
    };
177
178
6.07k
    SplitInfo best_split_static_constant;
179
6.07k
    SplitInfo best_split_static;
180
6.07k
    SplitInfo best_split_nonstatic;
181
6.07k
    SplitInfo best_split_nowp;
182
183
6.07k
    JXL_DASSERT(begin <= end);
184
6.07k
    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
185
186
    // Compute the maximum token in the range.
187
6.07k
    size_t max_symbols = 0;
188
12.1k
    for (size_t pred = 0; pred < num_predictors; pred++) {
189
1.23M
      for (size_t i = begin; i < end; i++) {
190
1.22M
        uint32_t tok = tree_samples.Token(pred, i);
191
1.22M
        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
192
1.22M
      }
193
6.07k
    }
194
6.07k
    max_symbols = Padded(max_symbols);
195
6.07k
    std::vector<int32_t> counts(max_symbols * num_predictors);
196
6.07k
    std::vector<uint32_t> tot_extra_bits(num_predictors);
197
12.1k
    for (size_t pred = 0; pred < num_predictors; pred++) {
198
1.23M
      for (size_t i = begin; i < end; i++) {
199
1.22M
        counts[pred * max_symbols + tree_samples.Token(pred, i)] +=
200
1.22M
            tree_samples.Count(i);
201
1.22M
        tot_extra_bits[pred] +=
202
1.22M
            tree_samples.NBits(pred, i) * tree_samples.Count(i);
203
1.22M
      }
204
6.07k
    }
205
206
6.07k
    float base_bits;
207
6.07k
    {
208
6.07k
      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
209
6.07k
      base_bits =
210
6.07k
          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
211
6.07k
          tot_extra_bits[pred];
212
6.07k
    }
213
214
6.07k
    SplitInfo *best = &best_split_nonstatic;
215
216
6.07k
    SplitInfo forced_split;
217
    // The multiplier ranges cut halfway through the current ranges of static
218
    // properties. We do this even if the current node is not a leaf, to
219
    // minimize the number of nodes in the resulting tree.
220
6.07k
    for (const auto &mmi : mul_info) {
221
5.89k
      uint32_t axis;
222
5.89k
      uint32_t val;
223
5.89k
      IntersectionType t =
224
5.89k
          BoxIntersects(static_prop_range, mmi.range, axis, val);
225
5.89k
      if (t == IntersectionType::kNone) continue;
226
5.89k
      if (t == IntersectionType::kInside) {
227
5.89k
        (*tree)[pos].multiplier = mmi.multiplier;
228
5.89k
        break;
229
5.89k
      }
230
0
      if (t == IntersectionType::kPartial) {
231
0
        forced_split.val = tree_samples.QuantizeProperty(axis, val);
232
0
        forced_split.prop = axis;
233
0
        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
234
0
        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
235
0
        best = &forced_split;
236
0
        best->pos = begin;
237
0
        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
238
0
        for (size_t x = begin; x < end; x++) {
239
0
          if (tree_samples.Property(best->prop, x) <= best->val) {
240
0
            best->pos++;
241
0
          }
242
0
        }
243
0
        break;
244
0
      }
245
0
    }
246
247
6.07k
    if (best != &forced_split) {
248
6.07k
      std::vector<int> prop_value_used_count;
249
6.07k
      std::vector<int> count_increase;
250
6.07k
      std::vector<size_t> extra_bits_increase;
251
      // For each property, compute which of its values are used, and what
252
      // tokens correspond to those usages. Then, iterate through the values,
253
      // and compute the entropy of each side of the split (of the form `prop >
254
      // threshold`). Finally, find the split that minimizes the cost.
255
6.07k
      struct CostInfo {
256
6.07k
        float cost = std::numeric_limits<float>::max();
257
6.07k
        float extra_cost = 0;
258
6.07k
        float Cost() const { return cost + extra_cost; }
259
6.07k
        Predictor pred;  // will be uninitialized in some cases, but never used.
260
6.07k
      };
261
6.07k
      std::vector<CostInfo> costs_l;
262
6.07k
      std::vector<CostInfo> costs_r;
263
264
6.07k
      std::vector<int32_t> counts_above(max_symbols);
265
6.07k
      std::vector<int32_t> counts_below(max_symbols);
266
267
      // The lower the threshold, the higher the expected noisiness of the
268
      // estimate. Thus, discourage changing predictors.
269
6.07k
      float change_pred_penalty = 800.0f / (100.0f + threshold);
270
48.1k
      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
271
42.1k
           prop++) {
272
42.1k
        costs_l.clear();
273
42.1k
        costs_r.clear();
274
42.1k
        size_t prop_size = tree_samples.NumPropertyValues(prop);
275
42.1k
        if (extra_bits_increase.size() < prop_size) {
276
12.2k
          count_increase.resize(prop_size * max_symbols);
277
12.2k
          extra_bits_increase.resize(prop_size);
278
12.2k
        }
279
        // Clear prop_value_used_count (which cannot be cleared "on the go")
280
42.1k
        prop_value_used_count.clear();
281
42.1k
        prop_value_used_count.resize(prop_size);
282
283
42.1k
        size_t first_used = prop_size;
284
42.1k
        size_t last_used = 0;
285
286
        // TODO(veluca): consider finding multiple splits along a single
287
        // property at the same time, possibly with a bottom-up approach.
288
8.62M
        for (size_t i = begin; i < end; i++) {
289
8.57M
          size_t p = tree_samples.Property(prop, i);
290
8.57M
          prop_value_used_count[p]++;
291
8.57M
          last_used = std::max(last_used, p);
292
8.57M
          first_used = std::min(first_used, p);
293
8.57M
        }
294
42.1k
        costs_l.resize(last_used - first_used);
295
42.1k
        costs_r.resize(last_used - first_used);
296
        // For all predictors, compute the right and left costs of each split.
297
84.2k
        for (size_t pred = 0; pred < num_predictors; pred++) {
298
          // Compute cost and histogram increments for each property value.
299
8.62M
          for (size_t i = begin; i < end; i++) {
300
8.57M
            size_t p = tree_samples.Property(prop, i);
301
8.57M
            size_t cnt = tree_samples.Count(i);
302
8.57M
            size_t sym = tree_samples.Token(pred, i);
303
8.57M
            count_increase[p * max_symbols + sym] += cnt;
304
8.57M
            extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt;
305
8.57M
          }
306
42.1k
          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
307
42.1k
                 max_symbols * sizeof counts_above[0]);
308
42.1k
          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
309
42.1k
          size_t extra_bits_below = 0;
310
          // Exclude last used: this ensures neither counts_above nor
311
          // counts_below is empty.
312
351k
          for (size_t i = first_used; i < last_used; i++) {
313
309k
            if (!prop_value_used_count[i]) continue;
314
182k
            extra_bits_below += extra_bits_increase[i];
315
            // The increase for this property value has been used, and will not
316
            // be used again: clear it. Also below.
317
182k
            extra_bits_increase[i] = 0;
318
9.63M
            for (size_t sym = 0; sym < max_symbols; sym++) {
319
9.44M
              counts_above[sym] -= count_increase[i * max_symbols + sym];
320
9.44M
              counts_below[sym] += count_increase[i * max_symbols + sym];
321
9.44M
              count_increase[i * max_symbols + sym] = 0;
322
9.44M
            }
323
182k
            float rcost = EstimateBits(counts_above.data(), max_symbols) +
324
182k
                          tot_extra_bits[pred] - extra_bits_below;
325
182k
            float lcost = EstimateBits(counts_below.data(), max_symbols) +
326
182k
                          extra_bits_below;
327
182k
            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
328
182k
            float penalty = 0;
329
            // Never discourage moving away from the Weighted predictor.
330
182k
            if (tree_samples.PredictorFromIndex(pred) !=
331
182k
                    (*tree)[pos].predictor &&
332
182k
                (*tree)[pos].predictor != Predictor::Weighted) {
333
0
              penalty = change_pred_penalty;
334
0
            }
335
            // If everything else is equal, disfavour Weighted (slower) and
336
            // favour Zero (faster if it's the only predictor used in a
337
            // group+channel combination)
338
182k
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
339
0
              penalty += 1e-8;
340
0
            }
341
182k
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
342
0
              penalty -= 1e-8;
343
0
            }
344
182k
            if (rcost + penalty < costs_r[i - first_used].Cost()) {
345
182k
              costs_r[i - first_used].cost = rcost;
346
182k
              costs_r[i - first_used].extra_cost = penalty;
347
182k
              costs_r[i - first_used].pred =
348
182k
                  tree_samples.PredictorFromIndex(pred);
349
182k
            }
350
182k
            if (lcost + penalty < costs_l[i - first_used].Cost()) {
351
182k
              costs_l[i - first_used].cost = lcost;
352
182k
              costs_l[i - first_used].extra_cost = penalty;
353
182k
              costs_l[i - first_used].pred =
354
182k
                  tree_samples.PredictorFromIndex(pred);
355
182k
            }
356
182k
          }
357
42.1k
        }
358
        // Iterate through the possible splits and find the one with minimum sum
359
        // of costs of the two sides.
360
42.1k
        size_t split = begin;
361
351k
        for (size_t i = first_used; i < last_used; i++) {
362
309k
          if (!prop_value_used_count[i]) continue;
363
182k
          split += prop_value_used_count[i];
364
182k
          float rcost = costs_r[i - first_used].cost;
365
182k
          float lcost = costs_l[i - first_used].cost;
366
          // WP was not used + we would use the WP property or predictor
367
182k
          bool adds_wp =
368
182k
              (tree_samples.PropertyFromIndex(prop) == kWPProp &&
369
182k
               (used_properties & (1LU << prop)) == 0) ||
370
182k
              ((costs_l[i - first_used].pred == Predictor::Weighted ||
371
128k
                costs_r[i - first_used].pred == Predictor::Weighted) &&
372
128k
               (*tree)[pos].predictor != Predictor::Weighted);
373
182k
          bool zero_entropy_side = rcost == 0 || lcost == 0;
374
375
182k
          SplitInfo &best_ref =
376
182k
              tree_samples.PropertyFromIndex(prop) < kNumStaticProperties
377
182k
                  ? (zero_entropy_side ? best_split_static_constant
378
1.27k
                                       : best_split_static)
379
182k
                  : (adds_wp ? best_split_nonstatic : best_split_nowp);
380
182k
          if (lcost + rcost < best_ref.Cost()) {
381
39.1k
            best_ref.prop = prop;
382
39.1k
            best_ref.val = i;
383
39.1k
            best_ref.pos = split;
384
39.1k
            best_ref.lcost = lcost;
385
39.1k
            best_ref.lpred = costs_l[i - first_used].pred;
386
39.1k
            best_ref.rcost = rcost;
387
39.1k
            best_ref.rpred = costs_r[i - first_used].pred;
388
39.1k
          }
389
182k
        }
390
        // Clear extra_bits_increase and cost_increase for last_used.
391
42.1k
        extra_bits_increase[last_used] = 0;
392
2.39M
        for (size_t sym = 0; sym < max_symbols; sym++) {
393
2.35M
          count_increase[last_used * max_symbols + sym] = 0;
394
2.35M
        }
395
42.1k
      }
396
397
      // Try to avoid introducing WP.
398
6.07k
      if (best_split_nowp.Cost() + threshold < base_bits &&
399
6.07k
          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
400
2.76k
        best = &best_split_nowp;
401
2.76k
      }
402
      // Split along static props if possible and not significantly more
403
      // expensive.
404
6.07k
      if (best_split_static.Cost() + threshold < base_bits &&
405
6.07k
          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
406
402
        best = &best_split_static;
407
402
      }
408
      // Split along static props to create constant nodes if possible.
409
6.07k
      if (best_split_static_constant.Cost() + threshold < base_bits) {
410
2
        best = &best_split_static_constant;
411
2
      }
412
6.07k
    }
413
414
6.07k
    if (best->Cost() + threshold < base_bits) {
415
2.98k
      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
416
2.98k
      pixel_type dequant =
417
2.98k
          tree_samples.UnquantizeProperty(best->prop, best->val);
418
      // Split node and try to split children.
419
2.98k
      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
420
      // "Sort" according to winning property
421
2.98k
      SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop,
422
2.98k
                       best->val);
423
2.98k
      if (p >= kNumStaticProperties) {
424
2.57k
        used_properties |= 1 << best->prop;
425
2.57k
      }
426
2.98k
      auto new_sp_range = static_prop_range;
427
2.98k
      if (p < kNumStaticProperties) {
428
404
        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
429
404
        new_sp_range[p][1] = dequant + 1;
430
404
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
431
404
      }
432
2.98k
      nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
433
2.98k
                               used_properties, new_sp_range});
434
2.98k
      new_sp_range = static_prop_range;
435
2.98k
      if (p < kNumStaticProperties) {
436
404
        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
437
404
        new_sp_range[p][0] = dequant + 1;
438
404
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
439
404
      }
440
2.98k
      nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
441
2.98k
                               used_properties, new_sp_range});
442
2.98k
    }
443
6.07k
  }
444
111
}
Unexecuted instantiation: jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
445
446
// NOLINTNEXTLINE(google-readability-namespace-comments)
447
}  // namespace HWY_NAMESPACE
448
}  // namespace jxl
449
HWY_AFTER_NAMESPACE();
450
451
#if HWY_ONCE
452
namespace jxl {
453
454
HWY_EXPORT(FindBestSplit);  // Local function.
455
456
Status ComputeBestTree(TreeSamples &tree_samples, float threshold,
457
                       const std::vector<ModularMultiplierInfo> &mul_info,
458
                       StaticPropRange static_prop_range,
459
111
                       float fast_decode_multiplier, Tree *tree) {
460
  // TODO(veluca): take into account that different contexts can have different
461
  // uint configs.
462
  //
463
  // Initialize tree.
464
111
  tree->emplace_back();
465
111
  tree->back().property = -1;
466
111
  tree->back().predictor = tree_samples.PredictorFromIndex(0);
467
111
  tree->back().predictor_offset = 0;
468
111
  tree->back().multiplier = 1;
469
111
  JXL_ENSURE(tree_samples.NumProperties() < 64);
470
471
111
  JXL_ENSURE(tree_samples.NumDistinctSamples() <=
472
111
             std::numeric_limits<uint32_t>::max());
473
111
  HWY_DYNAMIC_DISPATCH(FindBestSplit)
474
111
  (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier,
475
111
   tree);
476
111
  return true;
477
111
}
478
479
#if JXL_CXX_LANG < JXL_CXX_17
480
constexpr int32_t TreeSamples::kPropertyRange;
481
constexpr uint32_t TreeSamples::kDedupEntryUnused;
482
#endif
483
484
Status TreeSamples::SetPredictor(Predictor predictor,
485
111
                                 ModularOptions::TreeMode wp_tree_mode) {
486
111
  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
487
0
    predictors = {Predictor::Weighted};
488
0
    residuals.resize(1);
489
0
    return true;
490
0
  }
491
111
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP &&
492
111
      predictor == Predictor::Weighted) {
493
0
    return JXL_FAILURE("Invalid predictor settings");
494
0
  }
495
111
  if (predictor == Predictor::Variable) {
496
0
    for (size_t i = 0; i < kNumModularPredictors; i++) {
497
0
      predictors.push_back(static_cast<Predictor>(i));
498
0
    }
499
0
    std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]);
500
0
    std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]);
501
111
  } else if (predictor == Predictor::Best) {
502
0
    predictors = {Predictor::Weighted, Predictor::Gradient};
503
111
  } else {
504
111
    predictors = {predictor};
505
111
  }
506
111
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
507
0
    auto wp_it =
508
0
        std::find(predictors.begin(), predictors.end(), Predictor::Weighted);
509
0
    if (wp_it != predictors.end()) {
510
0
      predictors.erase(wp_it);
511
0
    }
512
0
  }
513
111
  residuals.resize(predictors.size());
514
111
  return true;
515
111
}
516
517
Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties,
518
111
                                  ModularOptions::TreeMode wp_tree_mode) {
519
111
  props_to_use = properties;
520
111
  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
521
0
    props_to_use = {static_cast<uint32_t>(kWPProp)};
522
0
  }
523
111
  if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) {
524
0
    props_to_use = {static_cast<uint32_t>(kGradientProp)};
525
0
  }
526
111
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
527
0
    auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp);
528
0
    if (it != props_to_use.end()) {
529
0
      props_to_use.erase(it);
530
0
    }
531
0
  }
532
111
  if (props_to_use.empty()) {
533
0
    return JXL_FAILURE("Invalid property set configuration");
534
0
  }
535
111
  props.resize(props_to_use.size());
536
111
  return true;
537
111
}
538
539
306
void TreeSamples::InitTable(size_t log_size) {
540
306
  size_t size = 1ULL << log_size;
541
306
  if (dedup_table_.size() == size) return;
542
167
  dedup_table_.resize(size, kDedupEntryUnused);
543
47.1k
  for (size_t i = 0; i < NumDistinctSamples(); i++) {
544
46.9k
    if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) {
545
46.9k
      AddToTable(i);
546
46.9k
    }
547
46.9k
  }
548
167
}
549
550
806k
bool TreeSamples::AddToTableAndMerge(size_t a) {
551
806k
  size_t pos1 = Hash1(a);
552
806k
  size_t pos2 = Hash2(a);
553
806k
  if (dedup_table_[pos1] != kDedupEntryUnused &&
554
806k
      IsSameSample(a, dedup_table_[pos1])) {
555
472k
    JXL_DASSERT(sample_counts[a] == 1);
556
472k
    sample_counts[dedup_table_[pos1]]++;
557
    // Remove from hash table samples that are saturated.
558
472k
    if (sample_counts[dedup_table_[pos1]] ==
559
472k
        std::numeric_limits<uint16_t>::max()) {
560
0
      dedup_table_[pos1] = kDedupEntryUnused;
561
0
    }
562
472k
    return true;
563
472k
  }
564
333k
  if (dedup_table_[pos2] != kDedupEntryUnused &&
565
333k
      IsSameSample(a, dedup_table_[pos2])) {
566
148k
    JXL_DASSERT(sample_counts[a] == 1);
567
148k
    sample_counts[dedup_table_[pos2]]++;
568
    // Remove from hash table samples that are saturated.
569
148k
    if (sample_counts[dedup_table_[pos2]] ==
570
148k
        std::numeric_limits<uint16_t>::max()) {
571
0
      dedup_table_[pos2] = kDedupEntryUnused;
572
0
    }
573
148k
    return true;
574
148k
  }
575
185k
  AddToTable(a);
576
185k
  return false;
577
333k
}
578
579
232k
void TreeSamples::AddToTable(size_t a) {
580
232k
  size_t pos1 = Hash1(a);
581
232k
  size_t pos2 = Hash2(a);
582
232k
  if (dedup_table_[pos1] == kDedupEntryUnused) {
583
121k
    dedup_table_[pos1] = a;
584
121k
  } else if (dedup_table_[pos2] == kDedupEntryUnused) {
585
68.9k
    dedup_table_[pos2] = a;
586
68.9k
  }
587
232k
}
588
589
306
void TreeSamples::PrepareForSamples(size_t extra_num_samples) {
590
306
  for (auto &res : residuals) {
591
306
    res.reserve(res.size() + extra_num_samples);
592
306
  }
593
2.14k
  for (auto &p : props) {
594
2.14k
    p.reserve(p.size() + extra_num_samples);
595
2.14k
  }
596
306
  size_t total_num_samples = extra_num_samples + sample_counts.size();
597
306
  size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2);
598
306
  InitTable(next_size);
599
306
}
600
601
1.03M
size_t TreeSamples::Hash1(size_t a) const {
602
1.03M
  constexpr uint64_t constant = 0x1e35a7bd;
603
1.03M
  uint64_t h = constant;
604
1.03M
  for (const auto &r : residuals) {
605
1.03M
    h = h * constant + r[a].tok;
606
1.03M
    h = h * constant + r[a].nbits;
607
1.03M
  }
608
7.26M
  for (const auto &p : props) {
609
7.26M
    h = h * constant + p[a];
610
7.26M
  }
611
1.03M
  return (h >> 16) & (dedup_table_.size() - 1);
612
1.03M
}
613
1.03M
size_t TreeSamples::Hash2(size_t a) const {
614
1.03M
  constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
615
1.03M
  uint64_t h = constant;
616
7.26M
  for (const auto &p : props) {
617
7.26M
    h = h * constant ^ p[a];
618
7.26M
  }
619
1.03M
  for (const auto &r : residuals) {
620
1.03M
    h = h * constant ^ r[a].tok;
621
1.03M
    h = h * constant ^ r[a].nbits;
622
1.03M
  }
623
1.03M
  return (h >> 16) & (dedup_table_.size() - 1);
624
1.03M
}
625
626
895k
bool TreeSamples::IsSameSample(size_t a, size_t b) const {
627
895k
  bool ret = true;
628
895k
  for (const auto &r : residuals) {
629
895k
    if (r[a].tok != r[b].tok) {
630
102k
      ret = false;
631
102k
    }
632
895k
    if (r[a].nbits != r[b].nbits) {
633
81.4k
      ret = false;
634
81.4k
    }
635
895k
  }
636
6.26M
  for (const auto &p : props) {
637
6.26M
    if (p[a] != p[b]) {
638
738k
      ret = false;
639
738k
    }
640
6.26M
  }
641
895k
  return ret;
642
895k
}
643
644
void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties,
645
806k
                            const pixel_type_w *predictions) {
646
1.61M
  for (size_t i = 0; i < predictors.size(); i++) {
647
806k
    pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
648
806k
    uint32_t tok, nbits, bits;
649
806k
    HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
650
806k
    JXL_DASSERT(tok < 256);
651
806k
    JXL_DASSERT(nbits < 256);
652
806k
    residuals[i].emplace_back(
653
806k
        ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)});
654
806k
  }
655
6.44M
  for (size_t i = 0; i < props_to_use.size(); i++) {
656
5.64M
    props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
657
5.64M
  }
658
806k
  sample_counts.push_back(1);
659
806k
  num_samples++;
660
806k
  if (AddToTableAndMerge(sample_counts.size() - 1)) {
661
620k
    for (auto &r : residuals) r.pop_back();
662
4.34M
    for (auto &p : props) p.pop_back();
663
620k
    sample_counts.pop_back();
664
620k
  }
665
806k
}
666
667
195k
void TreeSamples::Swap(size_t a, size_t b) {
668
195k
  if (a == b) return;
669
195k
  for (auto &r : residuals) {
670
195k
    std::swap(r[a], r[b]);
671
195k
  }
672
1.36M
  for (auto &p : props) {
673
1.36M
    std::swap(p[a], p[b]);
674
1.36M
  }
675
195k
  std::swap(sample_counts[a], sample_counts[b]);
676
195k
}
677
678
namespace {
679
std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram,
680
321
                                       size_t num_chunks) {
681
321
  if (histogram.empty() || num_chunks == 0) return {};
682
321
  uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
683
321
  if (sum == 0) return {};
684
  // TODO(veluca): selecting distinct quantiles is likely not the best
685
  // way to go about this.
686
319
  std::vector<int32_t> thresholds;
687
319
  uint64_t cumsum = 0;
688
319
  uint64_t threshold = 1;
689
116k
  for (size_t i = 0; i < histogram.size(); i++) {
690
116k
    cumsum += histogram[i];
691
116k
    if (cumsum * num_chunks >= threshold * sum) {
692
1.72k
      thresholds.push_back(i);
693
17.0k
      while (cumsum * num_chunks >= threshold * sum) threshold++;
694
1.72k
    }
695
116k
  }
696
319
  JXL_DASSERT(thresholds.size() <= num_chunks);
697
  // last value collects all histogram and is not really a threshold
698
319
  thresholds.pop_back();
699
319
  return thresholds;
700
321
}
701
702
std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples,
703
123
                                     size_t num_chunks) {
704
123
  if (samples.empty()) return {};
705
113
  int min = *std::min_element(samples.begin(), samples.end());
706
113
  constexpr int kRange = 512;
707
113
  min = jxl::Clamp1(min, -kRange, kRange);
708
113
  std::vector<uint32_t> counts(2 * kRange + 1);
709
76.2k
  for (int s : samples) {
710
76.2k
    uint32_t sample_offset = jxl::Clamp1(s, -kRange, kRange) - min;
711
76.2k
    counts[sample_offset]++;
712
76.2k
  }
713
113
  std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks);
714
1.21k
  for (auto &v : thresholds) v += min;
715
113
  return thresholds;
716
123
}
717
}  // namespace
718
719
void TreeSamples::PreQuantizeProperties(
720
    const StaticPropRange &range,
721
    const std::vector<ModularMultiplierInfo> &multiplier_info,
722
    const std::vector<uint32_t> &group_pixel_count,
723
    const std::vector<uint32_t> &channel_pixel_count,
724
    std::vector<pixel_type> &pixel_samples,
725
111
    std::vector<pixel_type> &diff_samples, size_t max_property_values) {
726
  // If we have forced splits because of multipliers, choose channel and group
727
  // thresholds accordingly.
728
111
  std::vector<int32_t> group_multiplier_thresholds;
729
111
  std::vector<int32_t> channel_multiplier_thresholds;
730
111
  for (const auto &v : multiplier_info) {
731
97
    if (v.range[0][0] != range[0][0]) {
732
0
      channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
733
0
    }
734
97
    if (v.range[0][1] != range[0][1]) {
735
0
      channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
736
0
    }
737
97
    if (v.range[1][0] != range[1][0]) {
738
0
      group_multiplier_thresholds.push_back(v.range[1][0] - 1);
739
0
    }
740
97
    if (v.range[1][1] != range[1][1]) {
741
0
      group_multiplier_thresholds.push_back(v.range[1][1] - 1);
742
0
    }
743
97
  }
744
111
  std::sort(channel_multiplier_thresholds.begin(),
745
111
            channel_multiplier_thresholds.end());
746
111
  channel_multiplier_thresholds.resize(
747
111
      std::unique(channel_multiplier_thresholds.begin(),
748
111
                  channel_multiplier_thresholds.end()) -
749
111
      channel_multiplier_thresholds.begin());
750
111
  std::sort(group_multiplier_thresholds.begin(),
751
111
            group_multiplier_thresholds.end());
752
111
  group_multiplier_thresholds.resize(
753
111
      std::unique(group_multiplier_thresholds.begin(),
754
111
                  group_multiplier_thresholds.end()) -
755
111
      group_multiplier_thresholds.begin());
756
757
111
  compact_properties.resize(props_to_use.size());
758
111
  auto quantize_channel = [&]() {
759
111
    if (!channel_multiplier_thresholds.empty()) {
760
0
      return channel_multiplier_thresholds;
761
0
    }
762
111
    return QuantizeHistogram(channel_pixel_count, max_property_values);
763
111
  };
764
111
  auto quantize_group_id = [&]() {
765
97
    if (!group_multiplier_thresholds.empty()) {
766
0
      return group_multiplier_thresholds;
767
0
    }
768
97
    return QuantizeHistogram(group_pixel_count, max_property_values);
769
97
  };
770
111
  auto quantize_coordinate = [&]() {
771
0
    std::vector<int32_t> quantized;
772
0
    quantized.reserve(max_property_values - 1);
773
0
    for (size_t i = 0; i + 1 < max_property_values; i++) {
774
0
      quantized.push_back((i + 1) * 256 / max_property_values - 1);
775
0
    }
776
0
    return quantized;
777
0
  };
778
111
  std::vector<int32_t> abs_pixel_thresholds;
779
111
  std::vector<int32_t> pixel_thresholds;
780
111
  auto quantize_pixel_property = [&]() {
781
0
    if (pixel_thresholds.empty()) {
782
0
      pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values);
783
0
    }
784
0
    return pixel_thresholds;
785
0
  };
786
111
  auto quantize_abs_pixel_property = [&]() {
787
0
    if (abs_pixel_thresholds.empty()) {
788
0
      quantize_pixel_property();  // Compute the non-abs thresholds.
789
0
      for (auto &v : pixel_samples) v = std::abs(v);
790
0
      abs_pixel_thresholds =
791
0
          QuantizeSamples(pixel_samples, max_property_values);
792
0
    }
793
0
    return abs_pixel_thresholds;
794
0
  };
795
111
  std::vector<int32_t> abs_diff_thresholds;
796
111
  std::vector<int32_t> diff_thresholds;
797
458
  auto quantize_diff_property = [&]() {
798
458
    if (diff_thresholds.empty()) {
799
123
      diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
800
123
    }
801
458
    return diff_thresholds;
802
458
  };
803
111
  auto quantize_abs_diff_property = [&]() {
804
0
    if (abs_diff_thresholds.empty()) {
805
0
      quantize_diff_property();  // Compute the non-abs thresholds.
806
0
      for (auto &v : diff_samples) v = std::abs(v);
807
0
      abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
808
0
    }
809
0
    return abs_diff_thresholds;
810
0
  };
811
111
  auto quantize_wp = [&]() {
812
111
    if (max_property_values < 32) {
813
0
      return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0,
814
0
                                  1,    3,   7,   15,  31, 63, 127};
815
0
    }
816
111
    if (max_property_values < 64) {
817
111
      return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23,
818
111
                                  -15,  -11,  -7,   -5,  -3,  -1,  0,   1,
819
111
                                  3,    5,    7,    11,  15,  23,  31,  47,
820
111
                                  63,   95,   127,  191, 255};
821
111
    }
822
0
    return std::vector<int32_t>{
823
0
        -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47,
824
0
        -39,  -31,  -27,  -23,  -19,  -15,  -13, -11, -9,  -7,  -6,
825
0
        -5,   -4,   -3,   -2,   -1,   0,    1,   2,   3,   4,   5,
826
0
        6,    7,    9,    11,   13,   15,   19,  23,  27,  31,  39,
827
0
        47,   55,   63,   79,   95,   111,  127, 159, 191, 223, 255};
828
111
  };
829
830
111
  property_mapping.resize(props_to_use.size());
831
888
  for (size_t i = 0; i < props_to_use.size(); i++) {
832
777
    if (props_to_use[i] == 0) {
833
111
      compact_properties[i] = quantize_channel();
834
666
    } else if (props_to_use[i] == 1) {
835
97
      compact_properties[i] = quantize_group_id();
836
569
    } else if (props_to_use[i] == 2 || props_to_use[i] == 3) {
837
0
      compact_properties[i] = quantize_coordinate();
838
569
    } else if (props_to_use[i] == 6 || props_to_use[i] == 7 ||
839
569
               props_to_use[i] == 8 ||
840
569
               (props_to_use[i] >= kNumNonrefProperties &&
841
569
                (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) {
842
0
      compact_properties[i] = quantize_pixel_property();
843
569
    } else if (props_to_use[i] == 4 || props_to_use[i] == 5 ||
844
569
               (props_to_use[i] >= kNumNonrefProperties &&
845
569
                (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) {
846
0
      compact_properties[i] = quantize_abs_pixel_property();
847
569
    } else if (props_to_use[i] >= kNumNonrefProperties &&
848
569
               (props_to_use[i] - kNumNonrefProperties) % 4 == 2) {
849
0
      compact_properties[i] = quantize_abs_diff_property();
850
569
    } else if (props_to_use[i] == kWPProp) {
851
111
      compact_properties[i] = quantize_wp();
852
458
    } else {
853
458
      compact_properties[i] = quantize_diff_property();
854
458
    }
855
777
    property_mapping[i].resize(kPropertyRange * 2 + 1);
856
777
    size_t mapped = 0;
857
795k
    for (size_t j = 0; j < property_mapping[i].size(); j++) {
858
803k
      while (mapped < compact_properties[i].size() &&
859
803k
             static_cast<int>(j) - kPropertyRange >
860
431k
                 compact_properties[i][mapped]) {
861
8.29k
        mapped++;
862
8.29k
      }
863
794k
      JXL_DASSERT(mapped < 256);
864
      // property_mapping[i] of a value V is `mapped` if
865
      // compact_properties[i][mapped] <= j and
866
      // compact_properties[i][mapped-1] > j
867
      // This is because the decision node in the tree splits on (property) > j,
868
      // hence everything that is not > of a threshold should be clustered
869
      // together.
870
794k
      property_mapping[i][j] = mapped;
871
794k
    }
872
777
  }
873
111
}
874
875
void CollectPixelSamples(const Image &image, const ModularOptions &options,
876
                         uint32_t group_id,
877
                         std::vector<uint32_t> &group_pixel_count,
878
                         std::vector<uint32_t> &channel_pixel_count,
879
                         std::vector<pixel_type> &pixel_samples,
880
111
                         std::vector<pixel_type> &diff_samples) {
881
111
  if (options.nb_repeats == 0) return;
882
111
  if (group_pixel_count.size() <= group_id) {
883
111
    group_pixel_count.resize(group_id + 1);
884
111
  }
885
111
  if (channel_pixel_count.size() < image.channel.size()) {
886
111
    channel_pixel_count.resize(image.channel.size());
887
111
  }
888
111
  Rng rng(group_id);
889
  // Sample 10% of the final number of samples for property quantization.
890
111
  float fraction = std::min(options.nb_repeats * 0.1, 0.99);
891
111
  Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction);
892
111
  size_t total_pixels = 0;
893
111
  std::vector<size_t> channel_ids;
894
417
  for (size_t i = 0; i < image.channel.size(); i++) {
895
306
    if (i >= image.nb_meta_channels &&
896
306
        (image.channel[i].w > options.max_chan_size ||
897
305
         image.channel[i].h > options.max_chan_size)) {
898
0
      break;
899
0
    }
900
306
    if (image.channel[i].w <= 1 || image.channel[i].h == 0) {
901
2
      continue;  // skip empty or width-1 channels.
902
2
    }
903
304
    channel_ids.push_back(i);
904
304
    group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h;
905
304
    channel_pixel_count[i] += image.channel[i].w * image.channel[i].h;
906
304
    total_pixels += image.channel[i].w * image.channel[i].h;
907
304
  }
908
111
  if (channel_ids.empty()) return;
909
109
  pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels);
910
109
  diff_samples.reserve(diff_samples.size() + fraction * total_pixels);
911
109
  size_t i = 0;
912
109
  size_t y = 0;
913
109
  size_t x = 0;
914
76.1k
  auto advance = [&](size_t amount) {
915
76.1k
    x += amount;
916
    // Detect row overflow (rare).
917
94.8k
    while (x >= image.channel[channel_ids[i]].w) {
918
18.7k
      x -= image.channel[channel_ids[i]].w;
919
18.7k
      y++;
920
      // Detect end-of-channel (even rarer).
921
18.7k
      if (y == image.channel[channel_ids[i]].h) {
922
304
        i++;
923
304
        y = 0;
924
304
        if (i >= channel_ids.size()) {
925
109
          return;
926
109
        }
927
304
      }
928
18.7k
    }
929
76.1k
  };
930
109
  advance(rng.Geometric(dist));
931
76.1k
  for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) {
932
76.0k
    const pixel_type *row = image.channel[channel_ids[i]].Row(y);
933
76.0k
    pixel_samples.push_back(row[x]);
934
76.0k
    size_t xp = x == 0 ? 1 : x - 1;
935
76.0k
    diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]);
936
76.0k
  }
937
109
}
938
939
// TODO(veluca): very simple encoding scheme. This should be improved.
940
Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
941
283
                    Tree *decoder_tree) {
942
283
  JXL_ENSURE(tree.size() <= kMaxTreeSize);
943
283
  std::queue<int> q;
944
283
  q.push(0);
945
283
  size_t leaf_id = 0;
946
283
  decoder_tree->clear();
947
14.9k
  while (!q.empty()) {
948
14.6k
    int cur = q.front();
949
14.6k
    q.pop();
950
14.6k
    JXL_ENSURE(tree[cur].property >= -1);
951
14.6k
    tokens->emplace_back(kPropertyContext, tree[cur].property + 1);
952
14.6k
    if (tree[cur].property == -1) {
953
7.47k
      tokens->emplace_back(kPredictorContext,
954
7.47k
                           static_cast<int>(tree[cur].predictor));
955
7.47k
      tokens->emplace_back(kOffsetContext,
956
7.47k
                           PackSigned(tree[cur].predictor_offset));
957
7.47k
      uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier);
958
7.47k
      uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1;
959
7.47k
      tokens->emplace_back(kMultiplierLogContext, mul_log);
960
7.47k
      tokens->emplace_back(kMultiplierBitsContext, mul_bits);
961
7.47k
      JXL_ENSURE(tree[cur].predictor < Predictor::Best);
962
7.47k
      decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor,
963
7.47k
                                 tree[cur].predictor_offset,
964
7.47k
                                 tree[cur].multiplier);
965
7.47k
      continue;
966
7.47k
    }
967
7.19k
    decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval,
968
7.19k
                               decoder_tree->size() + q.size() + 1,
969
7.19k
                               decoder_tree->size() + q.size() + 2,
970
7.19k
                               Predictor::Zero, 0, 1);
971
7.19k
    q.push(tree[cur].lchild);
972
7.19k
    q.push(tree[cur].rchild);
973
7.19k
    tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval));
974
7.19k
  }
975
283
  return true;
976
283
}
977
978
}  // namespace jxl
979
#endif  // HWY_ONCE