Coverage Report

Created: 2026-05-16 06:41

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/work/svt-av1/Source/Lib/Codec/palette.c
Line
Count
Source
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
 * This source code is subject to the terms of the BSD 2 Clause License and
5
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6
 * was not distributed with this source code in the LICENSE file, you can
7
 * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8
 * Media Patent License 1.0 was not distributed with this source code in the
9
 * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10
 */
11
12
#include <math.h>
13
#include <stdlib.h>
14
#include "definitions.h"
15
#include "md_process.h"
16
#include "aom_dsp_rtcd.h"
17
18
#define DIVIDE_AND_ROUND(x, y) (((x) + ((y) >> 1)) / (y))
19
20
0
#define AV1_K_MEANS_RENAME(func, dim) func##_dim##dim##_c
21
22
void AV1_K_MEANS_RENAME(svt_av1_calc_indices, 1)(const int* data, const int* centroids, uint8_t* indices, int n, int k);
23
void AV1_K_MEANS_RENAME(svt_av1_calc_indices, 2)(const int* data, const int* centroids, uint8_t* indices, int n, int k);
24
void AV1_K_MEANS_RENAME(svt_av1_k_means, 1)(const int* data, int* centroids, uint8_t* indices, int n, int k,
25
                                            int max_itr);
26
void AV1_K_MEANS_RENAME(svt_av1_k_means, 2)(const int* data, int* centroids, uint8_t* indices, int n, int k,
27
                                            int max_itr);
28
29
// Given 'n' 'data' points and 'k' 'centroids' each of dimension 'dim',
30
// calculate the centroid 'indices' for the data points.
31
static inline void av1_calc_indices(const int* data, const int* centroids, uint8_t* indices, int n, int k, int dim) {
32
    if (dim == 1) {
33
        svt_av1_calc_indices_dim1(data, centroids, indices, n, k);
34
    } else if (dim == 2) {
35
        svt_av1_calc_indices_dim2(data, centroids, indices, n, k);
36
    } else {
37
        assert(0 && "Untemplated k means dimension");
38
    }
39
}
40
41
// Given 'n' 'data' points and an initial guess of 'k' 'centroids' each of
42
// dimension 'dim', runs up to 'max_itr' iterations of k-means algorithm to get
43
// updated 'centroids' and the centroid 'indices' for elements in 'data'.
44
// Note: the output centroids are rounded off to nearest integers.
45
static inline void av1_k_means(const int* data, int* centroids, uint8_t* indices, int n, int k, int dim, int max_itr) {
46
    if (dim == 1) {
47
        svt_av1_k_means_dim1(data, centroids, indices, n, k, max_itr);
48
    } else if (dim == 2) {
49
        svt_av1_k_means_dim2(data, centroids, indices, n, k, max_itr);
50
    } else {
51
        assert(0 && "Untemplated k means dimension");
52
    }
53
}
54
55
0
#define AV1_K_MEANS_DIM 1
56
#include "k_means_template.h"
57
#undef AV1_K_MEANS_DIM
58
0
#define AV1_K_MEANS_DIM 2
59
#include "k_means_template.h"
60
#undef AV1_K_MEANS_DIM
61
62
static int int_comparer(const void* a, const void* b) {
63
    return (*(int*)a - *(int*)b);
64
}
65
66
0
static int av1_remove_duplicates(int* centroids, int num_centroids) {
67
0
    int num_unique; // number of unique centroids
68
0
    int i;
69
0
    qsort(centroids, num_centroids, sizeof(*centroids), int_comparer);
70
    // Remove duplicates.
71
0
    num_unique = 1;
72
0
    for (i = 1; i < num_centroids; ++i) {
73
0
        if (centroids[i] != centroids[i - 1]) { // found a new unique centroid
74
0
            centroids[num_unique++] = centroids[i];
75
0
        }
76
0
    }
77
0
    return num_unique;
78
0
}
79
80
static int delta_encode_cost(const int* colors, int num, int bit_depth, int min_val) {
81
    if (num <= 0) {
82
        return 0;
83
    }
84
    int bits_cost = bit_depth;
85
    if (num == 1) {
86
        return bits_cost;
87
    }
88
    bits_cost += 2;
89
    int       max_delta = 0;
90
    int       deltas[PALETTE_MAX_SIZE];
91
    const int min_bits = bit_depth - 3;
92
    for (int i = 1; i < num; ++i) {
93
        const int delta = colors[i] - colors[i - 1];
94
        deltas[i - 1]   = delta;
95
        assert(delta >= min_val);
96
        if (delta > max_delta) {
97
            max_delta = delta;
98
        }
99
    }
100
    int bits_per_delta = AOMMAX(av1_ceil_log2(max_delta + 1 - min_val), min_bits);
101
    assert(bits_per_delta <= bit_depth);
102
    int range = (1 << bit_depth) - colors[0] - min_val;
103
    for (int i = 0; i < num - 1; ++i) {
104
        bits_cost += bits_per_delta;
105
        range -= deltas[i];
106
        bits_per_delta = AOMMIN(bits_per_delta, av1_ceil_log2(range));
107
    }
108
    return bits_cost;
109
}
110
111
int svt_av1_index_color_cache(const uint16_t* color_cache, int n_cache, const uint16_t* colors, int n_colors,
112
0
                              uint8_t* cache_color_found, int* out_cache_colors) {
113
0
    if (n_cache <= 0) {
114
0
        for (int i = 0; i < n_colors; ++i) {
115
0
            out_cache_colors[i] = colors[i];
116
0
        }
117
0
        return n_colors;
118
0
    }
119
0
    memset(cache_color_found, 0, n_cache * sizeof(*cache_color_found));
120
0
    int n_in_cache = 0;
121
0
    int in_cache_flags[PALETTE_MAX_SIZE];
122
0
    memset(in_cache_flags, 0, sizeof(in_cache_flags));
123
0
    for (int i = 0; i < n_cache && n_in_cache < n_colors; ++i) {
124
0
        for (int j = 0; j < n_colors; ++j) {
125
0
            if (colors[j] == color_cache[i]) {
126
0
                in_cache_flags[j]    = 1;
127
0
                cache_color_found[i] = 1;
128
0
                ++n_in_cache;
129
0
                break;
130
0
            }
131
0
        }
132
0
    }
133
0
    int j = 0;
134
0
    for (int i = 0; i < n_colors; ++i) {
135
0
        if (!in_cache_flags[i]) {
136
0
            out_cache_colors[j++] = colors[i];
137
0
        }
138
0
    }
139
0
    assert(j == n_colors - n_in_cache);
140
0
    return j;
141
0
}
142
143
int svt_av1_palette_color_cost_y(const PaletteModeInfo* const pmi, uint16_t* color_cache, const int palette_size,
144
0
                                 int n_cache, int bit_depth) {
145
0
    const int n = palette_size;
146
0
    int       out_cache_colors[PALETTE_MAX_SIZE];
147
0
    uint8_t   cache_color_found[2 * PALETTE_MAX_SIZE];
148
0
    const int n_out_cache = svt_av1_index_color_cache(
149
0
        color_cache, n_cache, pmi->palette_colors, n, cache_color_found, out_cache_colors);
150
0
    const int total_bits = n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 1);
151
0
    return av1_cost_literal(total_bits);
152
0
}
153
154
0
static void palette_add_to_cache(uint16_t* cache, int* n, uint16_t val) {
155
    // Do not add an already existing value
156
0
    if (*n > 0 && val == cache[*n - 1]) {
157
0
        return;
158
0
    }
159
160
0
    cache[(*n)++] = val;
161
0
}
162
163
// Get palette cache for luma only
164
0
int svt_get_palette_cache_y(const MacroBlockD* const xd, uint16_t* cache) {
165
0
    const int row = -xd->mb_to_top_edge >> 3;
166
    // Do not refer to above SB row when on SB boundary.
167
0
    const MbModeInfo* const above_mi = (row % (1 << MIN_SB_SIZE_LOG2)) ? xd->above_mbmi : NULL;
168
0
    const MbModeInfo* const left_mi  = xd->left_mbmi;
169
0
    int                     above_n = 0, left_n = 0;
170
0
    if (above_mi) {
171
0
        above_n = above_mi->palette_mode_info.palette_size;
172
0
    }
173
0
    if (left_mi) {
174
0
        left_n = left_mi->palette_mode_info.palette_size;
175
0
    }
176
0
    if (above_n == 0 && left_n == 0) {
177
0
        return 0;
178
0
    }
179
0
    int             above_idx    = 0;
180
0
    int             left_idx     = 0;
181
0
    int             n            = 0;
182
0
    const uint16_t* above_colors = above_mi ? above_mi->palette_mode_info.palette_colors : NULL;
183
0
    const uint16_t* left_colors  = left_mi ? left_mi->palette_mode_info.palette_colors : NULL;
184
    // Merge the sorted lists of base colors from above and left to get
185
    // combined sorted color cache.
186
0
    while (above_n > 0 && left_n > 0) {
187
0
        uint16_t v_above = above_colors[above_idx];
188
0
        uint16_t v_left  = left_colors[left_idx];
189
0
        if (v_left < v_above) {
190
0
            palette_add_to_cache(cache, &n, v_left);
191
0
            ++left_idx, --left_n;
192
0
        } else {
193
0
            palette_add_to_cache(cache, &n, v_above);
194
0
            ++above_idx, --above_n;
195
0
            if (v_left == v_above) {
196
0
                ++left_idx, --left_n;
197
0
            }
198
0
        }
199
0
    }
200
0
    while (above_n-- > 0) {
201
0
        uint16_t val = above_colors[above_idx++];
202
0
        palette_add_to_cache(cache, &n, val);
203
0
    }
204
0
    while (left_n-- > 0) {
205
0
        uint16_t val = left_colors[left_idx++];
206
0
        palette_add_to_cache(cache, &n, val);
207
0
    }
208
0
    assert(n <= 2 * PALETTE_MAX_SIZE);
209
0
    return n;
210
0
}
211
212
// Returns sub-sampled dimensions of the given block.
213
// The output values for 'rows_within_bounds' and 'cols_within_bounds' will
214
// differ from 'height' and 'width' when part of the block is outside the
215
// right
216
// and/or bottom image boundary.
217
void svt_aom_get_block_dimensions(BlockSize bsize, int plane, const MacroBlockD* xd, int* width, int* height,
218
0
                                  int* rows_within_bounds, int* cols_within_bounds) {
219
0
    const int block_height = block_size_high[bsize];
220
0
    const int block_width  = block_size_wide[bsize];
221
0
    const int block_rows   = (xd->mb_to_bottom_edge >= 0) ? block_height : (xd->mb_to_bottom_edge >> 3) + block_height;
222
0
    const int block_cols   = (xd->mb_to_right_edge >= 0) ? block_width : (xd->mb_to_right_edge >> 3) + block_width;
223
224
0
    uint8_t subsampling_x = plane == 0 ? 0 : 1;
225
0
    uint8_t subsampling_y = plane == 0 ? 0 : 1;
226
227
0
    assert(block_width >= block_cols);
228
0
    assert(block_height >= block_rows);
229
0
    const int plane_block_width  = block_width >> subsampling_x;
230
0
    const int plane_block_height = block_height >> subsampling_y;
231
    // Special handling for chroma sub8x8.
232
0
    const int is_chroma_sub8_x = plane > 0 && plane_block_width < 4;
233
0
    const int is_chroma_sub8_y = plane > 0 && plane_block_height < 4;
234
0
    if (width) {
235
0
        *width = plane_block_width + 2 * is_chroma_sub8_x;
236
0
    }
237
0
    if (height) {
238
0
        *height = plane_block_height + 2 * is_chroma_sub8_y;
239
0
    }
240
0
    if (rows_within_bounds) {
241
0
        *rows_within_bounds = (block_rows >> subsampling_y) + 2 * is_chroma_sub8_y;
242
0
    }
243
0
    if (cols_within_bounds) {
244
0
        *cols_within_bounds = (block_cols >> subsampling_x) + 2 * is_chroma_sub8_x;
245
0
    }
246
0
}
247
248
// Bias toward using colors in the cache.
249
// TODO: Try other schemes to improve compression.
250
static AOM_INLINE void optimize_palette_colors(uint16_t* color_cache, int n_cache, int n_colors, int stride,
251
                                               int* centroids, int bit_depth, uint8_t qp_index) {
252
    if (n_cache <= 0) {
253
        return;
254
    }
255
    for (int i = 0; i < n_colors * stride; i += stride) {
256
        int min_diff = abs((int)centroids[i] - (int)color_cache[0]);
257
        int idx      = 0;
258
        for (int j = 1; j < n_cache; ++j) {
259
            const int this_diff = abs((int)centroids[i] - (int)color_cache[j]);
260
            if (this_diff < min_diff) {
261
                min_diff = this_diff;
262
                idx      = j;
263
            }
264
        }
265
        const int min_threshold = (6 + (qp_index >> 6)) << (bit_depth - 8);
266
        if (min_diff <= min_threshold) {
267
            centroids[i] = color_cache[idx];
268
        }
269
    }
270
}
271
272
// Extends 'color_map' array from 'orig_width x orig_height' to 'new_width x
273
// new_height'. Extra rows and columns are filled in by copying last valid
274
// row/column.
275
static AOM_INLINE void extend_palette_color_map(uint8_t* const color_map, int orig_width, int orig_height,
276
                                                int new_width, int new_height) {
277
    int j;
278
    assert(new_width >= orig_width);
279
    assert(new_height >= orig_height);
280
    if (new_width == orig_width && new_height == orig_height) {
281
        return;
282
    }
283
284
    for (j = orig_height - 1; j >= 0; --j) {
285
        memmove(color_map + j * new_width, color_map + j * orig_width, orig_width);
286
        // Copy last column to extra columns.
287
        memset(
288
            color_map + j * new_width + orig_width, color_map[j * new_width + orig_width - 1], new_width - orig_width);
289
    }
290
    // Copy last row to extra rows.
291
    for (j = orig_height; j < new_height; ++j) {
292
        svt_memcpy(color_map + j * new_width, color_map + (orig_height - 1) * new_width, new_width);
293
    }
294
}
295
296
static void palette_rd_y(PaletteInfo* palette_info, uint8_t* palette_size_array, ModeDecisionContext* ctx,
297
                         bool opt_colors, BlockSize bsize, const int* data, int* centroids, int n,
298
                         uint16_t* color_cache, int n_cache, int bit_depth) {
299
    if (opt_colors) {
300
        optimize_palette_colors(color_cache, n_cache, n, 1, centroids, bit_depth, ctx->qp_index);
301
    }
302
    int k = av1_remove_duplicates(centroids, n);
303
    if (k < PALETTE_MIN_SIZE) {
304
        // Too few unique colors to create a palette. And DC_PRED will work
305
        // well for that case anyway. So skip.
306
        palette_size_array[0] = 0;
307
        return;
308
    }
309
310
    if (bit_depth > EB_EIGHT_BIT) {
311
        for (int i = 0; i < k; ++i) {
312
            palette_info->pmi.palette_colors[i] = clip_pixel_highbd((int)centroids[i], bit_depth);
313
        }
314
    } else {
315
        for (int i = 0; i < k; ++i) {
316
            palette_info->pmi.palette_colors[i] = clip_pixel(centroids[i]);
317
        }
318
    }
319
    palette_size_array[0]    = k;
320
    uint8_t* const color_map = palette_info->color_idx_map;
321
    int            block_width, block_height, rows, cols;
322
    svt_aom_get_block_dimensions(bsize, 0, ctx->blk_ptr->av1xd, &block_width, &block_height, &rows, &cols);
323
    av1_calc_indices(data, centroids, color_map, rows * cols, k, 1);
324
    extend_palette_color_map(color_map, cols, rows, block_width, block_height);
325
}
326
327
int svt_av1_count_colors(const uint8_t* src, int stride, int rows, int cols, int* val_count);
328
int svt_av1_count_colors_highbd(uint16_t* src, int stride, int rows, int cols, int bit_depth, int* val_count);
329
330
static void cache_based_centroid_refinement(int* data, int rows, int cols, int n, int* centroids,
331
0
                                            uint8_t* color_idx_map, uint16_t* color_cache, int n_cache) {
332
0
    const int total = rows * cols;
333
0
    uint8_t   temp_map[MAX_SB_SQUARE];
334
335
    // Compute baseline SSE
336
0
    uint64_t baseline_sse = 0;
337
0
    for (int i = 0; i < total; i++) {
338
0
        int diff = data[i] - centroids[color_idx_map[i]];
339
0
        baseline_sse += (uint64_t)diff * diff;
340
0
    }
341
342
0
    for (int c = 0; c < n; c++) {
343
0
        int      original = centroids[c];
344
0
        int      best_val = original;
345
0
        uint64_t best_sse = baseline_sse;
346
347
0
        for (int k = 0; k < n_cache; k++) {
348
0
            int candidate = color_cache[k];
349
0
            if (candidate == original) {
350
0
                continue;
351
0
            }
352
353
0
            centroids[c] = candidate;
354
355
            // Reassign pixels
356
0
            for (int i = 0; i < total; i++) {
357
0
                int best_idx  = 0;
358
0
                int best_dist = abs(data[i] - centroids[0]);
359
360
0
                for (int j = 1; j < n; j++) {
361
0
                    int dist = abs(data[i] - centroids[j]);
362
0
                    if (dist < best_dist) {
363
0
                        best_dist = dist;
364
0
                        best_idx  = j;
365
0
                    }
366
0
                }
367
0
                temp_map[i] = best_idx;
368
0
            }
369
370
            // Compute SSE
371
0
            uint64_t sse = 0;
372
0
            for (int i = 0; i < total; i++) {
373
0
                int diff = data[i] - centroids[temp_map[i]];
374
0
                sse += (uint64_t)diff * diff;
375
0
            }
376
377
0
            if (sse < best_sse) {
378
0
                best_sse = sse;
379
0
                best_val = candidate;
380
0
                memcpy(color_idx_map, temp_map, total);
381
0
            }
382
0
        }
383
384
0
        centroids[c] = best_val;
385
0
    }
386
0
}
387
388
void search_palette_luma(PictureControlSet* pcs, ModeDecisionContext* ctx, PaletteInfo* palette_cand,
389
0
                         uint8_t* palette_size_array, uint32_t* tot_palette_cands) {
390
0
    int  colors;
391
0
    bool is16bit = ctx->hbd_md > 0;
392
393
0
    EbPictureBufferDesc* src_pic = is16bit ? pcs->input_frame16bit : pcs->ppcs->enhanced_pic;
394
395
0
    const int src_stride        = src_pic->y_stride;
396
0
    unsigned  palette_bit_depth = is16bit ? EB_TEN_BIT : EB_EIGHT_BIT;
397
398
0
    const uint8_t* const src = src_pic->y_buffer +
399
0
        (((ctx->blk_org_x) + (ctx->blk_org_y) * src_pic->y_stride) << is16bit);
400
401
0
    int block_width, block_height, rows, cols;
402
0
    svt_aom_get_block_dimensions(
403
0
        ctx->blk_geom->bsize, 0, ctx->blk_ptr->av1xd, &block_width, &block_height, &rows, &cols);
404
405
0
    int      count_buf[1 << 12];
406
0
    unsigned bit_depth = pcs->ppcs->scs->encoder_bit_depth;
407
408
0
    if (is16bit) {
409
0
        colors = svt_av1_count_colors_highbd((uint16_t*)src, src_stride, rows, cols, bit_depth, count_buf);
410
0
    } else {
411
0
        colors = svt_av1_count_colors(src, src_stride, rows, cols, count_buf);
412
0
    }
413
414
0
    if (colors <= 1 || colors > 64) {
415
0
        return;
416
0
    }
417
418
0
    const int max_n = AOMMIN(colors, PALETTE_MAX_SIZE);
419
0
    const int min_n = PALETTE_MIN_SIZE;
420
421
0
    int* data = ctx->palette_buffer->kmeans_data_buf;
422
0
    int  centroids[PALETTE_MAX_SIZE];
423
0
    int  lb, ub;
424
425
0
    lb = ub = is16bit ? ((uint16_t*)src)[0] : ((uint8_t*)src)[0];
426
427
0
    for (int r = 0; r < rows; ++r) {
428
0
        for (int c = 0; c < cols; ++c) {
429
0
            int val = is16bit ? ((uint16_t*)src)[r * src_stride + c] : ((uint8_t*)src)[r * src_stride + c];
430
431
0
            data[r * cols + c] = val;
432
0
            if (val < lb) {
433
0
                lb = val;
434
0
            }
435
0
            if (val > ub) {
436
0
                ub = val;
437
0
            }
438
0
        }
439
0
    }
440
#if !OPT_SC_STILL_IMAGE
441
    uint16_t  color_cache[2 * PALETTE_MAX_SIZE];
442
    const int n_cache = svt_get_palette_cache_y(ctx->blk_ptr->av1xd, color_cache);
443
#endif
444
#if !OPT_SC_STILL_IMAGE
445
    //  Extract dominant colors
446
    int top_colors[PALETTE_MAX_SIZE] = {0};
447
    for (int i = 0; i < max_n; ++i) {
448
        int max_count = 0;
449
        for (int j = 0; j < (1 << palette_bit_depth); ++j) {
450
            if (count_buf[j] > max_count) {
451
                max_count     = count_buf[j];
452
                top_colors[i] = j;
453
            }
454
        }
455
        count_buf[top_colors[i]] = 0;
456
    }
457
#endif
458
    //  Dominant-color search
459
0
    uint8_t dominant_color_step = pcs->ppcs->palette_ctrls.dominant_color_step;
460
0
    if (dominant_color_step != (uint8_t)~0) {
461
0
#if OPT_SC_STILL_IMAGE
462
        //  Extract dominant colors
463
0
        int top_colors[PALETTE_MAX_SIZE] = {0};
464
0
        for (int i = 0; i < max_n; ++i) {
465
0
            int max_count = 0;
466
0
            for (int j = 0; j < (1 << palette_bit_depth); ++j) {
467
0
                if (count_buf[j] > max_count) {
468
0
                    max_count     = count_buf[j];
469
0
                    top_colors[i] = j;
470
0
                }
471
0
            }
472
0
            count_buf[top_colors[i]] = 0;
473
0
        }
474
0
#endif
475
0
        for (int n = max_n; n >= min_n; n -= dominant_color_step) {
476
0
            for (int i = 0; i < n; ++i) {
477
0
                centroids[i] = top_colors[i];
478
0
            }
479
480
0
            uint32_t cand_index = *tot_palette_cands;
481
482
0
            palette_rd_y(&palette_cand[cand_index],
483
0
                         &palette_size_array[cand_index],
484
0
                         ctx,
485
0
                         false,
486
0
                         ctx->blk_geom->bsize,
487
0
                         data,
488
0
                         centroids,
489
0
                         n,
490
0
#if OPT_SC_STILL_IMAGE
491
0
                         NULL,
492
0
                         0,
493
#else
494
                         color_cache,
495
                         n_cache,
496
#endif
497
0
                         palette_bit_depth);
498
499
0
            if (palette_size_array[cand_index] >= PALETTE_MIN_SIZE) {
500
0
                (*tot_palette_cands)++;
501
0
            }
502
0
        }
503
0
    }
504
505
    // K-means search
506
0
    uint8_t kmean_color_step = pcs->ppcs->palette_ctrls.kmean_color_step;
507
0
    if (kmean_color_step != (uint8_t)~0) {
508
0
#if OPT_SC_STILL_IMAGE
509
0
        uint16_t  color_cache[2 * PALETTE_MAX_SIZE];
510
0
        const int n_cache = svt_get_palette_cache_y(ctx->blk_ptr->av1xd, color_cache);
511
0
#endif
512
0
        for (int n = max_n; n >= min_n; n -= kmean_color_step) {
513
0
            if (colors == PALETTE_MIN_SIZE) {
514
0
                centroids[0] = lb;
515
0
                centroids[1] = ub;
516
0
            } else {
517
0
                for (int i = 0; i < n; ++i) {
518
0
                    centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
519
0
                }
520
0
#if OPT_SC_STILL_IMAGE
521
0
                av1_k_means(data,
522
0
                            centroids,
523
0
                            palette_cand[*tot_palette_cands].color_idx_map,
524
0
                            rows * cols,
525
0
                            n,
526
0
                            1,
527
0
                            pcs->ppcs->palette_ctrls.k_means_max_itr);
528
#else
529
                av1_k_means(data, centroids, palette_cand[*tot_palette_cands].color_idx_map, rows * cols, n, 1, 50);
530
#endif
531
0
            }
532
0
            if (pcs->ppcs->palette_ctrls.centroid_refinement) {
533
0
                cache_based_centroid_refinement(data,
534
0
                                                rows,
535
0
                                                cols,
536
0
                                                n,
537
0
                                                centroids,
538
0
                                                palette_cand[*tot_palette_cands].color_idx_map,
539
0
                                                color_cache,
540
0
                                                n_cache);
541
0
            }
542
0
            uint32_t cand_index = *tot_palette_cands;
543
544
0
            palette_rd_y(&palette_cand[cand_index],
545
0
                         &palette_size_array[cand_index],
546
0
                         ctx,
547
0
                         true,
548
0
                         ctx->blk_geom->bsize,
549
0
                         data,
550
0
                         centroids,
551
0
                         n,
552
0
                         color_cache,
553
0
                         n_cache,
554
0
                         palette_bit_depth);
555
556
0
            if (palette_size_array[cand_index] >= PALETTE_MIN_SIZE) {
557
0
                (*tot_palette_cands)++;
558
0
            }
559
0
        }
560
0
    }
561
0
}
562
563
typedef AomCdfProb (*MapCdf)[PALETTE_COLOR_INDEX_CONTEXTS][CDF_SIZE(PALETTE_COLORS)];
564
typedef const int (*ColorCost)[PALETTE_SIZES][PALETTE_COLOR_INDEX_CONTEXTS][PALETTE_COLORS];
565
566
typedef struct {
567
    int       rows;
568
    int       cols;
569
    int       n_colors;
570
    int       plane_width;
571
    uint8_t*  color_map;
572
    MapCdf    map_cdf;
573
    ColorCost color_cost;
574
} Av1ColorMapParam;
575
576
static void get_palette_params(FRAME_CONTEXT* frame_context, EcBlkStruct* blk_ptr, int plane, BlockSize bsize,
577
0
                               Av1ColorMapParam* params) {
578
0
    const MacroBlockD* const xd = blk_ptr->av1xd;
579
0
    params->color_map           = blk_ptr->palette_info->color_idx_map;
580
0
    params->map_cdf    = plane ? frame_context->palette_uv_color_index_cdf : frame_context->palette_y_color_index_cdf;
581
0
    params->color_cost = NULL;
582
0
    params->n_colors   = blk_ptr->palette_size[plane];
583
0
    svt_aom_get_block_dimensions(bsize, plane, xd, &params->plane_width, NULL, &params->rows, &params->cols);
584
0
}
585
586
static void get_color_map_params(FRAME_CONTEXT* frame_context, EcBlkStruct* blk_ptr, int plane, BlockSize bsize,
587
0
                                 TxSize tx_size, COLOR_MAP_TYPE type, Av1ColorMapParam* params) {
588
0
    (void)tx_size;
589
0
    memset(params, 0, sizeof(*params));
590
0
    switch (type) {
591
0
    case PALETTE_MAP:
592
0
        get_palette_params(frame_context, blk_ptr, plane, bsize, params);
593
0
        break;
594
0
    default:
595
0
        assert(0 && "Invalid color map type");
596
0
        return;
597
0
    }
598
0
}
599
600
static void get_palette_params_rate(ModeDecisionCandidate* cand, MdRateEstimationContext* rate_table,
601
0
                                    BlkStruct* blk_ptr, int plane, BlockSize bsize, Av1ColorMapParam* params) {
602
0
    PaletteInfo* palette_info = cand->palette_info;
603
604
0
    const MacroBlockD* const xd = blk_ptr->av1xd;
605
0
    params->color_map           = palette_info->color_idx_map;
606
0
    params->map_cdf             = NULL;
607
0
    params->color_cost          = plane ? NULL : (ColorCost)&rate_table->palette_ycolor_fac_bitss;
608
0
    params->n_colors            = cand->palette_size[plane];
609
610
0
    svt_aom_get_block_dimensions(bsize, plane, xd, &params->plane_width, NULL, &params->rows, &params->cols);
611
0
}
612
613
static void get_color_map_params_rate(ModeDecisionCandidate* cand, MdRateEstimationContext* rate_table,
614
                                      /*const MACROBLOCK *const x*/ BlkStruct* blk_ptr, int plane, BlockSize bsize,
615
0
                                      COLOR_MAP_TYPE type, Av1ColorMapParam* params) {
616
0
    memset(params, 0, sizeof(*params));
617
0
    switch (type) {
618
0
    case PALETTE_MAP:
619
0
        get_palette_params_rate(cand, rate_table, blk_ptr, plane, bsize, params);
620
0
        break;
621
0
    default:
622
0
        assert(0 && "Invalid color map type");
623
0
        return;
624
0
    }
625
0
}
626
627
#define SWAP(i, j)                               \
628
0
    do {                                         \
629
0
        const uint8_t tmp_score = score_rank[i]; \
630
0
        const uint8_t tmp_color = color_rank[i]; \
631
0
        score_rank[i]           = score_rank[j]; \
632
0
        color_rank[i]           = color_rank[j]; \
633
0
        score_rank[j]           = tmp_score;     \
634
0
        color_rank[j]           = tmp_color;     \
635
0
    } while (0)
636
637
#define MAX_COLOR_CONTEXT_HASH 8
638
// Negative values are invalid
639
int svt_aom_palette_color_index_context_lookup[MAX_COLOR_CONTEXT_HASH + 1] = {-1, -1, 0, -1, -1, 4, 3, 2, 1};
640
0
#define NUM_PALETTE_NEIGHBORS 3 // left, top-left and top.
641
0
#define INVALID_COLOR_IDX (UINT8_MAX)
642
643
static inline int av1_fast_palette_color_index_context_on_edge(const uint8_t* color_map, int stride, int r, int c,
644
0
                                                               int* color_idx) {
645
0
    const bool has_left  = (c - 1 >= 0);
646
0
    const bool has_above = (r - 1 >= 0);
647
0
    assert(r > 0 || c > 0);
648
0
    assert(has_above ^ has_left);
649
0
    assert(color_idx);
650
0
    (void)has_left;
651
652
0
    const uint8_t color_neighbor = has_above ? color_map[(r - 1) * stride + (c - 0)]
653
0
                                             : color_map[(r - 0) * stride + (c - 1)];
654
    // If the neighbor color has higher index than current color index, then we
655
    // move up by 1.
656
0
    const uint8_t current_color = *color_idx = color_map[r * stride + c];
657
0
    if (color_neighbor > current_color) {
658
0
        (*color_idx)++;
659
0
    } else if (color_neighbor == current_color) {
660
0
        *color_idx = 0;
661
0
    }
662
663
    // Get hash value of context.
664
    // The non-diagonal neighbors get a weight of 2.
665
0
    const uint8_t color_score          = 2;
666
0
    const uint8_t hash_multiplier      = 1;
667
0
    const uint8_t color_index_ctx_hash = color_score * hash_multiplier;
668
669
    // Lookup context from hash.
670
0
    const int color_index_ctx = svt_aom_palette_color_index_context_lookup[color_index_ctx_hash];
671
0
    assert(color_index_ctx == 0);
672
0
    (void)color_index_ctx;
673
0
    return 0;
674
0
}
675
676
// A faster version of av1_get_palette_color_index_context used by the encoder
677
// exploiting the fact that the encoder does not need to maintain a color order.
678
static inline int av1_fast_palette_color_index_context(const uint8_t* color_map, int stride, int r, int c,
679
0
                                                       int* color_idx) {
680
0
    assert(r > 0 || c > 0);
681
682
0
    const bool has_above = (r - 1 >= 0);
683
0
    const bool has_left  = (c - 1 >= 0);
684
0
    assert(has_above || has_left);
685
0
    if (has_above ^ has_left) {
686
0
        return av1_fast_palette_color_index_context_on_edge(color_map, stride, r, c, color_idx);
687
0
    }
688
689
    // This goes in the order of left, top, and top-left. This has the advantage
690
    // that unless anything here are not distinct or invalid, this will already
691
    // be in sorted order. Furthermore, if either of the first two is
692
    // invalid, we know the last one is also invalid.
693
0
    uint8_t color_neighbors[NUM_PALETTE_NEIGHBORS];
694
0
    color_neighbors[0] = color_map[(r - 0) * stride + (c - 1)];
695
0
    color_neighbors[1] = color_map[(r - 1) * stride + (c - 0)];
696
0
    color_neighbors[2] = color_map[(r - 1) * stride + (c - 1)];
697
698
    // Aggregate duplicated values.
699
    // Since our array is so small, using a couple if statements is faster
700
0
    uint8_t scores[NUM_PALETTE_NEIGHBORS] = {2, 2, 1};
701
0
    uint8_t num_invalid_colors            = 0;
702
0
    if (color_neighbors[0] == color_neighbors[1]) {
703
0
        scores[0] += scores[1];
704
0
        color_neighbors[1] = INVALID_COLOR_IDX;
705
0
        num_invalid_colors += 1;
706
707
0
        if (color_neighbors[0] == color_neighbors[2]) {
708
0
            scores[0] += scores[2];
709
0
            num_invalid_colors += 1;
710
0
        }
711
0
    } else if (color_neighbors[0] == color_neighbors[2]) {
712
0
        scores[0] += scores[2];
713
0
        num_invalid_colors += 1;
714
0
    } else if (color_neighbors[1] == color_neighbors[2]) {
715
0
        scores[1] += scores[2];
716
0
        num_invalid_colors += 1;
717
0
    }
718
719
0
    const uint8_t num_valid_colors = NUM_PALETTE_NEIGHBORS - num_invalid_colors;
720
721
0
    uint8_t* color_rank = color_neighbors;
722
0
    uint8_t* score_rank = scores;
723
724
    // Sort everything
725
0
    if (num_valid_colors > 1) {
726
0
        if (color_neighbors[1] == INVALID_COLOR_IDX) {
727
0
            scores[1]          = scores[2];
728
0
            color_neighbors[1] = color_neighbors[2];
729
0
        }
730
731
        // We need to swap the first two elements if they have the same score but
732
        // the color indices are not in the right order
733
0
        if (score_rank[0] < score_rank[1] || (score_rank[0] == score_rank[1] && color_rank[0] > color_rank[1])) {
734
0
            SWAP(0, 1);
735
0
        }
736
0
        if (num_valid_colors > 2) {
737
0
            if (score_rank[0] < score_rank[2]) {
738
0
                SWAP(0, 2);
739
0
            }
740
0
            if (score_rank[1] < score_rank[2]) {
741
0
                SWAP(1, 2);
742
0
            }
743
0
        }
744
0
    }
745
746
    // If any of the neighbor colors has higher index than current color index,
747
    // then we move up by 1 unless the current color is the same as one of the
748
    // neighbors.
749
0
    const uint8_t current_color = *color_idx = color_map[r * stride + c];
750
0
    for (int idx = 0; idx < num_valid_colors; idx++) {
751
0
        if (color_rank[idx] > current_color) {
752
0
            (*color_idx)++;
753
0
        } else if (color_rank[idx] == current_color) {
754
0
            *color_idx = idx;
755
0
            break;
756
0
        }
757
0
    }
758
759
    // Get hash value of context.
760
0
    uint8_t              color_index_ctx_hash                    = 0;
761
0
    static const uint8_t hash_multipliers[NUM_PALETTE_NEIGHBORS] = {1, 2, 2};
762
0
    for (int idx = 0; idx < num_valid_colors; ++idx) {
763
0
        color_index_ctx_hash += score_rank[idx] * hash_multipliers[idx];
764
0
    }
765
0
    assert(color_index_ctx_hash > 0);
766
0
    assert(color_index_ctx_hash <= MAX_COLOR_CONTEXT_HASH);
767
768
    // Lookup context from hash.
769
0
    const int color_index_ctx = 9 - color_index_ctx_hash;
770
0
    assert(color_index_ctx == svt_aom_palette_color_index_context_lookup[color_index_ctx_hash]);
771
0
    assert(color_index_ctx >= 0);
772
0
    assert(color_index_ctx < PALETTE_COLOR_INDEX_CONTEXTS);
773
0
    return color_index_ctx;
774
0
}
775
776
#undef INVALID_COLOR_IDX
777
#undef SWAP
778
779
static int cost_and_tokenize_map(Av1ColorMapParam* param, TOKENEXTRA** t, int plane, int calc_rate,
780
0
                                 int allow_update_cdf, MapCdf map_pb_cdf) {
781
0
    const uint8_t* const color_map         = param->color_map;
782
0
    MapCdf               map_cdf           = param->map_cdf;
783
0
    ColorCost            color_cost        = param->color_cost;
784
0
    const int            plane_block_width = param->plane_width;
785
0
    const int            rows              = param->rows;
786
0
    const int            cols              = param->cols;
787
0
    const int            n                 = param->n_colors;
788
0
    const int            palette_size_idx  = n - PALETTE_MIN_SIZE;
789
0
    int                  this_rate         = 0;
790
791
0
    (void)plane;
792
793
0
    for (int k = 1; k < rows + cols - 1; ++k) {
794
0
        for (int j = AOMMIN(k, cols - 1); j >= AOMMAX(0, k - rows + 1); --j) {
795
0
            int       i = k - j;
796
0
            int       color_new_idx;
797
0
            const int color_ctx = av1_fast_palette_color_index_context(
798
0
                color_map, plane_block_width, i, j, &color_new_idx);
799
0
            assert(color_new_idx >= 0 && color_new_idx < n);
800
0
            if (calc_rate) {
801
0
                this_rate += (*color_cost)[palette_size_idx][color_ctx][color_new_idx];
802
0
            } else {
803
0
                (*t)->token         = color_new_idx;
804
0
                (*t)->color_map_cdf = map_pb_cdf[palette_size_idx][color_ctx];
805
0
                ++(*t);
806
0
                if (allow_update_cdf) {
807
0
                    update_cdf(map_cdf[palette_size_idx][color_ctx], color_new_idx, n);
808
0
                }
809
#if CONFIG_ENTROPY_STATS
810
                if (plane) {
811
                    ++counts->palette_uv_color_index[palette_size_idx][color_ctx][color_new_idx];
812
                } else {
813
                    ++counts->palette_y_color_index[palette_size_idx][color_ctx][color_new_idx];
814
                }
815
#endif
816
0
            }
817
0
        }
818
0
    }
819
0
    return this_rate;
820
0
}
821
822
void svt_av1_tokenize_color_map(FRAME_CONTEXT* frame_context, EcBlkStruct* blk_ptr, int plane, TOKENEXTRA** t,
823
0
                                BlockSize bsize, TxSize tx_size, COLOR_MAP_TYPE type, int allow_update_cdf) {
824
0
    assert(plane == 0 || plane == 1);
825
0
    Av1ColorMapParam color_map_params;
826
0
    get_color_map_params(frame_context, blk_ptr, plane, bsize, tx_size, type, &color_map_params);
827
    // The first color index does not use context or entropy.
828
0
    (*t)->token         = color_map_params.color_map[0];
829
0
    (*t)->color_map_cdf = NULL;
830
0
    ++(*t);
831
0
    MapCdf map_pb_cdf = plane ? frame_context->palette_uv_color_index_cdf : frame_context->palette_y_color_index_cdf;
832
0
    cost_and_tokenize_map(&color_map_params, t, plane, 0, allow_update_cdf, map_pb_cdf);
833
0
}
834
835
int svt_av1_cost_color_map(ModeDecisionCandidate* cand, MdRateEstimationContext* rate_table, BlkStruct* blk_ptr,
836
0
                           int plane, BlockSize bsize, COLOR_MAP_TYPE type) {
837
0
    assert(plane == 0 || plane == 1);
838
0
    Av1ColorMapParam color_map_params;
839
0
    get_color_map_params_rate(cand, rate_table, blk_ptr, plane, bsize, type, &color_map_params);
840
0
    MapCdf map_pb_cdf = NULL;
841
    return cost_and_tokenize_map(&color_map_params, NULL, plane, 1, 0, map_pb_cdf);
842
0
}