Coverage Report

Created: 2025-06-22 08:04

/src/aom/av1/encoder/palette.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3
 *
4
 * This source code is subject to the terms of the BSD 2 Clause License and
5
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6
 * was not distributed with this source code in the LICENSE file, you can
7
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8
 * Media Patent License 1.0 was not distributed with this source code in the
9
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
 */
11
12
#include <math.h>
13
#include <stdlib.h>
14
15
#include "av1/common/pred_common.h"
16
17
#include "av1/encoder/block.h"
18
#include "av1/encoder/cost.h"
19
#include "av1/encoder/encoder.h"
20
#include "av1/encoder/intra_mode_search.h"
21
#include "av1/encoder/intra_mode_search_utils.h"
22
#include "av1/encoder/palette.h"
23
#include "av1/encoder/random.h"
24
#include "av1/encoder/rdopt_utils.h"
25
#include "av1/encoder/tx_search.h"
26
27
0
#define AV1_K_MEANS_DIM 1
28
#include "av1/encoder/k_means_template.h"
29
#undef AV1_K_MEANS_DIM
30
0
#define AV1_K_MEANS_DIM 2
31
#include "av1/encoder/k_means_template.h"
32
#undef AV1_K_MEANS_DIM
33
34
0
static int int16_comparer(const void *a, const void *b) {
35
0
  return (*(int16_t *)a - *(int16_t *)b);
36
0
}
37
38
/*!\brief Removes duplicated centroid indices.
39
 *
40
 * \ingroup palette_mode_search
41
 * \param[in]    centroids          A list of centroids index.
42
 * \param[in]    num_centroids      Number of centroids.
43
 *
44
 * \return Returns the number of unique centroids and saves the unique centroids
45
 * in beginning of the centroids array.
46
 *
47
 * \attention The centroids should be rounded to integers before calling this
48
 * method.
49
 */
50
0
static int remove_duplicates(int16_t *centroids, int num_centroids) {
51
0
  int num_unique;  // number of unique centroids
52
0
  int i;
53
0
  qsort(centroids, num_centroids, sizeof(*centroids), int16_comparer);
54
  // Remove duplicates.
55
0
  num_unique = 1;
56
0
  for (i = 1; i < num_centroids; ++i) {
57
0
    if (centroids[i] != centroids[i - 1]) {  // found a new unique centroid
58
0
      centroids[num_unique++] = centroids[i];
59
0
    }
60
0
  }
61
0
  return num_unique;
62
0
}
63
64
static int delta_encode_cost(const int *colors, int num, int bit_depth,
65
0
                             int min_val) {
66
0
  if (num <= 0) return 0;
67
0
  int bits_cost = bit_depth;
68
0
  if (num == 1) return bits_cost;
69
0
  bits_cost += 2;
70
0
  int max_delta = 0;
71
0
  int deltas[PALETTE_MAX_SIZE];
72
0
  const int min_bits = bit_depth - 3;
73
0
  for (int i = 1; i < num; ++i) {
74
0
    const int delta = colors[i] - colors[i - 1];
75
0
    deltas[i - 1] = delta;
76
0
    assert(delta >= min_val);
77
0
    if (delta > max_delta) max_delta = delta;
78
0
  }
79
0
  int bits_per_delta = AOMMAX(av1_ceil_log2(max_delta + 1 - min_val), min_bits);
80
0
  assert(bits_per_delta <= bit_depth);
81
0
  int range = (1 << bit_depth) - colors[0] - min_val;
82
0
  for (int i = 0; i < num - 1; ++i) {
83
0
    bits_cost += bits_per_delta;
84
0
    range -= deltas[i];
85
0
    bits_per_delta = AOMMIN(bits_per_delta, av1_ceil_log2(range));
86
0
  }
87
0
  return bits_cost;
88
0
}
89
90
int av1_index_color_cache(const uint16_t *color_cache, int n_cache,
91
                          const uint16_t *colors, int n_colors,
92
0
                          uint8_t *cache_color_found, int *out_cache_colors) {
93
0
  if (n_cache <= 0) {
94
0
    for (int i = 0; i < n_colors; ++i) out_cache_colors[i] = colors[i];
95
0
    return n_colors;
96
0
  }
97
0
  memset(cache_color_found, 0, n_cache * sizeof(*cache_color_found));
98
0
  int n_in_cache = 0;
99
0
  int in_cache_flags[PALETTE_MAX_SIZE];
100
0
  memset(in_cache_flags, 0, sizeof(in_cache_flags));
101
0
  for (int i = 0; i < n_cache && n_in_cache < n_colors; ++i) {
102
0
    for (int j = 0; j < n_colors; ++j) {
103
0
      if (colors[j] == color_cache[i]) {
104
0
        in_cache_flags[j] = 1;
105
0
        cache_color_found[i] = 1;
106
0
        ++n_in_cache;
107
0
        break;
108
0
      }
109
0
    }
110
0
  }
111
0
  int j = 0;
112
0
  for (int i = 0; i < n_colors; ++i)
113
0
    if (!in_cache_flags[i]) out_cache_colors[j++] = colors[i];
114
0
  assert(j == n_colors - n_in_cache);
115
0
  return j;
116
0
}
117
118
int av1_get_palette_delta_bits_v(const PALETTE_MODE_INFO *const pmi,
119
                                 int bit_depth, int *zero_count,
120
0
                                 int *min_bits) {
121
0
  const int n = pmi->palette_size[1];
122
0
  const int max_val = 1 << bit_depth;
123
0
  int max_d = 0;
124
0
  *min_bits = bit_depth - 4;
125
0
  *zero_count = 0;
126
0
  for (int i = 1; i < n; ++i) {
127
0
    const int delta = pmi->palette_colors[2 * PALETTE_MAX_SIZE + i] -
128
0
                      pmi->palette_colors[2 * PALETTE_MAX_SIZE + i - 1];
129
0
    const int v = abs(delta);
130
0
    const int d = AOMMIN(v, max_val - v);
131
0
    if (d > max_d) max_d = d;
132
0
    if (d == 0) ++(*zero_count);
133
0
  }
134
0
  return AOMMAX(av1_ceil_log2(max_d + 1), *min_bits);
135
0
}
136
137
int av1_palette_color_cost_y(const PALETTE_MODE_INFO *const pmi,
138
                             const uint16_t *color_cache, int n_cache,
139
0
                             int bit_depth) {
140
0
  const int n = pmi->palette_size[0];
141
0
  int out_cache_colors[PALETTE_MAX_SIZE];
142
0
  uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
143
0
  const int n_out_cache =
144
0
      av1_index_color_cache(color_cache, n_cache, pmi->palette_colors, n,
145
0
                            cache_color_found, out_cache_colors);
146
0
  const int total_bits =
147
0
      n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 1);
148
0
  return av1_cost_literal(total_bits);
149
0
}
150
151
int av1_palette_color_cost_uv(const PALETTE_MODE_INFO *const pmi,
152
                              const uint16_t *color_cache, int n_cache,
153
0
                              int bit_depth) {
154
0
  const int n = pmi->palette_size[1];
155
0
  int total_bits = 0;
156
  // U channel palette color cost.
157
0
  int out_cache_colors[PALETTE_MAX_SIZE];
158
0
  uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
159
0
  const int n_out_cache = av1_index_color_cache(
160
0
      color_cache, n_cache, pmi->palette_colors + PALETTE_MAX_SIZE, n,
161
0
      cache_color_found, out_cache_colors);
162
0
  total_bits +=
163
0
      n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 0);
164
165
  // V channel palette color cost.
166
0
  int zero_count = 0, min_bits_v = 0;
167
0
  const int bits_v =
168
0
      av1_get_palette_delta_bits_v(pmi, bit_depth, &zero_count, &min_bits_v);
169
0
  const int bits_using_delta =
170
0
      2 + bit_depth + (bits_v + 1) * (n - 1) - zero_count;
171
0
  const int bits_using_raw = bit_depth * n;
172
0
  total_bits += 1 + AOMMIN(bits_using_delta, bits_using_raw);
173
0
  return av1_cost_literal(total_bits);
174
0
}
175
176
// Extends 'color_map' array from 'orig_width x orig_height' to 'new_width x
177
// new_height'. Extra rows and columns are filled in by copying last valid
178
// row/column.
179
static inline void extend_palette_color_map(uint8_t *const color_map,
180
                                            int orig_width, int orig_height,
181
0
                                            int new_width, int new_height) {
182
0
  int j;
183
0
  assert(new_width >= orig_width);
184
0
  assert(new_height >= orig_height);
185
0
  if (new_width == orig_width && new_height == orig_height) return;
186
187
0
  for (j = orig_height - 1; j >= 0; --j) {
188
0
    memmove(color_map + j * new_width, color_map + j * orig_width, orig_width);
189
    // Copy last column to extra columns.
190
0
    memset(color_map + j * new_width + orig_width,
191
0
           color_map[j * new_width + orig_width - 1], new_width - orig_width);
192
0
  }
193
  // Copy last row to extra rows.
194
0
  for (j = orig_height; j < new_height; ++j) {
195
0
    memcpy(color_map + j * new_width, color_map + (orig_height - 1) * new_width,
196
0
           new_width);
197
0
  }
198
0
}
199
200
// Bias toward using colors in the cache.
201
// TODO(huisu): Try other schemes to improve compression.
202
static inline void optimize_palette_colors(uint16_t *color_cache, int n_cache,
203
                                           int n_colors, int stride,
204
0
                                           int16_t *centroids, int bit_depth) {
205
0
  if (n_cache <= 0) return;
206
0
  for (int i = 0; i < n_colors * stride; i += stride) {
207
0
    int min_diff = abs((int)centroids[i] - (int)color_cache[0]);
208
0
    int idx = 0;
209
0
    for (int j = 1; j < n_cache; ++j) {
210
0
      const int this_diff = abs((int)centroids[i] - (int)color_cache[j]);
211
0
      if (this_diff < min_diff) {
212
0
        min_diff = this_diff;
213
0
        idx = j;
214
0
      }
215
0
    }
216
0
    const int min_threshold = 4 << (bit_depth - 8);
217
0
    if (min_diff <= min_threshold) centroids[i] = color_cache[idx];
218
0
  }
219
0
}
220
221
/*!\brief Calculate the luma palette cost from a given color palette
222
 *
223
 * \ingroup palette_mode_search
224
 * \callergraph
225
 * Given the base colors as specified in centroids[], calculate the RD cost
226
 * of palette mode.
227
 */
228
static inline void palette_rd_y(
229
    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
230
    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data, int16_t *centroids,
231
    int n, uint16_t *color_cache, int n_cache, bool do_header_rd_based_gating,
232
    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
233
    int *rate, int *rate_tokenonly, int64_t *distortion, uint8_t *skippable,
234
    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *blk_skip,
235
    uint8_t *tx_type_map, int *beat_best_palette_rd,
236
0
    bool *do_header_rd_based_breakout, int discount_color_cost) {
237
0
  if (do_header_rd_based_breakout != NULL) *do_header_rd_based_breakout = false;
238
0
  optimize_palette_colors(color_cache, n_cache, n, 1, centroids,
239
0
                          cpi->common.seq_params->bit_depth);
240
0
  const int num_unique_colors = remove_duplicates(centroids, n);
241
0
  if (num_unique_colors < PALETTE_MIN_SIZE) {
242
    // Too few unique colors to create a palette. And DC_PRED will work
243
    // well for that case anyway. So skip.
244
0
    return;
245
0
  }
246
0
  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
247
0
  if (cpi->common.seq_params->use_highbitdepth) {
248
0
    for (int i = 0; i < num_unique_colors; ++i) {
249
0
      pmi->palette_colors[i] = clip_pixel_highbd(
250
0
          (int)centroids[i], cpi->common.seq_params->bit_depth);
251
0
    }
252
0
  } else {
253
0
    for (int i = 0; i < num_unique_colors; ++i) {
254
0
      pmi->palette_colors[i] = clip_pixel(centroids[i]);
255
0
    }
256
0
  }
257
0
  pmi->palette_size[0] = num_unique_colors;
258
0
  MACROBLOCKD *const xd = &x->e_mbd;
259
0
  uint8_t *const color_map = xd->plane[0].color_index_map;
260
0
  int block_width, block_height, rows, cols;
261
0
  av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
262
0
                           &cols);
263
0
  av1_calc_indices(data, centroids, color_map, rows * cols, num_unique_colors,
264
0
                   1);
265
0
  extend_palette_color_map(color_map, cols, rows, block_width, block_height);
266
267
0
  RD_STATS tokenonly_rd_stats;
268
0
  int this_rate;
269
270
0
  if (do_header_rd_based_gating) {
271
0
    assert(do_header_rd_based_breakout != NULL);
272
0
    const int palette_mode_rate = intra_mode_info_cost_y(
273
0
        cpi, x, mbmi, bsize, dc_mode_cost, discount_color_cost);
274
0
    const int64_t header_rd = RDCOST(x->rdmult, palette_mode_rate, 0);
275
    // Less aggressive pruning when prune_luma_palette_size_search_level == 1.
276
0
    const int header_rd_shift =
277
0
        (cpi->sf.intra_sf.prune_luma_palette_size_search_level == 1) ? 1 : 0;
278
    // Terminate further palette_size search, if the header cost corresponding
279
    // to lower palette_size is more than *best_rd << header_rd_shift. This
280
    // logic is implemented with a right shift in the LHS to prevent a possible
281
    // overflow with the left shift in RHS.
282
0
    if ((header_rd >> header_rd_shift) > *best_rd) {
283
0
      *do_header_rd_based_breakout = true;
284
0
      return;
285
0
    }
286
0
    av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
287
0
                                      *best_rd);
288
0
    if (tokenonly_rd_stats.rate == INT_MAX) return;
289
0
    this_rate = tokenonly_rd_stats.rate + palette_mode_rate;
290
0
  } else {
291
0
    av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
292
0
                                      *best_rd);
293
0
    if (tokenonly_rd_stats.rate == INT_MAX) return;
294
0
    this_rate = tokenonly_rd_stats.rate +
295
0
                intra_mode_info_cost_y(cpi, x, mbmi, bsize, dc_mode_cost,
296
0
                                       discount_color_cost);
297
0
  }
298
299
0
  int64_t this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
300
0
  if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->bsize)) {
301
0
    tokenonly_rd_stats.rate -= tx_size_cost(x, bsize, mbmi->tx_size);
302
0
  }
303
  // Collect mode stats for multiwinner mode processing
304
0
  const int txfm_search_done = 1;
305
0
  store_winner_mode_stats(
306
0
      &cpi->common, x, mbmi, NULL, NULL, NULL, THR_DC, color_map, bsize,
307
0
      this_rd, cpi->sf.winner_mode_sf.multi_winner_mode_type, txfm_search_done);
308
0
  if (this_rd < *best_rd) {
309
0
    *best_rd = this_rd;
310
    // Setting beat_best_rd flag because current mode rd is better than best_rd.
311
    // This flag need to be updated only for palette evaluation in key frames
312
0
    if (beat_best_rd) *beat_best_rd = 1;
313
0
    memcpy(best_palette_color_map, color_map,
314
0
           block_width * block_height * sizeof(color_map[0]));
315
0
    *best_mbmi = *mbmi;
316
0
    memcpy(blk_skip, x->txfm_search_info.blk_skip,
317
0
           sizeof(x->txfm_search_info.blk_skip[0]) * ctx->num_4x4_blk);
318
0
    av1_copy_array(tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
319
0
    if (rate) *rate = this_rate;
320
0
    if (rate_tokenonly) *rate_tokenonly = tokenonly_rd_stats.rate;
321
0
    if (distortion) *distortion = tokenonly_rd_stats.dist;
322
0
    if (skippable) *skippable = tokenonly_rd_stats.skip_txfm;
323
0
    if (beat_best_palette_rd) *beat_best_palette_rd = 1;
324
0
  }
325
0
}
326
327
0
static inline int is_iter_over(int curr_idx, int end_idx, int step_size) {
328
0
  assert(step_size != 0);
329
0
  return (step_size > 0) ? curr_idx >= end_idx : curr_idx <= end_idx;
330
0
}
331
332
// Performs count-based palette search with number of colors in interval
333
// [start_n, end_n) with step size step_size. If step_size < 0, then end_n can
334
// be less than start_n. Saves the last numbers searched in last_n_searched and
335
// returns the best number of colors found.
336
static inline int perform_top_color_palette_search(
337
    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
338
    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data,
339
    int16_t *top_colors, int start_n, int end_n, int step_size,
340
    bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
341
    int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
342
    int64_t *best_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
343
    uint8_t *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
344
0
    uint8_t *best_blk_skip, uint8_t *tx_type_map, int discount_color_cost) {
345
0
  int16_t centroids[PALETTE_MAX_SIZE];
346
0
  int n = start_n;
347
0
  int top_color_winner = end_n;
348
  /* clang-format off */
349
0
  assert(IMPLIES(step_size < 0, start_n > end_n));
350
  /* clang-format on */
351
0
  assert(IMPLIES(step_size > 0, start_n < end_n));
352
0
  while (!is_iter_over(n, end_n, step_size)) {
353
0
    int beat_best_palette_rd = 0;
354
0
    bool do_header_rd_based_breakout = false;
355
0
    memcpy(centroids, top_colors, n * sizeof(top_colors[0]));
356
0
    palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
357
0
                 color_cache, n_cache, do_header_rd_based_gating, best_mbmi,
358
0
                 best_palette_color_map, best_rd, rate, rate_tokenonly,
359
0
                 distortion, skippable, beat_best_rd, ctx, best_blk_skip,
360
0
                 tx_type_map, &beat_best_palette_rd,
361
0
                 &do_header_rd_based_breakout, discount_color_cost);
362
0
    *last_n_searched = n;
363
0
    if (do_header_rd_based_breakout) {
364
      // Terminate palette_size search by setting last_n_searched to end_n.
365
0
      *last_n_searched = end_n;
366
0
      break;
367
0
    }
368
0
    if (beat_best_palette_rd) {
369
0
      top_color_winner = n;
370
0
    } else if (cpi->sf.intra_sf.prune_palette_search_level == 2) {
371
      // At search level 2, we return immediately if we don't see an improvement
372
0
      return top_color_winner;
373
0
    }
374
0
    n += step_size;
375
0
  }
376
0
  return top_color_winner;
377
0
}
378
379
// Performs k-means based palette search with number of colors in interval
380
// [start_n, end_n) with step size step_size. If step_size < 0, then end_n can
381
// be less than start_n. Saves the last numbers searched in last_n_searched and
382
// returns the best number of colors found.
383
static inline int perform_k_means_palette_search(
384
    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
385
    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data, int lower_bound,
386
    int upper_bound, int start_n, int end_n, int step_size,
387
    bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
388
    int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
389
    int64_t *best_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
390
    uint8_t *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
391
    uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
392
0
    int data_points, int discount_color_cost) {
393
0
  int16_t centroids[PALETTE_MAX_SIZE];
394
0
  const int max_itr = 50;
395
0
  int n = start_n;
396
0
  int top_color_winner = end_n;
397
  /* clang-format off */
398
0
  assert(IMPLIES(step_size < 0, start_n > end_n));
399
  /* clang-format on */
400
0
  assert(IMPLIES(step_size > 0, start_n < end_n));
401
0
  while (!is_iter_over(n, end_n, step_size)) {
402
0
    int beat_best_palette_rd = 0;
403
0
    bool do_header_rd_based_breakout = false;
404
0
    for (int i = 0; i < n; ++i) {
405
0
      centroids[i] =
406
0
          lower_bound + (2 * i + 1) * (upper_bound - lower_bound) / n / 2;
407
0
    }
408
0
    av1_k_means(data, centroids, color_map, data_points, n, 1, max_itr);
409
0
    palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
410
0
                 color_cache, n_cache, do_header_rd_based_gating, best_mbmi,
411
0
                 best_palette_color_map, best_rd, rate, rate_tokenonly,
412
0
                 distortion, skippable, beat_best_rd, ctx, best_blk_skip,
413
0
                 tx_type_map, &beat_best_palette_rd,
414
0
                 &do_header_rd_based_breakout, discount_color_cost);
415
0
    *last_n_searched = n;
416
0
    if (do_header_rd_based_breakout) {
417
      // Terminate palette_size search by setting last_n_searched to end_n.
418
0
      *last_n_searched = end_n;
419
0
      break;
420
0
    }
421
0
    if (beat_best_palette_rd) {
422
0
      top_color_winner = n;
423
0
    } else if (cpi->sf.intra_sf.prune_palette_search_level == 2) {
424
      // At search level 2, we return immediately if we don't see an improvement
425
0
      return top_color_winner;
426
0
    }
427
0
    n += step_size;
428
0
  }
429
0
  return top_color_winner;
430
0
}
431
432
// Sets the parameters to search the current number of colors +- 1
433
static inline void set_stage2_params(int *min_n, int *max_n, int *step_size,
434
0
                                     int winner, int end_n) {
435
  // Set min to winner - 1 unless we are already at the border, then we set it
436
  // to winner + 1
437
0
  *min_n = (winner == PALETTE_MIN_SIZE) ? (PALETTE_MIN_SIZE + 1)
438
0
                                        : AOMMAX(winner - 1, PALETTE_MIN_SIZE);
439
  // Set max to winner + 1 unless we are already at the border, then we set it
440
  // to winner - 1
441
0
  *max_n =
442
0
      (winner == end_n) ? (winner - 1) : AOMMIN(winner + 1, PALETTE_MAX_SIZE);
443
444
  // Set the step size to max_n - min_n so we only search those two values.
445
  // If max_n == min_n, then set step_size to 1 to avoid infinite loop later.
446
0
  *step_size = AOMMAX(1, *max_n - *min_n);
447
0
}
448
449
static inline void fill_data_and_get_bounds(const uint8_t *src,
450
                                            const int src_stride,
451
                                            const int rows, const int cols,
452
                                            const int is_high_bitdepth,
453
                                            int16_t *data, int *lower_bound,
454
0
                                            int *upper_bound) {
455
0
  if (is_high_bitdepth) {
456
0
    const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
457
0
    *lower_bound = *upper_bound = src_ptr[0];
458
0
    for (int r = 0; r < rows; ++r) {
459
0
      for (int c = 0; c < cols; ++c) {
460
0
        const int val = src_ptr[c];
461
0
        data[c] = (int16_t)val;
462
0
        *lower_bound = AOMMIN(*lower_bound, val);
463
0
        *upper_bound = AOMMAX(*upper_bound, val);
464
0
      }
465
0
      src_ptr += src_stride;
466
0
      data += cols;
467
0
    }
468
0
    return;
469
0
  }
470
471
  // low bit depth
472
0
  *lower_bound = *upper_bound = src[0];
473
0
  for (int r = 0; r < rows; ++r) {
474
0
    for (int c = 0; c < cols; ++c) {
475
0
      const int val = src[c];
476
0
      data[c] = (int16_t)val;
477
0
      *lower_bound = AOMMIN(*lower_bound, val);
478
0
      *upper_bound = AOMMAX(*upper_bound, val);
479
0
    }
480
0
    src += src_stride;
481
0
    data += cols;
482
0
  }
483
0
}
484
485
/*! \brief Colors are sorted by their count: the higher the better.
486
 */
487
struct ColorCount {
488
  //! Color index in the histogram.
489
  int index;
490
  //! Histogram count.
491
  int count;
492
};
493
494
0
static int color_count_comp(const void *c1, const void *c2) {
495
0
  const struct ColorCount *color_count1 = (const struct ColorCount *)c1;
496
0
  const struct ColorCount *color_count2 = (const struct ColorCount *)c2;
497
0
  if (color_count1->count > color_count2->count) return -1;
498
0
  if (color_count1->count < color_count2->count) return 1;
499
0
  if (color_count1->index < color_count2->index) return -1;
500
0
  return 1;
501
0
}
502
503
static void find_top_colors(const int *const count_buf, int bit_depth,
504
0
                            int n_colors, int16_t *top_colors) {
505
  // Top color array, serving as a priority queue if more than n_colors are
506
  // found.
507
0
  struct ColorCount top_color_counts[PALETTE_MAX_SIZE] = { { 0 } };
508
0
  int n_color_count = 0;
509
0
  for (int i = 0; i < (1 << bit_depth); ++i) {
510
0
    if (count_buf[i] > 0) {
511
0
      if (n_color_count < n_colors) {
512
        // Keep adding to the top colors.
513
0
        top_color_counts[n_color_count].index = i;
514
0
        top_color_counts[n_color_count].count = count_buf[i];
515
0
        ++n_color_count;
516
0
        if (n_color_count == n_colors) {
517
0
          qsort(top_color_counts, n_colors, sizeof(top_color_counts[0]),
518
0
                color_count_comp);
519
0
        }
520
0
      } else {
521
        // Check the worst in the sorted top.
522
0
        if (count_buf[i] > top_color_counts[n_colors - 1].count) {
523
0
          int j = n_colors - 1;
524
          // Move up to the best one.
525
0
          while (j >= 1 && count_buf[i] > top_color_counts[j - 1].count) --j;
526
0
          memmove(top_color_counts + j + 1, top_color_counts + j,
527
0
                  (n_colors - j - 1) * sizeof(top_color_counts[0]));
528
0
          top_color_counts[j].index = i;
529
0
          top_color_counts[j].count = count_buf[i];
530
0
        }
531
0
      }
532
0
    }
533
0
  }
534
0
  assert(n_color_count == n_colors);
535
536
0
  for (int i = 0; i < n_colors; ++i) {
537
0
    top_colors[i] = top_color_counts[i].index;
538
0
  }
539
0
}
540
541
void av1_rd_pick_palette_intra_sby(
542
    const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int dc_mode_cost,
543
    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
544
    int *rate, int *rate_tokenonly, int64_t *distortion, uint8_t *skippable,
545
    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip,
546
0
    uint8_t *tx_type_map) {
547
0
  MACROBLOCKD *const xd = &x->e_mbd;
548
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
549
0
  assert(!is_inter_block(mbmi));
550
0
  assert(av1_allow_palette(cpi->common.features.allow_screen_content_tools,
551
0
                           bsize));
552
0
  assert(PALETTE_MAX_SIZE == 8);
553
0
  assert(PALETTE_MIN_SIZE == 2);
554
555
0
  const int src_stride = x->plane[0].src.stride;
556
0
  const uint8_t *const src = x->plane[0].src.buf;
557
0
  int block_width, block_height, rows, cols;
558
0
  av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
559
0
                           &cols);
560
0
  const SequenceHeader *const seq_params = cpi->common.seq_params;
561
0
  const int is_hbd = seq_params->use_highbitdepth;
562
0
  const int bit_depth = seq_params->bit_depth;
563
0
  const int discount_color_cost = cpi->sf.rt_sf.use_nonrd_pick_mode;
564
0
  int unused;
565
566
0
  int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
567
0
  int colors, colors_threshold = 0;
568
0
  if (is_hbd) {
569
0
    int count_buf_8bit[1 << 8];  // Maximum (1 << 8) bins for hbd path.
570
0
    av1_count_colors_highbd(src, src_stride, rows, cols, bit_depth, count_buf,
571
0
                            count_buf_8bit, &colors_threshold, &colors);
572
0
  } else {
573
0
    av1_count_colors(src, src_stride, rows, cols, count_buf, &colors);
574
0
    colors_threshold = colors;
575
0
  }
576
577
0
  uint8_t *const color_map = xd->plane[0].color_index_map;
578
0
  int color_thresh_palette = x->color_palette_thresh;
579
  // Allow for larger color_threshold for palette search, based on color,
580
  // scene_change, and block source variance.
581
  // Since palette is Y based, only allow larger threshold if block
582
  // color_dist is below threshold.
583
0
  if (cpi->sf.rt_sf.use_nonrd_pick_mode &&
584
0
      cpi->sf.rt_sf.increase_color_thresh_palette && cpi->rc.high_source_sad &&
585
0
      x->source_variance > 50) {
586
0
    int64_t norm_color_dist = 0;
587
0
    if (x->color_sensitivity[0] || x->color_sensitivity[1]) {
588
0
      norm_color_dist = x->min_dist_inter_uv >>
589
0
                        (mi_size_wide_log2[bsize] + mi_size_high_log2[bsize]);
590
0
      if (x->color_sensitivity[0] && x->color_sensitivity[1])
591
0
        norm_color_dist = norm_color_dist >> 1;
592
0
    }
593
0
    if (norm_color_dist < 8000) color_thresh_palette += 20;
594
0
  }
595
0
  if (colors_threshold > 1 && colors_threshold <= color_thresh_palette) {
596
0
    int16_t *const data = x->palette_buffer->kmeans_data_buf;
597
0
    int16_t centroids[PALETTE_MAX_SIZE];
598
0
    int lower_bound, upper_bound;
599
0
    fill_data_and_get_bounds(src, src_stride, rows, cols, is_hbd, data,
600
0
                             &lower_bound, &upper_bound);
601
602
0
    mbmi->mode = DC_PRED;
603
0
    mbmi->filter_intra_mode_info.use_filter_intra = 0;
604
605
0
    uint16_t color_cache[2 * PALETTE_MAX_SIZE];
606
0
    const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
607
608
    // Find the dominant colors, stored in top_colors[].
609
0
    int16_t top_colors[PALETTE_MAX_SIZE] = { 0 };
610
0
    find_top_colors(count_buf, bit_depth, AOMMIN(colors, PALETTE_MAX_SIZE),
611
0
                    top_colors);
612
613
    // The following are the approaches used for header rdcost based gating
614
    // for early termination for different values of prune_palette_search_level.
615
    // 0: Pruning based on header rdcost for ascending order palette_size
616
    // search.
617
    // 1: When colors > PALETTE_MIN_SIZE, enabled only for coarse palette_size
618
    // search and for finer search do_header_rd_based_gating parameter is
619
    // explicitly passed as 'false'.
620
    // 2: Enabled only for ascending order palette_size search and for
621
    // descending order search do_header_rd_based_gating parameter is explicitly
622
    // passed as 'false'.
623
0
    const bool do_header_rd_based_gating =
624
0
        cpi->sf.intra_sf.prune_luma_palette_size_search_level != 0;
625
626
    // TODO(huisu@google.com): Try to avoid duplicate computation in cases
627
    // where the dominant colors and the k-means results are similar.
628
0
    if ((cpi->sf.intra_sf.prune_palette_search_level == 1) &&
629
0
        (colors > PALETTE_MIN_SIZE)) {
630
      // Start index and step size below are chosen to evaluate unique
631
      // candidates in neighbor search, in case a winner candidate is found in
632
      // coarse search. Example,
633
      // 1) 8 colors (end_n = 8): 2,3,4,5,6,7,8. start_n is chosen as 2 and step
634
      // size is chosen as 3. Therefore, coarse search will evaluate 2, 5 and 8.
635
      // If winner is found at 5, then 4 and 6 are evaluated. Similarly, for 2
636
      // (3) and 8 (7).
637
      // 2) 7 colors (end_n = 7): 2,3,4,5,6,7. If start_n is chosen as 2 (same
638
      // as for 8 colors) then step size should also be 2, to cover all
639
      // candidates. Coarse search will evaluate 2, 4 and 6. If winner is either
640
      // 2 or 4, 3 will be evaluated. Instead, if start_n=3 and step_size=3,
641
      // coarse search will evaluate 3 and 6. For the winner, unique neighbors
642
      // (3: 2,4 or 6: 5,7) would be evaluated.
643
644
      // Start index for coarse palette search for dominant colors and k-means
645
0
      const uint8_t start_n_lookup_table[PALETTE_MAX_SIZE + 1] = { 0, 0, 0,
646
0
                                                                   3, 3, 2,
647
0
                                                                   3, 3, 2 };
648
      // Step size for coarse palette search for dominant colors and k-means
649
0
      const uint8_t step_size_lookup_table[PALETTE_MAX_SIZE + 1] = { 0, 0, 0,
650
0
                                                                     3, 3, 3,
651
0
                                                                     3, 3, 3 };
652
653
      // Choose the start index and step size for coarse search based on number
654
      // of colors
655
0
      const int max_n = AOMMIN(colors, PALETTE_MAX_SIZE);
656
0
      const int min_n = start_n_lookup_table[max_n];
657
0
      const int step_size = step_size_lookup_table[max_n];
658
0
      assert(min_n >= PALETTE_MIN_SIZE);
659
      // Perform top color coarse palette search to find the winner candidate
660
0
      const int top_color_winner = perform_top_color_palette_search(
661
0
          cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, min_n, max_n + 1,
662
0
          step_size, do_header_rd_based_gating, &unused, color_cache, n_cache,
663
0
          best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
664
0
          distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
665
0
          discount_color_cost);
666
      // Evaluate neighbors for the winner color (if winner is found) in the
667
      // above coarse search for dominant colors
668
0
      if (top_color_winner <= max_n) {
669
0
        int stage2_min_n, stage2_max_n, stage2_step_size;
670
0
        set_stage2_params(&stage2_min_n, &stage2_max_n, &stage2_step_size,
671
0
                          top_color_winner, max_n);
672
        // perform finer search for the winner candidate
673
0
        perform_top_color_palette_search(
674
0
            cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, stage2_min_n,
675
0
            stage2_max_n + 1, stage2_step_size,
676
            /*do_header_rd_based_gating=*/false, &unused, color_cache, n_cache,
677
0
            best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
678
0
            distortion, skippable, beat_best_rd, ctx, best_blk_skip,
679
0
            tx_type_map, discount_color_cost);
680
0
      }
681
      // K-means clustering.
682
      // Perform k-means coarse palette search to find the winner candidate
683
0
      const int k_means_winner = perform_k_means_palette_search(
684
0
          cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
685
0
          min_n, max_n + 1, step_size, do_header_rd_based_gating, &unused,
686
0
          color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
687
0
          rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
688
0
          best_blk_skip, tx_type_map, color_map, rows * cols,
689
0
          discount_color_cost);
690
      // Evaluate neighbors for the winner color (if winner is found) in the
691
      // above coarse search for k-means
692
0
      if (k_means_winner <= max_n) {
693
0
        int start_n_stage2, end_n_stage2, step_size_stage2;
694
0
        set_stage2_params(&start_n_stage2, &end_n_stage2, &step_size_stage2,
695
0
                          k_means_winner, max_n);
696
        // perform finer search for the winner candidate
697
0
        perform_k_means_palette_search(
698
0
            cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
699
0
            start_n_stage2, end_n_stage2 + 1, step_size_stage2,
700
            /*do_header_rd_based_gating=*/false, &unused, color_cache, n_cache,
701
0
            best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
702
0
            distortion, skippable, beat_best_rd, ctx, best_blk_skip,
703
0
            tx_type_map, color_map, rows * cols, discount_color_cost);
704
0
      }
705
0
    } else {
706
0
      const int max_n = AOMMIN(colors, PALETTE_MAX_SIZE),
707
0
                min_n = PALETTE_MIN_SIZE;
708
      // Perform top color palette search in ascending order
709
0
      int last_n_searched = min_n;
710
0
      perform_top_color_palette_search(
711
0
          cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, min_n, max_n + 1,
712
0
          1, do_header_rd_based_gating, &last_n_searched, color_cache, n_cache,
713
0
          best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
714
0
          distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
715
0
          discount_color_cost);
716
0
      if (last_n_searched < max_n) {
717
        // Search in descending order until we get to the previous best
718
0
        perform_top_color_palette_search(
719
0
            cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, max_n,
720
0
            last_n_searched, -1, /*do_header_rd_based_gating=*/false, &unused,
721
0
            color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
722
0
            rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
723
0
            best_blk_skip, tx_type_map, discount_color_cost);
724
0
      }
725
      // K-means clustering.
726
0
      if (colors == PALETTE_MIN_SIZE) {
727
        // Special case: These colors automatically become the centroids.
728
0
        assert(colors == 2);
729
0
        centroids[0] = lower_bound;
730
0
        centroids[1] = upper_bound;
731
0
        palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, colors,
732
0
                     color_cache, n_cache, /*do_header_rd_based_gating=*/false,
733
0
                     best_mbmi, best_palette_color_map, best_rd, rate,
734
0
                     rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
735
0
                     best_blk_skip, tx_type_map, NULL, NULL,
736
0
                     discount_color_cost);
737
0
      } else {
738
        // Perform k-means palette search in ascending order
739
0
        last_n_searched = min_n;
740
0
        perform_k_means_palette_search(
741
0
            cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
742
0
            min_n, max_n + 1, 1, do_header_rd_based_gating, &last_n_searched,
743
0
            color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
744
0
            rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
745
0
            best_blk_skip, tx_type_map, color_map, rows * cols,
746
0
            discount_color_cost);
747
0
        if (last_n_searched < max_n) {
748
          // Search in descending order until we get to the previous best
749
0
          perform_k_means_palette_search(
750
0
              cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
751
0
              max_n, last_n_searched, -1, /*do_header_rd_based_gating=*/false,
752
0
              &unused, color_cache, n_cache, best_mbmi, best_palette_color_map,
753
0
              best_rd, rate, rate_tokenonly, distortion, skippable,
754
0
              beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
755
0
              rows * cols, discount_color_cost);
756
0
        }
757
0
      }
758
0
    }
759
0
  }
760
761
0
  if (best_mbmi->palette_mode_info.palette_size[0] > 0) {
762
0
    memcpy(color_map, best_palette_color_map,
763
0
           block_width * block_height * sizeof(best_palette_color_map[0]));
764
    // Gather the stats to determine whether to use screen content tools in
765
    // function av1_determine_sc_tools_with_encoding().
766
0
    x->palette_pixels += (block_width * block_height);
767
0
  }
768
0
  *mbmi = *best_mbmi;
769
0
}
770
771
void av1_rd_pick_palette_intra_sbuv(const AV1_COMP *cpi, MACROBLOCK *x,
772
                                    int dc_mode_cost,
773
                                    uint8_t *best_palette_color_map,
774
                                    MB_MODE_INFO *const best_mbmi,
775
                                    int64_t *best_rd, int *rate,
776
                                    int *rate_tokenonly, int64_t *distortion,
777
0
                                    uint8_t *skippable) {
778
0
  MACROBLOCKD *const xd = &x->e_mbd;
779
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
780
0
  assert(!is_inter_block(mbmi));
781
0
  assert(av1_allow_palette(cpi->common.features.allow_screen_content_tools,
782
0
                           mbmi->bsize));
783
0
  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
784
0
  const BLOCK_SIZE bsize = mbmi->bsize;
785
0
  const SequenceHeader *const seq_params = cpi->common.seq_params;
786
0
  int this_rate;
787
0
  int64_t this_rd;
788
0
  int colors_u, colors_v;
789
0
  int colors_threshold_u = 0, colors_threshold_v = 0, colors_threshold = 0;
790
0
  const int src_stride = x->plane[1].src.stride;
791
0
  const uint8_t *const src_u = x->plane[1].src.buf;
792
0
  const uint8_t *const src_v = x->plane[2].src.buf;
793
0
  uint8_t *const color_map = xd->plane[1].color_index_map;
794
0
  RD_STATS tokenonly_rd_stats;
795
0
  int plane_block_width, plane_block_height, rows, cols;
796
0
  av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
797
0
                           &plane_block_height, &rows, &cols);
798
799
0
  mbmi->uv_mode = UV_DC_PRED;
800
0
  if (seq_params->use_highbitdepth) {
801
0
    int count_buf[1 << 12];      // Maximum (1 << 12) color levels.
802
0
    int count_buf_8bit[1 << 8];  // Maximum (1 << 8) bins for hbd path.
803
0
    av1_count_colors_highbd(src_u, src_stride, rows, cols,
804
0
                            seq_params->bit_depth, count_buf, count_buf_8bit,
805
0
                            &colors_threshold_u, &colors_u);
806
0
    av1_count_colors_highbd(src_v, src_stride, rows, cols,
807
0
                            seq_params->bit_depth, count_buf, count_buf_8bit,
808
0
                            &colors_threshold_v, &colors_v);
809
0
  } else {
810
0
    int count_buf[1 << 8];
811
0
    av1_count_colors(src_u, src_stride, rows, cols, count_buf, &colors_u);
812
0
    av1_count_colors(src_v, src_stride, rows, cols, count_buf, &colors_v);
813
0
    colors_threshold_u = colors_u;
814
0
    colors_threshold_v = colors_v;
815
0
  }
816
817
0
  uint16_t color_cache[2 * PALETTE_MAX_SIZE];
818
0
  const int n_cache = av1_get_palette_cache(xd, 1, color_cache);
819
820
0
  colors_threshold = colors_threshold_u > colors_threshold_v
821
0
                         ? colors_threshold_u
822
0
                         : colors_threshold_v;
823
0
  if (colors_threshold > 1 && colors_threshold <= 64) {
824
0
    int r, c, n, i, j;
825
0
    const int max_itr = 50;
826
0
    int lb_u, ub_u, val_u;
827
0
    int lb_v, ub_v, val_v;
828
0
    int16_t *const data = x->palette_buffer->kmeans_data_buf;
829
0
    int16_t centroids[2 * PALETTE_MAX_SIZE];
830
831
0
    uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u);
832
0
    uint16_t *src_v16 = CONVERT_TO_SHORTPTR(src_v);
833
0
    if (seq_params->use_highbitdepth) {
834
0
      lb_u = src_u16[0];
835
0
      ub_u = src_u16[0];
836
0
      lb_v = src_v16[0];
837
0
      ub_v = src_v16[0];
838
0
    } else {
839
0
      lb_u = src_u[0];
840
0
      ub_u = src_u[0];
841
0
      lb_v = src_v[0];
842
0
      ub_v = src_v[0];
843
0
    }
844
845
0
    for (r = 0; r < rows; ++r) {
846
0
      for (c = 0; c < cols; ++c) {
847
0
        if (seq_params->use_highbitdepth) {
848
0
          val_u = src_u16[r * src_stride + c];
849
0
          val_v = src_v16[r * src_stride + c];
850
0
          data[(r * cols + c) * 2] = val_u;
851
0
          data[(r * cols + c) * 2 + 1] = val_v;
852
0
        } else {
853
0
          val_u = src_u[r * src_stride + c];
854
0
          val_v = src_v[r * src_stride + c];
855
0
          data[(r * cols + c) * 2] = val_u;
856
0
          data[(r * cols + c) * 2 + 1] = val_v;
857
0
        }
858
0
        if (val_u < lb_u)
859
0
          lb_u = val_u;
860
0
        else if (val_u > ub_u)
861
0
          ub_u = val_u;
862
0
        if (val_v < lb_v)
863
0
          lb_v = val_v;
864
0
        else if (val_v > ub_v)
865
0
          ub_v = val_v;
866
0
      }
867
0
    }
868
869
0
    const int colors = colors_u > colors_v ? colors_u : colors_v;
870
0
    const int max_colors =
871
0
        colors > PALETTE_MAX_SIZE ? PALETTE_MAX_SIZE : colors;
872
0
    for (n = PALETTE_MIN_SIZE; n <= max_colors; ++n) {
873
0
      for (i = 0; i < n; ++i) {
874
0
        centroids[i * 2] = lb_u + (2 * i + 1) * (ub_u - lb_u) / n / 2;
875
0
        centroids[i * 2 + 1] = lb_v + (2 * i + 1) * (ub_v - lb_v) / n / 2;
876
0
      }
877
0
      av1_k_means(data, centroids, color_map, rows * cols, n, 2, max_itr);
878
0
      optimize_palette_colors(color_cache, n_cache, n, 2, centroids,
879
0
                              cpi->common.seq_params->bit_depth);
880
      // Sort the U channel colors in ascending order.
881
0
      for (i = 0; i < 2 * (n - 1); i += 2) {
882
0
        int min_idx = i;
883
0
        int min_val = centroids[i];
884
0
        for (j = i + 2; j < 2 * n; j += 2)
885
0
          if (centroids[j] < min_val) min_val = centroids[j], min_idx = j;
886
0
        if (min_idx != i) {
887
0
          int temp_u = centroids[i], temp_v = centroids[i + 1];
888
0
          centroids[i] = centroids[min_idx];
889
0
          centroids[i + 1] = centroids[min_idx + 1];
890
0
          centroids[min_idx] = temp_u, centroids[min_idx + 1] = temp_v;
891
0
        }
892
0
      }
893
0
      av1_calc_indices(data, centroids, color_map, rows * cols, n, 2);
894
0
      extend_palette_color_map(color_map, cols, rows, plane_block_width,
895
0
                               plane_block_height);
896
0
      pmi->palette_size[1] = n;
897
0
      for (i = 1; i < 3; ++i) {
898
0
        for (j = 0; j < n; ++j) {
899
0
          if (seq_params->use_highbitdepth)
900
0
            pmi->palette_colors[i * PALETTE_MAX_SIZE + j] = clip_pixel_highbd(
901
0
                (int)centroids[j * 2 + i - 1], seq_params->bit_depth);
902
0
          else
903
0
            pmi->palette_colors[i * PALETTE_MAX_SIZE + j] =
904
0
                clip_pixel((int)centroids[j * 2 + i - 1]);
905
0
        }
906
0
      }
907
908
0
      if (cpi->sf.intra_sf.early_term_chroma_palette_size_search) {
909
0
        const int palette_mode_rate =
910
0
            intra_mode_info_cost_uv(cpi, x, mbmi, bsize, dc_mode_cost);
911
0
        const int64_t header_rd = RDCOST(x->rdmult, palette_mode_rate, 0);
912
        // Terminate further palette_size search, if header cost corresponding
913
        // to lower palette_size is more than the best_rd.
914
0
        if (header_rd >= *best_rd) break;
915
0
        av1_txfm_uvrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
916
0
        if (tokenonly_rd_stats.rate == INT_MAX) continue;
917
0
        this_rate = tokenonly_rd_stats.rate + palette_mode_rate;
918
0
      } else {
919
0
        av1_txfm_uvrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
920
0
        if (tokenonly_rd_stats.rate == INT_MAX) continue;
921
0
        this_rate = tokenonly_rd_stats.rate +
922
0
                    intra_mode_info_cost_uv(cpi, x, mbmi, bsize, dc_mode_cost);
923
0
      }
924
925
0
      this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
926
0
      if (this_rd < *best_rd) {
927
0
        *best_rd = this_rd;
928
0
        *best_mbmi = *mbmi;
929
0
        memcpy(best_palette_color_map, color_map,
930
0
               plane_block_width * plane_block_height *
931
0
                   sizeof(best_palette_color_map[0]));
932
0
        *rate = this_rate;
933
0
        *distortion = tokenonly_rd_stats.dist;
934
0
        *rate_tokenonly = tokenonly_rd_stats.rate;
935
0
        *skippable = tokenonly_rd_stats.skip_txfm;
936
0
      }
937
0
    }
938
0
  }
939
0
  if (best_mbmi->palette_mode_info.palette_size[1] > 0) {
940
0
    memcpy(color_map, best_palette_color_map,
941
0
           plane_block_width * plane_block_height *
942
0
               sizeof(best_palette_color_map[0]));
943
0
  }
944
0
}
945
946
0
void av1_restore_uv_color_map(const AV1_COMP *cpi, MACROBLOCK *x) {
947
0
  MACROBLOCKD *const xd = &x->e_mbd;
948
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
949
0
  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
950
0
  const BLOCK_SIZE bsize = mbmi->bsize;
951
0
  int src_stride = x->plane[1].src.stride;
952
0
  const uint8_t *const src_u = x->plane[1].src.buf;
953
0
  const uint8_t *const src_v = x->plane[2].src.buf;
954
0
  int16_t *const data = x->palette_buffer->kmeans_data_buf;
955
0
  int16_t centroids[2 * PALETTE_MAX_SIZE];
956
0
  uint8_t *const color_map = xd->plane[1].color_index_map;
957
0
  int r, c;
958
0
  const uint16_t *const src_u16 = CONVERT_TO_SHORTPTR(src_u);
959
0
  const uint16_t *const src_v16 = CONVERT_TO_SHORTPTR(src_v);
960
0
  int plane_block_width, plane_block_height, rows, cols;
961
0
  av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
962
0
                           &plane_block_height, &rows, &cols);
963
964
0
  for (r = 0; r < rows; ++r) {
965
0
    for (c = 0; c < cols; ++c) {
966
0
      if (cpi->common.seq_params->use_highbitdepth) {
967
0
        data[(r * cols + c) * 2] = src_u16[r * src_stride + c];
968
0
        data[(r * cols + c) * 2 + 1] = src_v16[r * src_stride + c];
969
0
      } else {
970
0
        data[(r * cols + c) * 2] = src_u[r * src_stride + c];
971
0
        data[(r * cols + c) * 2 + 1] = src_v[r * src_stride + c];
972
0
      }
973
0
    }
974
0
  }
975
976
0
  for (r = 1; r < 3; ++r) {
977
0
    for (c = 0; c < pmi->palette_size[1]; ++c) {
978
0
      centroids[c * 2 + r - 1] = pmi->palette_colors[r * PALETTE_MAX_SIZE + c];
979
0
    }
980
0
  }
981
982
0
  av1_calc_indices(data, centroids, color_map, rows * cols,
983
0
                   pmi->palette_size[1], 2);
984
0
  extend_palette_color_map(color_map, cols, rows, plane_block_width,
985
0
                           plane_block_height);
986
0
}