Coverage Report

Created: 2023-08-28 07:24

/src/libjxl/lib/jxl/enc_cluster.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_cluster.h"
7
8
#include <algorithm>
9
#include <cmath>
10
#include <limits>
11
#include <map>
12
#include <memory>
13
#include <numeric>
14
#include <queue>
15
#include <tuple>
16
17
#undef HWY_TARGET_INCLUDE
18
#define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc"
19
#include <hwy/foreach_target.h>
20
#include <hwy/highway.h>
21
22
#include "lib/jxl/ac_context.h"
23
#include "lib/jxl/fast_math-inl.h"
24
HWY_BEFORE_NAMESPACE();
25
namespace jxl {
26
namespace HWY_NAMESPACE {
27
28
// These templates are not found via ADL.
29
using hwy::HWY_NAMESPACE::Eq;
30
using hwy::HWY_NAMESPACE::IfThenZeroElse;
31
32
template <class V>
33
0
V Entropy(V count, V inv_total, V total) {
34
0
  const HWY_CAPPED(float, Histogram::kRounding) d;
35
0
  const auto zero = Set(d, 0.0f);
36
  // TODO(eustas): why (0 - x) instead of Neg(x)?
37
0
  return IfThenZeroElse(
38
0
      Eq(count, total),
39
0
      Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count)))));
40
0
}
Unexecuted instantiation: hwy::N_SSE4::Vec128<float, 4ul> jxl::N_SSE4::Entropy<hwy::N_SSE4::Vec128<float, 4ul> >(hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>, hwy::N_SSE4::Vec128<float, 4ul>)
Unexecuted instantiation: hwy::N_AVX2::Vec256<float> jxl::N_AVX2::Entropy<hwy::N_AVX2::Vec256<float> >(hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>, hwy::N_AVX2::Vec256<float>)
Unexecuted instantiation: hwy::N_AVX3::Vec256<float> jxl::N_AVX3::Entropy<hwy::N_AVX3::Vec256<float> >(hwy::N_AVX3::Vec256<float>, hwy::N_AVX3::Vec256<float>, hwy::N_AVX3::Vec256<float>)
Unexecuted instantiation: hwy::N_AVX3_ZEN4::Vec256<float> jxl::N_AVX3_ZEN4::Entropy<hwy::N_AVX3_ZEN4::Vec256<float> >(hwy::N_AVX3_ZEN4::Vec256<float>, hwy::N_AVX3_ZEN4::Vec256<float>, hwy::N_AVX3_ZEN4::Vec256<float>)
Unexecuted instantiation: hwy::N_AVX3_SPR::Vec256<float> jxl::N_AVX3_SPR::Entropy<hwy::N_AVX3_SPR::Vec256<float> >(hwy::N_AVX3_SPR::Vec256<float>, hwy::N_AVX3_SPR::Vec256<float>, hwy::N_AVX3_SPR::Vec256<float>)
Unexecuted instantiation: hwy::N_SSE2::Vec128<float, 4ul> jxl::N_SSE2::Entropy<hwy::N_SSE2::Vec128<float, 4ul> >(hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>, hwy::N_SSE2::Vec128<float, 4ul>)
41
42
0
void HistogramEntropy(const Histogram& a) {
43
0
  a.entropy_ = 0.0f;
44
0
  if (a.total_count_ == 0) return;
45
46
0
  const HWY_CAPPED(float, Histogram::kRounding) df;
47
0
  const HWY_CAPPED(int32_t, Histogram::kRounding) di;
48
49
0
  const auto inv_tot = Set(df, 1.0f / a.total_count_);
50
0
  auto entropy_lanes = Zero(df);
51
0
  auto total = Set(df, a.total_count_);
52
53
0
  for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) {
54
0
    const auto counts = LoadU(di, &a.data_[i]);
55
0
    entropy_lanes =
56
0
        Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total));
57
0
  }
58
0
  a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes));
59
0
}
Unexecuted instantiation: jxl::N_SSE4::HistogramEntropy(jxl::Histogram const&)
Unexecuted instantiation: jxl::N_AVX2::HistogramEntropy(jxl::Histogram const&)
Unexecuted instantiation: jxl::N_AVX3::HistogramEntropy(jxl::Histogram const&)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::HistogramEntropy(jxl::Histogram const&)
Unexecuted instantiation: jxl::N_AVX3_SPR::HistogramEntropy(jxl::Histogram const&)
Unexecuted instantiation: jxl::N_SSE2::HistogramEntropy(jxl::Histogram const&)
60
61
0
float HistogramDistance(const Histogram& a, const Histogram& b) {
62
0
  if (a.total_count_ == 0 || b.total_count_ == 0) return 0;
63
64
0
  const HWY_CAPPED(float, Histogram::kRounding) df;
65
0
  const HWY_CAPPED(int32_t, Histogram::kRounding) di;
66
67
0
  const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_));
68
0
  auto distance_lanes = Zero(df);
69
0
  auto total = Set(df, a.total_count_ + b.total_count_);
70
71
0
  for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size());
72
0
       i += Lanes(di)) {
73
0
    const auto a_counts =
74
0
        a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di);
75
0
    const auto b_counts =
76
0
        b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di);
77
0
    const auto counts = ConvertTo(df, Add(a_counts, b_counts));
78
0
    distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total));
79
0
  }
80
0
  const float total_distance = GetLane(SumOfLanes(df, distance_lanes));
81
0
  return total_distance - a.entropy_ - b.entropy_;
82
0
}
Unexecuted instantiation: jxl::N_SSE4::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&)
Unexecuted instantiation: jxl::N_AVX2::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&)
Unexecuted instantiation: jxl::N_AVX3::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&)
Unexecuted instantiation: jxl::N_AVX3_SPR::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&)
Unexecuted instantiation: jxl::N_SSE2::HistogramDistance(jxl::Histogram const&, jxl::Histogram const&)
83
84
// First step of a k-means clustering with a fancy distance metric.
85
void FastClusterHistograms(const std::vector<Histogram>& in,
86
                           size_t max_histograms, std::vector<Histogram>* out,
87
0
                           std::vector<uint32_t>* histogram_symbols) {
88
0
  out->clear();
89
0
  out->reserve(max_histograms);
90
0
  histogram_symbols->clear();
91
0
  histogram_symbols->resize(in.size(), max_histograms);
92
93
0
  std::vector<float> dists(in.size(), std::numeric_limits<float>::max());
94
0
  size_t largest_idx = 0;
95
0
  for (size_t i = 0; i < in.size(); i++) {
96
0
    if (in[i].total_count_ == 0) {
97
0
      (*histogram_symbols)[i] = 0;
98
0
      dists[i] = 0.0f;
99
0
      continue;
100
0
    }
101
0
    HistogramEntropy(in[i]);
102
0
    if (in[i].total_count_ > in[largest_idx].total_count_) {
103
0
      largest_idx = i;
104
0
    }
105
0
  }
106
107
0
  constexpr float kMinDistanceForDistinct = 48.0f;
108
0
  while (out->size() < max_histograms) {
109
0
    (*histogram_symbols)[largest_idx] = out->size();
110
0
    out->push_back(in[largest_idx]);
111
0
    dists[largest_idx] = 0.0f;
112
0
    largest_idx = 0;
113
0
    for (size_t i = 0; i < in.size(); i++) {
114
0
      if (dists[i] == 0.0f) continue;
115
0
      dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]);
116
0
      if (dists[i] > dists[largest_idx]) largest_idx = i;
117
0
    }
118
0
    if (dists[largest_idx] < kMinDistanceForDistinct) break;
119
0
  }
120
121
0
  for (size_t i = 0; i < in.size(); i++) {
122
0
    if ((*histogram_symbols)[i] != max_histograms) continue;
123
0
    size_t best = 0;
124
0
    float best_dist = HistogramDistance(in[i], (*out)[best]);
125
0
    for (size_t j = 1; j < out->size(); j++) {
126
0
      float dist = HistogramDistance(in[i], (*out)[j]);
127
0
      if (dist < best_dist) {
128
0
        best = j;
129
0
        best_dist = dist;
130
0
      }
131
0
    }
132
0
    (*out)[best].AddHistogram(in[i]);
133
0
    HistogramEntropy((*out)[best]);
134
0
    (*histogram_symbols)[i] = best;
135
0
  }
136
0
}
Unexecuted instantiation: jxl::N_SSE4::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*)
Unexecuted instantiation: jxl::N_AVX2::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*)
Unexecuted instantiation: jxl::N_AVX3::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*)
Unexecuted instantiation: jxl::N_AVX3_ZEN4::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*)
Unexecuted instantiation: jxl::N_AVX3_SPR::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*)
Unexecuted instantiation: jxl::N_SSE2::FastClusterHistograms(std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> > const&, unsigned long, std::__1::vector<jxl::Histogram, std::__1::allocator<jxl::Histogram> >*, std::__1::vector<unsigned int, std::__1::allocator<unsigned int> >*)
137
138
// NOLINTNEXTLINE(google-readability-namespace-comments)
139
}  // namespace HWY_NAMESPACE
140
}  // namespace jxl
141
HWY_AFTER_NAMESPACE();
142
143
#if HWY_ONCE
144
namespace jxl {
145
HWY_EXPORT(FastClusterHistograms);  // Local function
146
HWY_EXPORT(HistogramEntropy);       // Local function
147
148
0
float Histogram::ShannonEntropy() const {
149
0
  HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this);
150
0
  return entropy_;
151
0
}
152
153
namespace {
154
// -----------------------------------------------------------------------------
155
// Histogram refinement
156
157
// Reorder histograms in *out so that the new symbols in *symbols come in
158
// increasing order.
159
void HistogramReindex(std::vector<Histogram>* out,
160
0
                      std::vector<uint32_t>* symbols) {
161
0
  std::vector<Histogram> tmp(*out);
162
0
  std::map<int, int> new_index;
163
0
  int next_index = 0;
164
0
  for (uint32_t symbol : *symbols) {
165
0
    if (new_index.find(symbol) == new_index.end()) {
166
0
      new_index[symbol] = next_index;
167
0
      (*out)[next_index] = tmp[symbol];
168
0
      ++next_index;
169
0
    }
170
0
  }
171
0
  out->resize(next_index);
172
0
  for (uint32_t& symbol : *symbols) {
173
0
    symbol = new_index[symbol];
174
0
  }
175
0
}
176
177
}  // namespace
178
179
// Clusters similar histograms in 'in' together, the selected histograms are
180
// placed in 'out', and for each index in 'in', *histogram_symbols will
181
// indicate which of the 'out' histograms is the best approximation.
182
void ClusterHistograms(const HistogramParams params,
183
                       const std::vector<Histogram>& in, size_t max_histograms,
184
                       std::vector<Histogram>* out,
185
0
                       std::vector<uint32_t>* histogram_symbols) {
186
0
  max_histograms = std::min(max_histograms, params.max_histograms);
187
0
  max_histograms = std::min(max_histograms, in.size());
188
0
  if (params.clustering == HistogramParams::ClusteringType::kFastest) {
189
0
    max_histograms = std::min(max_histograms, static_cast<size_t>(4));
190
0
  }
191
192
0
  HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
193
0
  (in, max_histograms, out, histogram_symbols);
194
195
0
  if (params.clustering == HistogramParams::ClusteringType::kBest) {
196
0
    for (size_t i = 0; i < out->size(); i++) {
197
0
      (*out)[i].entropy_ =
198
0
          ANSPopulationCost((*out)[i].data_.data(), (*out)[i].data_.size());
199
0
    }
200
0
    uint32_t next_version = 2;
201
0
    std::vector<uint32_t> version(out->size(), 1);
202
0
    std::vector<uint32_t> renumbering(out->size());
203
0
    std::iota(renumbering.begin(), renumbering.end(), 0);
204
205
    // Try to pair up clusters if doing so reduces the total cost.
206
207
0
    struct HistogramPair {
208
      // validity of a pair: p.version == max(version[i], version[j])
209
0
      float cost;
210
0
      uint32_t first;
211
0
      uint32_t second;
212
0
      uint32_t version;
213
      // We use > because priority queues sort in *decreasing* order, but we
214
      // want lower cost elements to appear first.
215
0
      bool operator<(const HistogramPair& other) const {
216
0
        return std::make_tuple(cost, first, second, version) >
217
0
               std::make_tuple(other.cost, other.first, other.second,
218
0
                               other.version);
219
0
      }
220
0
    };
221
222
    // Create list of all pairs by increasing merging cost.
223
0
    std::priority_queue<HistogramPair> pairs_to_merge;
224
0
    for (uint32_t i = 0; i < out->size(); i++) {
225
0
      for (uint32_t j = i + 1; j < out->size(); j++) {
226
0
        Histogram histo;
227
0
        histo.AddHistogram((*out)[i]);
228
0
        histo.AddHistogram((*out)[j]);
229
0
        float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
230
0
                     (*out)[i].entropy_ - (*out)[j].entropy_;
231
        // Avoid enqueueing pairs that are not advantageous to merge.
232
0
        if (cost >= 0) continue;
233
0
        pairs_to_merge.push(
234
0
            HistogramPair{cost, i, j, std::max(version[i], version[j])});
235
0
      }
236
0
    }
237
238
    // Merge the best pair to merge, add new pairs that get formed as a
239
    // consequence.
240
0
    while (!pairs_to_merge.empty()) {
241
0
      uint32_t first = pairs_to_merge.top().first;
242
0
      uint32_t second = pairs_to_merge.top().second;
243
0
      uint32_t ver = pairs_to_merge.top().version;
244
0
      pairs_to_merge.pop();
245
0
      if (ver != std::max(version[first], version[second]) ||
246
0
          version[first] == 0 || version[second] == 0) {
247
0
        continue;
248
0
      }
249
0
      (*out)[first].AddHistogram((*out)[second]);
250
0
      (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(),
251
0
                                                 (*out)[first].data_.size());
252
0
      for (size_t i = 0; i < renumbering.size(); i++) {
253
0
        if (renumbering[i] == second) {
254
0
          renumbering[i] = first;
255
0
        }
256
0
      }
257
0
      version[second] = 0;
258
0
      version[first] = next_version++;
259
0
      for (uint32_t j = 0; j < out->size(); j++) {
260
0
        if (j == first) continue;
261
0
        if (version[j] == 0) continue;
262
0
        Histogram histo;
263
0
        histo.AddHistogram((*out)[first]);
264
0
        histo.AddHistogram((*out)[j]);
265
0
        float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
266
0
                     (*out)[first].entropy_ - (*out)[j].entropy_;
267
        // Avoid enqueueing pairs that are not advantageous to merge.
268
0
        if (cost >= 0) continue;
269
0
        pairs_to_merge.push(
270
0
            HistogramPair{cost, std::min(first, j), std::max(first, j),
271
0
                          std::max(version[first], version[j])});
272
0
      }
273
0
    }
274
0
    std::vector<uint32_t> reverse_renumbering(out->size(), -1);
275
0
    size_t num_alive = 0;
276
0
    for (size_t i = 0; i < out->size(); i++) {
277
0
      if (version[i] == 0) continue;
278
0
      (*out)[num_alive++] = (*out)[i];
279
0
      reverse_renumbering[i] = num_alive - 1;
280
0
    }
281
0
    out->resize(num_alive);
282
0
    for (size_t i = 0; i < histogram_symbols->size(); i++) {
283
0
      (*histogram_symbols)[i] =
284
0
          reverse_renumbering[renumbering[(*histogram_symbols)[i]]];
285
0
    }
286
0
  }
287
288
  // Convert the context map to a canonical form.
289
0
  HistogramReindex(out, histogram_symbols);
290
0
}
291
292
}  // namespace jxl
293
#endif  // HWY_ONCE