/src/libjxl/lib/jxl/enc_cluster.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/enc_cluster.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cmath> |
10 | | #include <limits> |
11 | | #include <map> |
12 | | #include <memory> |
13 | | #include <numeric> |
14 | | #include <queue> |
15 | | #include <tuple> |
16 | | |
17 | | #undef HWY_TARGET_INCLUDE |
18 | | #define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc" |
19 | | #include <hwy/foreach_target.h> |
20 | | #include <hwy/highway.h> |
21 | | |
22 | | #include "lib/jxl/ac_context.h" |
23 | | #include "lib/jxl/fast_math-inl.h" |
24 | | HWY_BEFORE_NAMESPACE(); |
25 | | namespace jxl { |
26 | | namespace HWY_NAMESPACE { |
27 | | |
28 | | // These templates are not found via ADL. |
29 | | using hwy::HWY_NAMESPACE::Eq; |
30 | | using hwy::HWY_NAMESPACE::IfThenZeroElse; |
31 | | |
32 | | template <class V> |
33 | 0 | V Entropy(V count, V inv_total, V total) { |
34 | 0 | const HWY_CAPPED(float, Histogram::kRounding) d; |
35 | 0 | const auto zero = Set(d, 0.0f); |
36 | | // TODO(eustas): why (0 - x) instead of Neg(x)? |
37 | 0 | return IfThenZeroElse( |
38 | 0 | Eq(count, total), |
39 | 0 | Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count))))); |
40 | 0 | } Unexecuted instantiation: hwy::N_SSE4::Vec128<float, 4ul> jxl::N_SSE4::Entropy<hwy::N_SSE4::Vec128<float, 4ul> >(hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>) Unexecuted instantiation: hwy::N_AVX2::Vec256<float> jxl::N_AVX2::Entropy<hwy::N_AVX2::Vec256<float> >(hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>) Unexecuted instantiation: hwy::N_AVX3::Vec256<float> jxl::N_AVX3::Entropy<hwy::N_AVX3::Vec256<float> >(hwy::N_AVX3::Vec256<float>, hwy::N_AVX3::Vec256<float>, hwy::N_AVX3::Vec256<float>) Unexecuted instantiation: hwy::N_AVX3_ZEN4::Vec256<float> jxl::N_AVX3_ZEN4::Entropy<hwy::N_AVX3_ZEN4::Vec256<float> >(hwy::N_AVX3_ZEN4::Vec256<float>, hwy::N_AVX3_ZEN4::Vec256<float>, hwy::N_AVX3_ZEN4::Vec256<float>) Unexecuted instantiation: hwy::N_AVX3_SPR::Vec256<float> jxl::N_AVX3_SPR::Entropy<hwy::N_AVX3_SPR::Vec256<float> >(hwy::N_AVX3_SPR::Vec256<float>, hwy::N_AVX3_SPR::Vec256<float>, hwy::N_AVX3_SPR::Vec256<float>) Unexecuted instantiation: hwy::N_SSE2::Vec128<float, 4ul> jxl::N_SSE2::Entropy<hwy::N_SSE2::Vec128<float, 4ul> >(hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>) |
41 | | |
42 | 0 | void HistogramEntropy(const Histogram& a) { |
43 | 0 | a.entropy_ = 0.0f; |
44 | 0 | if (a.total_count_ == 0) return; |
45 | | |
46 | 0 | const HWY_CAPPED(float, Histogram::kRounding) df; |
47 | 0 | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
48 | |
|
49 | 0 | const auto inv_tot = Set(df, 1.0f / a.total_count_); |
50 | 0 | auto entropy_lanes = Zero(df); |
51 | 0 | auto total = Set(df, a.total_count_); |
52 | |
|
53 | 0 | for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) { |
54 | 0 | const auto counts = LoadU(di, &a.data_[i]); |
55 | 0 | entropy_lanes = |
56 | 0 | Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total)); |
57 | 0 | } |
58 | 0 | a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes)); |
59 | 0 | } Unexecuted instantiation: jxl::N_SSE4::HistogramEntropy(jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX2::HistogramEntropy(jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3::HistogramEntropy(jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_ZEN4::HistogramEntropy(jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_SPR::HistogramEntropy(jxl::Histogram const&) Unexecuted instantiation: jxl::N_SSE2::HistogramEntropy(jxl::Histogram const&) |
60 | | |
61 | 0 | float HistogramDistance(const Histogram& a, const Histogram& b) { |
62 | 0 | if (a.total_count_ == 0 || b.total_count_ == 0) return 0; |
63 | | |
64 | 0 | const HWY_CAPPED(float, Histogram::kRounding) df; |
65 | 0 | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
66 | |
|
67 | 0 | const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_)); |
68 | 0 | auto distance_lanes = Zero(df); |
69 | 0 | auto total = Set(df, a.total_count_ + b.total_count_); |
70 | |
|
71 | 0 | for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size()); |
72 | 0 | i += Lanes(di)) { |
73 | 0 | const auto a_counts = |
74 | 0 | a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di); |
75 | 0 | const auto b_counts = |
76 | 0 | b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di); |
77 | 0 | const auto counts = ConvertTo(df, Add(a_counts, b_counts)); |
78 | 0 | distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total)); |
79 | 0 | } |
80 | 0 | const float total_distance = GetLane(SumOfLanes(df, distance_lanes)); |
81 | 0 | return total_distance - a.entropy_ - b.entropy_; |
82 | 0 | } Unexecuted instantiation: jxl::N_SSE4::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX2::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_ZEN4::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_SPR::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_SSE2::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) |
83 | | |
84 | | // First step of a k-means clustering with a fancy distance metric. |
85 | | void FastClusterHistograms(const std::vector<Histogram>& in, |
86 | | size_t max_histograms, std::vector<Histogram>* out, |
87 | 0 | std::vector<uint32_t>* histogram_symbols) { |
88 | 0 | out->clear(); |
89 | 0 | out->reserve(max_histograms); |
90 | 0 | histogram_symbols->clear(); |
91 | 0 | histogram_symbols->resize(in.size(), max_histograms); |
92 | |
|
93 | 0 | std::vector<float> dists(in.size(), std::numeric_limits<float>::max()); |
94 | 0 | size_t largest_idx = 0; |
95 | 0 | for (size_t i = 0; i < in.size(); i++) { |
96 | 0 | if (in[i].total_count_ == 0) { |
97 | 0 | (*histogram_symbols)[i] = 0; |
98 | 0 | dists[i] = 0.0f; |
99 | 0 | continue; |
100 | 0 | } |
101 | 0 | HistogramEntropy(in[i]); |
102 | 0 | if (in[i].total_count_ > in[largest_idx].total_count_) { |
103 | 0 | largest_idx = i; |
104 | 0 | } |
105 | 0 | } |
106 | |
|
107 | 0 | constexpr float kMinDistanceForDistinct = 48.0f; |
108 | 0 | while (out->size() < max_histograms) { |
109 | 0 | (*histogram_symbols)[largest_idx] = out->size(); |
110 | 0 | out->push_back(in[largest_idx]); |
111 | 0 | dists[largest_idx] = 0.0f; |
112 | 0 | largest_idx = 0; |
113 | 0 | for (size_t i = 0; i < in.size(); i++) { |
114 | 0 | if (dists[i] == 0.0f) continue; |
115 | 0 | dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); |
116 | 0 | if (dists[i] > dists[largest_idx]) largest_idx = i; |
117 | 0 | } |
118 | 0 | if (dists[largest_idx] < kMinDistanceForDistinct) break; |
119 | 0 | } |
120 | |
|
121 | 0 | for (size_t i = 0; i < in.size(); i++) { |
122 | 0 | if ((*histogram_symbols)[i] != max_histograms) continue; |
123 | 0 | size_t best = 0; |
124 | 0 | float best_dist = HistogramDistance(in[i], (*out)[best]); |
125 | 0 | for (size_t j = 1; j < out->size(); j++) { |
126 | 0 | float dist = HistogramDistance(in[i], (*out)[j]); |
127 | 0 | if (dist < best_dist) { |
128 | 0 | best = j; |
129 | 0 | best_dist = dist; |
130 | 0 | } |
131 | 0 | } |
132 | 0 | (*out)[best].AddHistogram(in[i]); |
133 | 0 | HistogramEntropy((*out)[best]); |
134 | 0 | (*histogram_symbols)[i] = best; |
135 | 0 | } |
136 | 0 | } Unexecuted instantiation: jxl::N_SSE4::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Unexecuted instantiation: jxl::N_AVX2::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Unexecuted instantiation: jxl::N_AVX3::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Unexecuted instantiation: jxl::N_AVX3_ZEN4::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Unexecuted instantiation: jxl::N_AVX3_SPR::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Unexecuted instantiation: jxl::N_SSE2::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) |
137 | | |
138 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
139 | | } // namespace HWY_NAMESPACE |
140 | | } // namespace jxl |
141 | | HWY_AFTER_NAMESPACE(); |
142 | | |
143 | | #if HWY_ONCE |
144 | | namespace jxl { |
145 | | HWY_EXPORT(FastClusterHistograms); // Local function |
146 | | HWY_EXPORT(HistogramEntropy); // Local function |
147 | | |
148 | 0 | float Histogram::ShannonEntropy() const { |
149 | 0 | HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this); |
150 | 0 | return entropy_; |
151 | 0 | } |
152 | | |
153 | | namespace { |
154 | | // ----------------------------------------------------------------------------- |
155 | | // Histogram refinement |
156 | | |
157 | | // Reorder histograms in *out so that the new symbols in *symbols come in |
158 | | // increasing order. |
159 | | void HistogramReindex(std::vector<Histogram>* out, |
160 | 0 | std::vector<uint32_t>* symbols) { |
161 | 0 | std::vector<Histogram> tmp(*out); |
162 | 0 | std::map<int, int> new_index; |
163 | 0 | int next_index = 0; |
164 | 0 | for (uint32_t symbol : *symbols) { |
165 | 0 | if (new_index.find(symbol) == new_index.end()) { |
166 | 0 | new_index[symbol] = next_index; |
167 | 0 | (*out)[next_index] = tmp[symbol]; |
168 | 0 | ++next_index; |
169 | 0 | } |
170 | 0 | } |
171 | 0 | out->resize(next_index); |
172 | 0 | for (uint32_t& symbol : *symbols) { |
173 | 0 | symbol = new_index[symbol]; |
174 | 0 | } |
175 | 0 | } |
176 | | |
177 | | } // namespace |
178 | | |
179 | | // Clusters similar histograms in 'in' together, the selected histograms are |
180 | | // placed in 'out', and for each index in 'in', *histogram_symbols will |
181 | | // indicate which of the 'out' histograms is the best approximation. |
182 | | void ClusterHistograms(const HistogramParams params, |
183 | | const std::vector<Histogram>& in, size_t max_histograms, |
184 | | std::vector<Histogram>* out, |
185 | 0 | std::vector<uint32_t>* histogram_symbols) { |
186 | 0 | max_histograms = std::min(max_histograms, params.max_histograms); |
187 | 0 | max_histograms = std::min(max_histograms, in.size()); |
188 | 0 | if (params.clustering == HistogramParams::ClusteringType::kFastest) { |
189 | 0 | max_histograms = std::min(max_histograms, static_cast<size_t>(4)); |
190 | 0 | } |
191 | |
|
192 | 0 | HWY_DYNAMIC_DISPATCH(FastClusterHistograms) |
193 | 0 | (in, max_histograms, out, histogram_symbols); |
194 | |
|
195 | 0 | if (params.clustering == HistogramParams::ClusteringType::kBest) { |
196 | 0 | for (size_t i = 0; i < out->size(); i++) { |
197 | 0 | (*out)[i].entropy_ = |
198 | 0 | ANSPopulationCost((*out)[i].data_.data(), (*out)[i].data_.size()); |
199 | 0 | } |
200 | 0 | uint32_t next_version = 2; |
201 | 0 | std::vector<uint32_t> version(out->size(), 1); |
202 | 0 | std::vector<uint32_t> renumbering(out->size()); |
203 | 0 | std::iota(renumbering.begin(), renumbering.end(), 0); |
204 | | |
205 | | // Try to pair up clusters if doing so reduces the total cost. |
206 | |
|
207 | 0 | struct HistogramPair { |
208 | | // validity of a pair: p.version == max(version[i], version[j]) |
209 | 0 | float cost; |
210 | 0 | uint32_t first; |
211 | 0 | uint32_t second; |
212 | 0 | uint32_t version; |
213 | | // We use > because priority queues sort in *decreasing* order, but we |
214 | | // want lower cost elements to appear first. |
215 | 0 | bool operator<(const HistogramPair& other) const { |
216 | 0 | return std::make_tuple(cost, first, second, version) > |
217 | 0 | std::make_tuple(other.cost, other.first, other.second, |
218 | 0 | other.version); |
219 | 0 | } |
220 | 0 | }; |
221 | | |
222 | | // Create list of all pairs by increasing merging cost. |
223 | 0 | std::priority_queue<HistogramPair> pairs_to_merge; |
224 | 0 | for (uint32_t i = 0; i < out->size(); i++) { |
225 | 0 | for (uint32_t j = i + 1; j < out->size(); j++) { |
226 | 0 | Histogram histo; |
227 | 0 | histo.AddHistogram((*out)[i]); |
228 | 0 | histo.AddHistogram((*out)[j]); |
229 | 0 | float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - |
230 | 0 | (*out)[i].entropy_ - (*out)[j].entropy_; |
231 | | // Avoid enqueueing pairs that are not advantageous to merge. |
232 | 0 | if (cost >= 0) continue; |
233 | 0 | pairs_to_merge.push( |
234 | 0 | HistogramPair{cost, i, j, std::max(version[i], version[j])}); |
235 | 0 | } |
236 | 0 | } |
237 | | |
238 | | // Merge the best pair to merge, add new pairs that get formed as a |
239 | | // consequence. |
240 | 0 | while (!pairs_to_merge.empty()) { |
241 | 0 | uint32_t first = pairs_to_merge.top().first; |
242 | 0 | uint32_t second = pairs_to_merge.top().second; |
243 | 0 | uint32_t ver = pairs_to_merge.top().version; |
244 | 0 | pairs_to_merge.pop(); |
245 | 0 | if (ver != std::max(version[first], version[second]) || |
246 | 0 | version[first] == 0 || version[second] == 0) { |
247 | 0 | continue; |
248 | 0 | } |
249 | 0 | (*out)[first].AddHistogram((*out)[second]); |
250 | 0 | (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(), |
251 | 0 | (*out)[first].data_.size()); |
252 | 0 | for (size_t i = 0; i < renumbering.size(); i++) { |
253 | 0 | if (renumbering[i] == second) { |
254 | 0 | renumbering[i] = first; |
255 | 0 | } |
256 | 0 | } |
257 | 0 | version[second] = 0; |
258 | 0 | version[first] = next_version++; |
259 | 0 | for (uint32_t j = 0; j < out->size(); j++) { |
260 | 0 | if (j == first) continue; |
261 | 0 | if (version[j] == 0) continue; |
262 | 0 | Histogram histo; |
263 | 0 | histo.AddHistogram((*out)[first]); |
264 | 0 | histo.AddHistogram((*out)[j]); |
265 | 0 | float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - |
266 | 0 | (*out)[first].entropy_ - (*out)[j].entropy_; |
267 | | // Avoid enqueueing pairs that are not advantageous to merge. |
268 | 0 | if (cost >= 0) continue; |
269 | 0 | pairs_to_merge.push( |
270 | 0 | HistogramPair{cost, std::min(first, j), std::max(first, j), |
271 | 0 | std::max(version[first], version[j])}); |
272 | 0 | } |
273 | 0 | } |
274 | 0 | std::vector<uint32_t> reverse_renumbering(out->size(), -1); |
275 | 0 | size_t num_alive = 0; |
276 | 0 | for (size_t i = 0; i < out->size(); i++) { |
277 | 0 | if (version[i] == 0) continue; |
278 | 0 | (*out)[num_alive++] = (*out)[i]; |
279 | 0 | reverse_renumbering[i] = num_alive - 1; |
280 | 0 | } |
281 | 0 | out->resize(num_alive); |
282 | 0 | for (size_t i = 0; i < histogram_symbols->size(); i++) { |
283 | 0 | (*histogram_symbols)[i] = |
284 | 0 | reverse_renumbering[renumbering[(*histogram_symbols)[i]]]; |
285 | 0 | } |
286 | 0 | } |
287 | | |
288 | | // Convert the context map to a canonical form. |
289 | 0 | HistogramReindex(out, histogram_symbols); |
290 | 0 | } |
291 | | |
292 | | } // namespace jxl |
293 | | #endif // HWY_ONCE |