/src/libjxl/lib/jxl/modular/encoding/enc_ma.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/modular/encoding/enc_ma.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cstdint> |
10 | | #include <cstdlib> |
11 | | #include <cstring> |
12 | | #include <limits> |
13 | | #include <numeric> |
14 | | #include <queue> |
15 | | #include <vector> |
16 | | |
17 | | #include "lib/jxl/ans_params.h" |
18 | | #include "lib/jxl/base/bits.h" |
19 | | #include "lib/jxl/base/common.h" |
20 | | #include "lib/jxl/base/compiler_specific.h" |
21 | | #include "lib/jxl/base/status.h" |
22 | | #include "lib/jxl/dec_ans.h" |
23 | | #include "lib/jxl/modular/encoding/dec_ma.h" |
24 | | #include "lib/jxl/modular/encoding/ma_common.h" |
25 | | #include "lib/jxl/modular/modular_image.h" |
26 | | |
27 | | #undef HWY_TARGET_INCLUDE |
28 | | #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" |
29 | | #include <hwy/foreach_target.h> |
30 | | #include <hwy/highway.h> |
31 | | |
32 | | #include "lib/jxl/base/fast_math-inl.h" |
33 | | #include "lib/jxl/base/random.h" |
34 | | #include "lib/jxl/enc_ans.h" |
35 | | #include "lib/jxl/modular/encoding/context_predict.h" |
36 | | #include "lib/jxl/modular/options.h" |
37 | | #include "lib/jxl/pack_signed.h" |
38 | | HWY_BEFORE_NAMESPACE(); |
39 | | namespace jxl { |
40 | | namespace HWY_NAMESPACE { |
41 | | |
42 | | // These templates are not found via ADL. |
43 | | using hwy::HWY_NAMESPACE::Eq; |
44 | | using hwy::HWY_NAMESPACE::IfThenElse; |
45 | | using hwy::HWY_NAMESPACE::Lt; |
46 | | using hwy::HWY_NAMESPACE::Max; |
47 | | |
48 | | const HWY_FULL(float) df; |
49 | | const HWY_FULL(int32_t) di; |
50 | 6.07k | size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } Unexecuted instantiation: jxl::N_SSE4::Padded(unsigned long) jxl::N_AVX2::Padded(unsigned long) Line | Count | Source | 50 | 6.07k | size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } |
Unexecuted instantiation: jxl::N_SSE2::Padded(unsigned long) |
51 | | |
52 | | // Compute entropy of the histogram, taking into account the minimum probability |
53 | | // for symbols with non-zero counts. |
54 | 370k | float EstimateBits(const int32_t *counts, size_t num_symbols) { |
55 | 370k | int32_t total = std::accumulate(counts, counts + num_symbols, 0); |
56 | 370k | const auto zero = Zero(df); |
57 | 370k | const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); |
58 | 370k | const auto inv_total = Set(df, 1.0f / total); |
59 | 370k | auto bits_lanes = Zero(df); |
60 | 370k | auto total_v = Set(di, total); |
61 | 2.77M | for (size_t i = 0; i < num_symbols; i += Lanes(df)) { |
62 | 2.40M | const auto counts_iv = LoadU(di, &counts[i]); |
63 | 2.40M | const auto counts_fv = ConvertTo(df, counts_iv); |
64 | 2.40M | const auto probs = Mul(counts_fv, inv_total); |
65 | 2.40M | const auto mprobs = Max(probs, minprob); |
66 | 2.40M | const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), |
67 | 2.40M | BitCast(di, FastLog2f(df, mprobs))); |
68 | 2.40M | bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); |
69 | 2.40M | } |
70 | 370k | return GetLane(SumOfLanes(df, bits_lanes)); |
71 | 370k | } Unexecuted instantiation: jxl::N_SSE4::EstimateBits(int const*, unsigned long) jxl::N_AVX2::EstimateBits(int const*, unsigned long) Line | Count | Source | 54 | 370k | float EstimateBits(const int32_t *counts, size_t num_symbols) { | 55 | 370k | int32_t total = std::accumulate(counts, counts + num_symbols, 0); | 56 | 370k | const auto zero = Zero(df); | 57 | 370k | const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); | 58 | 370k | const auto inv_total = Set(df, 1.0f / total); | 59 | 370k | auto bits_lanes = Zero(df); | 60 | 370k | auto total_v = Set(di, total); | 61 | 2.77M | for (size_t i = 0; i < num_symbols; i += Lanes(df)) { | 62 | 2.40M | const auto counts_iv = LoadU(di, &counts[i]); | 63 | 2.40M | const auto counts_fv = ConvertTo(df, counts_iv); | 64 | 2.40M | const auto probs = Mul(counts_fv, inv_total); | 65 | 2.40M | const auto mprobs = Max(probs, minprob); | 66 | 2.40M | const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), | 67 | 2.40M | BitCast(di, FastLog2f(df, mprobs))); | 68 | 2.40M | bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); | 69 | 2.40M | } | 70 | 370k | return GetLane(SumOfLanes(df, bits_lanes)); | 71 | 370k | } |
Unexecuted instantiation: jxl::N_SSE2::EstimateBits(int const*, unsigned long) |
72 | | |
73 | | void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, |
74 | 2.98k | int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { |
75 | | // Note that the tree splits on *strictly greater*. |
76 | 2.98k | (*tree)[pos].lchild = tree->size(); |
77 | 2.98k | (*tree)[pos].rchild = tree->size() + 1; |
78 | 2.98k | (*tree)[pos].splitval = splitval; |
79 | 2.98k | (*tree)[pos].property = property; |
80 | 2.98k | tree->emplace_back(); |
81 | 2.98k | tree->back().property = -1; |
82 | 2.98k | tree->back().predictor = rpred; |
83 | 2.98k | tree->back().predictor_offset = roff; |
84 | 2.98k | tree->back().multiplier = 1; |
85 | 2.98k | tree->emplace_back(); |
86 | 2.98k | tree->back().property = -1; |
87 | 2.98k | tree->back().predictor = lpred; |
88 | 2.98k | tree->back().predictor_offset = loff; |
89 | 2.98k | tree->back().multiplier = 1; |
90 | 2.98k | } Unexecuted instantiation: jxl::N_SSE4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) jxl::N_AVX2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) Line | Count | Source | 74 | 2.98k | int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { | 75 | | // Note that the tree splits on *strictly greater*. | 76 | 2.98k | (*tree)[pos].lchild = tree->size(); | 77 | 2.98k | (*tree)[pos].rchild = tree->size() + 1; | 78 | 2.98k | (*tree)[pos].splitval = splitval; | 79 | 2.98k | (*tree)[pos].property = property; | 80 | 2.98k | tree->emplace_back(); | 81 | 2.98k | tree->back().property = -1; | 82 | 2.98k | tree->back().predictor = rpred; | 83 | 2.98k | tree->back().predictor_offset = roff; | 84 | 2.98k | tree->back().multiplier = 1; | 85 | 2.98k | tree->emplace_back(); | 86 | 2.98k | tree->back().property = -1; | 87 | 2.98k | tree->back().predictor = lpred; | 88 | 2.98k | tree->back().predictor_offset = loff; | 89 | 2.98k | tree->back().multiplier = 1; | 90 | 2.98k | } |
Unexecuted instantiation: jxl::N_SSE2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) |
91 | | |
92 | | enum class IntersectionType { kNone, kPartial, kInside }; |
93 | | IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, |
94 | 5.89k | uint32_t &partial_axis, uint32_t &partial_val) { |
95 | 5.89k | bool partial = false; |
96 | 17.6k | for (size_t i = 0; i < kNumStaticProperties; i++) { |
97 | 11.7k | if (haystack[i][0] >= needle[i][1]) { |
98 | 0 | return IntersectionType::kNone; |
99 | 0 | } |
100 | 11.7k | if (haystack[i][1] <= needle[i][0]) { |
101 | 0 | return IntersectionType::kNone; |
102 | 0 | } |
103 | 11.7k | if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { |
104 | 11.7k | continue; |
105 | 11.7k | } |
106 | 0 | partial = true; |
107 | 0 | partial_axis = i; |
108 | 0 | if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { |
109 | 0 | partial_val = haystack[i][0] - 1; |
110 | 0 | } else { |
111 | 0 | JXL_DASSERT(haystack[i][1] > needle[i][0] && |
112 | 0 | haystack[i][1] < needle[i][1]); |
113 | 0 | partial_val = haystack[i][1] - 1; |
114 | 0 | } |
115 | 0 | } |
116 | 5.89k | return partial ? IntersectionType::kPartial : IntersectionType::kInside; |
117 | 5.89k | } Unexecuted instantiation: jxl::N_SSE4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) jxl::N_AVX2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) Line | Count | Source | 94 | 5.89k | uint32_t &partial_axis, uint32_t &partial_val) { | 95 | 5.89k | bool partial = false; | 96 | 17.6k | for (size_t i = 0; i < kNumStaticProperties; i++) { | 97 | 11.7k | if (haystack[i][0] >= needle[i][1]) { | 98 | 0 | return IntersectionType::kNone; | 99 | 0 | } | 100 | 11.7k | if (haystack[i][1] <= needle[i][0]) { | 101 | 0 | return IntersectionType::kNone; | 102 | 0 | } | 103 | 11.7k | if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { | 104 | 11.7k | continue; | 105 | 11.7k | } | 106 | 0 | partial = true; | 107 | 0 | partial_axis = i; | 108 | 0 | if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { | 109 | 0 | partial_val = haystack[i][0] - 1; | 110 | 0 | } else { | 111 | 0 | JXL_DASSERT(haystack[i][1] > needle[i][0] && | 112 | 0 | haystack[i][1] < needle[i][1]); | 113 | 0 | partial_val = haystack[i][1] - 1; | 114 | 0 | } | 115 | 0 | } | 116 | 5.89k | return partial ? IntersectionType::kPartial : IntersectionType::kInside; | 117 | 5.89k | } |
Unexecuted instantiation: jxl::N_SSE2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) |
118 | | |
119 | | void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, |
120 | 2.98k | size_t end, size_t prop, uint32_t val) { |
121 | 2.98k | size_t begin_pos = begin; |
122 | 2.98k | size_t end_pos = pos; |
123 | 196k | do { |
124 | 502k | while (begin_pos < pos && tree_samples.Property(prop, begin_pos) <= val) { |
125 | 306k | ++begin_pos; |
126 | 306k | } |
127 | 521k | while (end_pos < end && tree_samples.Property(prop, end_pos) > val) { |
128 | 324k | ++end_pos; |
129 | 324k | } |
130 | 196k | if (begin_pos < pos && end_pos < end) { |
131 | 195k | tree_samples.Swap(begin_pos, end_pos); |
132 | 195k | } |
133 | 196k | ++begin_pos; |
134 | 196k | ++end_pos; |
135 | 196k | } while (begin_pos < pos && end_pos < end); |
136 | 2.98k | } Unexecuted instantiation: jxl::N_SSE4::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) jxl::N_AVX2::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Line | Count | Source | 120 | 2.98k | size_t end, size_t prop, uint32_t val) { | 121 | 2.98k | size_t begin_pos = begin; | 122 | 2.98k | size_t end_pos = pos; | 123 | 196k | do { | 124 | 502k | while (begin_pos < pos && tree_samples.Property(prop, begin_pos) <= val) { | 125 | 306k | ++begin_pos; | 126 | 306k | } | 127 | 521k | while (end_pos < end && tree_samples.Property(prop, end_pos) > val) { | 128 | 324k | ++end_pos; | 129 | 324k | } | 130 | 196k | if (begin_pos < pos && end_pos < end) { | 131 | 195k | tree_samples.Swap(begin_pos, end_pos); | 132 | 195k | } | 133 | 196k | ++begin_pos; | 134 | 196k | ++end_pos; | 135 | 196k | } while (begin_pos < pos && end_pos < end); | 136 | 2.98k | } |
Unexecuted instantiation: jxl::N_SSE2::SplitTreeSamples(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) |
137 | | |
138 | | void FindBestSplit(TreeSamples &tree_samples, float threshold, |
139 | | const std::vector<ModularMultiplierInfo> &mul_info, |
140 | | StaticPropRange initial_static_prop_range, |
141 | 111 | float fast_decode_multiplier, Tree *tree) { |
142 | 111 | struct NodeInfo { |
143 | 111 | size_t pos; |
144 | 111 | size_t begin; |
145 | 111 | size_t end; |
146 | 111 | uint64_t used_properties; |
147 | 111 | StaticPropRange static_prop_range; |
148 | 111 | }; |
149 | 111 | std::vector<NodeInfo> nodes; |
150 | 111 | nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, |
151 | 111 | initial_static_prop_range}); |
152 | | |
153 | 111 | size_t num_predictors = tree_samples.NumPredictors(); |
154 | 111 | size_t num_properties = tree_samples.NumProperties(); |
155 | | |
156 | | // TODO(veluca): consider parallelizing the search (processing multiple nodes |
157 | | // at a time). |
158 | 6.18k | while (!nodes.empty()) { |
159 | 6.07k | size_t pos = nodes.back().pos; |
160 | 6.07k | size_t begin = nodes.back().begin; |
161 | 6.07k | size_t end = nodes.back().end; |
162 | 6.07k | uint64_t used_properties = nodes.back().used_properties; |
163 | 6.07k | StaticPropRange static_prop_range = nodes.back().static_prop_range; |
164 | 6.07k | nodes.pop_back(); |
165 | 6.07k | if (begin == end) continue; |
166 | | |
167 | 6.07k | struct SplitInfo { |
168 | 6.07k | size_t prop = 0; |
169 | 6.07k | uint32_t val = 0; |
170 | 6.07k | size_t pos = 0; |
171 | 6.07k | float lcost = std::numeric_limits<float>::max(); |
172 | 6.07k | float rcost = std::numeric_limits<float>::max(); |
173 | 6.07k | Predictor lpred = Predictor::Zero; |
174 | 6.07k | Predictor rpred = Predictor::Zero; |
175 | 213k | float Cost() const { return lcost + rcost; } Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const Line | Count | Source | 175 | 213k | float Cost() const { return lcost + rcost; } |
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const |
176 | 6.07k | }; |
177 | | |
178 | 6.07k | SplitInfo best_split_static_constant; |
179 | 6.07k | SplitInfo best_split_static; |
180 | 6.07k | SplitInfo best_split_nonstatic; |
181 | 6.07k | SplitInfo best_split_nowp; |
182 | | |
183 | 6.07k | JXL_DASSERT(begin <= end); |
184 | 6.07k | JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); |
185 | | |
186 | | // Compute the maximum token in the range. |
187 | 6.07k | size_t max_symbols = 0; |
188 | 12.1k | for (size_t pred = 0; pred < num_predictors; pred++) { |
189 | 1.23M | for (size_t i = begin; i < end; i++) { |
190 | 1.22M | uint32_t tok = tree_samples.Token(pred, i); |
191 | 1.22M | max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; |
192 | 1.22M | } |
193 | 6.07k | } |
194 | 6.07k | max_symbols = Padded(max_symbols); |
195 | 6.07k | std::vector<int32_t> counts(max_symbols * num_predictors); |
196 | 6.07k | std::vector<uint32_t> tot_extra_bits(num_predictors); |
197 | 12.1k | for (size_t pred = 0; pred < num_predictors; pred++) { |
198 | 1.23M | for (size_t i = begin; i < end; i++) { |
199 | 1.22M | counts[pred * max_symbols + tree_samples.Token(pred, i)] += |
200 | 1.22M | tree_samples.Count(i); |
201 | 1.22M | tot_extra_bits[pred] += |
202 | 1.22M | tree_samples.NBits(pred, i) * tree_samples.Count(i); |
203 | 1.22M | } |
204 | 6.07k | } |
205 | | |
206 | 6.07k | float base_bits; |
207 | 6.07k | { |
208 | 6.07k | size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); |
209 | 6.07k | base_bits = |
210 | 6.07k | EstimateBits(counts.data() + pred * max_symbols, max_symbols) + |
211 | 6.07k | tot_extra_bits[pred]; |
212 | 6.07k | } |
213 | | |
214 | 6.07k | SplitInfo *best = &best_split_nonstatic; |
215 | | |
216 | 6.07k | SplitInfo forced_split; |
217 | | // The multiplier ranges cut halfway through the current ranges of static |
218 | | // properties. We do this even if the current node is not a leaf, to |
219 | | // minimize the number of nodes in the resulting tree. |
220 | 6.07k | for (const auto &mmi : mul_info) { |
221 | 5.89k | uint32_t axis; |
222 | 5.89k | uint32_t val; |
223 | 5.89k | IntersectionType t = |
224 | 5.89k | BoxIntersects(static_prop_range, mmi.range, axis, val); |
225 | 5.89k | if (t == IntersectionType::kNone) continue; |
226 | 5.89k | if (t == IntersectionType::kInside) { |
227 | 5.89k | (*tree)[pos].multiplier = mmi.multiplier; |
228 | 5.89k | break; |
229 | 5.89k | } |
230 | 0 | if (t == IntersectionType::kPartial) { |
231 | 0 | forced_split.val = tree_samples.QuantizeProperty(axis, val); |
232 | 0 | forced_split.prop = axis; |
233 | 0 | forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; |
234 | 0 | forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; |
235 | 0 | best = &forced_split; |
236 | 0 | best->pos = begin; |
237 | 0 | JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); |
238 | 0 | for (size_t x = begin; x < end; x++) { |
239 | 0 | if (tree_samples.Property(best->prop, x) <= best->val) { |
240 | 0 | best->pos++; |
241 | 0 | } |
242 | 0 | } |
243 | 0 | break; |
244 | 0 | } |
245 | 0 | } |
246 | | |
247 | 6.07k | if (best != &forced_split) { |
248 | 6.07k | std::vector<int> prop_value_used_count; |
249 | 6.07k | std::vector<int> count_increase; |
250 | 6.07k | std::vector<size_t> extra_bits_increase; |
251 | | // For each property, compute which of its values are used, and what |
252 | | // tokens correspond to those usages. Then, iterate through the values, |
253 | | // and compute the entropy of each side of the split (of the form `prop > |
254 | | // threshold`). Finally, find the split that minimizes the cost. |
255 | 6.07k | struct CostInfo { |
256 | 6.07k | float cost = std::numeric_limits<float>::max(); |
257 | 6.07k | float extra_cost = 0; |
258 | 364k | float Cost() const { return cost + extra_cost; } Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const Line | Count | Source | 258 | 364k | float Cost() const { return cost + extra_cost; } |
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const |
259 | 6.07k | Predictor pred; // will be uninitialized in some cases, but never used. |
260 | 6.07k | }; |
261 | 6.07k | std::vector<CostInfo> costs_l; |
262 | 6.07k | std::vector<CostInfo> costs_r; |
263 | | |
264 | 6.07k | std::vector<int32_t> counts_above(max_symbols); |
265 | 6.07k | std::vector<int32_t> counts_below(max_symbols); |
266 | | |
267 | | // The lower the threshold, the higher the expected noisiness of the |
268 | | // estimate. Thus, discourage changing predictors. |
269 | 6.07k | float change_pred_penalty = 800.0f / (100.0f + threshold); |
270 | 48.1k | for (size_t prop = 0; prop < num_properties && base_bits > threshold; |
271 | 42.1k | prop++) { |
272 | 42.1k | costs_l.clear(); |
273 | 42.1k | costs_r.clear(); |
274 | 42.1k | size_t prop_size = tree_samples.NumPropertyValues(prop); |
275 | 42.1k | if (extra_bits_increase.size() < prop_size) { |
276 | 12.2k | count_increase.resize(prop_size * max_symbols); |
277 | 12.2k | extra_bits_increase.resize(prop_size); |
278 | 12.2k | } |
279 | | // Clear prop_value_used_count (which cannot be cleared "on the go") |
280 | 42.1k | prop_value_used_count.clear(); |
281 | 42.1k | prop_value_used_count.resize(prop_size); |
282 | | |
283 | 42.1k | size_t first_used = prop_size; |
284 | 42.1k | size_t last_used = 0; |
285 | | |
286 | | // TODO(veluca): consider finding multiple splits along a single |
287 | | // property at the same time, possibly with a bottom-up approach. |
288 | 8.62M | for (size_t i = begin; i < end; i++) { |
289 | 8.57M | size_t p = tree_samples.Property(prop, i); |
290 | 8.57M | prop_value_used_count[p]++; |
291 | 8.57M | last_used = std::max(last_used, p); |
292 | 8.57M | first_used = std::min(first_used, p); |
293 | 8.57M | } |
294 | 42.1k | costs_l.resize(last_used - first_used); |
295 | 42.1k | costs_r.resize(last_used - first_used); |
296 | | // For all predictors, compute the right and left costs of each split. |
297 | 84.2k | for (size_t pred = 0; pred < num_predictors; pred++) { |
298 | | // Compute cost and histogram increments for each property value. |
299 | 8.62M | for (size_t i = begin; i < end; i++) { |
300 | 8.57M | size_t p = tree_samples.Property(prop, i); |
301 | 8.57M | size_t cnt = tree_samples.Count(i); |
302 | 8.57M | size_t sym = tree_samples.Token(pred, i); |
303 | 8.57M | count_increase[p * max_symbols + sym] += cnt; |
304 | 8.57M | extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; |
305 | 8.57M | } |
306 | 42.1k | memcpy(counts_above.data(), counts.data() + pred * max_symbols, |
307 | 42.1k | max_symbols * sizeof counts_above[0]); |
308 | 42.1k | memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); |
309 | 42.1k | size_t extra_bits_below = 0; |
310 | | // Exclude last used: this ensures neither counts_above nor |
311 | | // counts_below is empty. |
312 | 351k | for (size_t i = first_used; i < last_used; i++) { |
313 | 309k | if (!prop_value_used_count[i]) continue; |
314 | 182k | extra_bits_below += extra_bits_increase[i]; |
315 | | // The increase for this property value has been used, and will not |
316 | | // be used again: clear it. Also below. |
317 | 182k | extra_bits_increase[i] = 0; |
318 | 9.63M | for (size_t sym = 0; sym < max_symbols; sym++) { |
319 | 9.44M | counts_above[sym] -= count_increase[i * max_symbols + sym]; |
320 | 9.44M | counts_below[sym] += count_increase[i * max_symbols + sym]; |
321 | 9.44M | count_increase[i * max_symbols + sym] = 0; |
322 | 9.44M | } |
323 | 182k | float rcost = EstimateBits(counts_above.data(), max_symbols) + |
324 | 182k | tot_extra_bits[pred] - extra_bits_below; |
325 | 182k | float lcost = EstimateBits(counts_below.data(), max_symbols) + |
326 | 182k | extra_bits_below; |
327 | 182k | JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); |
328 | 182k | float penalty = 0; |
329 | | // Never discourage moving away from the Weighted predictor. |
330 | 182k | if (tree_samples.PredictorFromIndex(pred) != |
331 | 182k | (*tree)[pos].predictor && |
332 | 182k | (*tree)[pos].predictor != Predictor::Weighted) { |
333 | 0 | penalty = change_pred_penalty; |
334 | 0 | } |
335 | | // If everything else is equal, disfavour Weighted (slower) and |
336 | | // favour Zero (faster if it's the only predictor used in a |
337 | | // group+channel combination) |
338 | 182k | if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { |
339 | 0 | penalty += 1e-8; |
340 | 0 | } |
341 | 182k | if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { |
342 | 0 | penalty -= 1e-8; |
343 | 0 | } |
344 | 182k | if (rcost + penalty < costs_r[i - first_used].Cost()) { |
345 | 182k | costs_r[i - first_used].cost = rcost; |
346 | 182k | costs_r[i - first_used].extra_cost = penalty; |
347 | 182k | costs_r[i - first_used].pred = |
348 | 182k | tree_samples.PredictorFromIndex(pred); |
349 | 182k | } |
350 | 182k | if (lcost + penalty < costs_l[i - first_used].Cost()) { |
351 | 182k | costs_l[i - first_used].cost = lcost; |
352 | 182k | costs_l[i - first_used].extra_cost = penalty; |
353 | 182k | costs_l[i - first_used].pred = |
354 | 182k | tree_samples.PredictorFromIndex(pred); |
355 | 182k | } |
356 | 182k | } |
357 | 42.1k | } |
358 | | // Iterate through the possible splits and find the one with minimum sum |
359 | | // of costs of the two sides. |
360 | 42.1k | size_t split = begin; |
361 | 351k | for (size_t i = first_used; i < last_used; i++) { |
362 | 309k | if (!prop_value_used_count[i]) continue; |
363 | 182k | split += prop_value_used_count[i]; |
364 | 182k | float rcost = costs_r[i - first_used].cost; |
365 | 182k | float lcost = costs_l[i - first_used].cost; |
366 | | // WP was not used + we would use the WP property or predictor |
367 | 182k | bool adds_wp = |
368 | 182k | (tree_samples.PropertyFromIndex(prop) == kWPProp && |
369 | 182k | (used_properties & (1LU << prop)) == 0) || |
370 | 182k | ((costs_l[i - first_used].pred == Predictor::Weighted || |
371 | 128k | costs_r[i - first_used].pred == Predictor::Weighted) && |
372 | 128k | (*tree)[pos].predictor != Predictor::Weighted); |
373 | 182k | bool zero_entropy_side = rcost == 0 || lcost == 0; |
374 | | |
375 | 182k | SplitInfo &best_ref = |
376 | 182k | tree_samples.PropertyFromIndex(prop) < kNumStaticProperties |
377 | 182k | ? (zero_entropy_side ? best_split_static_constant |
378 | 1.27k | : best_split_static) |
379 | 182k | : (adds_wp ? best_split_nonstatic : best_split_nowp); |
380 | 182k | if (lcost + rcost < best_ref.Cost()) { |
381 | 39.1k | best_ref.prop = prop; |
382 | 39.1k | best_ref.val = i; |
383 | 39.1k | best_ref.pos = split; |
384 | 39.1k | best_ref.lcost = lcost; |
385 | 39.1k | best_ref.lpred = costs_l[i - first_used].pred; |
386 | 39.1k | best_ref.rcost = rcost; |
387 | 39.1k | best_ref.rpred = costs_r[i - first_used].pred; |
388 | 39.1k | } |
389 | 182k | } |
390 | | // Clear extra_bits_increase and cost_increase for last_used. |
391 | 42.1k | extra_bits_increase[last_used] = 0; |
392 | 2.39M | for (size_t sym = 0; sym < max_symbols; sym++) { |
393 | 2.35M | count_increase[last_used * max_symbols + sym] = 0; |
394 | 2.35M | } |
395 | 42.1k | } |
396 | | |
397 | | // Try to avoid introducing WP. |
398 | 6.07k | if (best_split_nowp.Cost() + threshold < base_bits && |
399 | 6.07k | best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { |
400 | 2.76k | best = &best_split_nowp; |
401 | 2.76k | } |
402 | | // Split along static props if possible and not significantly more |
403 | | // expensive. |
404 | 6.07k | if (best_split_static.Cost() + threshold < base_bits && |
405 | 6.07k | best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { |
406 | 402 | best = &best_split_static; |
407 | 402 | } |
408 | | // Split along static props to create constant nodes if possible. |
409 | 6.07k | if (best_split_static_constant.Cost() + threshold < base_bits) { |
410 | 2 | best = &best_split_static_constant; |
411 | 2 | } |
412 | 6.07k | } |
413 | | |
414 | 6.07k | if (best->Cost() + threshold < base_bits) { |
415 | 2.98k | uint32_t p = tree_samples.PropertyFromIndex(best->prop); |
416 | 2.98k | pixel_type dequant = |
417 | 2.98k | tree_samples.UnquantizeProperty(best->prop, best->val); |
418 | | // Split node and try to split children. |
419 | 2.98k | MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); |
420 | | // "Sort" according to winning property |
421 | 2.98k | SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop, |
422 | 2.98k | best->val); |
423 | 2.98k | if (p >= kNumStaticProperties) { |
424 | 2.57k | used_properties |= 1 << best->prop; |
425 | 2.57k | } |
426 | 2.98k | auto new_sp_range = static_prop_range; |
427 | 2.98k | if (p < kNumStaticProperties) { |
428 | 404 | JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); |
429 | 404 | new_sp_range[p][1] = dequant + 1; |
430 | 404 | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); |
431 | 404 | } |
432 | 2.98k | nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, |
433 | 2.98k | used_properties, new_sp_range}); |
434 | 2.98k | new_sp_range = static_prop_range; |
435 | 2.98k | if (p < kNumStaticProperties) { |
436 | 404 | JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); |
437 | 404 | new_sp_range[p][0] = dequant + 1; |
438 | 404 | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); |
439 | 404 | } |
440 | 2.98k | nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, |
441 | 2.98k | used_properties, new_sp_range}); |
442 | 2.98k | } |
443 | 6.07k | } |
444 | 111 | } Unexecuted instantiation: jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) Line | Count | Source | 141 | 111 | float fast_decode_multiplier, Tree *tree) { | 142 | 111 | struct NodeInfo { | 143 | 111 | size_t pos; | 144 | 111 | size_t begin; | 145 | 111 | size_t end; | 146 | 111 | uint64_t used_properties; | 147 | 111 | StaticPropRange static_prop_range; | 148 | 111 | }; | 149 | 111 | std::vector<NodeInfo> nodes; | 150 | 111 | nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, | 151 | 111 | initial_static_prop_range}); | 152 | | | 153 | 111 | size_t num_predictors = tree_samples.NumPredictors(); | 154 | 111 | size_t num_properties = tree_samples.NumProperties(); | 155 | | | 156 | | // TODO(veluca): consider parallelizing the search (processing multiple nodes | 157 | | // at a time). | 158 | 6.18k | while (!nodes.empty()) { | 159 | 6.07k | size_t pos = nodes.back().pos; | 160 | 6.07k | size_t begin = nodes.back().begin; | 161 | 6.07k | size_t end = nodes.back().end; | 162 | 6.07k | uint64_t used_properties = nodes.back().used_properties; | 163 | 6.07k | StaticPropRange static_prop_range = nodes.back().static_prop_range; | 164 | 6.07k | nodes.pop_back(); | 165 | 6.07k | if (begin == end) continue; | 166 | | | 167 | 6.07k | struct SplitInfo { | 168 | 6.07k | size_t prop = 0; | 169 | 6.07k | uint32_t val = 0; | 170 | 6.07k | size_t pos = 0; | 171 | 6.07k | float lcost = std::numeric_limits<float>::max(); | 172 | 6.07k | float rcost = std::numeric_limits<float>::max(); | 173 | 6.07k | Predictor lpred = Predictor::Zero; | 174 | 6.07k | Predictor rpred = Predictor::Zero; | 175 | 6.07k | float Cost() const { return lcost + rcost; } | 176 | 6.07k | }; | 177 | | | 178 | 6.07k | SplitInfo best_split_static_constant; | 179 | 6.07k | SplitInfo best_split_static; | 180 | 6.07k | SplitInfo best_split_nonstatic; | 181 | 6.07k | SplitInfo best_split_nowp; | 182 | | | 183 | 6.07k | JXL_DASSERT(begin <= end); | 184 | 6.07k | JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); | 185 | | | 186 | | // Compute the maximum token in the range. | 187 | 6.07k | size_t max_symbols = 0; | 188 | 12.1k | for (size_t pred = 0; pred < num_predictors; pred++) { | 189 | 1.23M | for (size_t i = begin; i < end; i++) { | 190 | 1.22M | uint32_t tok = tree_samples.Token(pred, i); | 191 | 1.22M | max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; | 192 | 1.22M | } | 193 | 6.07k | } | 194 | 6.07k | max_symbols = Padded(max_symbols); | 195 | 6.07k | std::vector<int32_t> counts(max_symbols * num_predictors); | 196 | 6.07k | std::vector<uint32_t> tot_extra_bits(num_predictors); | 197 | 12.1k | for (size_t pred = 0; pred < num_predictors; pred++) { | 198 | 1.23M | for (size_t i = begin; i < end; i++) { | 199 | 1.22M | counts[pred * max_symbols + tree_samples.Token(pred, i)] += | 200 | 1.22M | tree_samples.Count(i); | 201 | 1.22M | tot_extra_bits[pred] += | 202 | 1.22M | tree_samples.NBits(pred, i) * tree_samples.Count(i); | 203 | 1.22M | } | 204 | 6.07k | } | 205 | | | 206 | 6.07k | float base_bits; | 207 | 6.07k | { | 208 | 6.07k | size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); | 209 | 6.07k | base_bits = | 210 | 6.07k | EstimateBits(counts.data() + pred * max_symbols, max_symbols) + | 211 | 6.07k | tot_extra_bits[pred]; | 212 | 6.07k | } | 213 | | | 214 | 6.07k | SplitInfo *best = &best_split_nonstatic; | 215 | | | 216 | 6.07k | SplitInfo forced_split; | 217 | | // The multiplier ranges cut halfway through the current ranges of static | 218 | | // properties. We do this even if the current node is not a leaf, to | 219 | | // minimize the number of nodes in the resulting tree. | 220 | 6.07k | for (const auto &mmi : mul_info) { | 221 | 5.89k | uint32_t axis; | 222 | 5.89k | uint32_t val; | 223 | 5.89k | IntersectionType t = | 224 | 5.89k | BoxIntersects(static_prop_range, mmi.range, axis, val); | 225 | 5.89k | if (t == IntersectionType::kNone) continue; | 226 | 5.89k | if (t == IntersectionType::kInside) { | 227 | 5.89k | (*tree)[pos].multiplier = mmi.multiplier; | 228 | 5.89k | break; | 229 | 5.89k | } | 230 | 0 | if (t == IntersectionType::kPartial) { | 231 | 0 | forced_split.val = tree_samples.QuantizeProperty(axis, val); | 232 | 0 | forced_split.prop = axis; | 233 | 0 | forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; | 234 | 0 | forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; | 235 | 0 | best = &forced_split; | 236 | 0 | best->pos = begin; | 237 | 0 | JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); | 238 | 0 | for (size_t x = begin; x < end; x++) { | 239 | 0 | if (tree_samples.Property(best->prop, x) <= best->val) { | 240 | 0 | best->pos++; | 241 | 0 | } | 242 | 0 | } | 243 | 0 | break; | 244 | 0 | } | 245 | 0 | } | 246 | | | 247 | 6.07k | if (best != &forced_split) { | 248 | 6.07k | std::vector<int> prop_value_used_count; | 249 | 6.07k | std::vector<int> count_increase; | 250 | 6.07k | std::vector<size_t> extra_bits_increase; | 251 | | // For each property, compute which of its values are used, and what | 252 | | // tokens correspond to those usages. Then, iterate through the values, | 253 | | // and compute the entropy of each side of the split (of the form `prop > | 254 | | // threshold`). Finally, find the split that minimizes the cost. | 255 | 6.07k | struct CostInfo { | 256 | 6.07k | float cost = std::numeric_limits<float>::max(); | 257 | 6.07k | float extra_cost = 0; | 258 | 6.07k | float Cost() const { return cost + extra_cost; } | 259 | 6.07k | Predictor pred; // will be uninitialized in some cases, but never used. | 260 | 6.07k | }; | 261 | 6.07k | std::vector<CostInfo> costs_l; | 262 | 6.07k | std::vector<CostInfo> costs_r; | 263 | | | 264 | 6.07k | std::vector<int32_t> counts_above(max_symbols); | 265 | 6.07k | std::vector<int32_t> counts_below(max_symbols); | 266 | | | 267 | | // The lower the threshold, the higher the expected noisiness of the | 268 | | // estimate. Thus, discourage changing predictors. | 269 | 6.07k | float change_pred_penalty = 800.0f / (100.0f + threshold); | 270 | 48.1k | for (size_t prop = 0; prop < num_properties && base_bits > threshold; | 271 | 42.1k | prop++) { | 272 | 42.1k | costs_l.clear(); | 273 | 42.1k | costs_r.clear(); | 274 | 42.1k | size_t prop_size = tree_samples.NumPropertyValues(prop); | 275 | 42.1k | if (extra_bits_increase.size() < prop_size) { | 276 | 12.2k | count_increase.resize(prop_size * max_symbols); | 277 | 12.2k | extra_bits_increase.resize(prop_size); | 278 | 12.2k | } | 279 | | // Clear prop_value_used_count (which cannot be cleared "on the go") | 280 | 42.1k | prop_value_used_count.clear(); | 281 | 42.1k | prop_value_used_count.resize(prop_size); | 282 | | | 283 | 42.1k | size_t first_used = prop_size; | 284 | 42.1k | size_t last_used = 0; | 285 | | | 286 | | // TODO(veluca): consider finding multiple splits along a single | 287 | | // property at the same time, possibly with a bottom-up approach. | 288 | 8.62M | for (size_t i = begin; i < end; i++) { | 289 | 8.57M | size_t p = tree_samples.Property(prop, i); | 290 | 8.57M | prop_value_used_count[p]++; | 291 | 8.57M | last_used = std::max(last_used, p); | 292 | 8.57M | first_used = std::min(first_used, p); | 293 | 8.57M | } | 294 | 42.1k | costs_l.resize(last_used - first_used); | 295 | 42.1k | costs_r.resize(last_used - first_used); | 296 | | // For all predictors, compute the right and left costs of each split. | 297 | 84.2k | for (size_t pred = 0; pred < num_predictors; pred++) { | 298 | | // Compute cost and histogram increments for each property value. | 299 | 8.62M | for (size_t i = begin; i < end; i++) { | 300 | 8.57M | size_t p = tree_samples.Property(prop, i); | 301 | 8.57M | size_t cnt = tree_samples.Count(i); | 302 | 8.57M | size_t sym = tree_samples.Token(pred, i); | 303 | 8.57M | count_increase[p * max_symbols + sym] += cnt; | 304 | 8.57M | extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; | 305 | 8.57M | } | 306 | 42.1k | memcpy(counts_above.data(), counts.data() + pred * max_symbols, | 307 | 42.1k | max_symbols * sizeof counts_above[0]); | 308 | 42.1k | memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); | 309 | 42.1k | size_t extra_bits_below = 0; | 310 | | // Exclude last used: this ensures neither counts_above nor | 311 | | // counts_below is empty. | 312 | 351k | for (size_t i = first_used; i < last_used; i++) { | 313 | 309k | if (!prop_value_used_count[i]) continue; | 314 | 182k | extra_bits_below += extra_bits_increase[i]; | 315 | | // The increase for this property value has been used, and will not | 316 | | // be used again: clear it. Also below. | 317 | 182k | extra_bits_increase[i] = 0; | 318 | 9.63M | for (size_t sym = 0; sym < max_symbols; sym++) { | 319 | 9.44M | counts_above[sym] -= count_increase[i * max_symbols + sym]; | 320 | 9.44M | counts_below[sym] += count_increase[i * max_symbols + sym]; | 321 | 9.44M | count_increase[i * max_symbols + sym] = 0; | 322 | 9.44M | } | 323 | 182k | float rcost = EstimateBits(counts_above.data(), max_symbols) + | 324 | 182k | tot_extra_bits[pred] - extra_bits_below; | 325 | 182k | float lcost = EstimateBits(counts_below.data(), max_symbols) + | 326 | 182k | extra_bits_below; | 327 | 182k | JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); | 328 | 182k | float penalty = 0; | 329 | | // Never discourage moving away from the Weighted predictor. | 330 | 182k | if (tree_samples.PredictorFromIndex(pred) != | 331 | 182k | (*tree)[pos].predictor && | 332 | 182k | (*tree)[pos].predictor != Predictor::Weighted) { | 333 | 0 | penalty = change_pred_penalty; | 334 | 0 | } | 335 | | // If everything else is equal, disfavour Weighted (slower) and | 336 | | // favour Zero (faster if it's the only predictor used in a | 337 | | // group+channel combination) | 338 | 182k | if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { | 339 | 0 | penalty += 1e-8; | 340 | 0 | } | 341 | 182k | if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { | 342 | 0 | penalty -= 1e-8; | 343 | 0 | } | 344 | 182k | if (rcost + penalty < costs_r[i - first_used].Cost()) { | 345 | 182k | costs_r[i - first_used].cost = rcost; | 346 | 182k | costs_r[i - first_used].extra_cost = penalty; | 347 | 182k | costs_r[i - first_used].pred = | 348 | 182k | tree_samples.PredictorFromIndex(pred); | 349 | 182k | } | 350 | 182k | if (lcost + penalty < costs_l[i - first_used].Cost()) { | 351 | 182k | costs_l[i - first_used].cost = lcost; | 352 | 182k | costs_l[i - first_used].extra_cost = penalty; | 353 | 182k | costs_l[i - first_used].pred = | 354 | 182k | tree_samples.PredictorFromIndex(pred); | 355 | 182k | } | 356 | 182k | } | 357 | 42.1k | } | 358 | | // Iterate through the possible splits and find the one with minimum sum | 359 | | // of costs of the two sides. | 360 | 42.1k | size_t split = begin; | 361 | 351k | for (size_t i = first_used; i < last_used; i++) { | 362 | 309k | if (!prop_value_used_count[i]) continue; | 363 | 182k | split += prop_value_used_count[i]; | 364 | 182k | float rcost = costs_r[i - first_used].cost; | 365 | 182k | float lcost = costs_l[i - first_used].cost; | 366 | | // WP was not used + we would use the WP property or predictor | 367 | 182k | bool adds_wp = | 368 | 182k | (tree_samples.PropertyFromIndex(prop) == kWPProp && | 369 | 182k | (used_properties & (1LU << prop)) == 0) || | 370 | 182k | ((costs_l[i - first_used].pred == Predictor::Weighted || | 371 | 128k | costs_r[i - first_used].pred == Predictor::Weighted) && | 372 | 128k | (*tree)[pos].predictor != Predictor::Weighted); | 373 | 182k | bool zero_entropy_side = rcost == 0 || lcost == 0; | 374 | | | 375 | 182k | SplitInfo &best_ref = | 376 | 182k | tree_samples.PropertyFromIndex(prop) < kNumStaticProperties | 377 | 182k | ? (zero_entropy_side ? best_split_static_constant | 378 | 1.27k | : best_split_static) | 379 | 182k | : (adds_wp ? best_split_nonstatic : best_split_nowp); | 380 | 182k | if (lcost + rcost < best_ref.Cost()) { | 381 | 39.1k | best_ref.prop = prop; | 382 | 39.1k | best_ref.val = i; | 383 | 39.1k | best_ref.pos = split; | 384 | 39.1k | best_ref.lcost = lcost; | 385 | 39.1k | best_ref.lpred = costs_l[i - first_used].pred; | 386 | 39.1k | best_ref.rcost = rcost; | 387 | 39.1k | best_ref.rpred = costs_r[i - first_used].pred; | 388 | 39.1k | } | 389 | 182k | } | 390 | | // Clear extra_bits_increase and cost_increase for last_used. | 391 | 42.1k | extra_bits_increase[last_used] = 0; | 392 | 2.39M | for (size_t sym = 0; sym < max_symbols; sym++) { | 393 | 2.35M | count_increase[last_used * max_symbols + sym] = 0; | 394 | 2.35M | } | 395 | 42.1k | } | 396 | | | 397 | | // Try to avoid introducing WP. | 398 | 6.07k | if (best_split_nowp.Cost() + threshold < base_bits && | 399 | 6.07k | best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { | 400 | 2.76k | best = &best_split_nowp; | 401 | 2.76k | } | 402 | | // Split along static props if possible and not significantly more | 403 | | // expensive. | 404 | 6.07k | if (best_split_static.Cost() + threshold < base_bits && | 405 | 6.07k | best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { | 406 | 402 | best = &best_split_static; | 407 | 402 | } | 408 | | // Split along static props to create constant nodes if possible. | 409 | 6.07k | if (best_split_static_constant.Cost() + threshold < base_bits) { | 410 | 2 | best = &best_split_static_constant; | 411 | 2 | } | 412 | 6.07k | } | 413 | | | 414 | 6.07k | if (best->Cost() + threshold < base_bits) { | 415 | 2.98k | uint32_t p = tree_samples.PropertyFromIndex(best->prop); | 416 | 2.98k | pixel_type dequant = | 417 | 2.98k | tree_samples.UnquantizeProperty(best->prop, best->val); | 418 | | // Split node and try to split children. | 419 | 2.98k | MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); | 420 | | // "Sort" according to winning property | 421 | 2.98k | SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop, | 422 | 2.98k | best->val); | 423 | 2.98k | if (p >= kNumStaticProperties) { | 424 | 2.57k | used_properties |= 1 << best->prop; | 425 | 2.57k | } | 426 | 2.98k | auto new_sp_range = static_prop_range; | 427 | 2.98k | if (p < kNumStaticProperties) { | 428 | 404 | JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); | 429 | 404 | new_sp_range[p][1] = dequant + 1; | 430 | 404 | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); | 431 | 404 | } | 432 | 2.98k | nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, | 433 | 2.98k | used_properties, new_sp_range}); | 434 | 2.98k | new_sp_range = static_prop_range; | 435 | 2.98k | if (p < kNumStaticProperties) { | 436 | 404 | JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); | 437 | 404 | new_sp_range[p][0] = dequant + 1; | 438 | 404 | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); | 439 | 404 | } | 440 | 2.98k | nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, | 441 | 2.98k | used_properties, new_sp_range}); | 442 | 2.98k | } | 443 | 6.07k | } | 444 | 111 | } |
Unexecuted instantiation: jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) |
445 | | |
446 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
447 | | } // namespace HWY_NAMESPACE |
448 | | } // namespace jxl |
449 | | HWY_AFTER_NAMESPACE(); |
450 | | |
451 | | #if HWY_ONCE |
452 | | namespace jxl { |
453 | | |
454 | | HWY_EXPORT(FindBestSplit); // Local function. |
455 | | |
456 | | Status ComputeBestTree(TreeSamples &tree_samples, float threshold, |
457 | | const std::vector<ModularMultiplierInfo> &mul_info, |
458 | | StaticPropRange static_prop_range, |
459 | 111 | float fast_decode_multiplier, Tree *tree) { |
460 | | // TODO(veluca): take into account that different contexts can have different |
461 | | // uint configs. |
462 | | // |
463 | | // Initialize tree. |
464 | 111 | tree->emplace_back(); |
465 | 111 | tree->back().property = -1; |
466 | 111 | tree->back().predictor = tree_samples.PredictorFromIndex(0); |
467 | 111 | tree->back().predictor_offset = 0; |
468 | 111 | tree->back().multiplier = 1; |
469 | 111 | JXL_ENSURE(tree_samples.NumProperties() < 64); |
470 | | |
471 | 111 | JXL_ENSURE(tree_samples.NumDistinctSamples() <= |
472 | 111 | std::numeric_limits<uint32_t>::max()); |
473 | 111 | HWY_DYNAMIC_DISPATCH(FindBestSplit) |
474 | 111 | (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, |
475 | 111 | tree); |
476 | 111 | return true; |
477 | 111 | } |
478 | | |
479 | | #if JXL_CXX_LANG < JXL_CXX_17 |
480 | | constexpr int32_t TreeSamples::kPropertyRange; |
481 | | constexpr uint32_t TreeSamples::kDedupEntryUnused; |
482 | | #endif |
483 | | |
484 | | Status TreeSamples::SetPredictor(Predictor predictor, |
485 | 111 | ModularOptions::TreeMode wp_tree_mode) { |
486 | 111 | if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { |
487 | 0 | predictors = {Predictor::Weighted}; |
488 | 0 | residuals.resize(1); |
489 | 0 | return true; |
490 | 0 | } |
491 | 111 | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && |
492 | 111 | predictor == Predictor::Weighted) { |
493 | 0 | return JXL_FAILURE("Invalid predictor settings"); |
494 | 0 | } |
495 | 111 | if (predictor == Predictor::Variable) { |
496 | 0 | for (size_t i = 0; i < kNumModularPredictors; i++) { |
497 | 0 | predictors.push_back(static_cast<Predictor>(i)); |
498 | 0 | } |
499 | 0 | std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); |
500 | 0 | std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); |
501 | 111 | } else if (predictor == Predictor::Best) { |
502 | 0 | predictors = {Predictor::Weighted, Predictor::Gradient}; |
503 | 111 | } else { |
504 | 111 | predictors = {predictor}; |
505 | 111 | } |
506 | 111 | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { |
507 | 0 | auto wp_it = |
508 | 0 | std::find(predictors.begin(), predictors.end(), Predictor::Weighted); |
509 | 0 | if (wp_it != predictors.end()) { |
510 | 0 | predictors.erase(wp_it); |
511 | 0 | } |
512 | 0 | } |
513 | 111 | residuals.resize(predictors.size()); |
514 | 111 | return true; |
515 | 111 | } |
516 | | |
517 | | Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, |
518 | 111 | ModularOptions::TreeMode wp_tree_mode) { |
519 | 111 | props_to_use = properties; |
520 | 111 | if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { |
521 | 0 | props_to_use = {static_cast<uint32_t>(kWPProp)}; |
522 | 0 | } |
523 | 111 | if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { |
524 | 0 | props_to_use = {static_cast<uint32_t>(kGradientProp)}; |
525 | 0 | } |
526 | 111 | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { |
527 | 0 | auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp); |
528 | 0 | if (it != props_to_use.end()) { |
529 | 0 | props_to_use.erase(it); |
530 | 0 | } |
531 | 0 | } |
532 | 111 | if (props_to_use.empty()) { |
533 | 0 | return JXL_FAILURE("Invalid property set configuration"); |
534 | 0 | } |
535 | 111 | props.resize(props_to_use.size()); |
536 | 111 | return true; |
537 | 111 | } |
538 | | |
539 | 306 | void TreeSamples::InitTable(size_t log_size) { |
540 | 306 | size_t size = 1ULL << log_size; |
541 | 306 | if (dedup_table_.size() == size) return; |
542 | 167 | dedup_table_.resize(size, kDedupEntryUnused); |
543 | 47.1k | for (size_t i = 0; i < NumDistinctSamples(); i++) { |
544 | 46.9k | if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { |
545 | 46.9k | AddToTable(i); |
546 | 46.9k | } |
547 | 46.9k | } |
548 | 167 | } |
549 | | |
550 | 806k | bool TreeSamples::AddToTableAndMerge(size_t a) { |
551 | 806k | size_t pos1 = Hash1(a); |
552 | 806k | size_t pos2 = Hash2(a); |
553 | 806k | if (dedup_table_[pos1] != kDedupEntryUnused && |
554 | 806k | IsSameSample(a, dedup_table_[pos1])) { |
555 | 472k | JXL_DASSERT(sample_counts[a] == 1); |
556 | 472k | sample_counts[dedup_table_[pos1]]++; |
557 | | // Remove from hash table samples that are saturated. |
558 | 472k | if (sample_counts[dedup_table_[pos1]] == |
559 | 472k | std::numeric_limits<uint16_t>::max()) { |
560 | 0 | dedup_table_[pos1] = kDedupEntryUnused; |
561 | 0 | } |
562 | 472k | return true; |
563 | 472k | } |
564 | 333k | if (dedup_table_[pos2] != kDedupEntryUnused && |
565 | 333k | IsSameSample(a, dedup_table_[pos2])) { |
566 | 148k | JXL_DASSERT(sample_counts[a] == 1); |
567 | 148k | sample_counts[dedup_table_[pos2]]++; |
568 | | // Remove from hash table samples that are saturated. |
569 | 148k | if (sample_counts[dedup_table_[pos2]] == |
570 | 148k | std::numeric_limits<uint16_t>::max()) { |
571 | 0 | dedup_table_[pos2] = kDedupEntryUnused; |
572 | 0 | } |
573 | 148k | return true; |
574 | 148k | } |
575 | 185k | AddToTable(a); |
576 | 185k | return false; |
577 | 333k | } |
578 | | |
579 | 232k | void TreeSamples::AddToTable(size_t a) { |
580 | 232k | size_t pos1 = Hash1(a); |
581 | 232k | size_t pos2 = Hash2(a); |
582 | 232k | if (dedup_table_[pos1] == kDedupEntryUnused) { |
583 | 121k | dedup_table_[pos1] = a; |
584 | 121k | } else if (dedup_table_[pos2] == kDedupEntryUnused) { |
585 | 68.9k | dedup_table_[pos2] = a; |
586 | 68.9k | } |
587 | 232k | } |
588 | | |
589 | 306 | void TreeSamples::PrepareForSamples(size_t extra_num_samples) { |
590 | 306 | for (auto &res : residuals) { |
591 | 306 | res.reserve(res.size() + extra_num_samples); |
592 | 306 | } |
593 | 2.14k | for (auto &p : props) { |
594 | 2.14k | p.reserve(p.size() + extra_num_samples); |
595 | 2.14k | } |
596 | 306 | size_t total_num_samples = extra_num_samples + sample_counts.size(); |
597 | 306 | size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2); |
598 | 306 | InitTable(next_size); |
599 | 306 | } |
600 | | |
601 | 1.03M | size_t TreeSamples::Hash1(size_t a) const { |
602 | 1.03M | constexpr uint64_t constant = 0x1e35a7bd; |
603 | 1.03M | uint64_t h = constant; |
604 | 1.03M | for (const auto &r : residuals) { |
605 | 1.03M | h = h * constant + r[a].tok; |
606 | 1.03M | h = h * constant + r[a].nbits; |
607 | 1.03M | } |
608 | 7.26M | for (const auto &p : props) { |
609 | 7.26M | h = h * constant + p[a]; |
610 | 7.26M | } |
611 | 1.03M | return (h >> 16) & (dedup_table_.size() - 1); |
612 | 1.03M | } |
613 | 1.03M | size_t TreeSamples::Hash2(size_t a) const { |
614 | 1.03M | constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; |
615 | 1.03M | uint64_t h = constant; |
616 | 7.26M | for (const auto &p : props) { |
617 | 7.26M | h = h * constant ^ p[a]; |
618 | 7.26M | } |
619 | 1.03M | for (const auto &r : residuals) { |
620 | 1.03M | h = h * constant ^ r[a].tok; |
621 | 1.03M | h = h * constant ^ r[a].nbits; |
622 | 1.03M | } |
623 | 1.03M | return (h >> 16) & (dedup_table_.size() - 1); |
624 | 1.03M | } |
625 | | |
626 | 895k | bool TreeSamples::IsSameSample(size_t a, size_t b) const { |
627 | 895k | bool ret = true; |
628 | 895k | for (const auto &r : residuals) { |
629 | 895k | if (r[a].tok != r[b].tok) { |
630 | 102k | ret = false; |
631 | 102k | } |
632 | 895k | if (r[a].nbits != r[b].nbits) { |
633 | 81.4k | ret = false; |
634 | 81.4k | } |
635 | 895k | } |
636 | 6.26M | for (const auto &p : props) { |
637 | 6.26M | if (p[a] != p[b]) { |
638 | 738k | ret = false; |
639 | 738k | } |
640 | 6.26M | } |
641 | 895k | return ret; |
642 | 895k | } |
643 | | |
644 | | void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, |
645 | 806k | const pixel_type_w *predictions) { |
646 | 1.61M | for (size_t i = 0; i < predictors.size(); i++) { |
647 | 806k | pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; |
648 | 806k | uint32_t tok, nbits, bits; |
649 | 806k | HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); |
650 | 806k | JXL_DASSERT(tok < 256); |
651 | 806k | JXL_DASSERT(nbits < 256); |
652 | 806k | residuals[i].emplace_back( |
653 | 806k | ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)}); |
654 | 806k | } |
655 | 6.44M | for (size_t i = 0; i < props_to_use.size(); i++) { |
656 | 5.64M | props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]])); |
657 | 5.64M | } |
658 | 806k | sample_counts.push_back(1); |
659 | 806k | num_samples++; |
660 | 806k | if (AddToTableAndMerge(sample_counts.size() - 1)) { |
661 | 620k | for (auto &r : residuals) r.pop_back(); |
662 | 4.34M | for (auto &p : props) p.pop_back(); |
663 | 620k | sample_counts.pop_back(); |
664 | 620k | } |
665 | 806k | } |
666 | | |
667 | 195k | void TreeSamples::Swap(size_t a, size_t b) { |
668 | 195k | if (a == b) return; |
669 | 195k | for (auto &r : residuals) { |
670 | 195k | std::swap(r[a], r[b]); |
671 | 195k | } |
672 | 1.36M | for (auto &p : props) { |
673 | 1.36M | std::swap(p[a], p[b]); |
674 | 1.36M | } |
675 | 195k | std::swap(sample_counts[a], sample_counts[b]); |
676 | 195k | } |
677 | | |
678 | | namespace { |
679 | | std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, |
680 | 321 | size_t num_chunks) { |
681 | 321 | if (histogram.empty() || num_chunks == 0) return {}; |
682 | 321 | uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); |
683 | 321 | if (sum == 0) return {}; |
684 | | // TODO(veluca): selecting distinct quantiles is likely not the best |
685 | | // way to go about this. |
686 | 319 | std::vector<int32_t> thresholds; |
687 | 319 | uint64_t cumsum = 0; |
688 | 319 | uint64_t threshold = 1; |
689 | 116k | for (size_t i = 0; i < histogram.size(); i++) { |
690 | 116k | cumsum += histogram[i]; |
691 | 116k | if (cumsum * num_chunks >= threshold * sum) { |
692 | 1.72k | thresholds.push_back(i); |
693 | 17.0k | while (cumsum * num_chunks >= threshold * sum) threshold++; |
694 | 1.72k | } |
695 | 116k | } |
696 | 319 | JXL_DASSERT(thresholds.size() <= num_chunks); |
697 | | // last value collects all histogram and is not really a threshold |
698 | 319 | thresholds.pop_back(); |
699 | 319 | return thresholds; |
700 | 321 | } |
701 | | |
702 | | std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, |
703 | 123 | size_t num_chunks) { |
704 | 123 | if (samples.empty()) return {}; |
705 | 113 | int min = *std::min_element(samples.begin(), samples.end()); |
706 | 113 | constexpr int kRange = 512; |
707 | 113 | min = jxl::Clamp1(min, -kRange, kRange); |
708 | 113 | std::vector<uint32_t> counts(2 * kRange + 1); |
709 | 76.2k | for (int s : samples) { |
710 | 76.2k | uint32_t sample_offset = jxl::Clamp1(s, -kRange, kRange) - min; |
711 | 76.2k | counts[sample_offset]++; |
712 | 76.2k | } |
713 | 113 | std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); |
714 | 1.21k | for (auto &v : thresholds) v += min; |
715 | 113 | return thresholds; |
716 | 123 | } |
717 | | } // namespace |
718 | | |
719 | | void TreeSamples::PreQuantizeProperties( |
720 | | const StaticPropRange &range, |
721 | | const std::vector<ModularMultiplierInfo> &multiplier_info, |
722 | | const std::vector<uint32_t> &group_pixel_count, |
723 | | const std::vector<uint32_t> &channel_pixel_count, |
724 | | std::vector<pixel_type> &pixel_samples, |
725 | 111 | std::vector<pixel_type> &diff_samples, size_t max_property_values) { |
726 | | // If we have forced splits because of multipliers, choose channel and group |
727 | | // thresholds accordingly. |
728 | 111 | std::vector<int32_t> group_multiplier_thresholds; |
729 | 111 | std::vector<int32_t> channel_multiplier_thresholds; |
730 | 111 | for (const auto &v : multiplier_info) { |
731 | 97 | if (v.range[0][0] != range[0][0]) { |
732 | 0 | channel_multiplier_thresholds.push_back(v.range[0][0] - 1); |
733 | 0 | } |
734 | 97 | if (v.range[0][1] != range[0][1]) { |
735 | 0 | channel_multiplier_thresholds.push_back(v.range[0][1] - 1); |
736 | 0 | } |
737 | 97 | if (v.range[1][0] != range[1][0]) { |
738 | 0 | group_multiplier_thresholds.push_back(v.range[1][0] - 1); |
739 | 0 | } |
740 | 97 | if (v.range[1][1] != range[1][1]) { |
741 | 0 | group_multiplier_thresholds.push_back(v.range[1][1] - 1); |
742 | 0 | } |
743 | 97 | } |
744 | 111 | std::sort(channel_multiplier_thresholds.begin(), |
745 | 111 | channel_multiplier_thresholds.end()); |
746 | 111 | channel_multiplier_thresholds.resize( |
747 | 111 | std::unique(channel_multiplier_thresholds.begin(), |
748 | 111 | channel_multiplier_thresholds.end()) - |
749 | 111 | channel_multiplier_thresholds.begin()); |
750 | 111 | std::sort(group_multiplier_thresholds.begin(), |
751 | 111 | group_multiplier_thresholds.end()); |
752 | 111 | group_multiplier_thresholds.resize( |
753 | 111 | std::unique(group_multiplier_thresholds.begin(), |
754 | 111 | group_multiplier_thresholds.end()) - |
755 | 111 | group_multiplier_thresholds.begin()); |
756 | | |
757 | 111 | compact_properties.resize(props_to_use.size()); |
758 | 111 | auto quantize_channel = [&]() { |
759 | 111 | if (!channel_multiplier_thresholds.empty()) { |
760 | 0 | return channel_multiplier_thresholds; |
761 | 0 | } |
762 | 111 | return QuantizeHistogram(channel_pixel_count, max_property_values); |
763 | 111 | }; |
764 | 111 | auto quantize_group_id = [&]() { |
765 | 97 | if (!group_multiplier_thresholds.empty()) { |
766 | 0 | return group_multiplier_thresholds; |
767 | 0 | } |
768 | 97 | return QuantizeHistogram(group_pixel_count, max_property_values); |
769 | 97 | }; |
770 | 111 | auto quantize_coordinate = [&]() { |
771 | 0 | std::vector<int32_t> quantized; |
772 | 0 | quantized.reserve(max_property_values - 1); |
773 | 0 | for (size_t i = 0; i + 1 < max_property_values; i++) { |
774 | 0 | quantized.push_back((i + 1) * 256 / max_property_values - 1); |
775 | 0 | } |
776 | 0 | return quantized; |
777 | 0 | }; |
778 | 111 | std::vector<int32_t> abs_pixel_thresholds; |
779 | 111 | std::vector<int32_t> pixel_thresholds; |
780 | 111 | auto quantize_pixel_property = [&]() { |
781 | 0 | if (pixel_thresholds.empty()) { |
782 | 0 | pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values); |
783 | 0 | } |
784 | 0 | return pixel_thresholds; |
785 | 0 | }; |
786 | 111 | auto quantize_abs_pixel_property = [&]() { |
787 | 0 | if (abs_pixel_thresholds.empty()) { |
788 | 0 | quantize_pixel_property(); // Compute the non-abs thresholds. |
789 | 0 | for (auto &v : pixel_samples) v = std::abs(v); |
790 | 0 | abs_pixel_thresholds = |
791 | 0 | QuantizeSamples(pixel_samples, max_property_values); |
792 | 0 | } |
793 | 0 | return abs_pixel_thresholds; |
794 | 0 | }; |
795 | 111 | std::vector<int32_t> abs_diff_thresholds; |
796 | 111 | std::vector<int32_t> diff_thresholds; |
797 | 458 | auto quantize_diff_property = [&]() { |
798 | 458 | if (diff_thresholds.empty()) { |
799 | 123 | diff_thresholds = QuantizeSamples(diff_samples, max_property_values); |
800 | 123 | } |
801 | 458 | return diff_thresholds; |
802 | 458 | }; |
803 | 111 | auto quantize_abs_diff_property = [&]() { |
804 | 0 | if (abs_diff_thresholds.empty()) { |
805 | 0 | quantize_diff_property(); // Compute the non-abs thresholds. |
806 | 0 | for (auto &v : diff_samples) v = std::abs(v); |
807 | 0 | abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values); |
808 | 0 | } |
809 | 0 | return abs_diff_thresholds; |
810 | 0 | }; |
811 | 111 | auto quantize_wp = [&]() { |
812 | 111 | if (max_property_values < 32) { |
813 | 0 | return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, |
814 | 0 | 1, 3, 7, 15, 31, 63, 127}; |
815 | 0 | } |
816 | 111 | if (max_property_values < 64) { |
817 | 111 | return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, |
818 | 111 | -15, -11, -7, -5, -3, -1, 0, 1, |
819 | 111 | 3, 5, 7, 11, 15, 23, 31, 47, |
820 | 111 | 63, 95, 127, 191, 255}; |
821 | 111 | } |
822 | 0 | return std::vector<int32_t>{ |
823 | 0 | -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, |
824 | 0 | -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, |
825 | 0 | -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, |
826 | 0 | 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, |
827 | 0 | 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; |
828 | 111 | }; |
829 | | |
830 | 111 | property_mapping.resize(props_to_use.size()); |
831 | 888 | for (size_t i = 0; i < props_to_use.size(); i++) { |
832 | 777 | if (props_to_use[i] == 0) { |
833 | 111 | compact_properties[i] = quantize_channel(); |
834 | 666 | } else if (props_to_use[i] == 1) { |
835 | 97 | compact_properties[i] = quantize_group_id(); |
836 | 569 | } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { |
837 | 0 | compact_properties[i] = quantize_coordinate(); |
838 | 569 | } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || |
839 | 569 | props_to_use[i] == 8 || |
840 | 569 | (props_to_use[i] >= kNumNonrefProperties && |
841 | 569 | (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { |
842 | 0 | compact_properties[i] = quantize_pixel_property(); |
843 | 569 | } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || |
844 | 569 | (props_to_use[i] >= kNumNonrefProperties && |
845 | 569 | (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { |
846 | 0 | compact_properties[i] = quantize_abs_pixel_property(); |
847 | 569 | } else if (props_to_use[i] >= kNumNonrefProperties && |
848 | 569 | (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { |
849 | 0 | compact_properties[i] = quantize_abs_diff_property(); |
850 | 569 | } else if (props_to_use[i] == kWPProp) { |
851 | 111 | compact_properties[i] = quantize_wp(); |
852 | 458 | } else { |
853 | 458 | compact_properties[i] = quantize_diff_property(); |
854 | 458 | } |
855 | 777 | property_mapping[i].resize(kPropertyRange * 2 + 1); |
856 | 777 | size_t mapped = 0; |
857 | 795k | for (size_t j = 0; j < property_mapping[i].size(); j++) { |
858 | 803k | while (mapped < compact_properties[i].size() && |
859 | 803k | static_cast<int>(j) - kPropertyRange > |
860 | 431k | compact_properties[i][mapped]) { |
861 | 8.29k | mapped++; |
862 | 8.29k | } |
863 | 794k | JXL_DASSERT(mapped < 256); |
864 | | // property_mapping[i] of a value V is `mapped` if |
865 | | // compact_properties[i][mapped] <= j and |
866 | | // compact_properties[i][mapped-1] > j |
867 | | // This is because the decision node in the tree splits on (property) > j, |
868 | | // hence everything that is not > of a threshold should be clustered |
869 | | // together. |
870 | 794k | property_mapping[i][j] = mapped; |
871 | 794k | } |
872 | 777 | } |
873 | 111 | } |
874 | | |
875 | | void CollectPixelSamples(const Image &image, const ModularOptions &options, |
876 | | uint32_t group_id, |
877 | | std::vector<uint32_t> &group_pixel_count, |
878 | | std::vector<uint32_t> &channel_pixel_count, |
879 | | std::vector<pixel_type> &pixel_samples, |
880 | 111 | std::vector<pixel_type> &diff_samples) { |
881 | 111 | if (options.nb_repeats == 0) return; |
882 | 111 | if (group_pixel_count.size() <= group_id) { |
883 | 111 | group_pixel_count.resize(group_id + 1); |
884 | 111 | } |
885 | 111 | if (channel_pixel_count.size() < image.channel.size()) { |
886 | 111 | channel_pixel_count.resize(image.channel.size()); |
887 | 111 | } |
888 | 111 | Rng rng(group_id); |
889 | | // Sample 10% of the final number of samples for property quantization. |
890 | 111 | float fraction = std::min(options.nb_repeats * 0.1, 0.99); |
891 | 111 | Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction); |
892 | 111 | size_t total_pixels = 0; |
893 | 111 | std::vector<size_t> channel_ids; |
894 | 417 | for (size_t i = 0; i < image.channel.size(); i++) { |
895 | 306 | if (i >= image.nb_meta_channels && |
896 | 306 | (image.channel[i].w > options.max_chan_size || |
897 | 305 | image.channel[i].h > options.max_chan_size)) { |
898 | 0 | break; |
899 | 0 | } |
900 | 306 | if (image.channel[i].w <= 1 || image.channel[i].h == 0) { |
901 | 2 | continue; // skip empty or width-1 channels. |
902 | 2 | } |
903 | 304 | channel_ids.push_back(i); |
904 | 304 | group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; |
905 | 304 | channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; |
906 | 304 | total_pixels += image.channel[i].w * image.channel[i].h; |
907 | 304 | } |
908 | 111 | if (channel_ids.empty()) return; |
909 | 109 | pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); |
910 | 109 | diff_samples.reserve(diff_samples.size() + fraction * total_pixels); |
911 | 109 | size_t i = 0; |
912 | 109 | size_t y = 0; |
913 | 109 | size_t x = 0; |
914 | 76.1k | auto advance = [&](size_t amount) { |
915 | 76.1k | x += amount; |
916 | | // Detect row overflow (rare). |
917 | 94.8k | while (x >= image.channel[channel_ids[i]].w) { |
918 | 18.7k | x -= image.channel[channel_ids[i]].w; |
919 | 18.7k | y++; |
920 | | // Detect end-of-channel (even rarer). |
921 | 18.7k | if (y == image.channel[channel_ids[i]].h) { |
922 | 304 | i++; |
923 | 304 | y = 0; |
924 | 304 | if (i >= channel_ids.size()) { |
925 | 109 | return; |
926 | 109 | } |
927 | 304 | } |
928 | 18.7k | } |
929 | 76.1k | }; |
930 | 109 | advance(rng.Geometric(dist)); |
931 | 76.1k | for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { |
932 | 76.0k | const pixel_type *row = image.channel[channel_ids[i]].Row(y); |
933 | 76.0k | pixel_samples.push_back(row[x]); |
934 | 76.0k | size_t xp = x == 0 ? 1 : x - 1; |
935 | 76.0k | diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]); |
936 | 76.0k | } |
937 | 109 | } |
938 | | |
939 | | // TODO(veluca): very simple encoding scheme. This should be improved. |
940 | | Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens, |
941 | 283 | Tree *decoder_tree) { |
942 | 283 | JXL_ENSURE(tree.size() <= kMaxTreeSize); |
943 | 283 | std::queue<int> q; |
944 | 283 | q.push(0); |
945 | 283 | size_t leaf_id = 0; |
946 | 283 | decoder_tree->clear(); |
947 | 14.9k | while (!q.empty()) { |
948 | 14.6k | int cur = q.front(); |
949 | 14.6k | q.pop(); |
950 | 14.6k | JXL_ENSURE(tree[cur].property >= -1); |
951 | 14.6k | tokens->emplace_back(kPropertyContext, tree[cur].property + 1); |
952 | 14.6k | if (tree[cur].property == -1) { |
953 | 7.47k | tokens->emplace_back(kPredictorContext, |
954 | 7.47k | static_cast<int>(tree[cur].predictor)); |
955 | 7.47k | tokens->emplace_back(kOffsetContext, |
956 | 7.47k | PackSigned(tree[cur].predictor_offset)); |
957 | 7.47k | uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); |
958 | 7.47k | uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; |
959 | 7.47k | tokens->emplace_back(kMultiplierLogContext, mul_log); |
960 | 7.47k | tokens->emplace_back(kMultiplierBitsContext, mul_bits); |
961 | 7.47k | JXL_ENSURE(tree[cur].predictor < Predictor::Best); |
962 | 7.47k | decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, |
963 | 7.47k | tree[cur].predictor_offset, |
964 | 7.47k | tree[cur].multiplier); |
965 | 7.47k | continue; |
966 | 7.47k | } |
967 | 7.19k | decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, |
968 | 7.19k | decoder_tree->size() + q.size() + 1, |
969 | 7.19k | decoder_tree->size() + q.size() + 2, |
970 | 7.19k | Predictor::Zero, 0, 1); |
971 | 7.19k | q.push(tree[cur].lchild); |
972 | 7.19k | q.push(tree[cur].rchild); |
973 | 7.19k | tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); |
974 | 7.19k | } |
975 | 283 | return true; |
976 | 283 | } |
977 | | |
978 | | } // namespace jxl |
979 | | #endif // HWY_ONCE |