/src/libjxl/lib/jxl/enc_cluster.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/enc_cluster.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cstddef> |
10 | | #include <cstdint> |
11 | | #include <limits> |
12 | | #include <map> |
13 | | #include <numeric> |
14 | | #include <queue> |
15 | | #include <tuple> |
16 | | #include <vector> |
17 | | |
18 | | #include "lib/jxl/base/status.h" |
19 | | #include "lib/jxl/enc_ans_params.h" |
20 | | |
21 | | #undef HWY_TARGET_INCLUDE |
22 | | #define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc" |
23 | | #include <hwy/foreach_target.h> |
24 | | #include <hwy/highway.h> |
25 | | |
26 | | #include "lib/jxl/base/fast_math-inl.h" |
27 | | HWY_BEFORE_NAMESPACE(); |
28 | | namespace jxl { |
29 | | namespace HWY_NAMESPACE { |
30 | | |
31 | | // These templates are not found via ADL. |
32 | | using hwy::HWY_NAMESPACE::Eq; |
33 | | using hwy::HWY_NAMESPACE::IfThenZeroElse; |
34 | | |
35 | | template <class V> |
36 | 105M | V Entropy(V count, V inv_total, V total) { |
37 | 105M | const HWY_CAPPED(float, Histogram::kRounding) d; |
38 | 105M | const auto zero = Set(d, 0.0f); |
39 | | // TODO(eustas): why (0 - x) instead of Neg(x)? |
40 | 105M | return IfThenZeroElse( |
41 | 105M | Eq(count, total), |
42 | 105M | Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count))))); |
43 | 105M | } Unexecuted instantiation: hwy::N_SSE4::Vec128<float, 4ul> jxl::N_SSE4::Entropy<hwy::N_SSE4::Vec128<float, 4ul> >(hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>) hwy::N_AVX2::Vec256<float> jxl::N_AVX2::Entropy<hwy::N_AVX2::Vec256<float> >(hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>) Line | Count | Source | 36 | 105M | V Entropy(V count, V inv_total, V total) { | 37 | 105M | const HWY_CAPPED(float, Histogram::kRounding) d; | 38 | 105M | const auto zero = Set(d, 0.0f); | 39 | | // TODO(eustas): why (0 - x) instead of Neg(x)? | 40 | 105M | return IfThenZeroElse( | 41 | 105M | Eq(count, total), | 42 | 105M | Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count))))); | 43 | 105M | } |
Unexecuted instantiation: hwy::N_SSE2::Vec128<float, 4ul> jxl::N_SSE2::Entropy<hwy::N_SSE2::Vec128<float, 4ul> >(hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>) |
44 | | |
45 | 639k | void HistogramEntropy(const Histogram& a) { |
46 | 639k | a.entropy = 0.0f; |
47 | 639k | if (a.total_count == 0) return; |
48 | | |
49 | 483k | const HWY_CAPPED(float, Histogram::kRounding) df; |
50 | 483k | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
51 | | |
52 | 483k | const auto inv_tot = Set(df, 1.0f / a.total_count); |
53 | 483k | auto entropy_lanes = Zero(df); |
54 | 483k | auto total = Set(df, a.total_count); |
55 | | |
56 | 1.71M | for (size_t i = 0; i < a.counts.size(); i += Lanes(di)) { |
57 | 1.22M | const auto counts = LoadU(di, &a.counts[i]); |
58 | 1.22M | entropy_lanes = |
59 | 1.22M | Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total)); |
60 | 1.22M | } |
61 | 483k | a.entropy += GetLane(SumOfLanes(df, entropy_lanes)); |
62 | 483k | } Unexecuted instantiation: jxl::N_SSE4::HistogramEntropy(jxl::Histogram const&) jxl::N_AVX2::HistogramEntropy(jxl::Histogram const&) Line | Count | Source | 45 | 639k | void HistogramEntropy(const Histogram& a) { | 46 | 639k | a.entropy = 0.0f; | 47 | 639k | if (a.total_count == 0) return; | 48 | | | 49 | 483k | const HWY_CAPPED(float, Histogram::kRounding) df; | 50 | 483k | const HWY_CAPPED(int32_t, Histogram::kRounding) di; | 51 | | | 52 | 483k | const auto inv_tot = Set(df, 1.0f / a.total_count); | 53 | 483k | auto entropy_lanes = Zero(df); | 54 | 483k | auto total = Set(df, a.total_count); | 55 | | | 56 | 1.71M | for (size_t i = 0; i < a.counts.size(); i += Lanes(di)) { | 57 | 1.22M | const auto counts = LoadU(di, &a.counts[i]); | 58 | 1.22M | entropy_lanes = | 59 | 1.22M | Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total)); | 60 | 1.22M | } | 61 | 483k | a.entropy += GetLane(SumOfLanes(df, entropy_lanes)); | 62 | 483k | } |
Unexecuted instantiation: jxl::N_SSE2::HistogramEntropy(jxl::Histogram const&) |
63 | | |
64 | 29.7M | float HistogramDistance(const Histogram& a, const Histogram& b) { |
65 | 29.7M | if (a.total_count == 0 || b.total_count == 0) return 0; |
66 | | |
67 | 29.7M | const HWY_CAPPED(float, Histogram::kRounding) df; |
68 | 29.7M | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
69 | | |
70 | 29.7M | const auto inv_tot = Set(df, 1.0f / (a.total_count + b.total_count)); |
71 | 29.7M | auto distance_lanes = Zero(df); |
72 | 29.7M | auto total = Set(df, a.total_count + b.total_count); |
73 | | |
74 | 134M | for (size_t i = 0; i < std::max(a.counts.size(), b.counts.size()); |
75 | 104M | i += Lanes(di)) { |
76 | 104M | const auto a_counts = |
77 | 104M | a.counts.size() > i ? LoadU(di, &a.counts[i]) : Zero(di); |
78 | 104M | const auto b_counts = |
79 | 104M | b.counts.size() > i ? LoadU(di, &b.counts[i]) : Zero(di); |
80 | 104M | const auto counts = ConvertTo(df, Add(a_counts, b_counts)); |
81 | 104M | distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total)); |
82 | 104M | } |
83 | 29.7M | const float total_distance = GetLane(SumOfLanes(df, distance_lanes)); |
84 | 29.7M | return total_distance - a.entropy - b.entropy; |
85 | 29.7M | } Unexecuted instantiation: jxl::N_SSE4::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) jxl::N_AVX2::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Line | Count | Source | 64 | 29.7M | float HistogramDistance(const Histogram& a, const Histogram& b) { | 65 | 29.7M | if (a.total_count == 0 || b.total_count == 0) return 0; | 66 | | | 67 | 29.7M | const HWY_CAPPED(float, Histogram::kRounding) df; | 68 | 29.7M | const HWY_CAPPED(int32_t, Histogram::kRounding) di; | 69 | | | 70 | 29.7M | const auto inv_tot = Set(df, 1.0f / (a.total_count + b.total_count)); | 71 | 29.7M | auto distance_lanes = Zero(df); | 72 | 29.7M | auto total = Set(df, a.total_count + b.total_count); | 73 | | | 74 | 134M | for (size_t i = 0; i < std::max(a.counts.size(), b.counts.size()); | 75 | 104M | i += Lanes(di)) { | 76 | 104M | const auto a_counts = | 77 | 104M | a.counts.size() > i ? LoadU(di, &a.counts[i]) : Zero(di); | 78 | 104M | const auto b_counts = | 79 | 104M | b.counts.size() > i ? LoadU(di, &b.counts[i]) : Zero(di); | 80 | 104M | const auto counts = ConvertTo(df, Add(a_counts, b_counts)); | 81 | 104M | distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total)); | 82 | 104M | } | 83 | 29.7M | const float total_distance = GetLane(SumOfLanes(df, distance_lanes)); | 84 | 29.7M | return total_distance - a.entropy - b.entropy; | 85 | 29.7M | } |
Unexecuted instantiation: jxl::N_SSE2::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) |
86 | | |
87 | | constexpr const float kInfinity = std::numeric_limits<float>::infinity(); |
88 | | |
89 | 0 | float HistogramKLDivergence(const Histogram& actual, const Histogram& coding) { |
90 | 0 | if (actual.total_count == 0) return 0; |
91 | 0 | if (coding.total_count == 0) return kInfinity; |
92 | | |
93 | 0 | const HWY_CAPPED(float, Histogram::kRounding) df; |
94 | 0 | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
95 | |
|
96 | 0 | const auto coding_inv = Set(df, 1.0f / coding.total_count); |
97 | 0 | auto cost_lanes = Zero(df); |
98 | |
|
99 | 0 | for (size_t i = 0; i < actual.counts.size(); i += Lanes(di)) { |
100 | 0 | const auto counts = LoadU(di, &actual.counts[i]); |
101 | 0 | const auto coding_counts = |
102 | 0 | coding.counts.size() > i ? LoadU(di, &coding.counts[i]) : Zero(di); |
103 | 0 | const auto coding_probs = Mul(ConvertTo(df, coding_counts), coding_inv); |
104 | 0 | const auto neg_coding_cost = BitCast( |
105 | 0 | df, |
106 | 0 | IfThenZeroElse(Eq(counts, Zero(di)), |
107 | 0 | IfThenElse(Eq(coding_counts, Zero(di)), |
108 | 0 | BitCast(di, Set(df, -kInfinity)), |
109 | 0 | BitCast(di, FastLog2f(df, coding_probs))))); |
110 | 0 | cost_lanes = NegMulAdd(ConvertTo(df, counts), neg_coding_cost, cost_lanes); |
111 | 0 | } |
112 | 0 | const float total_cost = GetLane(SumOfLanes(df, cost_lanes)); |
113 | 0 | return total_cost - actual.entropy; |
114 | 0 | } Unexecuted instantiation: jxl::N_SSE4::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX2::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_SSE2::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) |
115 | | |
116 | | // First step of a k-means clustering with a fancy distance metric. |
117 | | Status FastClusterHistograms(const std::vector<Histogram>& in, |
118 | | size_t max_histograms, std::vector<Histogram>* out, |
119 | 1.30k | std::vector<uint32_t>* histogram_symbols) { |
120 | 1.30k | const size_t prev_histograms = out->size(); |
121 | 1.30k | out->reserve(max_histograms); |
122 | 1.30k | histogram_symbols->clear(); |
123 | 1.30k | histogram_symbols->resize(in.size(), max_histograms); |
124 | | |
125 | 1.30k | std::vector<float> dists(in.size(), std::numeric_limits<float>::max()); |
126 | 1.30k | size_t largest_idx = 0; |
127 | 584k | for (size_t i = 0; i < in.size(); i++) { |
128 | 583k | if (in[i].total_count == 0) { |
129 | 415k | (*histogram_symbols)[i] = 0; |
130 | 415k | dists[i] = 0.0f; |
131 | 415k | continue; |
132 | 415k | } |
133 | 168k | HistogramEntropy(in[i]); |
134 | 168k | if (in[i].total_count > in[largest_idx].total_count) { |
135 | 2.52k | largest_idx = i; |
136 | 2.52k | } |
137 | 168k | } |
138 | | |
139 | 1.30k | if (prev_histograms > 0) { |
140 | 0 | for (size_t j = 0; j < prev_histograms; ++j) { |
141 | 0 | HistogramEntropy((*out)[j]); |
142 | 0 | } |
143 | 0 | for (size_t i = 0; i < in.size(); i++) { |
144 | 0 | if (dists[i] == 0.0f) continue; |
145 | 0 | for (size_t j = 0; j < prev_histograms; ++j) { |
146 | 0 | dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]); |
147 | 0 | } |
148 | 0 | } |
149 | 0 | auto max_dist = std::max_element(dists.begin(), dists.end()); |
150 | 0 | if (*max_dist > 0.0f) { |
151 | 0 | largest_idx = max_dist - dists.begin(); |
152 | 0 | } |
153 | 0 | } |
154 | | |
155 | 1.30k | constexpr float kMinDistanceForDistinct = 48.0f; |
156 | 20.2k | while (out->size() < max_histograms) { |
157 | 20.1k | (*histogram_symbols)[largest_idx] = out->size(); |
158 | 20.1k | out->push_back(in[largest_idx]); |
159 | 20.1k | dists[largest_idx] = 0.0f; |
160 | 20.1k | largest_idx = 0; |
161 | 21.7M | for (size_t i = 0; i < in.size(); i++) { |
162 | 21.7M | if (dists[i] == 0.0f) continue; |
163 | 15.2M | dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); |
164 | 15.2M | if (dists[i] > dists[largest_idx]) largest_idx = i; |
165 | 15.2M | } |
166 | 20.1k | if (dists[largest_idx] < kMinDistanceForDistinct) break; |
167 | 20.1k | } |
168 | | |
169 | 584k | for (size_t i = 0; i < in.size(); i++) { |
170 | 583k | if ((*histogram_symbols)[i] != max_histograms) continue; |
171 | 148k | size_t best = 0; |
172 | 148k | float best_dist = std::numeric_limits<float>::max(); |
173 | 14.6M | for (size_t j = 0; j < out->size(); j++) { |
174 | 14.4M | float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j]) |
175 | 14.4M | : HistogramDistance(in[i], (*out)[j]); |
176 | 14.4M | if (dist < best_dist) { |
177 | 587k | best = j; |
178 | 587k | best_dist = dist; |
179 | 587k | } |
180 | 14.4M | } |
181 | 148k | JXL_ENSURE(best_dist < std::numeric_limits<float>::max()); |
182 | 148k | if (best >= prev_histograms) { |
183 | 148k | (*out)[best].AddHistogram(in[i]); |
184 | 148k | HistogramEntropy((*out)[best]); |
185 | 148k | } |
186 | 148k | (*histogram_symbols)[i] = best; |
187 | 148k | } |
188 | 1.30k | return true; |
189 | 1.30k | } Unexecuted instantiation: jxl::N_SSE4::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) jxl::N_AVX2::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Line | Count | Source | 119 | 1.30k | std::vector<uint32_t>* histogram_symbols) { | 120 | 1.30k | const size_t prev_histograms = out->size(); | 121 | 1.30k | out->reserve(max_histograms); | 122 | 1.30k | histogram_symbols->clear(); | 123 | 1.30k | histogram_symbols->resize(in.size(), max_histograms); | 124 | | | 125 | 1.30k | std::vector<float> dists(in.size(), std::numeric_limits<float>::max()); | 126 | 1.30k | size_t largest_idx = 0; | 127 | 584k | for (size_t i = 0; i < in.size(); i++) { | 128 | 583k | if (in[i].total_count == 0) { | 129 | 415k | (*histogram_symbols)[i] = 0; | 130 | 415k | dists[i] = 0.0f; | 131 | 415k | continue; | 132 | 415k | } | 133 | 168k | HistogramEntropy(in[i]); | 134 | 168k | if (in[i].total_count > in[largest_idx].total_count) { | 135 | 2.52k | largest_idx = i; | 136 | 2.52k | } | 137 | 168k | } | 138 | | | 139 | 1.30k | if (prev_histograms > 0) { | 140 | 0 | for (size_t j = 0; j < prev_histograms; ++j) { | 141 | 0 | HistogramEntropy((*out)[j]); | 142 | 0 | } | 143 | 0 | for (size_t i = 0; i < in.size(); i++) { | 144 | 0 | if (dists[i] == 0.0f) continue; | 145 | 0 | for (size_t j = 0; j < prev_histograms; ++j) { | 146 | 0 | dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]); | 147 | 0 | } | 148 | 0 | } | 149 | 0 | auto max_dist = std::max_element(dists.begin(), dists.end()); | 150 | 0 | if (*max_dist > 0.0f) { | 151 | 0 | largest_idx = max_dist - dists.begin(); | 152 | 0 | } | 153 | 0 | } | 154 | | | 155 | 1.30k | constexpr float kMinDistanceForDistinct = 48.0f; | 156 | 20.2k | while (out->size() < max_histograms) { | 157 | 20.1k | (*histogram_symbols)[largest_idx] = out->size(); | 158 | 20.1k | out->push_back(in[largest_idx]); | 159 | 20.1k | dists[largest_idx] = 0.0f; | 160 | 20.1k | largest_idx = 0; | 161 | 21.7M | for (size_t i = 0; i < in.size(); i++) { | 162 | 21.7M | if (dists[i] == 0.0f) continue; | 163 | 15.2M | dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); | 164 | 15.2M | if (dists[i] > dists[largest_idx]) largest_idx = i; | 165 | 15.2M | } | 166 | 20.1k | if (dists[largest_idx] < kMinDistanceForDistinct) break; | 167 | 20.1k | } | 168 | | | 169 | 584k | for (size_t i = 0; i < in.size(); i++) { | 170 | 583k | if ((*histogram_symbols)[i] != max_histograms) continue; | 171 | 148k | size_t best = 0; | 172 | 148k | float best_dist = std::numeric_limits<float>::max(); | 173 | 14.6M | for (size_t j = 0; j < out->size(); j++) { | 174 | 14.4M | float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j]) | 175 | 14.4M | : HistogramDistance(in[i], (*out)[j]); | 176 | 14.4M | if (dist < best_dist) { | 177 | 587k | best = j; | 178 | 587k | best_dist = dist; | 179 | 587k | } | 180 | 14.4M | } | 181 | 148k | JXL_ENSURE(best_dist < std::numeric_limits<float>::max()); | 182 | 148k | if (best >= prev_histograms) { | 183 | 148k | (*out)[best].AddHistogram(in[i]); | 184 | 148k | HistogramEntropy((*out)[best]); | 185 | 148k | } | 186 | 148k | (*histogram_symbols)[i] = best; | 187 | 148k | } | 188 | 1.30k | return true; | 189 | 1.30k | } |
Unexecuted instantiation: jxl::N_SSE2::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) |
190 | | |
191 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
192 | | } // namespace HWY_NAMESPACE |
193 | | } // namespace jxl |
194 | | HWY_AFTER_NAMESPACE(); |
195 | | |
196 | | #if HWY_ONCE |
197 | | namespace jxl { |
198 | | HWY_EXPORT(FastClusterHistograms); // Local function |
199 | | HWY_EXPORT(HistogramEntropy); // Local function |
200 | | |
201 | 323k | float Histogram::ShannonEntropy() const { |
202 | 323k | HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this); |
203 | 323k | return entropy; |
204 | 323k | } |
205 | | |
206 | | namespace { |
207 | | // ----------------------------------------------------------------------------- |
208 | | // Histogram refinement |
209 | | |
210 | | // Reorder histograms in *out so that the new symbols in *symbols come in |
211 | | // increasing order. |
212 | | void HistogramReindex(std::vector<Histogram>* out, size_t prev_histograms, |
213 | 1.30k | std::vector<uint32_t>* symbols) { |
214 | 1.30k | std::vector<Histogram> tmp(*out); |
215 | 1.30k | std::map<int, int> new_index; |
216 | 1.30k | for (size_t i = 0; i < prev_histograms; ++i) { |
217 | 0 | new_index[i] = i; |
218 | 0 | } |
219 | 1.30k | int next_index = prev_histograms; |
220 | 583k | for (uint32_t symbol : *symbols) { |
221 | 583k | if (new_index.find(symbol) == new_index.end()) { |
222 | 19.9k | new_index[symbol] = next_index; |
223 | 19.9k | (*out)[next_index] = tmp[symbol]; |
224 | 19.9k | ++next_index; |
225 | 19.9k | } |
226 | 583k | } |
227 | 1.30k | out->resize(next_index); |
228 | 583k | for (uint32_t& symbol : *symbols) { |
229 | 583k | symbol = new_index[symbol]; |
230 | 583k | } |
231 | 1.30k | } |
232 | | |
233 | | } // namespace |
234 | | |
235 | | // Clusters similar histograms in 'in' together, the selected histograms are |
236 | | // placed in 'out', and for each index in 'in', *histogram_symbols will |
237 | | // indicate which of the 'out' histograms is the best approximation. |
238 | | Status ClusterHistograms(const HistogramParams& params, |
239 | | const std::vector<Histogram>& in, |
240 | | size_t max_histograms, std::vector<Histogram>* out, |
241 | 1.30k | std::vector<uint32_t>* histogram_symbols) { |
242 | 1.30k | size_t prev_histograms = out->size(); |
243 | 1.30k | max_histograms = std::min(max_histograms, params.max_histograms); |
244 | 1.30k | max_histograms = std::min(max_histograms, in.size()); |
245 | 1.30k | if (params.clustering == HistogramParams::ClusteringType::kFastest) { |
246 | 0 | max_histograms = std::min(max_histograms, static_cast<size_t>(4)); |
247 | 0 | } |
248 | | |
249 | 1.30k | JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(FastClusterHistograms)( |
250 | 1.30k | in, prev_histograms + max_histograms, out, histogram_symbols)); |
251 | | |
252 | 1.30k | if (prev_histograms == 0 && |
253 | 1.30k | params.clustering == HistogramParams::ClusteringType::kBest) { |
254 | 1.79k | for (auto& histo : *out) { |
255 | 1.79k | JXL_ASSIGN_OR_RETURN(histo.entropy, histo.ANSPopulationCost()); |
256 | 1.79k | } |
257 | 557 | uint32_t next_version = 2; |
258 | 557 | std::vector<uint32_t> version(out->size(), 1); |
259 | 557 | std::vector<uint32_t> renumbering(out->size()); |
260 | 557 | std::iota(renumbering.begin(), renumbering.end(), 0); |
261 | | |
262 | | // Try to pair up clusters if doing so reduces the total cost. |
263 | | |
264 | 557 | struct HistogramPair { |
265 | | // validity of a pair: p.version == max(version[i], version[j]) |
266 | 557 | float cost; |
267 | 557 | uint32_t first; |
268 | 557 | uint32_t second; |
269 | 557 | uint32_t version; |
270 | | // We use > because priority queues sort in *decreasing* order, but we |
271 | | // want lower cost elements to appear first. |
272 | 1.60k | bool operator<(const HistogramPair& other) const { |
273 | 1.60k | return std::make_tuple(cost, first, second, version) > |
274 | 1.60k | std::make_tuple(other.cost, other.first, other.second, |
275 | 1.60k | other.version); |
276 | 1.60k | } |
277 | 557 | }; |
278 | | |
279 | | // Create list of all pairs by increasing merging cost. |
280 | 557 | std::priority_queue<HistogramPair> pairs_to_merge; |
281 | 2.35k | for (uint32_t i = 0; i < out->size(); i++) { |
282 | 5.01k | for (uint32_t j = i + 1; j < out->size(); j++) { |
283 | 3.21k | Histogram histo; |
284 | 3.21k | histo.AddHistogram((*out)[i]); |
285 | 3.21k | histo.AddHistogram((*out)[j]); |
286 | 3.21k | JXL_ASSIGN_OR_RETURN(float cost, histo.ANSPopulationCost()); |
287 | 3.21k | cost -= (*out)[i].entropy + (*out)[j].entropy; |
288 | | // Avoid enqueueing pairs that are not advantageous to merge. |
289 | 3.21k | if (cost >= 0) continue; |
290 | 389 | pairs_to_merge.push( |
291 | 389 | HistogramPair{cost, i, j, std::max(version[i], version[j])}); |
292 | 389 | } |
293 | 1.79k | } |
294 | | |
295 | | // Merge the best pair to merge, add new pairs that get formed as a |
296 | | // consequence. |
297 | 1.04k | while (!pairs_to_merge.empty()) { |
298 | 485 | uint32_t first = pairs_to_merge.top().first; |
299 | 485 | uint32_t second = pairs_to_merge.top().second; |
300 | 485 | uint32_t ver = pairs_to_merge.top().version; |
301 | 485 | pairs_to_merge.pop(); |
302 | 485 | if (ver != std::max(version[first], version[second]) || |
303 | 485 | version[first] == 0 || version[second] == 0) { |
304 | 333 | continue; |
305 | 333 | } |
306 | 152 | (*out)[first].AddHistogram((*out)[second]); |
307 | 152 | JXL_ASSIGN_OR_RETURN((*out)[first].entropy, |
308 | 152 | (*out)[first].ANSPopulationCost()); |
309 | 878 | for (uint32_t& item : renumbering) { |
310 | 878 | if (item == second) { |
311 | 173 | item = first; |
312 | 173 | } |
313 | 878 | } |
314 | 152 | version[second] = 0; |
315 | 152 | version[first] = next_version++; |
316 | 1.03k | for (uint32_t j = 0; j < out->size(); j++) { |
317 | 878 | if (j == first) continue; |
318 | 726 | if (version[j] == 0) continue; |
319 | 470 | Histogram histo; |
320 | 470 | histo.AddHistogram((*out)[first]); |
321 | 470 | histo.AddHistogram((*out)[j]); |
322 | 470 | JXL_ASSIGN_OR_RETURN(float merge_cost, histo.ANSPopulationCost()); |
323 | 470 | merge_cost -= (*out)[first].entropy + (*out)[j].entropy; |
324 | | // Avoid enqueueing pairs that are not advantageous to merge. |
325 | 470 | if (merge_cost >= 0) continue; |
326 | 96 | pairs_to_merge.push( |
327 | 96 | HistogramPair{merge_cost, std::min(first, j), std::max(first, j), |
328 | 96 | std::max(version[first], version[j])}); |
329 | 96 | } |
330 | 152 | } |
331 | 557 | std::vector<uint32_t> reverse_renumbering(out->size(), -1); |
332 | 557 | size_t num_alive = 0; |
333 | 2.35k | for (size_t i = 0; i < out->size(); i++) { |
334 | 1.79k | if (version[i] == 0) continue; |
335 | 1.64k | (*out)[num_alive++] = (*out)[i]; |
336 | 1.64k | reverse_renumbering[i] = num_alive - 1; |
337 | 1.64k | } |
338 | 557 | out->resize(num_alive); |
339 | 2.75k | for (uint32_t& item : *histogram_symbols) { |
340 | 2.75k | item = reverse_renumbering[renumbering[item]]; |
341 | 2.75k | } |
342 | 557 | } |
343 | | |
344 | | // Convert the context map to a canonical form. |
345 | 1.30k | HistogramReindex(out, prev_histograms, histogram_symbols); |
346 | 1.30k | return true; |
347 | 1.30k | } |
348 | | |
349 | | } // namespace jxl |
350 | | #endif // HWY_ONCE |