Coverage Report

Created: 2025-07-23 08:18

/src/libjxl/lib/jxl/modular/encoding/enc_ma.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/modular/encoding/enc_ma.h"
7
8
#include <algorithm>
9
#include <cstdint>
10
#include <cstdlib>
11
#include <cstring>
12
#include <limits>
13
#include <numeric>
14
#include <queue>
15
#include <vector>
16
17
#include "lib/jxl/ans_params.h"
18
#include "lib/jxl/base/bits.h"
19
#include "lib/jxl/base/common.h"
20
#include "lib/jxl/base/compiler_specific.h"
21
#include "lib/jxl/base/status.h"
22
#include "lib/jxl/dec_ans.h"
23
#include "lib/jxl/modular/encoding/dec_ma.h"
24
#include "lib/jxl/modular/encoding/ma_common.h"
25
#include "lib/jxl/modular/modular_image.h"
26
27
#undef HWY_TARGET_INCLUDE
28
#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc"
29
#include <hwy/foreach_target.h>
30
#include <hwy/highway.h>
31
32
#include "lib/jxl/base/fast_math-inl.h"
33
#include "lib/jxl/base/random.h"
34
#include "lib/jxl/enc_ans.h"
35
#include "lib/jxl/modular/encoding/context_predict.h"
36
#include "lib/jxl/modular/options.h"
37
#include "lib/jxl/pack_signed.h"
38
HWY_BEFORE_NAMESPACE();
39
namespace jxl {
40
namespace HWY_NAMESPACE {
41
42
// These templates are not found via ADL.
43
using hwy::HWY_NAMESPACE::Eq;
44
using hwy::HWY_NAMESPACE::IfThenElse;
45
using hwy::HWY_NAMESPACE::Lt;
46
using hwy::HWY_NAMESPACE::Max;
47
48
const HWY_FULL(float) df;
49
const HWY_FULL(int32_t) di;
50
121k
size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
Unexecuted instantiation: jxl::N_SSE4::Padded(unsigned long)
jxl::N_AVX2::Padded(unsigned long)
Line
Count
Source
50
121k
size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
Unexecuted instantiation: jxl::N_AVX3::Padded(unsigned long)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::Padded(unsigned long)
Unexecuted instantiation: jxl::N_AVX3_SPR::Padded(unsigned long)
Unexecuted instantiation: jxl::N_SSE2::Padded(unsigned long)
51
52
// Compute entropy of the histogram, taking into account the minimum probability
53
// for symbols with non-zero counts.
54
8.95M
float EstimateBits(const int32_t *counts, size_t num_symbols) {
55
8.95M
  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
56
8.95M
  const auto zero = Zero(df);
57
8.95M
  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
58
8.95M
  const auto inv_total = Set(df, 1.0f / total);
59
8.95M
  auto bits_lanes = Zero(df);
60
8.95M
  auto total_v = Set(di, total);
61
50.5M
  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
62
41.6M
    const auto counts_iv = LoadU(di, &counts[i]);
63
41.6M
    const auto counts_fv = ConvertTo(df, counts_iv);
64
41.6M
    const auto probs = Mul(counts_fv, inv_total);
65
41.6M
    const auto mprobs = Max(probs, minprob);
66
41.6M
    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
67
41.6M
                                 BitCast(di, FastLog2f(df, mprobs)));
68
41.6M
    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
69
41.6M
  }
70
8.95M
  return GetLane(SumOfLanes(df, bits_lanes));
71
8.95M
}
Unexecuted instantiation: jxl::N_SSE4::EstimateBits(int const*, unsigned long)
jxl::N_AVX2::EstimateBits(int const*, unsigned long)
Line
Count
Source
54
8.95M
float EstimateBits(const int32_t *counts, size_t num_symbols) {
55
8.95M
  int32_t total = std::accumulate(counts, counts + num_symbols, 0);
56
8.95M
  const auto zero = Zero(df);
57
8.95M
  const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
58
8.95M
  const auto inv_total = Set(df, 1.0f / total);
59
8.95M
  auto bits_lanes = Zero(df);
60
8.95M
  auto total_v = Set(di, total);
61
50.5M
  for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
62
41.6M
    const auto counts_iv = LoadU(di, &counts[i]);
63
41.6M
    const auto counts_fv = ConvertTo(df, counts_iv);
64
41.6M
    const auto probs = Mul(counts_fv, inv_total);
65
41.6M
    const auto mprobs = Max(probs, minprob);
66
41.6M
    const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
67
41.6M
                                 BitCast(di, FastLog2f(df, mprobs)));
68
41.6M
    bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
69
41.6M
  }
70
8.95M
  return GetLane(SumOfLanes(df, bits_lanes));
71
8.95M
}
Unexecuted instantiation: jxl::N_AVX3::EstimateBits(int const*, unsigned long)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::EstimateBits(int const*, unsigned long)
Unexecuted instantiation: jxl::N_AVX3_SPR::EstimateBits(int const*, unsigned long)
Unexecuted instantiation: jxl::N_SSE2::EstimateBits(int const*, unsigned long)
72
73
void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred,
74
60.3k
                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
75
  // Note that the tree splits on *strictly greater*.
76
60.3k
  (*tree)[pos].lchild = tree->size();
77
60.3k
  (*tree)[pos].rchild = tree->size() + 1;
78
60.3k
  (*tree)[pos].splitval = splitval;
79
60.3k
  (*tree)[pos].property = property;
80
60.3k
  tree->emplace_back();
81
60.3k
  tree->back().property = -1;
82
60.3k
  tree->back().predictor = rpred;
83
60.3k
  tree->back().predictor_offset = roff;
84
60.3k
  tree->back().multiplier = 1;
85
60.3k
  tree->emplace_back();
86
60.3k
  tree->back().property = -1;
87
60.3k
  tree->back().predictor = lpred;
88
60.3k
  tree->back().predictor_offset = loff;
89
60.3k
  tree->back().multiplier = 1;
90
60.3k
}
Unexecuted instantiation: jxl::N_SSE4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
jxl::N_AVX2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Line
Count
Source
74
60.3k
                   int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
75
  // Note that the tree splits on *strictly greater*.
76
60.3k
  (*tree)[pos].lchild = tree->size();
77
60.3k
  (*tree)[pos].rchild = tree->size() + 1;
78
60.3k
  (*tree)[pos].splitval = splitval;
79
60.3k
  (*tree)[pos].property = property;
80
60.3k
  tree->emplace_back();
81
60.3k
  tree->back().property = -1;
82
60.3k
  tree->back().predictor = rpred;
83
60.3k
  tree->back().predictor_offset = roff;
84
60.3k
  tree->back().multiplier = 1;
85
60.3k
  tree->emplace_back();
86
60.3k
  tree->back().property = -1;
87
60.3k
  tree->back().predictor = lpred;
88
60.3k
  tree->back().predictor_offset = loff;
89
60.3k
  tree->back().multiplier = 1;
90
60.3k
}
Unexecuted instantiation: jxl::N_AVX3::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Unexecuted instantiation: jxl::N_AVX3_SPR::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Unexecuted instantiation: jxl::N_SSE2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
91
92
enum class IntersectionType { kNone, kPartial, kInside };
93
IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack,
94
121k
                               uint32_t &partial_axis, uint32_t &partial_val) {
95
121k
  bool partial = false;
96
365k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
97
243k
    if (haystack[i][0] >= needle[i][1]) {
98
0
      return IntersectionType::kNone;
99
0
    }
100
243k
    if (haystack[i][1] <= needle[i][0]) {
101
0
      return IntersectionType::kNone;
102
0
    }
103
243k
    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
104
243k
      continue;
105
243k
    }
106
0
    partial = true;
107
0
    partial_axis = i;
108
0
    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
109
0
      partial_val = haystack[i][0] - 1;
110
0
    } else {
111
0
      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
112
0
                  haystack[i][1] < needle[i][1]);
113
0
      partial_val = haystack[i][1] - 1;
114
0
    }
115
0
  }
116
121k
  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
117
121k
}
Unexecuted instantiation: jxl::N_SSE4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
jxl::N_AVX2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
Line
Count
Source
94
121k
                               uint32_t &partial_axis, uint32_t &partial_val) {
95
121k
  bool partial = false;
96
365k
  for (size_t i = 0; i < kNumStaticProperties; i++) {
97
243k
    if (haystack[i][0] >= needle[i][1]) {
98
0
      return IntersectionType::kNone;
99
0
    }
100
243k
    if (haystack[i][1] <= needle[i][0]) {
101
0
      return IntersectionType::kNone;
102
0
    }
103
243k
    if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
104
243k
      continue;
105
243k
    }
106
0
    partial = true;
107
0
    partial_axis = i;
108
0
    if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
109
0
      partial_val = haystack[i][0] - 1;
110
0
    } else {
111
0
      JXL_DASSERT(haystack[i][1] > needle[i][0] &&
112
0
                  haystack[i][1] < needle[i][1]);
113
0
      partial_val = haystack[i][1] - 1;
114
0
    }
115
0
  }
116
121k
  return partial ? IntersectionType::kPartial : IntersectionType::kInside;
117
121k
}
Unexecuted instantiation: jxl::N_AVX3::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
Unexecuted instantiation: jxl::N_AVX3_SPR::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
Unexecuted instantiation: jxl::N_SSE2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&)
118
119
void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
120
60.3k
                      size_t end, size_t prop, uint32_t val) {
121
60.3k
  size_t begin_pos = begin;
122
60.3k
  size_t end_pos = pos;
123
6.61M
  do {
124
15.8M
    while (begin_pos < pos && tree_samples.Property(prop, begin_pos) <= val) {
125
9.23M
      ++begin_pos;
126
9.23M
    }
127
16.7M
    while (end_pos < end && tree_samples.Property(prop, end_pos) > val) {
128
10.1M
      ++end_pos;
129
10.1M
    }
130
6.61M
    if (begin_pos < pos && end_pos < end) {
131
6.59M
      tree_samples.Swap(begin_pos, end_pos);
132
6.59M
    }
133
6.61M
    ++begin_pos;
134
6.61M
    ++end_pos;
135
6.61M
  } while (begin_pos < pos && end_pos < end);
136
60.3k
}
Unexecuted instantiation: jxl::N_SSE4::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
jxl::N_AVX2::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Line
Count
Source
120
60.3k
                      size_t end, size_t prop, uint32_t val) {
121
60.3k
  size_t begin_pos = begin;
122
60.3k
  size_t end_pos = pos;
123
6.61M
  do {
124
15.8M
    while (begin_pos < pos && tree_samples.Property(prop, begin_pos) <= val) {
125
9.23M
      ++begin_pos;
126
9.23M
    }
127
16.7M
    while (end_pos < end && tree_samples.Property(prop, end_pos) > val) {
128
10.1M
      ++end_pos;
129
10.1M
    }
130
6.61M
    if (begin_pos < pos && end_pos < end) {
131
6.59M
      tree_samples.Swap(begin_pos, end_pos);
132
6.59M
    }
133
6.61M
    ++begin_pos;
134
6.61M
    ++end_pos;
135
6.61M
  } while (begin_pos < pos && end_pos < end);
136
60.3k
}
Unexecuted instantiation: jxl::N_AVX3::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Unexecuted instantiation: jxl::N_AVX3_SPR::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
Unexecuted instantiation: jxl::N_SSE2::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int)
137
138
void FindBestSplit(TreeSamples &tree_samples, float threshold,
139
                   const std::vector<ModularMultiplierInfo> &mul_info,
140
                   StaticPropRange initial_static_prop_range,
141
1.00k
                   float fast_decode_multiplier, Tree *tree) {
142
1.00k
  struct NodeInfo {
143
1.00k
    size_t pos;
144
1.00k
    size_t begin;
145
1.00k
    size_t end;
146
1.00k
    uint64_t used_properties;
147
1.00k
    StaticPropRange static_prop_range;
148
1.00k
  };
149
1.00k
  std::vector<NodeInfo> nodes;
150
1.00k
  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
151
1.00k
                           initial_static_prop_range});
152
153
1.00k
  size_t num_predictors = tree_samples.NumPredictors();
154
1.00k
  size_t num_properties = tree_samples.NumProperties();
155
156
  // TODO(veluca): consider parallelizing the search (processing multiple nodes
157
  // at a time).
158
122k
  while (!nodes.empty()) {
159
121k
    size_t pos = nodes.back().pos;
160
121k
    size_t begin = nodes.back().begin;
161
121k
    size_t end = nodes.back().end;
162
121k
    uint64_t used_properties = nodes.back().used_properties;
163
121k
    StaticPropRange static_prop_range = nodes.back().static_prop_range;
164
121k
    nodes.pop_back();
165
121k
    if (begin == end) continue;
166
167
121k
    struct SplitInfo {
168
121k
      size_t prop = 0;
169
121k
      uint32_t val = 0;
170
121k
      size_t pos = 0;
171
121k
      float lcost = std::numeric_limits<float>::max();
172
121k
      float rcost = std::numeric_limits<float>::max();
173
121k
      Predictor lpred = Predictor::Zero;
174
121k
      Predictor rpred = Predictor::Zero;
175
5.03M
      float Cost() const { return lcost + rcost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
Line
Count
Source
175
5.03M
      float Cost() const { return lcost + rcost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_AVX3::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
Unexecuted instantiation: enc_ma.cc:jxl::N_AVX3_ZEN4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
Unexecuted instantiation: enc_ma.cc:jxl::N_AVX3_SPR::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const
176
121k
    };
177
178
121k
    SplitInfo best_split_static_constant;
179
121k
    SplitInfo best_split_static;
180
121k
    SplitInfo best_split_nonstatic;
181
121k
    SplitInfo best_split_nowp;
182
183
121k
    JXL_DASSERT(begin <= end);
184
121k
    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
185
186
    // Compute the maximum token in the range.
187
121k
    size_t max_symbols = 0;
188
243k
    for (size_t pred = 0; pred < num_predictors; pred++) {
189
37.3M
      for (size_t i = begin; i < end; i++) {
190
37.2M
        uint32_t tok = tree_samples.Token(pred, i);
191
37.2M
        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
192
37.2M
      }
193
121k
    }
194
121k
    max_symbols = Padded(max_symbols);
195
121k
    std::vector<int32_t> counts(max_symbols * num_predictors);
196
121k
    std::vector<uint32_t> tot_extra_bits(num_predictors);
197
243k
    for (size_t pred = 0; pred < num_predictors; pred++) {
198
121k
      size_t extra_bits = 0;
199
37.3M
      for (size_t i = begin; i < end; i++) {
200
37.2M
        const ResidualToken& rt = tree_samples.RToken(pred, i);
201
37.2M
        size_t count = tree_samples.Count(i);
202
37.2M
        counts[pred * max_symbols + rt.tok] += count;
203
37.2M
        extra_bits += rt.nbits * count;
204
37.2M
      }
205
121k
      tot_extra_bits[pred] = extra_bits;
206
121k
    }
207
208
121k
    float base_bits;
209
121k
    {
210
121k
      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
211
121k
      base_bits =
212
121k
          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
213
121k
          tot_extra_bits[pred];
214
121k
    }
215
216
121k
    SplitInfo *best = &best_split_nonstatic;
217
218
121k
    SplitInfo forced_split;
219
    // The multiplier ranges cut halfway through the current ranges of static
220
    // properties. We do this even if the current node is not a leaf, to
221
    // minimize the number of nodes in the resulting tree.
222
121k
    for (const auto &mmi : mul_info) {
223
121k
      uint32_t axis;
224
121k
      uint32_t val;
225
121k
      IntersectionType t =
226
121k
          BoxIntersects(static_prop_range, mmi.range, axis, val);
227
121k
      if (t == IntersectionType::kNone) continue;
228
121k
      if (t == IntersectionType::kInside) {
229
121k
        (*tree)[pos].multiplier = mmi.multiplier;
230
121k
        break;
231
121k
      }
232
0
      if (t == IntersectionType::kPartial) {
233
0
        JXL_DASSERT(axis < kNumStaticProperties);
234
0
        forced_split.val = tree_samples.QuantizeStaticProperty(axis, val);
235
0
        forced_split.prop = axis;
236
0
        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
237
0
        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
238
0
        best = &forced_split;
239
0
        best->pos = begin;
240
0
        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
241
0
        for (size_t x = begin; x < end; x++) {
242
0
          if (tree_samples.Property(best->prop, x) <= best->val) {
243
0
            best->pos++;
244
0
          }
245
0
        }
246
0
        break;
247
0
      }
248
0
    }
249
250
121k
    if (best != &forced_split) {
251
121k
      std::vector<int> prop_value_used_count;
252
121k
      std::vector<int> count_increase;
253
121k
      std::vector<size_t> extra_bits_increase;
254
      // For each property, compute which of its values are used, and what
255
      // tokens correspond to those usages. Then, iterate through the values,
256
      // and compute the entropy of each side of the split (of the form `prop >
257
      // threshold`). Finally, find the split that minimizes the cost.
258
121k
      struct CostInfo {
259
121k
        float cost = std::numeric_limits<float>::max();
260
121k
        float extra_cost = 0;
261
8.82M
        float Cost() const { return cost + extra_cost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
Line
Count
Source
261
8.82M
        float Cost() const { return cost + extra_cost; }
Unexecuted instantiation: enc_ma.cc:jxl::N_AVX3::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
Unexecuted instantiation: enc_ma.cc:jxl::N_AVX3_ZEN4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
Unexecuted instantiation: enc_ma.cc:jxl::N_AVX3_SPR::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const
262
121k
        Predictor pred;  // will be uninitialized in some cases, but never used.
263
121k
      };
264
121k
      std::vector<CostInfo> costs_l;
265
121k
      std::vector<CostInfo> costs_r;
266
267
121k
      std::vector<int32_t> counts_above(max_symbols);
268
121k
      std::vector<int32_t> counts_below(max_symbols);
269
270
      // The lower the threshold, the higher the expected noisiness of the
271
      // estimate. Thus, discourage changing predictors.
272
121k
      float change_pred_penalty = 800.0f / (100.0f + threshold);
273
960k
      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
274
838k
           prop++) {
275
838k
        costs_l.clear();
276
838k
        costs_r.clear();
277
838k
        size_t prop_size = tree_samples.NumPropertyValues(prop);
278
838k
        if (extra_bits_increase.size() < prop_size) {
279
240k
          count_increase.resize(prop_size * max_symbols);
280
240k
          extra_bits_increase.resize(prop_size);
281
240k
        }
282
        // Clear prop_value_used_count (which cannot be cleared "on the go")
283
838k
        prop_value_used_count.clear();
284
838k
        prop_value_used_count.resize(prop_size);
285
286
838k
        size_t first_used = prop_size;
287
838k
        size_t last_used = 0;
288
289
        // TODO(veluca): consider finding multiple splits along a single
290
        // property at the same time, possibly with a bottom-up approach.
291
261M
        for (size_t i = begin; i < end; i++) {
292
260M
          size_t p = tree_samples.Property(prop, i);
293
260M
          prop_value_used_count[p]++;
294
260M
          last_used = std::max(last_used, p);
295
260M
          first_used = std::min(first_used, p);
296
260M
        }
297
838k
        costs_l.resize(last_used - first_used);
298
838k
        costs_r.resize(last_used - first_used);
299
        // For all predictors, compute the right and left costs of each split.
300
1.67M
        for (size_t pred = 0; pred < num_predictors; pred++) {
301
          // Compute cost and histogram increments for each property value.
302
261M
          for (size_t i = begin; i < end; i++) {
303
260M
            size_t p = tree_samples.Property(prop, i);
304
260M
            size_t cnt = tree_samples.Count(i);
305
260M
            const ResidualToken& rt = tree_samples.RToken(pred, i);
306
260M
            size_t sym = rt.tok;
307
260M
            count_increase[p * max_symbols + sym] += cnt;
308
260M
            extra_bits_increase[p] += rt.nbits * cnt;
309
260M
          }
310
838k
          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
311
838k
                 max_symbols * sizeof counts_above[0]);
312
838k
          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
313
838k
          size_t extra_bits_below = 0;
314
          // Exclude last used: this ensures neither counts_above nor
315
          // counts_below is empty.
316
6.86M
          for (size_t i = first_used; i < last_used; i++) {
317
6.02M
            if (!prop_value_used_count[i]) continue;
318
4.41M
            extra_bits_below += extra_bits_increase[i];
319
            // The increase for this property value has been used, and will not
320
            // be used again: clear it. Also below.
321
4.41M
            extra_bits_increase[i] = 0;
322
168M
            for (size_t sym = 0; sym < max_symbols; sym++) {
323
164M
              counts_above[sym] -= count_increase[i * max_symbols + sym];
324
164M
              counts_below[sym] += count_increase[i * max_symbols + sym];
325
164M
              count_increase[i * max_symbols + sym] = 0;
326
164M
            }
327
4.41M
            float rcost = EstimateBits(counts_above.data(), max_symbols) +
328
4.41M
                          tot_extra_bits[pred] - extra_bits_below;
329
4.41M
            float lcost = EstimateBits(counts_below.data(), max_symbols) +
330
4.41M
                          extra_bits_below;
331
4.41M
            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
332
4.41M
            float penalty = 0;
333
            // Never discourage moving away from the Weighted predictor.
334
4.41M
            if (tree_samples.PredictorFromIndex(pred) !=
335
4.41M
                    (*tree)[pos].predictor &&
336
4.41M
                (*tree)[pos].predictor != Predictor::Weighted) {
337
0
              penalty = change_pred_penalty;
338
0
            }
339
            // If everything else is equal, disfavour Weighted (slower) and
340
            // favour Zero (faster if it's the only predictor used in a
341
            // group+channel combination)
342
4.41M
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
343
0
              penalty += 1e-8;
344
0
            }
345
4.41M
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
346
0
              penalty -= 1e-8;
347
0
            }
348
4.41M
            if (rcost + penalty < costs_r[i - first_used].Cost()) {
349
4.41M
              costs_r[i - first_used].cost = rcost;
350
4.41M
              costs_r[i - first_used].extra_cost = penalty;
351
4.41M
              costs_r[i - first_used].pred =
352
4.41M
                  tree_samples.PredictorFromIndex(pred);
353
4.41M
            }
354
4.41M
            if (lcost + penalty < costs_l[i - first_used].Cost()) {
355
4.41M
              costs_l[i - first_used].cost = lcost;
356
4.41M
              costs_l[i - first_used].extra_cost = penalty;
357
4.41M
              costs_l[i - first_used].pred =
358
4.41M
                  tree_samples.PredictorFromIndex(pred);
359
4.41M
            }
360
4.41M
          }
361
838k
        }
362
        // Iterate through the possible splits and find the one with minimum sum
363
        // of costs of the two sides.
364
838k
        size_t split = begin;
365
6.86M
        for (size_t i = first_used; i < last_used; i++) {
366
6.02M
          if (!prop_value_used_count[i]) continue;
367
4.41M
          split += prop_value_used_count[i];
368
4.41M
          float rcost = costs_r[i - first_used].cost;
369
4.41M
          float lcost = costs_l[i - first_used].cost;
370
          // WP was not used + we would use the WP property or predictor
371
4.41M
          bool adds_wp =
372
4.41M
              (tree_samples.PropertyFromIndex(prop) == kWPProp &&
373
4.41M
               (used_properties & (1LU << prop)) == 0) ||
374
4.41M
              ((costs_l[i - first_used].pred == Predictor::Weighted ||
375
2.75M
                costs_r[i - first_used].pred == Predictor::Weighted) &&
376
2.75M
               (*tree)[pos].predictor != Predictor::Weighted);
377
4.41M
          bool zero_entropy_side = rcost == 0 || lcost == 0;
378
379
4.41M
          SplitInfo &best_ref =
380
4.41M
              tree_samples.PropertyFromIndex(prop) < kNumStaticProperties
381
4.41M
                  ? (zero_entropy_side ? best_split_static_constant
382
34.5k
                                       : best_split_static)
383
4.41M
                  : (adds_wp ? best_split_nonstatic : best_split_nowp);
384
4.41M
          if (lcost + rcost < best_ref.Cost()) {
385
1.01M
            best_ref.prop = prop;
386
1.01M
            best_ref.val = i;
387
1.01M
            best_ref.pos = split;
388
1.01M
            best_ref.lcost = lcost;
389
1.01M
            best_ref.lpred = costs_l[i - first_used].pred;
390
1.01M
            best_ref.rcost = rcost;
391
1.01M
            best_ref.rpred = costs_r[i - first_used].pred;
392
1.01M
          }
393
4.41M
        }
394
        // Clear extra_bits_increase and cost_increase for last_used.
395
838k
        extra_bits_increase[last_used] = 0;
396
31.3M
        for (size_t sym = 0; sym < max_symbols; sym++) {
397
30.4M
          count_increase[last_used * max_symbols + sym] = 0;
398
30.4M
        }
399
838k
      }
400
401
      // Try to avoid introducing WP.
402
121k
      if (best_split_nowp.Cost() + threshold < base_bits &&
403
121k
          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
404
56.3k
        best = &best_split_nowp;
405
56.3k
      }
406
      // Split along static props if possible and not significantly more
407
      // expensive.
408
121k
      if (best_split_static.Cost() + threshold < base_bits &&
409
121k
          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
410
4.72k
        best = &best_split_static;
411
4.72k
      }
412
      // Split along static props to create constant nodes if possible.
413
121k
      if (best_split_static_constant.Cost() + threshold < base_bits) {
414
72
        best = &best_split_static_constant;
415
72
      }
416
121k
    }
417
418
121k
    if (best->Cost() + threshold < base_bits) {
419
60.3k
      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
420
60.3k
      pixel_type dequant =
421
60.3k
          tree_samples.UnquantizeProperty(best->prop, best->val);
422
      // Split node and try to split children.
423
60.3k
      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
424
      // "Sort" according to winning property
425
60.3k
      SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop,
426
60.3k
                       best->val);
427
60.3k
      if (p >= kNumStaticProperties) {
428
55.5k
        used_properties |= 1 << best->prop;
429
55.5k
      }
430
60.3k
      auto new_sp_range = static_prop_range;
431
60.3k
      if (p < kNumStaticProperties) {
432
4.79k
        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
433
4.79k
        new_sp_range[p][1] = dequant + 1;
434
4.79k
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
435
4.79k
      }
436
60.3k
      nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
437
60.3k
                               used_properties, new_sp_range});
438
60.3k
      new_sp_range = static_prop_range;
439
60.3k
      if (p < kNumStaticProperties) {
440
4.79k
        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
441
4.79k
        new_sp_range[p][0] = dequant + 1;
442
4.79k
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
443
4.79k
      }
444
60.3k
      nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
445
60.3k
                               used_properties, new_sp_range});
446
60.3k
    }
447
121k
  }
448
1.00k
}
Unexecuted instantiation: jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Line
Count
Source
141
1.00k
                   float fast_decode_multiplier, Tree *tree) {
142
1.00k
  struct NodeInfo {
143
1.00k
    size_t pos;
144
1.00k
    size_t begin;
145
1.00k
    size_t end;
146
1.00k
    uint64_t used_properties;
147
1.00k
    StaticPropRange static_prop_range;
148
1.00k
  };
149
1.00k
  std::vector<NodeInfo> nodes;
150
1.00k
  nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
151
1.00k
                           initial_static_prop_range});
152
153
1.00k
  size_t num_predictors = tree_samples.NumPredictors();
154
1.00k
  size_t num_properties = tree_samples.NumProperties();
155
156
  // TODO(veluca): consider parallelizing the search (processing multiple nodes
157
  // at a time).
158
122k
  while (!nodes.empty()) {
159
121k
    size_t pos = nodes.back().pos;
160
121k
    size_t begin = nodes.back().begin;
161
121k
    size_t end = nodes.back().end;
162
121k
    uint64_t used_properties = nodes.back().used_properties;
163
121k
    StaticPropRange static_prop_range = nodes.back().static_prop_range;
164
121k
    nodes.pop_back();
165
121k
    if (begin == end) continue;
166
167
121k
    struct SplitInfo {
168
121k
      size_t prop = 0;
169
121k
      uint32_t val = 0;
170
121k
      size_t pos = 0;
171
121k
      float lcost = std::numeric_limits<float>::max();
172
121k
      float rcost = std::numeric_limits<float>::max();
173
121k
      Predictor lpred = Predictor::Zero;
174
121k
      Predictor rpred = Predictor::Zero;
175
121k
      float Cost() const { return lcost + rcost; }
176
121k
    };
177
178
121k
    SplitInfo best_split_static_constant;
179
121k
    SplitInfo best_split_static;
180
121k
    SplitInfo best_split_nonstatic;
181
121k
    SplitInfo best_split_nowp;
182
183
121k
    JXL_DASSERT(begin <= end);
184
121k
    JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
185
186
    // Compute the maximum token in the range.
187
121k
    size_t max_symbols = 0;
188
243k
    for (size_t pred = 0; pred < num_predictors; pred++) {
189
37.3M
      for (size_t i = begin; i < end; i++) {
190
37.2M
        uint32_t tok = tree_samples.Token(pred, i);
191
37.2M
        max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
192
37.2M
      }
193
121k
    }
194
121k
    max_symbols = Padded(max_symbols);
195
121k
    std::vector<int32_t> counts(max_symbols * num_predictors);
196
121k
    std::vector<uint32_t> tot_extra_bits(num_predictors);
197
243k
    for (size_t pred = 0; pred < num_predictors; pred++) {
198
121k
      size_t extra_bits = 0;
199
37.3M
      for (size_t i = begin; i < end; i++) {
200
37.2M
        const ResidualToken& rt = tree_samples.RToken(pred, i);
201
37.2M
        size_t count = tree_samples.Count(i);
202
37.2M
        counts[pred * max_symbols + rt.tok] += count;
203
37.2M
        extra_bits += rt.nbits * count;
204
37.2M
      }
205
121k
      tot_extra_bits[pred] = extra_bits;
206
121k
    }
207
208
121k
    float base_bits;
209
121k
    {
210
121k
      size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
211
121k
      base_bits =
212
121k
          EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
213
121k
          tot_extra_bits[pred];
214
121k
    }
215
216
121k
    SplitInfo *best = &best_split_nonstatic;
217
218
121k
    SplitInfo forced_split;
219
    // The multiplier ranges cut halfway through the current ranges of static
220
    // properties. We do this even if the current node is not a leaf, to
221
    // minimize the number of nodes in the resulting tree.
222
121k
    for (const auto &mmi : mul_info) {
223
121k
      uint32_t axis;
224
121k
      uint32_t val;
225
121k
      IntersectionType t =
226
121k
          BoxIntersects(static_prop_range, mmi.range, axis, val);
227
121k
      if (t == IntersectionType::kNone) continue;
228
121k
      if (t == IntersectionType::kInside) {
229
121k
        (*tree)[pos].multiplier = mmi.multiplier;
230
121k
        break;
231
121k
      }
232
0
      if (t == IntersectionType::kPartial) {
233
0
        JXL_DASSERT(axis < kNumStaticProperties);
234
0
        forced_split.val = tree_samples.QuantizeStaticProperty(axis, val);
235
0
        forced_split.prop = axis;
236
0
        forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
237
0
        forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
238
0
        best = &forced_split;
239
0
        best->pos = begin;
240
0
        JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
241
0
        for (size_t x = begin; x < end; x++) {
242
0
          if (tree_samples.Property(best->prop, x) <= best->val) {
243
0
            best->pos++;
244
0
          }
245
0
        }
246
0
        break;
247
0
      }
248
0
    }
249
250
121k
    if (best != &forced_split) {
251
121k
      std::vector<int> prop_value_used_count;
252
121k
      std::vector<int> count_increase;
253
121k
      std::vector<size_t> extra_bits_increase;
254
      // For each property, compute which of its values are used, and what
255
      // tokens correspond to those usages. Then, iterate through the values,
256
      // and compute the entropy of each side of the split (of the form `prop >
257
      // threshold`). Finally, find the split that minimizes the cost.
258
121k
      struct CostInfo {
259
121k
        float cost = std::numeric_limits<float>::max();
260
121k
        float extra_cost = 0;
261
121k
        float Cost() const { return cost + extra_cost; }
262
121k
        Predictor pred;  // will be uninitialized in some cases, but never used.
263
121k
      };
264
121k
      std::vector<CostInfo> costs_l;
265
121k
      std::vector<CostInfo> costs_r;
266
267
121k
      std::vector<int32_t> counts_above(max_symbols);
268
121k
      std::vector<int32_t> counts_below(max_symbols);
269
270
      // The lower the threshold, the higher the expected noisiness of the
271
      // estimate. Thus, discourage changing predictors.
272
121k
      float change_pred_penalty = 800.0f / (100.0f + threshold);
273
960k
      for (size_t prop = 0; prop < num_properties && base_bits > threshold;
274
838k
           prop++) {
275
838k
        costs_l.clear();
276
838k
        costs_r.clear();
277
838k
        size_t prop_size = tree_samples.NumPropertyValues(prop);
278
838k
        if (extra_bits_increase.size() < prop_size) {
279
240k
          count_increase.resize(prop_size * max_symbols);
280
240k
          extra_bits_increase.resize(prop_size);
281
240k
        }
282
        // Clear prop_value_used_count (which cannot be cleared "on the go")
283
838k
        prop_value_used_count.clear();
284
838k
        prop_value_used_count.resize(prop_size);
285
286
838k
        size_t first_used = prop_size;
287
838k
        size_t last_used = 0;
288
289
        // TODO(veluca): consider finding multiple splits along a single
290
        // property at the same time, possibly with a bottom-up approach.
291
261M
        for (size_t i = begin; i < end; i++) {
292
260M
          size_t p = tree_samples.Property(prop, i);
293
260M
          prop_value_used_count[p]++;
294
260M
          last_used = std::max(last_used, p);
295
260M
          first_used = std::min(first_used, p);
296
260M
        }
297
838k
        costs_l.resize(last_used - first_used);
298
838k
        costs_r.resize(last_used - first_used);
299
        // For all predictors, compute the right and left costs of each split.
300
1.67M
        for (size_t pred = 0; pred < num_predictors; pred++) {
301
          // Compute cost and histogram increments for each property value.
302
261M
          for (size_t i = begin; i < end; i++) {
303
260M
            size_t p = tree_samples.Property(prop, i);
304
260M
            size_t cnt = tree_samples.Count(i);
305
260M
            const ResidualToken& rt = tree_samples.RToken(pred, i);
306
260M
            size_t sym = rt.tok;
307
260M
            count_increase[p * max_symbols + sym] += cnt;
308
260M
            extra_bits_increase[p] += rt.nbits * cnt;
309
260M
          }
310
838k
          memcpy(counts_above.data(), counts.data() + pred * max_symbols,
311
838k
                 max_symbols * sizeof counts_above[0]);
312
838k
          memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
313
838k
          size_t extra_bits_below = 0;
314
          // Exclude last used: this ensures neither counts_above nor
315
          // counts_below is empty.
316
6.86M
          for (size_t i = first_used; i < last_used; i++) {
317
6.02M
            if (!prop_value_used_count[i]) continue;
318
4.41M
            extra_bits_below += extra_bits_increase[i];
319
            // The increase for this property value has been used, and will not
320
            // be used again: clear it. Also below.
321
4.41M
            extra_bits_increase[i] = 0;
322
168M
            for (size_t sym = 0; sym < max_symbols; sym++) {
323
164M
              counts_above[sym] -= count_increase[i * max_symbols + sym];
324
164M
              counts_below[sym] += count_increase[i * max_symbols + sym];
325
164M
              count_increase[i * max_symbols + sym] = 0;
326
164M
            }
327
4.41M
            float rcost = EstimateBits(counts_above.data(), max_symbols) +
328
4.41M
                          tot_extra_bits[pred] - extra_bits_below;
329
4.41M
            float lcost = EstimateBits(counts_below.data(), max_symbols) +
330
4.41M
                          extra_bits_below;
331
4.41M
            JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
332
4.41M
            float penalty = 0;
333
            // Never discourage moving away from the Weighted predictor.
334
4.41M
            if (tree_samples.PredictorFromIndex(pred) !=
335
4.41M
                    (*tree)[pos].predictor &&
336
4.41M
                (*tree)[pos].predictor != Predictor::Weighted) {
337
0
              penalty = change_pred_penalty;
338
0
            }
339
            // If everything else is equal, disfavour Weighted (slower) and
340
            // favour Zero (faster if it's the only predictor used in a
341
            // group+channel combination)
342
4.41M
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
343
0
              penalty += 1e-8;
344
0
            }
345
4.41M
            if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
346
0
              penalty -= 1e-8;
347
0
            }
348
4.41M
            if (rcost + penalty < costs_r[i - first_used].Cost()) {
349
4.41M
              costs_r[i - first_used].cost = rcost;
350
4.41M
              costs_r[i - first_used].extra_cost = penalty;
351
4.41M
              costs_r[i - first_used].pred =
352
4.41M
                  tree_samples.PredictorFromIndex(pred);
353
4.41M
            }
354
4.41M
            if (lcost + penalty < costs_l[i - first_used].Cost()) {
355
4.41M
              costs_l[i - first_used].cost = lcost;
356
4.41M
              costs_l[i - first_used].extra_cost = penalty;
357
4.41M
              costs_l[i - first_used].pred =
358
4.41M
                  tree_samples.PredictorFromIndex(pred);
359
4.41M
            }
360
4.41M
          }
361
838k
        }
362
        // Iterate through the possible splits and find the one with minimum sum
363
        // of costs of the two sides.
364
838k
        size_t split = begin;
365
6.86M
        for (size_t i = first_used; i < last_used; i++) {
366
6.02M
          if (!prop_value_used_count[i]) continue;
367
4.41M
          split += prop_value_used_count[i];
368
4.41M
          float rcost = costs_r[i - first_used].cost;
369
4.41M
          float lcost = costs_l[i - first_used].cost;
370
          // WP was not used + we would use the WP property or predictor
371
4.41M
          bool adds_wp =
372
4.41M
              (tree_samples.PropertyFromIndex(prop) == kWPProp &&
373
4.41M
               (used_properties & (1LU << prop)) == 0) ||
374
4.41M
              ((costs_l[i - first_used].pred == Predictor::Weighted ||
375
2.75M
                costs_r[i - first_used].pred == Predictor::Weighted) &&
376
2.75M
               (*tree)[pos].predictor != Predictor::Weighted);
377
4.41M
          bool zero_entropy_side = rcost == 0 || lcost == 0;
378
379
4.41M
          SplitInfo &best_ref =
380
4.41M
              tree_samples.PropertyFromIndex(prop) < kNumStaticProperties
381
4.41M
                  ? (zero_entropy_side ? best_split_static_constant
382
34.5k
                                       : best_split_static)
383
4.41M
                  : (adds_wp ? best_split_nonstatic : best_split_nowp);
384
4.41M
          if (lcost + rcost < best_ref.Cost()) {
385
1.01M
            best_ref.prop = prop;
386
1.01M
            best_ref.val = i;
387
1.01M
            best_ref.pos = split;
388
1.01M
            best_ref.lcost = lcost;
389
1.01M
            best_ref.lpred = costs_l[i - first_used].pred;
390
1.01M
            best_ref.rcost = rcost;
391
1.01M
            best_ref.rpred = costs_r[i - first_used].pred;
392
1.01M
          }
393
4.41M
        }
394
        // Clear extra_bits_increase and cost_increase for last_used.
395
838k
        extra_bits_increase[last_used] = 0;
396
31.3M
        for (size_t sym = 0; sym < max_symbols; sym++) {
397
30.4M
          count_increase[last_used * max_symbols + sym] = 0;
398
30.4M
        }
399
838k
      }
400
401
      // Try to avoid introducing WP.
402
121k
      if (best_split_nowp.Cost() + threshold < base_bits &&
403
121k
          best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
404
56.3k
        best = &best_split_nowp;
405
56.3k
      }
406
      // Split along static props if possible and not significantly more
407
      // expensive.
408
121k
      if (best_split_static.Cost() + threshold < base_bits &&
409
121k
          best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
410
4.72k
        best = &best_split_static;
411
4.72k
      }
412
      // Split along static props to create constant nodes if possible.
413
121k
      if (best_split_static_constant.Cost() + threshold < base_bits) {
414
72
        best = &best_split_static_constant;
415
72
      }
416
121k
    }
417
418
121k
    if (best->Cost() + threshold < base_bits) {
419
60.3k
      uint32_t p = tree_samples.PropertyFromIndex(best->prop);
420
60.3k
      pixel_type dequant =
421
60.3k
          tree_samples.UnquantizeProperty(best->prop, best->val);
422
      // Split node and try to split children.
423
60.3k
      MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
424
      // "Sort" according to winning property
425
60.3k
      SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop,
426
60.3k
                       best->val);
427
60.3k
      if (p >= kNumStaticProperties) {
428
55.5k
        used_properties |= 1 << best->prop;
429
55.5k
      }
430
60.3k
      auto new_sp_range = static_prop_range;
431
60.3k
      if (p < kNumStaticProperties) {
432
4.79k
        JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
433
4.79k
        new_sp_range[p][1] = dequant + 1;
434
4.79k
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
435
4.79k
      }
436
60.3k
      nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
437
60.3k
                               used_properties, new_sp_range});
438
60.3k
      new_sp_range = static_prop_range;
439
60.3k
      if (p < kNumStaticProperties) {
440
4.79k
        JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
441
4.79k
        new_sp_range[p][0] = dequant + 1;
442
4.79k
        JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
443
4.79k
      }
444
60.3k
      nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
445
60.3k
                               used_properties, new_sp_range});
446
60.3k
    }
447
121k
  }
448
1.00k
}
Unexecuted instantiation: jxl::N_AVX3::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Unexecuted instantiation: jxl::N_AVX3_SPR::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
Unexecuted instantiation: jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)
449
450
// NOLINTNEXTLINE(google-readability-namespace-comments)
451
}  // namespace HWY_NAMESPACE
452
}  // namespace jxl
453
HWY_AFTER_NAMESPACE();
454
455
#if HWY_ONCE
456
namespace jxl {
457
458
HWY_EXPORT(FindBestSplit);  // Local function.
459
460
Status ComputeBestTree(TreeSamples &tree_samples, float threshold,
461
                       const std::vector<ModularMultiplierInfo> &mul_info,
462
                       StaticPropRange static_prop_range,
463
1.00k
                       float fast_decode_multiplier, Tree *tree) {
464
  // TODO(veluca): take into account that different contexts can have different
465
  // uint configs.
466
  //
467
  // Initialize tree.
468
1.00k
  tree->emplace_back();
469
1.00k
  tree->back().property = -1;
470
1.00k
  tree->back().predictor = tree_samples.PredictorFromIndex(0);
471
1.00k
  tree->back().predictor_offset = 0;
472
1.00k
  tree->back().multiplier = 1;
473
1.00k
  JXL_ENSURE(tree_samples.NumProperties() < 64);
474
475
1.00k
  JXL_ENSURE(tree_samples.NumDistinctSamples() <=
476
1.00k
             std::numeric_limits<uint32_t>::max());
477
1.00k
  HWY_DYNAMIC_DISPATCH(FindBestSplit)
478
1.00k
  (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier,
479
1.00k
   tree);
480
1.00k
  return true;
481
1.00k
}
482
483
#if JXL_CXX_LANG < JXL_CXX_17
484
constexpr int32_t TreeSamples::kPropertyRange;
485
constexpr uint32_t TreeSamples::kDedupEntryUnused;
486
#endif
487
488
Status TreeSamples::SetPredictor(Predictor predictor,
489
1.00k
                                 ModularOptions::TreeMode wp_tree_mode) {
490
1.00k
  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
491
0
    predictors = {Predictor::Weighted};
492
0
    residuals.resize(1);
493
0
    return true;
494
0
  }
495
1.00k
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP &&
496
1.00k
      predictor == Predictor::Weighted) {
497
0
    return JXL_FAILURE("Invalid predictor settings");
498
0
  }
499
1.00k
  if (predictor == Predictor::Variable) {
500
0
    for (size_t i = 0; i < kNumModularPredictors; i++) {
501
0
      predictors.push_back(static_cast<Predictor>(i));
502
0
    }
503
0
    std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]);
504
0
    std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]);
505
1.00k
  } else if (predictor == Predictor::Best) {
506
0
    predictors = {Predictor::Weighted, Predictor::Gradient};
507
1.00k
  } else {
508
1.00k
    predictors = {predictor};
509
1.00k
  }
510
1.00k
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
511
0
    predictors.erase(
512
0
        std::remove(predictors.begin(), predictors.end(), Predictor::Weighted),
513
0
        predictors.end());
514
0
  }
515
1.00k
  residuals.resize(predictors.size());
516
1.00k
  return true;
517
1.00k
}
518
519
Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties,
520
1.00k
                                  ModularOptions::TreeMode wp_tree_mode) {
521
1.00k
  props_to_use = properties;
522
1.00k
  if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
523
0
    props_to_use = {static_cast<uint32_t>(kWPProp)};
524
0
  }
525
1.00k
  if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) {
526
0
    props_to_use = {static_cast<uint32_t>(kGradientProp)};
527
0
  }
528
1.00k
  if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
529
0
    props_to_use.erase(
530
0
        std::remove(props_to_use.begin(), props_to_use.end(), kWPProp),
531
0
        props_to_use.end());
532
0
  }
533
1.00k
  if (props_to_use.empty()) {
534
0
    return JXL_FAILURE("Invalid property set configuration");
535
0
  }
536
1.00k
  num_static_props = 0;
537
  // Check that if static properties present, then those are at the beginning.
538
8.00k
  for (size_t i = 0; i < props_to_use.size(); ++i) {
539
7.00k
    uint32_t prop = props_to_use[i];
540
7.00k
    if (prop < kNumStaticProperties) {
541
2.00k
      JXL_DASSERT(i == prop);
542
2.00k
      num_static_props++;
543
2.00k
    }
544
7.00k
  }
545
1.00k
  props.resize(props_to_use.size() - num_static_props);
546
1.00k
  return true;
547
1.00k
}
548
549
3.00k
void TreeSamples::InitTable(size_t log_size) {
550
3.00k
  size_t size = 1ULL << log_size;
551
3.00k
  if (dedup_table_.size() == size) return;
552
1.87k
  dedup_table_.resize(size, kDedupEntryUnused);
553
1.93M
  for (size_t i = 0; i < NumDistinctSamples(); i++) {
554
1.93M
    if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) {
555
1.93M
      AddToTable(i);
556
1.93M
    }
557
1.93M
  }
558
1.87k
}
559
560
12.1M
bool TreeSamples::AddToTableAndMerge(size_t a) {
561
12.1M
  size_t pos1 = Hash1(a);
562
12.1M
  size_t pos2 = Hash2(a);
563
12.1M
  if (dedup_table_[pos1] != kDedupEntryUnused &&
564
12.1M
      IsSameSample(a, dedup_table_[pos1])) {
565
6.04M
    JXL_DASSERT(sample_counts[a] == 1);
566
6.04M
    sample_counts[dedup_table_[pos1]]++;
567
    // Remove from hash table samples that are saturated.
568
6.04M
    if (sample_counts[dedup_table_[pos1]] ==
569
6.04M
        std::numeric_limits<uint16_t>::max()) {
570
0
      dedup_table_[pos1] = kDedupEntryUnused;
571
0
    }
572
6.04M
    return true;
573
6.04M
  }
574
6.12M
  if (dedup_table_[pos2] != kDedupEntryUnused &&
575
6.12M
      IsSameSample(a, dedup_table_[pos2])) {
576
1.75M
    JXL_DASSERT(sample_counts[a] == 1);
577
1.75M
    sample_counts[dedup_table_[pos2]]++;
578
    // Remove from hash table samples that are saturated.
579
1.75M
    if (sample_counts[dedup_table_[pos2]] ==
580
1.75M
        std::numeric_limits<uint16_t>::max()) {
581
0
      dedup_table_[pos2] = kDedupEntryUnused;
582
0
    }
583
1.75M
    return true;
584
1.75M
  }
585
4.36M
  AddToTable(a);
586
4.36M
  return false;
587
6.12M
}
588
589
6.30M
void TreeSamples::AddToTable(size_t a) {
590
6.30M
  size_t pos1 = Hash1(a);
591
6.30M
  size_t pos2 = Hash2(a);
592
6.30M
  if (dedup_table_[pos1] == kDedupEntryUnused) {
593
3.16M
    dedup_table_[pos1] = a;
594
3.16M
  } else if (dedup_table_[pos2] == kDedupEntryUnused) {
595
1.82M
    dedup_table_[pos2] = a;
596
1.82M
  }
597
6.30M
}
598
599
3.00k
void TreeSamples::PrepareForSamples(size_t extra_num_samples) {
600
3.00k
  for (auto &res : residuals) {
601
3.00k
    res.reserve(res.size() + extra_num_samples);
602
3.00k
  }
603
9.00k
  for (size_t i = 0; i < num_static_props; ++i) {
604
6.00k
    static_props[i].reserve(static_props[i].size() + extra_num_samples);
605
6.00k
  }
606
15.0k
  for (auto &p : props) {
607
15.0k
    p.reserve(p.size() + extra_num_samples);
608
15.0k
  }
609
3.00k
  size_t total_num_samples = extra_num_samples + sample_counts.size();
610
3.00k
  size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2);
611
3.00k
  InitTable(next_size);
612
3.00k
}
613
614
18.4M
size_t TreeSamples::Hash1(size_t a) const {
615
18.4M
  constexpr uint64_t constant = 0x1e35a7bd;
616
18.4M
  uint64_t h = constant;
617
18.4M
  for (const auto &r : residuals) {
618
18.4M
    h = h * constant + r[a].tok;
619
18.4M
    h = h * constant + r[a].nbits;
620
18.4M
  }
621
55.4M
  for (size_t i = 0; i < num_static_props; ++i) {
622
36.9M
    h = h * constant + static_props[i][a];
623
36.9M
  }
624
92.3M
  for (const auto &p : props) {
625
92.3M
    h = h * constant + p[a];
626
92.3M
  }
627
18.4M
  return (h >> 16) & (dedup_table_.size() - 1);
628
18.4M
}
629
18.4M
size_t TreeSamples::Hash2(size_t a) const {
630
18.4M
  constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
631
18.4M
  uint64_t h = constant;
632
55.4M
  for (size_t i = 0; i < num_static_props; ++i) {
633
36.9M
    h = h * constant ^ static_props[i][a];
634
36.9M
  }
635
92.3M
  for (const auto &p : props) {
636
92.3M
    h = h * constant ^ p[a];
637
92.3M
  }
638
18.4M
  for (const auto &r : residuals) {
639
18.4M
    h = h * constant ^ r[a].tok;
640
18.4M
    h = h * constant ^ r[a].nbits;
641
18.4M
  }
642
18.4M
  return (h >> 16) & (dedup_table_.size() - 1);
643
18.4M
}
644
645
12.7M
bool TreeSamples::IsSameSample(size_t a, size_t b) const {
646
12.7M
  bool ret = true;
647
12.7M
  for (const auto &r : residuals) {
648
12.7M
    if (r[a].tok != r[b].tok) {
649
2.69M
      ret = false;
650
2.69M
    }
651
12.7M
    if (r[a].nbits != r[b].nbits) {
652
1.40M
      ret = false;
653
1.40M
    }
654
12.7M
  }
655
38.2M
  for (size_t i = 0; i < num_static_props; ++i) {
656
25.4M
    if (static_props[i][a] != static_props[i][b]) {
657
2.42M
      ret = false;
658
2.42M
    }
659
25.4M
  }
660
63.7M
  for (const auto &p : props) {
661
63.7M
    if (p[a] != p[b]) {
662
15.5M
      ret = false;
663
15.5M
    }
664
63.7M
  }
665
12.7M
  return ret;
666
12.7M
}
667
668
void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties,
669
12.1M
                            const pixel_type_w *predictions) {
670
24.3M
  for (size_t i = 0; i < predictors.size(); i++) {
671
12.1M
    pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
672
12.1M
    uint32_t tok, nbits, bits;
673
12.1M
    HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
674
12.1M
    JXL_DASSERT(tok < 256);
675
12.1M
    JXL_DASSERT(nbits < 256);
676
12.1M
    residuals[i].emplace_back(
677
12.1M
        ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)});
678
12.1M
  }
679
36.5M
  for (size_t i = 0; i < num_static_props; ++i) {
680
24.3M
    static_props[i].push_back(QuantizeStaticProperty(i, properties[i]));
681
24.3M
  }
682
73.0M
  for (size_t i = num_static_props; i < props_to_use.size(); i++) {
683
60.8M
    props[i - num_static_props].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
684
60.8M
  }
685
12.1M
  sample_counts.push_back(1);
686
12.1M
  num_samples++;
687
12.1M
  if (AddToTableAndMerge(sample_counts.size() - 1)) {
688
7.80M
    for (auto &r : residuals) r.pop_back();
689
23.4M
    for (size_t i = 0; i < num_static_props; ++i) static_props[i].pop_back();
690
39.0M
    for (auto &p : props) p.pop_back();
691
7.80M
    sample_counts.pop_back();
692
7.80M
  }
693
12.1M
}
694
695
6.59M
void TreeSamples::Swap(size_t a, size_t b) {
696
6.59M
  if (a == b) return;
697
6.59M
  for (auto &r : residuals) {
698
6.59M
    std::swap(r[a], r[b]);
699
6.59M
  }
700
19.7M
  for (size_t i = 0; i < num_static_props; ++i) {
701
13.1M
    std::swap(static_props[i][a], static_props[i][b]);
702
13.1M
  }
703
32.9M
  for (auto &p : props) {
704
32.9M
    std::swap(p[a], p[b]);
705
32.9M
  }
706
6.59M
  std::swap(sample_counts[a], sample_counts[b]);
707
6.59M
}
708
709
namespace {
710
std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram,
711
3.00k
                                       size_t num_chunks) {
712
3.00k
  if (histogram.empty() || num_chunks == 0) return {};
713
3.00k
  uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
714
3.00k
  if (sum == 0) return {};
715
  // TODO(veluca): selecting distinct quantiles is likely not the best
716
  // way to go about this.
717
3.00k
  std::vector<int32_t> thresholds;
718
3.00k
  uint64_t cumsum = 0;
719
3.00k
  uint64_t threshold = 1;
720
1.03M
  for (size_t i = 0; i < histogram.size(); i++) {
721
1.03M
    cumsum += histogram[i];
722
1.03M
    if (cumsum * num_chunks >= threshold * sum) {
723
18.8k
      thresholds.push_back(i);
724
163k
      while (cumsum * num_chunks >= threshold * sum) threshold++;
725
18.8k
    }
726
1.03M
  }
727
3.00k
  JXL_DASSERT(thresholds.size() <= num_chunks);
728
  // last value collects all histogram and is not really a threshold
729
3.00k
  thresholds.pop_back();
730
3.00k
  return thresholds;
731
3.00k
}
732
733
std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples,
734
1.00k
                                     size_t num_chunks) {
735
1.00k
  if (samples.empty()) return {};
736
1.00k
  int min = *std::min_element(samples.begin(), samples.end());
737
1.00k
  constexpr int kRange = 512;
738
1.00k
  min = jxl::Clamp1(min, -kRange, kRange);
739
1.00k
  std::vector<uint32_t> counts(2 * kRange + 1);
740
1.17M
  for (int s : samples) {
741
1.17M
    uint32_t sample_offset = jxl::Clamp1(s, -kRange, kRange) - min;
742
1.17M
    counts[sample_offset]++;
743
1.17M
  }
744
1.00k
  std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks);
745
13.8k
  for (auto &v : thresholds) v += min;
746
1.00k
  return thresholds;
747
1.00k
}
748
749
// `to[i]` is assigned value `v` conforming `from[v] <= i && from[v-1] > i`.
750
// This is because the decision node in the tree splits on (property) > i,
751
// hence everything that is not > of a threshold should be clustered
752
// together.
753
template <typename T>
754
void QuantMap(const std::vector<int32_t> &from, std::vector<T> &to,
755
7.00k
              size_t num_pegs, int bias) {
756
7.00k
  to.resize(num_pegs);
757
7.00k
  size_t mapped = 0;
758
7.17M
  for (size_t i = 0; i < num_pegs; i++) {
759
7.25M
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
760
86.2k
      mapped++;
761
86.2k
    }
762
7.16M
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
763
7.16M
    to[i] = mapped;
764
7.16M
  }
765
7.00k
}
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned short>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned short, std::__1::allocator<unsigned short> >&, unsigned long, int)
Line
Count
Source
755
2.00k
              size_t num_pegs, int bias) {
756
2.00k
  to.resize(num_pegs);
757
2.00k
  size_t mapped = 0;
758
2.05M
  for (size_t i = 0; i < num_pegs; i++) {
759
2.05M
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
760
2.00k
      mapped++;
761
2.00k
    }
762
2.04M
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
763
2.04M
    to[i] = mapped;
764
2.04M
  }
765
2.00k
}
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned char>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> >&, unsigned long, int)
Line
Count
Source
755
5.00k
              size_t num_pegs, int bias) {
756
5.00k
  to.resize(num_pegs);
757
5.00k
  size_t mapped = 0;
758
5.12M
  for (size_t i = 0; i < num_pegs; i++) {
759
5.20M
    while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) {
760
84.2k
      mapped++;
761
84.2k
    }
762
5.12M
    JXL_DASSERT(static_cast<T>(mapped) == mapped);
763
5.12M
    to[i] = mapped;
764
5.12M
  }
765
5.00k
}
766
}  // namespace
767
768
void TreeSamples::PreQuantizeProperties(
769
    const StaticPropRange &range,
770
    const std::vector<ModularMultiplierInfo> &multiplier_info,
771
    const std::vector<uint32_t> &group_pixel_count,
772
    const std::vector<uint32_t> &channel_pixel_count,
773
    std::vector<pixel_type> &pixel_samples,
774
1.00k
    std::vector<pixel_type> &diff_samples, size_t max_property_values) {
775
  // If we have forced splits because of multipliers, choose channel and group
776
  // thresholds accordingly.
777
1.00k
  std::vector<int32_t> group_multiplier_thresholds;
778
1.00k
  std::vector<int32_t> channel_multiplier_thresholds;
779
1.00k
  for (const auto &v : multiplier_info) {
780
1.00k
    if (v.range[0][0] != range[0][0]) {
781
0
      channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
782
0
    }
783
1.00k
    if (v.range[0][1] != range[0][1]) {
784
0
      channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
785
0
    }
786
1.00k
    if (v.range[1][0] != range[1][0]) {
787
0
      group_multiplier_thresholds.push_back(v.range[1][0] - 1);
788
0
    }
789
1.00k
    if (v.range[1][1] != range[1][1]) {
790
0
      group_multiplier_thresholds.push_back(v.range[1][1] - 1);
791
0
    }
792
1.00k
  }
793
1.00k
  std::sort(channel_multiplier_thresholds.begin(),
794
1.00k
            channel_multiplier_thresholds.end());
795
1.00k
  channel_multiplier_thresholds.resize(
796
1.00k
      std::unique(channel_multiplier_thresholds.begin(),
797
1.00k
                  channel_multiplier_thresholds.end()) -
798
1.00k
      channel_multiplier_thresholds.begin());
799
1.00k
  std::sort(group_multiplier_thresholds.begin(),
800
1.00k
            group_multiplier_thresholds.end());
801
1.00k
  group_multiplier_thresholds.resize(
802
1.00k
      std::unique(group_multiplier_thresholds.begin(),
803
1.00k
                  group_multiplier_thresholds.end()) -
804
1.00k
      group_multiplier_thresholds.begin());
805
806
1.00k
  compact_properties.resize(props_to_use.size());
807
1.00k
  auto quantize_channel = [&]() {
808
1.00k
    if (!channel_multiplier_thresholds.empty()) {
809
0
      return channel_multiplier_thresholds;
810
0
    }
811
1.00k
    return QuantizeHistogram(channel_pixel_count, max_property_values);
812
1.00k
  };
813
1.00k
  auto quantize_group_id = [&]() {
814
1.00k
    if (!group_multiplier_thresholds.empty()) {
815
0
      return group_multiplier_thresholds;
816
0
    }
817
1.00k
    return QuantizeHistogram(group_pixel_count, max_property_values);
818
1.00k
  };
819
1.00k
  auto quantize_coordinate = [&]() {
820
0
    std::vector<int32_t> quantized;
821
0
    quantized.reserve(max_property_values - 1);
822
0
    for (size_t i = 0; i + 1 < max_property_values; i++) {
823
0
      quantized.push_back((i + 1) * 256 / max_property_values - 1);
824
0
    }
825
0
    return quantized;
826
0
  };
827
1.00k
  std::vector<int32_t> abs_pixel_thresholds;
828
1.00k
  std::vector<int32_t> pixel_thresholds;
829
1.00k
  auto quantize_pixel_property = [&]() {
830
0
    if (pixel_thresholds.empty()) {
831
0
      pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values);
832
0
    }
833
0
    return pixel_thresholds;
834
0
  };
835
1.00k
  auto quantize_abs_pixel_property = [&]() {
836
0
    if (abs_pixel_thresholds.empty()) {
837
0
      quantize_pixel_property();  // Compute the non-abs thresholds.
838
0
      for (auto &v : pixel_samples) v = std::abs(v);
839
0
      abs_pixel_thresholds =
840
0
          QuantizeSamples(pixel_samples, max_property_values);
841
0
    }
842
0
    return abs_pixel_thresholds;
843
0
  };
844
1.00k
  std::vector<int32_t> abs_diff_thresholds;
845
1.00k
  std::vector<int32_t> diff_thresholds;
846
4.00k
  auto quantize_diff_property = [&]() {
847
4.00k
    if (diff_thresholds.empty()) {
848
1.00k
      diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
849
1.00k
    }
850
4.00k
    return diff_thresholds;
851
4.00k
  };
852
1.00k
  auto quantize_abs_diff_property = [&]() {
853
0
    if (abs_diff_thresholds.empty()) {
854
0
      quantize_diff_property();  // Compute the non-abs thresholds.
855
0
      for (auto &v : diff_samples) v = std::abs(v);
856
0
      abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values);
857
0
    }
858
0
    return abs_diff_thresholds;
859
0
  };
860
1.00k
  auto quantize_wp = [&]() {
861
1.00k
    if (max_property_values < 32) {
862
0
      return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0,
863
0
                                  1,    3,   7,   15,  31, 63, 127};
864
0
    }
865
1.00k
    if (max_property_values < 64) {
866
1.00k
      return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23,
867
1.00k
                                  -15,  -11,  -7,   -5,  -3,  -1,  0,   1,
868
1.00k
                                  3,    5,    7,    11,  15,  23,  31,  47,
869
1.00k
                                  63,   95,   127,  191, 255};
870
1.00k
    }
871
0
    return std::vector<int32_t>{
872
0
        -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47,
873
0
        -39,  -31,  -27,  -23,  -19,  -15,  -13, -11, -9,  -7,  -6,
874
0
        -5,   -4,   -3,   -2,   -1,   0,    1,   2,   3,   4,   5,
875
0
        6,    7,    9,    11,   13,   15,   19,  23,  27,  31,  39,
876
0
        47,   55,   63,   79,   95,   111,  127, 159, 191, 223, 255};
877
1.00k
  };
878
879
1.00k
  property_mapping.resize(props_to_use.size() - num_static_props);
880
8.00k
  for (size_t i = 0; i < props_to_use.size(); i++) {
881
7.00k
    if (props_to_use[i] == 0) {
882
1.00k
      compact_properties[i] = quantize_channel();
883
6.00k
    } else if (props_to_use[i] == 1) {
884
1.00k
      compact_properties[i] = quantize_group_id();
885
5.00k
    } else if (props_to_use[i] == 2 || props_to_use[i] == 3) {
886
0
      compact_properties[i] = quantize_coordinate();
887
5.00k
    } else if (props_to_use[i] == 6 || props_to_use[i] == 7 ||
888
5.00k
               props_to_use[i] == 8 ||
889
5.00k
               (props_to_use[i] >= kNumNonrefProperties &&
890
5.00k
                (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) {
891
0
      compact_properties[i] = quantize_pixel_property();
892
5.00k
    } else if (props_to_use[i] == 4 || props_to_use[i] == 5 ||
893
5.00k
               (props_to_use[i] >= kNumNonrefProperties &&
894
5.00k
                (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) {
895
0
      compact_properties[i] = quantize_abs_pixel_property();
896
5.00k
    } else if (props_to_use[i] >= kNumNonrefProperties &&
897
5.00k
               (props_to_use[i] - kNumNonrefProperties) % 4 == 2) {
898
0
      compact_properties[i] = quantize_abs_diff_property();
899
5.00k
    } else if (props_to_use[i] == kWPProp) {
900
1.00k
      compact_properties[i] = quantize_wp();
901
4.00k
    } else {
902
4.00k
      compact_properties[i] = quantize_diff_property();
903
4.00k
    }
904
7.00k
    if (i < num_static_props) {
905
2.00k
      QuantMap(compact_properties[i], static_property_mapping[i],
906
2.00k
               kPropertyRange * 2 + 1, kPropertyRange);
907
5.00k
    } else {
908
5.00k
      QuantMap(compact_properties[i], property_mapping[i - num_static_props],
909
5.00k
               kPropertyRange * 2 + 1, kPropertyRange);
910
5.00k
    }
911
7.00k
  }
912
1.00k
}
913
914
void CollectPixelSamples(const Image &image, const ModularOptions &options,
915
                         uint32_t group_id,
916
                         std::vector<uint32_t> &group_pixel_count,
917
                         std::vector<uint32_t> &channel_pixel_count,
918
                         std::vector<pixel_type> &pixel_samples,
919
1.00k
                         std::vector<pixel_type> &diff_samples) {
920
1.00k
  if (options.nb_repeats == 0) return;
921
1.00k
  if (group_pixel_count.size() <= group_id) {
922
1.00k
    group_pixel_count.resize(group_id + 1);
923
1.00k
  }
924
1.00k
  if (channel_pixel_count.size() < image.channel.size()) {
925
1.00k
    channel_pixel_count.resize(image.channel.size());
926
1.00k
  }
927
1.00k
  Rng rng(group_id);
928
  // Sample 10% of the final number of samples for property quantization.
929
1.00k
  float fraction = std::min(options.nb_repeats * 0.1, 0.99);
930
1.00k
  Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction);
931
1.00k
  size_t total_pixels = 0;
932
1.00k
  std::vector<size_t> channel_ids;
933
4.00k
  for (size_t i = 0; i < image.channel.size(); i++) {
934
3.00k
    if (i >= image.nb_meta_channels &&
935
3.00k
        (image.channel[i].w > options.max_chan_size ||
936
3.00k
         image.channel[i].h > options.max_chan_size)) {
937
0
      break;
938
0
    }
939
3.00k
    if (image.channel[i].w <= 1 || image.channel[i].h == 0) {
940
0
      continue;  // skip empty or width-1 channels.
941
0
    }
942
3.00k
    channel_ids.push_back(i);
943
3.00k
    group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h;
944
3.00k
    channel_pixel_count[i] += image.channel[i].w * image.channel[i].h;
945
3.00k
    total_pixels += image.channel[i].w * image.channel[i].h;
946
3.00k
  }
947
1.00k
  if (channel_ids.empty()) return;
948
1.00k
  pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels);
949
1.00k
  diff_samples.reserve(diff_samples.size() + fraction * total_pixels);
950
1.00k
  size_t i = 0;
951
1.00k
  size_t y = 0;
952
1.00k
  size_t x = 0;
953
1.17M
  auto advance = [&](size_t amount) {
954
1.17M
    x += amount;
955
    // Detect row overflow (rare).
956
1.37M
    while (x >= image.channel[channel_ids[i]].w) {
957
197k
      x -= image.channel[channel_ids[i]].w;
958
197k
      y++;
959
      // Detect end-of-channel (even rarer).
960
197k
      if (y == image.channel[channel_ids[i]].h) {
961
3.00k
        i++;
962
3.00k
        y = 0;
963
3.00k
        if (i >= channel_ids.size()) {
964
1.00k
          return;
965
1.00k
        }
966
3.00k
      }
967
197k
    }
968
1.17M
  };
969
1.00k
  advance(rng.Geometric(dist));
970
1.17M
  for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) {
971
1.17M
    const pixel_type *row = image.channel[channel_ids[i]].Row(y);
972
1.17M
    pixel_samples.push_back(row[x]);
973
1.17M
    size_t xp = x == 0 ? 1 : x - 1;
974
1.17M
    diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]);
975
1.17M
  }
976
1.00k
}
977
978
// TODO(veluca): very simple encoding scheme. This should be improved.
979
Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
980
3.13k
                    Tree *decoder_tree) {
981
3.13k
  JXL_ENSURE(tree.size() <= kMaxTreeSize);
982
3.13k
  std::queue<int> q;
983
3.13k
  q.push(0);
984
3.13k
  size_t leaf_id = 0;
985
3.13k
  decoder_tree->clear();
986
191k
  while (!q.empty()) {
987
188k
    int cur = q.front();
988
188k
    q.pop();
989
188k
    JXL_ENSURE(tree[cur].property >= -1);
990
188k
    tokens->emplace_back(kPropertyContext, tree[cur].property + 1);
991
188k
    if (tree[cur].property == -1) {
992
95.7k
      tokens->emplace_back(kPredictorContext,
993
95.7k
                           static_cast<int>(tree[cur].predictor));
994
95.7k
      tokens->emplace_back(kOffsetContext,
995
95.7k
                           PackSigned(tree[cur].predictor_offset));
996
95.7k
      uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier);
997
95.7k
      uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1;
998
95.7k
      tokens->emplace_back(kMultiplierLogContext, mul_log);
999
95.7k
      tokens->emplace_back(kMultiplierBitsContext, mul_bits);
1000
95.7k
      JXL_ENSURE(tree[cur].predictor < Predictor::Best);
1001
95.7k
      decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor,
1002
95.7k
                                 tree[cur].predictor_offset,
1003
95.7k
                                 tree[cur].multiplier);
1004
95.7k
      continue;
1005
95.7k
    }
1006
92.6k
    decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval,
1007
92.6k
                               decoder_tree->size() + q.size() + 1,
1008
92.6k
                               decoder_tree->size() + q.size() + 2,
1009
92.6k
                               Predictor::Zero, 0, 1);
1010
92.6k
    q.push(tree[cur].lchild);
1011
92.6k
    q.push(tree[cur].rchild);
1012
92.6k
    tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval));
1013
92.6k
  }
1014
3.13k
  return true;
1015
3.13k
}
1016
1017
}  // namespace jxl
1018
#endif  // HWY_ONCE