/src/libjxl/lib/jxl/modular/encoding/enc_ma.cc
Line | Count | Source |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/modular/encoding/enc_ma.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cstdint> |
10 | | #include <cstdlib> |
11 | | #include <cstring> |
12 | | #include <limits> |
13 | | #include <numeric> |
14 | | #include <queue> |
15 | | #include <vector> |
16 | | |
17 | | #include "lib/jxl/ans_params.h" |
18 | | #include "lib/jxl/base/bits.h" |
19 | | #include "lib/jxl/base/common.h" |
20 | | #include "lib/jxl/base/compiler_specific.h" |
21 | | #include "lib/jxl/base/status.h" |
22 | | #include "lib/jxl/dec_ans.h" |
23 | | #include "lib/jxl/modular/encoding/dec_ma.h" |
24 | | #include "lib/jxl/modular/encoding/ma_common.h" |
25 | | #include "lib/jxl/modular/modular_image.h" |
26 | | |
27 | | #undef HWY_TARGET_INCLUDE |
28 | | #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" |
29 | | #include <hwy/foreach_target.h> |
30 | | #include <hwy/highway.h> |
31 | | |
32 | | #include "lib/jxl/base/fast_math-inl.h" |
33 | | #include "lib/jxl/base/random.h" |
34 | | #include "lib/jxl/enc_ans.h" |
35 | | #include "lib/jxl/modular/encoding/context_predict.h" |
36 | | #include "lib/jxl/modular/options.h" |
37 | | #include "lib/jxl/pack_signed.h" |
38 | | HWY_BEFORE_NAMESPACE(); |
39 | | namespace jxl { |
40 | | namespace HWY_NAMESPACE { |
41 | | |
42 | | // These templates are not found via ADL. |
43 | | using hwy::HWY_NAMESPACE::Eq; |
44 | | using hwy::HWY_NAMESPACE::IfThenElse; |
45 | | using hwy::HWY_NAMESPACE::Lt; |
46 | | using hwy::HWY_NAMESPACE::Max; |
47 | | |
48 | | const HWY_FULL(float) df; |
49 | | const HWY_FULL(int32_t) di; |
50 | 346k | size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }Unexecuted instantiation: jxl::N_SSE4::Padded(unsigned long) jxl::N_AVX2::Padded(unsigned long) Line | Count | Source | 50 | 346k | size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } |
Unexecuted instantiation: jxl::N_SSE2::Padded(unsigned long) |
51 | | |
52 | | // Compute entropy of the histogram, taking into account the minimum probability |
53 | | // for symbols with non-zero counts. |
54 | 26.8M | float EstimateBits(const int32_t *counts, size_t num_symbols) { |
55 | 26.8M | int32_t total = std::accumulate(counts, counts + num_symbols, 0); |
56 | 26.8M | const auto zero = Zero(df); |
57 | 26.8M | const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); |
58 | 26.8M | const auto inv_total = Set(df, 1.0f / total); |
59 | 26.8M | auto bits_lanes = Zero(df); |
60 | 26.8M | auto total_v = Set(di, total); |
61 | 194M | for (size_t i = 0; i < num_symbols; i += Lanes(df)) { |
62 | 167M | const auto counts_iv = LoadU(di, &counts[i]); |
63 | 167M | const auto counts_fv = ConvertTo(df, counts_iv); |
64 | 167M | const auto probs = Mul(counts_fv, inv_total); |
65 | 167M | const auto mprobs = Max(probs, minprob); |
66 | 167M | const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), |
67 | 167M | BitCast(di, FastLog2f(df, mprobs))); |
68 | 167M | bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); |
69 | 167M | } |
70 | 26.8M | return GetLane(SumOfLanes(df, bits_lanes)); |
71 | 26.8M | } Unexecuted instantiation: jxl::N_SSE4::EstimateBits(int const*, unsigned long) jxl::N_AVX2::EstimateBits(int const*, unsigned long) Line | Count | Source | 54 | 26.8M | float EstimateBits(const int32_t *counts, size_t num_symbols) { | 55 | 26.8M | int32_t total = std::accumulate(counts, counts + num_symbols, 0); | 56 | 26.8M | const auto zero = Zero(df); | 57 | 26.8M | const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); | 58 | 26.8M | const auto inv_total = Set(df, 1.0f / total); | 59 | 26.8M | auto bits_lanes = Zero(df); | 60 | 26.8M | auto total_v = Set(di, total); | 61 | 194M | for (size_t i = 0; i < num_symbols; i += Lanes(df)) { | 62 | 167M | const auto counts_iv = LoadU(di, &counts[i]); | 63 | 167M | const auto counts_fv = ConvertTo(df, counts_iv); | 64 | 167M | const auto probs = Mul(counts_fv, inv_total); | 65 | 167M | const auto mprobs = Max(probs, minprob); | 66 | 167M | const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), | 67 | 167M | BitCast(di, FastLog2f(df, mprobs))); | 68 | 167M | bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); | 69 | 167M | } | 70 | 26.8M | return GetLane(SumOfLanes(df, bits_lanes)); | 71 | 26.8M | } |
Unexecuted instantiation: jxl::N_SSE2::EstimateBits(int const*, unsigned long) |
72 | | |
73 | | void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, |
74 | 172k | int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { |
75 | | // Note that the tree splits on *strictly greater*. |
76 | 172k | (*tree)[pos].lchild = tree->size(); |
77 | 172k | (*tree)[pos].rchild = tree->size() + 1; |
78 | 172k | (*tree)[pos].splitval = splitval; |
79 | 172k | (*tree)[pos].property = property; |
80 | 172k | tree->emplace_back(); |
81 | 172k | tree->back().property = -1; |
82 | 172k | tree->back().predictor = rpred; |
83 | 172k | tree->back().predictor_offset = roff; |
84 | 172k | tree->back().multiplier = 1; |
85 | 172k | tree->emplace_back(); |
86 | 172k | tree->back().property = -1; |
87 | 172k | tree->back().predictor = lpred; |
88 | 172k | tree->back().predictor_offset = loff; |
89 | 172k | tree->back().multiplier = 1; |
90 | 172k | } Unexecuted instantiation: jxl::N_SSE4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) jxl::N_AVX2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) Line | Count | Source | 74 | 172k | int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { | 75 | | // Note that the tree splits on *strictly greater*. | 76 | 172k | (*tree)[pos].lchild = tree->size(); | 77 | 172k | (*tree)[pos].rchild = tree->size() + 1; | 78 | 172k | (*tree)[pos].splitval = splitval; | 79 | 172k | (*tree)[pos].property = property; | 80 | 172k | tree->emplace_back(); | 81 | 172k | tree->back().property = -1; | 82 | 172k | tree->back().predictor = rpred; | 83 | 172k | tree->back().predictor_offset = roff; | 84 | 172k | tree->back().multiplier = 1; | 85 | 172k | tree->emplace_back(); | 86 | 172k | tree->back().property = -1; | 87 | 172k | tree->back().predictor = lpred; | 88 | 172k | tree->back().predictor_offset = loff; | 89 | 172k | tree->back().multiplier = 1; | 90 | 172k | } |
Unexecuted instantiation: jxl::N_SSE2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) |
91 | | |
92 | | enum class IntersectionType { kNone, kPartial, kInside }; |
93 | | IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, |
94 | 84.8k | uint32_t &partial_axis, uint32_t &partial_val) { |
95 | 84.8k | bool partial = false; |
96 | 254k | for (size_t i = 0; i < kNumStaticProperties; i++) { |
97 | 169k | if (haystack[i][0] >= needle[i][1]) { |
98 | 0 | return IntersectionType::kNone; |
99 | 0 | } |
100 | 169k | if (haystack[i][1] <= needle[i][0]) { |
101 | 0 | return IntersectionType::kNone; |
102 | 0 | } |
103 | 169k | if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { |
104 | 169k | continue; |
105 | 169k | } |
106 | 0 | partial = true; |
107 | 0 | partial_axis = i; |
108 | 0 | if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { |
109 | 0 | partial_val = haystack[i][0] - 1; |
110 | 0 | } else { |
111 | 0 | JXL_DASSERT(haystack[i][1] > needle[i][0] && |
112 | 0 | haystack[i][1] < needle[i][1]); |
113 | 0 | partial_val = haystack[i][1] - 1; |
114 | 0 | } |
115 | 0 | } |
116 | 84.8k | return partial ? IntersectionType::kPartial : IntersectionType::kInside; |
117 | 84.8k | } Unexecuted instantiation: jxl::N_SSE4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) jxl::N_AVX2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) Line | Count | Source | 94 | 84.8k | uint32_t &partial_axis, uint32_t &partial_val) { | 95 | 84.8k | bool partial = false; | 96 | 254k | for (size_t i = 0; i < kNumStaticProperties; i++) { | 97 | 169k | if (haystack[i][0] >= needle[i][1]) { | 98 | 0 | return IntersectionType::kNone; | 99 | 0 | } | 100 | 169k | if (haystack[i][1] <= needle[i][0]) { | 101 | 0 | return IntersectionType::kNone; | 102 | 0 | } | 103 | 169k | if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { | 104 | 169k | continue; | 105 | 169k | } | 106 | 0 | partial = true; | 107 | 0 | partial_axis = i; | 108 | 0 | if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { | 109 | 0 | partial_val = haystack[i][0] - 1; | 110 | 0 | } else { | 111 | 0 | JXL_DASSERT(haystack[i][1] > needle[i][0] && | 112 | 0 | haystack[i][1] < needle[i][1]); | 113 | 0 | partial_val = haystack[i][1] - 1; | 114 | 0 | } | 115 | 0 | } | 116 | 84.8k | return partial ? IntersectionType::kPartial : IntersectionType::kInside; | 117 | 84.8k | } |
Unexecuted instantiation: jxl::N_SSE2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) |
118 | | |
119 | | template<bool S> |
120 | | void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, |
121 | 172k | size_t end, size_t prop, uint32_t val) { |
122 | 172k | size_t begin_pos = begin; |
123 | 172k | size_t end_pos = pos; |
124 | 9.76M | do { |
125 | 24.2M | while (begin_pos < pos && |
126 | 24.2M | tree_samples.Property<S>(prop, begin_pos) <= val) { |
127 | 14.5M | ++begin_pos; |
128 | 14.5M | } |
129 | 26.6M | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { |
130 | 16.8M | ++end_pos; |
131 | 16.8M | } |
132 | 9.76M | if (begin_pos < pos && end_pos < end) { |
133 | 9.72M | tree_samples.Swap(begin_pos, end_pos); |
134 | 9.72M | } |
135 | 9.76M | ++begin_pos; |
136 | 9.76M | ++end_pos; |
137 | 9.76M | } while (begin_pos < pos && end_pos < end); |
138 | 172k | } Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) void jxl::N_AVX2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Line | Count | Source | 121 | 6.75k | size_t end, size_t prop, uint32_t val) { | 122 | 6.75k | size_t begin_pos = begin; | 123 | 6.75k | size_t end_pos = pos; | 124 | 546k | do { | 125 | 2.87M | while (begin_pos < pos && | 126 | 2.87M | tree_samples.Property<S>(prop, begin_pos) <= val) { | 127 | 2.33M | ++begin_pos; | 128 | 2.33M | } | 129 | 1.87M | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { | 130 | 1.33M | ++end_pos; | 131 | 1.33M | } | 132 | 546k | if (begin_pos < pos && end_pos < end) { | 133 | 543k | tree_samples.Swap(begin_pos, end_pos); | 134 | 543k | } | 135 | 546k | ++begin_pos; | 136 | 546k | ++end_pos; | 137 | 546k | } while (begin_pos < pos && end_pos < end); | 138 | 6.75k | } |
void jxl::N_AVX2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Line | Count | Source | 121 | 165k | size_t end, size_t prop, uint32_t val) { | 122 | 165k | size_t begin_pos = begin; | 123 | 165k | size_t end_pos = pos; | 124 | 9.21M | do { | 125 | 21.4M | while (begin_pos < pos && | 126 | 21.3M | tree_samples.Property<S>(prop, begin_pos) <= val) { | 127 | 12.1M | ++begin_pos; | 128 | 12.1M | } | 129 | 24.7M | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { | 130 | 15.5M | ++end_pos; | 131 | 15.5M | } | 132 | 9.21M | if (begin_pos < pos && end_pos < end) { | 133 | 9.18M | tree_samples.Swap(begin_pos, end_pos); | 134 | 9.18M | } | 135 | 9.21M | ++begin_pos; | 136 | 9.21M | ++end_pos; | 137 | 9.21M | } while (begin_pos < pos && end_pos < end); | 138 | 165k | } |
Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) |
139 | | |
140 | | template <bool S> |
141 | | void CollectExtraBitsIncrease(TreeSamples &tree_samples, |
142 | | const std::vector<ResidualToken> &rtokens, |
143 | | std::vector<int> &count_increase, |
144 | | std::vector<size_t> &extra_bits_increase, |
145 | | size_t begin, size_t end, size_t prop_idx, |
146 | 2.25M | size_t max_symbols) { |
147 | 402M | for (size_t i2 = begin; i2 < end; i2++) { |
148 | 399M | const ResidualToken &rt = rtokens[i2]; |
149 | 399M | size_t cnt = tree_samples.Count(i2); |
150 | 399M | size_t p = tree_samples.Property<S>(prop_idx, i2); |
151 | 399M | size_t sym = rt.tok; |
152 | 399M | size_t ebi = rt.nbits * cnt; |
153 | 399M | count_increase[p * max_symbols + sym] += cnt; |
154 | 399M | extra_bits_increase[p] += ebi; |
155 | 399M | } |
156 | 2.25M | } Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) void jxl::N_AVX2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Line | Count | Source | 146 | 418k | size_t max_symbols) { | 147 | 79.4M | for (size_t i2 = begin; i2 < end; i2++) { | 148 | 79.0M | const ResidualToken &rt = rtokens[i2]; | 149 | 79.0M | size_t cnt = tree_samples.Count(i2); | 150 | 79.0M | size_t p = tree_samples.Property<S>(prop_idx, i2); | 151 | 79.0M | size_t sym = rt.tok; | 152 | 79.0M | size_t ebi = rt.nbits * cnt; | 153 | 79.0M | count_increase[p * max_symbols + sym] += cnt; | 154 | 79.0M | extra_bits_increase[p] += ebi; | 155 | 79.0M | } | 156 | 418k | } |
void jxl::N_AVX2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Line | Count | Source | 146 | 1.83M | size_t max_symbols) { | 147 | 322M | for (size_t i2 = begin; i2 < end; i2++) { | 148 | 320M | const ResidualToken &rt = rtokens[i2]; | 149 | 320M | size_t cnt = tree_samples.Count(i2); | 150 | 320M | size_t p = tree_samples.Property<S>(prop_idx, i2); | 151 | 320M | size_t sym = rt.tok; | 152 | 320M | size_t ebi = rt.nbits * cnt; | 153 | 320M | count_increase[p * max_symbols + sym] += cnt; | 154 | 320M | extra_bits_increase[p] += ebi; | 155 | 320M | } | 156 | 1.83M | } |
Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) |
157 | | |
158 | | void FindBestSplit(TreeSamples &tree_samples, float threshold, |
159 | | const std::vector<ModularMultiplierInfo> &mul_info, |
160 | | StaticPropRange initial_static_prop_range, |
161 | 2.28k | float fast_decode_multiplier, Tree *tree) { |
162 | 2.28k | struct NodeInfo { |
163 | 2.28k | size_t pos; |
164 | 2.28k | size_t begin; |
165 | 2.28k | size_t end; |
166 | 2.28k | StaticPropRange static_prop_range; |
167 | 2.28k | }; |
168 | 2.28k | std::vector<NodeInfo> nodes; |
169 | 2.28k | nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), |
170 | 2.28k | initial_static_prop_range}); |
171 | | |
172 | 2.28k | size_t num_predictors = tree_samples.NumPredictors(); |
173 | 2.28k | size_t num_properties = tree_samples.NumProperties(); |
174 | | |
175 | | // TODO(veluca): consider parallelizing the search (processing multiple nodes |
176 | | // at a time). |
177 | 348k | while (!nodes.empty()) { |
178 | 346k | size_t pos = nodes.back().pos; |
179 | 346k | size_t begin = nodes.back().begin; |
180 | 346k | size_t end = nodes.back().end; |
181 | | |
182 | 346k | StaticPropRange static_prop_range = nodes.back().static_prop_range; |
183 | 346k | nodes.pop_back(); |
184 | 346k | if (begin == end) continue; |
185 | | |
186 | 346k | struct SplitInfo { |
187 | 346k | size_t prop = 0; |
188 | 346k | uint32_t val = 0; |
189 | 346k | size_t pos = 0; |
190 | 346k | float lcost = std::numeric_limits<float>::max(); |
191 | 346k | float rcost = std::numeric_limits<float>::max(); |
192 | 346k | Predictor lpred = Predictor::Zero; |
193 | 346k | Predictor rpred = Predictor::Zero; |
194 | 15.0M | float Cost() const { return lcost + rcost; }Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const Line | Count | Source | 194 | 15.0M | float Cost() const { return lcost + rcost; } |
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const |
195 | 346k | }; |
196 | | |
197 | 346k | SplitInfo best_split_static_constant; |
198 | 346k | SplitInfo best_split_static; |
199 | 346k | SplitInfo best_split_nonstatic; |
200 | 346k | SplitInfo best_split_nowp; |
201 | | |
202 | 346k | JXL_DASSERT(begin <= end); |
203 | 346k | JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); |
204 | | |
205 | | // Compute the maximum token in the range. |
206 | 346k | size_t max_symbols = 0; |
207 | 693k | for (size_t pred = 0; pred < num_predictors; pred++) { |
208 | 57.7M | for (size_t i = begin; i < end; i++) { |
209 | 57.4M | uint32_t tok = tree_samples.Token(pred, i); |
210 | 57.4M | max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; |
211 | 57.4M | } |
212 | 346k | } |
213 | 346k | max_symbols = Padded(max_symbols); |
214 | 346k | std::vector<int32_t> counts(max_symbols * num_predictors); |
215 | 346k | std::vector<uint32_t> tot_extra_bits(num_predictors); |
216 | 693k | for (size_t pred = 0; pred < num_predictors; pred++) { |
217 | 346k | size_t extra_bits = 0; |
218 | 346k | const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred); |
219 | 57.7M | for (size_t i = begin; i < end; i++) { |
220 | 57.4M | const ResidualToken& rt = rtokens[i]; |
221 | 57.4M | size_t count = tree_samples.Count(i); |
222 | 57.4M | size_t eb = rt.nbits * count; |
223 | 57.4M | counts[pred * max_symbols + rt.tok] += count; |
224 | 57.4M | extra_bits += eb; |
225 | 57.4M | } |
226 | 346k | tot_extra_bits[pred] = extra_bits; |
227 | 346k | } |
228 | | |
229 | 346k | float base_bits; |
230 | 346k | { |
231 | 346k | size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); |
232 | 346k | base_bits = |
233 | 346k | EstimateBits(counts.data() + pred * max_symbols, max_symbols) + |
234 | 346k | tot_extra_bits[pred]; |
235 | 346k | } |
236 | | |
237 | 346k | SplitInfo *best = &best_split_nonstatic; |
238 | | |
239 | 346k | SplitInfo forced_split; |
240 | | // The multiplier ranges cut halfway through the current ranges of static |
241 | | // properties. We do this even if the current node is not a leaf, to |
242 | | // minimize the number of nodes in the resulting tree. |
243 | 346k | for (const auto &mmi : mul_info) { |
244 | 84.8k | uint32_t axis; |
245 | 84.8k | uint32_t val; |
246 | 84.8k | IntersectionType t = |
247 | 84.8k | BoxIntersects(static_prop_range, mmi.range, axis, val); |
248 | 84.8k | if (t == IntersectionType::kNone) continue; |
249 | 84.8k | if (t == IntersectionType::kInside) { |
250 | 84.8k | (*tree)[pos].multiplier = mmi.multiplier; |
251 | 84.8k | break; |
252 | 84.8k | } |
253 | 0 | if (t == IntersectionType::kPartial) { |
254 | 0 | JXL_DASSERT(axis < kNumStaticProperties); |
255 | 0 | forced_split.val = tree_samples.QuantizeStaticProperty(axis, val); |
256 | 0 | forced_split.prop = axis; |
257 | 0 | forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; |
258 | 0 | forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; |
259 | 0 | best = &forced_split; |
260 | 0 | best->pos = begin; |
261 | 0 | JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); |
262 | 0 | if (best->prop < tree_samples.NumStaticProps()) { |
263 | 0 | for (size_t x = begin; x < end; x++) { |
264 | 0 | if (tree_samples.Property<true>(best->prop, x) <= best->val) { |
265 | 0 | best->pos++; |
266 | 0 | } |
267 | 0 | } |
268 | 0 | } else { |
269 | 0 | size_t prop = best->prop - tree_samples.NumStaticProps(); |
270 | 0 | for (size_t x = begin; x < end; x++) { |
271 | 0 | if (tree_samples.Property<false>(prop, x) <= best->val) { |
272 | 0 | best->pos++; |
273 | 0 | } |
274 | 0 | } |
275 | 0 | } |
276 | 0 | break; |
277 | 0 | } |
278 | 0 | } |
279 | | |
280 | 346k | if (best != &forced_split) { |
281 | 346k | std::vector<int> prop_value_used_count; |
282 | 346k | std::vector<int> count_increase; |
283 | 346k | std::vector<size_t> extra_bits_increase; |
284 | | // For each property, compute which of its values are used, and what |
285 | | // tokens correspond to those usages. Then, iterate through the values, |
286 | | // and compute the entropy of each side of the split (of the form `prop > |
287 | | // threshold`). Finally, find the split that minimizes the cost. |
288 | 346k | struct CostInfo { |
289 | 346k | float cost = std::numeric_limits<float>::max(); |
290 | 346k | float extra_cost = 0; |
291 | 26.4M | float Cost() const { return cost + extra_cost; }Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const Line | Count | Source | 291 | 26.4M | float Cost() const { return cost + extra_cost; } |
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const |
292 | 346k | Predictor pred; // will be uninitialized in some cases, but never used. |
293 | 346k | }; |
294 | 346k | std::vector<CostInfo> costs_l; |
295 | 346k | std::vector<CostInfo> costs_r; |
296 | | |
297 | 346k | std::vector<int32_t> counts_above(max_symbols); |
298 | 346k | std::vector<int32_t> counts_below(max_symbols); |
299 | | |
300 | | // The lower the threshold, the higher the expected noisiness of the |
301 | | // estimate. Thus, discourage changing predictors. |
302 | 346k | float change_pred_penalty = 800.0f / (100.0f + threshold); |
303 | 2.60M | for (size_t prop = 0; prop < num_properties && base_bits > threshold; |
304 | 2.25M | prop++) { |
305 | 2.25M | costs_l.clear(); |
306 | 2.25M | costs_r.clear(); |
307 | 2.25M | size_t prop_size = tree_samples.NumPropertyValues(prop); |
308 | 2.25M | if (extra_bits_increase.size() < prop_size) { |
309 | 813k | count_increase.resize(prop_size * max_symbols); |
310 | 813k | extra_bits_increase.resize(prop_size); |
311 | 813k | } |
312 | | // Clear prop_value_used_count (which cannot be cleared "on the go") |
313 | 2.25M | prop_value_used_count.clear(); |
314 | 2.25M | prop_value_used_count.resize(prop_size); |
315 | | |
316 | 2.25M | size_t first_used = prop_size; |
317 | 2.25M | size_t last_used = 0; |
318 | | |
319 | | // TODO(veluca): consider finding multiple splits along a single |
320 | | // property at the same time, possibly with a bottom-up approach. |
321 | 2.25M | if (prop < tree_samples.NumStaticProps()) { |
322 | 79.4M | for (size_t i = begin; i < end; i++) { |
323 | 79.0M | size_t p = tree_samples.Property<true>(prop, i); |
324 | 79.0M | prop_value_used_count[p]++; |
325 | 79.0M | last_used = std::max(last_used, p); |
326 | 79.0M | first_used = std::min(first_used, p); |
327 | 79.0M | } |
328 | 1.83M | } else { |
329 | 1.83M | size_t prop_idx = prop - tree_samples.NumStaticProps(); |
330 | 322M | for (size_t i = begin; i < end; i++) { |
331 | 320M | size_t p = tree_samples.Property<false>(prop_idx, i); |
332 | 320M | prop_value_used_count[p]++; |
333 | 320M | last_used = std::max(last_used, p); |
334 | 320M | first_used = std::min(first_used, p); |
335 | 320M | } |
336 | 1.83M | } |
337 | 2.25M | costs_l.resize(last_used - first_used); |
338 | 2.25M | costs_r.resize(last_used - first_used); |
339 | | // For all predictors, compute the right and left costs of each split. |
340 | 4.51M | for (size_t pred = 0; pred < num_predictors; pred++) { |
341 | | // Compute cost and histogram increments for each property value. |
342 | 2.25M | const std::vector<ResidualToken> &rtokens = |
343 | 2.25M | tree_samples.RTokens(pred); |
344 | 2.25M | if (prop < tree_samples.NumStaticProps()) { |
345 | 418k | CollectExtraBitsIncrease<true>(tree_samples, rtokens, |
346 | 418k | count_increase, extra_bits_increase, |
347 | 418k | begin, end, prop, max_symbols); |
348 | 1.83M | } else { |
349 | 1.83M | CollectExtraBitsIncrease<false>( |
350 | 1.83M | tree_samples, rtokens, count_increase, extra_bits_increase, |
351 | 1.83M | begin, end, prop - tree_samples.NumStaticProps(), max_symbols); |
352 | 1.83M | } |
353 | 2.25M | memcpy(counts_above.data(), counts.data() + pred * max_symbols, |
354 | 2.25M | max_symbols * sizeof counts_above[0]); |
355 | 2.25M | memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); |
356 | 2.25M | size_t extra_bits_below = 0; |
357 | | // Exclude last used: this ensures neither counts_above nor |
358 | | // counts_below is empty. |
359 | 32.7M | for (size_t i = first_used; i < last_used; i++) { |
360 | 30.4M | if (!prop_value_used_count[i]) continue; |
361 | 13.2M | extra_bits_below += extra_bits_increase[i]; |
362 | | // The increase for this property value has been used, and will not |
363 | | // be used again: clear it. Also below. |
364 | 13.2M | extra_bits_increase[i] = 0; |
365 | 674M | for (size_t sym = 0; sym < max_symbols; sym++) { |
366 | 660M | counts_above[sym] -= count_increase[i * max_symbols + sym]; |
367 | 660M | counts_below[sym] += count_increase[i * max_symbols + sym]; |
368 | 660M | count_increase[i * max_symbols + sym] = 0; |
369 | 660M | } |
370 | 13.2M | float rcost = EstimateBits(counts_above.data(), max_symbols) + |
371 | 13.2M | tot_extra_bits[pred] - extra_bits_below; |
372 | 13.2M | float lcost = EstimateBits(counts_below.data(), max_symbols) + |
373 | 13.2M | extra_bits_below; |
374 | 13.2M | JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); |
375 | 13.2M | float penalty = 0; |
376 | | // Never discourage moving away from the Weighted predictor. |
377 | 13.2M | if (tree_samples.PredictorFromIndex(pred) != |
378 | 13.2M | (*tree)[pos].predictor && |
379 | 0 | (*tree)[pos].predictor != Predictor::Weighted) { |
380 | 0 | penalty = change_pred_penalty; |
381 | 0 | } |
382 | | // If everything else is equal, disfavour Weighted (slower) and |
383 | | // favour Zero (faster if it's the only predictor used in a |
384 | | // group+channel combination) |
385 | 13.2M | if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { |
386 | 0 | penalty += 1e-8; |
387 | 0 | } |
388 | 13.2M | if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { |
389 | 0 | penalty -= 1e-8; |
390 | 0 | } |
391 | 13.2M | if (rcost + penalty < costs_r[i - first_used].Cost()) { |
392 | 13.2M | costs_r[i - first_used].cost = rcost; |
393 | 13.2M | costs_r[i - first_used].extra_cost = penalty; |
394 | 13.2M | costs_r[i - first_used].pred = |
395 | 13.2M | tree_samples.PredictorFromIndex(pred); |
396 | 13.2M | } |
397 | 13.2M | if (lcost + penalty < costs_l[i - first_used].Cost()) { |
398 | 13.2M | costs_l[i - first_used].cost = lcost; |
399 | 13.2M | costs_l[i - first_used].extra_cost = penalty; |
400 | 13.2M | costs_l[i - first_used].pred = |
401 | 13.2M | tree_samples.PredictorFromIndex(pred); |
402 | 13.2M | } |
403 | 13.2M | } |
404 | 2.25M | } |
405 | | // Iterate through the possible splits and find the one with minimum sum |
406 | | // of costs of the two sides. |
407 | 2.25M | size_t split = begin; |
408 | 32.7M | for (size_t i = first_used; i < last_used; i++) { |
409 | 30.4M | if (!prop_value_used_count[i]) continue; |
410 | 13.2M | split += prop_value_used_count[i]; |
411 | 13.2M | float rcost = costs_r[i - first_used].cost; |
412 | 13.2M | float lcost = costs_l[i - first_used].cost; |
413 | | |
414 | 13.2M | bool uses_wp = tree_samples.PropertyFromIndex(prop) == kWPProp || |
415 | 11.0M | costs_l[i - first_used].pred == Predictor::Weighted || |
416 | 11.0M | costs_r[i - first_used].pred == Predictor::Weighted; |
417 | 13.2M | bool zero_entropy_side = rcost == 0 || lcost == 0; |
418 | | |
419 | 13.2M | SplitInfo &best_ref = |
420 | 13.2M | tree_samples.PropertyFromIndex(prop) < kNumStaticProperties |
421 | 13.2M | ? (zero_entropy_side ? best_split_static_constant |
422 | 90.6k | : best_split_static) |
423 | 13.2M | : (uses_wp ? best_split_nonstatic : best_split_nowp); |
424 | 13.2M | if (lcost + rcost < best_ref.Cost()) { |
425 | 2.40M | best_ref.prop = prop; |
426 | 2.40M | best_ref.val = i; |
427 | 2.40M | best_ref.pos = split; |
428 | 2.40M | best_ref.lcost = lcost; |
429 | 2.40M | best_ref.lpred = costs_l[i - first_used].pred; |
430 | 2.40M | best_ref.rcost = rcost; |
431 | 2.40M | best_ref.rpred = costs_r[i - first_used].pred; |
432 | 2.40M | } |
433 | 13.2M | } |
434 | | // Clear extra_bits_increase and cost_increase for last_used. |
435 | 2.25M | extra_bits_increase[last_used] = 0; |
436 | 115M | for (size_t sym = 0; sym < max_symbols; sym++) { |
437 | 113M | count_increase[last_used * max_symbols + sym] = 0; |
438 | 113M | } |
439 | 2.25M | } |
440 | | |
441 | | // Try to avoid introducing WP. |
442 | 346k | if (best_split_nowp.Cost() + threshold < base_bits && |
443 | 166k | best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { |
444 | 159k | best = &best_split_nowp; |
445 | 159k | } |
446 | | // Split along static props if possible and not significantly more |
447 | | // expensive. |
448 | 346k | if (best_split_static.Cost() + threshold < base_bits && |
449 | 30.1k | best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { |
450 | 6.40k | best = &best_split_static; |
451 | 6.40k | } |
452 | | // Split along static props to create constant nodes if possible. |
453 | 346k | if (best_split_static_constant.Cost() + threshold < base_bits) { |
454 | 489 | best = &best_split_static_constant; |
455 | 489 | } |
456 | 346k | } |
457 | | |
458 | 346k | if (best->Cost() + threshold < base_bits) { |
459 | 172k | uint32_t p = tree_samples.PropertyFromIndex(best->prop); |
460 | 172k | pixel_type dequant = |
461 | 172k | tree_samples.UnquantizeProperty(best->prop, best->val); |
462 | | // Split node and try to split children. |
463 | 172k | MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); |
464 | | // "Sort" according to winning property |
465 | 172k | if (best->prop < tree_samples.NumStaticProps()) { |
466 | 6.75k | SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop, |
467 | 6.75k | best->val); |
468 | 165k | } else { |
469 | 165k | SplitTreeSamples<false>(tree_samples, begin, best->pos, end, |
470 | 165k | best->prop - tree_samples.NumStaticProps(), |
471 | 165k | best->val); |
472 | 165k | } |
473 | 172k | auto new_sp_range = static_prop_range; |
474 | 172k | if (p < kNumStaticProperties) { |
475 | 6.75k | JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); |
476 | 6.75k | new_sp_range[p][1] = dequant + 1; |
477 | 6.75k | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); |
478 | 6.75k | } |
479 | 172k | nodes.push_back( |
480 | 172k | NodeInfo{(*tree)[pos].rchild, begin, best->pos, new_sp_range}); |
481 | 172k | new_sp_range = static_prop_range; |
482 | 172k | if (p < kNumStaticProperties) { |
483 | 6.75k | JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); |
484 | 6.75k | new_sp_range[p][0] = dequant + 1; |
485 | 6.75k | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); |
486 | 6.75k | } |
487 | 172k | nodes.push_back( |
488 | 172k | NodeInfo{(*tree)[pos].lchild, best->pos, end, new_sp_range}); |
489 | 172k | } |
490 | 346k | } |
491 | 2.28k | } Unexecuted instantiation: jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) Line | Count | Source | 161 | 2.28k | float fast_decode_multiplier, Tree *tree) { | 162 | 2.28k | struct NodeInfo { | 163 | 2.28k | size_t pos; | 164 | 2.28k | size_t begin; | 165 | 2.28k | size_t end; | 166 | 2.28k | StaticPropRange static_prop_range; | 167 | 2.28k | }; | 168 | 2.28k | std::vector<NodeInfo> nodes; | 169 | 2.28k | nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), | 170 | 2.28k | initial_static_prop_range}); | 171 | | | 172 | 2.28k | size_t num_predictors = tree_samples.NumPredictors(); | 173 | 2.28k | size_t num_properties = tree_samples.NumProperties(); | 174 | | | 175 | | // TODO(veluca): consider parallelizing the search (processing multiple nodes | 176 | | // at a time). | 177 | 348k | while (!nodes.empty()) { | 178 | 346k | size_t pos = nodes.back().pos; | 179 | 346k | size_t begin = nodes.back().begin; | 180 | 346k | size_t end = nodes.back().end; | 181 | | | 182 | 346k | StaticPropRange static_prop_range = nodes.back().static_prop_range; | 183 | 346k | nodes.pop_back(); | 184 | 346k | if (begin == end) continue; | 185 | | | 186 | 346k | struct SplitInfo { | 187 | 346k | size_t prop = 0; | 188 | 346k | uint32_t val = 0; | 189 | 346k | size_t pos = 0; | 190 | 346k | float lcost = std::numeric_limits<float>::max(); | 191 | 346k | float rcost = std::numeric_limits<float>::max(); | 192 | 346k | Predictor lpred = Predictor::Zero; | 193 | 346k | Predictor rpred = Predictor::Zero; | 194 | 346k | float Cost() const { return lcost + rcost; } | 195 | 346k | }; | 196 | | | 197 | 346k | SplitInfo best_split_static_constant; | 198 | 346k | SplitInfo best_split_static; | 199 | 346k | SplitInfo best_split_nonstatic; | 200 | 346k | SplitInfo best_split_nowp; | 201 | | | 202 | 346k | JXL_DASSERT(begin <= end); | 203 | 346k | JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); | 204 | | | 205 | | // Compute the maximum token in the range. | 206 | 346k | size_t max_symbols = 0; | 207 | 693k | for (size_t pred = 0; pred < num_predictors; pred++) { | 208 | 57.7M | for (size_t i = begin; i < end; i++) { | 209 | 57.4M | uint32_t tok = tree_samples.Token(pred, i); | 210 | 57.4M | max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; | 211 | 57.4M | } | 212 | 346k | } | 213 | 346k | max_symbols = Padded(max_symbols); | 214 | 346k | std::vector<int32_t> counts(max_symbols * num_predictors); | 215 | 346k | std::vector<uint32_t> tot_extra_bits(num_predictors); | 216 | 693k | for (size_t pred = 0; pred < num_predictors; pred++) { | 217 | 346k | size_t extra_bits = 0; | 218 | 346k | const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred); | 219 | 57.7M | for (size_t i = begin; i < end; i++) { | 220 | 57.4M | const ResidualToken& rt = rtokens[i]; | 221 | 57.4M | size_t count = tree_samples.Count(i); | 222 | 57.4M | size_t eb = rt.nbits * count; | 223 | 57.4M | counts[pred * max_symbols + rt.tok] += count; | 224 | 57.4M | extra_bits += eb; | 225 | 57.4M | } | 226 | 346k | tot_extra_bits[pred] = extra_bits; | 227 | 346k | } | 228 | | | 229 | 346k | float base_bits; | 230 | 346k | { | 231 | 346k | size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); | 232 | 346k | base_bits = | 233 | 346k | EstimateBits(counts.data() + pred * max_symbols, max_symbols) + | 234 | 346k | tot_extra_bits[pred]; | 235 | 346k | } | 236 | | | 237 | 346k | SplitInfo *best = &best_split_nonstatic; | 238 | | | 239 | 346k | SplitInfo forced_split; | 240 | | // The multiplier ranges cut halfway through the current ranges of static | 241 | | // properties. We do this even if the current node is not a leaf, to | 242 | | // minimize the number of nodes in the resulting tree. | 243 | 346k | for (const auto &mmi : mul_info) { | 244 | 84.8k | uint32_t axis; | 245 | 84.8k | uint32_t val; | 246 | 84.8k | IntersectionType t = | 247 | 84.8k | BoxIntersects(static_prop_range, mmi.range, axis, val); | 248 | 84.8k | if (t == IntersectionType::kNone) continue; | 249 | 84.8k | if (t == IntersectionType::kInside) { | 250 | 84.8k | (*tree)[pos].multiplier = mmi.multiplier; | 251 | 84.8k | break; | 252 | 84.8k | } | 253 | 0 | if (t == IntersectionType::kPartial) { | 254 | 0 | JXL_DASSERT(axis < kNumStaticProperties); | 255 | 0 | forced_split.val = tree_samples.QuantizeStaticProperty(axis, val); | 256 | 0 | forced_split.prop = axis; | 257 | 0 | forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; | 258 | 0 | forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; | 259 | 0 | best = &forced_split; | 260 | 0 | best->pos = begin; | 261 | 0 | JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); | 262 | 0 | if (best->prop < tree_samples.NumStaticProps()) { | 263 | 0 | for (size_t x = begin; x < end; x++) { | 264 | 0 | if (tree_samples.Property<true>(best->prop, x) <= best->val) { | 265 | 0 | best->pos++; | 266 | 0 | } | 267 | 0 | } | 268 | 0 | } else { | 269 | 0 | size_t prop = best->prop - tree_samples.NumStaticProps(); | 270 | 0 | for (size_t x = begin; x < end; x++) { | 271 | 0 | if (tree_samples.Property<false>(prop, x) <= best->val) { | 272 | 0 | best->pos++; | 273 | 0 | } | 274 | 0 | } | 275 | 0 | } | 276 | 0 | break; | 277 | 0 | } | 278 | 0 | } | 279 | | | 280 | 346k | if (best != &forced_split) { | 281 | 346k | std::vector<int> prop_value_used_count; | 282 | 346k | std::vector<int> count_increase; | 283 | 346k | std::vector<size_t> extra_bits_increase; | 284 | | // For each property, compute which of its values are used, and what | 285 | | // tokens correspond to those usages. Then, iterate through the values, | 286 | | // and compute the entropy of each side of the split (of the form `prop > | 287 | | // threshold`). Finally, find the split that minimizes the cost. | 288 | 346k | struct CostInfo { | 289 | 346k | float cost = std::numeric_limits<float>::max(); | 290 | 346k | float extra_cost = 0; | 291 | 346k | float Cost() const { return cost + extra_cost; } | 292 | 346k | Predictor pred; // will be uninitialized in some cases, but never used. | 293 | 346k | }; | 294 | 346k | std::vector<CostInfo> costs_l; | 295 | 346k | std::vector<CostInfo> costs_r; | 296 | | | 297 | 346k | std::vector<int32_t> counts_above(max_symbols); | 298 | 346k | std::vector<int32_t> counts_below(max_symbols); | 299 | | | 300 | | // The lower the threshold, the higher the expected noisiness of the | 301 | | // estimate. Thus, discourage changing predictors. | 302 | 346k | float change_pred_penalty = 800.0f / (100.0f + threshold); | 303 | 2.60M | for (size_t prop = 0; prop < num_properties && base_bits > threshold; | 304 | 2.25M | prop++) { | 305 | 2.25M | costs_l.clear(); | 306 | 2.25M | costs_r.clear(); | 307 | 2.25M | size_t prop_size = tree_samples.NumPropertyValues(prop); | 308 | 2.25M | if (extra_bits_increase.size() < prop_size) { | 309 | 813k | count_increase.resize(prop_size * max_symbols); | 310 | 813k | extra_bits_increase.resize(prop_size); | 311 | 813k | } | 312 | | // Clear prop_value_used_count (which cannot be cleared "on the go") | 313 | 2.25M | prop_value_used_count.clear(); | 314 | 2.25M | prop_value_used_count.resize(prop_size); | 315 | | | 316 | 2.25M | size_t first_used = prop_size; | 317 | 2.25M | size_t last_used = 0; | 318 | | | 319 | | // TODO(veluca): consider finding multiple splits along a single | 320 | | // property at the same time, possibly with a bottom-up approach. | 321 | 2.25M | if (prop < tree_samples.NumStaticProps()) { | 322 | 79.4M | for (size_t i = begin; i < end; i++) { | 323 | 79.0M | size_t p = tree_samples.Property<true>(prop, i); | 324 | 79.0M | prop_value_used_count[p]++; | 325 | 79.0M | last_used = std::max(last_used, p); | 326 | 79.0M | first_used = std::min(first_used, p); | 327 | 79.0M | } | 328 | 1.83M | } else { | 329 | 1.83M | size_t prop_idx = prop - tree_samples.NumStaticProps(); | 330 | 322M | for (size_t i = begin; i < end; i++) { | 331 | 320M | size_t p = tree_samples.Property<false>(prop_idx, i); | 332 | 320M | prop_value_used_count[p]++; | 333 | 320M | last_used = std::max(last_used, p); | 334 | 320M | first_used = std::min(first_used, p); | 335 | 320M | } | 336 | 1.83M | } | 337 | 2.25M | costs_l.resize(last_used - first_used); | 338 | 2.25M | costs_r.resize(last_used - first_used); | 339 | | // For all predictors, compute the right and left costs of each split. | 340 | 4.51M | for (size_t pred = 0; pred < num_predictors; pred++) { | 341 | | // Compute cost and histogram increments for each property value. | 342 | 2.25M | const std::vector<ResidualToken> &rtokens = | 343 | 2.25M | tree_samples.RTokens(pred); | 344 | 2.25M | if (prop < tree_samples.NumStaticProps()) { | 345 | 418k | CollectExtraBitsIncrease<true>(tree_samples, rtokens, | 346 | 418k | count_increase, extra_bits_increase, | 347 | 418k | begin, end, prop, max_symbols); | 348 | 1.83M | } else { | 349 | 1.83M | CollectExtraBitsIncrease<false>( | 350 | 1.83M | tree_samples, rtokens, count_increase, extra_bits_increase, | 351 | 1.83M | begin, end, prop - tree_samples.NumStaticProps(), max_symbols); | 352 | 1.83M | } | 353 | 2.25M | memcpy(counts_above.data(), counts.data() + pred * max_symbols, | 354 | 2.25M | max_symbols * sizeof counts_above[0]); | 355 | 2.25M | memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); | 356 | 2.25M | size_t extra_bits_below = 0; | 357 | | // Exclude last used: this ensures neither counts_above nor | 358 | | // counts_below is empty. | 359 | 32.7M | for (size_t i = first_used; i < last_used; i++) { | 360 | 30.4M | if (!prop_value_used_count[i]) continue; | 361 | 13.2M | extra_bits_below += extra_bits_increase[i]; | 362 | | // The increase for this property value has been used, and will not | 363 | | // be used again: clear it. Also below. | 364 | 13.2M | extra_bits_increase[i] = 0; | 365 | 674M | for (size_t sym = 0; sym < max_symbols; sym++) { | 366 | 660M | counts_above[sym] -= count_increase[i * max_symbols + sym]; | 367 | 660M | counts_below[sym] += count_increase[i * max_symbols + sym]; | 368 | 660M | count_increase[i * max_symbols + sym] = 0; | 369 | 660M | } | 370 | 13.2M | float rcost = EstimateBits(counts_above.data(), max_symbols) + | 371 | 13.2M | tot_extra_bits[pred] - extra_bits_below; | 372 | 13.2M | float lcost = EstimateBits(counts_below.data(), max_symbols) + | 373 | 13.2M | extra_bits_below; | 374 | 13.2M | JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); | 375 | 13.2M | float penalty = 0; | 376 | | // Never discourage moving away from the Weighted predictor. | 377 | 13.2M | if (tree_samples.PredictorFromIndex(pred) != | 378 | 13.2M | (*tree)[pos].predictor && | 379 | 0 | (*tree)[pos].predictor != Predictor::Weighted) { | 380 | 0 | penalty = change_pred_penalty; | 381 | 0 | } | 382 | | // If everything else is equal, disfavour Weighted (slower) and | 383 | | // favour Zero (faster if it's the only predictor used in a | 384 | | // group+channel combination) | 385 | 13.2M | if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { | 386 | 0 | penalty += 1e-8; | 387 | 0 | } | 388 | 13.2M | if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { | 389 | 0 | penalty -= 1e-8; | 390 | 0 | } | 391 | 13.2M | if (rcost + penalty < costs_r[i - first_used].Cost()) { | 392 | 13.2M | costs_r[i - first_used].cost = rcost; | 393 | 13.2M | costs_r[i - first_used].extra_cost = penalty; | 394 | 13.2M | costs_r[i - first_used].pred = | 395 | 13.2M | tree_samples.PredictorFromIndex(pred); | 396 | 13.2M | } | 397 | 13.2M | if (lcost + penalty < costs_l[i - first_used].Cost()) { | 398 | 13.2M | costs_l[i - first_used].cost = lcost; | 399 | 13.2M | costs_l[i - first_used].extra_cost = penalty; | 400 | 13.2M | costs_l[i - first_used].pred = | 401 | 13.2M | tree_samples.PredictorFromIndex(pred); | 402 | 13.2M | } | 403 | 13.2M | } | 404 | 2.25M | } | 405 | | // Iterate through the possible splits and find the one with minimum sum | 406 | | // of costs of the two sides. | 407 | 2.25M | size_t split = begin; | 408 | 32.7M | for (size_t i = first_used; i < last_used; i++) { | 409 | 30.4M | if (!prop_value_used_count[i]) continue; | 410 | 13.2M | split += prop_value_used_count[i]; | 411 | 13.2M | float rcost = costs_r[i - first_used].cost; | 412 | 13.2M | float lcost = costs_l[i - first_used].cost; | 413 | | | 414 | 13.2M | bool uses_wp = tree_samples.PropertyFromIndex(prop) == kWPProp || | 415 | 11.0M | costs_l[i - first_used].pred == Predictor::Weighted || | 416 | 11.0M | costs_r[i - first_used].pred == Predictor::Weighted; | 417 | 13.2M | bool zero_entropy_side = rcost == 0 || lcost == 0; | 418 | | | 419 | 13.2M | SplitInfo &best_ref = | 420 | 13.2M | tree_samples.PropertyFromIndex(prop) < kNumStaticProperties | 421 | 13.2M | ? (zero_entropy_side ? best_split_static_constant | 422 | 90.6k | : best_split_static) | 423 | 13.2M | : (uses_wp ? best_split_nonstatic : best_split_nowp); | 424 | 13.2M | if (lcost + rcost < best_ref.Cost()) { | 425 | 2.40M | best_ref.prop = prop; | 426 | 2.40M | best_ref.val = i; | 427 | 2.40M | best_ref.pos = split; | 428 | 2.40M | best_ref.lcost = lcost; | 429 | 2.40M | best_ref.lpred = costs_l[i - first_used].pred; | 430 | 2.40M | best_ref.rcost = rcost; | 431 | 2.40M | best_ref.rpred = costs_r[i - first_used].pred; | 432 | 2.40M | } | 433 | 13.2M | } | 434 | | // Clear extra_bits_increase and cost_increase for last_used. | 435 | 2.25M | extra_bits_increase[last_used] = 0; | 436 | 115M | for (size_t sym = 0; sym < max_symbols; sym++) { | 437 | 113M | count_increase[last_used * max_symbols + sym] = 0; | 438 | 113M | } | 439 | 2.25M | } | 440 | | | 441 | | // Try to avoid introducing WP. | 442 | 346k | if (best_split_nowp.Cost() + threshold < base_bits && | 443 | 166k | best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { | 444 | 159k | best = &best_split_nowp; | 445 | 159k | } | 446 | | // Split along static props if possible and not significantly more | 447 | | // expensive. | 448 | 346k | if (best_split_static.Cost() + threshold < base_bits && | 449 | 30.1k | best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { | 450 | 6.40k | best = &best_split_static; | 451 | 6.40k | } | 452 | | // Split along static props to create constant nodes if possible. | 453 | 346k | if (best_split_static_constant.Cost() + threshold < base_bits) { | 454 | 489 | best = &best_split_static_constant; | 455 | 489 | } | 456 | 346k | } | 457 | | | 458 | 346k | if (best->Cost() + threshold < base_bits) { | 459 | 172k | uint32_t p = tree_samples.PropertyFromIndex(best->prop); | 460 | 172k | pixel_type dequant = | 461 | 172k | tree_samples.UnquantizeProperty(best->prop, best->val); | 462 | | // Split node and try to split children. | 463 | 172k | MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); | 464 | | // "Sort" according to winning property | 465 | 172k | if (best->prop < tree_samples.NumStaticProps()) { | 466 | 6.75k | SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop, | 467 | 6.75k | best->val); | 468 | 165k | } else { | 469 | 165k | SplitTreeSamples<false>(tree_samples, begin, best->pos, end, | 470 | 165k | best->prop - tree_samples.NumStaticProps(), | 471 | 165k | best->val); | 472 | 165k | } | 473 | 172k | auto new_sp_range = static_prop_range; | 474 | 172k | if (p < kNumStaticProperties) { | 475 | 6.75k | JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); | 476 | 6.75k | new_sp_range[p][1] = dequant + 1; | 477 | 6.75k | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); | 478 | 6.75k | } | 479 | 172k | nodes.push_back( | 480 | 172k | NodeInfo{(*tree)[pos].rchild, begin, best->pos, new_sp_range}); | 481 | 172k | new_sp_range = static_prop_range; | 482 | 172k | if (p < kNumStaticProperties) { | 483 | 6.75k | JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); | 484 | 6.75k | new_sp_range[p][0] = dequant + 1; | 485 | 6.75k | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); | 486 | 6.75k | } | 487 | 172k | nodes.push_back( | 488 | 172k | NodeInfo{(*tree)[pos].lchild, best->pos, end, new_sp_range}); | 489 | 172k | } | 490 | 346k | } | 491 | 2.28k | } |
Unexecuted instantiation: jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) |
492 | | |
493 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
494 | | } // namespace HWY_NAMESPACE |
495 | | } // namespace jxl |
496 | | HWY_AFTER_NAMESPACE(); |
497 | | |
498 | | #if HWY_ONCE |
499 | | namespace jxl { |
500 | | |
501 | | HWY_EXPORT(FindBestSplit); // Local function. |
502 | | |
503 | | Status ComputeBestTree(TreeSamples &tree_samples, float threshold, |
504 | | const std::vector<ModularMultiplierInfo> &mul_info, |
505 | | StaticPropRange static_prop_range, |
506 | 2.28k | float fast_decode_multiplier, Tree *tree) { |
507 | | // TODO(veluca): take into account that different contexts can have different |
508 | | // uint configs. |
509 | | // |
510 | | // Initialize tree. |
511 | 2.28k | tree->emplace_back(); |
512 | 2.28k | tree->back().property = -1; |
513 | 2.28k | tree->back().predictor = tree_samples.PredictorFromIndex(0); |
514 | 2.28k | tree->back().predictor_offset = 0; |
515 | 2.28k | tree->back().multiplier = 1; |
516 | 2.28k | JXL_ENSURE(tree_samples.NumProperties() < 64); |
517 | | |
518 | 2.28k | JXL_ENSURE(tree_samples.NumDistinctSamples() <= |
519 | 2.28k | std::numeric_limits<uint32_t>::max()); |
520 | 2.28k | HWY_DYNAMIC_DISPATCH(FindBestSplit) |
521 | 2.28k | (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, |
522 | 2.28k | tree); |
523 | 2.28k | return true; |
524 | 2.28k | } |
525 | | |
526 | | #if JXL_CXX_LANG < JXL_CXX_17 |
527 | | constexpr int32_t TreeSamples::kPropertyRange; |
528 | | constexpr uint32_t TreeSamples::kDedupEntryUnused; |
529 | | #endif |
530 | | |
531 | | Status TreeSamples::SetPredictor(Predictor predictor, |
532 | 2.44k | ModularOptions::TreeMode wp_tree_mode) { |
533 | 2.44k | if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { |
534 | 0 | predictors = {Predictor::Weighted}; |
535 | 0 | residuals.resize(1); |
536 | 0 | return true; |
537 | 0 | } |
538 | 2.44k | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && |
539 | 0 | predictor == Predictor::Weighted) { |
540 | 0 | return JXL_FAILURE("Invalid predictor settings"); |
541 | 0 | } |
542 | 2.44k | if (predictor == Predictor::Variable) { |
543 | 0 | for (size_t i = 0; i < kNumModularPredictors; i++) { |
544 | 0 | predictors.push_back(static_cast<Predictor>(i)); |
545 | 0 | } |
546 | 0 | std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); |
547 | 0 | std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); |
548 | 2.44k | } else if (predictor == Predictor::Best) { |
549 | 0 | predictors = {Predictor::Weighted, Predictor::Gradient}; |
550 | 2.44k | } else { |
551 | 2.44k | predictors = {predictor}; |
552 | 2.44k | } |
553 | 2.44k | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { |
554 | 0 | predictors.erase( |
555 | 0 | std::remove(predictors.begin(), predictors.end(), Predictor::Weighted), |
556 | 0 | predictors.end()); |
557 | 0 | } |
558 | 2.44k | residuals.resize(predictors.size()); |
559 | 2.44k | return true; |
560 | 2.44k | } |
561 | | |
562 | | Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, |
563 | 2.44k | ModularOptions::TreeMode wp_tree_mode) { |
564 | 2.44k | props_to_use = properties; |
565 | 2.44k | if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { |
566 | 0 | props_to_use = {static_cast<uint32_t>(kWPProp)}; |
567 | 0 | } |
568 | 2.44k | if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { |
569 | 0 | props_to_use = {static_cast<uint32_t>(kGradientProp)}; |
570 | 0 | } |
571 | 2.44k | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { |
572 | 0 | props_to_use.erase( |
573 | 0 | std::remove(props_to_use.begin(), props_to_use.end(), kWPProp), |
574 | 0 | props_to_use.end()); |
575 | 0 | } |
576 | 2.44k | if (props_to_use.empty()) { |
577 | 0 | return JXL_FAILURE("Invalid property set configuration"); |
578 | 0 | } |
579 | 2.44k | num_static_props = 0; |
580 | | // Check that if static properties present, then those are at the beginning. |
581 | 19.5k | for (size_t i = 0; i < props_to_use.size(); ++i) { |
582 | 17.1k | uint32_t prop = props_to_use[i]; |
583 | 17.1k | if (prop < kNumStaticProperties) { |
584 | 3.53k | JXL_DASSERT(i == prop); |
585 | 3.53k | num_static_props++; |
586 | 3.53k | } |
587 | 17.1k | } |
588 | 2.44k | props.resize(props_to_use.size() - num_static_props); |
589 | 2.44k | return true; |
590 | 2.44k | } |
591 | | |
592 | 5.66k | void TreeSamples::InitTable(size_t log_size) { |
593 | 5.66k | size_t size = 1ULL << log_size; |
594 | 5.66k | if (dedup_table_.size() == size) return; |
595 | 4.08k | dedup_table_.resize(size, kDedupEntryUnused); |
596 | 5.17M | for (size_t i = 0; i < NumDistinctSamples(); i++) { |
597 | 5.16M | if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { |
598 | 5.16M | AddToTable(i); |
599 | 5.16M | } |
600 | 5.16M | } |
601 | 4.08k | } |
602 | | |
603 | 44.6M | bool TreeSamples::AddToTableAndMerge(size_t a) { |
604 | 44.6M | size_t pos1 = Hash1(a); |
605 | 44.6M | size_t pos2 = Hash2(a); |
606 | 44.6M | if (dedup_table_[pos1] != kDedupEntryUnused && |
607 | 40.8M | IsSameSample(a, dedup_table_[pos1])) { |
608 | 33.9M | JXL_DASSERT(sample_counts[a] == 1); |
609 | 33.9M | sample_counts[dedup_table_[pos1]]++; |
610 | | // Remove from hash table samples that are saturated. |
611 | 33.9M | if (sample_counts[dedup_table_[pos1]] == |
612 | 33.9M | std::numeric_limits<uint16_t>::max()) { |
613 | 18 | dedup_table_[pos1] = kDedupEntryUnused; |
614 | 18 | } |
615 | 33.9M | return true; |
616 | 33.9M | } |
617 | 10.6M | if (dedup_table_[pos2] != kDedupEntryUnused && |
618 | 6.18M | IsSameSample(a, dedup_table_[pos2])) { |
619 | 4.51M | JXL_DASSERT(sample_counts[a] == 1); |
620 | 4.51M | sample_counts[dedup_table_[pos2]]++; |
621 | | // Remove from hash table samples that are saturated. |
622 | 4.51M | if (sample_counts[dedup_table_[pos2]] == |
623 | 4.51M | std::numeric_limits<uint16_t>::max()) { |
624 | 0 | dedup_table_[pos2] = kDedupEntryUnused; |
625 | 0 | } |
626 | 4.51M | return true; |
627 | 4.51M | } |
628 | 6.11M | AddToTable(a); |
629 | 6.11M | return false; |
630 | 10.6M | } |
631 | | |
632 | 11.2M | void TreeSamples::AddToTable(size_t a) { |
633 | 11.2M | size_t pos1 = Hash1(a); |
634 | 11.2M | size_t pos2 = Hash2(a); |
635 | 11.2M | if (dedup_table_[pos1] == kDedupEntryUnused) { |
636 | 5.46M | dedup_table_[pos1] = a; |
637 | 5.81M | } else if (dedup_table_[pos2] == kDedupEntryUnused) { |
638 | 2.99M | dedup_table_[pos2] = a; |
639 | 2.99M | } |
640 | 11.2M | } |
641 | | |
642 | 5.66k | void TreeSamples::PrepareForSamples(size_t extra_num_samples) { |
643 | 5.66k | for (auto &res : residuals) { |
644 | 5.66k | res.reserve(res.size() + extra_num_samples); |
645 | 5.66k | } |
646 | 14.3k | for (size_t i = 0; i < num_static_props; ++i) { |
647 | 8.66k | static_props[i].reserve(static_props[i].size() + extra_num_samples); |
648 | 8.66k | } |
649 | 30.9k | for (auto &p : props) { |
650 | 30.9k | p.reserve(p.size() + extra_num_samples); |
651 | 30.9k | } |
652 | 5.66k | size_t total_num_samples = extra_num_samples + sample_counts.size(); |
653 | 5.66k | size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2); |
654 | 5.66k | InitTable(next_size); |
655 | 5.66k | } |
656 | | |
657 | 55.8M | size_t TreeSamples::Hash1(size_t a) const { |
658 | 55.8M | constexpr uint64_t constant = 0x1e35a7bd; |
659 | 55.8M | uint64_t h = constant; |
660 | 55.8M | for (const auto &r : residuals) { |
661 | 55.8M | h = h * constant + r[a].tok; |
662 | 55.8M | h = h * constant + r[a].nbits; |
663 | 55.8M | } |
664 | 130M | for (size_t i = 0; i < num_static_props; ++i) { |
665 | 74.4M | h = h * constant + static_props[i][a]; |
666 | 74.4M | } |
667 | 316M | for (const auto &p : props) { |
668 | 316M | h = h * constant + p[a]; |
669 | 316M | } |
670 | 55.8M | return (h >> 16) & (dedup_table_.size() - 1); |
671 | 55.8M | } |
672 | 55.8M | size_t TreeSamples::Hash2(size_t a) const { |
673 | 55.8M | constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; |
674 | 55.8M | uint64_t h = constant; |
675 | 130M | for (size_t i = 0; i < num_static_props; ++i) { |
676 | 74.4M | h = h * constant ^ static_props[i][a]; |
677 | 74.4M | } |
678 | 316M | for (const auto &p : props) { |
679 | 316M | h = h * constant ^ p[a]; |
680 | 316M | } |
681 | 55.8M | for (const auto &r : residuals) { |
682 | 55.8M | h = h * constant ^ r[a].tok; |
683 | 55.8M | h = h * constant ^ r[a].nbits; |
684 | 55.8M | } |
685 | 55.8M | return (h >> 16) & (dedup_table_.size() - 1); |
686 | 55.8M | } |
687 | | |
688 | 47.0M | bool TreeSamples::IsSameSample(size_t a, size_t b) const { |
689 | 47.0M | bool ret = true; |
690 | 47.0M | for (const auto &r : residuals) { |
691 | 47.0M | if (r[a].tok != r[b].tok) { |
692 | 3.93M | ret = false; |
693 | 3.93M | } |
694 | 47.0M | if (r[a].nbits != r[b].nbits) { |
695 | 2.94M | ret = false; |
696 | 2.94M | } |
697 | 47.0M | } |
698 | 109M | for (size_t i = 0; i < num_static_props; ++i) { |
699 | 61.9M | if (static_props[i][a] != static_props[i][b]) { |
700 | 2.47M | ret = false; |
701 | 2.47M | } |
702 | 61.9M | } |
703 | 267M | for (const auto &p : props) { |
704 | 267M | if (p[a] != p[b]) { |
705 | 23.7M | ret = false; |
706 | 23.7M | } |
707 | 267M | } |
708 | 47.0M | return ret; |
709 | 47.0M | } |
710 | | |
711 | | void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, |
712 | 44.6M | const pixel_type_w *predictions) { |
713 | 89.2M | for (size_t i = 0; i < predictors.size(); i++) { |
714 | 44.6M | pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; |
715 | 44.6M | uint32_t tok, nbits, bits; |
716 | 44.6M | HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); |
717 | 44.6M | JXL_DASSERT(tok < 256); |
718 | 44.6M | JXL_DASSERT(nbits < 256); |
719 | 44.6M | ResidualToken token = {static_cast<uint8_t>(tok), |
720 | 44.6M | static_cast<uint8_t>(nbits)}; |
721 | 44.6M | residuals[i].push_back(token); |
722 | 44.6M | } |
723 | 103M | for (size_t i = 0; i < num_static_props; ++i) { |
724 | 59.0M | static_props[i].push_back(QuantizeStaticProperty(i, properties[i])); |
725 | 59.0M | } |
726 | 297M | for (size_t i = num_static_props; i < props_to_use.size(); i++) { |
727 | 253M | props[i - num_static_props].push_back(QuantizeProperty(i, properties[props_to_use[i]])); |
728 | 253M | } |
729 | 44.6M | sample_counts.push_back(1); |
730 | 44.6M | num_samples++; |
731 | 44.6M | if (AddToTableAndMerge(sample_counts.size() - 1)) { |
732 | 38.4M | for (auto &r : residuals) r.pop_back(); |
733 | 88.7M | for (size_t i = 0; i < num_static_props; ++i) static_props[i].pop_back(); |
734 | 219M | for (auto &p : props) p.pop_back(); |
735 | 38.4M | sample_counts.pop_back(); |
736 | 38.4M | } |
737 | 44.6M | } |
738 | | |
739 | 9.72M | void TreeSamples::Swap(size_t a, size_t b) { |
740 | 9.72M | if (a == b) return; |
741 | 9.72M | for (auto &r : residuals) { |
742 | 9.72M | std::swap(r[a], r[b]); |
743 | 9.72M | } |
744 | 22.9M | for (size_t i = 0; i < num_static_props; ++i) { |
745 | 13.2M | std::swap(static_props[i][a], static_props[i][b]); |
746 | 13.2M | } |
747 | 54.8M | for (auto &p : props) { |
748 | 54.8M | std::swap(p[a], p[b]); |
749 | 54.8M | } |
750 | 9.72M | std::swap(sample_counts[a], sample_counts[b]); |
751 | 9.72M | } |
752 | | |
753 | | namespace { |
754 | | std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, |
755 | 7.43k | size_t num_chunks) { |
756 | 7.43k | if (histogram.empty() || num_chunks == 0) return {}; |
757 | 7.43k | uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); |
758 | 7.43k | if (sum == 0) return {}; |
759 | | // TODO(veluca): selecting distinct quantiles is likely not the best |
760 | | // way to go about this. |
761 | 7.20k | std::vector<int32_t> thresholds; |
762 | 7.20k | uint64_t cumsum = 0; |
763 | 7.20k | uint64_t threshold = 1; |
764 | 4.01M | for (size_t i = 0; i < histogram.size(); i++) { |
765 | 4.00M | cumsum += histogram[i]; |
766 | 4.00M | if (cumsum * num_chunks >= threshold * sum) { |
767 | 42.7k | thresholds.push_back(i); |
768 | 734k | while (cumsum * num_chunks >= threshold * sum) threshold++; |
769 | 42.7k | } |
770 | 4.00M | } |
771 | 7.20k | JXL_DASSERT(thresholds.size() <= num_chunks); |
772 | | // last value collects all histogram and is not really a threshold |
773 | 7.20k | thresholds.pop_back(); |
774 | 7.20k | return thresholds; |
775 | 7.43k | } |
776 | | |
777 | | std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, |
778 | 5.03k | size_t num_chunks) { |
779 | 5.03k | if (samples.empty()) return {}; |
780 | 3.90k | int min = *std::min_element(samples.begin(), samples.end()); |
781 | 3.90k | constexpr int kRange = 512; |
782 | 3.90k | min = jxl::Clamp1(min, -kRange, kRange); |
783 | 3.90k | std::vector<uint32_t> counts(2 * kRange + 1); |
784 | 6.63M | for (int s : samples) { |
785 | 6.63M | uint32_t sample_offset = jxl::Clamp1(s, -kRange, kRange) - min; |
786 | 6.63M | counts[sample_offset]++; |
787 | 6.63M | } |
788 | 3.90k | std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); |
789 | 33.3k | for (auto &v : thresholds) v += min; |
790 | 3.90k | return thresholds; |
791 | 5.03k | } |
792 | | |
793 | | // `to[i]` is assigned value `v` conforming `from[v] <= i && from[v-1] > i`. |
794 | | // This is because the decision node in the tree splits on (property) > i, |
795 | | // hence everything that is not > of a threshold should be clustered |
796 | | // together. |
797 | | template <typename T> |
798 | | void QuantMap(const std::vector<int32_t> &from, std::vector<T> &to, |
799 | 17.1k | size_t num_pegs, int bias) { |
800 | 17.1k | to.resize(num_pegs); |
801 | 17.1k | size_t mapped = 0; |
802 | 17.5M | for (size_t i = 0; i < num_pegs; i++) { |
803 | 17.8M | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { |
804 | 285k | mapped++; |
805 | 285k | } |
806 | 17.5M | JXL_DASSERT(static_cast<T>(mapped) == mapped); |
807 | 17.5M | to[i] = mapped; |
808 | 17.5M | } |
809 | 17.1k | } enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned short>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned short, std::__1::allocator<unsigned short> >&, unsigned long, int) Line | Count | Source | 799 | 3.53k | size_t num_pegs, int bias) { | 800 | 3.53k | to.resize(num_pegs); | 801 | 3.53k | size_t mapped = 0; | 802 | 3.61M | for (size_t i = 0; i < num_pegs; i++) { | 803 | 3.61M | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { | 804 | 2.14k | mapped++; | 805 | 2.14k | } | 806 | 3.61M | JXL_DASSERT(static_cast<T>(mapped) == mapped); | 807 | 3.61M | to[i] = mapped; | 808 | 3.61M | } | 809 | 3.53k | } |
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned char>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> >&, unsigned long, int) Line | Count | Source | 799 | 13.6k | size_t num_pegs, int bias) { | 800 | 13.6k | to.resize(num_pegs); | 801 | 13.6k | size_t mapped = 0; | 802 | 13.9M | for (size_t i = 0; i < num_pegs; i++) { | 803 | 14.2M | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { | 804 | 283k | mapped++; | 805 | 283k | } | 806 | 13.9M | JXL_DASSERT(static_cast<T>(mapped) == mapped); | 807 | 13.9M | to[i] = mapped; | 808 | 13.9M | } | 809 | 13.6k | } |
|
810 | | } // namespace |
811 | | |
812 | | void TreeSamples::PreQuantizeProperties( |
813 | | const StaticPropRange &range, |
814 | | const std::vector<ModularMultiplierInfo> &multiplier_info, |
815 | | const std::vector<uint32_t> &group_pixel_count, |
816 | | const std::vector<uint32_t> &channel_pixel_count, |
817 | | std::vector<pixel_type> &pixel_samples, |
818 | 2.44k | std::vector<pixel_type> &diff_samples, size_t max_property_values) { |
819 | | // If we have forced splits because of multipliers, choose channel and group |
820 | | // thresholds accordingly. |
821 | 2.44k | std::vector<int32_t> group_multiplier_thresholds; |
822 | 2.44k | std::vector<int32_t> channel_multiplier_thresholds; |
823 | 2.44k | for (const auto &v : multiplier_info) { |
824 | 785 | if (v.range[0][0] != range[0][0]) { |
825 | 0 | channel_multiplier_thresholds.push_back(v.range[0][0] - 1); |
826 | 0 | } |
827 | 785 | if (v.range[0][1] != range[0][1]) { |
828 | 0 | channel_multiplier_thresholds.push_back(v.range[0][1] - 1); |
829 | 0 | } |
830 | 785 | if (v.range[1][0] != range[1][0]) { |
831 | 0 | group_multiplier_thresholds.push_back(v.range[1][0] - 1); |
832 | 0 | } |
833 | 785 | if (v.range[1][1] != range[1][1]) { |
834 | 0 | group_multiplier_thresholds.push_back(v.range[1][1] - 1); |
835 | 0 | } |
836 | 785 | } |
837 | 2.44k | std::sort(channel_multiplier_thresholds.begin(), |
838 | 2.44k | channel_multiplier_thresholds.end()); |
839 | 2.44k | channel_multiplier_thresholds.resize( |
840 | 2.44k | std::unique(channel_multiplier_thresholds.begin(), |
841 | 2.44k | channel_multiplier_thresholds.end()) - |
842 | 2.44k | channel_multiplier_thresholds.begin()); |
843 | 2.44k | std::sort(group_multiplier_thresholds.begin(), |
844 | 2.44k | group_multiplier_thresholds.end()); |
845 | 2.44k | group_multiplier_thresholds.resize( |
846 | 2.44k | std::unique(group_multiplier_thresholds.begin(), |
847 | 2.44k | group_multiplier_thresholds.end()) - |
848 | 2.44k | group_multiplier_thresholds.begin()); |
849 | | |
850 | 2.44k | compact_properties.resize(props_to_use.size()); |
851 | 2.44k | auto quantize_channel = [&]() { |
852 | 2.44k | if (!channel_multiplier_thresholds.empty()) { |
853 | 0 | return channel_multiplier_thresholds; |
854 | 0 | } |
855 | 2.44k | return QuantizeHistogram(channel_pixel_count, max_property_values); |
856 | 2.44k | }; |
857 | 2.44k | auto quantize_group_id = [&]() { |
858 | 1.08k | if (!group_multiplier_thresholds.empty()) { |
859 | 0 | return group_multiplier_thresholds; |
860 | 0 | } |
861 | 1.08k | return QuantizeHistogram(group_pixel_count, max_property_values); |
862 | 1.08k | }; |
863 | 2.44k | auto quantize_coordinate = [&]() { |
864 | 0 | std::vector<int32_t> quantized; |
865 | 0 | quantized.reserve(max_property_values - 1); |
866 | 0 | for (size_t i = 0; i + 1 < max_property_values; i++) { |
867 | 0 | quantized.push_back((i + 1) * 256 / max_property_values - 1); |
868 | 0 | } |
869 | 0 | return quantized; |
870 | 0 | }; |
871 | 2.44k | std::vector<int32_t> abs_pixel_thresholds; |
872 | 2.44k | std::vector<int32_t> pixel_thresholds; |
873 | 2.44k | auto quantize_pixel_property = [&]() { |
874 | 0 | if (pixel_thresholds.empty()) { |
875 | 0 | pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values); |
876 | 0 | } |
877 | 0 | return pixel_thresholds; |
878 | 0 | }; |
879 | 2.44k | auto quantize_abs_pixel_property = [&]() { |
880 | 0 | if (abs_pixel_thresholds.empty()) { |
881 | 0 | quantize_pixel_property(); // Compute the non-abs thresholds. |
882 | 0 | for (auto &v : pixel_samples) v = std::abs(v); |
883 | 0 | abs_pixel_thresholds = |
884 | 0 | QuantizeSamples(pixel_samples, max_property_values); |
885 | 0 | } |
886 | 0 | return abs_pixel_thresholds; |
887 | 0 | }; |
888 | 2.44k | std::vector<int32_t> abs_diff_thresholds; |
889 | 2.44k | std::vector<int32_t> diff_thresholds; |
890 | 11.1k | auto quantize_diff_property = [&]() { |
891 | 11.1k | if (diff_thresholds.empty()) { |
892 | 5.03k | diff_thresholds = QuantizeSamples(diff_samples, max_property_values); |
893 | 5.03k | } |
894 | 11.1k | return diff_thresholds; |
895 | 11.1k | }; |
896 | 2.44k | auto quantize_abs_diff_property = [&]() { |
897 | 0 | if (abs_diff_thresholds.empty()) { |
898 | 0 | quantize_diff_property(); // Compute the non-abs thresholds. |
899 | 0 | for (auto &v : diff_samples) v = std::abs(v); |
900 | 0 | abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values); |
901 | 0 | } |
902 | 0 | return abs_diff_thresholds; |
903 | 0 | }; |
904 | 2.44k | auto quantize_wp = [&]() { |
905 | 2.44k | if (max_property_values < 32) { |
906 | 0 | return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, |
907 | 0 | 1, 3, 7, 15, 31, 63, 127}; |
908 | 0 | } |
909 | 2.44k | if (max_property_values < 64) { |
910 | 0 | return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, |
911 | 0 | -15, -11, -7, -5, -3, -1, 0, 1, |
912 | 0 | 3, 5, 7, 11, 15, 23, 31, 47, |
913 | 0 | 63, 95, 127, 191, 255}; |
914 | 0 | } |
915 | 2.44k | return std::vector<int32_t>{ |
916 | 2.44k | -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, |
917 | 2.44k | -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, |
918 | 2.44k | -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, |
919 | 2.44k | 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, |
920 | 2.44k | 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; |
921 | 2.44k | }; |
922 | | |
923 | 2.44k | property_mapping.resize(props_to_use.size() - num_static_props); |
924 | 19.5k | for (size_t i = 0; i < props_to_use.size(); i++) { |
925 | 17.1k | if (props_to_use[i] == 0) { |
926 | 2.44k | compact_properties[i] = quantize_channel(); |
927 | 14.6k | } else if (props_to_use[i] == 1) { |
928 | 1.08k | compact_properties[i] = quantize_group_id(); |
929 | 13.6k | } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { |
930 | 0 | compact_properties[i] = quantize_coordinate(); |
931 | 13.6k | } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || |
932 | 13.6k | props_to_use[i] == 8 || |
933 | 13.6k | (props_to_use[i] >= kNumNonrefProperties && |
934 | 0 | (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { |
935 | 0 | compact_properties[i] = quantize_pixel_property(); |
936 | 13.6k | } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || |
937 | 13.6k | (props_to_use[i] >= kNumNonrefProperties && |
938 | 0 | (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { |
939 | 0 | compact_properties[i] = quantize_abs_pixel_property(); |
940 | 13.6k | } else if (props_to_use[i] >= kNumNonrefProperties && |
941 | 0 | (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { |
942 | 0 | compact_properties[i] = quantize_abs_diff_property(); |
943 | 13.6k | } else if (props_to_use[i] == kWPProp) { |
944 | 2.44k | compact_properties[i] = quantize_wp(); |
945 | 11.1k | } else { |
946 | 11.1k | compact_properties[i] = quantize_diff_property(); |
947 | 11.1k | } |
948 | 17.1k | if (i < num_static_props) { |
949 | 3.53k | QuantMap(compact_properties[i], static_property_mapping[i], |
950 | 3.53k | kPropertyRange * 2 + 1, kPropertyRange); |
951 | 13.6k | } else { |
952 | 13.6k | QuantMap(compact_properties[i], property_mapping[i - num_static_props], |
953 | 13.6k | kPropertyRange * 2 + 1, kPropertyRange); |
954 | 13.6k | } |
955 | 17.1k | } |
956 | 2.44k | } |
957 | | |
958 | | void CollectPixelSamples(const Image &image, const ModularOptions &options, |
959 | | uint32_t group_id, |
960 | | std::vector<uint32_t> &group_pixel_count, |
961 | | std::vector<uint32_t> &channel_pixel_count, |
962 | | std::vector<pixel_type> &pixel_samples, |
963 | 3.57k | std::vector<pixel_type> &diff_samples) { |
964 | 3.57k | if (options.nb_repeats == 0) return; |
965 | 3.57k | if (group_pixel_count.size() <= group_id) { |
966 | 3.57k | group_pixel_count.resize(group_id + 1); |
967 | 3.57k | } |
968 | 3.57k | if (channel_pixel_count.size() < image.channel.size()) { |
969 | 2.49k | channel_pixel_count.resize(image.channel.size()); |
970 | 2.49k | } |
971 | 3.57k | Rng rng(group_id); |
972 | | // Sample 10% of the final number of samples for property quantization. |
973 | 3.57k | float fraction = std::min(options.nb_repeats * 0.1, 0.99); |
974 | 3.57k | Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction); |
975 | 3.57k | size_t total_pixels = 0; |
976 | 3.57k | std::vector<size_t> channel_ids; |
977 | 9.23k | for (size_t i = 0; i < image.channel.size(); i++) { |
978 | 6.10k | if (i >= image.nb_meta_channels && |
979 | 5.48k | (image.channel[i].w > options.max_chan_size || |
980 | 5.07k | image.channel[i].h > options.max_chan_size)) { |
981 | 443 | break; |
982 | 443 | } |
983 | 5.66k | if (image.channel[i].w <= 1 || image.channel[i].h == 0) { |
984 | 60 | continue; // skip empty or width-1 channels. |
985 | 60 | } |
986 | 5.60k | channel_ids.push_back(i); |
987 | 5.60k | group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; |
988 | 5.60k | channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; |
989 | 5.60k | total_pixels += image.channel[i].w * image.channel[i].h; |
990 | 5.60k | } |
991 | 3.57k | if (channel_ids.empty()) return; |
992 | 3.34k | pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); |
993 | 3.34k | diff_samples.reserve(diff_samples.size() + fraction * total_pixels); |
994 | 3.34k | size_t i = 0; |
995 | 3.34k | size_t y = 0; |
996 | 3.34k | size_t x = 0; |
997 | 4.39M | auto advance = [&](size_t amount) { |
998 | 4.39M | x += amount; |
999 | | // Detect row overflow (rare). |
1000 | 4.92M | while (x >= image.channel[channel_ids[i]].w) { |
1001 | 533k | x -= image.channel[channel_ids[i]].w; |
1002 | 533k | y++; |
1003 | | // Detect end-of-channel (even rarer). |
1004 | 533k | if (y == image.channel[channel_ids[i]].h) { |
1005 | 5.60k | i++; |
1006 | 5.60k | y = 0; |
1007 | 5.60k | if (i >= channel_ids.size()) { |
1008 | 3.34k | return; |
1009 | 3.34k | } |
1010 | 5.60k | } |
1011 | 533k | } |
1012 | 4.39M | }; |
1013 | 3.34k | advance(rng.Geometric(dist)); |
1014 | 4.39M | for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { |
1015 | 4.39M | const pixel_type *row = image.channel[channel_ids[i]].Row(y); |
1016 | 4.39M | pixel_samples.push_back(row[x]); |
1017 | 4.39M | size_t xp = x == 0 ? 1 : x - 1; |
1018 | 4.39M | diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]); |
1019 | 4.39M | } |
1020 | 3.34k | } |
1021 | | |
1022 | | // TODO(veluca): very simple encoding scheme. This should be improved. |
1023 | | Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens, |
1024 | 3.88k | Tree *decoder_tree) { |
1025 | 3.88k | JXL_ENSURE(tree.size() <= kMaxTreeSize); |
1026 | 3.88k | std::queue<int> q; |
1027 | 3.88k | q.push(0); |
1028 | 3.88k | size_t leaf_id = 0; |
1029 | 3.88k | decoder_tree->clear(); |
1030 | 422k | while (!q.empty()) { |
1031 | 418k | int cur = q.front(); |
1032 | 418k | q.pop(); |
1033 | 418k | JXL_ENSURE(tree[cur].property >= -1); |
1034 | 418k | tokens->emplace_back(kPropertyContext, tree[cur].property + 1); |
1035 | 418k | if (tree[cur].property == -1) { |
1036 | 211k | tokens->emplace_back(kPredictorContext, |
1037 | 211k | static_cast<int>(tree[cur].predictor)); |
1038 | 211k | tokens->emplace_back(kOffsetContext, |
1039 | 211k | PackSigned(tree[cur].predictor_offset)); |
1040 | 211k | uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); |
1041 | 211k | uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; |
1042 | 211k | tokens->emplace_back(kMultiplierLogContext, mul_log); |
1043 | 211k | tokens->emplace_back(kMultiplierBitsContext, mul_bits); |
1044 | 211k | JXL_ENSURE(tree[cur].predictor < Predictor::Best); |
1045 | 211k | decoder_tree->emplace_back( |
1046 | 211k | -1, 0, static_cast<int>(leaf_id), 0, tree[cur].predictor, |
1047 | 211k | tree[cur].predictor_offset, tree[cur].multiplier); |
1048 | 211k | leaf_id++; |
1049 | 211k | continue; |
1050 | 211k | } |
1051 | 207k | decoder_tree->emplace_back( |
1052 | 207k | tree[cur].property, tree[cur].splitval, |
1053 | 207k | static_cast<int>(decoder_tree->size() + q.size() + 1), |
1054 | 207k | static_cast<int>(decoder_tree->size() + q.size() + 2), Predictor::Zero, |
1055 | 207k | 0, 1); |
1056 | 207k | q.push(tree[cur].lchild); |
1057 | 207k | q.push(tree[cur].rchild); |
1058 | 207k | tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); |
1059 | 207k | } |
1060 | 3.88k | return true; |
1061 | 3.88k | } |
1062 | | |
1063 | | } // namespace jxl |
1064 | | #endif // HWY_ONCE |