Coverage Report

Created: 2025-06-22 08:04

/src/libjxl/lib/jxl/enc_cluster.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_cluster.h"
7
8
#include <algorithm>
9
#include <cmath>
10
#include <limits>
11
#include <map>
12
#include <numeric>
13
#include <queue>
14
#include <tuple>
15
16
#include "lib/jxl/base/status.h"
17
18
#undef HWY_TARGET_INCLUDE
19
#define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc"
20
#include <hwy/foreach_target.h>
21
#include <hwy/highway.h>
22
23
#include "lib/jxl/base/fast_math-inl.h"
24
#include "lib/jxl/enc_ans.h"
25
HWY_BEFORE_NAMESPACE();
26
namespace jxl {
27
namespace HWY_NAMESPACE {
28
29
// These templates are not found via ADL.
30
using hwy::HWY_NAMESPACE::Eq;
31
using hwy::HWY_NAMESPACE::IfThenZeroElse;
32
33
template <class V>
34
0
V Entropy(V count, V inv_total, V total) {
35
0
  const HWY_CAPPED(float, Histogram::kRounding) d;
36
0
  const auto zero = Set(d, 0.0f);
37
  // TODO(eustas): why (0 - x) instead of Neg(x)?
38
0
  return IfThenZeroElse(
39
0
      Eq(count, total),
40
0
      Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count)))));
41
0
}
42
43
0
void HistogramEntropy(const Histogram& a) {
44
0
  a.entropy_ = 0.0f;
45
0
  if (a.total_count_ == 0) return;
46
47
0
  const HWY_CAPPED(float, Histogram::kRounding) df;
48
0
  const HWY_CAPPED(int32_t, Histogram::kRounding) di;
49
50
0
  const auto inv_tot = Set(df, 1.0f / a.total_count_);
51
0
  auto entropy_lanes = Zero(df);
52
0
  auto total = Set(df, a.total_count_);
53
54
0
  for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) {
55
0
    const auto counts = LoadU(di, &a.data_[i]);
56
0
    entropy_lanes =
57
0
        Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total));
58
0
  }
59
0
  a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes));
60
0
}
61
62
0
float HistogramDistance(const Histogram& a, const Histogram& b) {
63
0
  if (a.total_count_ == 0 || b.total_count_ == 0) return 0;
64
65
0
  const HWY_CAPPED(float, Histogram::kRounding) df;
66
0
  const HWY_CAPPED(int32_t, Histogram::kRounding) di;
67
68
0
  const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_));
69
0
  auto distance_lanes = Zero(df);
70
0
  auto total = Set(df, a.total_count_ + b.total_count_);
71
72
0
  for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size());
73
0
       i += Lanes(di)) {
74
0
    const auto a_counts =
75
0
        a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di);
76
0
    const auto b_counts =
77
0
        b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di);
78
0
    const auto counts = ConvertTo(df, Add(a_counts, b_counts));
79
0
    distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total));
80
0
  }
81
0
  const float total_distance = GetLane(SumOfLanes(df, distance_lanes));
82
0
  return total_distance - a.entropy_ - b.entropy_;
83
0
}
84
85
constexpr const float kInfinity = std::numeric_limits<float>::infinity();
86
87
0
float HistogramKLDivergence(const Histogram& actual, const Histogram& coding) {
88
0
  if (actual.total_count_ == 0) return 0;
89
0
  if (coding.total_count_ == 0) return kInfinity;
90
91
0
  const HWY_CAPPED(float, Histogram::kRounding) df;
92
0
  const HWY_CAPPED(int32_t, Histogram::kRounding) di;
93
94
0
  const auto coding_inv = Set(df, 1.0f / coding.total_count_);
95
0
  auto cost_lanes = Zero(df);
96
97
0
  for (size_t i = 0; i < actual.data_.size(); i += Lanes(di)) {
98
0
    const auto counts = LoadU(di, &actual.data_[i]);
99
0
    const auto coding_counts =
100
0
        coding.data_.size() > i ? LoadU(di, &coding.data_[i]) : Zero(di);
101
0
    const auto coding_probs = Mul(ConvertTo(df, coding_counts), coding_inv);
102
0
    const auto neg_coding_cost = BitCast(
103
0
        df,
104
0
        IfThenZeroElse(Eq(counts, Zero(di)),
105
0
                       IfThenElse(Eq(coding_counts, Zero(di)),
106
0
                                  BitCast(di, Set(df, -kInfinity)),
107
0
                                  BitCast(di, FastLog2f(df, coding_probs)))));
108
0
    cost_lanes = NegMulAdd(ConvertTo(df, counts), neg_coding_cost, cost_lanes);
109
0
  }
110
0
  const float total_cost = GetLane(SumOfLanes(df, cost_lanes));
111
0
  return total_cost - actual.entropy_;
112
0
}
113
114
// First step of a k-means clustering with a fancy distance metric.
115
Status FastClusterHistograms(const std::vector<Histogram>& in,
116
                             size_t max_histograms, std::vector<Histogram>* out,
117
0
                             std::vector<uint32_t>* histogram_symbols) {
118
0
  const size_t prev_histograms = out->size();
119
0
  out->reserve(max_histograms);
120
0
  histogram_symbols->clear();
121
0
  histogram_symbols->resize(in.size(), max_histograms);
122
123
0
  std::vector<float> dists(in.size(), std::numeric_limits<float>::max());
124
0
  size_t largest_idx = 0;
125
0
  for (size_t i = 0; i < in.size(); i++) {
126
0
    if (in[i].total_count_ == 0) {
127
0
      (*histogram_symbols)[i] = 0;
128
0
      dists[i] = 0.0f;
129
0
      continue;
130
0
    }
131
0
    HistogramEntropy(in[i]);
132
0
    if (in[i].total_count_ > in[largest_idx].total_count_) {
133
0
      largest_idx = i;
134
0
    }
135
0
  }
136
137
0
  if (prev_histograms > 0) {
138
0
    for (size_t j = 0; j < prev_histograms; ++j) {
139
0
      HistogramEntropy((*out)[j]);
140
0
    }
141
0
    for (size_t i = 0; i < in.size(); i++) {
142
0
      if (dists[i] == 0.0f) continue;
143
0
      for (size_t j = 0; j < prev_histograms; ++j) {
144
0
        dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]);
145
0
      }
146
0
    }
147
0
    auto max_dist = std::max_element(dists.begin(), dists.end());
148
0
    if (*max_dist > 0.0f) {
149
0
      largest_idx = max_dist - dists.begin();
150
0
    }
151
0
  }
152
153
0
  constexpr float kMinDistanceForDistinct = 48.0f;
154
0
  while (out->size() < max_histograms) {
155
0
    (*histogram_symbols)[largest_idx] = out->size();
156
0
    out->push_back(in[largest_idx]);
157
0
    dists[largest_idx] = 0.0f;
158
0
    largest_idx = 0;
159
0
    for (size_t i = 0; i < in.size(); i++) {
160
0
      if (dists[i] == 0.0f) continue;
161
0
      dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]);
162
0
      if (dists[i] > dists[largest_idx]) largest_idx = i;
163
0
    }
164
0
    if (dists[largest_idx] < kMinDistanceForDistinct) break;
165
0
  }
166
167
0
  for (size_t i = 0; i < in.size(); i++) {
168
0
    if ((*histogram_symbols)[i] != max_histograms) continue;
169
0
    size_t best = 0;
170
0
    float best_dist = std::numeric_limits<float>::max();
171
0
    for (size_t j = 0; j < out->size(); j++) {
172
0
      float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j])
173
0
                                       : HistogramDistance(in[i], (*out)[j]);
174
0
      if (dist < best_dist) {
175
0
        best = j;
176
0
        best_dist = dist;
177
0
      }
178
0
    }
179
0
    JXL_ENSURE(best_dist < std::numeric_limits<float>::max());
180
0
    if (best >= prev_histograms) {
181
0
      (*out)[best].AddHistogram(in[i]);
182
0
      HistogramEntropy((*out)[best]);
183
0
    }
184
0
    (*histogram_symbols)[i] = best;
185
0
  }
186
0
  return true;
187
0
}
188
189
// NOLINTNEXTLINE(google-readability-namespace-comments)
190
}  // namespace HWY_NAMESPACE
191
}  // namespace jxl
192
HWY_AFTER_NAMESPACE();
193
194
#if HWY_ONCE
195
namespace jxl {
196
HWY_EXPORT(FastClusterHistograms);  // Local function
197
HWY_EXPORT(HistogramEntropy);       // Local function
198
199
0
StatusOr<float> Histogram::PopulationCost() const {
200
0
  return ANSPopulationCost(data_.data(), data_.size());
201
0
}
202
203
0
float Histogram::ShannonEntropy() const {
204
0
  HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this);
205
0
  return entropy_;
206
0
}
207
208
namespace {
209
// -----------------------------------------------------------------------------
210
// Histogram refinement
211
212
// Reorder histograms in *out so that the new symbols in *symbols come in
213
// increasing order.
214
void HistogramReindex(std::vector<Histogram>* out, size_t prev_histograms,
215
0
                      std::vector<uint32_t>* symbols) {
216
0
  std::vector<Histogram> tmp(*out);
217
0
  std::map<int, int> new_index;
218
0
  for (size_t i = 0; i < prev_histograms; ++i) {
219
0
    new_index[i] = i;
220
0
  }
221
0
  int next_index = prev_histograms;
222
0
  for (uint32_t symbol : *symbols) {
223
0
    if (new_index.find(symbol) == new_index.end()) {
224
0
      new_index[symbol] = next_index;
225
0
      (*out)[next_index] = tmp[symbol];
226
0
      ++next_index;
227
0
    }
228
0
  }
229
0
  out->resize(next_index);
230
0
  for (uint32_t& symbol : *symbols) {
231
0
    symbol = new_index[symbol];
232
0
  }
233
0
}
234
235
}  // namespace
236
237
// Clusters similar histograms in 'in' together, the selected histograms are
238
// placed in 'out', and for each index in 'in', *histogram_symbols will
239
// indicate which of the 'out' histograms is the best approximation.
240
Status ClusterHistograms(const HistogramParams& params,
241
                         const std::vector<Histogram>& in,
242
                         size_t max_histograms, std::vector<Histogram>* out,
243
0
                         std::vector<uint32_t>* histogram_symbols) {
244
0
  size_t prev_histograms = out->size();
245
0
  max_histograms = std::min(max_histograms, params.max_histograms);
246
0
  max_histograms = std::min(max_histograms, in.size());
247
0
  if (params.clustering == HistogramParams::ClusteringType::kFastest) {
248
0
    max_histograms = std::min(max_histograms, static_cast<size_t>(4));
249
0
  }
250
251
0
  JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(FastClusterHistograms)(
252
0
      in, prev_histograms + max_histograms, out, histogram_symbols));
253
254
0
  if (prev_histograms == 0 &&
255
0
      params.clustering == HistogramParams::ClusteringType::kBest) {
256
0
    for (auto& histo : *out) {
257
0
      JXL_ASSIGN_OR_RETURN(
258
0
          histo.entropy_,
259
0
          ANSPopulationCost(histo.data_.data(), histo.data_.size()));
260
0
    }
261
0
    uint32_t next_version = 2;
262
0
    std::vector<uint32_t> version(out->size(), 1);
263
0
    std::vector<uint32_t> renumbering(out->size());
264
0
    std::iota(renumbering.begin(), renumbering.end(), 0);
265
266
    // Try to pair up clusters if doing so reduces the total cost.
267
268
0
    struct HistogramPair {
269
      // validity of a pair: p.version == max(version[i], version[j])
270
0
      float cost;
271
0
      uint32_t first;
272
0
      uint32_t second;
273
0
      uint32_t version;
274
      // We use > because priority queues sort in *decreasing* order, but we
275
      // want lower cost elements to appear first.
276
0
      bool operator<(const HistogramPair& other) const {
277
0
        return std::make_tuple(cost, first, second, version) >
278
0
               std::make_tuple(other.cost, other.first, other.second,
279
0
                               other.version);
280
0
      }
281
0
    };
282
283
    // Create list of all pairs by increasing merging cost.
284
0
    std::priority_queue<HistogramPair> pairs_to_merge;
285
0
    for (uint32_t i = 0; i < out->size(); i++) {
286
0
      for (uint32_t j = i + 1; j < out->size(); j++) {
287
0
        Histogram histo;
288
0
        histo.AddHistogram((*out)[i]);
289
0
        histo.AddHistogram((*out)[j]);
290
0
        JXL_ASSIGN_OR_RETURN(float cost, ANSPopulationCost(histo.data_.data(),
291
0
                                                           histo.data_.size()));
292
0
        cost -= (*out)[i].entropy_ + (*out)[j].entropy_;
293
        // Avoid enqueueing pairs that are not advantageous to merge.
294
0
        if (cost >= 0) continue;
295
0
        pairs_to_merge.push(
296
0
            HistogramPair{cost, i, j, std::max(version[i], version[j])});
297
0
      }
298
0
    }
299
300
    // Merge the best pair to merge, add new pairs that get formed as a
301
    // consequence.
302
0
    while (!pairs_to_merge.empty()) {
303
0
      uint32_t first = pairs_to_merge.top().first;
304
0
      uint32_t second = pairs_to_merge.top().second;
305
0
      uint32_t ver = pairs_to_merge.top().version;
306
0
      pairs_to_merge.pop();
307
0
      if (ver != std::max(version[first], version[second]) ||
308
0
          version[first] == 0 || version[second] == 0) {
309
0
        continue;
310
0
      }
311
0
      (*out)[first].AddHistogram((*out)[second]);
312
0
      JXL_ASSIGN_OR_RETURN(float cost,
313
0
                           ANSPopulationCost((*out)[first].data_.data(),
314
0
                                             (*out)[first].data_.size()));
315
0
      (*out)[first].entropy_ = cost;
316
0
      for (uint32_t& item : renumbering) {
317
0
        if (item == second) {
318
0
          item = first;
319
0
        }
320
0
      }
321
0
      version[second] = 0;
322
0
      version[first] = next_version++;
323
0
      for (uint32_t j = 0; j < out->size(); j++) {
324
0
        if (j == first) continue;
325
0
        if (version[j] == 0) continue;
326
0
        Histogram histo;
327
0
        histo.AddHistogram((*out)[first]);
328
0
        histo.AddHistogram((*out)[j]);
329
0
        JXL_ASSIGN_OR_RETURN(float cost, ANSPopulationCost(histo.data_.data(),
330
0
                                                           histo.data_.size()));
331
0
        cost -= (*out)[first].entropy_ + (*out)[j].entropy_;
332
        // Avoid enqueueing pairs that are not advantageous to merge.
333
0
        if (cost >= 0) continue;
334
0
        pairs_to_merge.push(
335
0
            HistogramPair{cost, std::min(first, j), std::max(first, j),
336
0
                          std::max(version[first], version[j])});
337
0
      }
338
0
    }
339
0
    std::vector<uint32_t> reverse_renumbering(out->size(), -1);
340
0
    size_t num_alive = 0;
341
0
    for (size_t i = 0; i < out->size(); i++) {
342
0
      if (version[i] == 0) continue;
343
0
      (*out)[num_alive++] = (*out)[i];
344
0
      reverse_renumbering[i] = num_alive - 1;
345
0
    }
346
0
    out->resize(num_alive);
347
0
    for (uint32_t& item : *histogram_symbols) {
348
0
      item = reverse_renumbering[renumbering[item]];
349
0
    }
350
0
  }
351
352
  // Convert the context map to a canonical form.
353
0
  HistogramReindex(out, prev_histograms, histogram_symbols);
354
0
  return true;
355
0
}
356
357
}  // namespace jxl
358
#endif  // HWY_ONCE