/src/libjxl/lib/jxl/modular/encoding/enc_ma.cc
Line | Count | Source |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/modular/encoding/enc_ma.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cstdint> |
10 | | #include <cstdlib> |
11 | | #include <cstring> |
12 | | #include <limits> |
13 | | #include <numeric> |
14 | | #include <queue> |
15 | | #include <vector> |
16 | | |
17 | | #include "lib/jxl/ans_params.h" |
18 | | #include "lib/jxl/base/bits.h" |
19 | | #include "lib/jxl/base/common.h" |
20 | | #include "lib/jxl/base/compiler_specific.h" |
21 | | #include "lib/jxl/base/status.h" |
22 | | #include "lib/jxl/dec_ans.h" |
23 | | #include "lib/jxl/modular/encoding/dec_ma.h" |
24 | | #include "lib/jxl/modular/encoding/ma_common.h" |
25 | | #include "lib/jxl/modular/modular_image.h" |
26 | | |
27 | | #undef HWY_TARGET_INCLUDE |
28 | | #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" |
29 | | #include <hwy/foreach_target.h> |
30 | | #include <hwy/highway.h> |
31 | | |
32 | | #include "lib/jxl/base/fast_math-inl.h" |
33 | | #include "lib/jxl/base/random.h" |
34 | | #include "lib/jxl/enc_ans.h" |
35 | | #include "lib/jxl/modular/encoding/context_predict.h" |
36 | | #include "lib/jxl/modular/options.h" |
37 | | #include "lib/jxl/pack_signed.h" |
38 | | HWY_BEFORE_NAMESPACE(); |
39 | | namespace jxl { |
40 | | namespace HWY_NAMESPACE { |
41 | | |
42 | | // These templates are not found via ADL. |
43 | | using hwy::HWY_NAMESPACE::Eq; |
44 | | using hwy::HWY_NAMESPACE::IfThenElse; |
45 | | using hwy::HWY_NAMESPACE::Lt; |
46 | | using hwy::HWY_NAMESPACE::Max; |
47 | | |
48 | | const HWY_FULL(float) df; |
49 | | const HWY_FULL(int32_t) di; |
50 | 326k | size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }Unexecuted instantiation: jxl::N_SSE4::Padded(unsigned long) jxl::N_AVX2::Padded(unsigned long) Line | Count | Source | 50 | 326k | size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } |
Unexecuted instantiation: jxl::N_SSE2::Padded(unsigned long) |
51 | | |
52 | | // Compute entropy of the histogram, taking into account the minimum probability |
53 | | // for symbols with non-zero counts. |
54 | 23.9M | float EstimateBits(const int32_t *counts, size_t num_symbols) { |
55 | 23.9M | int32_t total = std::accumulate(counts, counts + num_symbols, 0); |
56 | 23.9M | const auto zero = Zero(df); |
57 | 23.9M | const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); |
58 | 23.9M | const auto inv_total = Set(df, 1.0f / total); |
59 | 23.9M | auto bits_lanes = Zero(df); |
60 | 23.9M | auto total_v = Set(di, total); |
61 | 170M | for (size_t i = 0; i < num_symbols; i += Lanes(df)) { |
62 | 146M | const auto counts_iv = LoadU(di, &counts[i]); |
63 | 146M | const auto counts_fv = ConvertTo(df, counts_iv); |
64 | 146M | const auto probs = Mul(counts_fv, inv_total); |
65 | 146M | const auto mprobs = Max(probs, minprob); |
66 | 146M | const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), |
67 | 146M | BitCast(di, FastLog2f(df, mprobs))); |
68 | 146M | bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); |
69 | 146M | } |
70 | 23.9M | return GetLane(SumOfLanes(df, bits_lanes)); |
71 | 23.9M | } Unexecuted instantiation: jxl::N_SSE4::EstimateBits(int const*, unsigned long) jxl::N_AVX2::EstimateBits(int const*, unsigned long) Line | Count | Source | 54 | 23.9M | float EstimateBits(const int32_t *counts, size_t num_symbols) { | 55 | 23.9M | int32_t total = std::accumulate(counts, counts + num_symbols, 0); | 56 | 23.9M | const auto zero = Zero(df); | 57 | 23.9M | const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); | 58 | 23.9M | const auto inv_total = Set(df, 1.0f / total); | 59 | 23.9M | auto bits_lanes = Zero(df); | 60 | 23.9M | auto total_v = Set(di, total); | 61 | 170M | for (size_t i = 0; i < num_symbols; i += Lanes(df)) { | 62 | 146M | const auto counts_iv = LoadU(di, &counts[i]); | 63 | 146M | const auto counts_fv = ConvertTo(df, counts_iv); | 64 | 146M | const auto probs = Mul(counts_fv, inv_total); | 65 | 146M | const auto mprobs = Max(probs, minprob); | 66 | 146M | const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), | 67 | 146M | BitCast(di, FastLog2f(df, mprobs))); | 68 | 146M | bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); | 69 | 146M | } | 70 | 23.9M | return GetLane(SumOfLanes(df, bits_lanes)); | 71 | 23.9M | } |
Unexecuted instantiation: jxl::N_SSE2::EstimateBits(int const*, unsigned long) |
72 | | |
73 | | void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, |
74 | 162k | int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { |
75 | | // Note that the tree splits on *strictly greater*. |
76 | 162k | (*tree)[pos].lchild = tree->size(); |
77 | 162k | (*tree)[pos].rchild = tree->size() + 1; |
78 | 162k | (*tree)[pos].splitval = splitval; |
79 | 162k | (*tree)[pos].property = property; |
80 | 162k | tree->emplace_back(); |
81 | 162k | tree->back().property = -1; |
82 | 162k | tree->back().predictor = rpred; |
83 | 162k | tree->back().predictor_offset = roff; |
84 | 162k | tree->back().multiplier = 1; |
85 | 162k | tree->emplace_back(); |
86 | 162k | tree->back().property = -1; |
87 | 162k | tree->back().predictor = lpred; |
88 | 162k | tree->back().predictor_offset = loff; |
89 | 162k | tree->back().multiplier = 1; |
90 | 162k | } Unexecuted instantiation: jxl::N_SSE4::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) jxl::N_AVX2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) Line | Count | Source | 74 | 162k | int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { | 75 | | // Note that the tree splits on *strictly greater*. | 76 | 162k | (*tree)[pos].lchild = tree->size(); | 77 | 162k | (*tree)[pos].rchild = tree->size() + 1; | 78 | 162k | (*tree)[pos].splitval = splitval; | 79 | 162k | (*tree)[pos].property = property; | 80 | 162k | tree->emplace_back(); | 81 | 162k | tree->back().property = -1; | 82 | 162k | tree->back().predictor = rpred; | 83 | 162k | tree->back().predictor_offset = roff; | 84 | 162k | tree->back().multiplier = 1; | 85 | 162k | tree->emplace_back(); | 86 | 162k | tree->back().property = -1; | 87 | 162k | tree->back().predictor = lpred; | 88 | 162k | tree->back().predictor_offset = loff; | 89 | 162k | tree->back().multiplier = 1; | 90 | 162k | } |
Unexecuted instantiation: jxl::N_SSE2::MakeSplitNode(unsigned long, int, int, jxl::Predictor, long, jxl::Predictor, long, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) |
91 | | |
92 | | enum class IntersectionType { kNone, kPartial, kInside }; |
93 | | IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, |
94 | 86.2k | uint32_t &partial_axis, uint32_t &partial_val) { |
95 | 86.2k | bool partial = false; |
96 | 258k | for (size_t i = 0; i < kNumStaticProperties; i++) { |
97 | 172k | if (haystack[i][0] >= needle[i][1]) { |
98 | 0 | return IntersectionType::kNone; |
99 | 0 | } |
100 | 172k | if (haystack[i][1] <= needle[i][0]) { |
101 | 0 | return IntersectionType::kNone; |
102 | 0 | } |
103 | 172k | if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { |
104 | 172k | continue; |
105 | 172k | } |
106 | 0 | partial = true; |
107 | 0 | partial_axis = i; |
108 | 0 | if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { |
109 | 0 | partial_val = haystack[i][0] - 1; |
110 | 0 | } else { |
111 | 0 | JXL_DASSERT(haystack[i][1] > needle[i][0] && |
112 | 0 | haystack[i][1] < needle[i][1]); |
113 | 0 | partial_val = haystack[i][1] - 1; |
114 | 0 | } |
115 | 0 | } |
116 | 86.2k | return partial ? IntersectionType::kPartial : IntersectionType::kInside; |
117 | 86.2k | } Unexecuted instantiation: jxl::N_SSE4::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) jxl::N_AVX2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) Line | Count | Source | 94 | 86.2k | uint32_t &partial_axis, uint32_t &partial_val) { | 95 | 86.2k | bool partial = false; | 96 | 258k | for (size_t i = 0; i < kNumStaticProperties; i++) { | 97 | 172k | if (haystack[i][0] >= needle[i][1]) { | 98 | 0 | return IntersectionType::kNone; | 99 | 0 | } | 100 | 172k | if (haystack[i][1] <= needle[i][0]) { | 101 | 0 | return IntersectionType::kNone; | 102 | 0 | } | 103 | 172k | if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { | 104 | 172k | continue; | 105 | 172k | } | 106 | 0 | partial = true; | 107 | 0 | partial_axis = i; | 108 | 0 | if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { | 109 | 0 | partial_val = haystack[i][0] - 1; | 110 | 0 | } else { | 111 | 0 | JXL_DASSERT(haystack[i][1] > needle[i][0] && | 112 | 0 | haystack[i][1] < needle[i][1]); | 113 | 0 | partial_val = haystack[i][1] - 1; | 114 | 0 | } | 115 | 0 | } | 116 | 86.2k | return partial ? IntersectionType::kPartial : IntersectionType::kInside; | 117 | 86.2k | } |
Unexecuted instantiation: jxl::N_SSE2::BoxIntersects(std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, unsigned int&, unsigned int&) |
118 | | |
119 | | template<bool S> |
120 | | void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, |
121 | 162k | size_t end, size_t prop, uint32_t val) { |
122 | 162k | size_t begin_pos = begin; |
123 | 162k | size_t end_pos = pos; |
124 | 8.17M | do { |
125 | 20.7M | while (begin_pos < pos && |
126 | 20.6M | tree_samples.Property<S>(prop, begin_pos) <= val) { |
127 | 12.5M | ++begin_pos; |
128 | 12.5M | } |
129 | 21.9M | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { |
130 | 13.7M | ++end_pos; |
131 | 13.7M | } |
132 | 8.17M | if (begin_pos < pos && end_pos < end) { |
133 | 8.14M | tree_samples.Swap(begin_pos, end_pos); |
134 | 8.14M | } |
135 | 8.17M | ++begin_pos; |
136 | 8.17M | ++end_pos; |
137 | 8.17M | } while (begin_pos < pos && end_pos < end); |
138 | 162k | } Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Unexecuted instantiation: void jxl::N_SSE4::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) void jxl::N_AVX2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Line | Count | Source | 121 | 3.52k | size_t end, size_t prop, uint32_t val) { | 122 | 3.52k | size_t begin_pos = begin; | 123 | 3.52k | size_t end_pos = pos; | 124 | 411k | do { | 125 | 2.55M | while (begin_pos < pos && | 126 | 2.55M | tree_samples.Property<S>(prop, begin_pos) <= val) { | 127 | 2.14M | ++begin_pos; | 128 | 2.14M | } | 129 | 1.58M | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { | 130 | 1.17M | ++end_pos; | 131 | 1.17M | } | 132 | 411k | if (begin_pos < pos && end_pos < end) { | 133 | 409k | tree_samples.Swap(begin_pos, end_pos); | 134 | 409k | } | 135 | 411k | ++begin_pos; | 136 | 411k | ++end_pos; | 137 | 411k | } while (begin_pos < pos && end_pos < end); | 138 | 3.52k | } |
void jxl::N_AVX2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Line | Count | Source | 121 | 158k | size_t end, size_t prop, uint32_t val) { | 122 | 158k | size_t begin_pos = begin; | 123 | 158k | size_t end_pos = pos; | 124 | 7.76M | do { | 125 | 18.1M | while (begin_pos < pos && | 126 | 18.1M | tree_samples.Property<S>(prop, begin_pos) <= val) { | 127 | 10.4M | ++begin_pos; | 128 | 10.4M | } | 129 | 20.3M | while (end_pos < end && tree_samples.Property<S>(prop, end_pos) > val) { | 130 | 12.5M | ++end_pos; | 131 | 12.5M | } | 132 | 7.76M | if (begin_pos < pos && end_pos < end) { | 133 | 7.73M | tree_samples.Swap(begin_pos, end_pos); | 134 | 7.73M | } | 135 | 7.76M | ++begin_pos; | 136 | 7.76M | ++end_pos; | 137 | 7.76M | } while (begin_pos < pos && end_pos < end); | 138 | 158k | } |
Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<true>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) Unexecuted instantiation: void jxl::N_SSE2::SplitTreeSamples<false>(jxl::TreeSamples&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned int) |
139 | | |
140 | | template <bool S> |
141 | | void CollectExtraBitsIncrease(TreeSamples &tree_samples, |
142 | | const std::vector<ResidualToken> &rtokens, |
143 | | std::vector<int> &count_increase, |
144 | | std::vector<size_t> &extra_bits_increase, |
145 | | size_t begin, size_t end, size_t prop_idx, |
146 | 2.09M | size_t max_symbols) { |
147 | 337M | for (size_t i2 = begin; i2 < end; i2++) { |
148 | 335M | const ResidualToken &rt = rtokens[i2]; |
149 | 335M | size_t cnt = tree_samples.Count(i2); |
150 | 335M | size_t p = tree_samples.Property<S>(prop_idx, i2); |
151 | 335M | size_t sym = rt.tok; |
152 | 335M | size_t ebi = rt.nbits * cnt; |
153 | 335M | count_increase[p * max_symbols + sym] += cnt; |
154 | 335M | extra_bits_increase[p] += ebi; |
155 | 335M | } |
156 | 2.09M | } Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Unexecuted instantiation: void jxl::N_SSE4::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) void jxl::N_AVX2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Line | Count | Source | 146 | 384k | size_t max_symbols) { | 147 | 68.8M | for (size_t i2 = begin; i2 < end; i2++) { | 148 | 68.4M | const ResidualToken &rt = rtokens[i2]; | 149 | 68.4M | size_t cnt = tree_samples.Count(i2); | 150 | 68.4M | size_t p = tree_samples.Property<S>(prop_idx, i2); | 151 | 68.4M | size_t sym = rt.tok; | 152 | 68.4M | size_t ebi = rt.nbits * cnt; | 153 | 68.4M | count_increase[p * max_symbols + sym] += cnt; | 154 | 68.4M | extra_bits_increase[p] += ebi; | 155 | 68.4M | } | 156 | 384k | } |
void jxl::N_AVX2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Line | Count | Source | 146 | 1.71M | size_t max_symbols) { | 147 | 269M | for (size_t i2 = begin; i2 < end; i2++) { | 148 | 267M | const ResidualToken &rt = rtokens[i2]; | 149 | 267M | size_t cnt = tree_samples.Count(i2); | 150 | 267M | size_t p = tree_samples.Property<S>(prop_idx, i2); | 151 | 267M | size_t sym = rt.tok; | 152 | 267M | size_t ebi = rt.nbits * cnt; | 153 | 267M | count_increase[p * max_symbols + sym] += cnt; | 154 | 267M | extra_bits_increase[p] += ebi; | 155 | 267M | } | 156 | 1.71M | } |
Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<true>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) Unexecuted instantiation: void jxl::N_SSE2::CollectExtraBitsIncrease<false>(jxl::TreeSamples&, std::__1::vector<jxl::ResidualToken, std::__1::allocator<jxl::ResidualToken> > const&, std::__1::vector<int, std::__1::allocator<int> >&, std::__1::vector<unsigned long, std::__1::allocator<unsigned long> >&, unsigned long, unsigned long, unsigned long, unsigned long) |
157 | | |
158 | | void FindBestSplit(TreeSamples &tree_samples, float threshold, |
159 | | const std::vector<ModularMultiplierInfo> &mul_info, |
160 | | StaticPropRange initial_static_prop_range, |
161 | 1.91k | float fast_decode_multiplier, Tree *tree) { |
162 | 1.91k | struct NodeInfo { |
163 | 1.91k | size_t pos; |
164 | 1.91k | size_t begin; |
165 | 1.91k | size_t end; |
166 | 1.91k | StaticPropRange static_prop_range; |
167 | 1.91k | }; |
168 | 1.91k | std::vector<NodeInfo> nodes; |
169 | 1.91k | nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), |
170 | 1.91k | initial_static_prop_range}); |
171 | | |
172 | 1.91k | size_t num_predictors = tree_samples.NumPredictors(); |
173 | 1.91k | size_t num_properties = tree_samples.NumProperties(); |
174 | | |
175 | | // TODO(veluca): consider parallelizing the search (processing multiple nodes |
176 | | // at a time). |
177 | 327k | while (!nodes.empty()) { |
178 | 326k | size_t pos = nodes.back().pos; |
179 | 326k | size_t begin = nodes.back().begin; |
180 | 326k | size_t end = nodes.back().end; |
181 | | |
182 | 326k | StaticPropRange static_prop_range = nodes.back().static_prop_range; |
183 | 326k | nodes.pop_back(); |
184 | 326k | if (begin == end) continue; |
185 | | |
186 | 326k | struct SplitInfo { |
187 | 326k | size_t prop = 0; |
188 | 326k | uint32_t val = 0; |
189 | 326k | size_t pos = 0; |
190 | 326k | float lcost = std::numeric_limits<float>::max(); |
191 | 326k | float rcost = std::numeric_limits<float>::max(); |
192 | 326k | Predictor lpred = Predictor::Zero; |
193 | 326k | Predictor rpred = Predictor::Zero; |
194 | 13.4M | float Cost() const { return lcost + rcost; }Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const Line | Count | Source | 194 | 13.4M | float Cost() const { return lcost + rcost; } |
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::SplitInfo::Cost() const |
195 | 326k | }; |
196 | | |
197 | 326k | SplitInfo best_split_static_constant; |
198 | 326k | SplitInfo best_split_static; |
199 | 326k | SplitInfo best_split_nonstatic; |
200 | 326k | SplitInfo best_split_nowp; |
201 | | |
202 | 326k | JXL_DASSERT(begin <= end); |
203 | 326k | JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); |
204 | | |
205 | | // Compute the maximum token in the range. |
206 | 326k | size_t max_symbols = 0; |
207 | 652k | for (size_t pred = 0; pred < num_predictors; pred++) { |
208 | 48.5M | for (size_t i = begin; i < end; i++) { |
209 | 48.2M | uint32_t tok = tree_samples.Token(pred, i); |
210 | 48.2M | max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; |
211 | 48.2M | } |
212 | 326k | } |
213 | 326k | max_symbols = Padded(max_symbols); |
214 | 326k | std::vector<int32_t> counts(max_symbols * num_predictors); |
215 | 326k | std::vector<uint32_t> tot_extra_bits(num_predictors); |
216 | 652k | for (size_t pred = 0; pred < num_predictors; pred++) { |
217 | 326k | size_t extra_bits = 0; |
218 | 326k | const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred); |
219 | 48.5M | for (size_t i = begin; i < end; i++) { |
220 | 48.2M | const ResidualToken& rt = rtokens[i]; |
221 | 48.2M | size_t count = tree_samples.Count(i); |
222 | 48.2M | size_t eb = rt.nbits * count; |
223 | 48.2M | counts[pred * max_symbols + rt.tok] += count; |
224 | 48.2M | extra_bits += eb; |
225 | 48.2M | } |
226 | 326k | tot_extra_bits[pred] = extra_bits; |
227 | 326k | } |
228 | | |
229 | 326k | float base_bits; |
230 | 326k | { |
231 | 326k | size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); |
232 | 326k | base_bits = |
233 | 326k | EstimateBits(counts.data() + pred * max_symbols, max_symbols) + |
234 | 326k | tot_extra_bits[pred]; |
235 | 326k | } |
236 | | |
237 | 326k | SplitInfo *best = &best_split_nonstatic; |
238 | | |
239 | 326k | SplitInfo forced_split; |
240 | | // The multiplier ranges cut halfway through the current ranges of static |
241 | | // properties. We do this even if the current node is not a leaf, to |
242 | | // minimize the number of nodes in the resulting tree. |
243 | 326k | for (const auto &mmi : mul_info) { |
244 | 86.2k | uint32_t axis; |
245 | 86.2k | uint32_t val; |
246 | 86.2k | IntersectionType t = |
247 | 86.2k | BoxIntersects(static_prop_range, mmi.range, axis, val); |
248 | 86.2k | if (t == IntersectionType::kNone) continue; |
249 | 86.2k | if (t == IntersectionType::kInside) { |
250 | 86.2k | (*tree)[pos].multiplier = mmi.multiplier; |
251 | 86.2k | break; |
252 | 86.2k | } |
253 | 0 | if (t == IntersectionType::kPartial) { |
254 | 0 | JXL_DASSERT(axis < kNumStaticProperties); |
255 | 0 | forced_split.val = tree_samples.QuantizeStaticProperty(axis, val); |
256 | 0 | forced_split.prop = axis; |
257 | 0 | forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; |
258 | 0 | forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; |
259 | 0 | best = &forced_split; |
260 | 0 | best->pos = begin; |
261 | 0 | JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); |
262 | 0 | if (best->prop < tree_samples.NumStaticProps()) { |
263 | 0 | for (size_t x = begin; x < end; x++) { |
264 | 0 | if (tree_samples.Property<true>(best->prop, x) <= best->val) { |
265 | 0 | best->pos++; |
266 | 0 | } |
267 | 0 | } |
268 | 0 | } else { |
269 | 0 | size_t prop = best->prop - tree_samples.NumStaticProps(); |
270 | 0 | for (size_t x = begin; x < end; x++) { |
271 | 0 | if (tree_samples.Property<false>(prop, x) <= best->val) { |
272 | 0 | best->pos++; |
273 | 0 | } |
274 | 0 | } |
275 | 0 | } |
276 | 0 | break; |
277 | 0 | } |
278 | 0 | } |
279 | | |
280 | 326k | if (best != &forced_split) { |
281 | 326k | std::vector<int> prop_value_used_count; |
282 | 326k | std::vector<int> count_increase; |
283 | 326k | std::vector<size_t> extra_bits_increase; |
284 | | // For each property, compute which of its values are used, and what |
285 | | // tokens correspond to those usages. Then, iterate through the values, |
286 | | // and compute the entropy of each side of the split (of the form `prop > |
287 | | // threshold`). Finally, find the split that minimizes the cost. |
288 | 326k | struct CostInfo { |
289 | 326k | float cost = std::numeric_limits<float>::max(); |
290 | 326k | float extra_cost = 0; |
291 | 23.6M | float Cost() const { return cost + extra_cost; }Unexecuted instantiation: enc_ma.cc:jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const enc_ma.cc:jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const Line | Count | Source | 291 | 23.6M | float Cost() const { return cost + extra_cost; } |
Unexecuted instantiation: enc_ma.cc:jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*)::CostInfo::Cost() const |
292 | 326k | Predictor pred; // will be uninitialized in some cases, but never used. |
293 | 326k | }; |
294 | 326k | std::vector<CostInfo> costs_l; |
295 | 326k | std::vector<CostInfo> costs_r; |
296 | | |
297 | 326k | std::vector<int32_t> counts_above(max_symbols); |
298 | 326k | std::vector<int32_t> counts_below(max_symbols); |
299 | | |
300 | | // The lower the threshold, the higher the expected noisiness of the |
301 | | // estimate. Thus, discourage changing predictors. |
302 | 326k | float change_pred_penalty = 800.0f / (100.0f + threshold); |
303 | 2.42M | for (size_t prop = 0; prop < num_properties && base_bits > threshold; |
304 | 2.09M | prop++) { |
305 | 2.09M | costs_l.clear(); |
306 | 2.09M | costs_r.clear(); |
307 | 2.09M | size_t prop_size = tree_samples.NumPropertyValues(prop); |
308 | 2.09M | if (extra_bits_increase.size() < prop_size) { |
309 | 763k | count_increase.resize(prop_size * max_symbols); |
310 | 763k | extra_bits_increase.resize(prop_size); |
311 | 763k | } |
312 | | // Clear prop_value_used_count (which cannot be cleared "on the go") |
313 | 2.09M | prop_value_used_count.clear(); |
314 | 2.09M | prop_value_used_count.resize(prop_size); |
315 | | |
316 | 2.09M | size_t first_used = prop_size; |
317 | 2.09M | size_t last_used = 0; |
318 | | |
319 | | // TODO(veluca): consider finding multiple splits along a single |
320 | | // property at the same time, possibly with a bottom-up approach. |
321 | 2.09M | if (prop < tree_samples.NumStaticProps()) { |
322 | 68.8M | for (size_t i = begin; i < end; i++) { |
323 | 68.4M | size_t p = tree_samples.Property<true>(prop, i); |
324 | 68.4M | prop_value_used_count[p]++; |
325 | 68.4M | last_used = std::max(last_used, p); |
326 | 68.4M | first_used = std::min(first_used, p); |
327 | 68.4M | } |
328 | 1.71M | } else { |
329 | 1.71M | size_t prop_idx = prop - tree_samples.NumStaticProps(); |
330 | 269M | for (size_t i = begin; i < end; i++) { |
331 | 267M | size_t p = tree_samples.Property<false>(prop_idx, i); |
332 | 267M | prop_value_used_count[p]++; |
333 | 267M | last_used = std::max(last_used, p); |
334 | 267M | first_used = std::min(first_used, p); |
335 | 267M | } |
336 | 1.71M | } |
337 | 2.09M | costs_l.resize(last_used - first_used); |
338 | 2.09M | costs_r.resize(last_used - first_used); |
339 | | // For all predictors, compute the right and left costs of each split. |
340 | 4.19M | for (size_t pred = 0; pred < num_predictors; pred++) { |
341 | | // Compute cost and histogram increments for each property value. |
342 | 2.09M | const std::vector<ResidualToken> &rtokens = |
343 | 2.09M | tree_samples.RTokens(pred); |
344 | 2.09M | if (prop < tree_samples.NumStaticProps()) { |
345 | 384k | CollectExtraBitsIncrease<true>(tree_samples, rtokens, |
346 | 384k | count_increase, extra_bits_increase, |
347 | 384k | begin, end, prop, max_symbols); |
348 | 1.71M | } else { |
349 | 1.71M | CollectExtraBitsIncrease<false>( |
350 | 1.71M | tree_samples, rtokens, count_increase, extra_bits_increase, |
351 | 1.71M | begin, end, prop - tree_samples.NumStaticProps(), max_symbols); |
352 | 1.71M | } |
353 | 2.09M | memcpy(counts_above.data(), counts.data() + pred * max_symbols, |
354 | 2.09M | max_symbols * sizeof counts_above[0]); |
355 | 2.09M | memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); |
356 | 2.09M | size_t extra_bits_below = 0; |
357 | | // Exclude last used: this ensures neither counts_above nor |
358 | | // counts_below is empty. |
359 | 29.9M | for (size_t i = first_used; i < last_used; i++) { |
360 | 27.8M | if (!prop_value_used_count[i]) continue; |
361 | 11.8M | extra_bits_below += extra_bits_increase[i]; |
362 | | // The increase for this property value has been used, and will not |
363 | | // be used again: clear it. Also below. |
364 | 11.8M | extra_bits_increase[i] = 0; |
365 | 588M | for (size_t sym = 0; sym < max_symbols; sym++) { |
366 | 576M | counts_above[sym] -= count_increase[i * max_symbols + sym]; |
367 | 576M | counts_below[sym] += count_increase[i * max_symbols + sym]; |
368 | 576M | count_increase[i * max_symbols + sym] = 0; |
369 | 576M | } |
370 | 11.8M | float rcost = EstimateBits(counts_above.data(), max_symbols) + |
371 | 11.8M | tot_extra_bits[pred] - extra_bits_below; |
372 | 11.8M | float lcost = EstimateBits(counts_below.data(), max_symbols) + |
373 | 11.8M | extra_bits_below; |
374 | 11.8M | JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); |
375 | 11.8M | float penalty = 0; |
376 | | // Never discourage moving away from the Weighted predictor. |
377 | 11.8M | if (tree_samples.PredictorFromIndex(pred) != |
378 | 11.8M | (*tree)[pos].predictor && |
379 | 0 | (*tree)[pos].predictor != Predictor::Weighted) { |
380 | 0 | penalty = change_pred_penalty; |
381 | 0 | } |
382 | | // If everything else is equal, disfavour Weighted (slower) and |
383 | | // favour Zero (faster if it's the only predictor used in a |
384 | | // group+channel combination) |
385 | 11.8M | if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { |
386 | 0 | penalty += 1e-8; |
387 | 0 | } |
388 | 11.8M | if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { |
389 | 0 | penalty -= 1e-8; |
390 | 0 | } |
391 | 11.8M | if (rcost + penalty < costs_r[i - first_used].Cost()) { |
392 | 11.8M | costs_r[i - first_used].cost = rcost; |
393 | 11.8M | costs_r[i - first_used].extra_cost = penalty; |
394 | 11.8M | costs_r[i - first_used].pred = |
395 | 11.8M | tree_samples.PredictorFromIndex(pred); |
396 | 11.8M | } |
397 | 11.8M | if (lcost + penalty < costs_l[i - first_used].Cost()) { |
398 | 11.8M | costs_l[i - first_used].cost = lcost; |
399 | 11.8M | costs_l[i - first_used].extra_cost = penalty; |
400 | 11.8M | costs_l[i - first_used].pred = |
401 | 11.8M | tree_samples.PredictorFromIndex(pred); |
402 | 11.8M | } |
403 | 11.8M | } |
404 | 2.09M | } |
405 | | // Iterate through the possible splits and find the one with minimum sum |
406 | | // of costs of the two sides. |
407 | 2.09M | size_t split = begin; |
408 | 29.9M | for (size_t i = first_used; i < last_used; i++) { |
409 | 27.8M | if (!prop_value_used_count[i]) continue; |
410 | 11.8M | split += prop_value_used_count[i]; |
411 | 11.8M | float rcost = costs_r[i - first_used].cost; |
412 | 11.8M | float lcost = costs_l[i - first_used].cost; |
413 | | |
414 | 11.8M | bool uses_wp = tree_samples.PropertyFromIndex(prop) == kWPProp || |
415 | 9.70M | costs_l[i - first_used].pred == Predictor::Weighted || |
416 | 9.70M | costs_r[i - first_used].pred == Predictor::Weighted; |
417 | 11.8M | bool zero_entropy_side = rcost == 0 || lcost == 0; |
418 | | |
419 | 11.8M | SplitInfo &best_ref = |
420 | 11.8M | tree_samples.PropertyFromIndex(prop) < kNumStaticProperties |
421 | 11.8M | ? (zero_entropy_side ? best_split_static_constant |
422 | 22.2k | : best_split_static) |
423 | 11.8M | : (uses_wp ? best_split_nonstatic : best_split_nowp); |
424 | 11.8M | if (lcost + rcost < best_ref.Cost()) { |
425 | 2.16M | best_ref.prop = prop; |
426 | 2.16M | best_ref.val = i; |
427 | 2.16M | best_ref.pos = split; |
428 | 2.16M | best_ref.lcost = lcost; |
429 | 2.16M | best_ref.lpred = costs_l[i - first_used].pred; |
430 | 2.16M | best_ref.rcost = rcost; |
431 | 2.16M | best_ref.rpred = costs_r[i - first_used].pred; |
432 | 2.16M | } |
433 | 11.8M | } |
434 | | // Clear extra_bits_increase and cost_increase for last_used. |
435 | 2.09M | extra_bits_increase[last_used] = 0; |
436 | 104M | for (size_t sym = 0; sym < max_symbols; sym++) { |
437 | 102M | count_increase[last_used * max_symbols + sym] = 0; |
438 | 102M | } |
439 | 2.09M | } |
440 | | |
441 | | // Try to avoid introducing WP. |
442 | 326k | if (best_split_nowp.Cost() + threshold < base_bits && |
443 | 157k | best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { |
444 | 151k | best = &best_split_nowp; |
445 | 151k | } |
446 | | // Split along static props if possible and not significantly more |
447 | | // expensive. |
448 | 326k | if (best_split_static.Cost() + threshold < base_bits && |
449 | 7.95k | best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { |
450 | 3.26k | best = &best_split_static; |
451 | 3.26k | } |
452 | | // Split along static props to create constant nodes if possible. |
453 | 326k | if (best_split_static_constant.Cost() + threshold < base_bits) { |
454 | 386 | best = &best_split_static_constant; |
455 | 386 | } |
456 | 326k | } |
457 | | |
458 | 326k | if (best->Cost() + threshold < base_bits) { |
459 | 162k | uint32_t p = tree_samples.PropertyFromIndex(best->prop); |
460 | 162k | pixel_type dequant = |
461 | 162k | tree_samples.UnquantizeProperty(best->prop, best->val); |
462 | | // Split node and try to split children. |
463 | 162k | MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); |
464 | | // "Sort" according to winning property |
465 | 162k | if (best->prop < tree_samples.NumStaticProps()) { |
466 | 3.52k | SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop, |
467 | 3.52k | best->val); |
468 | 158k | } else { |
469 | 158k | SplitTreeSamples<false>(tree_samples, begin, best->pos, end, |
470 | 158k | best->prop - tree_samples.NumStaticProps(), |
471 | 158k | best->val); |
472 | 158k | } |
473 | 162k | auto new_sp_range = static_prop_range; |
474 | 162k | if (p < kNumStaticProperties) { |
475 | 3.52k | JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); |
476 | 3.52k | new_sp_range[p][1] = dequant + 1; |
477 | 3.52k | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); |
478 | 3.52k | } |
479 | 162k | nodes.push_back( |
480 | 162k | NodeInfo{(*tree)[pos].rchild, begin, best->pos, new_sp_range}); |
481 | 162k | new_sp_range = static_prop_range; |
482 | 162k | if (p < kNumStaticProperties) { |
483 | 3.52k | JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); |
484 | 3.52k | new_sp_range[p][0] = dequant + 1; |
485 | 3.52k | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); |
486 | 3.52k | } |
487 | 162k | nodes.push_back( |
488 | 162k | NodeInfo{(*tree)[pos].lchild, best->pos, end, new_sp_range}); |
489 | 162k | } |
490 | 326k | } |
491 | 1.91k | } Unexecuted instantiation: jxl::N_SSE4::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) jxl::N_AVX2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) Line | Count | Source | 161 | 1.91k | float fast_decode_multiplier, Tree *tree) { | 162 | 1.91k | struct NodeInfo { | 163 | 1.91k | size_t pos; | 164 | 1.91k | size_t begin; | 165 | 1.91k | size_t end; | 166 | 1.91k | StaticPropRange static_prop_range; | 167 | 1.91k | }; | 168 | 1.91k | std::vector<NodeInfo> nodes; | 169 | 1.91k | nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), | 170 | 1.91k | initial_static_prop_range}); | 171 | | | 172 | 1.91k | size_t num_predictors = tree_samples.NumPredictors(); | 173 | 1.91k | size_t num_properties = tree_samples.NumProperties(); | 174 | | | 175 | | // TODO(veluca): consider parallelizing the search (processing multiple nodes | 176 | | // at a time). | 177 | 327k | while (!nodes.empty()) { | 178 | 326k | size_t pos = nodes.back().pos; | 179 | 326k | size_t begin = nodes.back().begin; | 180 | 326k | size_t end = nodes.back().end; | 181 | | | 182 | 326k | StaticPropRange static_prop_range = nodes.back().static_prop_range; | 183 | 326k | nodes.pop_back(); | 184 | 326k | if (begin == end) continue; | 185 | | | 186 | 326k | struct SplitInfo { | 187 | 326k | size_t prop = 0; | 188 | 326k | uint32_t val = 0; | 189 | 326k | size_t pos = 0; | 190 | 326k | float lcost = std::numeric_limits<float>::max(); | 191 | 326k | float rcost = std::numeric_limits<float>::max(); | 192 | 326k | Predictor lpred = Predictor::Zero; | 193 | 326k | Predictor rpred = Predictor::Zero; | 194 | 326k | float Cost() const { return lcost + rcost; } | 195 | 326k | }; | 196 | | | 197 | 326k | SplitInfo best_split_static_constant; | 198 | 326k | SplitInfo best_split_static; | 199 | 326k | SplitInfo best_split_nonstatic; | 200 | 326k | SplitInfo best_split_nowp; | 201 | | | 202 | 326k | JXL_DASSERT(begin <= end); | 203 | 326k | JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); | 204 | | | 205 | | // Compute the maximum token in the range. | 206 | 326k | size_t max_symbols = 0; | 207 | 652k | for (size_t pred = 0; pred < num_predictors; pred++) { | 208 | 48.5M | for (size_t i = begin; i < end; i++) { | 209 | 48.2M | uint32_t tok = tree_samples.Token(pred, i); | 210 | 48.2M | max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; | 211 | 48.2M | } | 212 | 326k | } | 213 | 326k | max_symbols = Padded(max_symbols); | 214 | 326k | std::vector<int32_t> counts(max_symbols * num_predictors); | 215 | 326k | std::vector<uint32_t> tot_extra_bits(num_predictors); | 216 | 652k | for (size_t pred = 0; pred < num_predictors; pred++) { | 217 | 326k | size_t extra_bits = 0; | 218 | 326k | const std::vector<ResidualToken>& rtokens = tree_samples.RTokens(pred); | 219 | 48.5M | for (size_t i = begin; i < end; i++) { | 220 | 48.2M | const ResidualToken& rt = rtokens[i]; | 221 | 48.2M | size_t count = tree_samples.Count(i); | 222 | 48.2M | size_t eb = rt.nbits * count; | 223 | 48.2M | counts[pred * max_symbols + rt.tok] += count; | 224 | 48.2M | extra_bits += eb; | 225 | 48.2M | } | 226 | 326k | tot_extra_bits[pred] = extra_bits; | 227 | 326k | } | 228 | | | 229 | 326k | float base_bits; | 230 | 326k | { | 231 | 326k | size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); | 232 | 326k | base_bits = | 233 | 326k | EstimateBits(counts.data() + pred * max_symbols, max_symbols) + | 234 | 326k | tot_extra_bits[pred]; | 235 | 326k | } | 236 | | | 237 | 326k | SplitInfo *best = &best_split_nonstatic; | 238 | | | 239 | 326k | SplitInfo forced_split; | 240 | | // The multiplier ranges cut halfway through the current ranges of static | 241 | | // properties. We do this even if the current node is not a leaf, to | 242 | | // minimize the number of nodes in the resulting tree. | 243 | 326k | for (const auto &mmi : mul_info) { | 244 | 86.2k | uint32_t axis; | 245 | 86.2k | uint32_t val; | 246 | 86.2k | IntersectionType t = | 247 | 86.2k | BoxIntersects(static_prop_range, mmi.range, axis, val); | 248 | 86.2k | if (t == IntersectionType::kNone) continue; | 249 | 86.2k | if (t == IntersectionType::kInside) { | 250 | 86.2k | (*tree)[pos].multiplier = mmi.multiplier; | 251 | 86.2k | break; | 252 | 86.2k | } | 253 | 0 | if (t == IntersectionType::kPartial) { | 254 | 0 | JXL_DASSERT(axis < kNumStaticProperties); | 255 | 0 | forced_split.val = tree_samples.QuantizeStaticProperty(axis, val); | 256 | 0 | forced_split.prop = axis; | 257 | 0 | forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; | 258 | 0 | forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; | 259 | 0 | best = &forced_split; | 260 | 0 | best->pos = begin; | 261 | 0 | JXL_DASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); | 262 | 0 | if (best->prop < tree_samples.NumStaticProps()) { | 263 | 0 | for (size_t x = begin; x < end; x++) { | 264 | 0 | if (tree_samples.Property<true>(best->prop, x) <= best->val) { | 265 | 0 | best->pos++; | 266 | 0 | } | 267 | 0 | } | 268 | 0 | } else { | 269 | 0 | size_t prop = best->prop - tree_samples.NumStaticProps(); | 270 | 0 | for (size_t x = begin; x < end; x++) { | 271 | 0 | if (tree_samples.Property<false>(prop, x) <= best->val) { | 272 | 0 | best->pos++; | 273 | 0 | } | 274 | 0 | } | 275 | 0 | } | 276 | 0 | break; | 277 | 0 | } | 278 | 0 | } | 279 | | | 280 | 326k | if (best != &forced_split) { | 281 | 326k | std::vector<int> prop_value_used_count; | 282 | 326k | std::vector<int> count_increase; | 283 | 326k | std::vector<size_t> extra_bits_increase; | 284 | | // For each property, compute which of its values are used, and what | 285 | | // tokens correspond to those usages. Then, iterate through the values, | 286 | | // and compute the entropy of each side of the split (of the form `prop > | 287 | | // threshold`). Finally, find the split that minimizes the cost. | 288 | 326k | struct CostInfo { | 289 | 326k | float cost = std::numeric_limits<float>::max(); | 290 | 326k | float extra_cost = 0; | 291 | 326k | float Cost() const { return cost + extra_cost; } | 292 | 326k | Predictor pred; // will be uninitialized in some cases, but never used. | 293 | 326k | }; | 294 | 326k | std::vector<CostInfo> costs_l; | 295 | 326k | std::vector<CostInfo> costs_r; | 296 | | | 297 | 326k | std::vector<int32_t> counts_above(max_symbols); | 298 | 326k | std::vector<int32_t> counts_below(max_symbols); | 299 | | | 300 | | // The lower the threshold, the higher the expected noisiness of the | 301 | | // estimate. Thus, discourage changing predictors. | 302 | 326k | float change_pred_penalty = 800.0f / (100.0f + threshold); | 303 | 2.42M | for (size_t prop = 0; prop < num_properties && base_bits > threshold; | 304 | 2.09M | prop++) { | 305 | 2.09M | costs_l.clear(); | 306 | 2.09M | costs_r.clear(); | 307 | 2.09M | size_t prop_size = tree_samples.NumPropertyValues(prop); | 308 | 2.09M | if (extra_bits_increase.size() < prop_size) { | 309 | 763k | count_increase.resize(prop_size * max_symbols); | 310 | 763k | extra_bits_increase.resize(prop_size); | 311 | 763k | } | 312 | | // Clear prop_value_used_count (which cannot be cleared "on the go") | 313 | 2.09M | prop_value_used_count.clear(); | 314 | 2.09M | prop_value_used_count.resize(prop_size); | 315 | | | 316 | 2.09M | size_t first_used = prop_size; | 317 | 2.09M | size_t last_used = 0; | 318 | | | 319 | | // TODO(veluca): consider finding multiple splits along a single | 320 | | // property at the same time, possibly with a bottom-up approach. | 321 | 2.09M | if (prop < tree_samples.NumStaticProps()) { | 322 | 68.8M | for (size_t i = begin; i < end; i++) { | 323 | 68.4M | size_t p = tree_samples.Property<true>(prop, i); | 324 | 68.4M | prop_value_used_count[p]++; | 325 | 68.4M | last_used = std::max(last_used, p); | 326 | 68.4M | first_used = std::min(first_used, p); | 327 | 68.4M | } | 328 | 1.71M | } else { | 329 | 1.71M | size_t prop_idx = prop - tree_samples.NumStaticProps(); | 330 | 269M | for (size_t i = begin; i < end; i++) { | 331 | 267M | size_t p = tree_samples.Property<false>(prop_idx, i); | 332 | 267M | prop_value_used_count[p]++; | 333 | 267M | last_used = std::max(last_used, p); | 334 | 267M | first_used = std::min(first_used, p); | 335 | 267M | } | 336 | 1.71M | } | 337 | 2.09M | costs_l.resize(last_used - first_used); | 338 | 2.09M | costs_r.resize(last_used - first_used); | 339 | | // For all predictors, compute the right and left costs of each split. | 340 | 4.19M | for (size_t pred = 0; pred < num_predictors; pred++) { | 341 | | // Compute cost and histogram increments for each property value. | 342 | 2.09M | const std::vector<ResidualToken> &rtokens = | 343 | 2.09M | tree_samples.RTokens(pred); | 344 | 2.09M | if (prop < tree_samples.NumStaticProps()) { | 345 | 384k | CollectExtraBitsIncrease<true>(tree_samples, rtokens, | 346 | 384k | count_increase, extra_bits_increase, | 347 | 384k | begin, end, prop, max_symbols); | 348 | 1.71M | } else { | 349 | 1.71M | CollectExtraBitsIncrease<false>( | 350 | 1.71M | tree_samples, rtokens, count_increase, extra_bits_increase, | 351 | 1.71M | begin, end, prop - tree_samples.NumStaticProps(), max_symbols); | 352 | 1.71M | } | 353 | 2.09M | memcpy(counts_above.data(), counts.data() + pred * max_symbols, | 354 | 2.09M | max_symbols * sizeof counts_above[0]); | 355 | 2.09M | memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); | 356 | 2.09M | size_t extra_bits_below = 0; | 357 | | // Exclude last used: this ensures neither counts_above nor | 358 | | // counts_below is empty. | 359 | 29.9M | for (size_t i = first_used; i < last_used; i++) { | 360 | 27.8M | if (!prop_value_used_count[i]) continue; | 361 | 11.8M | extra_bits_below += extra_bits_increase[i]; | 362 | | // The increase for this property value has been used, and will not | 363 | | // be used again: clear it. Also below. | 364 | 11.8M | extra_bits_increase[i] = 0; | 365 | 588M | for (size_t sym = 0; sym < max_symbols; sym++) { | 366 | 576M | counts_above[sym] -= count_increase[i * max_symbols + sym]; | 367 | 576M | counts_below[sym] += count_increase[i * max_symbols + sym]; | 368 | 576M | count_increase[i * max_symbols + sym] = 0; | 369 | 576M | } | 370 | 11.8M | float rcost = EstimateBits(counts_above.data(), max_symbols) + | 371 | 11.8M | tot_extra_bits[pred] - extra_bits_below; | 372 | 11.8M | float lcost = EstimateBits(counts_below.data(), max_symbols) + | 373 | 11.8M | extra_bits_below; | 374 | 11.8M | JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); | 375 | 11.8M | float penalty = 0; | 376 | | // Never discourage moving away from the Weighted predictor. | 377 | 11.8M | if (tree_samples.PredictorFromIndex(pred) != | 378 | 11.8M | (*tree)[pos].predictor && | 379 | 0 | (*tree)[pos].predictor != Predictor::Weighted) { | 380 | 0 | penalty = change_pred_penalty; | 381 | 0 | } | 382 | | // If everything else is equal, disfavour Weighted (slower) and | 383 | | // favour Zero (faster if it's the only predictor used in a | 384 | | // group+channel combination) | 385 | 11.8M | if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { | 386 | 0 | penalty += 1e-8; | 387 | 0 | } | 388 | 11.8M | if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { | 389 | 0 | penalty -= 1e-8; | 390 | 0 | } | 391 | 11.8M | if (rcost + penalty < costs_r[i - first_used].Cost()) { | 392 | 11.8M | costs_r[i - first_used].cost = rcost; | 393 | 11.8M | costs_r[i - first_used].extra_cost = penalty; | 394 | 11.8M | costs_r[i - first_used].pred = | 395 | 11.8M | tree_samples.PredictorFromIndex(pred); | 396 | 11.8M | } | 397 | 11.8M | if (lcost + penalty < costs_l[i - first_used].Cost()) { | 398 | 11.8M | costs_l[i - first_used].cost = lcost; | 399 | 11.8M | costs_l[i - first_used].extra_cost = penalty; | 400 | 11.8M | costs_l[i - first_used].pred = | 401 | 11.8M | tree_samples.PredictorFromIndex(pred); | 402 | 11.8M | } | 403 | 11.8M | } | 404 | 2.09M | } | 405 | | // Iterate through the possible splits and find the one with minimum sum | 406 | | // of costs of the two sides. | 407 | 2.09M | size_t split = begin; | 408 | 29.9M | for (size_t i = first_used; i < last_used; i++) { | 409 | 27.8M | if (!prop_value_used_count[i]) continue; | 410 | 11.8M | split += prop_value_used_count[i]; | 411 | 11.8M | float rcost = costs_r[i - first_used].cost; | 412 | 11.8M | float lcost = costs_l[i - first_used].cost; | 413 | | | 414 | 11.8M | bool uses_wp = tree_samples.PropertyFromIndex(prop) == kWPProp || | 415 | 9.70M | costs_l[i - first_used].pred == Predictor::Weighted || | 416 | 9.70M | costs_r[i - first_used].pred == Predictor::Weighted; | 417 | 11.8M | bool zero_entropy_side = rcost == 0 || lcost == 0; | 418 | | | 419 | 11.8M | SplitInfo &best_ref = | 420 | 11.8M | tree_samples.PropertyFromIndex(prop) < kNumStaticProperties | 421 | 11.8M | ? (zero_entropy_side ? best_split_static_constant | 422 | 22.2k | : best_split_static) | 423 | 11.8M | : (uses_wp ? best_split_nonstatic : best_split_nowp); | 424 | 11.8M | if (lcost + rcost < best_ref.Cost()) { | 425 | 2.16M | best_ref.prop = prop; | 426 | 2.16M | best_ref.val = i; | 427 | 2.16M | best_ref.pos = split; | 428 | 2.16M | best_ref.lcost = lcost; | 429 | 2.16M | best_ref.lpred = costs_l[i - first_used].pred; | 430 | 2.16M | best_ref.rcost = rcost; | 431 | 2.16M | best_ref.rpred = costs_r[i - first_used].pred; | 432 | 2.16M | } | 433 | 11.8M | } | 434 | | // Clear extra_bits_increase and cost_increase for last_used. | 435 | 2.09M | extra_bits_increase[last_used] = 0; | 436 | 104M | for (size_t sym = 0; sym < max_symbols; sym++) { | 437 | 102M | count_increase[last_used * max_symbols + sym] = 0; | 438 | 102M | } | 439 | 2.09M | } | 440 | | | 441 | | // Try to avoid introducing WP. | 442 | 326k | if (best_split_nowp.Cost() + threshold < base_bits && | 443 | 157k | best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { | 444 | 151k | best = &best_split_nowp; | 445 | 151k | } | 446 | | // Split along static props if possible and not significantly more | 447 | | // expensive. | 448 | 326k | if (best_split_static.Cost() + threshold < base_bits && | 449 | 7.95k | best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { | 450 | 3.26k | best = &best_split_static; | 451 | 3.26k | } | 452 | | // Split along static props to create constant nodes if possible. | 453 | 326k | if (best_split_static_constant.Cost() + threshold < base_bits) { | 454 | 386 | best = &best_split_static_constant; | 455 | 386 | } | 456 | 326k | } | 457 | | | 458 | 326k | if (best->Cost() + threshold < base_bits) { | 459 | 162k | uint32_t p = tree_samples.PropertyFromIndex(best->prop); | 460 | 162k | pixel_type dequant = | 461 | 162k | tree_samples.UnquantizeProperty(best->prop, best->val); | 462 | | // Split node and try to split children. | 463 | 162k | MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); | 464 | | // "Sort" according to winning property | 465 | 162k | if (best->prop < tree_samples.NumStaticProps()) { | 466 | 3.52k | SplitTreeSamples<true>(tree_samples, begin, best->pos, end, best->prop, | 467 | 3.52k | best->val); | 468 | 158k | } else { | 469 | 158k | SplitTreeSamples<false>(tree_samples, begin, best->pos, end, | 470 | 158k | best->prop - tree_samples.NumStaticProps(), | 471 | 158k | best->val); | 472 | 158k | } | 473 | 162k | auto new_sp_range = static_prop_range; | 474 | 162k | if (p < kNumStaticProperties) { | 475 | 3.52k | JXL_DASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); | 476 | 3.52k | new_sp_range[p][1] = dequant + 1; | 477 | 3.52k | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); | 478 | 3.52k | } | 479 | 162k | nodes.push_back( | 480 | 162k | NodeInfo{(*tree)[pos].rchild, begin, best->pos, new_sp_range}); | 481 | 162k | new_sp_range = static_prop_range; | 482 | 162k | if (p < kNumStaticProperties) { | 483 | 3.52k | JXL_DASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); | 484 | 3.52k | new_sp_range[p][0] = dequant + 1; | 485 | 3.52k | JXL_DASSERT(new_sp_range[p][0] < new_sp_range[p][1]); | 486 | 3.52k | } | 487 | 162k | nodes.push_back( | 488 | 162k | NodeInfo{(*tree)[pos].lchild, best->pos, end, new_sp_range}); | 489 | 162k | } | 490 | 326k | } | 491 | 1.91k | } |
Unexecuted instantiation: jxl::N_SSE2::FindBestSplit(jxl::TreeSamples&, float, std::__1::vector<jxl::ModularMultiplierInfo, std::__1::allocator<jxl::ModularMultiplierInfo> > const&, std::__1::array<std::__1::array<unsigned int, 2ul>, 2ul>, float, std::__1::vector<jxl::PropertyDecisionNode, std::__1::allocator<jxl::PropertyDecisionNode> >*) |
492 | | |
493 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
494 | | } // namespace HWY_NAMESPACE |
495 | | } // namespace jxl |
496 | | HWY_AFTER_NAMESPACE(); |
497 | | |
498 | | #if HWY_ONCE |
499 | | namespace jxl { |
500 | | |
501 | | HWY_EXPORT(FindBestSplit); // Local function. |
502 | | |
503 | | Status ComputeBestTree(TreeSamples &tree_samples, float threshold, |
504 | | const std::vector<ModularMultiplierInfo> &mul_info, |
505 | | StaticPropRange static_prop_range, |
506 | 1.91k | float fast_decode_multiplier, Tree *tree) { |
507 | | // TODO(veluca): take into account that different contexts can have different |
508 | | // uint configs. |
509 | | // |
510 | | // Initialize tree. |
511 | 1.91k | tree->emplace_back(); |
512 | 1.91k | tree->back().property = -1; |
513 | 1.91k | tree->back().predictor = tree_samples.PredictorFromIndex(0); |
514 | 1.91k | tree->back().predictor_offset = 0; |
515 | 1.91k | tree->back().multiplier = 1; |
516 | 1.91k | JXL_ENSURE(tree_samples.NumProperties() < 64); |
517 | | |
518 | 1.91k | JXL_ENSURE(tree_samples.NumDistinctSamples() <= |
519 | 1.91k | std::numeric_limits<uint32_t>::max()); |
520 | 1.91k | HWY_DYNAMIC_DISPATCH(FindBestSplit) |
521 | 1.91k | (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, |
522 | 1.91k | tree); |
523 | 1.91k | return true; |
524 | 1.91k | } |
525 | | |
526 | | #if JXL_CXX_LANG < JXL_CXX_17 |
527 | | constexpr int32_t TreeSamples::kPropertyRange; |
528 | | constexpr uint32_t TreeSamples::kDedupEntryUnused; |
529 | | #endif |
530 | | |
531 | | Status TreeSamples::SetPredictor(Predictor predictor, |
532 | 2.07k | ModularOptions::TreeMode wp_tree_mode) { |
533 | 2.07k | if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { |
534 | 0 | predictors = {Predictor::Weighted}; |
535 | 0 | residuals.resize(1); |
536 | 0 | return true; |
537 | 0 | } |
538 | 2.07k | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && |
539 | 0 | predictor == Predictor::Weighted) { |
540 | 0 | return JXL_FAILURE("Invalid predictor settings"); |
541 | 0 | } |
542 | 2.07k | if (predictor == Predictor::Variable) { |
543 | 0 | for (size_t i = 0; i < kNumModularPredictors; i++) { |
544 | 0 | predictors.push_back(static_cast<Predictor>(i)); |
545 | 0 | } |
546 | 0 | std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); |
547 | 0 | std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); |
548 | 2.07k | } else if (predictor == Predictor::Best) { |
549 | 0 | predictors = {Predictor::Weighted, Predictor::Gradient}; |
550 | 2.07k | } else { |
551 | 2.07k | predictors = {predictor}; |
552 | 2.07k | } |
553 | 2.07k | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { |
554 | 0 | predictors.erase( |
555 | 0 | std::remove(predictors.begin(), predictors.end(), Predictor::Weighted), |
556 | 0 | predictors.end()); |
557 | 0 | } |
558 | 2.07k | residuals.resize(predictors.size()); |
559 | 2.07k | return true; |
560 | 2.07k | } |
561 | | |
562 | | Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, |
563 | 2.07k | ModularOptions::TreeMode wp_tree_mode) { |
564 | 2.07k | props_to_use = properties; |
565 | 2.07k | if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { |
566 | 0 | props_to_use = {static_cast<uint32_t>(kWPProp)}; |
567 | 0 | } |
568 | 2.07k | if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { |
569 | 0 | props_to_use = {static_cast<uint32_t>(kGradientProp)}; |
570 | 0 | } |
571 | 2.07k | if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { |
572 | 0 | props_to_use.erase( |
573 | 0 | std::remove(props_to_use.begin(), props_to_use.end(), kWPProp), |
574 | 0 | props_to_use.end()); |
575 | 0 | } |
576 | 2.07k | if (props_to_use.empty()) { |
577 | 0 | return JXL_FAILURE("Invalid property set configuration"); |
578 | 0 | } |
579 | 2.07k | num_static_props = 0; |
580 | | // Check that if static properties present, then those are at the beginning. |
581 | 16.6k | for (size_t i = 0; i < props_to_use.size(); ++i) { |
582 | 14.5k | uint32_t prop = props_to_use[i]; |
583 | 14.5k | if (prop < kNumStaticProperties) { |
584 | 2.86k | JXL_DASSERT(i == prop); |
585 | 2.86k | num_static_props++; |
586 | 2.86k | } |
587 | 14.5k | } |
588 | 2.07k | props.resize(props_to_use.size() - num_static_props); |
589 | 2.07k | return true; |
590 | 2.07k | } |
591 | | |
592 | 5.11k | void TreeSamples::InitTable(size_t log_size) { |
593 | 5.11k | size_t size = 1ULL << log_size; |
594 | 5.11k | if (dedup_table_.size() == size) return; |
595 | 3.49k | dedup_table_.resize(size, kDedupEntryUnused); |
596 | 4.31M | for (size_t i = 0; i < NumDistinctSamples(); i++) { |
597 | 4.30M | if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { |
598 | 4.30M | AddToTable(i); |
599 | 4.30M | } |
600 | 4.30M | } |
601 | 3.49k | } |
602 | | |
603 | 38.5M | bool TreeSamples::AddToTableAndMerge(size_t a) { |
604 | 38.5M | size_t pos1 = Hash1(a); |
605 | 38.5M | size_t pos2 = Hash2(a); |
606 | 38.5M | if (dedup_table_[pos1] != kDedupEntryUnused && |
607 | 35.3M | IsSameSample(a, dedup_table_[pos1])) { |
608 | 29.2M | JXL_DASSERT(sample_counts[a] == 1); |
609 | 29.2M | sample_counts[dedup_table_[pos1]]++; |
610 | | // Remove from hash table samples that are saturated. |
611 | 29.2M | if (sample_counts[dedup_table_[pos1]] == |
612 | 29.2M | std::numeric_limits<uint16_t>::max()) { |
613 | 21 | dedup_table_[pos1] = kDedupEntryUnused; |
614 | 21 | } |
615 | 29.2M | return true; |
616 | 29.2M | } |
617 | 9.32M | if (dedup_table_[pos2] != kDedupEntryUnused && |
618 | 5.58M | IsSameSample(a, dedup_table_[pos2])) { |
619 | 4.17M | JXL_DASSERT(sample_counts[a] == 1); |
620 | 4.17M | sample_counts[dedup_table_[pos2]]++; |
621 | | // Remove from hash table samples that are saturated. |
622 | 4.17M | if (sample_counts[dedup_table_[pos2]] == |
623 | 4.17M | std::numeric_limits<uint16_t>::max()) { |
624 | 0 | dedup_table_[pos2] = kDedupEntryUnused; |
625 | 0 | } |
626 | 4.17M | return true; |
627 | 4.17M | } |
628 | 5.15M | AddToTable(a); |
629 | 5.15M | return false; |
630 | 9.32M | } |
631 | | |
632 | 9.45M | void TreeSamples::AddToTable(size_t a) { |
633 | 9.45M | size_t pos1 = Hash1(a); |
634 | 9.45M | size_t pos2 = Hash2(a); |
635 | 9.45M | if (dedup_table_[pos1] == kDedupEntryUnused) { |
636 | 4.58M | dedup_table_[pos1] = a; |
637 | 4.87M | } else if (dedup_table_[pos2] == kDedupEntryUnused) { |
638 | 2.56M | dedup_table_[pos2] = a; |
639 | 2.56M | } |
640 | 9.45M | } |
641 | | |
642 | 5.11k | void TreeSamples::PrepareForSamples(size_t extra_num_samples) { |
643 | 5.11k | for (auto &res : residuals) { |
644 | 5.11k | res.reserve(res.size() + extra_num_samples); |
645 | 5.11k | } |
646 | 12.9k | for (size_t i = 0; i < num_static_props; ++i) { |
647 | 7.87k | static_props[i].reserve(static_props[i].size() + extra_num_samples); |
648 | 7.87k | } |
649 | 27.9k | for (auto &p : props) { |
650 | 27.9k | p.reserve(p.size() + extra_num_samples); |
651 | 27.9k | } |
652 | 5.11k | size_t total_num_samples = extra_num_samples + sample_counts.size(); |
653 | 5.11k | size_t next_size = CeilLog2Nonzero(total_num_samples * 3 / 2); |
654 | 5.11k | InitTable(next_size); |
655 | 5.11k | } |
656 | | |
657 | 47.9M | size_t TreeSamples::Hash1(size_t a) const { |
658 | 47.9M | constexpr uint64_t constant = 0x1e35a7bd; |
659 | 47.9M | uint64_t h = constant; |
660 | 47.9M | for (const auto &r : residuals) { |
661 | 47.9M | h = h * constant + r[a].tok; |
662 | 47.9M | h = h * constant + r[a].nbits; |
663 | 47.9M | } |
664 | 108M | for (size_t i = 0; i < num_static_props; ++i) { |
665 | 60.8M | h = h * constant + static_props[i][a]; |
666 | 60.8M | } |
667 | 275M | for (const auto &p : props) { |
668 | 275M | h = h * constant + p[a]; |
669 | 275M | } |
670 | 47.9M | return (h >> 16) & (dedup_table_.size() - 1); |
671 | 47.9M | } |
672 | 47.9M | size_t TreeSamples::Hash2(size_t a) const { |
673 | 47.9M | constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; |
674 | 47.9M | uint64_t h = constant; |
675 | 108M | for (size_t i = 0; i < num_static_props; ++i) { |
676 | 60.8M | h = h * constant ^ static_props[i][a]; |
677 | 60.8M | } |
678 | 275M | for (const auto &p : props) { |
679 | 275M | h = h * constant ^ p[a]; |
680 | 275M | } |
681 | 47.9M | for (const auto &r : residuals) { |
682 | 47.9M | h = h * constant ^ r[a].tok; |
683 | 47.9M | h = h * constant ^ r[a].nbits; |
684 | 47.9M | } |
685 | 47.9M | return (h >> 16) & (dedup_table_.size() - 1); |
686 | 47.9M | } |
687 | | |
688 | 40.8M | bool TreeSamples::IsSameSample(size_t a, size_t b) const { |
689 | 40.8M | bool ret = true; |
690 | 40.8M | for (const auto &r : residuals) { |
691 | 40.8M | if (r[a].tok != r[b].tok) { |
692 | 3.25M | ret = false; |
693 | 3.25M | } |
694 | 40.8M | if (r[a].nbits != r[b].nbits) { |
695 | 2.34M | ret = false; |
696 | 2.34M | } |
697 | 40.8M | } |
698 | 91.0M | for (size_t i = 0; i < num_static_props; ++i) { |
699 | 50.1M | if (static_props[i][a] != static_props[i][b]) { |
700 | 1.70M | ret = false; |
701 | 1.70M | } |
702 | 50.1M | } |
703 | 236M | for (const auto &p : props) { |
704 | 236M | if (p[a] != p[b]) { |
705 | 20.0M | ret = false; |
706 | 20.0M | } |
707 | 236M | } |
708 | 40.8M | return ret; |
709 | 40.8M | } |
710 | | |
711 | | void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, |
712 | 38.5M | const pixel_type_w *predictions) { |
713 | 77.0M | for (size_t i = 0; i < predictors.size(); i++) { |
714 | 38.5M | pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; |
715 | 38.5M | uint32_t tok, nbits, bits; |
716 | 38.5M | HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); |
717 | 38.5M | JXL_DASSERT(tok < 256); |
718 | 38.5M | JXL_DASSERT(nbits < 256); |
719 | 38.5M | ResidualToken token = {static_cast<uint8_t>(tok), |
720 | 38.5M | static_cast<uint8_t>(nbits)}; |
721 | 38.5M | residuals[i].push_back(token); |
722 | 38.5M | } |
723 | 85.8M | for (size_t i = 0; i < num_static_props; ++i) { |
724 | 47.3M | static_props[i].push_back(QuantizeStaticProperty(i, properties[i])); |
725 | 47.3M | } |
726 | 260M | for (size_t i = num_static_props; i < props_to_use.size(); i++) { |
727 | 222M | props[i - num_static_props].push_back(QuantizeProperty(i, properties[props_to_use[i]])); |
728 | 222M | } |
729 | 38.5M | sample_counts.push_back(1); |
730 | 38.5M | num_samples++; |
731 | 38.5M | if (AddToTableAndMerge(sample_counts.size() - 1)) { |
732 | 33.3M | for (auto &r : residuals) r.pop_back(); |
733 | 73.0M | for (size_t i = 0; i < num_static_props; ++i) static_props[i].pop_back(); |
734 | 193M | for (auto &p : props) p.pop_back(); |
735 | 33.3M | sample_counts.pop_back(); |
736 | 33.3M | } |
737 | 38.5M | } |
738 | | |
739 | 8.14M | void TreeSamples::Swap(size_t a, size_t b) { |
740 | 8.14M | if (a == b) return; |
741 | 8.14M | for (auto &r : residuals) { |
742 | 8.14M | std::swap(r[a], r[b]); |
743 | 8.14M | } |
744 | 19.5M | for (size_t i = 0; i < num_static_props; ++i) { |
745 | 11.4M | std::swap(static_props[i][a], static_props[i][b]); |
746 | 11.4M | } |
747 | 45.5M | for (auto &p : props) { |
748 | 45.5M | std::swap(p[a], p[b]); |
749 | 45.5M | } |
750 | 8.14M | std::swap(sample_counts[a], sample_counts[b]); |
751 | 8.14M | } |
752 | | |
753 | | namespace { |
754 | | std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, |
755 | 5.81k | size_t num_chunks) { |
756 | 5.81k | if (histogram.empty() || num_chunks == 0) return {}; |
757 | 5.81k | uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); |
758 | 5.81k | if (sum == 0) return {}; |
759 | | // TODO(veluca): selecting distinct quantiles is likely not the best |
760 | | // way to go about this. |
761 | 5.60k | std::vector<int32_t> thresholds; |
762 | 5.60k | uint64_t cumsum = 0; |
763 | 5.60k | uint64_t threshold = 1; |
764 | 3.04M | for (size_t i = 0; i < histogram.size(); i++) { |
765 | 3.03M | cumsum += histogram[i]; |
766 | 3.03M | if (cumsum * num_chunks >= threshold * sum) { |
767 | 38.9k | thresholds.push_back(i); |
768 | 576k | while (cumsum * num_chunks >= threshold * sum) threshold++; |
769 | 38.9k | } |
770 | 3.03M | } |
771 | 5.60k | JXL_DASSERT(thresholds.size() <= num_chunks); |
772 | | // last value collects all histogram and is not really a threshold |
773 | 5.60k | thresholds.pop_back(); |
774 | 5.60k | return thresholds; |
775 | 5.81k | } |
776 | | |
777 | | std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, |
778 | 4.12k | size_t num_chunks) { |
779 | 4.12k | if (samples.empty()) return {}; |
780 | 2.95k | int min = *std::min_element(samples.begin(), samples.end()); |
781 | 2.95k | constexpr int kRange = 512; |
782 | 2.95k | min = jxl::Clamp1(min, -kRange, kRange); |
783 | 2.95k | std::vector<uint32_t> counts(2 * kRange + 1); |
784 | 5.42M | for (int s : samples) { |
785 | 5.42M | uint32_t sample_offset = jxl::Clamp1(s, -kRange, kRange) - min; |
786 | 5.42M | counts[sample_offset]++; |
787 | 5.42M | } |
788 | 2.95k | std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); |
789 | 31.3k | for (auto &v : thresholds) v += min; |
790 | 2.95k | return thresholds; |
791 | 4.12k | } |
792 | | |
793 | | // `to[i]` is assigned value `v` conforming `from[v] <= i && from[v-1] > i`. |
794 | | // This is because the decision node in the tree splits on (property) > i, |
795 | | // hence everything that is not > of a threshold should be clustered |
796 | | // together. |
797 | | template <typename T> |
798 | | void QuantMap(const std::vector<int32_t> &from, std::vector<T> &to, |
799 | 14.5k | size_t num_pegs, int bias) { |
800 | 14.5k | to.resize(num_pegs); |
801 | 14.5k | size_t mapped = 0; |
802 | 14.8M | for (size_t i = 0; i < num_pegs; i++) { |
803 | 15.1M | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { |
804 | 256k | mapped++; |
805 | 256k | } |
806 | 14.8M | JXL_DASSERT(static_cast<T>(mapped) == mapped); |
807 | 14.8M | to[i] = mapped; |
808 | 14.8M | } |
809 | 14.5k | } enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned short>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned short, std::__1::allocator<unsigned short> >&, unsigned long, int) Line | Count | Source | 799 | 2.86k | size_t num_pegs, int bias) { | 800 | 2.86k | to.resize(num_pegs); | 801 | 2.86k | size_t mapped = 0; | 802 | 2.92M | for (size_t i = 0; i < num_pegs; i++) { | 803 | 2.92M | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { | 804 | 2.04k | mapped++; | 805 | 2.04k | } | 806 | 2.92M | JXL_DASSERT(static_cast<T>(mapped) == mapped); | 807 | 2.92M | to[i] = mapped; | 808 | 2.92M | } | 809 | 2.86k | } |
enc_ma.cc:void jxl::(anonymous namespace)::QuantMap<unsigned char>(std::__1::vector<int, std::__1::allocator<int> > const&, std::__1::vector<unsigned char, std::__1::allocator<unsigned char> >&, unsigned long, int) Line | Count | Source | 799 | 11.6k | size_t num_pegs, int bias) { | 800 | 11.6k | to.resize(num_pegs); | 801 | 11.6k | size_t mapped = 0; | 802 | 11.9M | for (size_t i = 0; i < num_pegs; i++) { | 803 | 12.2M | while (mapped < from.size() && static_cast<int>(i) - bias > from[mapped]) { | 804 | 254k | mapped++; | 805 | 254k | } | 806 | 11.9M | JXL_DASSERT(static_cast<T>(mapped) == mapped); | 807 | 11.9M | to[i] = mapped; | 808 | 11.9M | } | 809 | 11.6k | } |
|
810 | | } // namespace |
811 | | |
812 | | void TreeSamples::PreQuantizeProperties( |
813 | | const StaticPropRange &range, |
814 | | const std::vector<ModularMultiplierInfo> &multiplier_info, |
815 | | const std::vector<uint32_t> &group_pixel_count, |
816 | | const std::vector<uint32_t> &channel_pixel_count, |
817 | | std::vector<pixel_type> &pixel_samples, |
818 | 2.07k | std::vector<pixel_type> &diff_samples, size_t max_property_values) { |
819 | | // If we have forced splits because of multipliers, choose channel and group |
820 | | // thresholds accordingly. |
821 | 2.07k | std::vector<int32_t> group_multiplier_thresholds; |
822 | 2.07k | std::vector<int32_t> channel_multiplier_thresholds; |
823 | 2.07k | for (const auto &v : multiplier_info) { |
824 | 773 | if (v.range[0][0] != range[0][0]) { |
825 | 0 | channel_multiplier_thresholds.push_back(v.range[0][0] - 1); |
826 | 0 | } |
827 | 773 | if (v.range[0][1] != range[0][1]) { |
828 | 0 | channel_multiplier_thresholds.push_back(v.range[0][1] - 1); |
829 | 0 | } |
830 | 773 | if (v.range[1][0] != range[1][0]) { |
831 | 0 | group_multiplier_thresholds.push_back(v.range[1][0] - 1); |
832 | 0 | } |
833 | 773 | if (v.range[1][1] != range[1][1]) { |
834 | 0 | group_multiplier_thresholds.push_back(v.range[1][1] - 1); |
835 | 0 | } |
836 | 773 | } |
837 | 2.07k | std::sort(channel_multiplier_thresholds.begin(), |
838 | 2.07k | channel_multiplier_thresholds.end()); |
839 | 2.07k | channel_multiplier_thresholds.resize( |
840 | 2.07k | std::unique(channel_multiplier_thresholds.begin(), |
841 | 2.07k | channel_multiplier_thresholds.end()) - |
842 | 2.07k | channel_multiplier_thresholds.begin()); |
843 | 2.07k | std::sort(group_multiplier_thresholds.begin(), |
844 | 2.07k | group_multiplier_thresholds.end()); |
845 | 2.07k | group_multiplier_thresholds.resize( |
846 | 2.07k | std::unique(group_multiplier_thresholds.begin(), |
847 | 2.07k | group_multiplier_thresholds.end()) - |
848 | 2.07k | group_multiplier_thresholds.begin()); |
849 | | |
850 | 2.07k | compact_properties.resize(props_to_use.size()); |
851 | 2.07k | auto quantize_channel = [&]() { |
852 | 2.07k | if (!channel_multiplier_thresholds.empty()) { |
853 | 0 | return channel_multiplier_thresholds; |
854 | 0 | } |
855 | 2.07k | return QuantizeHistogram(channel_pixel_count, max_property_values); |
856 | 2.07k | }; |
857 | 2.07k | auto quantize_group_id = [&]() { |
858 | 783 | if (!group_multiplier_thresholds.empty()) { |
859 | 0 | return group_multiplier_thresholds; |
860 | 0 | } |
861 | 783 | return QuantizeHistogram(group_pixel_count, max_property_values); |
862 | 783 | }; |
863 | 2.07k | auto quantize_coordinate = [&]() { |
864 | 0 | std::vector<int32_t> quantized; |
865 | 0 | quantized.reserve(max_property_values - 1); |
866 | 0 | for (size_t i = 0; i + 1 < max_property_values; i++) { |
867 | 0 | quantized.push_back((i + 1) * 256 / max_property_values - 1); |
868 | 0 | } |
869 | 0 | return quantized; |
870 | 0 | }; |
871 | 2.07k | std::vector<int32_t> abs_pixel_thresholds; |
872 | 2.07k | std::vector<int32_t> pixel_thresholds; |
873 | 2.07k | auto quantize_pixel_property = [&]() { |
874 | 0 | if (pixel_thresholds.empty()) { |
875 | 0 | pixel_thresholds = QuantizeSamples(pixel_samples, max_property_values); |
876 | 0 | } |
877 | 0 | return pixel_thresholds; |
878 | 0 | }; |
879 | 2.07k | auto quantize_abs_pixel_property = [&]() { |
880 | 0 | if (abs_pixel_thresholds.empty()) { |
881 | 0 | quantize_pixel_property(); // Compute the non-abs thresholds. |
882 | 0 | for (auto &v : pixel_samples) v = std::abs(v); |
883 | 0 | abs_pixel_thresholds = |
884 | 0 | QuantizeSamples(pixel_samples, max_property_values); |
885 | 0 | } |
886 | 0 | return abs_pixel_thresholds; |
887 | 0 | }; |
888 | 2.07k | std::vector<int32_t> abs_diff_thresholds; |
889 | 2.07k | std::vector<int32_t> diff_thresholds; |
890 | 9.60k | auto quantize_diff_property = [&]() { |
891 | 9.60k | if (diff_thresholds.empty()) { |
892 | 4.12k | diff_thresholds = QuantizeSamples(diff_samples, max_property_values); |
893 | 4.12k | } |
894 | 9.60k | return diff_thresholds; |
895 | 9.60k | }; |
896 | 2.07k | auto quantize_abs_diff_property = [&]() { |
897 | 0 | if (abs_diff_thresholds.empty()) { |
898 | 0 | quantize_diff_property(); // Compute the non-abs thresholds. |
899 | 0 | for (auto &v : diff_samples) v = std::abs(v); |
900 | 0 | abs_diff_thresholds = QuantizeSamples(diff_samples, max_property_values); |
901 | 0 | } |
902 | 0 | return abs_diff_thresholds; |
903 | 0 | }; |
904 | 2.07k | auto quantize_wp = [&]() { |
905 | 2.07k | if (max_property_values < 32) { |
906 | 0 | return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, |
907 | 0 | 1, 3, 7, 15, 31, 63, 127}; |
908 | 0 | } |
909 | 2.07k | if (max_property_values < 64) { |
910 | 0 | return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, |
911 | 0 | -15, -11, -7, -5, -3, -1, 0, 1, |
912 | 0 | 3, 5, 7, 11, 15, 23, 31, 47, |
913 | 0 | 63, 95, 127, 191, 255}; |
914 | 0 | } |
915 | 2.07k | return std::vector<int32_t>{ |
916 | 2.07k | -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, |
917 | 2.07k | -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, |
918 | 2.07k | -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, |
919 | 2.07k | 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, |
920 | 2.07k | 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; |
921 | 2.07k | }; |
922 | | |
923 | 2.07k | property_mapping.resize(props_to_use.size() - num_static_props); |
924 | 16.6k | for (size_t i = 0; i < props_to_use.size(); i++) { |
925 | 14.5k | if (props_to_use[i] == 0) { |
926 | 2.07k | compact_properties[i] = quantize_channel(); |
927 | 12.4k | } else if (props_to_use[i] == 1) { |
928 | 783 | compact_properties[i] = quantize_group_id(); |
929 | 11.6k | } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { |
930 | 0 | compact_properties[i] = quantize_coordinate(); |
931 | 11.6k | } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || |
932 | 11.6k | props_to_use[i] == 8 || |
933 | 11.6k | (props_to_use[i] >= kNumNonrefProperties && |
934 | 0 | (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { |
935 | 0 | compact_properties[i] = quantize_pixel_property(); |
936 | 11.6k | } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || |
937 | 11.6k | (props_to_use[i] >= kNumNonrefProperties && |
938 | 0 | (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { |
939 | 0 | compact_properties[i] = quantize_abs_pixel_property(); |
940 | 11.6k | } else if (props_to_use[i] >= kNumNonrefProperties && |
941 | 0 | (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { |
942 | 0 | compact_properties[i] = quantize_abs_diff_property(); |
943 | 11.6k | } else if (props_to_use[i] == kWPProp) { |
944 | 2.07k | compact_properties[i] = quantize_wp(); |
945 | 9.60k | } else { |
946 | 9.60k | compact_properties[i] = quantize_diff_property(); |
947 | 9.60k | } |
948 | 14.5k | if (i < num_static_props) { |
949 | 2.86k | QuantMap(compact_properties[i], static_property_mapping[i], |
950 | 2.86k | kPropertyRange * 2 + 1, kPropertyRange); |
951 | 11.6k | } else { |
952 | 11.6k | QuantMap(compact_properties[i], property_mapping[i - num_static_props], |
953 | 11.6k | kPropertyRange * 2 + 1, kPropertyRange); |
954 | 11.6k | } |
955 | 14.5k | } |
956 | 2.07k | } |
957 | | |
958 | | void CollectPixelSamples(const Image &image, const ModularOptions &options, |
959 | | uint32_t group_id, |
960 | | std::vector<uint32_t> &group_pixel_count, |
961 | | std::vector<uint32_t> &channel_pixel_count, |
962 | | std::vector<pixel_type> &pixel_samples, |
963 | 3.30k | std::vector<pixel_type> &diff_samples) { |
964 | 3.30k | if (options.nb_repeats == 0) return; |
965 | 3.30k | if (group_pixel_count.size() <= group_id) { |
966 | 3.30k | group_pixel_count.resize(group_id + 1); |
967 | 3.30k | } |
968 | 3.30k | if (channel_pixel_count.size() < image.channel.size()) { |
969 | 2.07k | channel_pixel_count.resize(image.channel.size()); |
970 | 2.07k | } |
971 | 3.30k | Rng rng(group_id); |
972 | | // Sample 10% of the final number of samples for property quantization. |
973 | 3.30k | float fraction = std::min(options.nb_repeats * 0.1, 0.99); |
974 | 3.30k | Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction); |
975 | 3.30k | size_t total_pixels = 0; |
976 | 3.30k | std::vector<size_t> channel_ids; |
977 | 8.41k | for (size_t i = 0; i < image.channel.size(); i++) { |
978 | 5.54k | if (i >= image.nb_meta_channels && |
979 | 5.18k | (image.channel[i].w > options.max_chan_size || |
980 | 4.78k | image.channel[i].h > options.max_chan_size)) { |
981 | 428 | break; |
982 | 428 | } |
983 | 5.11k | if (image.channel[i].w <= 1 || image.channel[i].h == 0) { |
984 | 59 | continue; // skip empty or width-1 channels. |
985 | 59 | } |
986 | 5.05k | channel_ids.push_back(i); |
987 | 5.05k | group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; |
988 | 5.05k | channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; |
989 | 5.05k | total_pixels += image.channel[i].w * image.channel[i].h; |
990 | 5.05k | } |
991 | 3.30k | if (channel_ids.empty()) return; |
992 | 3.07k | pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); |
993 | 3.07k | diff_samples.reserve(diff_samples.size() + fraction * total_pixels); |
994 | 3.07k | size_t i = 0; |
995 | 3.07k | size_t y = 0; |
996 | 3.07k | size_t x = 0; |
997 | 3.78M | auto advance = [&](size_t amount) { |
998 | 3.78M | x += amount; |
999 | | // Detect row overflow (rare). |
1000 | 4.26M | while (x >= image.channel[channel_ids[i]].w) { |
1001 | 481k | x -= image.channel[channel_ids[i]].w; |
1002 | 481k | y++; |
1003 | | // Detect end-of-channel (even rarer). |
1004 | 481k | if (y == image.channel[channel_ids[i]].h) { |
1005 | 5.05k | i++; |
1006 | 5.05k | y = 0; |
1007 | 5.05k | if (i >= channel_ids.size()) { |
1008 | 3.07k | return; |
1009 | 3.07k | } |
1010 | 5.05k | } |
1011 | 481k | } |
1012 | 3.78M | }; |
1013 | 3.07k | advance(rng.Geometric(dist)); |
1014 | 3.78M | for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { |
1015 | 3.77M | const pixel_type *row = image.channel[channel_ids[i]].Row(y); |
1016 | 3.77M | pixel_samples.push_back(row[x]); |
1017 | 3.77M | size_t xp = x == 0 ? 1 : x - 1; |
1018 | 3.77M | diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]); |
1019 | 3.77M | } |
1020 | 3.07k | } |
1021 | | |
1022 | | // TODO(veluca): very simple encoding scheme. This should be improved. |
1023 | | Status TokenizeTree(const Tree &tree, std::vector<Token> *tokens, |
1024 | 3.30k | Tree *decoder_tree) { |
1025 | 3.30k | JXL_ENSURE(tree.size() <= kMaxTreeSize); |
1026 | 3.30k | std::queue<int> q; |
1027 | 3.30k | q.push(0); |
1028 | 3.30k | size_t leaf_id = 0; |
1029 | 3.30k | decoder_tree->clear(); |
1030 | 396k | while (!q.empty()) { |
1031 | 393k | int cur = q.front(); |
1032 | 393k | q.pop(); |
1033 | 393k | JXL_ENSURE(tree[cur].property >= -1); |
1034 | 393k | tokens->emplace_back(kPropertyContext, tree[cur].property + 1); |
1035 | 393k | if (tree[cur].property == -1) { |
1036 | 198k | tokens->emplace_back(kPredictorContext, |
1037 | 198k | static_cast<int>(tree[cur].predictor)); |
1038 | 198k | tokens->emplace_back(kOffsetContext, |
1039 | 198k | PackSigned(tree[cur].predictor_offset)); |
1040 | 198k | uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); |
1041 | 198k | uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; |
1042 | 198k | tokens->emplace_back(kMultiplierLogContext, mul_log); |
1043 | 198k | tokens->emplace_back(kMultiplierBitsContext, mul_bits); |
1044 | 198k | JXL_ENSURE(tree[cur].predictor < Predictor::Best); |
1045 | 198k | decoder_tree->emplace_back( |
1046 | 198k | -1, 0, static_cast<int>(leaf_id), 0, tree[cur].predictor, |
1047 | 198k | tree[cur].predictor_offset, tree[cur].multiplier); |
1048 | 198k | leaf_id++; |
1049 | 198k | continue; |
1050 | 198k | } |
1051 | 194k | decoder_tree->emplace_back( |
1052 | 194k | tree[cur].property, tree[cur].splitval, |
1053 | 194k | static_cast<int>(decoder_tree->size() + q.size() + 1), |
1054 | 194k | static_cast<int>(decoder_tree->size() + q.size() + 2), Predictor::Zero, |
1055 | 194k | 0, 1); |
1056 | 194k | q.push(tree[cur].lchild); |
1057 | 194k | q.push(tree[cur].rchild); |
1058 | 194k | tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); |
1059 | 194k | } |
1060 | 3.30k | return true; |
1061 | 3.30k | } |
1062 | | |
1063 | | } // namespace jxl |
1064 | | #endif // HWY_ONCE |