/src/libjxl/lib/jxl/enc_cluster.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/enc_cluster.h" |
7 | | |
8 | | #include <algorithm> |
9 | | #include <cstddef> |
10 | | #include <cstdint> |
11 | | #include <limits> |
12 | | #include <map> |
13 | | #include <numeric> |
14 | | #include <queue> |
15 | | #include <tuple> |
16 | | #include <vector> |
17 | | |
18 | | #include "lib/jxl/base/status.h" |
19 | | #include "lib/jxl/enc_ans_params.h" |
20 | | |
21 | | #undef HWY_TARGET_INCLUDE |
22 | | #define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc" |
23 | | #include <hwy/foreach_target.h> |
24 | | #include <hwy/highway.h> |
25 | | |
26 | | #include "lib/jxl/base/fast_math-inl.h" |
27 | | HWY_BEFORE_NAMESPACE(); |
28 | | namespace jxl { |
29 | | namespace HWY_NAMESPACE { |
30 | | |
31 | | // These templates are not found via ADL. |
32 | | using hwy::HWY_NAMESPACE::AllTrue; |
33 | | using hwy::HWY_NAMESPACE::Eq; |
34 | | using hwy::HWY_NAMESPACE::GetLane; |
35 | | using hwy::HWY_NAMESPACE::IfThenZeroElse; |
36 | | using hwy::HWY_NAMESPACE::SumOfLanes; |
37 | | using hwy::HWY_NAMESPACE::Zero; |
38 | | |
39 | | template <class V> |
40 | 380M | V Entropy(V count, V inv_total, V total) { |
41 | 380M | const HWY_CAPPED(float, Histogram::kRounding) d; |
42 | 380M | const auto zero = Set(d, 0.0f); |
43 | | // TODO(eustas): why (0 - x) instead of Neg(x)? |
44 | 380M | return IfThenZeroElse( |
45 | 380M | Eq(count, total), |
46 | 380M | Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count))))); |
47 | 380M | } Unexecuted instantiation: hwy::N_SSE4::Vec128<float, 4ul> jxl::N_SSE4::Entropy<hwy::N_SSE4::Vec128<float, 4ul> >(hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>) hwy::N_AVX2::Vec256<float> jxl::N_AVX2::Entropy<hwy::N_AVX2::Vec256<float> >(hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>) Line | Count | Source | 40 | 380M | V Entropy(V count, V inv_total, V total) { | 41 | 380M | const HWY_CAPPED(float, Histogram::kRounding) d; | 42 | 380M | const auto zero = Set(d, 0.0f); | 43 | | // TODO(eustas): why (0 - x) instead of Neg(x)? | 44 | 380M | return IfThenZeroElse( | 45 | 380M | Eq(count, total), | 46 | 380M | Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count))))); | 47 | 380M | } |
Unexecuted instantiation: hwy::N_AVX3::Vec256<float> jxl::N_AVX3::Entropy<hwy::N_AVX3::Vec256<float> >(hwy::N_AVX3::Vec256<float>, hwy::N_AVX3::Vec256<float>, hwy::N_AVX3::Vec256<float>) Unexecuted instantiation: hwy::N_AVX3_ZEN4::Vec256<float> jxl::N_AVX3_ZEN4::Entropy<hwy::N_AVX3_ZEN4::Vec256<float> >(hwy::N_AVX3_ZEN4::Vec256<float>, hwy::N_AVX3_ZEN4::Vec256<float>, hwy::N_AVX3_ZEN4::Vec256<float>) Unexecuted instantiation: hwy::N_AVX3_SPR::Vec256<float> jxl::N_AVX3_SPR::Entropy<hwy::N_AVX3_SPR::Vec256<float> >(hwy::N_AVX3_SPR::Vec256<float>, hwy::N_AVX3_SPR::Vec256<float>, hwy::N_AVX3_SPR::Vec256<float>) Unexecuted instantiation: hwy::N_SSE2::Vec128<float, 4ul> jxl::N_SSE2::Entropy<hwy::N_SSE2::Vec128<float, 4ul> >(hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>) |
48 | | |
49 | 404k | void HistogramCondition(Histogram& a) { |
50 | 404k | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
51 | 404k | const auto kZero = Zero(di); |
52 | 404k | auto total = kZero; |
53 | 404k | int nz_pos = -static_cast<int>(Lanes(di)); |
54 | 2.13M | for (size_t i = 0; i < a.counts.size(); i += Lanes(di)) { |
55 | 1.73M | const auto counts = LoadU(di, &a.counts[i]); |
56 | 1.73M | const bool nz = !AllTrue(di, Eq(counts, kZero)); |
57 | 1.73M | total = Add(total, counts); |
58 | 1.73M | if (nz) nz_pos = i; |
59 | 1.73M | } |
60 | 404k | a.counts.resize(nz_pos + Lanes(di)); |
61 | 404k | a.total_count = GetLane(SumOfLanes(di, total)); |
62 | 404k | } Unexecuted instantiation: jxl::N_SSE4::HistogramCondition(jxl::Histogram&) jxl::N_AVX2::HistogramCondition(jxl::Histogram&) Line | Count | Source | 49 | 404k | void HistogramCondition(Histogram& a) { | 50 | 404k | const HWY_CAPPED(int32_t, Histogram::kRounding) di; | 51 | 404k | const auto kZero = Zero(di); | 52 | 404k | auto total = kZero; | 53 | 404k | int nz_pos = -static_cast<int>(Lanes(di)); | 54 | 2.13M | for (size_t i = 0; i < a.counts.size(); i += Lanes(di)) { | 55 | 1.73M | const auto counts = LoadU(di, &a.counts[i]); | 56 | 1.73M | const bool nz = !AllTrue(di, Eq(counts, kZero)); | 57 | 1.73M | total = Add(total, counts); | 58 | 1.73M | if (nz) nz_pos = i; | 59 | 1.73M | } | 60 | 404k | a.counts.resize(nz_pos + Lanes(di)); | 61 | 404k | a.total_count = GetLane(SumOfLanes(di, total)); | 62 | 404k | } |
Unexecuted instantiation: jxl::N_AVX3::HistogramCondition(jxl::Histogram&) Unexecuted instantiation: jxl::N_AVX3_ZEN4::HistogramCondition(jxl::Histogram&) Unexecuted instantiation: jxl::N_AVX3_SPR::HistogramCondition(jxl::Histogram&) Unexecuted instantiation: jxl::N_SSE2::HistogramCondition(jxl::Histogram&) |
63 | | |
64 | 12.6M | void HistogramEntropy(const Histogram& a) { |
65 | 12.6M | a.entropy = 0.0f; |
66 | 12.6M | if (a.total_count == 0) return; |
67 | | |
68 | 5.26M | const HWY_CAPPED(float, Histogram::kRounding) df; |
69 | 5.26M | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
70 | | |
71 | 5.26M | const auto inv_tot = Set(df, 1.0f / a.total_count); |
72 | 5.26M | auto entropy_lanes = Zero(df); |
73 | 5.26M | auto total = Set(df, a.total_count); |
74 | | |
75 | 16.1M | for (size_t i = 0; i < a.counts.size(); i += Lanes(di)) { |
76 | 10.9M | const auto counts = LoadU(di, &a.counts[i]); |
77 | 10.9M | entropy_lanes = |
78 | 10.9M | Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total)); |
79 | 10.9M | } |
80 | 5.26M | a.entropy += GetLane(SumOfLanes(df, entropy_lanes)); |
81 | 5.26M | } Unexecuted instantiation: jxl::N_SSE4::HistogramEntropy(jxl::Histogram const&) jxl::N_AVX2::HistogramEntropy(jxl::Histogram const&) Line | Count | Source | 64 | 12.6M | void HistogramEntropy(const Histogram& a) { | 65 | 12.6M | a.entropy = 0.0f; | 66 | 12.6M | if (a.total_count == 0) return; | 67 | | | 68 | 5.26M | const HWY_CAPPED(float, Histogram::kRounding) df; | 69 | 5.26M | const HWY_CAPPED(int32_t, Histogram::kRounding) di; | 70 | | | 71 | 5.26M | const auto inv_tot = Set(df, 1.0f / a.total_count); | 72 | 5.26M | auto entropy_lanes = Zero(df); | 73 | 5.26M | auto total = Set(df, a.total_count); | 74 | | | 75 | 16.1M | for (size_t i = 0; i < a.counts.size(); i += Lanes(di)) { | 76 | 10.9M | const auto counts = LoadU(di, &a.counts[i]); | 77 | 10.9M | entropy_lanes = | 78 | 10.9M | Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total)); | 79 | 10.9M | } | 80 | 5.26M | a.entropy += GetLane(SumOfLanes(df, entropy_lanes)); | 81 | 5.26M | } |
Unexecuted instantiation: jxl::N_AVX3::HistogramEntropy(jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_ZEN4::HistogramEntropy(jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_SPR::HistogramEntropy(jxl::Histogram const&) Unexecuted instantiation: jxl::N_SSE2::HistogramEntropy(jxl::Histogram const&) |
82 | | |
83 | 124M | float HistogramDistance(const Histogram& a, const Histogram& b) { |
84 | 124M | if (a.total_count == 0 || b.total_count == 0) return 0; |
85 | | |
86 | 124M | const HWY_CAPPED(float, Histogram::kRounding) df; |
87 | 124M | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
88 | | |
89 | 124M | const auto inv_tot = Set(df, 1.0f / (a.total_count + b.total_count)); |
90 | 124M | auto distance_lanes = Zero(df); |
91 | 124M | auto total = Set(df, a.total_count + b.total_count); |
92 | | |
93 | 493M | for (size_t i = 0; i < std::max(a.counts.size(), b.counts.size()); |
94 | 369M | i += Lanes(di)) { |
95 | 369M | const auto a_counts = |
96 | 369M | a.counts.size() > i ? LoadU(di, &a.counts[i]) : Zero(di); |
97 | 369M | const auto b_counts = |
98 | 369M | b.counts.size() > i ? LoadU(di, &b.counts[i]) : Zero(di); |
99 | 369M | const auto counts = ConvertTo(df, Add(a_counts, b_counts)); |
100 | 369M | distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total)); |
101 | 369M | } |
102 | 124M | const float total_distance = GetLane(SumOfLanes(df, distance_lanes)); |
103 | 124M | return total_distance - a.entropy - b.entropy; |
104 | 124M | } Unexecuted instantiation: jxl::N_SSE4::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) jxl::N_AVX2::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Line | Count | Source | 83 | 124M | float HistogramDistance(const Histogram& a, const Histogram& b) { | 84 | 124M | if (a.total_count == 0 || b.total_count == 0) return 0; | 85 | | | 86 | 124M | const HWY_CAPPED(float, Histogram::kRounding) df; | 87 | 124M | const HWY_CAPPED(int32_t, Histogram::kRounding) di; | 88 | | | 89 | 124M | const auto inv_tot = Set(df, 1.0f / (a.total_count + b.total_count)); | 90 | 124M | auto distance_lanes = Zero(df); | 91 | 124M | auto total = Set(df, a.total_count + b.total_count); | 92 | | | 93 | 493M | for (size_t i = 0; i < std::max(a.counts.size(), b.counts.size()); | 94 | 369M | i += Lanes(di)) { | 95 | 369M | const auto a_counts = | 96 | 369M | a.counts.size() > i ? LoadU(di, &a.counts[i]) : Zero(di); | 97 | 369M | const auto b_counts = | 98 | 369M | b.counts.size() > i ? LoadU(di, &b.counts[i]) : Zero(di); | 99 | 369M | const auto counts = ConvertTo(df, Add(a_counts, b_counts)); | 100 | 369M | distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total)); | 101 | 369M | } | 102 | 124M | const float total_distance = GetLane(SumOfLanes(df, distance_lanes)); | 103 | 124M | return total_distance - a.entropy - b.entropy; | 104 | 124M | } |
Unexecuted instantiation: jxl::N_AVX3::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_ZEN4::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_SPR::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_SSE2::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&) |
105 | | |
106 | | constexpr const float kInfinity = std::numeric_limits<float>::infinity(); |
107 | | |
108 | 0 | float HistogramKLDivergence(const Histogram& actual, const Histogram& coding) { |
109 | 0 | if (actual.total_count == 0) return 0; |
110 | 0 | if (coding.total_count == 0) return kInfinity; |
111 | | |
112 | 0 | const HWY_CAPPED(float, Histogram::kRounding) df; |
113 | 0 | const HWY_CAPPED(int32_t, Histogram::kRounding) di; |
114 | |
|
115 | 0 | const auto coding_inv = Set(df, 1.0f / coding.total_count); |
116 | 0 | auto cost_lanes = Zero(df); |
117 | |
|
118 | 0 | for (size_t i = 0; i < actual.counts.size(); i += Lanes(di)) { |
119 | 0 | const auto counts = LoadU(di, &actual.counts[i]); |
120 | 0 | const auto coding_counts = |
121 | 0 | coding.counts.size() > i ? LoadU(di, &coding.counts[i]) : Zero(di); |
122 | 0 | const auto coding_probs = Mul(ConvertTo(df, coding_counts), coding_inv); |
123 | 0 | const auto neg_coding_cost = BitCast( |
124 | 0 | df, |
125 | 0 | IfThenZeroElse(Eq(counts, Zero(di)), |
126 | 0 | IfThenElse(Eq(coding_counts, Zero(di)), |
127 | 0 | BitCast(di, Set(df, -kInfinity)), |
128 | 0 | BitCast(di, FastLog2f(df, coding_probs))))); |
129 | 0 | cost_lanes = NegMulAdd(ConvertTo(df, counts), neg_coding_cost, cost_lanes); |
130 | 0 | } |
131 | 0 | const float total_cost = GetLane(SumOfLanes(df, cost_lanes)); |
132 | 0 | return total_cost - actual.entropy; |
133 | 0 | } Unexecuted instantiation: jxl::N_SSE4::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX2::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_ZEN4::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_AVX3_SPR::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) Unexecuted instantiation: jxl::N_SSE2::HistogramKLDivergence(jxl::Histogram const&, jxl::Histogram const&) |
134 | | |
135 | | // First step of a k-means clustering with a fancy distance metric. |
136 | | Status FastClusterHistograms(const std::vector<Histogram>& in, |
137 | | size_t max_histograms, std::vector<Histogram>* out, |
138 | 13.1k | std::vector<uint32_t>* histogram_symbols) { |
139 | 13.1k | const size_t prev_histograms = out->size(); |
140 | 13.1k | out->reserve(max_histograms); |
141 | 13.1k | histogram_symbols->clear(); |
142 | 13.1k | histogram_symbols->resize(in.size(), max_histograms); |
143 | | |
144 | 13.1k | std::vector<float> dists(in.size(), std::numeric_limits<float>::max()); |
145 | 13.1k | size_t largest_idx = 0; |
146 | 12.4M | for (size_t i = 0; i < in.size(); i++) { |
147 | 12.3M | if (in[i].total_count == 0) { |
148 | 10.5M | (*histogram_symbols)[i] = 0; |
149 | 10.5M | dists[i] = 0.0f; |
150 | 10.5M | continue; |
151 | 10.5M | } |
152 | 1.79M | HistogramEntropy(in[i]); |
153 | 1.79M | if (in[i].total_count > in[largest_idx].total_count) { |
154 | 24.8k | largest_idx = i; |
155 | 24.8k | } |
156 | 1.79M | } |
157 | | |
158 | 13.1k | if (prev_histograms > 0) { |
159 | 0 | for (size_t j = 0; j < prev_histograms; ++j) { |
160 | 0 | HistogramEntropy((*out)[j]); |
161 | 0 | } |
162 | 0 | for (size_t i = 0; i < in.size(); i++) { |
163 | 0 | if (dists[i] == 0.0f) continue; |
164 | 0 | for (size_t j = 0; j < prev_histograms; ++j) { |
165 | 0 | dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]); |
166 | 0 | } |
167 | 0 | } |
168 | 0 | auto max_dist = std::max_element(dists.begin(), dists.end()); |
169 | 0 | if (*max_dist > 0.0f) { |
170 | 0 | largest_idx = max_dist - dists.begin(); |
171 | 0 | } |
172 | 0 | } |
173 | | |
174 | 13.1k | constexpr float kMinDistanceForDistinct = 48.0f; |
175 | 108k | while (out->size() < max_histograms) { |
176 | 108k | (*histogram_symbols)[largest_idx] = out->size(); |
177 | 108k | out->push_back(in[largest_idx]); |
178 | 108k | dists[largest_idx] = 0.0f; |
179 | 108k | largest_idx = 0; |
180 | 181M | for (size_t i = 0; i < in.size(); i++) { |
181 | 181M | if (dists[i] == 0.0f) continue; |
182 | 63.1M | dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); |
183 | 63.1M | if (dists[i] > dists[largest_idx]) largest_idx = i; |
184 | 63.1M | } |
185 | 108k | if (dists[largest_idx] < kMinDistanceForDistinct) break; |
186 | 108k | } |
187 | | |
188 | 12.4M | for (size_t i = 0; i < in.size(); i++) { |
189 | 12.3M | if ((*histogram_symbols)[i] != max_histograms) continue; |
190 | 1.68M | size_t best = 0; |
191 | 1.68M | float best_dist = std::numeric_limits<float>::max(); |
192 | 62.7M | for (size_t j = 0; j < out->size(); j++) { |
193 | 61.0M | float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j]) |
194 | 61.0M | : HistogramDistance(in[i], (*out)[j]); |
195 | 61.0M | if (dist < best_dist) { |
196 | 4.87M | best = j; |
197 | 4.87M | best_dist = dist; |
198 | 4.87M | } |
199 | 61.0M | } |
200 | 1.68M | JXL_ENSURE(best_dist < std::numeric_limits<float>::max()); |
201 | 1.68M | if (best >= prev_histograms) { |
202 | 1.68M | (*out)[best].AddHistogram(in[i]); |
203 | 1.68M | HistogramEntropy((*out)[best]); |
204 | 1.68M | } |
205 | 1.68M | (*histogram_symbols)[i] = best; |
206 | 1.68M | } |
207 | 13.1k | return true; |
208 | 13.1k | } Unexecuted instantiation: jxl::N_SSE4::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) jxl::N_AVX2::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Line | Count | Source | 138 | 13.1k | std::vector<uint32_t>* histogram_symbols) { | 139 | 13.1k | const size_t prev_histograms = out->size(); | 140 | 13.1k | out->reserve(max_histograms); | 141 | 13.1k | histogram_symbols->clear(); | 142 | 13.1k | histogram_symbols->resize(in.size(), max_histograms); | 143 | | | 144 | 13.1k | std::vector<float> dists(in.size(), std::numeric_limits<float>::max()); | 145 | 13.1k | size_t largest_idx = 0; | 146 | 12.4M | for (size_t i = 0; i < in.size(); i++) { | 147 | 12.3M | if (in[i].total_count == 0) { | 148 | 10.5M | (*histogram_symbols)[i] = 0; | 149 | 10.5M | dists[i] = 0.0f; | 150 | 10.5M | continue; | 151 | 10.5M | } | 152 | 1.79M | HistogramEntropy(in[i]); | 153 | 1.79M | if (in[i].total_count > in[largest_idx].total_count) { | 154 | 24.8k | largest_idx = i; | 155 | 24.8k | } | 156 | 1.79M | } | 157 | | | 158 | 13.1k | if (prev_histograms > 0) { | 159 | 0 | for (size_t j = 0; j < prev_histograms; ++j) { | 160 | 0 | HistogramEntropy((*out)[j]); | 161 | 0 | } | 162 | 0 | for (size_t i = 0; i < in.size(); i++) { | 163 | 0 | if (dists[i] == 0.0f) continue; | 164 | 0 | for (size_t j = 0; j < prev_histograms; ++j) { | 165 | 0 | dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]); | 166 | 0 | } | 167 | 0 | } | 168 | 0 | auto max_dist = std::max_element(dists.begin(), dists.end()); | 169 | 0 | if (*max_dist > 0.0f) { | 170 | 0 | largest_idx = max_dist - dists.begin(); | 171 | 0 | } | 172 | 0 | } | 173 | | | 174 | 13.1k | constexpr float kMinDistanceForDistinct = 48.0f; | 175 | 108k | while (out->size() < max_histograms) { | 176 | 108k | (*histogram_symbols)[largest_idx] = out->size(); | 177 | 108k | out->push_back(in[largest_idx]); | 178 | 108k | dists[largest_idx] = 0.0f; | 179 | 108k | largest_idx = 0; | 180 | 181M | for (size_t i = 0; i < in.size(); i++) { | 181 | 181M | if (dists[i] == 0.0f) continue; | 182 | 63.1M | dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); | 183 | 63.1M | if (dists[i] > dists[largest_idx]) largest_idx = i; | 184 | 63.1M | } | 185 | 108k | if (dists[largest_idx] < kMinDistanceForDistinct) break; | 186 | 108k | } | 187 | | | 188 | 12.4M | for (size_t i = 0; i < in.size(); i++) { | 189 | 12.3M | if ((*histogram_symbols)[i] != max_histograms) continue; | 190 | 1.68M | size_t best = 0; | 191 | 1.68M | float best_dist = std::numeric_limits<float>::max(); | 192 | 62.7M | for (size_t j = 0; j < out->size(); j++) { | 193 | 61.0M | float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j]) | 194 | 61.0M | : HistogramDistance(in[i], (*out)[j]); | 195 | 61.0M | if (dist < best_dist) { | 196 | 4.87M | best = j; | 197 | 4.87M | best_dist = dist; | 198 | 4.87M | } | 199 | 61.0M | } | 200 | 1.68M | JXL_ENSURE(best_dist < std::numeric_limits<float>::max()); | 201 | 1.68M | if (best >= prev_histograms) { | 202 | 1.68M | (*out)[best].AddHistogram(in[i]); | 203 | 1.68M | HistogramEntropy((*out)[best]); | 204 | 1.68M | } | 205 | 1.68M | (*histogram_symbols)[i] = best; | 206 | 1.68M | } | 207 | 13.1k | return true; | 208 | 13.1k | } |
Unexecuted instantiation: jxl::N_AVX3::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Unexecuted instantiation: jxl::N_AVX3_ZEN4::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Unexecuted instantiation: jxl::N_AVX3_SPR::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) Unexecuted instantiation: jxl::N_SSE2::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*) |
209 | | |
210 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
211 | | } // namespace HWY_NAMESPACE |
212 | | } // namespace jxl |
213 | | HWY_AFTER_NAMESPACE(); |
214 | | |
215 | | #if HWY_ONCE |
216 | | namespace jxl { |
217 | | HWY_EXPORT(FastClusterHistograms); // Local function |
218 | | HWY_EXPORT(HistogramEntropy); // Local function |
219 | | HWY_EXPORT(HistogramCondition); // Local function |
220 | | |
221 | 404k | void Histogram::Condition() { HWY_DYNAMIC_DISPATCH(HistogramCondition)(*this); } |
222 | | |
223 | 9.20M | float Histogram::ShannonEntropy() const { |
224 | 9.20M | HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this); |
225 | 9.20M | return entropy; |
226 | 9.20M | } |
227 | | |
228 | | namespace { |
229 | | // ----------------------------------------------------------------------------- |
230 | | // Histogram refinement |
231 | | |
232 | | // Reorder histograms in *out so that the new symbols in *symbols come in |
233 | | // increasing order. |
234 | | void HistogramReindex(std::vector<Histogram>* out, size_t prev_histograms, |
235 | 13.1k | std::vector<uint32_t>* symbols) { |
236 | 13.1k | std::vector<Histogram> tmp(*out); |
237 | 13.1k | std::map<int, int> new_index; |
238 | 13.1k | for (size_t i = 0; i < prev_histograms; ++i) { |
239 | 0 | new_index[i] = i; |
240 | 0 | } |
241 | 13.1k | int next_index = prev_histograms; |
242 | 12.3M | for (uint32_t symbol : *symbols) { |
243 | 12.3M | if (new_index.find(symbol) == new_index.end()) { |
244 | 107k | new_index[symbol] = next_index; |
245 | 107k | (*out)[next_index] = tmp[symbol]; |
246 | 107k | ++next_index; |
247 | 107k | } |
248 | 12.3M | } |
249 | 13.1k | out->resize(next_index); |
250 | 12.3M | for (uint32_t& symbol : *symbols) { |
251 | 12.3M | symbol = new_index[symbol]; |
252 | 12.3M | } |
253 | 13.1k | } |
254 | | |
255 | | } // namespace |
256 | | |
257 | | // Clusters similar histograms in 'in' together, the selected histograms are |
258 | | // placed in 'out', and for each index in 'in', *histogram_symbols will |
259 | | // indicate which of the 'out' histograms is the best approximation. |
260 | | Status ClusterHistograms(const HistogramParams& params, |
261 | | const std::vector<Histogram>& in, |
262 | | size_t max_histograms, std::vector<Histogram>* out, |
263 | 13.1k | std::vector<uint32_t>* histogram_symbols) { |
264 | 13.1k | size_t prev_histograms = out->size(); |
265 | 13.1k | max_histograms = std::min(max_histograms, params.max_histograms); |
266 | 13.1k | max_histograms = std::min(max_histograms, in.size()); |
267 | 13.1k | if (params.clustering == HistogramParams::ClusteringType::kFastest) { |
268 | 0 | max_histograms = std::min(max_histograms, static_cast<size_t>(4)); |
269 | 0 | } |
270 | | |
271 | 13.1k | JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(FastClusterHistograms)( |
272 | 13.1k | in, prev_histograms + max_histograms, out, histogram_symbols)); |
273 | | |
274 | 13.1k | if (prev_histograms == 0 && |
275 | 13.1k | params.clustering == HistogramParams::ClusteringType::kBest) { |
276 | 12.9k | for (auto& histo : *out) { |
277 | 12.9k | JXL_ASSIGN_OR_RETURN(histo.entropy, histo.ANSPopulationCost()); |
278 | 12.9k | } |
279 | 4.84k | uint32_t next_version = 2; |
280 | 4.84k | std::vector<uint32_t> version(out->size(), 1); |
281 | 4.84k | std::vector<uint32_t> renumbering(out->size()); |
282 | 4.84k | std::iota(renumbering.begin(), renumbering.end(), 0); |
283 | | |
284 | | // Try to pair up clusters if doing so reduces the total cost. |
285 | | |
286 | 4.84k | struct HistogramPair { |
287 | | // validity of a pair: p.version == max(version[i], version[j]) |
288 | 4.84k | float cost; |
289 | 4.84k | uint32_t first; |
290 | 4.84k | uint32_t second; |
291 | 4.84k | uint32_t version; |
292 | | // We use > because priority queues sort in *decreasing* order, but we |
293 | | // want lower cost elements to appear first. |
294 | 4.84k | bool operator<(const HistogramPair& other) const { |
295 | 2.43k | return std::make_tuple(cost, first, second, version) > |
296 | 2.43k | std::make_tuple(other.cost, other.first, other.second, |
297 | 2.43k | other.version); |
298 | 2.43k | } |
299 | 4.84k | }; |
300 | | |
301 | | // Create list of all pairs by increasing merging cost. |
302 | 4.84k | std::priority_queue<HistogramPair> pairs_to_merge; |
303 | 17.8k | for (uint32_t i = 0; i < out->size(); i++) { |
304 | 32.1k | for (uint32_t j = i + 1; j < out->size(); j++) { |
305 | 19.1k | Histogram histo; |
306 | 19.1k | histo.AddHistogram((*out)[i]); |
307 | 19.1k | histo.AddHistogram((*out)[j]); |
308 | 19.1k | JXL_ASSIGN_OR_RETURN(float cost, histo.ANSPopulationCost()); |
309 | 19.1k | cost -= (*out)[i].entropy + (*out)[j].entropy; |
310 | | // Avoid enqueueing pairs that are not advantageous to merge. |
311 | 19.1k | if (cost >= 0) continue; |
312 | 1.13k | pairs_to_merge.push( |
313 | 1.13k | HistogramPair{cost, i, j, std::max(version[i], version[j])}); |
314 | 1.13k | } |
315 | 12.9k | } |
316 | | |
317 | | // Merge the best pair to merge, add new pairs that get formed as a |
318 | | // consequence. |
319 | 6.13k | while (!pairs_to_merge.empty()) { |
320 | 1.28k | uint32_t first = pairs_to_merge.top().first; |
321 | 1.28k | uint32_t second = pairs_to_merge.top().second; |
322 | 1.28k | uint32_t ver = pairs_to_merge.top().version; |
323 | 1.28k | pairs_to_merge.pop(); |
324 | 1.28k | if (ver != std::max(version[first], version[second]) || |
325 | 1.28k | version[first] == 0 || version[second] == 0) { |
326 | 644 | continue; |
327 | 644 | } |
328 | 645 | (*out)[first].AddHistogram((*out)[second]); |
329 | 645 | JXL_ASSIGN_OR_RETURN((*out)[first].entropy, |
330 | 645 | (*out)[first].ANSPopulationCost()); |
331 | 3.23k | for (uint32_t& item : renumbering) { |
332 | 3.23k | if (item == second) { |
333 | 686 | item = first; |
334 | 686 | } |
335 | 3.23k | } |
336 | 645 | version[second] = 0; |
337 | 645 | version[first] = next_version++; |
338 | 3.87k | for (uint32_t j = 0; j < out->size(); j++) { |
339 | 3.23k | if (j == first) continue; |
340 | 2.58k | if (version[j] == 0) continue; |
341 | 1.70k | Histogram histo; |
342 | 1.70k | histo.AddHistogram((*out)[first]); |
343 | 1.70k | histo.AddHistogram((*out)[j]); |
344 | 1.70k | JXL_ASSIGN_OR_RETURN(float merge_cost, histo.ANSPopulationCost()); |
345 | 1.70k | merge_cost -= (*out)[first].entropy + (*out)[j].entropy; |
346 | | // Avoid enqueueing pairs that are not advantageous to merge. |
347 | 1.70k | if (merge_cost >= 0) continue; |
348 | 154 | pairs_to_merge.push( |
349 | 154 | HistogramPair{merge_cost, std::min(first, j), std::max(first, j), |
350 | 154 | std::max(version[first], version[j])}); |
351 | 154 | } |
352 | 645 | } |
353 | 4.84k | std::vector<uint32_t> reverse_renumbering(out->size(), -1); |
354 | 4.84k | size_t num_alive = 0; |
355 | 17.8k | for (size_t i = 0; i < out->size(); i++) { |
356 | 12.9k | if (version[i] == 0) continue; |
357 | 12.3k | (*out)[num_alive++] = (*out)[i]; |
358 | 12.3k | reverse_renumbering[i] = num_alive - 1; |
359 | 12.3k | } |
360 | 4.84k | out->resize(num_alive); |
361 | 26.9k | for (uint32_t& item : *histogram_symbols) { |
362 | 26.9k | item = reverse_renumbering[renumbering[item]]; |
363 | 26.9k | } |
364 | 4.84k | } |
365 | | |
366 | | // Convert the context map to a canonical form. |
367 | 13.1k | HistogramReindex(out, prev_histograms, histogram_symbols); |
368 | 13.1k | return true; |
369 | 13.1k | } |
370 | | |
371 | | } // namespace jxl |
372 | | #endif // HWY_ONCE |