/src/libjxl/lib/jxl/modular/encoding/enc_ma.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/modular/encoding/enc_ma.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cstdint> |
10 | | #include <cstdlib> |
11 | | #include <cstring> |
12 | | #include <limits> |
13 | | #include <numeric> |
14 | | #include <queue> |
15 | | #include <vector> |
16 | | |
17 | | #include "lib/jxl/ans_params.h" |
18 | | #include "lib/jxl/base/bits.h" |
19 | | #include "lib/jxl/base/common.h" |
20 | | #include "lib/jxl/base/compiler_specific.h" |
21 | | #include "lib/jxl/base/status.h" |
22 | | #include "lib/jxl/dec_ans.h" |
23 | | #include "lib/jxl/modular/encoding/dec_ma.h" |
24 | | #include "lib/jxl/modular/encoding/ma_common.h" |
25 | | #include "lib/jxl/modular/modular_image.h" |
26 | | |
27 | | #undef HWY_TARGET_INCLUDE |
28 | | #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" |
29 | | #include <hwy/foreach_target.h> |
30 | | #include <hwy/highway.h> |
31 | | |
32 | | #include "lib/jxl/base/fast_math-inl.h" |
33 | | #include "lib/jxl/base/random.h" |
34 | | #include "lib/jxl/enc_ans.h" |
35 | | #include "lib/jxl/modular/encoding/context_predict.h" |
36 | | #include "lib/jxl/modular/options.h" |
37 | | #include "lib/jxl/pack_signed.h" |
38 | | HWY_BEFORE_NAMESPACE(); |
39 | | namespace jxl { |
40 | | namespace HWY_NAMESPACE { |
41 | | |
42 | | // These templates are not found via ADL. |
43 | | using hwy::HWY_NAMESPACE::Eq; |
44 | | using hwy::HWY_NAMESPACE::IfThenElse; |
45 | | using hwy::HWY_NAMESPACE::Lt; |
46 | | using hwy::HWY_NAMESPACE::Max; |
47 | | |
48 | | const HWY_FULL(float) df; |
49 | | const HWY_FULL(int32_t) di; |
50 | 5.05k | size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } Unexecuted instantiation: jxl::N_SSE4::Padded(unsigned long) jxl::N_AVX2::Padded(unsigned long) Line | Count | Source | 50 | 5.05k | size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } |
Unexecuted instantiation: jxl::N_SSE2::Padded(unsigned long) |
51 | | |
52 | | // Compute entropy of the histogram, taking into account the minimum probability |
53 | | // for symbols with non-zero counts. |
54 | 376k | float EstimateBits(const int32_t *counts, size_t num_symbols) { |
55 | 376k | int32_t total = std::accumulate(counts, counts + num_symbols, 0); |
56 | 376k | const auto zero = Zero(df); |
57 | 376k | const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); |
58 | 376k | const auto inv_total = Set(df, 1.0f / total); |
59 | 376k | auto bits_lanes = Zero(df); |
60 | 376k | auto total_v = Set(di, total); |
61 | 2.61M | for (size_t i = 0; i < num_symbols; i += Lanes(df)) { |
62 | 2.23M | const auto counts_iv = LoadU(di, &counts[i]); |
63 | 2.23M | const auto counts_fv = ConvertTo(df, counts_iv); |
64 | 2.23M | const auto probs = Mul(counts_fv, inv_total); |
65 | 2.23M | const auto mprobs = Max(probs, minprob); |
66 | 2.23M | const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), |
67 | 2.23M | BitCast(di, FastLog2f(df, mprobs))); |
68 | 2.23M | bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); |
69 | 2.23M | } |
70 | 376k | return GetLane(SumOfLanes(df, bits_lanes)); |
71 | 376k | } Unexecuted instantiation: jxl::N_SSE4::EstimateBits(int const*, unsigned long) jxl::N_AVX2::EstimateBits(int const*, unsigned long) Line | Count | Source | 54 | 376k | float EstimateBits(const int32_t *counts, size_t num_symbols) { | 55 | 376k | int32_t total = std::accumulate(counts, counts + num_symbols, 0); | 56 | 376k | const auto zero = Zero(df); | 57 | 376k | const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); | 58 | 376k | const auto inv_total = Set(df, 1.0f / total); | 59 | 376k | auto bits_lanes = Zero(df); | 60 | 376k | auto total_v = Set(di, total); | 61 | 2.61M | for (size_t i = 0; i < num_symbols; i += Lanes(df)) { | 62 | 2.23M | const auto counts_iv = LoadU(di, &counts[i]); | 63 | 2.23M | const auto counts_fv = ConvertTo(df, counts_iv); | 64 | 2.23M | const auto probs = Mul(counts_fv, inv_total); | 65 | 2.23M | const auto mprobs = Max(probs, minprob); | 66 | 2.23M | const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), | 67 | 2.23M | BitCast(di, FastLog2f(df, mprobs))); | 68 | 2.23M | bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); | 69 | 2.23M | } | 70 | 376k | return GetLane(SumOfLanes(df, bits_lanes)); | 71 | 376k | } |
Unexecuted instantiation: jxl::N_SSE2::EstimateBits(int const*, unsigned long) |
72 | | |
73 | | void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, |
74 | 2.49k | int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { |
75 | | // Note that the tree splits on *strictly greater*. |
76 | 2.49k | (*tree)[pos].lchild = tree->size(); |
77 | 2.49k | (*tree)[pos].rchild = tree->size() + 1; |
78 | 2.49k | (*tree)[pos].splitval = splitval; |
79 | 2.49k | (*tree)[pos].property = property; |
80 | 2.49k | tree->emplace_back(); |
81 | 2.49k | tree->back().property = -1; |
82 | 2.49k | tree->back().predictor = rpred; |
83 | 2.49k | tree->back().predictor_offset = roff; |
84 | 2.49k | tree->back().multiplier = 1; |
85 | 2.49k | tree->emplace_back(); |
86 | 2.49k | tree->back().property = -1; |
87 | 2.49k | tree->back().predictor = lpred; |
88 | 2.49k | tree->back().predictor_offset = loff; |
89 | 2.49k | tree->back().multiplier = 1; |
90 | 2.49k | } Unexecuted instantiation: jxl::N_SSE4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) jxl::N_AVX2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) Line | Count | Source | 74 | 2.49k | int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { | 75 | | // Note that the tree splits on *strictly greater*. | 76 | 2.49k | (*tree)[pos].lchild = tree->size(); | 77 | 2.49k | (*tree)[pos].rchild = tree->size() + 1; | 78 | 2.49k | (*tree)[pos].splitval = splitval; | 79 | 2.49k | (*tree)[pos].property = property; | 80 | 2.49k | tree->emplace_back(); | 81 | 2.49k | tree->back().property = -1; | 82 | 2.49k | tree->back().predictor = rpred; | 83 | 2.49k | tree->back().predictor_offset = roff; | 84 | 2.49k | tree->back().multiplier = 1; | 85 | 2.49k | tree->emplace_back(); | 86 | 2.49k | tree->back().property = -1; | 87 | 2.49k | tree->back().predictor = lpred; | 88 | 2.49k | tree->back().predictor_offset = loff; | 89 | 2.49k | tree->back().multiplier = 1; | 90 | 2.49k | } |
Unexecuted instantiation: jxl::N_SSE2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) |
91 | | |
92 | | enum class IntersectionType { kNone, kPartial, kInside }; |
93 | | IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, |
94 | 5.03k | uint32_t &partial_axis, uint32_t &partial_val) { |
95 | 5.03k | bool partial = false; |
96 | 15.0k | for (size_t i = 0; i < kNumStaticProperties; i++) { |
97 | 10.0k | if (haystack[i][0] >= needle[i][1]) { |
98 | 0 | return IntersectionType::kNone; |
99 | 0 | } |
100 | 10.0k | if (haystack[i][1] <= needle[i][0]) { |
101 | 0 | return IntersectionType::kNone; |
102 | 0 | } |
103 | 10.0k | if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { |
104 | 10.0k | continue; |
105 | 10.0k | } |
106 | 0 | partial = true; |
107 | 0 | partial_axis = i; |
108 | 0 | if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { |
109 | 0 | partial_val = haystack[i][0] - 1; |
110 | 0 | } else { |
111 | 0 | JXL_DASSERT(haystack[i][1] > needle[i][0] && |
112 | 0 | haystack[i][1] < needle[i][1]); |
113 | 0 | partial_val = haystack[i][1] - 1; |
114 | 0 | } |
115 | 0 | } |
116 | 5.03k | return partial ? IntersectionType::kPartial : IntersectionType::kInside; |
117 | 5.03k | } Unexecuted instantiation: jxl::N_SSE4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) jxl::N_AVX2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) Line | Count | Source | 94 | 5.03k | uint32_t &partial_axis, uint32_t &partial_val) { | 95 | 5.03k | bool partial = false; | 96 | 15.0k | for (size_t i = 0; i < kNumStaticProperties; i++) { | 97 | 10.0k | if (haystack[i][0] >= needle[i][1]) { | 98 | 0 | return IntersectionType::kNone; | 99 | 0 | } | 100 | 10.0k | if (haystack[i][1] <= needle[i][0]) { | 101 | 0 | return IntersectionType::kNone; | 102 | 0 | } | 103 | 10.0k | if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { | 104 | 10.0k | continue; | 105 | 10.0k | } | 106 | 0 | partial = true; | 107 | 0 | partial_axis = i; | 108 | 0 | if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { | 109 | 0 | partial_val = haystack[i][0] - 1; | 110 | 0 | } else { | 111 | 0 | JXL_DASSERT(haystack[i][1] > needle[i][0] && | 112 | 0 | haystack[i][1] < needle[i][1]); | 113 | 0 | partial_val = haystack[i][1] - 1; | 114 | 0 | } | 115 | 0 | } | 116 | 5.03k | return partial ? IntersectionType::kPartial : IntersectionType::kInside; | 117 | 5.03k | } |
Unexecuted instantiation: jxl::N_SSE2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) |
118 | | |
119 | | template<bool S> |
120 | | void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, |
121 | 2.49k | size_t end, size_t prop, uint32_t val) { |
122 | 2.49k | size_t begin_pos = begin; |
123 | 2.49k | size_t end_pos = pos; |
124 | 193k | do { |
125 | 476k | while (begin_pos < pos && |
126 | 476k | tree_samples.Property<S>(prop, begin_pos) <= val) { |
127 | 283k | ++begin_pos; |
128 | 283k | } |
129 | 476k | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { |
130 | 283k | ++end_pos; |
131 | 283k | } |
132 | 193k | if (begin_pos < pos && end_pos < end) { |
133 | 192k | tree_samples.Swap(begin_pos, end_pos); |
134 | 192k | } |
135 | 193k | ++begin_pos; |
136 | 193k | ++end_pos; |
137 | 193k | } while (begin_pos < pos && end_pos < end); |
138 | 2.49k | } Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) void jxl::N_AVX2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Line | Count | Source | 121 | 318 | size_t end, size_t prop, uint32_t val) { | 122 | 318 | size_t begin_pos = begin; | 123 | 318 | size_t end_pos = pos; | 124 | 32.2k | do { | 125 | 99.3k | while (begin_pos < pos && | 126 | 99.3k | tree_samples.Property<S>(prop, begin_pos) <= val) { | 127 | 67.0k | ++begin_pos; | 128 | 67.0k | } | 129 | 116k | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { | 130 | 84.1k | ++end_pos; | 131 | 84.1k | } | 132 | 32.2k | if (begin_pos < pos && end_pos < end) { | 133 | 32.1k | tree_samples.Swap(begin_pos, end_pos); | 134 | 32.1k | } | 135 | 32.2k | ++begin_pos; | 136 | 32.2k | ++end_pos; | 137 | 32.2k | } while (begin_pos < pos && end_pos < end); | 138 | 318 | } |
void jxl::N_AVX2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Line | Count | Source | 121 | 2.17k | size_t end, size_t prop, uint32_t val) { | 122 | 2.17k | size_t begin_pos = begin; | 123 | 2.17k | size_t end_pos = pos; | 124 | 161k | do { | 125 | 377k | while (begin_pos < pos && | 126 | 377k | tree_samples.Property<S>(prop, begin_pos) <= val) { | 127 | 216k | ++begin_pos; | 128 | 216k | } | 129 | 360k | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { | 130 | 199k | ++end_pos; | 131 | 199k | } | 132 | 161k | if (begin_pos < pos && end_pos < end) { | 133 | 160k | tree_samples.Swap(begin_pos, end_pos); | 134 | 160k | } | 135 | 161k | ++begin_pos; | 136 | 161k | ++end_pos; | 137 | 161k | } while (begin_pos < pos && end_pos < end); | 138 | 2.17k | } |
Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) |
139 | | |
140 | | template <bool S> |
141 | | void CollectExtraBitsIncrease(TreeSamples &tree_samples, |
142 | | const std::vector<ResidualToken> &rtokens, |
143 | | std::vector<int> &count_increase, |
144 | | std::vector<size_t> &extra_bits_increase, |
145 | | size_t begin, size_t end, size_t prop_idx, |
146 | 35.0k | size_t max_symbols) { |
147 | 7.84M | for (size_t i2 = begin; i2 < end; i2++) { |
148 | 7.81M | const ResidualToken &rt = rtokens[i2]; |
149 | 7.81M | size_t cnt = tree_samples.Count(i2); |
150 | 7.81M | size_t p = tree_samples.Property<S>(prop_idx, i2); |
151 | 7.81M | size_t sym = rt.tok; |
152 | 7.81M | size_t ebi = rt.nbits * cnt; |
153 | 7.81M | count_increase[p * max_symbols + sym] += cnt; |
154 | 7.81M | extra_bits_increase[p] += ebi; |
155 | 7.81M | } |
156 | 35.0k | } Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) void jxl::N_AVX2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Line | Count | Source | 146 | 10.0k | size_t max_symbols) { | 147 | 2.24M | for (size_t i2 = begin; i2 < end; i2++) { | 148 | 2.23M | const ResidualToken &rt = rtokens[i2]; | 149 | 2.23M | size_t cnt = tree_samples.Count(i2); | 150 | 2.23M | size_t p = tree_samples.Property<S>(prop_idx, i2); | 151 | 2.23M | size_t sym = rt.tok; | 152 | 2.23M | size_t ebi = rt.nbits * cnt; | 153 | 2.23M | count_increase[p * max_symbols + sym] += cnt; | 154 | 2.23M | extra_bits_increase[p] += ebi; | 155 | 2.23M | } | 156 | 10.0k | } |
void jxl::N_AVX2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Line | Count | Source | 146 | 25.0k | size_t max_symbols) { | 147 | 5.60M | for (size_t i2 = begin; i2 < end; i2++) { | 148 | 5.58M | const ResidualToken &rt = rtokens[i2]; | 149 | 5.58M | size_t cnt = tree_samples.Count(i2); | 150 | 5.58M | size_t p = tree_samples.Property<S>(prop_idx, i2); | 151 | 5.58M | size_t sym = rt.tok; | 152 | 5.58M | size_t ebi = rt.nbits * cnt; | 153 | 5.58M | count_increase[p * max_symbols + sym] += cnt; | 154 | 5.58M | extra_bits_increase[p] += ebi; | 155 | 5.58M | } | 156 | 25.0k | } |
Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) |
157 | | |
158 | | void FindBestSplit(TreeSamples &tree_samples, float threshold, |
159 | | const std::vector<ModularMultiplierInfo> &mul_info, |
160 | | StaticPropRange initial_static_prop_range, |
161 | 73 | float fast_decode_multiplier, Tree *tree) { |
162 | 73 | struct NodeInfo { |
163 | 73 | size_t pos; |
164 | 73 | size_t begin; |
165 | 73 | size_t end; |
166 | 73 | uint64_t used_properties; |
167 | 73 | StaticPropRange static_prop_range; |
168 | 73 | }; |
169 | 73 | std::vector<NodeInfo> nodes; |
170 | 73 | nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, |
171 | 73 | initial_static_prop_range}); |
172 | | |
173 | 73 | size_t num_predictors = tree_samples.NumPredictors(); |
174 | 73 | size_t num_properties = tree_samples.NumProperties(); |
175 | | |
176 | | // TODO(veluca): consider parallelizing the search (processing multiple nodes |
177 | | // at a time). |
178 | 5.12k | while (!nodes.empty()) { |
179 | 5.05k | size_t pos = nodes.back().pos; |
180 | 5.05k | size_t begin = nodes.back().begin; |
181 | 5.05k | size_t end = nodes.back().end; |
182 | 5.05k | uint64_t used_properties = nodes.back().used_properties; |
183 | 5.05k | StaticPropRange static_prop_range = nodes.back().static_prop_range; |
184 | 5.05k | nodes.pop_back(); |
185 | 5.05k | if (begin == end) continue; |
186 | | |
187 | 5.05k | struct SplitInfo { |
188 | 5.05k | size_t prop = 0; |
189 | 5.05k | uint32_t val = 0; |
190 | 5.05k | size_t pos = 0; |
191 | 5.05k | float lcost = std::numeric_limits<float>::max(); |
192 | 5.05k | float rcost = std::numeric_limits<float>::max(); |
193 | 5.05k | Predictor lpred = Predictor::Zero; |
194 | 5.05k | Predictor rpred = Predictor::Zero; |
195 | 211k | float Cost() const { return lcost + rcost; } Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const Line | Count | Source | 195 | 211k | float Cost() const { return lcost + rcost; } |
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const |
196 | 5.05k | }; |
197 | | |
198 | 5.05k | SplitInfo best_split_static_constant; |
199 | 5.05k | SplitInfo best_split_static; |
200 | 5.05k | SplitInfo best_split_nonstatic; |
201 | 5.05k | SplitInfo best_split_nowp; |
202 | | |
203 | 5.05k | JXL_DASSERT(begin <= end); |
204 | 5.05k | JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); |
205 | | |
206 | | // Compute the maximum token in the range. |
207 | 5.05k | size_t max_symbols = 0; |
208 | 10.1k | for (size_t pred = 0; pred < num_predictors; pred++) { |
209 | 1.12M | for (size_t i = begin; i < end; i++) { |
210 | 1.11M | uint32_t tok = tree_samples.Token(pred, i); |
211 | 1.11M | max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; |
212 | 1.11M | } |
213 | 5.05k | } |
214 | 5.05k | max_symbols = Padded(max_symbols); |
215 | 5.05k | std::vector<int32_t> counts(max_symbols * num_predictors); |
216 | 5.05k | std::vector<uint32_t> tot_extra_bits(num_predictors); |
217 | 10.1k | for (size_t pred = 0; pred < num_predictors; pred++) { |
218 | 5.05k | size_t extra_bits = 0; |
219 | 5.05k | const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred); |
220 | 1.12M | for (size_t i = begin; i < end; i++) { |
221 | 1.11M | const ResidualToken& rt = rtokens[i]; |
222 | 1.11M | size_t count = tree_samples.Count(i); |
223 | 1.11M | size_t eb = rt.nbits * count; |
224 | 1.11M | counts[pred * max_symbols + rt.tok] += count; |
225 | 1.11M | extra_bits += eb; |
226 | 1.11M | } |
227 | 5.05k | tot_extra_bits[pred] = extra_bits; |
228 | 5.05k | } |
229 | | |
230 | 5.05k | float base_bits; |
231 | 5.05k | { |
232 | 5.05k | size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); |
233 | 5.05k | base_bits = |
234 | 5.05k | EstimateBits(counts.data() + pred * max_symbols, max_symbols) + |
235 | 5.05k | tot_extra_bits[pred]; |
236 | 5.05k | } |
237 | | |
238 | 5.05k | SplitInfo *best = &best_split_nonstatic; |
239 | | |
240 | 5.05k | SplitInfo forced_split; |
241 | | // The multiplier ranges cut halfway through the current ranges of static |
242 | | // properties. We do this even if the current node is not a leaf, to |
243 | | // minimize the number of nodes in the resulting tree. |
244 | 5.05k | for (const auto &mmi : mul_info) { |
245 | 5.03k | uint32_t axis; |
246 | 5.03k | uint32_t val; |
247 | 5.03k | IntersectionType t = |
248 | 5.03k | BoxIntersects(static_prop_range, mmi.range, axis, val); |
249 | 5.03k | if (t == IntersectionType::kNone) continue; |
250 | 5.03k | if (t == IntersectionType::kInside) { |
251 | 5.03k | (*tree)[pos].multiplier = mmi.multiplier; |
252 | 5.03k | break; |
253 | 5.03k | } |
254 | 0 | if (t == IntersectionType::kPartial) { |
255 | 0 | JXL_DASSERT(axis < kNumStaticProperties); |
256 | 0 | forced_split.val = tree_samples.QuantizeStaticProperty(axis, val); |
257 | 0 | forced_split.prop = axis; |
258 | 0 | forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; |
259 | 0 | forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; |
260 | 0 | best = &forced_split; |
261 | 0 | best->pos = begin; |
262 | 0 | JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); |
263 | 0 | if (best->prop < tree_samples.NumStaticProps()) { |
264 | 0 | for (size_t x = begin; x < end; x++) { |
265 | 0 | if (tree_samples.Property<true>(best->prop, x) <= best->val) { |
266 | 0 | best->pos++; |
267 | 0 | } |
268 | 0 | } |
269 | 0 | } else { |
270 | 0 | size_t prop = best->prop - tree_samples.NumStaticProps(); |
271 | 0 | for (size_t x = begin; x < end; x++) { |
272 | 0 | if (tree_samples.Property<false>(prop, x) <= best->val) { |
273 | 0 | best->pos++; |
274 | 0 | } |
275 | 0 | } |
276 | 0 | } |
277 | 0 | break; |
278 | 0 | } |
279 | 0 | } |
280 | | |
281 | 5.05k | if (best != &forced_split) { |
282 | 5.05k | std::vector<int> prop_value_used_count; |
283 | 5.05k | std::vector<int> count_increase; |
284 | 5.05k | std::vector<size_t> extra_bits_increase; |
285 | | // For each property, compute which of its values are used, and what |
286 | | // tokens correspond to those usages. Then, iterate through the values, |
287 | | // and compute the entropy of each side of the split (of the form `prop > |
288 | | // threshold`). Finally, find the split that minimizes the cost. |
289 | 5.05k | struct CostInfo { |
290 | 5.05k | float cost = std::numeric_limits<float>::max(); |
291 | 5.05k | float extra_cost = 0; |
292 | 371k | float Cost() const { return cost + extra_cost; } Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const Line | Count | Source | 292 | 371k | float Cost() const { return cost + extra_cost; } |
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const |
293 | 5.05k | Predictor pred; // will be uninitialized in some cases, but never used. |
294 | 5.05k | }; |
295 | 5.05k | std::vector<CostInfo> costs_l; |
296 | 5.05k | std::vector<CostInfo> costs_r; |
297 | | |
298 | 5.05k | std::vector<int32_t> counts_above(max_symbols); |
299 | 5.05k | std::vector<int32_t> counts_below(max_symbols); |
300 | | |
301 | | // The lower the threshold, the higher the expected noisiness of the |
302 | | // estimate. Thus, discourage changing predictors. |
303 | 5.05k | float change_pred_penalty = 800.0f / (100.0f + threshold); |
304 | 40.1k | for (size_t prop = 0; prop < num_properties && base_bits > threshold; |
305 | 35.0k | prop++) { |
306 | 35.0k | costs_l.clear(); |
307 | 35.0k | costs_r.clear(); |
308 | 35.0k | size_t prop_size = tree_samples.NumPropertyValues(prop); |
309 | 35.0k | if (extra_bits_increase.size() < prop_size) { |
310 | 10.9k | count_increase.resize(prop_size * max_symbols); |
311 | 10.9k | extra_bits_increase.resize(prop_size); |
312 | 10.9k | } |
313 | | // Clear prop_value_used_count (which cannot be cleared "on the go") |
314 | 35.0k | prop_value_used_count.clear(); |
315 | 35.0k | prop_value_used_count.resize(prop_size); |
316 | | |
317 | 35.0k | size_t first_used = prop_size; |
318 | 35.0k | size_t last_used = 0; |
319 | | |
320 | | // TODO(veluca): consider finding multiple splits along a single |
321 | | // property at the same time, possibly with a bottom-up approach. |
322 | 35.0k | if (prop < tree_samples.NumStaticProps()) { |
323 | 2.24M | for (size_t i = begin; i < end; i++) { |
324 | 2.23M | size_t p = tree_samples.Property<true>(prop, i); |
325 | 2.23M | prop_value_used_count[p]++; |
326 | 2.23M | last_used = std::max(last_used, p); |
327 | 2.23M | first_used = std::min(first_used, p); |
328 | 2.23M | } |
329 | 25.0k | } else { |
330 | 25.0k | size_t prop_idx = prop - tree_samples.NumStaticProps(); |
331 | 5.60M | for (size_t i = begin; i < end; i++) { |
332 | 5.58M | size_t p = tree_samples.Property<false>(prop_idx, i); |
333 | 5.58M | prop_value_used_count[p]++; |
334 | 5.58M | last_used = std::max(last_used, p); |
335 | 5.58M | first_used = std::min(first_used, p); |
336 | 5.58M | } |
337 | 25.0k | } |
338 | 35.0k | costs_l.resize(last_used - first_used); |
339 | 35.0k | costs_r.resize(last_used - first_used); |
340 | | // For all predictors, compute the right and left costs of each split. |
341 | 70.1k | for (size_t pred = 0; pred < num_predictors; pred++) { |
342 | | // Compute cost and histogram increments for each property value. |
343 | 35.0k | const std::vector<ResidualToken> &rtokens = |
344 | 35.0k | tree_samples.RTokens(pred); |
345 | 35.0k | if (prop < tree_samples.NumStaticProps()) { |
346 | 10.0k | CollectExtraBitsIncrease<true>(tree_samples, rtokens, |
347 | 10.0k | count_increase, extra_bits_increase, |
348 | 10.0k | begin, end, prop, max_symbols); |
349 | 25.0k | } else { |
350 | 25.0k | CollectExtraBitsIncrease<false>( |
351 | 25.0k | tree_samples, rtokens, count_increase, extra_bits_increase, |
352 | 25.0k | begin, end, prop - tree_samples.NumStaticProps(), max_symbols); |
353 | 25.0k | } |
354 | 35.0k | memcpy(counts_above.data(), counts.data() + pred * max_symbols, |
355 | 35.0k | max_symbols * sizeof counts_above[0]); |
356 | 35.0k | memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); |
357 | 35.0k | size_t extra_bits_below = 0; |
358 | | // Exclude last used: this ensures neither counts_above nor |
359 | | // counts_below is empty. |
360 | 329k | for (size_t i = first_used; i < last_used; i++) { |
361 | 294k | if (!prop_value_used_count[i]) continue; |
362 | 185k | extra_bits_below += extra_bits_increase[i]; |
363 | | // The increase for this property value has been used, and will not |
364 | | // be used again: clear it. Also below. |
365 | 185k | extra_bits_increase[i] = 0; |
366 | 9.00M | for (size_t sym = 0; sym < max_symbols; sym++) { |
367 | 8.81M | counts_above[sym] -= count_increase[i * max_symbols + sym]; |
368 | 8.81M | counts_below[sym] += count_increase[i * max_symbols + sym]; |
369 | 8.81M | count_increase[i * max_symbols + sym] = 0; |
370 | 8.81M | } |
371 | 185k | float rcost = EstimateBits(counts_above.data(), max_symbols) + |
372 | 185k | tot_extra_bits[pred] - extra_bits_below; |
373 | 185k | float lcost = EstimateBits(counts_below.data(), max_symbols) + |
374 | 185k | extra_bits_below; |
375 | 185k | JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); |
376 | 185k | float penalty = 0; |
377 | | // Never discourage moving away from the Weighted predictor. |
378 | 185k | if (tree_samples.PredictorFromIndex(pred) != |
379 | 185k | (*tree)[pos].predictor && |
380 | 185k | (*tree)[pos].predictor != Predictor::Weighted) { |
381 | 0 | penalty = change_pred_penalty; |
382 | 0 | } |
383 | | // If everything else is equal, disfavour Weighted (slower) and |
384 | | // favour Zero (faster if it's the only predictor used in a |
385 | | // group+channel combination) |
386 | 185k | if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { |
387 | 0 | penalty += 1e-8; |
388 | 0 | } |
389 | 185k | if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { |
390 | 0 | penalty -= 1e-8; |
391 | 0 | } |
392 | 185k | if (rcost + penalty < costs_r[i - first_used].Cost()) { |
393 | 185k | costs_r[i - first_used].cost = rcost; |
394 | 185k | costs_r[i - first_used].extra_cost = penalty; |
395 | 185k | costs_r[i - first_used].pred = |
396 | 185k | tree_samples.PredictorFromIndex(pred); |
397 | 185k | } |
398 | 185k | if (lcost + penalty < costs_l[i - first_used].Cost()) { |
399 | 185k | costs_l[i - first_used].cost = lcost; |
400 | 185k | costs_l[i - first_used].extra_cost = penalty; |
401 | 185k | costs_l[i - first_used].pred = |
402 | 185k | tree_samples.PredictorFromIndex(pred); |
403 | 185k | } |
404 | 185k | } |
405 | 35.0k | } |
406 | | // Iterate through the possible splits and find the one with minimum sum |
407 | | // of costs of the two sides. |
408 | 35.0k | size_t split = begin; |
409 | 329k | for (size_t i = first_used; i < last_used; i++) { |
410 | 294k | if (!prop_value_used_count[i]) continue; |
411 | 185k | split += prop_value_used_count[i]; |
412 | 185k | float rcost = costs_r[i - first_used].cost; |
413 | 185k | float lcost = costs_l[i - first_used].cost; |
414 | | // WP was not used + we would use the WP property or predictor |
415 | 185k | bool adds_wp = |
416 | 185k | (tree_samples.PropertyFromIndex(prop) == kWPProp && |
417 | 185k | (used_properties & (1LU << prop)) == 0) || |
418 | 185k | ((costs_l[i - first_used].pred == Predictor::Weighted || |
419 | 137k | costs_r[i - first_used].pred == Predictor::Weighted) && |
420 | 137k | (*tree)[pos].predictor != Predictor::Weighted); |
421 | 185k | bool zero_entropy_side = rcost == 0 || lcost == 0; |
422 | | |
423 | 185k | SplitInfo &best_ref = |
424 | 185k | tree_samples.PropertyFromIndex(prop) < kNumStaticProperties |
425 | 185k | ? (zero_entropy_side ? best_split_static_constant |
426 | 1.07k | : best_split_static) |
427 | 185k | : (adds_wp ? best_split_nonstatic : best_split_nowp); |
428 | 185k | if (lcost + rcost < best_ref.Cost()) { |
429 | 37.3k | best_ref.prop = prop; |
430 | 37.3k | best_ref.val = i; |
431 | 37.3k | best_ref.pos = split; |
432 | 37.3k | best_ref.lcost = lcost; |
433 | 37.3k | best_ref.lpred = costs_l[i - first_used].pred; |
434 | 37.3k | best_ref.rcost = rcost; |
435 | 37.3k | best_ref.rpred = costs_r[i - first_used].pred; |
436 | 37.3k | } |
437 | 185k | } |
438 | | // Clear extra_bits_increase and cost_increase for last_used. |
439 | 35.0k | extra_bits_increase[last_used] = 0; |
440 | 1.72M | for (size_t sym = 0; sym < max_symbols; sym++) { |
441 | 1.69M | count_increase[last_used * max_symbols + sym] = 0; |
442 | 1.69M | } |
443 | 35.0k | } |
444 | | |
445 | | // Try to avoid introducing WP. |
446 | 5.05k | if (best_split_nowp.Cost() + threshold < base_bits && |
447 | 5.05k | best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { |
448 | 2.31k | best = &best_split_nowp; |
449 | 2.31k | } |
450 | | // Split along static props if possible and not significantly more |
451 | | // expensive. |
452 | 5.05k | if (best_split_static.Cost() + threshold < base_bits && |
453 | 5.05k | best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { |
454 | 318 | best = &best_split_static; |
455 | 318 | } |
456 | | // Split along static props to create constant nodes if possible. |
457 | 5.05k | if (best_split_static_constant.Cost() + threshold < base_bits) { |
458 | 0 | best = &best_split_static_constant; |
459 | 0 | } |
460 | 5.05k | } |
461 | | |
462 | 5.05k | if (best->Cost() + threshold < base_bits) { |
463 | 2.49k | uint32_t p = tree_samples.PropertyFromIndex(best->prop); |
464 | 2.49k | pixel_type dequant = |
465 | 2.49k | tree_samples.UnquantizeProperty(best->prop, best->val); |
466 | | // Split node and try to split children. |
467 | 2.49k | MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); |
468 | | // "Sort" according to winning property |
469 | 2.49k | if (best->prop < tree_samples.NumStaticProps()) { |
470 | 318 | SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop, |
471 | 318 | best->val); |
472 | 2.17k | } else { |
473 | 2.17k | SplitTreeSamples<false>(tree_samples, begin, best->pos, end, |
474 | 2.17k | best->prop - tree_samples.NumStaticProps(), |
475 | 2.17k | best->val); |
476 | 2.17k | } |
477 | 2.49k | if (p >= kNumStaticProperties) { |
478 | 2.17k | used_properties |= 1 << best->prop; |
479 | 2.17k | } |
480 | 2.49k | auto new_sp_range = static_prop_range; |
481 | 2.49k | if (p < kNumStaticProperties) { |
482 | 318 | JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); |
483 | 318 | new_sp_range[p][1] = dequant + 1; |
484 | 318 | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); |
485 | 318 | } |
486 | 2.49k | nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, |
487 | 2.49k | used_properties, new_sp_range}); |
488 | 2.49k | new_sp_range = static_prop_range; |
489 | 2.49k | if (p < kNumStaticProperties) { |
490 | 318 | JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); |
491 | 318 | new_sp_range[p][0] = dequant + 1; |
492 | 318 | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); |
493 | 318 | } |
494 | 2.49k | nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, |
495 | 2.49k | used_properties, new_sp_range}); |
496 | 2.49k | } |
497 | 5.05k | } |
498 | 73 | } Unexecuted instantiation: jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) Line | Count | Source | 161 | 73 | float fast_decode_multiplier, Tree *tree) { | 162 | 73 | struct NodeInfo { | 163 | 73 | size_t pos; | 164 | 73 | size_t begin; | 165 | 73 | size_t end; | 166 | 73 | uint64_t used_properties; | 167 | 73 | StaticPropRange static_prop_range; | 168 | 73 | }; | 169 | 73 | std::vector<NodeInfo> nodes; | 170 | 73 | nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, | 171 | 73 | initial_static_prop_range}); | 172 | | | 173 | 73 | size_t num_predictors = tree_samples.NumPredictors(); | 174 | 73 | size_t num_properties = tree_samples.NumProperties(); | 175 | | | 176 | | // TODO(veluca): consider parallelizing the search (processing multiple nodes | 177 | | // at a time). | 178 | 5.12k | while (!nodes.empty()) { | 179 | 5.05k | size_t pos = nodes.back().pos; | 180 | 5.05k | size_t begin = nodes.back().begin; | 181 | 5.05k | size_t end = nodes.back().end; | 182 | 5.05k | uint64_t used_properties = nodes.back().used_properties; | 183 | 5.05k | StaticPropRange static_prop_range = nodes.back().static_prop_range; | 184 | 5.05k | nodes.pop_back(); | 185 | 5.05k | if (begin == end) continue; | 186 | | | 187 | 5.05k | struct SplitInfo { | 188 | 5.05k | size_t prop = 0; | 189 | 5.05k | uint32_t val = 0; | 190 | 5.05k | size_t pos = 0; | 191 | 5.05k | float lcost = std::numeric_limits<float>::max(); | 192 | 5.05k | float rcost = std::numeric_limits<float>::max(); | 193 | 5.05k | Predictor lpred = Predictor::Zero; | 194 | 5.05k | Predictor rpred = Predictor::Zero; | 195 | 5.05k | float Cost() const { return lcost + rcost; } | 196 | 5.05k | }; | 197 | | | 198 | 5.05k | SplitInfo best_split_static_constant; | 199 | 5.05k | SplitInfo best_split_static; | 200 | 5.05k | SplitInfo best_split_nonstatic; | 201 | 5.05k | SplitInfo best_split_nowp; | 202 | | | 203 | 5.05k | JXL_DASSERT(begin <= end); | 204 | 5.05k | JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); | 205 | | | 206 | | // Compute the maximum token in the range. | 207 | 5.05k | size_t max_symbols = 0; | 208 | 10.1k | for (size_t pred = 0; pred < num_predictors; pred++) { | 209 | 1.12M | for (size_t i = begin; i < end; i++) { | 210 | 1.11M | uint32_t tok = tree_samples.Token(pred, i); | 211 | 1.11M | max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; | 212 | 1.11M | } | 213 | 5.05k | } | 214 | 5.05k | max_symbols = Padded(max_symbols); | 215 | 5.05k | std::vector<int32_t> counts(max_symbols * num_predictors); | 216 | 5.05k | std::vector<uint32_t> tot_extra_bits(num_predictors); | 217 | 10.1k | for (size_t pred = 0; pred < num_predictors; pred++) { | 218 | 5.05k | size_t extra_bits = 0; | 219 | 5.05k | const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred); | 220 | 1.12M | for (size_t i = begin; i < end; i++) { | 221 | 1.11M | const ResidualToken& rt = rtokens[i]; | 222 | 1.11M | size_t count = tree_samples.Count(i); | 223 | 1.11M | size_t eb = rt.nbits * count; | 224 | 1.11M | counts[pred * max_symbols + rt.tok] += count; | 225 | 1.11M | extra_bits += eb; | 226 | 1.11M | } | 227 | 5.05k | tot_extra_bits[pred] = extra_bits; | 228 | 5.05k | } | 229 | | | 230 | 5.05k | float base_bits; | 231 | 5.05k | { | 232 | 5.05k | size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); | 233 | 5.05k | base_bits = | 234 | 5.05k | EstimateBits(counts.data() + pred * max_symbols, max_symbols) + | 235 | 5.05k | tot_extra_bits[pred]; | 236 | 5.05k | } | 237 | | | 238 | 5.05k | SplitInfo *best = &best_split_nonstatic; | 239 | | | 240 | 5.05k | SplitInfo forced_split; | 241 | | // The multiplier ranges cut halfway through the current ranges of static | 242 | | // properties. We do this even if the current node is not a leaf, to | 243 | | // minimize the number of nodes in the resulting tree. | 244 | 5.05k | for (const auto &mmi : mul_info) { | 245 | 5.03k | uint32_t axis; | 246 | 5.03k | uint32_t val; | 247 | 5.03k | IntersectionType t = | 248 | 5.03k | BoxIntersects(static_prop_range, mmi.range, axis, val); | 249 | 5.03k | if (t == IntersectionType::kNone) continue; | 250 | 5.03k | if (t == IntersectionType::kInside) { | 251 | 5.03k | (*tree)[pos].multiplier = mmi.multiplier; | 252 | 5.03k | break; | 253 | 5.03k | } | 254 | 0 | if (t == IntersectionType::kPartial) { | 255 | 0 | JXL_DASSERT(axis < kNumStaticProperties); | 256 | 0 | forced_split.val = tree_samples.QuantizeStaticProperty(axis, val); | 257 | 0 | forced_split.prop = axis; | 258 | 0 | forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; | 259 | 0 | forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; | 260 | 0 | best = &forced_split; | 261 | 0 | best->pos = begin; | 262 | 0 | JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); | 263 | 0 | if (best->prop < tree_samples.NumStaticProps()) { | 264 | 0 | for (size_t x = begin; x < end; x++) { | 265 | 0 | if (tree_samples.Property<true>(best->prop, x) <= best->val) { | 266 | 0 | best->pos++; | 267 | 0 | } | 268 | 0 | } | 269 | 0 | } else { | 270 | 0 | size_t prop = best->prop - tree_samples.NumStaticProps(); | 271 | 0 | for (size_t x = begin; x < end; x++) { | 272 | 0 | if (tree_samples.Property<false>(prop, x) <= best->val) { | 273 | 0 | best->pos++; | 274 | 0 | } | 275 | 0 | } | 276 | 0 | } | 277 | 0 | break; | 278 | 0 | } | 279 | 0 | } | 280 | | | 281 | 5.05k | if (best != &forced_split) { | 282 | 5.05k | std::vector<int> prop_value_used_count; | 283 | 5.05k | std::vector<int> count_increase; | 284 | 5.05k | std::vector<size_t> extra_bits_increase; | 285 | | // For each property, compute which of its values are used, and what | 286 | | // tokens correspond to those usages. Then, iterate through the values, | 287 | | // and compute the entropy of each side of the split (of the form `prop > | 288 | | // threshold`). Finally, find the split that minimizes the cost. | 289 | 5.05k | struct CostInfo { | 290 | 5.05k | float cost = std::numeric_limits<float>::max(); | 291 | 5.05k | float extra_cost = 0; | 292 | 5.05k | float Cost() const { return cost + extra_cost; } | 293 | 5.05k | Predictor pred; // will be uninitialized in some cases, but never used. | 294 | 5.05k | }; | 295 | 5.05k | std::vector<CostInfo> costs_l; | 296 | 5.05k | std::vector<CostInfo> costs_r; | 297 | | | 298 | 5.05k | std::vector<int32_t> counts_above(max_symbols); | 299 | 5.05k | std::vector<int32_t> counts_below(max_symbols); | 300 | | | 301 | | // The lower the threshold, the higher the expected noisiness of the | 302 | | // estimate. Thus, discourage changing predictors. | 303 | 5.05k | float change_pred_penalty = 800.0f / (100.0f + threshold); | 304 | 40.1k | for (size_t prop = 0; prop < num_properties && base_bits > threshold; | 305 | 35.0k | prop++) { | 306 | 35.0k | costs_l.clear(); | 307 | 35.0k | costs_r.clear(); | 308 | 35.0k | size_t prop_size = tree_samples.NumPropertyValues(prop); | 309 | 35.0k | if (extra_bits_increase.size() < prop_size) { | 310 | 10.9k | count_increase.resize(prop_size * max_symbols); | 311 | 10.9k | extra_bits_increase.resize(prop_size); | 312 | 10.9k | } | 313 | | // Clear prop_value_used_count (which cannot be cleared "on the go") | 314 | 35.0k | prop_value_used_count.clear(); | 315 | 35.0k | prop_value_used_count.resize(prop_size); | 316 | | | 317 | 35.0k | size_t first_used = prop_size; | 318 | 35.0k | size_t last_used = 0; | 319 | | | 320 | | // TODO(veluca): consider finding multiple splits along a single | 321 | | // property at the same time, possibly with a bottom-up approach. | 322 | 35.0k | if (prop < tree_samples.NumStaticProps()) { | 323 | 2.24M | for (size_t i = begin; i < end; i++) { | 324 | 2.23M | size_t p = tree_samples.Property<true>(prop, i); | 325 | 2.23M | prop_value_used_count[p]++; | 326 | 2.23M | last_used = std::max(last_used, p); | 327 | 2.23M | first_used = std::min(first_used, p); | 328 | 2.23M | } | 329 | 25.0k | } else { | 330 | 25.0k | size_t prop_idx = prop - tree_samples.NumStaticProps(); | 331 | 5.60M | for (size_t i = begin; i < end; i++) { | 332 | 5.58M | size_t p = tree_samples.Property<false>(prop_idx, i); | 333 | 5.58M | prop_value_used_count[p]++; | 334 | 5.58M | last_used = std::max(last_used, p); | 335 | 5.58M | first_used = std::min(first_used, p); | 336 | 5.58M | } | 337 | 25.0k | } | 338 | 35.0k | costs_l.resize(last_used - first_used); | 339 | 35.0k | costs_r.resize(last_used - first_used); | 340 | | // For all predictors, compute the right and left costs of each split. | 341 | 70.1k | for (size_t pred = 0; pred < num_predictors; pred++) { | 342 | | // Compute cost and histogram increments for each property value. | 343 | 35.0k | const std::vector<ResidualToken> &rtokens = | 344 | 35.0k | tree_samples.RTokens(pred); | 345 | 35.0k | if (prop < tree_samples.NumStaticProps()) { | 346 | 10.0k | CollectExtraBitsIncrease<true>(tree_samples, rtokens, | 347 | 10.0k | count_increase, extra_bits_increase, | 348 | 10.0k | begin, end, prop, max_symbols); | 349 | 25.0k | } else { | 350 | 25.0k | CollectExtraBitsIncrease<false>( | 351 | 25.0k | tree_samples, rtokens, count_increase, extra_bits_increase, | 352 | 25.0k | begin, end, prop - tree_samples.NumStaticProps(), max_symbols); | 353 | 25.0k | } | 354 | 35.0k | memcpy(counts_above.data(), counts.data() + pred * max_symbols, | 355 | 35.0k | max_symbols * sizeof counts_above[0]); | 356 | 35.0k | memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); | 357 | 35.0k | size_t extra_bits_below = 0; | 358 | | // Exclude last used: this ensures neither counts_above nor | 359 | | // counts_below is empty. | 360 | 329k | for (size_t i = first_used; i < last_used; i++) { | 361 | 294k | if (!prop_value_used_count[i]) continue; | 362 | 185k | extra_bits_below += extra_bits_increase[i]; | 363 | | // The increase for this property value has been used, and will not | 364 | | // be used again: clear it. Also below. | 365 | 185k | extra_bits_increase[i] = 0; | 366 | 9.00M | for (size_t sym = 0; sym < max_symbols; sym++) { | 367 | 8.81M | counts_above[sym] -= count_increase[i * max_symbols + sym]; | 368 | 8.81M | counts_below[sym] += count_increase[i * max_symbols + sym]; | 369 | 8.81M | count_increase[i * max_symbols + sym] = 0; | 370 | 8.81M | } | 371 | 185k | float rcost = EstimateBits(counts_above.data(), max_symbols) + | 372 | 185k | tot_extra_bits[pred] - extra_bits_below; | 373 | 185k | float lcost = EstimateBits(counts_below.data(), max_symbols) + | 374 | 185k | extra_bits_below; | 375 | 185k | JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); | 376 | 185k | float penalty = 0; | 377 | | // Never discourage moving away from the Weighted predictor. | 378 | 185k | if (tree_samples.PredictorFromIndex(pred) != | 379 | 185k | (*tree)[pos].predictor && | 380 | 185k | (*tree)[pos].predictor != Predictor::Weighted) { | 381 | 0 | penalty = change_pred_penalty; | 382 | 0 | } | 383 | | // If everything else is equal, disfavour Weighted (slower) and | 384 | | // favour Zero (faster if it's the only predictor used in a | 385 | | // group+channel combination) | 386 | 185k | if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { | 387 | 0 | penalty += 1e-8; | 388 | 0 | } | 389 | 185k | if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { | 390 | 0 | penalty -= 1e-8; | 391 | 0 | } | 392 | 185k | if (rcost + penalty < costs_r[i - first_used].Cost()) { | 393 | 185k | costs_r[i - first_used].cost = rcost; | 394 | 185k | costs_r[i - first_used].extra_cost = penalty; | 395 | 185k | costs_r[i - first_used].pred = | 396 | 185k | tree_samples.PredictorFromIndex(pred); | 397 | 185k | } | 398 | 185k | if (lcost + penalty < costs_l[i - first_used].Cost()) { | 399 | 185k | costs_l[i - first_used].cost = lcost; | 400 | 185k | costs_l[i - first_used].extra_cost = penalty; | 401 | 185k | costs_l[i - first_used].pred = | 402 | 185k | tree_samples.PredictorFromIndex(pred); | 403 | 185k | } | 404 | 185k | } | 405 | 35.0k | } | 406 | | // Iterate through the possible splits and find the one with minimum sum | 407 | | // of costs of the two sides. | 408 | 35.0k | size_t split = begin; | 409 | 329k | for (size_t i = first_used; i < last_used; i++) { | 410 | 294k | if (!prop_value_used_count[i]) continue; | 411 | 185k | split += prop_value_used_count[i]; | 412 | 185k | float rcost = costs_r[i - first_used].cost; | 413 | 185k | float lcost = costs_l[i - first_used].cost; | 414 | | // WP was not used + we would use the WP property or predictor | 415 | 185k | bool adds_wp = | 416 | 185k | (tree_samples.PropertyFromIndex(prop) == kWPProp && | 417 | 185k | (used_properties & (1LU << prop)) == 0) || | 418 | 185k | ((costs_l[i - first_used].pred == Predictor::Weighted || | 419 | 137k | costs_r[i - first_used].pred == Predictor::Weighted) && | 420 | 137k | (*tree)[pos].predictor != Predictor::Weighted); | 421 | 185k | bool zero_entropy_side = rcost == 0 || lcost == 0; | 422 | | | 423 | 185k | SplitInfo &best_ref = | 424 | 185k | tree_samples.PropertyFromIndex(prop) < kNumStaticProperties | 425 | 185k | ? (zero_entropy_side ? best_split_static_constant | 426 | 1.07k | : best_split_static) | 427 | 185k | : (adds_wp ? best_split_nonstatic : best_split_nowp); | 428 | 185k | if (lcost + rcost < best_ref.Cost()) { | 429 | 37.3k | best_ref.prop = prop; | 430 | 37.3k | best_ref.val = i; | 431 | 37.3k | best_ref.pos = split; | 432 | 37.3k | best_ref.lcost = lcost; | 433 | 37.3k | best_ref.lpred = costs_l[i - first_used].pred; | 434 | 37.3k | best_ref.rcost = rcost; | 435 | 37.3k | best_ref.rpred = costs_r[i - first_used].pred; | 436 | 37.3k | } | 437 | 185k | } | 438 | | // Clear extra_bits_increase and cost_increase for last_used. | 439 | 35.0k | extra_bits_increase[last_used] = 0; | 440 | 1.72M | for (size_t sym = 0; sym < max_symbols; sym++) { | 441 | 1.69M | count_increase[last_used * max_symbols + sym] = 0; | 442 | 1.69M | } | 443 | 35.0k | } | 444 | | | 445 | | // Try to avoid introducing WP. | 446 | 5.05k | if (best_split_nowp.Cost() + threshold < base_bits && | 447 | 5.05k | best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { | 448 | 2.31k | best = &best_split_nowp; | 449 | 2.31k | } | 450 | | // Split along static props if possible and not significantly more | 451 | | // expensive. | 452 | 5.05k | if (best_split_static.Cost() + threshold < base_bits && | 453 | 5.05k | best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { | 454 | 318 | best = &best_split_static; | 455 | 318 | } | 456 | | // Split along static props to create constant nodes if possible. | 457 | 5.05k | if (best_split_static_constant.Cost() + threshold < base_bits) { | 458 | 0 | best = &best_split_static_constant; | 459 | 0 | } | 460 | 5.05k | } | 461 | | | 462 | 5.05k | if (best->Cost() + threshold < base_bits) { | 463 | 2.49k | uint32_t p = tree_samples.PropertyFromIndex(best->prop); | 464 | 2.49k | pixel_type dequant = | 465 | 2.49k | tree_samples.UnquantizeProperty(best->prop, best->val); | 466 | | // Split node and try to split children. | 467 | 2.49k | MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); | 468 | | // "Sort" according to winning property | 469 | 2.49k | if (best->prop < tree_samples.NumStaticProps()) { | 470 | 318 | SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop, | 471 | 318 | best->val); | 472 | 2.17k | } else { | 473 | 2.17k | SplitTreeSamples<false>(tree_samples, begin, best->pos, end, | 474 | 2.17k | best->prop - tree_samples.NumStaticProps(), | 475 | 2.17k | best->val); | 476 | 2.17k | } | 477 | 2.49k | if (p >= kNumStaticProperties) { | 478 | 2.17k | used_properties |= 1 << best->prop; | 479 | 2.17k | } | 480 | 2.49k | auto new_sp_range = static_prop_range; | 481 | 2.49k | if (p < kNumStaticProperties) { | 482 | 318 | JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); | 483 | 318 | new_sp_range[p][1] = dequant + 1; | 484 | 318 | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); | 485 | 318 | } | 486 | 2.49k | nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, | 487 | 2.49k | used_properties, new_sp_range}); | 488 | 2.49k | new_sp_range = static_prop_range; | 489 | 2.49k | if (p < kNumStaticProperties) { | 490 | 318 | JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); | 491 | 318 | new_sp_range[p][0] = dequant + 1; | 492 | 318 | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); | 493 | 318 | } | 494 | 2.49k | nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, | 495 | 2.49k | used_properties, new_sp_range}); | 496 | 2.49k | } | 497 | 5.05k | } | 498 | 73 | } |
Unexecuted instantiation: jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) |
499 | | |
500 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
501 | | } // namespace HWY_NAMESPACE |
502 | | } // namespace jxl |
503 | | HWY_AFTER_NAMESPACE(); |
504 | | |
505 | | #if HWY_ONCE |
506 | | namespace jxl { |
507 | | |
508 | | HWY_EXPORT(FindBestSplit); // Local function. |
509 | | |
510 | | Status ComputeBestTree(TreeSamples &tree_samples, float threshold, |
511 | | const std::vector<ModularMultiplierInfo> &mul_info, |
512 | | StaticPropRange static_prop_range, |
513 | 73 | float fast_decode_multiplier, Tree *tree) { |
514 | | // TODO(veluca): take into account that different contexts can have different |
515 | | // uint configs. |
516 | | // |
517 | | // Initialize tree. |
518 | 73 | tree->emplace_back(); |
519 | 73 | tree->back().property = -1; |
520 | 73 | tree->back().predictor = tree_samples.PredictorFromIndex(0); |
521 | 73 | tree->back().predictor_offset = 0; |
522 | 73 | tree->back().multiplier = 1; |
523 | 73 | JXL_ENSURE(tree_samples.NumProperties() < 64); |
524 | | |
525 | 73 | JXL_ENSURE(tree_samples.NumDistinctSamples() <= |
526 | 73 | std::numeric_limits<uint32_t>::max()); |
527 | 73 | HWY_DYNAMIC_DISPATCH(FindBestSplit) |
528 | 73 | (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, |
529 | 73 | tree); |
530 | 73 | return true; |
531 | 73 | } |
532 | | |
533 | | #if JXL_CXX_LANG < JXL_CXX_17 |
534 | | constexpr int32_t TreeSamples::kPropertyRange; |
535 | | constexpr uint32_t TreeSamples::kDedupEntryUnused; |
536 | | #endif |
537 | | |
538 | | Status TreeSamples::SetPredictor(Predictor predictor, |
539 | 74 | ModularOptions::TreeMode wp_tree_mode) { |
540 | 74 | if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { |
541 | 0 | predictors = {Predictor::Weighted}; |
542 | 0 | residuals.resize(1); |
543 | 0 | return true; |
544 | 0 | } |
545 | 74 | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && |
546 | 74 | predictor == Predictor::Weighted) { |
547 | 0 | return JXL_FAILURE("Invalid predictor settings"); |
548 | 0 | } |
549 | 74 | if (predictor == Predictor::Variable) { |
550 | 0 | for (size_t i = 0; i < kNumModularPredictors; i++) { |
551 | 0 | predictors.push_back(static_cast<Predictor>(i)); |
552 | 0 | } |
553 | 0 | std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); |
554 | 0 | std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); |
555 | 74 | } else if (predictor == Predictor::Best) { |
556 | 0 | predictors = {Predictor::Weighted, Predictor::Gradient}; |
557 | 74 | } else { |
558 | 74 | predictors = {predictor}; |
559 | 74 | } |
560 | 74 | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { |
561 | 0 | predictors.erase( |
562 | 0 | std::remove(predictors.begin(), predictors.end(), Predictor::Weighted), |
563 | 0 | predictors.end()); |
564 | 0 | } |
565 | 74 | residuals.resize(predictors.size()); |
566 | 74 | return true; |
567 | 74 | } |
568 | | |
569 | | Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, |
570 | 74 | ModularOptions::TreeMode wp_tree_mode) { |
571 | 74 | props_to_use = properties; |
572 | 74 | if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { |
573 | 0 | props_to_use = {static_cast<uint32_t>(kWPProp)}; |
574 | 0 | } |
575 | 74 | if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { |
576 | 0 | props_to_use = {static_cast<uint32_t>(kGradientProp)}; |
577 | 0 | } |
578 | 74 | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { |
579 | 0 | props_to_use.erase( |
580 | 0 | std::remove(props_to_use.begin(), props_to_use.end(), kWPProp), |
581 | 0 | props_to_use.end()); |
582 | 0 | } |
583 | 74 | if (props_to_use.empty()) { |
584 | 0 | return JXL_FAILURE("Invalid property set configuration"); |
585 | 0 | } |
586 | 74 | num_static_props = 0; |
587 | | // Check that if static properties present, then those are at the beginning. |
588 | 592 | for (size_t i = 0; i < props_to_use.size(); ++i) { |
589 | 518 | uint32_t prop = props_to_use[i]; |
590 | 518 | if (prop < kNumStaticProperties) { |
591 | 136 | JXL_DASSERT(i == prop); |
592 | 136 | num_static_props++; |
593 | 136 | } |
594 | 518 | } |
595 | 74 | props.resize(props_to_use.size() - num_static_props); |
596 | 74 | return true; |
597 | 74 | } |
598 | | |
599 | 201 | void TreeSamples::InitTable(size_t log_size) { |
600 | 201 | size_t size = 1ULL << log_size; |
601 | 201 | if (dedup_table_.size() == size) return; |
602 | 112 | dedup_table_.resize(size, kDedupEntryUnused); |
603 | 51.7k | for (size_t i = 0; i < NumDistinctSamples(); i++) { |
604 | 51.6k | if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { |
605 | 51.6k | AddToTable(i); |
606 | 51.6k | } |
607 | 51.6k | } |
608 | 112 | } |
609 | | |
610 | 593k | bool TreeSamples::AddToTableAndMerge(size_t a) { |
611 | 593k | size_t pos1 = Hash1(a); |
612 | 593k | size_t pos2 = Hash2(a); |
613 | 593k | if (dedup_table_[pos1] != kDedupEntryUnused && |
614 | 593k | IsSameSample(a, dedup_table_[pos1])) { |
615 | 360k | JXL_DASSERT(sample_counts[a] == 1); |
616 | 360k | sample_counts[dedup_table_[pos1]]++; |
617 | | // Remove from hash table samples that are saturated. |
618 | 360k | if (sample_counts[dedup_table_[pos1]] == |
619 | 360k | std::numeric_limits<uint16_t>::max()) { |
620 | 0 | dedup_table_[pos1] = kDedupEntryUnused; |
621 | 0 | } |
622 | 360k | return true; |
623 | 360k | } |
624 | 233k | if (dedup_table_[pos2] != kDedupEntryUnused && |
625 | 233k | IsSameSample(a, dedup_table_[pos2])) { |
626 | 81.6k | JXL_DASSERT(sample_counts[a] == 1); |
627 | 81.6k | sample_counts[dedup_table_[pos2]]++; |
628 | | // Remove from hash table samples that are saturated. |
629 | 81.6k | if (sample_counts[dedup_table_[pos2]] == |
630 | 81.6k | std::numeric_limits<uint16_t>::max()) { |
631 | 0 | dedup_table_[pos2] = kDedupEntryUnused; |
632 | 0 | } |
633 | 81.6k | return true; |
634 | 81.6k | } |
635 | 152k | AddToTable(a); |
636 | 152k | return false; |
637 | 233k | } |
638 | | |
639 | 203k | void TreeSamples::AddToTable(size_t a) { |
640 | 203k | size_t pos1 = Hash1(a); |
641 | 203k | size_t pos2 = Hash2(a); |
642 | 203k | if (dedup_table_[pos1] == kDedupEntryUnused) { |
643 | 107k | dedup_table_[pos1] = a; |
644 | 107k | } else if (dedup_table_[pos2] == kDedupEntryUnused) { |
645 | 60.3k | dedup_table_[pos2] = a; |
646 | 60.3k | } |
647 | 203k | } |
648 | | |
649 | 201 | void TreeSamples::PrepareForSamples(size_t extra_num_samples) { |
650 | 201 | for (auto &res : residuals) { |
651 | 201 | res.reserve(res.size() + extra_num_samples); |
652 | 201 | } |
653 | 588 | for (size_t i = 0; i < num_static_props; ++i) { |
654 | 387 | static_props[i].reserve(static_props[i].size() + extra_num_samples); |
655 | 387 | } |
656 | 1.02k | for (auto &p : props) { |
657 | 1.02k | p.reserve(p.size() + extra_num_samples); |
658 | 1.02k | } |
659 | 201 | size_t total_num_samples = extra_num_samples + sample_counts.size(); |
660 | 201 | size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2); |
661 | 201 | InitTable(next_size); |
662 | 201 | } |
663 | | |
664 | 797k | size_t TreeSamples::Hash1(size_t a) const { |
665 | 797k | constexpr uint64_t constant = 0x1e35a7bd; |
666 | 797k | uint64_t h = constant; |
667 | 797k | for (const auto &r : residuals) { |
668 | 797k | h = h * constant + r[a].tok; |
669 | 797k | h = h * constant + r[a].nbits; |
670 | 797k | } |
671 | 2.34M | for (size_t i = 0; i < num_static_props; ++i) { |
672 | 1.54M | h = h * constant + static_props[i][a]; |
673 | 1.54M | } |
674 | 4.03M | for (const auto &p : props) { |
675 | 4.03M | h = h * constant + p[a]; |
676 | 4.03M | } |
677 | 797k | return (h >> 16) & (dedup_table_.size() - 1); |
678 | 797k | } |
679 | 797k | size_t TreeSamples::Hash2(size_t a) const { |
680 | 797k | constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; |
681 | 797k | uint64_t h = constant; |
682 | 2.34M | for (size_t i = 0; i < num_static_props; ++i) { |
683 | 1.54M | h = h * constant ^ static_props[i][a]; |
684 | 1.54M | } |
685 | 4.03M | for (const auto &p : props) { |
686 | 4.03M | h = h * constant ^ p[a]; |
687 | 4.03M | } |
688 | 797k | for (const auto &r : residuals) { |
689 | 797k | h = h * constant ^ r[a].tok; |
690 | 797k | h = h * constant ^ r[a].nbits; |
691 | 797k | } |
692 | 797k | return (h >> 16) & (dedup_table_.size() - 1); |
693 | 797k | } |
694 | | |
695 | 625k | bool TreeSamples::IsSameSample(size_t a, size_t b) const { |
696 | 625k | bool ret = true; |
697 | 625k | for (const auto &r : residuals) { |
698 | 625k | if (r[a].tok != r[b].tok) { |
699 | 82.5k | ret = false; |
700 | 82.5k | } |
701 | 625k | if (r[a].nbits != r[b].nbits) { |
702 | 62.4k | ret = false; |
703 | 62.4k | } |
704 | 625k | } |
705 | 1.82M | for (size_t i = 0; i < num_static_props; ++i) { |
706 | 1.20M | if (static_props[i][a] != static_props[i][b]) { |
707 | 72.4k | ret = false; |
708 | 72.4k | } |
709 | 1.20M | } |
710 | 3.17M | for (const auto &p : props) { |
711 | 3.17M | if (p[a] != p[b]) { |
712 | 490k | ret = false; |
713 | 490k | } |
714 | 3.17M | } |
715 | 625k | return ret; |
716 | 625k | } |
717 | | |
718 | | void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, |
719 | 593k | const pixel_type_w *predictions) { |
720 | 1.18M | for (size_t i = 0; i < predictors.size(); i++) { |
721 | 593k | pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; |
722 | 593k | uint32_t tok, nbits, bits; |
723 | 593k | HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); |
724 | 593k | JXL_DASSERT(tok < 256); |
725 | 593k | JXL_DASSERT(nbits < 256); |
726 | 593k | residuals[i].emplace_back( |
727 | 593k | ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)}); |
728 | 593k | } |
729 | 1.73M | for (size_t i = 0; i < num_static_props; ++i) { |
730 | 1.13M | static_props[i].push_back(QuantizeStaticProperty(i, properties[i])); |
731 | 1.13M | } |
732 | 3.61M | for (size_t i = num_static_props; i < props_to_use.size(); i++) { |
733 | 3.01M | props[i - num_static_props].push_back(QuantizeProperty(i, properties[props_to_use[i]])); |
734 | 3.01M | } |
735 | 593k | sample_counts.push_back(1); |
736 | 593k | num_samples++; |
737 | 593k | if (AddToTableAndMerge(sample_counts.size() - 1)) { |
738 | 441k | for (auto &r : residuals) r.pop_back(); |
739 | 1.27M | for (size_t i = 0; i < num_static_props; ++i) static_props[i].pop_back(); |
740 | 2.25M | for (auto &p : props) p.pop_back(); |
741 | 441k | sample_counts.pop_back(); |
742 | 441k | } |
743 | 593k | } |
744 | | |
745 | 192k | void TreeSamples::Swap(size_t a, size_t b) { |
746 | 192k | if (a == b) return; |
747 | 192k | for (auto &r : residuals) { |
748 | 192k | std::swap(r[a], r[b]); |
749 | 192k | } |
750 | 578k | for (size_t i = 0; i < num_static_props; ++i) { |
751 | 385k | std::swap(static_props[i][a], static_props[i][b]); |
752 | 385k | } |
753 | 964k | for (auto &p : props) { |
754 | 964k | std::swap(p[a], p[b]); |
755 | 964k | } |
756 | 192k | std::swap(sample_counts[a], sample_counts[b]); |
757 | 192k | } |
758 | | |
759 | | namespace { |
760 | | std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, |
761 | 240 | size_t num_chunks) { |
762 | 240 | if (histogram.empty() || num_chunks == 0) return {}; |
763 | 240 | uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); |
764 | 240 | if (sum == 0) return {}; |
765 | | // TODO(veluca): selecting distinct quantiles is likely not the best |
766 | | // way to go about this. |
767 | 238 | std::vector<int32_t> thresholds; |
768 | 238 | uint64_t cumsum = 0; |
769 | 238 | uint64_t threshold = 1; |
770 | 107k | for (size_t i = 0; i < histogram.size(); i++) { |
771 | 106k | cumsum += histogram[i]; |
772 | 106k | if (cumsum * num_chunks >= threshold * sum) { |
773 | 1.22k | thresholds.push_back(i); |
774 | 12.6k | while (cumsum * num_chunks >= threshold * sum) threshold++; |
775 | 1.22k | } |
776 | 106k | } |
777 | 238 | JXL_DASSERT(thresholds.size() <= num_chunks); |
778 | | // last value collects all histogram and is not really a threshold |
779 | 238 | thresholds.pop_back(); |
780 | 238 | return thresholds; |
781 | 240 | } |
782 | | |
783 | | std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, |
784 | 114 | size_t num_chunks) { |
785 | 114 | if (samples.empty()) return {}; |
786 | 104 | int min = *std::min_element(samples.begin(), samples.end()); |
787 | 104 | constexpr int kRange = 512; |
788 | 104 | min = jxl::Clamp1(min, -kRange, kRange); |
789 | 104 | std::vector<uint32_t> counts(2 * kRange + 1); |
790 | 74.6k | for (int s : samples) { |
791 | 74.6k | uint32_t sample_offset = jxl::Clamp1(s, -kRange, kRange) - min; |
792 | 74.6k | counts[sample_offset]++; |
793 | 74.6k | } |
794 | 104 | std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); |
795 | 859 | for (auto &v : thresholds) v += min; |
796 | 104 | return thresholds; |
797 | 114 | } |
798 | | |
799 | | // `to[i]` is assigned value `v` conforming `from[v] <= i && from[v-1] > i`. |
800 | | // This is because the decision node in the tree splits on (property) > i, |
801 | | // hence everything that is not > of a threshold should be clustered |
802 | | // together. |
803 | | template <typename T> |
804 | | void QuantMap(const std::vector<int32_t> &from, std::vector<T> &to, |
805 | 518 | size_t num_pegs, int bias) { |
806 | 518 | to.resize(num_pegs); |
807 | 518 | size_t mapped = 0; |
808 | 530k | for (size_t i = 0; i < num_pegs; i++) { |
809 | 535k | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { |
810 | 5.72k | mapped++; |
811 | 5.72k | } |
812 | 529k | JXL_DASSERT(static_cast<T>(mapped) == mapped); |
813 | 529k | to[i] = mapped; |
814 | 529k | } |
815 | 518 | } enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned short>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned short, std::__1::allocator<unsigned short> >&, unsigned long, int) Line | Count | Source | 805 | 136 | size_t num_pegs, int bias) { | 806 | 136 | to.resize(num_pegs); | 807 | 136 | size_t mapped = 0; | 808 | 139k | for (size_t i = 0; i < num_pegs; i++) { | 809 | 139k | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { | 810 | 125 | mapped++; | 811 | 125 | } | 812 | 139k | JXL_DASSERT(static_cast<T>(mapped) == mapped); | 813 | 139k | to[i] = mapped; | 814 | 139k | } | 815 | 136 | } |
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned char>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> >&, unsigned long, int) Line | Count | Source | 805 | 382 | size_t num_pegs, int bias) { | 806 | 382 | to.resize(num_pegs); | 807 | 382 | size_t mapped = 0; | 808 | 391k | for (size_t i = 0; i < num_pegs; i++) { | 809 | 396k | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { | 810 | 5.59k | mapped++; | 811 | 5.59k | } | 812 | 390k | JXL_DASSERT(static_cast<T>(mapped) == mapped); | 813 | 390k | to[i] = mapped; | 814 | 390k | } | 815 | 382 | } |
|
816 | | } // namespace |
817 | | |
818 | | void TreeSamples::PreQuantizeProperties( |
819 | | const StaticPropRange &range, |
820 | | const std::vector<ModularMultiplierInfo> &multiplier_info, |
821 | | const std::vector<uint32_t> &group_pixel_count, |
822 | | const std::vector<uint32_t> &channel_pixel_count, |
823 | | std::vector<pixel_type> &pixel_samples, |
824 | 74 | std::vector<pixel_type> &diff_samples, size_t max_property_values) { |
825 | | // If we have forced splits because of multipliers, choose channel and group |
826 | | // thresholds accordingly. |
827 | 74 | std::vector<int32_t> group_multiplier_thresholds; |
828 | 74 | std::vector<int32_t> channel_multiplier_thresholds; |
829 | 74 | for (const auto &v : multiplier_info) { |
830 | 62 | if (v.range[0][0] != range[0][0]) { |
831 | 0 | channel_multiplier_thresholds.push_back(v.range[0][0] - 1); |
832 | 0 | } |
833 | 62 | if (v.range[0][1] != range[0][1]) { |
834 | 0 | channel_multiplier_thresholds.push_back(v.range[0][1] - 1); |
835 | 0 | } |
836 | 62 | if (v.range[1][0] != range[1][0]) { |
837 | 0 | group_multiplier_thresholds.push_back(v.range[1][0] - 1); |
838 | 0 | } |
839 | 62 | if (v.range[1][1] != range[1][1]) { |
840 | 0 | group_multiplier_thresholds.push_back(v.range[1][1] - 1); |
841 | 0 | } |
842 | 62 | } |
843 | 74 | std::sort(channel_multiplier_thresholds.begin(), |
844 | 74 | channel_multiplier_thresholds.end()); |
845 | 74 | channel_multiplier_thresholds.resize( |
846 | 74 | std::unique(channel_multiplier_thresholds.begin(), |
847 | 74 | channel_multiplier_thresholds.end()) - |
848 | 74 | channel_multiplier_thresholds.begin()); |
849 | 74 | std::sort(group_multiplier_thresholds.begin(), |
850 | 74 | group_multiplier_thresholds.end()); |
851 | 74 | group_multiplier_thresholds.resize( |
852 | 74 | std::unique(group_multiplier_thresholds.begin(), |
853 | 74 | group_multiplier_thresholds.end()) - |
854 | 74 | group_multiplier_thresholds.begin()); |
855 | | |
856 | 74 | compact_properties.resize(props_to_use.size()); |
857 | 74 | auto quantize_channel = [&]() { |
858 | 74 | if (!channel_multiplier_thresholds.empty()) { |
859 | 0 | return channel_multiplier_thresholds; |
860 | 0 | } |
861 | 74 | return QuantizeHistogram(channel_pixel_count, max_property_values); |
862 | 74 | }; |
863 | 74 | auto quantize_group_id = [&]() { |
864 | 62 | if (!group_multiplier_thresholds.empty()) { |
865 | 0 | return group_multiplier_thresholds; |
866 | 0 | } |
867 | 62 | return QuantizeHistogram(group_pixel_count, max_property_values); |
868 | 62 | }; |
869 | 74 | auto quantize_coordinate = [&]() { |
870 | 0 | std::vector<int32_t> quantized; |
871 | 0 | quantized.reserve(max_property_values - 1); |
872 | 0 | for (size_t i = 0; i + 1 < max_property_values; i++) { |
873 | 0 | quantized.push_back((i + 1) * 256 / max_property_values - 1); |
874 | 0 | } |
875 | 0 | return quantized; |
876 | 0 | }; |
877 | 74 | std::vector<int32_t> abs_pixel_thresholds; |
878 | 74 | std::vector<int32_t> pixel_thresholds; |
879 | 74 | auto quantize_pixel_property = [&]() { |
880 | 0 | if (pixel_thresholds.empty()) { |
881 | 0 | pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values); |
882 | 0 | } |
883 | 0 | return pixel_thresholds; |
884 | 0 | }; |
885 | 74 | auto quantize_abs_pixel_property = [&]() { |
886 | 0 | if (abs_pixel_thresholds.empty()) { |
887 | 0 | quantize_pixel_property(); // Compute the non-abs thresholds. |
888 | 0 | for (auto &v : pixel_samples) v = std::abs(v); |
889 | 0 | abs_pixel_thresholds = |
890 | 0 | QuantizeSamples(pixel_samples, max_property_values); |
891 | 0 | } |
892 | 0 | return abs_pixel_thresholds; |
893 | 0 | }; |
894 | 74 | std::vector<int32_t> abs_diff_thresholds; |
895 | 74 | std::vector<int32_t> diff_thresholds; |
896 | 308 | auto quantize_diff_property = [&]() { |
897 | 308 | if (diff_thresholds.empty()) { |
898 | 114 | diff_thresholds = QuantizeSamples(diff_samples, max_property_values); |
899 | 114 | } |
900 | 308 | return diff_thresholds; |
901 | 308 | }; |
902 | 74 | auto quantize_abs_diff_property = [&]() { |
903 | 0 | if (abs_diff_thresholds.empty()) { |
904 | 0 | quantize_diff_property(); // Compute the non-abs thresholds. |
905 | 0 | for (auto &v : diff_samples) v = std::abs(v); |
906 | 0 | abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values); |
907 | 0 | } |
908 | 0 | return abs_diff_thresholds; |
909 | 0 | }; |
910 | 74 | auto quantize_wp = [&]() { |
911 | 74 | if (max_property_values < 32) { |
912 | 0 | return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, |
913 | 0 | 1, 3, 7, 15, 31, 63, 127}; |
914 | 0 | } |
915 | 74 | if (max_property_values < 64) { |
916 | 74 | return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, |
917 | 74 | -15, -11, -7, -5, -3, -1, 0, 1, |
918 | 74 | 3, 5, 7, 11, 15, 23, 31, 47, |
919 | 74 | 63, 95, 127, 191, 255}; |
920 | 74 | } |
921 | 0 | return std::vector<int32_t>{ |
922 | 0 | -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, |
923 | 0 | -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, |
924 | 0 | -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, |
925 | 0 | 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, |
926 | 0 | 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; |
927 | 74 | }; |
928 | | |
929 | 74 | property_mapping.resize(props_to_use.size() - num_static_props); |
930 | 592 | for (size_t i = 0; i < props_to_use.size(); i++) { |
931 | 518 | if (props_to_use[i] == 0) { |
932 | 74 | compact_properties[i] = quantize_channel(); |
933 | 444 | } else if (props_to_use[i] == 1) { |
934 | 62 | compact_properties[i] = quantize_group_id(); |
935 | 382 | } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { |
936 | 0 | compact_properties[i] = quantize_coordinate(); |
937 | 382 | } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || |
938 | 382 | props_to_use[i] == 8 || |
939 | 382 | (props_to_use[i] >= kNumNonrefProperties && |
940 | 382 | (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { |
941 | 0 | compact_properties[i] = quantize_pixel_property(); |
942 | 382 | } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || |
943 | 382 | (props_to_use[i] >= kNumNonrefProperties && |
944 | 382 | (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { |
945 | 0 | compact_properties[i] = quantize_abs_pixel_property(); |
946 | 382 | } else if (props_to_use[i] >= kNumNonrefProperties && |
947 | 382 | (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { |
948 | 0 | compact_properties[i] = quantize_abs_diff_property(); |
949 | 382 | } else if (props_to_use[i] == kWPProp) { |
950 | 74 | compact_properties[i] = quantize_wp(); |
951 | 308 | } else { |
952 | 308 | compact_properties[i] = quantize_diff_property(); |
953 | 308 | } |
954 | 518 | if (i < num_static_props) { |
955 | 136 | QuantMap(compact_properties[i], static_property_mapping[i], |
956 | 136 | kPropertyRange * 2 + 1, kPropertyRange); |
957 | 382 | } else { |
958 | 382 | QuantMap(compact_properties[i], property_mapping[i - num_static_props], |
959 | 382 | kPropertyRange * 2 + 1, kPropertyRange); |
960 | 382 | } |
961 | 518 | } |
962 | 74 | } |
963 | | |
964 | | void CollectPixelSamples(const Image &image, const ModularOptions &options, |
965 | | uint32_t group_id, |
966 | | std::vector<uint32_t> &group_pixel_count, |
967 | | std::vector<uint32_t> &channel_pixel_count, |
968 | | std::vector<pixel_type> &pixel_samples, |
969 | 77 | std::vector<pixel_type> &diff_samples) { |
970 | 77 | if (options.nb_repeats == 0) return; |
971 | 77 | if (group_pixel_count.size() <= group_id) { |
972 | 77 | group_pixel_count.resize(group_id + 1); |
973 | 77 | } |
974 | 77 | if (channel_pixel_count.size() < image.channel.size()) { |
975 | 74 | channel_pixel_count.resize(image.channel.size()); |
976 | 74 | } |
977 | 77 | Rng rng(group_id); |
978 | | // Sample 10% of the final number of samples for property quantization. |
979 | 77 | float fraction = std::min(options.nb_repeats * 0.1, 0.99); |
980 | 77 | Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction); |
981 | 77 | size_t total_pixels = 0; |
982 | 77 | std::vector<size_t> channel_ids; |
983 | 278 | for (size_t i = 0; i < image.channel.size(); i++) { |
984 | 202 | if (i >= image.nb_meta_channels && |
985 | 202 | (image.channel[i].w > options.max_chan_size || |
986 | 201 | image.channel[i].h > options.max_chan_size)) { |
987 | 1 | break; |
988 | 1 | } |
989 | 201 | if (image.channel[i].w <= 1 || image.channel[i].h == 0) { |
990 | 1 | continue; // skip empty or width-1 channels. |
991 | 1 | } |
992 | 200 | channel_ids.push_back(i); |
993 | 200 | group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; |
994 | 200 | channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; |
995 | 200 | total_pixels += image.channel[i].w * image.channel[i].h; |
996 | 200 | } |
997 | 77 | if (channel_ids.empty()) return; |
998 | 75 | pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); |
999 | 75 | diff_samples.reserve(diff_samples.size() + fraction * total_pixels); |
1000 | 75 | size_t i = 0; |
1001 | 75 | size_t y = 0; |
1002 | 75 | size_t x = 0; |
1003 | 56.4k | auto advance = [&](size_t amount) { |
1004 | 56.4k | x += amount; |
1005 | | // Detect row overflow (rare). |
1006 | 69.4k | while (x >= image.channel[channel_ids[i]].w) { |
1007 | 13.0k | x -= image.channel[channel_ids[i]].w; |
1008 | 13.0k | y++; |
1009 | | // Detect end-of-channel (even rarer). |
1010 | 13.0k | if (y == image.channel[channel_ids[i]].h) { |
1011 | 200 | i++; |
1012 | 200 | y = 0; |
1013 | 200 | if (i >= channel_ids.size()) { |
1014 | 75 | return; |
1015 | 75 | } |
1016 | 200 | } |
1017 | 13.0k | } |
1018 | 56.4k | }; |
1019 | 75 | advance(rng.Geometric(dist)); |
1020 | 56.4k | for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { |
1021 | 56.3k | const pixel_type *row = image.channel[channel_ids[i]].Row(y); |
1022 | 56.3k | pixel_samples.push_back(row[x]); |
1023 | 56.3k | size_t xp = x == 0 ? 1 : x - 1; |
1024 | 56.3k | diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]); |
1025 | 56.3k | } |
1026 | 75 | } |
1027 | | |
1028 | | // TODO(veluca): very simple encoding scheme. This should be improved. |
1029 | | Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens, |
1030 | 224 | Tree *decoder_tree) { |
1031 | 224 | JXL_ENSURE(tree.size() <= kMaxTreeSize); |
1032 | 224 | std::queue<int> q; |
1033 | 224 | q.push(0); |
1034 | 224 | size_t leaf_id = 0; |
1035 | 224 | decoder_tree->clear(); |
1036 | 12.2k | while (!q.empty()) { |
1037 | 12.0k | int cur = q.front(); |
1038 | 12.0k | q.pop(); |
1039 | 12.0k | JXL_ENSURE(tree[cur].property >= -1); |
1040 | 12.0k | tokens->emplace_back(kPropertyContext, tree[cur].property + 1); |
1041 | 12.0k | if (tree[cur].property == -1) { |
1042 | 6.12k | tokens->emplace_back(kPredictorContext, |
1043 | 6.12k | static_cast<int>(tree[cur].predictor)); |
1044 | 6.12k | tokens->emplace_back(kOffsetContext, |
1045 | 6.12k | PackSigned(tree[cur].predictor_offset)); |
1046 | 6.12k | uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); |
1047 | 6.12k | uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; |
1048 | 6.12k | tokens->emplace_back(kMultiplierLogContext, mul_log); |
1049 | 6.12k | tokens->emplace_back(kMultiplierBitsContext, mul_bits); |
1050 | 6.12k | JXL_ENSURE(tree[cur].predictor < Predictor::Best); |
1051 | 6.12k | decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, |
1052 | 6.12k | tree[cur].predictor_offset, |
1053 | 6.12k | tree[cur].multiplier); |
1054 | 6.12k | continue; |
1055 | 6.12k | } |
1056 | 5.90k | decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, |
1057 | 5.90k | decoder_tree->size() + q.size() + 1, |
1058 | 5.90k | decoder_tree->size() + q.size() + 2, |
1059 | 5.90k | Predictor::Zero, 0, 1); |
1060 | 5.90k | q.push(tree[cur].lchild); |
1061 | 5.90k | q.push(tree[cur].rchild); |
1062 | 5.90k | tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); |
1063 | 5.90k | } |
1064 | 224 | return true; |
1065 | 224 | } |
1066 | | |
1067 | | } // namespace jxl |
1068 | | #endif // HWY_ONCE |