Coverage Report

Created: 2025-06-22 08:04

/src/aom/av1/encoder/tx_search.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3
 *
4
 * This source code is subject to the terms of the BSD 2 Clause License and
5
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6
 * was not distributed with this source code in the LICENSE file, you can
7
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8
 * Media Patent License 1.0 was not distributed with this source code in the
9
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
 */
11
12
#include "av1/common/cfl.h"
13
#include "av1/common/reconintra.h"
14
#include "av1/encoder/block.h"
15
#include "av1/encoder/hybrid_fwd_txfm.h"
16
#include "av1/common/idct.h"
17
#include "av1/encoder/model_rd.h"
18
#include "av1/encoder/random.h"
19
#include "av1/encoder/rdopt_utils.h"
20
#include "av1/encoder/sorting_network.h"
21
#include "av1/encoder/tx_prune_model_weights.h"
22
#include "av1/encoder/tx_search.h"
23
#include "av1/encoder/txb_rdopt.h"
24
25
0
#define PROB_THRESH_OFFSET_TX_TYPE 100
26
27
struct rdcost_block_args {
28
  const AV1_COMP *cpi;
29
  MACROBLOCK *x;
30
  ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
31
  ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
32
  RD_STATS rd_stats;
33
  int64_t current_rd;
34
  int64_t best_rd;
35
  int exit_early;
36
  int incomplete_exit;
37
  FAST_TX_SEARCH_MODE ftxs_mode;
38
  int skip_trellis;
39
};
40
41
typedef struct {
42
  int64_t rd;
43
  int txb_entropy_ctx;
44
  TX_TYPE tx_type;
45
} TxCandidateInfo;
46
47
// origin_threshold * 128 / 100
48
static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
49
  {
50
      64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
51
      68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
52
  },
53
  {
54
      88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
55
      68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
56
  },
57
  {
58
      90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
59
      74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
60
  },
61
};
62
63
// lookup table for predict_skip_txfm
64
// int max_tx_size = max_txsize_rect_lookup[bsize];
65
// if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
66
//   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
67
static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
68
  TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
69
  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
70
  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
71
  TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
72
};
73
74
// look-up table for sqrt of number of pixels in a transform block
75
// rounded up to the nearest integer.
76
static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4,  8,  16, 32, 32, 6,  6,
77
                                                     12, 12, 23, 23, 32, 32, 8,
78
                                                     8,  16, 16, 23, 23 };
79
80
0
static inline uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
81
0
  const int rows = block_size_high[bsize];
82
0
  const int cols = block_size_wide[bsize];
83
0
  const int16_t *diff = x->plane[0].src_diff;
84
0
  const uint32_t hash =
85
0
      av1_get_crc32c_value(&x->txfm_search_info.mb_rd_record->crc_calculator,
86
0
                           (uint8_t *)diff, 2 * rows * cols);
87
0
  return (hash << 5) + bsize;
88
0
}
89
90
static inline int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
91
                                      const int64_t ref_best_rd,
92
0
                                      const uint32_t hash) {
93
0
  int32_t match_index = -1;
94
0
  if (ref_best_rd != INT64_MAX) {
95
0
    for (int i = 0; i < mb_rd_record->num; ++i) {
96
0
      const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
97
      // If there is a match in the mb_rd_record, fetch the RD decision and
98
      // terminate early.
99
0
      if (mb_rd_record->mb_rd_info[index].hash_value == hash) {
100
0
        match_index = index;
101
0
        break;
102
0
      }
103
0
    }
104
0
  }
105
0
  return match_index;
106
0
}
107
108
static inline void fetch_mb_rd_info(int n4, const MB_RD_INFO *const mb_rd_info,
109
                                    RD_STATS *const rd_stats,
110
0
                                    MACROBLOCK *const x) {
111
0
  MACROBLOCKD *const xd = &x->e_mbd;
112
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
113
0
  mbmi->tx_size = mb_rd_info->tx_size;
114
0
  memcpy(x->txfm_search_info.blk_skip, mb_rd_info->blk_skip,
115
0
         sizeof(mb_rd_info->blk_skip[0]) * n4);
116
0
  av1_copy(mbmi->inter_tx_size, mb_rd_info->inter_tx_size);
117
0
  av1_copy_array(xd->tx_type_map, mb_rd_info->tx_type_map, n4);
118
0
  *rd_stats = mb_rd_info->rd_stats;
119
0
}
120
121
int64_t av1_pixel_diff_dist(const MACROBLOCK *x, int plane, int blk_row,
122
                            int blk_col, const BLOCK_SIZE plane_bsize,
123
                            const BLOCK_SIZE tx_bsize,
124
0
                            unsigned int *block_mse_q8) {
125
0
  int visible_rows, visible_cols;
126
0
  const MACROBLOCKD *xd = &x->e_mbd;
127
0
  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
128
0
                     NULL, &visible_cols, &visible_rows);
129
0
  const int diff_stride = block_size_wide[plane_bsize];
130
0
  const int16_t *diff = x->plane[plane].src_diff;
131
132
0
  diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
133
0
  uint64_t sse =
134
0
      aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
135
0
  if (block_mse_q8 != NULL) {
136
0
    if (visible_cols > 0 && visible_rows > 0)
137
0
      *block_mse_q8 =
138
0
          (unsigned int)((256 * sse) / (visible_cols * visible_rows));
139
0
    else
140
0
      *block_mse_q8 = UINT_MAX;
141
0
  }
142
0
  return sse;
143
0
}
144
145
// Computes the residual block's SSE and mean on all visible 4x4s in the
146
// transform block
147
static inline int64_t pixel_diff_stats(
148
    MACROBLOCK *x, int plane, int blk_row, int blk_col,
149
    const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
150
0
    unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {
151
0
  int visible_rows, visible_cols;
152
0
  const MACROBLOCKD *xd = &x->e_mbd;
153
0
  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
154
0
                     NULL, &visible_cols, &visible_rows);
155
0
  const int diff_stride = block_size_wide[plane_bsize];
156
0
  const int16_t *diff = x->plane[plane].src_diff;
157
158
0
  diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
159
0
  uint64_t sse = 0;
160
0
  int sum = 0;
161
0
  sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum);
162
0
  if (visible_cols > 0 && visible_rows > 0) {
163
0
    double norm_factor = 1.0 / (visible_cols * visible_rows);
164
0
    int sign_sum = sum > 0 ? 1 : -1;
165
    // Conversion to transform domain
166
0
    *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7;
167
0
    *per_px_mean = sign_sum * (*per_px_mean);
168
0
    *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse));
169
0
    *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum));
170
0
  } else {
171
0
    *block_mse_q8 = UINT_MAX;
172
0
  }
173
0
  return sse;
174
0
}
175
176
// Uses simple features on top of DCT coefficients to quickly predict
177
// whether optimal RD decision is to skip encoding the residual.
178
// The sse value is stored in dist.
179
static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
180
0
                             int reduced_tx_set) {
181
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
182
0
  const int bw = block_size_wide[bsize];
183
0
  const int bh = block_size_high[bsize];
184
0
  const MACROBLOCKD *xd = &x->e_mbd;
185
0
  const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
186
187
0
  *dist = av1_pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
188
189
0
  const int64_t mse = *dist / bw / bh;
190
  // Normalized quantizer takes the transform upscaling factor (8 for tx size
191
  // smaller than 32) into account.
192
0
  const int16_t normalized_dc_q = dc_q >> 3;
193
0
  const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
194
  // For faster early skip decision, use dist to compare against threshold so
195
  // that quality risk is less for the skip=1 decision. Otherwise, use mse
196
  // since the fwd_txfm coeff checks will take care of quality
197
  // TODO(any): Use dist to return 0 when skip_txfm_level is 1
198
0
  int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse;
199
  // Predict not to skip when error is larger than threshold.
200
0
  if (pred_err > mse_thresh) return 0;
201
  // Return as skip otherwise for aggressive early skip
202
0
  else if (txfm_params->skip_txfm_level >= 2)
203
0
    return 1;
204
205
0
  const int max_tx_size = max_predict_sf_tx_size[bsize];
206
0
  const int tx_h = tx_size_high[max_tx_size];
207
0
  const int tx_w = tx_size_wide[max_tx_size];
208
0
  DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
209
0
  TxfmParam param;
210
0
  param.tx_type = DCT_DCT;
211
0
  param.tx_size = max_tx_size;
212
0
  param.bd = xd->bd;
213
0
  param.is_hbd = is_cur_buf_hbd(xd);
214
0
  param.lossless = 0;
215
0
  param.tx_set_type = av1_get_ext_tx_set_type(
216
0
      param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
217
0
  const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
218
0
  const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
219
0
  const int16_t *src_diff = x->plane[0].src_diff;
220
0
  const int n_coeff = tx_w * tx_h;
221
0
  const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
222
0
  const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
223
0
  const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
224
0
  for (int row = 0; row < bh; row += tx_h) {
225
0
    for (int col = 0; col < bw; col += tx_w) {
226
0
      av1_fwd_txfm(src_diff + col, coefs, bw, &param);
227
      // Operating on TX domain, not pixels; we want the QTX quantizers
228
0
      const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
229
0
      if (dc_coef >= dc_thresh) return 0;
230
0
      for (int i = 1; i < n_coeff; ++i) {
231
0
        const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
232
0
        if (ac_coef >= ac_thresh) return 0;
233
0
      }
234
0
    }
235
0
    src_diff += tx_h * bw;
236
0
  }
237
0
  return 1;
238
0
}
239
240
// Used to set proper context for early termination with skip = 1.
241
static inline void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
242
0
                                 BLOCK_SIZE bsize, int64_t dist) {
243
0
  MACROBLOCKD *const xd = &x->e_mbd;
244
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
245
0
  const int n4 = bsize_to_num_blk(bsize);
246
0
  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
247
0
  memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
248
0
  memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
249
0
  mbmi->tx_size = tx_size;
250
0
  for (int i = 0; i < n4; ++i)
251
0
    set_blk_skip(x->txfm_search_info.blk_skip, 0, i, 1);
252
0
  rd_stats->skip_txfm = 1;
253
0
  if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
254
0
  rd_stats->dist = rd_stats->sse = (dist << 4);
255
  // Though decision is to make the block as skip based on luma stats,
256
  // it is possible that block becomes non skip after chroma rd. In addition
257
  // intermediate non skip costs calculated by caller function will be
258
  // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
259
  // accounted). Hence intermediate rate is populated to code the luma tx blks
260
  // as skip, the caller function based on final rd decision (i.e., skip vs
261
  // non-skip) sets the final rate accordingly. Here the rate populated
262
  // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
263
  // size possible) in the current block. Eg: For 128*128 block, rate would be
264
  // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
265
  // block as 'all zeros'
266
0
  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
267
0
  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
268
0
  av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
269
0
  ENTROPY_CONTEXT *ta = ctxa;
270
0
  ENTROPY_CONTEXT *tl = ctxl;
271
0
  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
272
0
  TXB_CTX txb_ctx;
273
0
  get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
274
0
  const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
275
0
                                .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
276
0
  rd_stats->rate = zero_blk_rate *
277
0
                   (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
278
0
                   (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
279
0
}
280
281
static inline void save_mb_rd_info(int n4, uint32_t hash,
282
                                   const MACROBLOCK *const x,
283
                                   const RD_STATS *const rd_stats,
284
0
                                   MB_RD_RECORD *mb_rd_record) {
285
0
  int index;
286
0
  if (mb_rd_record->num < RD_RECORD_BUFFER_LEN) {
287
0
    index =
288
0
        (mb_rd_record->index_start + mb_rd_record->num) % RD_RECORD_BUFFER_LEN;
289
0
    ++mb_rd_record->num;
290
0
  } else {
291
0
    index = mb_rd_record->index_start;
292
0
    mb_rd_record->index_start =
293
0
        (mb_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
294
0
  }
295
0
  MB_RD_INFO *const mb_rd_info = &mb_rd_record->mb_rd_info[index];
296
0
  const MACROBLOCKD *const xd = &x->e_mbd;
297
0
  const MB_MODE_INFO *const mbmi = xd->mi[0];
298
0
  mb_rd_info->hash_value = hash;
299
0
  mb_rd_info->tx_size = mbmi->tx_size;
300
0
  memcpy(mb_rd_info->blk_skip, x->txfm_search_info.blk_skip,
301
0
         sizeof(mb_rd_info->blk_skip[0]) * n4);
302
0
  av1_copy(mb_rd_info->inter_tx_size, mbmi->inter_tx_size);
303
0
  av1_copy_array(mb_rd_info->tx_type_map, xd->tx_type_map, n4);
304
0
  mb_rd_info->rd_stats = *rd_stats;
305
0
}
306
307
static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
308
                                 const SPEED_FEATURES *sf,
309
0
                                 int tx_size_search_method) {
310
0
  if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
311
312
0
  if (sf->tx_sf.tx_size_search_lgr_block) {
313
0
    if (mi_width > mi_size_wide[BLOCK_64X64] ||
314
0
        mi_height > mi_size_high[BLOCK_64X64])
315
0
      return MAX_VARTX_DEPTH;
316
0
  }
317
318
0
  if (is_inter) {
319
0
    return (mi_height != mi_width)
320
0
               ? sf->tx_sf.inter_tx_size_search_init_depth_rect
321
0
               : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
322
0
  } else {
323
0
    return (mi_height != mi_width)
324
0
               ? sf->tx_sf.intra_tx_size_search_init_depth_rect
325
0
               : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
326
0
  }
327
0
}
328
329
static inline void select_tx_block(
330
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
331
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
332
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
333
    RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
334
    int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode);
335
336
// NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
337
// 0: Do not collect any RD stats
338
// 1: Collect RD stats for transform units
339
// 2: Collect RD stats for partition units
340
#if CONFIG_COLLECT_RD_STATS
341
342
static inline void get_energy_distribution_fine(
343
    const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
344
    const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
345
    double *verdist) {
346
  const int bw = block_size_wide[bsize];
347
  const int bh = block_size_high[bsize];
348
  unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
349
350
  if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
351
    // Special cases: calculate 'esq' values manually, as we don't have 'vf'
352
    // functions for the 16 (very small) sub-blocks of this block.
353
    const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
354
    const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
355
    assert(bw <= 32);
356
    assert(bh <= 32);
357
    assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
358
    if (cpi->common.seq_params->use_highbitdepth) {
359
      const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
360
      const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
361
      for (int i = 0; i < bh; ++i)
362
        for (int j = 0; j < bw; ++j) {
363
          const int index = (j >> w_shift) + ((i >> h_shift) << 2);
364
          esq[index] +=
365
              (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
366
              (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
367
        }
368
    } else {
369
      for (int i = 0; i < bh; ++i)
370
        for (int j = 0; j < bw; ++j) {
371
          const int index = (j >> w_shift) + ((i >> h_shift) << 2);
372
          esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
373
                        (src[j + i * src_stride] - dst[j + i * dst_stride]);
374
        }
375
    }
376
  } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
377
    const int f_index =
378
        (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
379
    assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
380
    const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
381
    assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
382
    assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
383
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
384
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
385
                                 dst_stride, &esq[1]);
386
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
387
                                 dst_stride, &esq[2]);
388
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
389
                                 dst_stride, &esq[3]);
390
    src += bh / 4 * src_stride;
391
    dst += bh / 4 * dst_stride;
392
393
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
394
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
395
                                 dst_stride, &esq[5]);
396
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
397
                                 dst_stride, &esq[6]);
398
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
399
                                 dst_stride, &esq[7]);
400
    src += bh / 4 * src_stride;
401
    dst += bh / 4 * dst_stride;
402
403
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
404
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
405
                                 dst_stride, &esq[9]);
406
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
407
                                 dst_stride, &esq[10]);
408
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
409
                                 dst_stride, &esq[11]);
410
    src += bh / 4 * src_stride;
411
    dst += bh / 4 * dst_stride;
412
413
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
414
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
415
                                 dst_stride, &esq[13]);
416
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
417
                                 dst_stride, &esq[14]);
418
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
419
                                 dst_stride, &esq[15]);
420
  }
421
422
  double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
423
                 esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
424
                 esq[12] + esq[13] + esq[14] + esq[15];
425
  if (total > 0) {
426
    const double e_recip = 1.0 / total;
427
    hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
428
    hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
429
    hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
430
    if (need_4th) {
431
      hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
432
    }
433
    verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
434
    verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
435
    verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
436
    if (need_4th) {
437
      verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
438
    }
439
  } else {
440
    hordist[0] = verdist[0] = 0.25;
441
    hordist[1] = verdist[1] = 0.25;
442
    hordist[2] = verdist[2] = 0.25;
443
    if (need_4th) {
444
      hordist[3] = verdist[3] = 0.25;
445
    }
446
  }
447
}
448
449
static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
450
  double sum = 0.0;
451
  for (int j = 0; j < h; ++j) {
452
    for (int i = 0; i < w; ++i) {
453
      const int err = diff[j * stride + i];
454
      sum += err * err;
455
    }
456
  }
457
  assert(w > 0 && h > 0);
458
  return sum / (w * h);
459
}
460
461
static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
462
  double sum = 0.0;
463
  for (int j = 0; j < h; ++j) {
464
    for (int i = 0; i < w; ++i) {
465
      sum += abs(diff[j * stride + i]);
466
    }
467
  }
468
  assert(w > 0 && h > 0);
469
  return sum / (w * h);
470
}
471
472
static inline void get_2x2_normalized_sses_and_sads(
473
    const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
474
    int src_stride, const uint8_t *const dst, int dst_stride,
475
    const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
476
    double *const sad_norm_arr) {
477
  const BLOCK_SIZE tx_bsize_half =
478
      get_partition_subsize(tx_bsize, PARTITION_SPLIT);
479
  if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
480
    const int half_width = block_size_wide[tx_bsize] / 2;
481
    const int half_height = block_size_high[tx_bsize] / 2;
482
    for (int row = 0; row < 2; ++row) {
483
      for (int col = 0; col < 2; ++col) {
484
        const int16_t *const this_src_diff =
485
            src_diff + row * half_height * diff_stride + col * half_width;
486
        if (sse_norm_arr) {
487
          sse_norm_arr[row * 2 + col] =
488
              get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
489
        }
490
        if (sad_norm_arr) {
491
          sad_norm_arr[row * 2 + col] =
492
              get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
493
        }
494
      }
495
    }
496
  } else {  // use function pointers to calculate stats
497
    const int half_width = block_size_wide[tx_bsize_half];
498
    const int half_height = block_size_high[tx_bsize_half];
499
    const int num_samples_half = half_width * half_height;
500
    for (int row = 0; row < 2; ++row) {
501
      for (int col = 0; col < 2; ++col) {
502
        const uint8_t *const this_src =
503
            src + row * half_height * src_stride + col * half_width;
504
        const uint8_t *const this_dst =
505
            dst + row * half_height * dst_stride + col * half_width;
506
507
        if (sse_norm_arr) {
508
          unsigned int this_sse;
509
          cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
510
                                             dst_stride, &this_sse);
511
          sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
512
        }
513
514
        if (sad_norm_arr) {
515
          const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf(
516
              this_src, src_stride, this_dst, dst_stride);
517
          sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
518
        }
519
      }
520
    }
521
  }
522
}
523
524
#if CONFIG_COLLECT_RD_STATS == 1
525
static double get_mean(const int16_t *diff, int stride, int w, int h) {
526
  double sum = 0.0;
527
  for (int j = 0; j < h; ++j) {
528
    for (int i = 0; i < w; ++i) {
529
      sum += diff[j * stride + i];
530
    }
531
  }
532
  assert(w > 0 && h > 0);
533
  return sum / (w * h);
534
}
535
static inline void PrintTransformUnitStats(
536
    const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
537
    int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
538
    TX_TYPE tx_type, int64_t rd) {
539
  if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
540
541
  // Generate small sample to restrict output size.
542
  static unsigned int seed = 21743;
543
  if (lcg_rand16(&seed) % 256 > 0) return;
544
545
  const char output_file[] = "tu_stats.txt";
546
  FILE *fout = fopen(output_file, "a");
547
  if (!fout) return;
548
549
  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
550
  const MACROBLOCKD *const xd = &x->e_mbd;
551
  const int plane = 0;
552
  struct macroblock_plane *const p = &x->plane[plane];
553
  const struct macroblockd_plane *const pd = &xd->plane[plane];
554
  const int txw = tx_size_wide[tx_size];
555
  const int txh = tx_size_high[tx_size];
556
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
557
  const int q_step = p->dequant_QTX[1] >> dequant_shift;
558
  const int num_samples = txw * txh;
559
560
  const double rate_norm = (double)rd_stats->rate / num_samples;
561
  const double dist_norm = (double)rd_stats->dist / num_samples;
562
563
  fprintf(fout, "%g %g", rate_norm, dist_norm);
564
565
  const int src_stride = p->src.stride;
566
  const uint8_t *const src =
567
      &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
568
  const int dst_stride = pd->dst.stride;
569
  const uint8_t *const dst =
570
      &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
571
  unsigned int sse;
572
  cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
573
  const double sse_norm = (double)sse / num_samples;
574
575
  const unsigned int sad =
576
      cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
577
  const double sad_norm = (double)sad / num_samples;
578
579
  fprintf(fout, " %g %g", sse_norm, sad_norm);
580
581
  const int diff_stride = block_size_wide[plane_bsize];
582
  const int16_t *const src_diff =
583
      &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
584
585
  double sse_norm_arr[4], sad_norm_arr[4];
586
  get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
587
                                   dst_stride, src_diff, diff_stride,
588
                                   sse_norm_arr, sad_norm_arr);
589
  for (int i = 0; i < 4; ++i) {
590
    fprintf(fout, " %g", sse_norm_arr[i]);
591
  }
592
  for (int i = 0; i < 4; ++i) {
593
    fprintf(fout, " %g", sad_norm_arr[i]);
594
  }
595
596
  const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
597
  const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
598
599
  fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
600
          tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
601
602
  int model_rate;
603
  int64_t model_dist;
604
  model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
605
                                   &model_rate, &model_dist);
606
  const double model_rate_norm = (double)model_rate / num_samples;
607
  const double model_dist_norm = (double)model_dist / num_samples;
608
  fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
609
610
  const double mean = get_mean(src_diff, diff_stride, txw, txh);
611
  float hor_corr, vert_corr;
612
  av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
613
                                  &vert_corr);
614
  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
615
616
  double hdist[4] = { 0 }, vdist[4] = { 0 };
617
  get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
618
                               1, hdist, vdist);
619
  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
620
          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
621
622
  fprintf(fout, " %d %" PRId64, x->rdmult, rd);
623
624
  fprintf(fout, "\n");
625
  fclose(fout);
626
}
627
#endif  // CONFIG_COLLECT_RD_STATS == 1
628
629
#if CONFIG_COLLECT_RD_STATS >= 2
630
static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
631
  const AV1_COMMON *cm = &cpi->common;
632
  const int num_planes = av1_num_planes(cm);
633
  const MACROBLOCKD *xd = &x->e_mbd;
634
  const MB_MODE_INFO *mbmi = xd->mi[0];
635
  int64_t total_sse = 0;
636
  for (int plane = 0; plane < num_planes; ++plane) {
637
    const struct macroblock_plane *const p = &x->plane[plane];
638
    const struct macroblockd_plane *const pd = &xd->plane[plane];
639
    const BLOCK_SIZE bs =
640
        get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
641
    unsigned int sse;
642
643
    if (plane) continue;
644
645
    cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf,
646
                            pd->dst.stride, &sse);
647
    total_sse += sse;
648
  }
649
  total_sse <<= 4;
650
  return total_sse;
651
}
652
653
static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
654
                             int64_t sse, int *est_residue_cost,
655
                             int64_t *est_dist) {
656
  const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
657
  if (md->ready) {
658
    if (sse < md->dist_mean) {
659
      *est_residue_cost = 0;
660
      *est_dist = sse;
661
    } else {
662
      *est_dist = (int64_t)round(md->dist_mean);
663
      const double est_ld = md->a * sse + md->b;
664
      // Clamp estimated rate cost by INT_MAX / 2.
665
      // TODO(angiebird@google.com): find better solution than clamping.
666
      if (fabs(est_ld) < 1e-2) {
667
        *est_residue_cost = INT_MAX / 2;
668
      } else {
669
        double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
670
        if (est_residue_cost_dbl < 0) {
671
          *est_residue_cost = 0;
672
        } else {
673
          *est_residue_cost =
674
              (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
675
        }
676
      }
677
      if (*est_residue_cost <= 0) {
678
        *est_residue_cost = 0;
679
        *est_dist = sse;
680
      }
681
    }
682
    return 1;
683
  }
684
  return 0;
685
}
686
687
static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
688
                                   const uint8_t *dst8, int dst_stride, int w,
689
                                   int h) {
690
  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
691
  const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
692
  double sum = 0.0;
693
  for (int j = 0; j < h; ++j) {
694
    for (int i = 0; i < w; ++i) {
695
      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
696
      sum += diff;
697
    }
698
  }
699
  assert(w > 0 && h > 0);
700
  return sum / (w * h);
701
}
702
703
static double get_diff_mean(const uint8_t *src, int src_stride,
704
                            const uint8_t *dst, int dst_stride, int w, int h) {
705
  double sum = 0.0;
706
  for (int j = 0; j < h; ++j) {
707
    for (int i = 0; i < w; ++i) {
708
      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
709
      sum += diff;
710
    }
711
  }
712
  assert(w > 0 && h > 0);
713
  return sum / (w * h);
714
}
715
716
static inline void PrintPredictionUnitStats(const AV1_COMP *const cpi,
717
                                            const TileDataEnc *tile_data,
718
                                            MACROBLOCK *x,
719
                                            const RD_STATS *const rd_stats,
720
                                            BLOCK_SIZE plane_bsize) {
721
  if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
722
723
  if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
724
      (tile_data == NULL ||
725
       !tile_data->inter_mode_rd_models[plane_bsize].ready))
726
    return;
727
  (void)tile_data;
728
  // Generate small sample to restrict output size.
729
  static unsigned int seed = 95014;
730
731
  if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
732
      1)
733
    return;
734
735
  const char output_file[] = "pu_stats.txt";
736
  FILE *fout = fopen(output_file, "a");
737
  if (!fout) return;
738
739
  MACROBLOCKD *const xd = &x->e_mbd;
740
  const int plane = 0;
741
  struct macroblock_plane *const p = &x->plane[plane];
742
  struct macroblockd_plane *pd = &xd->plane[plane];
743
  const int diff_stride = block_size_wide[plane_bsize];
744
  int bw, bh;
745
  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
746
                     &bh);
747
  const int num_samples = bw * bh;
748
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
749
  const int q_step = p->dequant_QTX[1] >> dequant_shift;
750
  const int shift = (xd->bd - 8);
751
752
  const double rate_norm = (double)rd_stats->rate / num_samples;
753
  const double dist_norm = (double)rd_stats->dist / num_samples;
754
  const double rdcost_norm =
755
      (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
756
757
  fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
758
759
  const int src_stride = p->src.stride;
760
  const uint8_t *const src = p->src.buf;
761
  const int dst_stride = pd->dst.stride;
762
  const uint8_t *const dst = pd->dst.buf;
763
  const int16_t *const src_diff = p->src_diff;
764
765
  int64_t sse = calculate_sse(xd, p, pd, bw, bh);
766
  const double sse_norm = (double)sse / num_samples;
767
768
  const unsigned int sad =
769
      cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
770
  const double sad_norm =
771
      (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
772
773
  fprintf(fout, " %g %g", sse_norm, sad_norm);
774
775
  double sse_norm_arr[4], sad_norm_arr[4];
776
  get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
777
                                   dst_stride, src_diff, diff_stride,
778
                                   sse_norm_arr, sad_norm_arr);
779
  if (shift) {
780
    for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
781
    for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
782
  }
783
  for (int i = 0; i < 4; ++i) {
784
    fprintf(fout, " %g", sse_norm_arr[i]);
785
  }
786
  for (int i = 0; i < 4; ++i) {
787
    fprintf(fout, " %g", sad_norm_arr[i]);
788
  }
789
790
  fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
791
792
  int model_rate;
793
  int64_t model_dist;
794
  model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
795
                                   &model_rate, &model_dist);
796
  const double model_rdcost_norm =
797
      (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
798
  const double model_rate_norm = (double)model_rate / num_samples;
799
  const double model_dist_norm = (double)model_dist / num_samples;
800
  fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
801
          model_rdcost_norm);
802
803
  double mean;
804
  if (is_cur_buf_hbd(xd)) {
805
    mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
806
                                pd->dst.stride, bw, bh);
807
  } else {
808
    mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
809
                         bw, bh);
810
  }
811
  mean /= (1 << shift);
812
  float hor_corr, vert_corr;
813
  av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
814
                                  &vert_corr);
815
  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
816
817
  double hdist[4] = { 0 }, vdist[4] = { 0 };
818
  get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
819
                               dst_stride, 1, hdist, vdist);
820
  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
821
          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
822
823
  if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
824
    assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
825
    const int64_t overall_sse = get_sse(cpi, x);
826
    int est_residue_cost = 0;
827
    int64_t est_dist = 0;
828
    get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
829
                      &est_dist);
830
    const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
831
    const double est_dist_norm = (double)est_dist / num_samples;
832
    const double est_rdcost_norm =
833
        (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
834
    fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
835
            est_rdcost_norm);
836
  }
837
838
  fprintf(fout, "\n");
839
  fclose(fout);
840
}
841
#endif  // CONFIG_COLLECT_RD_STATS >= 2
842
#endif  // CONFIG_COLLECT_RD_STATS
843
844
static inline void inverse_transform_block_facade(MACROBLOCK *const x,
845
                                                  int plane, int block,
846
                                                  int blk_row, int blk_col,
847
0
                                                  int eob, int reduced_tx_set) {
848
0
  if (!eob) return;
849
0
  struct macroblock_plane *const p = &x->plane[plane];
850
0
  MACROBLOCKD *const xd = &x->e_mbd;
851
0
  tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
852
0
  const PLANE_TYPE plane_type = get_plane_type(plane);
853
0
  const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
854
0
  const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
855
0
                                          tx_size, reduced_tx_set);
856
857
0
  struct macroblockd_plane *const pd = &xd->plane[plane];
858
0
  const int dst_stride = pd->dst.stride;
859
0
  uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
860
0
  av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
861
0
                              dst_stride, eob, reduced_tx_set);
862
0
}
863
864
static inline void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
865
                               int block, int blk_row, int blk_col,
866
                               BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
867
                               const TXB_CTX *const txb_ctx, int skip_trellis,
868
                               TX_TYPE best_tx_type, int do_quant,
869
0
                               int *rate_cost, uint16_t best_eob) {
870
0
  const AV1_COMMON *cm = &cpi->common;
871
0
  MACROBLOCKD *xd = &x->e_mbd;
872
0
  MB_MODE_INFO *mbmi = xd->mi[0];
873
0
  const int is_inter = is_inter_block(mbmi);
874
0
  if (!is_inter && best_eob &&
875
0
      (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
876
0
       blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
877
    // if the quantized coefficients are stored in the dqcoeff buffer, we don't
878
    // need to do transform and quantization again.
879
0
    if (do_quant) {
880
0
      TxfmParam txfm_param_intra;
881
0
      QUANT_PARAM quant_param_intra;
882
0
      av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
883
0
      av1_setup_quant(tx_size, !skip_trellis,
884
0
                      skip_trellis
885
0
                          ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
886
0
                                                    : AV1_XFORM_QUANT_FP)
887
0
                          : AV1_XFORM_QUANT_FP,
888
0
                      cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra);
889
0
      av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
890
0
                        &quant_param_intra);
891
0
      av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
892
0
                      &txfm_param_intra, &quant_param_intra);
893
0
      if (quant_param_intra.use_optimize_b) {
894
0
        av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
895
0
                       rate_cost);
896
0
      }
897
0
    }
898
899
0
    inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
900
0
                                   x->plane[plane].eobs[block],
901
0
                                   cm->features.reduced_tx_set_used);
902
903
    // This may happen because of hash collision. The eob stored in the hash
904
    // table is non-zero, but the real eob is zero. We need to make sure tx_type
905
    // is DCT_DCT in this case.
906
0
    if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
907
0
        best_tx_type != DCT_DCT) {
908
0
      update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
909
0
    }
910
0
  }
911
0
}
912
913
static unsigned pixel_dist_visible_only(
914
    const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
915
    const int src_stride, const uint8_t *dst, const int dst_stride,
916
    const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
917
0
    int visible_cols) {
918
0
  unsigned sse;
919
920
0
  if (txb_rows == visible_rows && txb_cols == visible_cols) {
921
0
    cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
922
0
    return sse;
923
0
  }
924
925
0
#if CONFIG_AV1_HIGHBITDEPTH
926
0
  const MACROBLOCKD *xd = &x->e_mbd;
927
0
  if (is_cur_buf_hbd(xd)) {
928
0
    uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
929
0
                                             visible_cols, visible_rows);
930
0
    return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
931
0
  }
932
#else
933
  (void)x;
934
#endif
935
0
  sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
936
0
                         visible_rows);
937
0
  return sse;
938
0
}
939
940
// Compute the pixel domain distortion from src and dst on all visible 4x4s in
941
// the
942
// transform block.
943
static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
944
                           int plane, const uint8_t *src, const int src_stride,
945
                           const uint8_t *dst, const int dst_stride,
946
                           int blk_row, int blk_col,
947
                           const BLOCK_SIZE plane_bsize,
948
0
                           const BLOCK_SIZE tx_bsize) {
949
0
  int txb_rows, txb_cols, visible_rows, visible_cols;
950
0
  const MACROBLOCKD *xd = &x->e_mbd;
951
952
0
  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
953
0
                     &txb_cols, &txb_rows, &visible_cols, &visible_rows);
954
0
  assert(visible_rows > 0);
955
0
  assert(visible_cols > 0);
956
957
0
  unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
958
0
                                         dst_stride, tx_bsize, txb_rows,
959
0
                                         txb_cols, visible_rows, visible_cols);
960
961
0
  return sse;
962
0
}
963
964
static inline int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
965
                                           int plane, BLOCK_SIZE plane_bsize,
966
                                           int block, int blk_row, int blk_col,
967
0
                                           TX_SIZE tx_size) {
968
0
  MACROBLOCKD *const xd = &x->e_mbd;
969
0
  const struct macroblock_plane *const p = &x->plane[plane];
970
0
  const uint16_t eob = p->eobs[block];
971
0
  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
972
0
  const int bsw = block_size_wide[tx_bsize];
973
0
  const int bsh = block_size_high[tx_bsize];
974
0
  const int src_stride = x->plane[plane].src.stride;
975
0
  const int dst_stride = xd->plane[plane].dst.stride;
976
  // Scale the transform block index to pixel unit.
977
0
  const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
978
0
  const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
979
0
  const uint8_t *src = &x->plane[plane].src.buf[src_idx];
980
0
  const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
981
0
  const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
982
983
0
  assert(cpi != NULL);
984
0
  assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
985
986
0
  uint8_t *recon;
987
0
  DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
988
989
0
#if CONFIG_AV1_HIGHBITDEPTH
990
0
  if (is_cur_buf_hbd(xd)) {
991
0
    recon = CONVERT_TO_BYTEPTR(recon16);
992
0
    aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride,
993
0
                             CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh);
994
0
  } else {
995
0
    recon = (uint8_t *)recon16;
996
0
    aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
997
0
  }
998
#else
999
  recon = (uint8_t *)recon16;
1000
  aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1001
#endif
1002
1003
0
  const PLANE_TYPE plane_type = get_plane_type(plane);
1004
0
  TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
1005
0
                                    cpi->common.features.reduced_tx_set_used);
1006
0
  av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
1007
0
                              MAX_TX_SIZE, eob,
1008
0
                              cpi->common.features.reduced_tx_set_used);
1009
1010
0
  return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
1011
0
                         blk_row, blk_col, plane_bsize, tx_bsize);
1012
0
}
1013
1014
// pruning thresholds for prune_txk_type and prune_txk_type_separ
1015
static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
1016
static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
1017
1018
// R-D costs are sorted in ascending order.
1019
0
static inline void sort_rd(int64_t rds[], int txk[], int len) {
1020
0
  int i, j, k;
1021
1022
0
  for (i = 1; i <= len - 1; ++i) {
1023
0
    for (j = 0; j < i; ++j) {
1024
0
      if (rds[j] > rds[i]) {
1025
0
        int64_t temprd;
1026
0
        int tempi;
1027
1028
0
        temprd = rds[i];
1029
0
        tempi = txk[i];
1030
1031
0
        for (k = i; k > j; k--) {
1032
0
          rds[k] = rds[k - 1];
1033
0
          txk[k] = txk[k - 1];
1034
0
        }
1035
1036
0
        rds[j] = temprd;
1037
0
        txk[j] = tempi;
1038
0
        break;
1039
0
      }
1040
0
    }
1041
0
  }
1042
0
}
1043
1044
static inline int64_t av1_block_error_qm(
1045
    const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size,
1046
0
    const qm_val_t *qmatrix, const int16_t *scan, int64_t *ssz, int bd) {
1047
0
  int i;
1048
0
  int64_t error = 0, sqcoeff = 0;
1049
0
  int shift = 2 * (bd - 8);
1050
0
  int rounding = (1 << shift) >> 1;
1051
1052
0
  for (i = 0; i < block_size; i++) {
1053
0
    int64_t weight = qmatrix[scan[i]];
1054
0
    int64_t dd = coeff[i] - dqcoeff[i];
1055
0
    dd *= weight;
1056
0
    int64_t cc = coeff[i];
1057
0
    cc *= weight;
1058
    // The ranges of coeff and dqcoeff are
1059
    //  bd8 : 18 bits (including sign)
1060
    //  bd10: 20 bits (including sign)
1061
    //  bd12: 22 bits (including sign)
1062
    // As AOM_QM_BITS is 5, the intermediate quantities in the calculation
1063
    // below should fit in 54 bits, thus no overflow should happen.
1064
0
    error += (dd * dd + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1065
0
    sqcoeff += (cc * cc + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1066
0
  }
1067
1068
0
  error = (error + rounding) >> shift;
1069
0
  sqcoeff = (sqcoeff + rounding) >> shift;
1070
1071
0
  *ssz = sqcoeff;
1072
0
  return error;
1073
0
}
1074
1075
static inline void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
1076
                                        TX_SIZE tx_size,
1077
                                        const qm_val_t *qmatrix,
1078
                                        const int16_t *scan, int64_t *out_dist,
1079
0
                                        int64_t *out_sse) {
1080
0
  const struct macroblock_plane *const p = &x->plane[plane];
1081
  // Transform domain distortion computation is more efficient as it does
1082
  // not involve an inverse transform, but it is less accurate.
1083
0
  const int buffer_length = av1_get_max_eob(tx_size);
1084
0
  int64_t this_sse;
1085
  // TX-domain results need to shift down to Q2/D10 to match pixel
1086
  // domain distortion values which are in Q2^2
1087
0
  int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
1088
0
  const int block_offset = BLOCK_OFFSET(block);
1089
0
  tran_low_t *const coeff = p->coeff + block_offset;
1090
0
  tran_low_t *const dqcoeff = p->dqcoeff + block_offset;
1091
0
#if CONFIG_AV1_HIGHBITDEPTH
1092
0
  MACROBLOCKD *const xd = &x->e_mbd;
1093
0
  if (is_cur_buf_hbd(xd)) {
1094
0
    if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
1095
0
      *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length,
1096
0
                                         &this_sse, xd->bd);
1097
0
    } else {
1098
0
      *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
1099
0
                                     scan, &this_sse, xd->bd);
1100
0
    }
1101
0
  } else {
1102
0
#endif
1103
0
    if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
1104
0
      *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1105
0
    } else {
1106
0
      *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
1107
0
                                     scan, &this_sse, 8);
1108
0
    }
1109
0
#if CONFIG_AV1_HIGHBITDEPTH
1110
0
  }
1111
0
#endif
1112
1113
0
  *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
1114
0
  *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
1115
0
}
1116
1117
static uint16_t prune_txk_type_separ(
1118
    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, TX_SIZE tx_size,
1119
    int blk_row, int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
1120
    int16_t allowed_tx_mask, int prune_factor, const TXB_CTX *const txb_ctx,
1121
0
    int reduced_tx_set_used, int64_t ref_best_rd, int num_sel) {
1122
0
  const AV1_COMMON *cm = &cpi->common;
1123
0
  MACROBLOCKD *xd = &x->e_mbd;
1124
1125
0
  int idx;
1126
1127
0
  int64_t rds_v[4];
1128
0
  int64_t rds_h[4];
1129
0
  int idx_v[4] = { 0, 1, 2, 3 };
1130
0
  int idx_h[4] = { 0, 1, 2, 3 };
1131
0
  int skip_v[4] = { 0 };
1132
0
  int skip_h[4] = { 0 };
1133
0
  const int idx_map[16] = {
1134
0
    DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1135
0
    ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1136
0
    FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1137
0
    H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1138
0
  };
1139
1140
0
  const int sel_pattern_v[16] = {
1141
0
    0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
1142
0
  };
1143
0
  const int sel_pattern_h[16] = {
1144
0
    0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
1145
0
  };
1146
1147
0
  QUANT_PARAM quant_param;
1148
0
  TxfmParam txfm_param;
1149
0
  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1150
0
  av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1151
0
                  &quant_param);
1152
0
  int tx_type;
1153
  // to ensure we can try ones even outside of ext_tx_set of current block
1154
  // this function should only be called for size < 16
1155
0
  assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
1156
0
  txfm_param.tx_set_type = EXT_TX_SET_ALL16;
1157
1158
0
  int rate_cost = 0;
1159
0
  int64_t dist = 0, sse = 0;
1160
  // evaluate horizontal with vertical DCT
1161
0
  for (idx = 0; idx < 4; ++idx) {
1162
0
    tx_type = idx_map[idx];
1163
0
    txfm_param.tx_type = tx_type;
1164
1165
0
    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1166
0
                      &quant_param);
1167
1168
0
    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1169
0
                    &quant_param);
1170
1171
0
    const SCAN_ORDER *const scan_order =
1172
0
        get_scan(txfm_param.tx_size, txfm_param.tx_type);
1173
0
    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1174
0
                         scan_order->scan, &dist, &sse);
1175
1176
0
    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1177
0
                                              txb_ctx, reduced_tx_set_used, 0);
1178
1179
0
    rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
1180
1181
0
    if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
1182
0
      skip_h[idx] = 1;
1183
0
    }
1184
0
  }
1185
0
  sort_rd(rds_h, idx_h, 4);
1186
0
  for (idx = 1; idx < 4; idx++) {
1187
0
    if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
1188
0
  }
1189
1190
0
  if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
1191
1192
  // evaluate vertical with the best horizontal chosen
1193
0
  rds_v[0] = rds_h[0];
1194
0
  int start_v = 1, end_v = 4;
1195
0
  const int *idx_map_v = idx_map + idx_h[0];
1196
1197
0
  for (idx = start_v; idx < end_v; ++idx) {
1198
0
    tx_type = idx_map_v[idx_v[idx] * 4];
1199
0
    txfm_param.tx_type = tx_type;
1200
1201
0
    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1202
0
                      &quant_param);
1203
1204
0
    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1205
0
                    &quant_param);
1206
1207
0
    const SCAN_ORDER *const scan_order =
1208
0
        get_scan(txfm_param.tx_size, txfm_param.tx_type);
1209
0
    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1210
0
                         scan_order->scan, &dist, &sse);
1211
1212
0
    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1213
0
                                              txb_ctx, reduced_tx_set_used, 0);
1214
1215
0
    rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
1216
1217
0
    if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
1218
0
      skip_v[idx] = 1;
1219
0
    }
1220
0
  }
1221
0
  sort_rd(rds_v, idx_v, 4);
1222
0
  for (idx = 1; idx < 4; idx++) {
1223
0
    if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
1224
0
  }
1225
1226
  // combine rd_h and rd_v to prune tx candidates
1227
0
  int i_v, i_h;
1228
0
  int64_t rds[16];
1229
0
  int num_cand = 0, last = TX_TYPES - 1;
1230
1231
0
  for (int i = 0; i < 16; i++) {
1232
0
    i_v = sel_pattern_v[i];
1233
0
    i_h = sel_pattern_h[i];
1234
0
    tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
1235
0
    if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
1236
0
        skip_v[idx_v[i_v]]) {
1237
0
      txk_map[last] = tx_type;
1238
0
      last--;
1239
0
    } else {
1240
0
      txk_map[num_cand] = tx_type;
1241
0
      rds[num_cand] = rds_v[i_v] + rds_h[i_h];
1242
0
      if (rds[num_cand] == 0) rds[num_cand] = 1;
1243
0
      num_cand++;
1244
0
    }
1245
0
  }
1246
0
  sort_rd(rds, txk_map, num_cand);
1247
1248
0
  uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1249
0
  num_sel = AOMMIN(num_sel, num_cand);
1250
1251
0
  for (int i = 1; i < num_sel; i++) {
1252
0
    int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
1253
0
    if (factor < (int64_t)prune_factor)
1254
0
      prune &= ~(1 << txk_map[i]);
1255
0
    else
1256
0
      break;
1257
0
  }
1258
0
  return prune;
1259
0
}
1260
1261
static uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1262
                               int block, TX_SIZE tx_size, int blk_row,
1263
                               int blk_col, BLOCK_SIZE plane_bsize,
1264
                               int *txk_map, uint16_t allowed_tx_mask,
1265
                               int prune_factor, const TXB_CTX *const txb_ctx,
1266
0
                               int reduced_tx_set_used) {
1267
0
  const AV1_COMMON *cm = &cpi->common;
1268
0
  MACROBLOCKD *xd = &x->e_mbd;
1269
0
  int tx_type;
1270
1271
0
  int64_t rds[TX_TYPES];
1272
1273
0
  int num_cand = 0;
1274
0
  int last = TX_TYPES - 1;
1275
1276
0
  TxfmParam txfm_param;
1277
0
  QUANT_PARAM quant_param;
1278
0
  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1279
0
  av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1280
0
                  &quant_param);
1281
1282
0
  for (int idx = 0; idx < TX_TYPES; idx++) {
1283
0
    tx_type = idx;
1284
0
    int rate_cost = 0;
1285
0
    int64_t dist = 0, sse = 0;
1286
0
    if (!(allowed_tx_mask & (1 << tx_type))) {
1287
0
      txk_map[last] = tx_type;
1288
0
      last--;
1289
0
      continue;
1290
0
    }
1291
0
    txfm_param.tx_type = tx_type;
1292
1293
0
    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1294
0
                      &quant_param);
1295
1296
    // do txfm and quantization
1297
0
    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1298
0
                    &quant_param);
1299
    // estimate rate cost
1300
0
    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1301
0
                                              txb_ctx, reduced_tx_set_used, 0);
1302
    // tx domain dist
1303
0
    const SCAN_ORDER *const scan_order =
1304
0
        get_scan(txfm_param.tx_size, txfm_param.tx_type);
1305
0
    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1306
0
                         scan_order->scan, &dist, &sse);
1307
1308
0
    txk_map[num_cand] = tx_type;
1309
0
    rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
1310
0
    if (rds[num_cand] == 0) rds[num_cand] = 1;
1311
0
    num_cand++;
1312
0
  }
1313
1314
0
  if (num_cand == 0) return (uint16_t)0xFFFF;
1315
1316
0
  sort_rd(rds, txk_map, num_cand);
1317
0
  uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1318
1319
  // 0 < prune_factor <= 1000 controls aggressiveness
1320
0
  int64_t factor = 0;
1321
0
  for (int idx = 1; idx < num_cand; idx++) {
1322
0
    factor = 1000 * (rds[idx] - rds[0]) / rds[0];
1323
0
    if (factor < (int64_t)prune_factor)
1324
0
      prune &= ~(1 << txk_map[idx]);
1325
0
    else
1326
0
      break;
1327
0
  }
1328
0
  return prune;
1329
0
}
1330
1331
// These thresholds were calibrated to provide a certain number of TX types
1332
// pruned by the model on average, i.e. selecting a threshold with index i
1333
// will lead to pruning i+1 TX types on average
1334
static const float *prune_2D_adaptive_thresholds[] = {
1335
  // TX_4X4
1336
  (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1337
             0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1338
             0.09778f, 0.11780f },
1339
  // TX_8X8
1340
  (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1341
             0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1342
             0.10803f, 0.14124f },
1343
  // TX_16X16
1344
  (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1345
             0.06897f, 0.07629f, 0.08875f, 0.11169f },
1346
  // TX_32X32
1347
  NULL,
1348
  // TX_64X64
1349
  NULL,
1350
  // TX_4X8
1351
  (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1352
             0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1353
             0.10168f, 0.12585f },
1354
  // TX_8X4
1355
  (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1356
             0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1357
             0.10583f, 0.13123f },
1358
  // TX_8X16
1359
  (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1360
             0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1361
             0.10730f, 0.14221f },
1362
  // TX_16X8
1363
  (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1364
             0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1365
             0.10339f, 0.13464f },
1366
  // TX_16X32
1367
  NULL,
1368
  // TX_32X16
1369
  NULL,
1370
  // TX_32X64
1371
  NULL,
1372
  // TX_64X32
1373
  NULL,
1374
  // TX_4X16
1375
  (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1376
             0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1377
             0.10242f, 0.12878f },
1378
  // TX_16X4
1379
  (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1380
             0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1381
             0.10217f, 0.12610f },
1382
  // TX_8X32
1383
  NULL,
1384
  // TX_32X8
1385
  NULL,
1386
  // TX_16X64
1387
  NULL,
1388
  // TX_64X16
1389
  NULL,
1390
};
1391
1392
static inline float get_adaptive_thresholds(
1393
    TX_SIZE tx_size, TxSetType tx_set_type,
1394
0
    TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {
1395
0
  const int prune_aggr_table[5][2] = {
1396
0
    { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 }
1397
0
  };
1398
0
  int pruning_aggressiveness = 0;
1399
0
  if (tx_set_type == EXT_TX_SET_ALL16)
1400
0
    pruning_aggressiveness =
1401
0
        prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0];
1402
0
  else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
1403
0
    pruning_aggressiveness =
1404
0
        prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1];
1405
1406
0
  return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
1407
0
}
1408
1409
static inline void get_energy_distribution_finer(const int16_t *diff,
1410
                                                 int stride, int bw, int bh,
1411
                                                 float *hordist,
1412
0
                                                 float *verdist) {
1413
  // First compute downscaled block energy values (esq); downscale factors
1414
  // are defined by w_shift and h_shift.
1415
0
  unsigned int esq[256];
1416
0
  const int w_shift = bw <= 8 ? 0 : 1;
1417
0
  const int h_shift = bh <= 8 ? 0 : 1;
1418
0
  const int esq_w = bw >> w_shift;
1419
0
  const int esq_h = bh >> h_shift;
1420
0
  const int esq_sz = esq_w * esq_h;
1421
0
  int i, j;
1422
0
  memset(esq, 0, esq_sz * sizeof(esq[0]));
1423
0
  if (w_shift) {
1424
0
    for (i = 0; i < bh; i++) {
1425
0
      unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1426
0
      const int16_t *cur_diff_row = diff + i * stride;
1427
0
      for (j = 0; j < bw; j += 2) {
1428
0
        cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1429
0
                                cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1430
0
      }
1431
0
    }
1432
0
  } else {
1433
0
    for (i = 0; i < bh; i++) {
1434
0
      unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1435
0
      const int16_t *cur_diff_row = diff + i * stride;
1436
0
      for (j = 0; j < bw; j++) {
1437
0
        cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1438
0
      }
1439
0
    }
1440
0
  }
1441
1442
0
  uint64_t total = 0;
1443
0
  for (i = 0; i < esq_sz; i++) total += esq[i];
1444
1445
  // Output hordist and verdist arrays are normalized 1D projections of esq
1446
0
  if (total == 0) {
1447
0
    float hor_val = 1.0f / esq_w;
1448
0
    for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1449
0
    float ver_val = 1.0f / esq_h;
1450
0
    for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1451
0
    return;
1452
0
  }
1453
1454
0
  const float e_recip = 1.0f / (float)total;
1455
0
  memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1456
0
  memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1457
0
  const unsigned int *cur_esq_row;
1458
0
  for (i = 0; i < esq_h - 1; i++) {
1459
0
    cur_esq_row = esq + i * esq_w;
1460
0
    for (j = 0; j < esq_w - 1; j++) {
1461
0
      hordist[j] += (float)cur_esq_row[j];
1462
0
      verdist[i] += (float)cur_esq_row[j];
1463
0
    }
1464
0
    verdist[i] += (float)cur_esq_row[j];
1465
0
  }
1466
0
  cur_esq_row = esq + i * esq_w;
1467
0
  for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1468
1469
0
  for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1470
0
  for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1471
0
}
1472
1473
0
static inline bool check_bit_mask(uint16_t mask, int val) {
1474
0
  return mask & (1 << val);
1475
0
}
1476
1477
0
static inline void set_bit_mask(uint16_t *mask, int val) {
1478
0
  *mask |= (1 << val);
1479
0
}
1480
1481
0
static inline void unset_bit_mask(uint16_t *mask, int val) {
1482
0
  *mask &= ~(1 << val);
1483
0
}
1484
1485
static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1486
                        int blk_row, int blk_col, TxSetType tx_set_type,
1487
                        TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
1488
0
                        uint16_t *allowed_tx_mask) {
1489
  // This table is used because the search order is different from the enum
1490
  // order.
1491
0
  static const int tx_type_table_2D[16] = {
1492
0
    DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1493
0
    ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1494
0
    FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1495
0
    H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1496
0
  };
1497
0
  if (tx_set_type != EXT_TX_SET_ALL16 &&
1498
0
      tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1499
0
    return;
1500
#if CONFIG_NN_V2
1501
  NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1502
  NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1503
#else
1504
0
  const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1505
0
  const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1506
0
#endif
1507
0
  if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
1508
1509
0
  float hfeatures[16], vfeatures[16];
1510
0
  float hscores[4], vscores[4];
1511
0
  float scores_2D_raw[16];
1512
0
  const int bw = tx_size_wide[tx_size];
1513
0
  const int bh = tx_size_high[tx_size];
1514
0
  const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1515
0
  const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1516
0
  assert(hfeatures_num <= 16);
1517
0
  assert(vfeatures_num <= 16);
1518
1519
0
  const struct macroblock_plane *const p = &x->plane[0];
1520
0
  const int diff_stride = block_size_wide[bsize];
1521
0
  const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1522
0
  get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1523
0
                                vfeatures);
1524
1525
0
  av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1526
0
                                  &hfeatures[hfeatures_num - 1],
1527
0
                                  &vfeatures[vfeatures_num - 1]);
1528
1529
#if CONFIG_NN_V2
1530
  av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
1531
  av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
1532
#else
1533
0
  av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
1534
0
  av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
1535
0
#endif
1536
1537
0
  for (int i = 0; i < 4; i++) {
1538
0
    float *cur_scores_2D = scores_2D_raw + i * 4;
1539
0
    cur_scores_2D[0] = vscores[i] * hscores[0];
1540
0
    cur_scores_2D[1] = vscores[i] * hscores[1];
1541
0
    cur_scores_2D[2] = vscores[i] * hscores[2];
1542
0
    cur_scores_2D[3] = vscores[i] * hscores[3];
1543
0
  }
1544
1545
0
  assert(TX_TYPES == 16);
1546
  // This version of the function only works when there are at most 16 classes.
1547
  // So we will need to change the optimization or use av1_nn_softmax instead if
1548
  // this ever gets changed.
1549
0
  av1_nn_fast_softmax_16(scores_2D_raw, scores_2D_raw);
1550
1551
0
  const float score_thresh =
1552
0
      get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
1553
1554
  // Always keep the TX type with the highest score, prune all others with
1555
  // score below score_thresh.
1556
0
  int max_score_i = 0;
1557
0
  float max_score = 0.0f;
1558
0
  uint16_t allow_bitmask = 0;
1559
0
  float sum_score = 0.0;
1560
  // Calculate sum of allowed tx type score and Populate allow bit mask based
1561
  // on score_thresh and allowed_tx_mask
1562
0
  int allow_count = 0;
1563
0
  int tx_type_allowed[16] = { TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1564
0
                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1565
0
                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1566
0
                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1567
0
                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1568
0
                              TX_TYPE_INVALID };
1569
0
  float scores_2D[16] = {
1570
0
    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1571
0
  };
1572
0
  for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1573
0
    const int allow_tx_type =
1574
0
        check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]);
1575
0
    if (!allow_tx_type) {
1576
0
      continue;
1577
0
    }
1578
0
    if (scores_2D_raw[tx_idx] > max_score) {
1579
0
      max_score = scores_2D_raw[tx_idx];
1580
0
      max_score_i = tx_idx;
1581
0
    }
1582
0
    if (scores_2D_raw[tx_idx] >= score_thresh) {
1583
      // Set allow mask based on score_thresh
1584
0
      set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
1585
1586
      // Accumulate score of allowed tx type
1587
0
      sum_score += scores_2D_raw[tx_idx];
1588
1589
0
      scores_2D[allow_count] = scores_2D_raw[tx_idx];
1590
0
      tx_type_allowed[allow_count] = tx_type_table_2D[tx_idx];
1591
0
      allow_count += 1;
1592
0
    }
1593
0
  }
1594
0
  if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) {
1595
    // If even the tx_type with max score is pruned, this means that no other
1596
    // tx_type is feasible. When this happens, we force enable max_score_i and
1597
    // end the search.
1598
0
    set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]);
1599
0
    memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
1600
0
    *allowed_tx_mask = allow_bitmask;
1601
0
    return;
1602
0
  }
1603
1604
  // Sort tx type probability of all types
1605
0
  if (allow_count <= 8) {
1606
0
    av1_sort_fi32_8(scores_2D, tx_type_allowed);
1607
0
  } else {
1608
0
    av1_sort_fi32_16(scores_2D, tx_type_allowed);
1609
0
  }
1610
1611
  // Enable more pruning based on tx type probability and number of allowed tx
1612
  // types
1613
0
  if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) {
1614
0
    float temp_score = 0.0;
1615
0
    float score_ratio = 0.0;
1616
0
    int tx_idx, tx_count = 0;
1617
0
    const float inv_sum_score = 100 / sum_score;
1618
    // Get allowed tx types based on sorted probability score and tx count
1619
0
    for (tx_idx = 0; tx_idx < allow_count; tx_idx++) {
1620
      // Skip the tx type which has more than 30% of cumulative
1621
      // probability and allowed tx type count is more than 2
1622
0
      if (score_ratio > 30.0 && tx_count >= 2) break;
1623
1624
0
      assert(check_bit_mask(allow_bitmask, tx_type_allowed[tx_idx]));
1625
      // Calculate cumulative probability
1626
0
      temp_score += scores_2D[tx_idx];
1627
1628
      // Calculate percentage of cumulative probability of allowed tx type
1629
0
      score_ratio = temp_score * inv_sum_score;
1630
0
      tx_count++;
1631
0
    }
1632
    // Set remaining tx types as pruned
1633
0
    for (; tx_idx < allow_count; tx_idx++)
1634
0
      unset_bit_mask(&allow_bitmask, tx_type_allowed[tx_idx]);
1635
0
  }
1636
1637
0
  memcpy(txk_map, tx_type_allowed, sizeof(tx_type_table_2D));
1638
0
  *allowed_tx_mask = allow_bitmask;
1639
0
}
1640
1641
0
static float get_dev(float mean, double x2_sum, int num) {
1642
0
  const float e_x2 = (float)(x2_sum / num);
1643
0
  const float diff = e_x2 - mean * mean;
1644
0
  const float dev = (diff > 0) ? sqrtf(diff) : 0;
1645
0
  return dev;
1646
0
}
1647
1648
// Writes the features required by the ML model to predict tx split based on
1649
// mean and standard deviation values of the block and sub-blocks.
1650
// Returns the number of elements written to the output array which is at most
1651
// 12 currently. Hence 'features' buffer should be able to accommodate at least
1652
// 12 elements.
1653
static inline int get_mean_dev_features(const int16_t *data, int stride, int bw,
1654
0
                                        int bh, float *features) {
1655
0
  const int16_t *const data_ptr = &data[0];
1656
0
  const int subh = (bh >= bw) ? (bh >> 1) : bh;
1657
0
  const int subw = (bw >= bh) ? (bw >> 1) : bw;
1658
0
  const int num = bw * bh;
1659
0
  const int sub_num = subw * subh;
1660
0
  int feature_idx = 2;
1661
0
  int total_x_sum = 0;
1662
0
  int64_t total_x2_sum = 0;
1663
0
  int num_sub_blks = 0;
1664
0
  double mean2_sum = 0.0f;
1665
0
  float dev_sum = 0.0f;
1666
1667
0
  for (int row = 0; row < bh; row += subh) {
1668
0
    for (int col = 0; col < bw; col += subw) {
1669
0
      int x_sum;
1670
0
      int64_t x2_sum;
1671
      // TODO(any): Write a SIMD version. Clear registers.
1672
0
      aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
1673
0
                          &x_sum, &x2_sum);
1674
0
      total_x_sum += x_sum;
1675
0
      total_x2_sum += x2_sum;
1676
1677
0
      const float mean = (float)x_sum / sub_num;
1678
0
      const float dev = get_dev(mean, (double)x2_sum, sub_num);
1679
0
      features[feature_idx++] = mean;
1680
0
      features[feature_idx++] = dev;
1681
0
      mean2_sum += (double)(mean * mean);
1682
0
      dev_sum += dev;
1683
0
      num_sub_blks++;
1684
0
    }
1685
0
  }
1686
1687
0
  const float lvl0_mean = (float)total_x_sum / num;
1688
0
  features[0] = lvl0_mean;
1689
0
  features[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
1690
1691
  // Deviation of means.
1692
0
  features[feature_idx++] = get_dev(lvl0_mean, mean2_sum, num_sub_blks);
1693
  // Mean of deviations.
1694
0
  features[feature_idx++] = dev_sum / num_sub_blks;
1695
1696
0
  return feature_idx;
1697
0
}
1698
1699
static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
1700
0
                               int blk_col, TX_SIZE tx_size) {
1701
0
  const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
1702
0
  if (!nn_config) return -1;
1703
1704
0
  const int diff_stride = block_size_wide[bsize];
1705
0
  const int16_t *diff =
1706
0
      x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1707
0
  const int bw = tx_size_wide[tx_size];
1708
0
  const int bh = tx_size_high[tx_size];
1709
1710
0
  float features[64] = { 0.0f };
1711
0
  get_mean_dev_features(diff, diff_stride, bw, bh, features);
1712
1713
0
  float score = 0.0f;
1714
0
  av1_nn_predict(features, nn_config, 1, &score);
1715
1716
0
  int int_score = (int)(score * 10000);
1717
0
  return clamp(int_score, -80000, 80000);
1718
0
}
1719
1720
static inline uint16_t get_tx_mask(
1721
    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, int blk_row,
1722
    int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1723
    const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
1724
0
    int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
1725
0
  const AV1_COMMON *cm = &cpi->common;
1726
0
  MACROBLOCKD *xd = &x->e_mbd;
1727
0
  MB_MODE_INFO *mbmi = xd->mi[0];
1728
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
1729
0
  const int is_inter = is_inter_block(mbmi);
1730
0
  const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
1731
  // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
1732
  // TX_TYPES, only that specific tx type is allowed.
1733
0
  TX_TYPE txk_allowed = TX_TYPES;
1734
1735
0
  const FRAME_UPDATE_TYPE update_type =
1736
0
      get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index);
1737
0
  int use_actual_frame_probs = 1;
1738
0
  const int *tx_type_probs;
1739
#if CONFIG_FPMT_TEST
1740
  use_actual_frame_probs =
1741
      (cpi->ppi->fpmt_unit_test_cfg == PARALLEL_SIMULATION_ENCODE) ? 0 : 1;
1742
  if (!use_actual_frame_probs) {
1743
    tx_type_probs =
1744
        (int *)cpi->ppi->temp_frame_probs.tx_type_probs[update_type][tx_size];
1745
  }
1746
#endif
1747
0
  if (use_actual_frame_probs) {
1748
0
    tx_type_probs = cpi->ppi->frame_probs.tx_type_probs[update_type][tx_size];
1749
0
  }
1750
1751
0
  if ((!is_inter && txfm_params->use_default_intra_tx_type) ||
1752
0
      (is_inter && txfm_params->default_inter_tx_type_prob_thresh == 0)) {
1753
0
    txk_allowed =
1754
0
        get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools);
1755
0
  } else if (is_inter &&
1756
0
             txfm_params->default_inter_tx_type_prob_thresh != INT_MAX) {
1757
0
    if (tx_type_probs[DEFAULT_INTER_TX_TYPE] >
1758
0
        txfm_params->default_inter_tx_type_prob_thresh) {
1759
0
      txk_allowed = DEFAULT_INTER_TX_TYPE;
1760
0
    } else {
1761
0
      int force_tx_type = 0;
1762
0
      int max_prob = 0;
1763
0
      const int tx_type_prob_threshold =
1764
0
          txfm_params->default_inter_tx_type_prob_thresh +
1765
0
          PROB_THRESH_OFFSET_TX_TYPE;
1766
0
      for (int i = 1; i < TX_TYPES; i++) {  // find maximum probability.
1767
0
        if (tx_type_probs[i] > max_prob) {
1768
0
          max_prob = tx_type_probs[i];
1769
0
          force_tx_type = i;
1770
0
        }
1771
0
      }
1772
0
      if (max_prob > tx_type_prob_threshold)  // force tx type with max prob.
1773
0
        txk_allowed = force_tx_type;
1774
0
      else if (x->rd_model == LOW_TXFM_RD) {
1775
0
        if (plane == 0) txk_allowed = DCT_DCT;
1776
0
      }
1777
0
    }
1778
0
  } else if (x->rd_model == LOW_TXFM_RD) {
1779
0
    if (plane == 0) txk_allowed = DCT_DCT;
1780
0
  }
1781
1782
0
  const TxSetType tx_set_type = av1_get_ext_tx_set_type(
1783
0
      tx_size, is_inter, cm->features.reduced_tx_set_used);
1784
1785
0
  TX_TYPE uv_tx_type = DCT_DCT;
1786
0
  if (plane) {
1787
    // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
1788
0
    uv_tx_type = txk_allowed =
1789
0
        av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1790
0
                        cm->features.reduced_tx_set_used);
1791
0
  }
1792
0
  PREDICTION_MODE intra_dir =
1793
0
      mbmi->filter_intra_mode_info.use_filter_intra
1794
0
          ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
1795
0
          : mbmi->mode;
1796
0
  uint16_t ext_tx_used_flag =
1797
0
      cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset != 0 &&
1798
0
              tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
1799
0
          ? av1_reduced_intra_tx_used_flag[intra_dir]
1800
0
          : av1_ext_tx_used_flag[tx_set_type];
1801
1802
0
  if (cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset == 2)
1803
0
    ext_tx_used_flag &= av1_derived_intra_tx_used_flag[intra_dir];
1804
1805
0
  if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
1806
0
      ext_tx_used_flag == 0x0001 ||
1807
0
      (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) ||
1808
0
      (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) {
1809
0
    txk_allowed = DCT_DCT;
1810
0
  }
1811
1812
0
  if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0)
1813
0
    ext_tx_used_flag &= DCT_ADST_TX_MASK;
1814
1815
0
  uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
1816
0
  if (txk_allowed < TX_TYPES) {
1817
0
    allowed_tx_mask = 1 << txk_allowed;
1818
0
    allowed_tx_mask &= ext_tx_used_flag;
1819
0
  } else if (fast_tx_search) {
1820
0
    allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
1821
0
    allowed_tx_mask &= ext_tx_used_flag;
1822
0
  } else {
1823
0
    assert(plane == 0);
1824
0
    allowed_tx_mask = ext_tx_used_flag;
1825
0
    int num_allowed = 0;
1826
0
    int i;
1827
1828
0
    if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
1829
0
      static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
1830
0
                                            { 10, 17, 17, 10, 17, 17, 17 } };
1831
0
      const int thresh =
1832
0
          thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
1833
0
                    [update_type];
1834
0
      uint16_t prune = 0;
1835
0
      int max_prob = -1;
1836
0
      int max_idx = 0;
1837
0
      for (i = 0; i < TX_TYPES; i++) {
1838
0
        if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
1839
0
          max_prob = tx_type_probs[i];
1840
0
          max_idx = i;
1841
0
        }
1842
0
        if (tx_type_probs[i] < thresh) prune |= (1 << i);
1843
0
      }
1844
0
      if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
1845
0
      allowed_tx_mask &= (~prune);
1846
0
    }
1847
0
    for (i = 0; i < TX_TYPES; i++) {
1848
0
      if (allowed_tx_mask & (1 << i)) num_allowed++;
1849
0
    }
1850
0
    assert(num_allowed > 0);
1851
1852
0
    if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
1853
0
      int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
1854
0
      int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
1855
0
      if (num_allowed <= 7) {
1856
0
        const uint16_t prune =
1857
0
            prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
1858
0
                           plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
1859
0
                           cm->features.reduced_tx_set_used);
1860
0
        allowed_tx_mask &= (~prune);
1861
0
      } else {
1862
0
        const int num_sel = (num_allowed * mf + 50) / 100;
1863
0
        const uint16_t prune = prune_txk_type_separ(
1864
0
            cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
1865
0
            txk_map, allowed_tx_mask, pf, txb_ctx,
1866
0
            cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
1867
1868
0
        allowed_tx_mask &= (~prune);
1869
0
      }
1870
0
    } else {
1871
0
      assert(num_allowed > 0);
1872
0
      int allowed_tx_count =
1873
0
          (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5;
1874
      // !fast_tx_search && txk_end != txk_start && plane == 0
1875
0
      if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter &&
1876
0
          num_allowed > allowed_tx_count) {
1877
0
        prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
1878
0
                    txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask);
1879
0
      }
1880
0
    }
1881
0
  }
1882
1883
  // Need to have at least one transform type allowed.
1884
0
  if (allowed_tx_mask == 0) {
1885
0
    txk_allowed = (plane ? uv_tx_type : DCT_DCT);
1886
0
    allowed_tx_mask = (1 << txk_allowed);
1887
0
  }
1888
1889
0
  assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
1890
0
  *allowed_txk_types = txk_allowed;
1891
0
  return allowed_tx_mask;
1892
0
}
1893
1894
#if CONFIG_RD_DEBUG
1895
static inline void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
1896
                                         int txb_coeff_cost) {
1897
  rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
1898
}
1899
#endif
1900
1901
static inline int cost_coeffs(MACROBLOCK *x, int plane, int block,
1902
                              TX_SIZE tx_size, const TX_TYPE tx_type,
1903
                              const TXB_CTX *const txb_ctx,
1904
0
                              int reduced_tx_set_used) {
1905
#if TXCOEFF_COST_TIMER
1906
  struct aom_usec_timer timer;
1907
  aom_usec_timer_start(&timer);
1908
#endif
1909
0
  const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
1910
0
                                       txb_ctx, reduced_tx_set_used);
1911
#if TXCOEFF_COST_TIMER
1912
  AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
1913
  aom_usec_timer_mark(&timer);
1914
  const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
1915
  tmp_cm->txcoeff_cost_timer += elapsed_time;
1916
  ++tmp_cm->txcoeff_cost_count;
1917
#endif
1918
0
  return cost;
1919
0
}
1920
1921
static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
1922
                                          QUANT_PARAM *quant_param, int plane,
1923
                                          int block, TX_SIZE tx_size,
1924
                                          int quant_b_adapt, int qstep,
1925
                                          unsigned int coeff_opt_satd_threshold,
1926
0
                                          int skip_trellis, int dc_only_blk) {
1927
0
  if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX))
1928
0
    return skip_trellis;
1929
1930
0
  const struct macroblock_plane *const p = &x->plane[plane];
1931
0
  const int block_offset = BLOCK_OFFSET(block);
1932
0
  tran_low_t *const coeff_ptr = p->coeff + block_offset;
1933
0
  const int n_coeffs = av1_get_max_eob(tx_size);
1934
0
  const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size));
1935
0
  int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs);
1936
0
  satd = RIGHT_SIGNED_SHIFT(satd, shift);
1937
0
  satd >>= (x->e_mbd.bd - 8);
1938
1939
0
  const int skip_block_trellis =
1940
0
      ((uint64_t)satd >
1941
0
       (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]);
1942
1943
0
  av1_setup_quant(
1944
0
      tx_size, !skip_block_trellis,
1945
0
      skip_block_trellis
1946
0
          ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP)
1947
0
          : AV1_XFORM_QUANT_FP,
1948
0
      quant_b_adapt, quant_param);
1949
1950
0
  return skip_block_trellis;
1951
0
}
1952
1953
// Predict DC only blocks if the residual variance is below a qstep based
1954
// threshold.For such blocks, transform type search is bypassed.
1955
static inline void predict_dc_only_block(
1956
    MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1957
    int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
1958
    int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
1959
0
    int *dc_only_blk) {
1960
0
  MACROBLOCKD *xd = &x->e_mbd;
1961
0
  MB_MODE_INFO *mbmi = xd->mi[0];
1962
0
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
1963
0
  const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
1964
0
  uint64_t block_var = UINT64_MAX;
1965
0
  const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3;
1966
0
  *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize,
1967
0
                                txsize_to_bsize[tx_size], block_mse_q8,
1968
0
                                per_px_mean, &block_var);
1969
0
  assert((*block_mse_q8) != UINT_MAX);
1970
0
  uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep);
1971
0
  if (is_cur_buf_hbd(xd))
1972
0
    block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2);
1973
1974
0
  if (block_var >= var_threshold) return;
1975
0
  const unsigned int predict_dc_level = x->txfm_search_params.predict_dc_level;
1976
0
  assert(predict_dc_level != 0);
1977
1978
  // Prediction of skip block if residual mean and variance are less
1979
  // than qstep based threshold
1980
0
  if ((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) {
1981
    // If the normalized mean of residual block is less than the dc qstep and
1982
    // the  normalized block variance is less than ac qstep, then the block is
1983
    // assumed to be a skip block and its rdcost is updated accordingly.
1984
0
    best_rd_stats->skip_txfm = 1;
1985
1986
0
    x->plane[plane].eobs[block] = 0;
1987
1988
0
    if (is_cur_buf_hbd(xd))
1989
0
      *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2);
1990
1991
0
    best_rd_stats->dist = (*block_sse) << 4;
1992
0
    best_rd_stats->sse = best_rd_stats->dist;
1993
1994
0
    ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
1995
0
    ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
1996
0
    av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl);
1997
0
    ENTROPY_CONTEXT *ta = ctxa;
1998
0
    ENTROPY_CONTEXT *tl = ctxl;
1999
0
    const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2000
0
    TXB_CTX txb_ctx_tmp;
2001
0
    const PLANE_TYPE plane_type = get_plane_type(plane);
2002
0
    get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp);
2003
0
    const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type]
2004
0
                                  .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1];
2005
0
    best_rd_stats->rate = zero_blk_rate;
2006
2007
0
    best_rd_stats->rdcost =
2008
0
        RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse);
2009
2010
0
    x->plane[plane].txb_entropy_ctx[block] = 0;
2011
0
  } else if (predict_dc_level > 1) {
2012
    // Predict DC only blocks based on residual variance.
2013
    // For chroma plane, this prediction is disabled for intra blocks.
2014
0
    if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1;
2015
0
  }
2016
0
}
2017
2018
// Search for the best transform type for a given transform block.
2019
// This function can be used for both inter and intra, both luma and chroma.
2020
static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2021
                           int block, int blk_row, int blk_col,
2022
                           BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2023
                           const TXB_CTX *const txb_ctx,
2024
                           FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis,
2025
0
                           int64_t ref_best_rd, RD_STATS *best_rd_stats) {
2026
0
  const AV1_COMMON *cm = &cpi->common;
2027
0
  MACROBLOCKD *xd = &x->e_mbd;
2028
0
  MB_MODE_INFO *mbmi = xd->mi[0];
2029
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2030
0
  int64_t best_rd = INT64_MAX;
2031
0
  uint16_t best_eob = 0;
2032
0
  TX_TYPE best_tx_type = DCT_DCT;
2033
0
  int rate_cost = 0;
2034
0
  struct macroblock_plane *const p = &x->plane[plane];
2035
0
  tran_low_t *orig_dqcoeff = p->dqcoeff;
2036
0
  tran_low_t *best_dqcoeff = x->dqcoeff_buf;
2037
0
  const int tx_type_map_idx =
2038
0
      plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
2039
0
  av1_invalid_rd_stats(best_rd_stats);
2040
2041
0
  skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
2042
0
                                   DRY_RUN_NORMAL);
2043
2044
0
  uint8_t best_txb_ctx = 0;
2045
  // txk_allowed = TX_TYPES: >1 tx types are allowed
2046
  // txk_allowed < TX_TYPES: only that specific tx type is allowed.
2047
0
  TX_TYPE txk_allowed = TX_TYPES;
2048
0
  int txk_map[TX_TYPES] = {
2049
0
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
2050
0
  };
2051
0
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2052
0
  const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2053
2054
0
  const uint8_t txw = tx_size_wide[tx_size];
2055
0
  const uint8_t txh = tx_size_high[tx_size];
2056
0
  int64_t block_sse;
2057
0
  unsigned int block_mse_q8;
2058
0
  int dc_only_blk = 0;
2059
0
  const bool predict_dc_block =
2060
0
      txfm_params->predict_dc_level >= 1 && txw != 64 && txh != 64;
2061
0
  int64_t per_px_mean = INT64_MAX;
2062
0
  if (predict_dc_block) {
2063
0
    predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row,
2064
0
                          blk_col, best_rd_stats, &block_sse, &block_mse_q8,
2065
0
                          &per_px_mean, &dc_only_blk);
2066
0
    if (best_rd_stats->skip_txfm == 1) {
2067
0
      const TX_TYPE tx_type = DCT_DCT;
2068
0
      if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2069
0
      return;
2070
0
    }
2071
0
  } else {
2072
0
    block_sse = av1_pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
2073
0
                                    txsize_to_bsize[tx_size], &block_mse_q8);
2074
0
    assert(block_mse_q8 != UINT_MAX);
2075
0
  }
2076
2077
  // Bit mask to indicate which transform types are allowed in the RD search.
2078
0
  uint16_t tx_mask;
2079
2080
  // Use DCT_DCT transform for DC only block.
2081
0
  if (dc_only_blk || cpi->sf.rt_sf.dct_only_palette_nonrd == 1)
2082
0
    tx_mask = 1 << DCT_DCT;
2083
0
  else
2084
0
    tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
2085
0
                          tx_size, txb_ctx, ftxs_mode, ref_best_rd,
2086
0
                          &txk_allowed, txk_map);
2087
0
  const uint16_t allowed_tx_mask = tx_mask;
2088
2089
0
  if (is_cur_buf_hbd(xd)) {
2090
0
    block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
2091
0
    block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
2092
0
  }
2093
0
  block_sse *= 16;
2094
  // Use mse / qstep^2 based threshold logic to take decision of R-D
2095
  // optimization of coeffs. For smaller residuals, coeff optimization
2096
  // would be helpful. For larger residuals, R-D optimization may not be
2097
  // effective.
2098
  // TODO(any): Experiment with variance and mean based thresholds
2099
0
  const int perform_block_coeff_opt =
2100
0
      ((uint64_t)block_mse_q8 <=
2101
0
       (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep);
2102
0
  skip_trellis |= !perform_block_coeff_opt;
2103
2104
  // Flag to indicate if distortion should be calculated in transform domain or
2105
  // not during iterating through transform type candidates.
2106
  // Transform domain distortion is accurate for higher residuals.
2107
  // TODO(any): Experiment with variance and mean based thresholds
2108
0
  int use_transform_domain_distortion =
2109
0
      (txfm_params->use_transform_domain_distortion > 0) &&
2110
0
      (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) &&
2111
      // Any 64-pt transforms only preserves half the coefficients.
2112
      // Therefore transform domain distortion is not valid for these
2113
      // transform sizes.
2114
0
      (txsize_sqr_up_map[tx_size] != TX_64X64) &&
2115
      // Use pixel domain distortion for DC only blocks
2116
0
      !dc_only_blk;
2117
  // Flag to indicate if an extra calculation of distortion in the pixel domain
2118
  // should be performed at the end, after the best transform type has been
2119
  // decided.
2120
0
  int calc_pixel_domain_distortion_final =
2121
0
      txfm_params->use_transform_domain_distortion == 1 &&
2122
0
      use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
2123
0
  if (calc_pixel_domain_distortion_final &&
2124
0
      (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
2125
0
    calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
2126
2127
0
  const uint16_t *eobs_ptr = x->plane[plane].eobs;
2128
2129
0
  TxfmParam txfm_param;
2130
0
  QUANT_PARAM quant_param;
2131
0
  int skip_trellis_based_on_satd[TX_TYPES] = { 0 };
2132
0
  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
2133
0
  av1_setup_quant(tx_size, !skip_trellis,
2134
0
                  skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
2135
0
                                                         : AV1_XFORM_QUANT_FP)
2136
0
                               : AV1_XFORM_QUANT_FP,
2137
0
                  cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
2138
2139
  // Iterate through all transform type candidates.
2140
0
  for (int idx = 0; idx < TX_TYPES; ++idx) {
2141
0
    const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
2142
0
    if (tx_type == TX_TYPE_INVALID || !check_bit_mask(allowed_tx_mask, tx_type))
2143
0
      continue;
2144
0
    txfm_param.tx_type = tx_type;
2145
0
    if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
2146
0
      av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
2147
0
                        &quant_param);
2148
0
    }
2149
0
    if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2150
0
    RD_STATS this_rd_stats;
2151
0
    av1_invalid_rd_stats(&this_rd_stats);
2152
2153
0
    if (!dc_only_blk)
2154
0
      av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
2155
0
    else
2156
0
      av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
2157
2158
0
    skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd(
2159
0
        x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt,
2160
0
        qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk);
2161
2162
0
    av1_quant(x, plane, block, &txfm_param, &quant_param);
2163
2164
    // Calculate rate cost of quantized coefficients.
2165
0
    if (quant_param.use_optimize_b) {
2166
      // TODO(aomedia:3209): update Trellis quantization to take into account
2167
      // quantization matrices.
2168
0
      av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
2169
0
                     &rate_cost);
2170
0
    } else {
2171
0
      rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
2172
0
                              cm->features.reduced_tx_set_used);
2173
0
    }
2174
2175
    // If rd cost based on coeff rate alone is already more than best_rd,
2176
    // terminate early.
2177
0
    if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
2178
2179
    // Calculate distortion.
2180
0
    if (eobs_ptr[block] == 0) {
2181
      // When eob is 0, pixel domain distortion is more efficient and accurate.
2182
0
      this_rd_stats.dist = this_rd_stats.sse = block_sse;
2183
0
    } else if (dc_only_blk) {
2184
0
      this_rd_stats.sse = block_sse;
2185
0
      this_rd_stats.dist = dist_block_px_domain(
2186
0
          cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2187
0
    } else if (use_transform_domain_distortion) {
2188
0
      const SCAN_ORDER *const scan_order =
2189
0
          get_scan(txfm_param.tx_size, txfm_param.tx_type);
2190
0
      dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2191
0
                           scan_order->scan, &this_rd_stats.dist,
2192
0
                           &this_rd_stats.sse);
2193
0
    } else {
2194
0
      int64_t sse_diff = INT64_MAX;
2195
      // high_energy threshold assumes that every pixel within a txfm block
2196
      // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
2197
      // for 8 bit.
2198
0
      const int64_t high_energy_thresh =
2199
0
          ((int64_t)128 * 128 * tx_size_2d[tx_size]);
2200
0
      const int is_high_energy = (block_sse >= high_energy_thresh);
2201
0
      if (tx_size == TX_64X64 || is_high_energy) {
2202
        // Because 3 out 4 quadrants of transform coefficients are forced to
2203
        // zero, the inverse transform has a tendency to overflow. sse_diff
2204
        // is effectively the energy of those 3 quadrants, here we use it
2205
        // to decide if we should do pixel domain distortion. If the energy
2206
        // is mostly in first quadrant, then it is unlikely that we have
2207
        // overflow issue in inverse transform.
2208
0
        const SCAN_ORDER *const scan_order =
2209
0
            get_scan(txfm_param.tx_size, txfm_param.tx_type);
2210
0
        dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2211
0
                             scan_order->scan, &this_rd_stats.dist,
2212
0
                             &this_rd_stats.sse);
2213
0
        sse_diff = block_sse - this_rd_stats.sse;
2214
0
      }
2215
0
      if (tx_size != TX_64X64 || !is_high_energy ||
2216
0
          (sse_diff * 2) < this_rd_stats.sse) {
2217
0
        const int64_t tx_domain_dist = this_rd_stats.dist;
2218
0
        this_rd_stats.dist = dist_block_px_domain(
2219
0
            cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2220
        // For high energy blocks, occasionally, the pixel domain distortion
2221
        // can be artificially low due to clamping at reconstruction stage
2222
        // even when inverse transform output is hugely different from the
2223
        // actual residue.
2224
0
        if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
2225
0
          this_rd_stats.dist = tx_domain_dist;
2226
0
      } else {
2227
0
        assert(sse_diff < INT64_MAX);
2228
0
        this_rd_stats.dist += sse_diff;
2229
0
      }
2230
0
      this_rd_stats.sse = block_sse;
2231
0
    }
2232
2233
0
    this_rd_stats.rate = rate_cost;
2234
2235
0
    const int64_t rd =
2236
0
        RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2237
2238
0
    if (rd < best_rd) {
2239
0
      best_rd = rd;
2240
0
      *best_rd_stats = this_rd_stats;
2241
0
      best_tx_type = tx_type;
2242
0
      best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
2243
0
      best_eob = x->plane[plane].eobs[block];
2244
      // Swap dqcoeff buffers
2245
0
      tran_low_t *const tmp_dqcoeff = best_dqcoeff;
2246
0
      best_dqcoeff = p->dqcoeff;
2247
0
      p->dqcoeff = tmp_dqcoeff;
2248
0
    }
2249
2250
#if CONFIG_COLLECT_RD_STATS == 1
2251
    if (plane == 0) {
2252
      PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
2253
                              plane_bsize, tx_size, tx_type, rd);
2254
    }
2255
#endif  // CONFIG_COLLECT_RD_STATS == 1
2256
2257
#if COLLECT_TX_SIZE_DATA
2258
    // Generate small sample to restrict output size.
2259
    static unsigned int seed = 21743;
2260
    if (lcg_rand16(&seed) % 200 == 0) {
2261
      FILE *fp = NULL;
2262
2263
      if (within_border) {
2264
        fp = fopen(av1_tx_size_data_output_file, "a");
2265
      }
2266
2267
      if (fp) {
2268
        // Transform info and RD
2269
        const int txb_w = tx_size_wide[tx_size];
2270
        const int txb_h = tx_size_high[tx_size];
2271
2272
        // Residue signal.
2273
        const int diff_stride = block_size_wide[plane_bsize];
2274
        struct macroblock_plane *const p = &x->plane[plane];
2275
        const int16_t *src_diff =
2276
            &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
2277
2278
        for (int r = 0; r < txb_h; ++r) {
2279
          for (int c = 0; c < txb_w; ++c) {
2280
            fprintf(fp, "%d,", src_diff[c]);
2281
          }
2282
          src_diff += diff_stride;
2283
        }
2284
2285
        fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
2286
        fprintf(fp, "\n");
2287
        fclose(fp);
2288
      }
2289
    }
2290
#endif  // COLLECT_TX_SIZE_DATA
2291
2292
    // If the current best RD cost is much worse than the reference RD cost,
2293
    // terminate early.
2294
0
    if (cpi->sf.tx_sf.adaptive_txb_search_level) {
2295
0
      if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
2296
0
          ref_best_rd) {
2297
0
        break;
2298
0
      }
2299
0
    }
2300
2301
    // Terminate transform type search if the block has been quantized to
2302
    // all zero.
2303
0
    if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
2304
0
  }
2305
2306
0
  assert(best_rd != INT64_MAX);
2307
2308
0
  best_rd_stats->skip_txfm = best_eob == 0;
2309
0
  if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2310
0
  x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
2311
0
  x->plane[plane].eobs[block] = best_eob;
2312
0
  skip_trellis = skip_trellis_based_on_satd[best_tx_type];
2313
2314
  // Point dqcoeff to the quantized coefficients corresponding to the best
2315
  // transform type, then we can skip transform and quantization, e.g. in the
2316
  // final pixel domain distortion calculation and recon_intra().
2317
0
  p->dqcoeff = best_dqcoeff;
2318
2319
0
  if (calc_pixel_domain_distortion_final && best_eob) {
2320
0
    best_rd_stats->dist = dist_block_px_domain(
2321
0
        cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2322
0
    best_rd_stats->sse = block_sse;
2323
0
  }
2324
2325
  // Intra mode needs decoded pixels such that the next transform block
2326
  // can use them for prediction.
2327
0
  recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2328
0
              txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
2329
0
  p->dqcoeff = orig_dqcoeff;
2330
0
}
2331
2332
// Pick transform type for a luma transform block of tx_size. Note this function
2333
// is used only for inter-predicted blocks.
2334
static inline void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
2335
                              TX_SIZE tx_size, int blk_row, int blk_col,
2336
                              int block, int plane_bsize, TXB_CTX *txb_ctx,
2337
                              RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode,
2338
0
                              int64_t ref_rdcost) {
2339
0
  assert(is_inter_block(x->e_mbd.mi[0]));
2340
0
  RD_STATS this_rd_stats;
2341
0
  const int skip_trellis = 0;
2342
0
  search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
2343
0
                 txb_ctx, ftxs_mode, skip_trellis, ref_rdcost, &this_rd_stats);
2344
2345
0
  av1_merge_rd_stats(rd_stats, &this_rd_stats);
2346
0
}
2347
2348
static inline void try_tx_block_no_split(
2349
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2350
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
2351
    const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
2352
    int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
2353
0
    FAST_TX_SEARCH_MODE ftxs_mode, TxCandidateInfo *no_split) {
2354
0
  MACROBLOCKD *const xd = &x->e_mbd;
2355
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2356
0
  struct macroblock_plane *const p = &x->plane[0];
2357
0
  const int bw = mi_size_wide[plane_bsize];
2358
0
  const ENTROPY_CONTEXT *const pta = ta + blk_col;
2359
0
  const ENTROPY_CONTEXT *const ptl = tl + blk_row;
2360
0
  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2361
0
  TXB_CTX txb_ctx;
2362
0
  get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
2363
0
  const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
2364
0
                                .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2365
0
  rd_stats->zero_rate = zero_blk_rate;
2366
0
  const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
2367
0
  mbmi->inter_tx_size[index] = tx_size;
2368
0
  tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2369
0
             rd_stats, ftxs_mode, ref_best_rd);
2370
0
  assert(rd_stats->rate < INT_MAX);
2371
2372
0
  const int pick_skip_txfm =
2373
0
      !xd->lossless[mbmi->segment_id] &&
2374
0
      (rd_stats->skip_txfm == 1 ||
2375
0
       RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2376
0
           RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
2377
0
  if (pick_skip_txfm) {
2378
#if CONFIG_RD_DEBUG
2379
    update_txb_coeff_cost(rd_stats, 0, zero_blk_rate - rd_stats->rate);
2380
#endif  // CONFIG_RD_DEBUG
2381
0
    rd_stats->rate = zero_blk_rate;
2382
0
    rd_stats->dist = rd_stats->sse;
2383
0
    p->eobs[block] = 0;
2384
0
    update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2385
0
  }
2386
0
  rd_stats->skip_txfm = pick_skip_txfm;
2387
0
  set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2388
0
               pick_skip_txfm);
2389
2390
0
  if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2391
0
    rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0];
2392
2393
0
  no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
2394
0
  no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
2395
0
  no_split->tx_type =
2396
0
      xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
2397
0
}
2398
2399
static inline void try_tx_block_split(
2400
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2401
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2402
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2403
    int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
2404
0
    FAST_TX_SEARCH_MODE ftxs_mode, RD_STATS *split_rd_stats) {
2405
0
  assert(tx_size < TX_SIZES_ALL);
2406
0
  MACROBLOCKD *const xd = &x->e_mbd;
2407
0
  const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2408
0
  const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2409
0
  const int txb_width = tx_size_wide_unit[tx_size];
2410
0
  const int txb_height = tx_size_high_unit[tx_size];
2411
  // Transform size after splitting current block.
2412
0
  const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2413
0
  const int sub_txb_width = tx_size_wide_unit[sub_txs];
2414
0
  const int sub_txb_height = tx_size_high_unit[sub_txs];
2415
0
  const int sub_step = sub_txb_width * sub_txb_height;
2416
0
  const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
2417
0
  assert(nblks > 0);
2418
0
  av1_init_rd_stats(split_rd_stats);
2419
0
  split_rd_stats->rate =
2420
0
      x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1];
2421
2422
0
  for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
2423
0
    const int offsetr = blk_row + r;
2424
0
    if (offsetr >= max_blocks_high) break;
2425
0
    for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
2426
0
      assert(blk_idx < 4);
2427
0
      const int offsetc = blk_col + c;
2428
0
      if (offsetc >= max_blocks_wide) continue;
2429
2430
0
      RD_STATS this_rd_stats;
2431
0
      int this_cost_valid = 1;
2432
0
      select_tx_block(cpi, x, offsetr, offsetc, block, sub_txs, depth + 1,
2433
0
                      plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
2434
0
                      no_split_rd / nblks, ref_best_rd - split_rd_stats->rdcost,
2435
0
                      &this_cost_valid, ftxs_mode);
2436
0
      if (!this_cost_valid) {
2437
0
        split_rd_stats->rdcost = INT64_MAX;
2438
0
        return;
2439
0
      }
2440
0
      av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
2441
0
      split_rd_stats->rdcost =
2442
0
          RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
2443
0
      if (split_rd_stats->rdcost > ref_best_rd) {
2444
0
        split_rd_stats->rdcost = INT64_MAX;
2445
0
        return;
2446
0
      }
2447
0
      block += sub_step;
2448
0
    }
2449
0
  }
2450
0
}
2451
2452
0
static float get_var(float mean, double x2_sum, int num) {
2453
0
  const float e_x2 = (float)(x2_sum / num);
2454
0
  const float diff = e_x2 - mean * mean;
2455
0
  return diff;
2456
0
}
2457
2458
static inline void get_blk_var_dev(const int16_t *data, int stride, int bw,
2459
                                   int bh, float *dev_of_mean,
2460
0
                                   float *var_of_vars) {
2461
0
  const int16_t *const data_ptr = &data[0];
2462
0
  const int subh = (bh >= bw) ? (bh >> 1) : bh;
2463
0
  const int subw = (bw >= bh) ? (bw >> 1) : bw;
2464
0
  const int num = bw * bh;
2465
0
  const int sub_num = subw * subh;
2466
0
  int total_x_sum = 0;
2467
0
  int64_t total_x2_sum = 0;
2468
0
  int blk_idx = 0;
2469
0
  float var_sum = 0.0f;
2470
0
  float mean_sum = 0.0f;
2471
0
  double var2_sum = 0.0f;
2472
0
  double mean2_sum = 0.0f;
2473
2474
0
  for (int row = 0; row < bh; row += subh) {
2475
0
    for (int col = 0; col < bw; col += subw) {
2476
0
      int x_sum;
2477
0
      int64_t x2_sum;
2478
0
      aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
2479
0
                          &x_sum, &x2_sum);
2480
0
      total_x_sum += x_sum;
2481
0
      total_x2_sum += x2_sum;
2482
2483
0
      const float mean = (float)x_sum / sub_num;
2484
0
      const float var = get_var(mean, (double)x2_sum, sub_num);
2485
0
      mean_sum += mean;
2486
0
      mean2_sum += (double)(mean * mean);
2487
0
      var_sum += var;
2488
0
      var2_sum += var * var;
2489
0
      blk_idx++;
2490
0
    }
2491
0
  }
2492
2493
0
  const float lvl0_mean = (float)total_x_sum / num;
2494
0
  const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num);
2495
0
  mean_sum += lvl0_mean;
2496
0
  mean2_sum += (double)(lvl0_mean * lvl0_mean);
2497
0
  var_sum += block_var;
2498
0
  var2_sum += block_var * block_var;
2499
0
  const float av_mean = mean_sum / 5;
2500
2501
0
  if (blk_idx > 1) {
2502
    // Deviation of means.
2503
0
    *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1));
2504
    // Variance of variances.
2505
0
    const float mean_var = var_sum / (blk_idx + 1);
2506
0
    *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1));
2507
0
  }
2508
0
}
2509
2510
static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
2511
                                    int blk_row, int blk_col, TX_SIZE tx_size,
2512
                                    int *try_no_split, int *try_split,
2513
0
                                    int pruning_level) {
2514
0
  const int diff_stride = block_size_wide[bsize];
2515
0
  const int16_t *diff =
2516
0
      x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
2517
0
  const int bw = tx_size_wide[tx_size];
2518
0
  const int bh = tx_size_high[tx_size];
2519
0
  float dev_of_means = 0.0f;
2520
0
  float var_of_vars = 0.0f;
2521
2522
  // This function calculates the deviation of means, and the variance of pixel
2523
  // variances of the block as well as it's sub-blocks.
2524
0
  get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars);
2525
0
  const int dc_q = x->plane[0].dequant_QTX[0] >> 3;
2526
0
  const int ac_q = x->plane[0].dequant_QTX[1] >> 3;
2527
0
  const int no_split_thresh_scales[4] = { 0, 24, 8, 8 };
2528
0
  const int no_split_thresh_scale = no_split_thresh_scales[pruning_level];
2529
0
  const int split_thresh_scales[4] = { 0, 24, 10, 8 };
2530
0
  const int split_thresh_scale = split_thresh_scales[pruning_level];
2531
2532
0
  if ((dev_of_means <= dc_q) &&
2533
0
      (split_thresh_scale * var_of_vars <= ac_q * ac_q)) {
2534
0
    *try_split = 0;
2535
0
  }
2536
0
  if ((dev_of_means > no_split_thresh_scale * dc_q) &&
2537
0
      (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) {
2538
0
    *try_no_split = 0;
2539
0
  }
2540
0
}
2541
2542
// Search for the best transform partition(recursive)/type for a given
2543
// inter-predicted luma block. The obtained transform selection will be saved
2544
// in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
2545
static inline void select_tx_block(
2546
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2547
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2548
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2549
    RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
2550
0
    int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode) {
2551
0
  assert(tx_size < TX_SIZES_ALL);
2552
0
  av1_init_rd_stats(rd_stats);
2553
0
  if (ref_best_rd < 0) {
2554
0
    *is_cost_valid = 0;
2555
0
    return;
2556
0
  }
2557
2558
0
  MACROBLOCKD *const xd = &x->e_mbd;
2559
0
  assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
2560
0
         blk_col < max_block_wide(xd, plane_bsize, 0));
2561
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2562
0
  const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2563
0
                                         mbmi->bsize, tx_size);
2564
0
  struct macroblock_plane *const p = &x->plane[0];
2565
2566
0
  int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
2567
0
                      txsize_sqr_up_map[tx_size] != TX_64X64) &&
2568
0
                     (cpi->oxcf.txfm_cfg.enable_rect_tx ||
2569
0
                      tx_size_wide[tx_size] == tx_size_high[tx_size]);
2570
0
  int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
2571
0
  TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
2572
2573
  // Prune tx_split and no-split based on sub-block properties.
2574
0
  if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 &&
2575
0
      cpi->sf.tx_sf.prune_tx_size_level > 0) {
2576
0
    prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size,
2577
0
                            &try_no_split, &try_split,
2578
0
                            cpi->sf.tx_sf.prune_tx_size_level);
2579
0
  }
2580
2581
0
  if (cpi->sf.rt_sf.skip_tx_no_split_var_based_partition) {
2582
0
    if (x->try_merge_partition && try_split && p->eobs[block]) try_no_split = 0;
2583
0
  }
2584
2585
  // Try using current block as a single transform block without split.
2586
0
  if (try_no_split) {
2587
0
    try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2588
0
                          plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
2589
0
                          ftxs_mode, &no_split);
2590
2591
    // Speed features for early termination.
2592
0
    const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
2593
0
    if (search_level) {
2594
0
      if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
2595
0
        *is_cost_valid = 0;
2596
0
        return;
2597
0
      }
2598
0
      if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
2599
0
        try_split = 0;
2600
0
      }
2601
0
    }
2602
0
    if (cpi->sf.tx_sf.txb_split_cap) {
2603
0
      if (p->eobs[block] == 0) try_split = 0;
2604
0
    }
2605
0
  }
2606
2607
  // ML based speed feature to skip searching for split transform blocks.
2608
0
  if (x->e_mbd.bd == 8 && try_split &&
2609
0
      !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
2610
0
    const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
2611
0
    if (threshold >= 0) {
2612
0
      const int split_score =
2613
0
          ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
2614
0
      if (split_score < -threshold) try_split = 0;
2615
0
    }
2616
0
  }
2617
2618
0
  RD_STATS split_rd_stats;
2619
0
  split_rd_stats.rdcost = INT64_MAX;
2620
  // Try splitting current block into smaller transform blocks.
2621
0
  if (try_split) {
2622
0
    try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2623
0
                       plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
2624
0
                       AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
2625
0
                       &split_rd_stats);
2626
0
  }
2627
2628
0
  if (no_split.rd < split_rd_stats.rdcost) {
2629
0
    ENTROPY_CONTEXT *pta = ta + blk_col;
2630
0
    ENTROPY_CONTEXT *ptl = tl + blk_row;
2631
0
    p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
2632
0
    av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
2633
0
    txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2634
0
                          tx_size);
2635
0
    for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
2636
0
      for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
2637
0
        const int index =
2638
0
            av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
2639
0
        mbmi->inter_tx_size[index] = tx_size;
2640
0
      }
2641
0
    }
2642
0
    mbmi->tx_size = tx_size;
2643
0
    update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
2644
0
    const int bw = mi_size_wide[plane_bsize];
2645
0
    set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2646
0
                 rd_stats->skip_txfm);
2647
0
  } else {
2648
0
    *rd_stats = split_rd_stats;
2649
0
    if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
2650
0
  }
2651
0
}
2652
2653
static inline void choose_largest_tx_size(const AV1_COMP *const cpi,
2654
                                          MACROBLOCK *x, RD_STATS *rd_stats,
2655
0
                                          int64_t ref_best_rd, BLOCK_SIZE bs) {
2656
0
  MACROBLOCKD *const xd = &x->e_mbd;
2657
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2658
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2659
0
  mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2660
2661
  // If tx64 is not enabled, we need to go down to the next available size
2662
0
  if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
2663
0
    static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
2664
0
      TX_4X4,    // 4x4 transform
2665
0
      TX_8X8,    // 8x8 transform
2666
0
      TX_16X16,  // 16x16 transform
2667
0
      TX_32X32,  // 32x32 transform
2668
0
      TX_32X32,  // 64x64 transform
2669
0
      TX_4X8,    // 4x8 transform
2670
0
      TX_8X4,    // 8x4 transform
2671
0
      TX_8X16,   // 8x16 transform
2672
0
      TX_16X8,   // 16x8 transform
2673
0
      TX_16X32,  // 16x32 transform
2674
0
      TX_32X16,  // 32x16 transform
2675
0
      TX_32X32,  // 32x64 transform
2676
0
      TX_32X32,  // 64x32 transform
2677
0
      TX_4X16,   // 4x16 transform
2678
0
      TX_16X4,   // 16x4 transform
2679
0
      TX_8X32,   // 8x32 transform
2680
0
      TX_32X8,   // 32x8 transform
2681
0
      TX_16X32,  // 16x64 transform
2682
0
      TX_32X16,  // 64x16 transform
2683
0
    };
2684
0
    mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
2685
0
  } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
2686
0
             !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2687
0
    static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
2688
0
      TX_4X4,    // 4x4 transform
2689
0
      TX_8X8,    // 8x8 transform
2690
0
      TX_16X16,  // 16x16 transform
2691
0
      TX_32X32,  // 32x32 transform
2692
0
      TX_64X64,  // 64x64 transform
2693
0
      TX_4X4,    // 4x8 transform
2694
0
      TX_4X4,    // 8x4 transform
2695
0
      TX_8X8,    // 8x16 transform
2696
0
      TX_8X8,    // 16x8 transform
2697
0
      TX_16X16,  // 16x32 transform
2698
0
      TX_16X16,  // 32x16 transform
2699
0
      TX_32X32,  // 32x64 transform
2700
0
      TX_32X32,  // 64x32 transform
2701
0
      TX_4X4,    // 4x16 transform
2702
0
      TX_4X4,    // 16x4 transform
2703
0
      TX_8X8,    // 8x32 transform
2704
0
      TX_8X8,    // 32x8 transform
2705
0
      TX_16X16,  // 16x64 transform
2706
0
      TX_16X16,  // 64x16 transform
2707
0
    };
2708
0
    mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
2709
0
  } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
2710
0
             !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2711
0
    static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
2712
0
      TX_4X4,    // 4x4 transform
2713
0
      TX_8X8,    // 8x8 transform
2714
0
      TX_16X16,  // 16x16 transform
2715
0
      TX_32X32,  // 32x32 transform
2716
0
      TX_32X32,  // 64x64 transform
2717
0
      TX_4X4,    // 4x8 transform
2718
0
      TX_4X4,    // 8x4 transform
2719
0
      TX_8X8,    // 8x16 transform
2720
0
      TX_8X8,    // 16x8 transform
2721
0
      TX_16X16,  // 16x32 transform
2722
0
      TX_16X16,  // 32x16 transform
2723
0
      TX_32X32,  // 32x64 transform
2724
0
      TX_32X32,  // 64x32 transform
2725
0
      TX_4X4,    // 4x16 transform
2726
0
      TX_4X4,    // 16x4 transform
2727
0
      TX_8X8,    // 8x32 transform
2728
0
      TX_8X8,    // 32x8 transform
2729
0
      TX_16X16,  // 16x64 transform
2730
0
      TX_16X16,  // 64x16 transform
2731
0
    };
2732
2733
0
    mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
2734
0
  }
2735
2736
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
2737
0
  const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
2738
0
  const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
2739
  // Skip RDcost is used only for Inter blocks
2740
0
  const int64_t skip_txfm_rd =
2741
0
      is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2742
0
  const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0);
2743
0
  const int skip_trellis = 0;
2744
0
  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2745
0
                       AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2746
0
                       mbmi->tx_size, FTXS_NONE, skip_trellis);
2747
0
}
2748
2749
static inline void choose_smallest_tx_size(const AV1_COMP *const cpi,
2750
                                           MACROBLOCK *x, RD_STATS *rd_stats,
2751
0
                                           int64_t ref_best_rd, BLOCK_SIZE bs) {
2752
0
  MACROBLOCKD *const xd = &x->e_mbd;
2753
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2754
2755
0
  mbmi->tx_size = TX_4X4;
2756
  // TODO(any) : Pass this_rd based on skip/non-skip cost
2757
0
  const int skip_trellis = 0;
2758
0
  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
2759
0
                       FTXS_NONE, skip_trellis);
2760
0
}
2761
2762
#if !CONFIG_REALTIME_ONLY
2763
static void ml_predict_intra_tx_depth_prune(MACROBLOCK *x, int blk_row,
2764
                                            int blk_col, BLOCK_SIZE bsize,
2765
0
                                            TX_SIZE tx_size) {
2766
0
  const MACROBLOCKD *const xd = &x->e_mbd;
2767
0
  const MB_MODE_INFO *const mbmi = xd->mi[0];
2768
2769
  // Disable the pruning logic using NN model for the following cases:
2770
  // 1) Lossless coding as only 4x4 transform is evaluated in this case
2771
  // 2) When transform and current block sizes do not match as the features are
2772
  // obtained over the current block
2773
  // 3) When operating bit-depth is not 8-bit as the input features are not
2774
  // scaled according to bit-depth.
2775
0
  if (xd->lossless[mbmi->segment_id] || txsize_to_bsize[tx_size] != bsize ||
2776
0
      xd->bd != 8)
2777
0
    return;
2778
2779
  // Currently NN model based pruning is supported only when largest transform
2780
  // size is 8x8
2781
0
  if (tx_size != TX_8X8) return;
2782
2783
  // Neural network model is a sequential neural net and was trained using SGD
2784
  // optimizer. The model can be further improved in terms of speed/quality by
2785
  // considering the following experiments:
2786
  // 1) Generate ML model by training with balanced data for different learning
2787
  // rates and optimizers.
2788
  // 2) Experiment with ML model by adding features related to the statistics of
2789
  // top and left pixels to capture the accuracy of reconstructed neighbouring
2790
  // pixels for 4x4 blocks numbered 1, 2, 3 in 8x8 block, source variance of 4x4
2791
  // sub-blocks, etc.
2792
  // 3) Generate ML models for transform blocks other than 8x8.
2793
0
  const NN_CONFIG *const nn_config = &av1_intra_tx_split_nnconfig_8x8;
2794
0
  const float *const intra_tx_prune_thresh = av1_intra_tx_prune_nn_thresh_8x8;
2795
2796
0
  float features[NUM_INTRA_TX_SPLIT_FEATURES] = { 0.0f };
2797
0
  const int diff_stride = block_size_wide[bsize];
2798
2799
0
  const int16_t *diff = x->plane[0].src_diff + MI_SIZE * blk_row * diff_stride +
2800
0
                        MI_SIZE * blk_col;
2801
0
  const int bw = tx_size_wide[tx_size];
2802
0
  const int bh = tx_size_high[tx_size];
2803
2804
0
  int feature_idx = get_mean_dev_features(diff, diff_stride, bw, bh, features);
2805
2806
0
  features[feature_idx++] = log1pf((float)x->source_variance);
2807
2808
0
  const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
2809
0
  const float log_dc_q_square = log1pf((float)(dc_q * dc_q) / 256.0f);
2810
0
  features[feature_idx++] = log_dc_q_square;
2811
0
  assert(feature_idx == NUM_INTRA_TX_SPLIT_FEATURES);
2812
0
  for (int i = 0; i < NUM_INTRA_TX_SPLIT_FEATURES; i++) {
2813
0
    features[i] = (features[i] - av1_intra_tx_split_8x8_mean[i]) /
2814
0
                  av1_intra_tx_split_8x8_std[i];
2815
0
  }
2816
2817
0
  float score;
2818
0
  av1_nn_predict(features, nn_config, 1, &score);
2819
2820
0
  TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2821
0
  if (score <= intra_tx_prune_thresh[0])
2822
0
    txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_SPLIT;
2823
0
  else if (score > intra_tx_prune_thresh[1])
2824
0
    txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_LARGEST;
2825
0
}
2826
#endif  // !CONFIG_REALTIME_ONLY
2827
2828
/*!\brief Transform type search for luma macroblock with fixed transform size.
2829
 *
2830
 * \ingroup transform_search
2831
 * Search for the best transform type and return the transform coefficients RD
2832
 * cost of current luma macroblock with the given uniform transform size.
2833
 *
2834
 * \param[in]    x              Pointer to structure holding the data for the
2835
                                current encoding macroblock
2836
 * \param[in]    cpi            Top-level encoder structure
2837
 * \param[in]    rd_stats       Pointer to struct to keep track of the RD stats
2838
 * \param[in]    ref_best_rd    Best RD cost seen for this block so far
2839
 * \param[in]    bs             Size of the current macroblock
2840
 * \param[in]    tx_size        The given transform size
2841
 * \param[in]    ftxs_mode      Transform search mode specifying desired speed
2842
                                and quality tradeoff
2843
 * \param[in]    skip_trellis   Binary flag indicating if trellis optimization
2844
                                should be skipped
2845
 * \return       An int64_t value that is the best RD cost found.
2846
 */
2847
static int64_t uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
2848
                                RD_STATS *rd_stats, int64_t ref_best_rd,
2849
                                BLOCK_SIZE bs, TX_SIZE tx_size,
2850
                                FAST_TX_SEARCH_MODE ftxs_mode,
2851
0
                                int skip_trellis) {
2852
0
  assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
2853
0
  MACROBLOCKD *const xd = &x->e_mbd;
2854
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2855
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2856
0
  const ModeCosts *mode_costs = &x->mode_costs;
2857
0
  const int is_inter = is_inter_block(mbmi);
2858
0
  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
2859
0
                        block_signals_txsize(mbmi->bsize);
2860
0
  int tx_size_rate = 0;
2861
0
  if (tx_select) {
2862
0
    const int ctx = txfm_partition_context(
2863
0
        xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
2864
0
    tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0]
2865
0
                            : tx_size_cost(x, bs, tx_size);
2866
0
  }
2867
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
2868
0
  const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
2869
0
  const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
2870
0
  const int64_t skip_txfm_rd =
2871
0
      is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2872
0
  const int64_t no_this_rd =
2873
0
      RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
2874
2875
0
  mbmi->tx_size = tx_size;
2876
0
  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2877
0
                       AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2878
0
                       tx_size, ftxs_mode, skip_trellis);
2879
0
  if (rd_stats->rate == INT_MAX) return INT64_MAX;
2880
2881
0
  int64_t rd;
2882
  // rdstats->rate should include all the rate except skip/non-skip cost as the
2883
  // same is accounted in the caller functions after rd evaluation of all
2884
  // planes. However the decisions should be done after considering the
2885
  // skip/non-skip header cost
2886
0
  if (rd_stats->skip_txfm && is_inter) {
2887
0
    rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
2888
0
  } else {
2889
    // Intra blocks are always signalled as non-skip
2890
0
    rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
2891
0
                rd_stats->dist);
2892
0
    rd_stats->rate += tx_size_rate;
2893
0
  }
2894
  // Check if forcing the block to skip transform leads to smaller RD cost.
2895
0
  if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
2896
0
    int64_t temp_skip_txfm_rd =
2897
0
        RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
2898
0
    if (temp_skip_txfm_rd <= rd) {
2899
0
      rd = temp_skip_txfm_rd;
2900
0
      rd_stats->rate = 0;
2901
0
      rd_stats->dist = rd_stats->sse;
2902
0
      rd_stats->skip_txfm = 1;
2903
0
    }
2904
0
  }
2905
2906
0
  return rd;
2907
0
}
2908
2909
// Search for the best uniform transform size and type for current coding block.
2910
static inline void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
2911
                                               MACROBLOCK *x,
2912
                                               RD_STATS *rd_stats,
2913
                                               int64_t ref_best_rd,
2914
0
                                               BLOCK_SIZE bs) {
2915
0
  av1_invalid_rd_stats(rd_stats);
2916
2917
0
  MACROBLOCKD *const xd = &x->e_mbd;
2918
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2919
0
  TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2920
0
  const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
2921
0
  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT;
2922
0
  int start_tx;
2923
  // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
2924
  // how many times of splitting is allowed during the RD search.
2925
0
  int init_depth;
2926
2927
0
  if (tx_select) {
2928
0
    start_tx = max_rect_tx_size;
2929
0
    init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
2930
0
                                       is_inter_block(mbmi), &cpi->sf,
2931
0
                                       txfm_params->tx_size_search_method);
2932
0
    if (init_depth == MAX_TX_DEPTH && !cpi->oxcf.txfm_cfg.enable_tx64 &&
2933
0
        txsize_sqr_up_map[start_tx] == TX_64X64) {
2934
0
      start_tx = sub_tx_size_map[start_tx];
2935
0
    }
2936
0
  } else {
2937
0
    const TX_SIZE chosen_tx_size =
2938
0
        tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2939
0
    start_tx = chosen_tx_size;
2940
0
    init_depth = MAX_TX_DEPTH;
2941
0
  }
2942
2943
0
  const int skip_trellis = 0;
2944
0
  uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
2945
0
  uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
2946
0
  TX_SIZE best_tx_size = max_rect_tx_size;
2947
0
  int64_t best_rd = INT64_MAX;
2948
0
  const int num_blks = bsize_to_num_blk(bs);
2949
0
  x->rd_model = FULL_TXFM_RD;
2950
0
  int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
2951
0
  TxfmSearchInfo *txfm_info = &x->txfm_search_info;
2952
0
  for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
2953
0
       depth++, tx_size = sub_tx_size_map[tx_size]) {
2954
0
    if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
2955
0
         txsize_sqr_up_map[tx_size] == TX_64X64) ||
2956
0
        (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
2957
0
         tx_size_wide[tx_size] != tx_size_high[tx_size])) {
2958
0
      continue;
2959
0
    }
2960
2961
0
#if !CONFIG_REALTIME_ONLY
2962
0
    if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_SPLIT) break;
2963
2964
    // Set the flag to enable the evaluation of NN classifier to prune transform
2965
    // depths. As the features are based on intra residual information of
2966
    // largest transform, the evaluation of NN model is enabled only for this
2967
    // case.
2968
0
    txfm_params->enable_nn_prune_intra_tx_depths =
2969
0
        (cpi->sf.tx_sf.prune_intra_tx_depths_using_nn && tx_size == start_tx);
2970
0
#endif
2971
2972
0
    RD_STATS this_rd_stats;
2973
    // When the speed feature use_rd_based_breakout_for_intra_tx_search is
2974
    // enabled, use the known minimum best_rd for early termination.
2975
0
    const int64_t rd_thresh =
2976
0
        cpi->sf.tx_sf.use_rd_based_breakout_for_intra_tx_search
2977
0
            ? AOMMIN(ref_best_rd, best_rd)
2978
0
            : ref_best_rd;
2979
0
    rd[depth] = uniform_txfm_yrd(cpi, x, &this_rd_stats, rd_thresh, bs, tx_size,
2980
0
                                 FTXS_NONE, skip_trellis);
2981
0
    if (rd[depth] < best_rd) {
2982
0
      av1_copy_array(best_blk_skip, txfm_info->blk_skip, num_blks);
2983
0
      av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
2984
0
      best_tx_size = tx_size;
2985
0
      best_rd = rd[depth];
2986
0
      *rd_stats = this_rd_stats;
2987
0
    }
2988
0
    if (tx_size == TX_4X4) break;
2989
    // If we are searching three depths, prune the smallest size depending
2990
    // on rd results for the first two depths for low contrast blocks.
2991
0
    if (depth > init_depth && depth != MAX_TX_DEPTH &&
2992
0
        x->source_variance < 256) {
2993
0
      if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
2994
0
    }
2995
0
  }
2996
2997
0
  if (rd_stats->rate != INT_MAX) {
2998
0
    mbmi->tx_size = best_tx_size;
2999
0
    av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
3000
0
    av1_copy_array(txfm_info->blk_skip, best_blk_skip, num_blks);
3001
0
  }
3002
3003
0
#if !CONFIG_REALTIME_ONLY
3004
  // Reset the flags to avoid any unintentional evaluation of NN model and
3005
  // consumption of prune depths.
3006
0
  txfm_params->enable_nn_prune_intra_tx_depths = false;
3007
0
  txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_NONE;
3008
0
#endif
3009
0
}
3010
3011
// Search for the best transform type for the given transform block in the
3012
// given plane/channel, and calculate the corresponding RD cost.
3013
static inline void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
3014
                                 BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
3015
0
                                 void *arg) {
3016
0
  struct rdcost_block_args *args = arg;
3017
0
  if (args->exit_early) {
3018
0
    args->incomplete_exit = 1;
3019
0
    return;
3020
0
  }
3021
3022
0
  MACROBLOCK *const x = args->x;
3023
0
  MACROBLOCKD *const xd = &x->e_mbd;
3024
0
  const int is_inter = is_inter_block(xd->mi[0]);
3025
0
  const AV1_COMP *cpi = args->cpi;
3026
0
  ENTROPY_CONTEXT *a = args->t_above + blk_col;
3027
0
  ENTROPY_CONTEXT *l = args->t_left + blk_row;
3028
0
  const AV1_COMMON *cm = &cpi->common;
3029
0
  RD_STATS this_rd_stats;
3030
0
  av1_init_rd_stats(&this_rd_stats);
3031
3032
0
  if (!is_inter) {
3033
0
    av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
3034
0
    av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
3035
0
#if !CONFIG_REALTIME_ONLY
3036
0
    const TxfmSearchParams *const txfm_params = &x->txfm_search_params;
3037
0
    if (txfm_params->enable_nn_prune_intra_tx_depths) {
3038
0
      ml_predict_intra_tx_depth_prune(x, blk_row, blk_col, plane_bsize,
3039
0
                                      tx_size);
3040
0
      if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_LARGEST) {
3041
0
        av1_invalid_rd_stats(&args->rd_stats);
3042
0
        args->exit_early = 1;
3043
0
        return;
3044
0
      }
3045
0
    }
3046
0
#endif
3047
0
  }
3048
3049
0
  TXB_CTX txb_ctx;
3050
0
  get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
3051
0
  search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3052
0
                 &txb_ctx, args->ftxs_mode, args->skip_trellis,
3053
0
                 args->best_rd - args->current_rd, &this_rd_stats);
3054
3055
0
#if !CONFIG_REALTIME_ONLY
3056
0
  if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
3057
0
    assert(!is_inter || plane_bsize < BLOCK_8X8);
3058
0
    cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
3059
0
  }
3060
0
#endif
3061
3062
#if CONFIG_RD_DEBUG
3063
  update_txb_coeff_cost(&this_rd_stats, plane, this_rd_stats.rate);
3064
#endif  // CONFIG_RD_DEBUG
3065
0
  av1_set_txb_context(x, plane, block, tx_size, a, l);
3066
3067
0
  const int blk_idx =
3068
0
      blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
3069
3070
0
  TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3071
0
  if (plane == 0)
3072
0
    set_blk_skip(txfm_info->blk_skip, plane, blk_idx,
3073
0
                 x->plane[plane].eobs[block] == 0);
3074
0
  else
3075
0
    set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0);
3076
3077
0
  int64_t rd;
3078
0
  if (is_inter) {
3079
0
    const int64_t no_skip_txfm_rd =
3080
0
        RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3081
0
    const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3082
0
    rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
3083
0
    this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block];
3084
0
  } else {
3085
    // Signal non-skip_txfm for Intra blocks
3086
0
    rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3087
0
    this_rd_stats.skip_txfm = 0;
3088
0
  }
3089
3090
0
  av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3091
3092
0
  args->current_rd += rd;
3093
0
  if (args->current_rd > args->best_rd) args->exit_early = 1;
3094
0
}
3095
3096
int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3097
                              RD_STATS *rd_stats, int64_t ref_best_rd,
3098
0
                              BLOCK_SIZE bs, TX_SIZE tx_size) {
3099
0
  MACROBLOCKD *const xd = &x->e_mbd;
3100
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3101
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3102
0
  const ModeCosts *mode_costs = &x->mode_costs;
3103
0
  const int is_inter = is_inter_block(mbmi);
3104
0
  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3105
0
                        block_signals_txsize(mbmi->bsize);
3106
0
  int tx_size_rate = 0;
3107
0
  if (tx_select) {
3108
0
    const int ctx = txfm_partition_context(
3109
0
        xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3110
0
    tx_size_rate = mode_costs->txfm_partition_cost[ctx][0];
3111
0
  }
3112
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
3113
0
  const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3114
0
  const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3115
0
  const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0);
3116
0
  const int64_t no_this_rd =
3117
0
      RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3118
0
  mbmi->tx_size = tx_size;
3119
3120
0
  const uint8_t txw_unit = tx_size_wide_unit[tx_size];
3121
0
  const uint8_t txh_unit = tx_size_high_unit[tx_size];
3122
0
  const int step = txw_unit * txh_unit;
3123
0
  const int max_blocks_wide = max_block_wide(xd, bs, 0);
3124
0
  const int max_blocks_high = max_block_high(xd, bs, 0);
3125
3126
0
  struct rdcost_block_args args;
3127
0
  av1_zero(args);
3128
0
  args.x = x;
3129
0
  args.cpi = cpi;
3130
0
  args.best_rd = ref_best_rd;
3131
0
  args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd);
3132
0
  av1_init_rd_stats(&args.rd_stats);
3133
0
  av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left);
3134
0
  int i = 0;
3135
0
  for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit;
3136
0
       blk_row += txh_unit) {
3137
0
    for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) {
3138
0
      RD_STATS this_rd_stats;
3139
0
      av1_init_rd_stats(&this_rd_stats);
3140
3141
0
      if (args.exit_early) {
3142
0
        args.incomplete_exit = 1;
3143
0
        break;
3144
0
      }
3145
3146
0
      ENTROPY_CONTEXT *a = args.t_above + blk_col;
3147
0
      ENTROPY_CONTEXT *l = args.t_left + blk_row;
3148
0
      TXB_CTX txb_ctx;
3149
0
      get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx);
3150
3151
0
      TxfmParam txfm_param;
3152
0
      QUANT_PARAM quant_param;
3153
0
      av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param);
3154
0
      av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param);
3155
3156
0
      av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param);
3157
0
      av1_quant(x, 0, i, &txfm_param, &quant_param);
3158
3159
0
      this_rd_stats.rate =
3160
0
          cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0);
3161
3162
0
      const SCAN_ORDER *const scan_order =
3163
0
          get_scan(txfm_param.tx_size, txfm_param.tx_type);
3164
0
      dist_block_tx_domain(x, 0, i, tx_size, quant_param.qmatrix,
3165
0
                           scan_order->scan, &this_rd_stats.dist,
3166
0
                           &this_rd_stats.sse);
3167
3168
0
      const int64_t no_skip_txfm_rd =
3169
0
          RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3170
0
      const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3171
3172
0
      this_rd_stats.skip_txfm &= !x->plane[0].eobs[i];
3173
3174
0
      av1_merge_rd_stats(&args.rd_stats, &this_rd_stats);
3175
0
      args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd);
3176
3177
0
      if (args.current_rd > ref_best_rd) {
3178
0
        args.exit_early = 1;
3179
0
        break;
3180
0
      }
3181
3182
0
      av1_set_txb_context(x, 0, i, tx_size, a, l);
3183
0
      i += step;
3184
0
    }
3185
0
  }
3186
3187
0
  if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats);
3188
3189
0
  *rd_stats = args.rd_stats;
3190
0
  if (rd_stats->rate == INT_MAX) return INT64_MAX;
3191
3192
0
  int64_t rd;
3193
  // rdstats->rate should include all the rate except skip/non-skip cost as the
3194
  // same is accounted in the caller functions after rd evaluation of all
3195
  // planes. However the decisions should be done after considering the
3196
  // skip/non-skip header cost
3197
0
  if (rd_stats->skip_txfm && is_inter) {
3198
0
    rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3199
0
  } else {
3200
    // Intra blocks are always signalled as non-skip
3201
0
    rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3202
0
                rd_stats->dist);
3203
0
    rd_stats->rate += tx_size_rate;
3204
0
  }
3205
  // Check if forcing the block to skip transform leads to smaller RD cost.
3206
0
  if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3207
0
    int64_t temp_skip_txfm_rd =
3208
0
        RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3209
0
    if (temp_skip_txfm_rd <= rd) {
3210
0
      rd = temp_skip_txfm_rd;
3211
0
      rd_stats->rate = 0;
3212
0
      rd_stats->dist = rd_stats->sse;
3213
0
      rd_stats->skip_txfm = 1;
3214
0
    }
3215
0
  }
3216
3217
0
  return rd;
3218
0
}
3219
3220
// Search for the best transform type for a luma inter-predicted block, given
3221
// the transform block partitions.
3222
// This function is used only when some speed features are enabled.
3223
static inline void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
3224
                                int blk_col, int block, TX_SIZE tx_size,
3225
                                BLOCK_SIZE plane_bsize, int depth,
3226
                                ENTROPY_CONTEXT *above_ctx,
3227
                                ENTROPY_CONTEXT *left_ctx,
3228
                                TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
3229
                                int64_t ref_best_rd, RD_STATS *rd_stats,
3230
0
                                FAST_TX_SEARCH_MODE ftxs_mode) {
3231
0
  assert(tx_size < TX_SIZES_ALL);
3232
0
  MACROBLOCKD *const xd = &x->e_mbd;
3233
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3234
0
  assert(is_inter_block(mbmi));
3235
0
  const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
3236
0
  const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
3237
3238
0
  if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
3239
3240
0
  const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
3241
0
      plane_bsize, blk_row, blk_col)];
3242
0
  const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
3243
0
                                         mbmi->bsize, tx_size);
3244
3245
0
  av1_init_rd_stats(rd_stats);
3246
0
  if (tx_size == plane_tx_size) {
3247
0
    ENTROPY_CONTEXT *ta = above_ctx + blk_col;
3248
0
    ENTROPY_CONTEXT *tl = left_ctx + blk_row;
3249
0
    const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
3250
0
    TXB_CTX txb_ctx;
3251
0
    get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
3252
3253
0
    const int zero_blk_rate =
3254
0
        x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)]
3255
0
            .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
3256
0
    rd_stats->zero_rate = zero_blk_rate;
3257
0
    tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
3258
0
               rd_stats, ftxs_mode, ref_best_rd);
3259
0
    const int mi_width = mi_size_wide[plane_bsize];
3260
0
    TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3261
0
    if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
3262
0
            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
3263
0
        rd_stats->skip_txfm == 1) {
3264
0
      rd_stats->rate = zero_blk_rate;
3265
0
      rd_stats->dist = rd_stats->sse;
3266
0
      rd_stats->skip_txfm = 1;
3267
0
      set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 1);
3268
0
      x->plane[0].eobs[block] = 0;
3269
0
      x->plane[0].txb_entropy_ctx[block] = 0;
3270
0
      update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
3271
0
    } else {
3272
0
      rd_stats->skip_txfm = 0;
3273
0
      set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 0);
3274
0
    }
3275
0
    if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3276
0
      rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0];
3277
0
    av1_set_txb_context(x, 0, block, tx_size, ta, tl);
3278
0
    txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
3279
0
                          tx_size);
3280
0
  } else {
3281
0
    const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
3282
0
    const int txb_width = tx_size_wide_unit[sub_txs];
3283
0
    const int txb_height = tx_size_high_unit[sub_txs];
3284
0
    const int step = txb_height * txb_width;
3285
0
    const int row_end =
3286
0
        AOMMIN(tx_size_high_unit[tx_size], max_blocks_high - blk_row);
3287
0
    const int col_end =
3288
0
        AOMMIN(tx_size_wide_unit[tx_size], max_blocks_wide - blk_col);
3289
0
    RD_STATS pn_rd_stats;
3290
0
    int64_t this_rd = 0;
3291
0
    assert(txb_width > 0 && txb_height > 0);
3292
3293
0
    for (int row = 0; row < row_end; row += txb_height) {
3294
0
      const int offsetr = blk_row + row;
3295
0
      for (int col = 0; col < col_end; col += txb_width) {
3296
0
        const int offsetc = blk_col + col;
3297
3298
0
        av1_init_rd_stats(&pn_rd_stats);
3299
0
        tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
3300
0
                     depth + 1, above_ctx, left_ctx, tx_above, tx_left,
3301
0
                     ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
3302
0
        if (pn_rd_stats.rate == INT_MAX) {
3303
0
          av1_invalid_rd_stats(rd_stats);
3304
0
          return;
3305
0
        }
3306
0
        av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3307
0
        this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
3308
0
        block += step;
3309
0
      }
3310
0
    }
3311
3312
0
    if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3313
0
      rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1];
3314
0
  }
3315
0
}
3316
3317
// search for tx type with tx sizes already decided for a inter-predicted luma
3318
// partition block. It's used only when some speed features are enabled.
3319
// Return value 0: early termination triggered, no valid rd cost available;
3320
//              1: rd cost values are valid.
3321
static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3322
                           RD_STATS *rd_stats, BLOCK_SIZE bsize,
3323
0
                           int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
3324
0
  if (ref_best_rd < 0) {
3325
0
    av1_invalid_rd_stats(rd_stats);
3326
0
    return 0;
3327
0
  }
3328
3329
0
  av1_init_rd_stats(rd_stats);
3330
3331
0
  MACROBLOCKD *const xd = &x->e_mbd;
3332
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3333
0
  const struct macroblockd_plane *const pd = &xd->plane[0];
3334
0
  const int mi_width = mi_size_wide[bsize];
3335
0
  const int mi_height = mi_size_high[bsize];
3336
0
  const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
3337
0
  const int bh = tx_size_high_unit[max_tx_size];
3338
0
  const int bw = tx_size_wide_unit[max_tx_size];
3339
0
  const int step = bw * bh;
3340
0
  const int init_depth = get_search_init_depth(
3341
0
      mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3342
0
  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3343
0
  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3344
0
  TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3345
0
  TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3346
0
  av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3347
0
  memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3348
0
  memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3349
3350
0
  int64_t this_rd = 0;
3351
0
  for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
3352
0
    for (int idx = 0; idx < mi_width; idx += bw) {
3353
0
      RD_STATS pn_rd_stats;
3354
0
      av1_init_rd_stats(&pn_rd_stats);
3355
0
      tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
3356
0
                   ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
3357
0
                   &pn_rd_stats, ftxs_mode);
3358
0
      if (pn_rd_stats.rate == INT_MAX) {
3359
0
        av1_invalid_rd_stats(rd_stats);
3360
0
        return 0;
3361
0
      }
3362
0
      av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3363
0
      this_rd +=
3364
0
          AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
3365
0
                 RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
3366
0
      block += step;
3367
0
    }
3368
0
  }
3369
3370
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
3371
0
  const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3372
0
  const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3373
0
  const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3374
0
  this_rd =
3375
0
      RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist);
3376
0
  if (skip_txfm_rd < this_rd) {
3377
0
    this_rd = skip_txfm_rd;
3378
0
    rd_stats->rate = 0;
3379
0
    rd_stats->dist = rd_stats->sse;
3380
0
    rd_stats->skip_txfm = 1;
3381
0
  }
3382
3383
0
  const int is_cost_valid = this_rd > ref_best_rd;
3384
0
  if (!is_cost_valid) {
3385
    // reset cost value
3386
0
    av1_invalid_rd_stats(rd_stats);
3387
0
  }
3388
0
  return is_cost_valid;
3389
0
}
3390
3391
// Search for the best transform size and type for current inter-predicted
3392
// luma block with recursive transform block partitioning. The obtained
3393
// transform selection will be saved in xd->mi[0], the corresponding RD stats
3394
// will be saved in rd_stats. The returned value is the corresponding RD cost.
3395
static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
3396
                                       RD_STATS *rd_stats, BLOCK_SIZE bsize,
3397
0
                                       int64_t ref_best_rd) {
3398
0
  MACROBLOCKD *const xd = &x->e_mbd;
3399
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3400
0
  assert(is_inter_block(xd->mi[0]));
3401
0
  assert(bsize < BLOCK_SIZES_ALL);
3402
0
  const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD;
3403
0
  int64_t rd_thresh = ref_best_rd;
3404
0
  if (rd_thresh == 0) {
3405
0
    av1_invalid_rd_stats(rd_stats);
3406
0
    return INT64_MAX;
3407
0
  }
3408
0
  if (fast_tx_search && rd_thresh < INT64_MAX) {
3409
0
    if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
3410
0
  }
3411
0
  assert(rd_thresh > 0);
3412
0
  const FAST_TX_SEARCH_MODE ftxs_mode =
3413
0
      fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
3414
0
  const struct macroblockd_plane *const pd = &xd->plane[0];
3415
0
  assert(bsize < BLOCK_SIZES_ALL);
3416
0
  const int mi_width = mi_size_wide[bsize];
3417
0
  const int mi_height = mi_size_high[bsize];
3418
0
  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3419
0
  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3420
0
  TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3421
0
  TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3422
0
  av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3423
0
  memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3424
0
  memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3425
0
  const int init_depth = get_search_init_depth(
3426
0
      mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3427
0
  const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
3428
0
  const int bh = tx_size_high_unit[max_tx_size];
3429
0
  const int bw = tx_size_wide_unit[max_tx_size];
3430
0
  const int step = bw * bh;
3431
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
3432
0
  const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3433
0
  const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3434
0
  int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0);
3435
0
  int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0);
3436
0
  int block = 0;
3437
3438
0
  av1_init_rd_stats(rd_stats);
3439
0
  for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
3440
0
    for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
3441
0
      const int64_t best_rd_sofar =
3442
0
          (rd_thresh == INT64_MAX)
3443
0
              ? INT64_MAX
3444
0
              : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd)));
3445
0
      int is_cost_valid = 1;
3446
0
      RD_STATS pn_rd_stats;
3447
      // Search for the best transform block size and type for the sub-block.
3448
0
      select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
3449
0
                      ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
3450
0
                      best_rd_sofar, &is_cost_valid, ftxs_mode);
3451
0
      if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
3452
0
        av1_invalid_rd_stats(rd_stats);
3453
0
        return INT64_MAX;
3454
0
      }
3455
0
      av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3456
0
      skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3457
0
      no_skip_txfm_rd =
3458
0
          RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3459
0
      block += step;
3460
0
    }
3461
0
  }
3462
3463
0
  if (rd_stats->rate == INT_MAX) return INT64_MAX;
3464
3465
0
  rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd);
3466
3467
  // If fast_tx_search is true, only DCT and 1D DCT were tested in
3468
  // select_inter_block_yrd() above. Do a better search for tx type with
3469
  // tx sizes already decided.
3470
0
  if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
3471
0
    if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
3472
0
      return INT64_MAX;
3473
0
  }
3474
3475
0
  int64_t final_rd;
3476
0
  if (rd_stats->skip_txfm) {
3477
0
    final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3478
0
  } else {
3479
0
    final_rd =
3480
0
        RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3481
0
    if (!xd->lossless[xd->mi[0]->segment_id]) {
3482
0
      final_rd =
3483
0
          AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse));
3484
0
    }
3485
0
  }
3486
3487
0
  return final_rd;
3488
0
}
3489
3490
// Return 1 to terminate transform search early. The decision is made based on
3491
// the comparison with the reference RD cost and the model-estimated RD cost.
3492
static inline int model_based_tx_search_prune(const AV1_COMP *cpi,
3493
                                              MACROBLOCK *x, BLOCK_SIZE bsize,
3494
0
                                              int64_t ref_best_rd) {
3495
0
  const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
3496
0
  assert(level >= 0 && level <= 2);
3497
0
  int model_rate;
3498
0
  int64_t model_dist;
3499
0
  uint8_t model_skip;
3500
0
  MACROBLOCKD *const xd = &x->e_mbd;
3501
0
  model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
3502
0
      cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
3503
0
      NULL, NULL, NULL);
3504
0
  if (model_skip) return 0;
3505
0
  const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
3506
  // TODO(debargha, urvang): Improve the model and make the check below
3507
  // tighter.
3508
0
  static const int prune_factor_by8[] = { 3, 5 };
3509
0
  const int factor = prune_factor_by8[level - 1];
3510
0
  return ((model_rd * factor) >> 3) > ref_best_rd;
3511
0
}
3512
3513
void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3514
                                         RD_STATS *rd_stats, BLOCK_SIZE bsize,
3515
0
                                         int64_t ref_best_rd) {
3516
0
  MACROBLOCKD *const xd = &x->e_mbd;
3517
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3518
0
  assert(is_inter_block(xd->mi[0]));
3519
3520
0
  av1_invalid_rd_stats(rd_stats);
3521
3522
  // If modeled RD cost is a lot worse than the best so far, terminate early.
3523
0
  if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
3524
0
      ref_best_rd != INT64_MAX) {
3525
0
    if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
3526
0
  }
3527
3528
  // Hashing based speed feature. If the hash of the prediction residue block is
3529
  // found in the hash table, use previous search results and terminate early.
3530
0
  uint32_t hash = 0;
3531
0
  MB_RD_RECORD *mb_rd_record = NULL;
3532
0
  const int mi_row = x->e_mbd.mi_row;
3533
0
  const int mi_col = x->e_mbd.mi_col;
3534
0
  const int within_border =
3535
0
      mi_row >= xd->tile.mi_row_start &&
3536
0
      (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
3537
0
      mi_col >= xd->tile.mi_col_start &&
3538
0
      (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
3539
0
  const int is_mb_rd_hash_enabled =
3540
0
      (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
3541
0
  const int n4 = bsize_to_num_blk(bsize);
3542
0
  if (is_mb_rd_hash_enabled) {
3543
0
    hash = get_block_residue_hash(x, bsize);
3544
0
    mb_rd_record = x->txfm_search_info.mb_rd_record;
3545
0
    const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3546
0
    if (match_index != -1) {
3547
0
      MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3548
0
      fetch_mb_rd_info(n4, mb_rd_info, rd_stats, x);
3549
0
      return;
3550
0
    }
3551
0
  }
3552
3553
  // If we predict that skip is the optimal RD decision - set the respective
3554
  // context and terminate early.
3555
0
  int64_t dist;
3556
0
  if (txfm_params->skip_txfm_level &&
3557
0
      predict_skip_txfm(x, bsize, &dist,
3558
0
                        cpi->common.features.reduced_tx_set_used)) {
3559
0
    set_skip_txfm(x, rd_stats, bsize, dist);
3560
    // Save the RD search results into mb_rd_record.
3561
0
    if (is_mb_rd_hash_enabled)
3562
0
      save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3563
0
    return;
3564
0
  }
3565
#if CONFIG_SPEED_STATS
3566
  ++x->txfm_search_info.tx_search_count;
3567
#endif  // CONFIG_SPEED_STATS
3568
3569
0
  const int64_t rd =
3570
0
      select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd);
3571
3572
0
  if (rd == INT64_MAX) {
3573
    // We should always find at least one candidate unless ref_best_rd is less
3574
    // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
3575
    // might have failed to find something better)
3576
0
    assert(ref_best_rd != INT64_MAX);
3577
0
    av1_invalid_rd_stats(rd_stats);
3578
0
    return;
3579
0
  }
3580
3581
  // Save the RD search results into mb_rd_record.
3582
0
  if (is_mb_rd_hash_enabled) {
3583
0
    assert(mb_rd_record != NULL);
3584
0
    save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3585
0
  }
3586
0
}
3587
3588
void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3589
                                       RD_STATS *rd_stats, BLOCK_SIZE bs,
3590
0
                                       int64_t ref_best_rd) {
3591
0
  MACROBLOCKD *const xd = &x->e_mbd;
3592
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3593
0
  const TxfmSearchParams *tx_params = &x->txfm_search_params;
3594
0
  assert(bs == mbmi->bsize);
3595
0
  const int is_inter = is_inter_block(mbmi);
3596
0
  const int mi_row = xd->mi_row;
3597
0
  const int mi_col = xd->mi_col;
3598
3599
0
  av1_init_rd_stats(rd_stats);
3600
3601
  // Hashing based speed feature for inter blocks. If the hash of the residue
3602
  // block is found in the table, use previously saved search results and
3603
  // terminate early.
3604
0
  uint32_t hash = 0;
3605
0
  MB_RD_RECORD *mb_rd_record = NULL;
3606
0
  const int num_blks = bsize_to_num_blk(bs);
3607
0
  if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
3608
0
    const int within_border =
3609
0
        mi_row >= xd->tile.mi_row_start &&
3610
0
        (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
3611
0
        mi_col >= xd->tile.mi_col_start &&
3612
0
        (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
3613
0
    if (within_border) {
3614
0
      hash = get_block_residue_hash(x, bs);
3615
0
      mb_rd_record = x->txfm_search_info.mb_rd_record;
3616
0
      const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3617
0
      if (match_index != -1) {
3618
0
        MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3619
0
        fetch_mb_rd_info(num_blks, mb_rd_info, rd_stats, x);
3620
0
        return;
3621
0
      }
3622
0
    }
3623
0
  }
3624
3625
  // If we predict that skip is the optimal RD decision - set the respective
3626
  // context and terminate early.
3627
0
  int64_t dist;
3628
0
  if (tx_params->skip_txfm_level && is_inter &&
3629
0
      !xd->lossless[mbmi->segment_id] &&
3630
0
      predict_skip_txfm(x, bs, &dist,
3631
0
                        cpi->common.features.reduced_tx_set_used)) {
3632
    // Populate rdstats as per skip decision
3633
0
    set_skip_txfm(x, rd_stats, bs, dist);
3634
    // Save the RD search results into mb_rd_record.
3635
0
    if (mb_rd_record) {
3636
0
      save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3637
0
    }
3638
0
    return;
3639
0
  }
3640
3641
0
  if (xd->lossless[mbmi->segment_id]) {
3642
    // Lossless mode can only pick the smallest (4x4) transform size.
3643
0
    choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3644
0
  } else if (tx_params->tx_size_search_method == USE_LARGESTALL) {
3645
0
    choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3646
0
  } else {
3647
0
    choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3648
0
  }
3649
3650
  // Save the RD search results into mb_rd_record for possible reuse in future.
3651
0
  if (mb_rd_record) {
3652
0
    save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3653
0
  }
3654
0
}
3655
3656
int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
3657
0
                  BLOCK_SIZE bsize, int64_t ref_best_rd) {
3658
0
  av1_init_rd_stats(rd_stats);
3659
0
  if (ref_best_rd < 0) return 0;
3660
0
  if (!x->e_mbd.is_chroma_ref) return 1;
3661
3662
0
  MACROBLOCKD *const xd = &x->e_mbd;
3663
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3664
0
  struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
3665
0
  const int is_inter = is_inter_block(mbmi);
3666
0
  int64_t this_rd = 0, skip_txfm_rd = 0;
3667
0
  const BLOCK_SIZE plane_bsize =
3668
0
      get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3669
3670
0
  if (is_inter) {
3671
0
    for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
3672
0
      av1_subtract_plane(x, plane_bsize, plane);
3673
0
  }
3674
3675
0
  const int skip_trellis = 0;
3676
0
  const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
3677
0
  int is_cost_valid = 1;
3678
0
  for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
3679
0
    RD_STATS this_rd_stats;
3680
0
    int64_t chroma_ref_best_rd = ref_best_rd;
3681
    // For inter blocks, refined ref_best_rd is used for early exit
3682
    // For intra blocks, even though current rd crosses ref_best_rd, early
3683
    // exit is not recommended as current rd is used for gating subsequent
3684
    // modes as well (say, for angular modes)
3685
    // TODO(any): Extend the early exit mechanism for intra modes as well
3686
0
    if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
3687
0
        chroma_ref_best_rd != INT64_MAX)
3688
0
      chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
3689
0
    av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
3690
0
                         plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
3691
0
    if (this_rd_stats.rate == INT_MAX) {
3692
0
      is_cost_valid = 0;
3693
0
      break;
3694
0
    }
3695
0
    av1_merge_rd_stats(rd_stats, &this_rd_stats);
3696
0
    this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3697
0
    skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
3698
0
    if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
3699
0
      is_cost_valid = 0;
3700
0
      break;
3701
0
    }
3702
0
  }
3703
3704
0
  if (!is_cost_valid) {
3705
    // reset cost value
3706
0
    av1_invalid_rd_stats(rd_stats);
3707
0
  }
3708
3709
0
  return is_cost_valid;
3710
0
}
3711
3712
void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3713
                          RD_STATS *rd_stats, int64_t ref_best_rd,
3714
                          int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
3715
                          TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
3716
0
                          int skip_trellis) {
3717
0
  assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
3718
3719
0
  if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3720
0
      txsize_sqr_up_map[tx_size] == TX_64X64) {
3721
0
    av1_invalid_rd_stats(rd_stats);
3722
0
    return;
3723
0
  }
3724
3725
0
  if (current_rd > ref_best_rd) {
3726
0
    av1_invalid_rd_stats(rd_stats);
3727
0
    return;
3728
0
  }
3729
3730
0
  MACROBLOCKD *const xd = &x->e_mbd;
3731
0
  const struct macroblockd_plane *const pd = &xd->plane[plane];
3732
0
  struct rdcost_block_args args;
3733
0
  av1_zero(args);
3734
0
  args.x = x;
3735
0
  args.cpi = cpi;
3736
0
  args.best_rd = ref_best_rd;
3737
0
  args.current_rd = current_rd;
3738
0
  args.ftxs_mode = ftxs_mode;
3739
0
  args.skip_trellis = skip_trellis;
3740
0
  av1_init_rd_stats(&args.rd_stats);
3741
3742
0
  av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
3743
0
  av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
3744
0
                                         &args);
3745
3746
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3747
0
  const int is_inter = is_inter_block(mbmi);
3748
0
  const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3749
3750
0
  if (invalid_rd) {
3751
0
    av1_invalid_rd_stats(rd_stats);
3752
0
  } else {
3753
0
    *rd_stats = args.rd_stats;
3754
0
  }
3755
0
}
3756
3757
int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
3758
                    RD_STATS *rd_stats, RD_STATS *rd_stats_y,
3759
0
                    RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
3760
0
  MACROBLOCKD *const xd = &x->e_mbd;
3761
0
  TxfmSearchParams *txfm_params = &x->txfm_search_params;
3762
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
3763
0
  const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0],
3764
0
                                  x->mode_costs.skip_txfm_cost[skip_ctx][1] };
3765
0
  const int64_t min_header_rate =
3766
0
      mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]);
3767
  // Account for minimum skip and non_skip rd.
3768
  // Eventually either one of them will be added to mode_rate
3769
0
  const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
3770
0
  if (min_header_rd_possible > ref_best_rd) {
3771
0
    av1_invalid_rd_stats(rd_stats_y);
3772
0
    return 0;
3773
0
  }
3774
3775
0
  const AV1_COMMON *cm = &cpi->common;
3776
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3777
0
  const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
3778
0
  const int64_t rd_thresh =
3779
0
      ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
3780
0
  av1_init_rd_stats(rd_stats);
3781
0
  av1_init_rd_stats(rd_stats_y);
3782
0
  rd_stats->rate = mode_rate;
3783
3784
  // cost and distortion
3785
0
  av1_subtract_plane(x, bsize, 0);
3786
0
  if (txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3787
0
      !xd->lossless[mbmi->segment_id]) {
3788
0
    av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3789
#if CONFIG_COLLECT_RD_STATS == 2
3790
    PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
3791
#endif  // CONFIG_COLLECT_RD_STATS == 2
3792
0
  } else {
3793
0
    av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3794
0
    memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
3795
0
    for (int i = 0; i < xd->height * xd->width; ++i)
3796
0
      set_blk_skip(x->txfm_search_info.blk_skip, 0, i, rd_stats_y->skip_txfm);
3797
0
  }
3798
3799
0
  if (rd_stats_y->rate == INT_MAX) return 0;
3800
3801
0
  av1_merge_rd_stats(rd_stats, rd_stats_y);
3802
3803
0
  const int64_t non_skip_txfm_rdcosty =
3804
0
      RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist);
3805
0
  const int64_t skip_txfm_rdcosty =
3806
0
      RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse);
3807
0
  const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty);
3808
0
  if (min_rdcosty > ref_best_rd) return 0;
3809
3810
0
  av1_init_rd_stats(rd_stats_uv);
3811
0
  const int num_planes = av1_num_planes(cm);
3812
0
  if (num_planes > 1) {
3813
0
    int64_t ref_best_chroma_rd = ref_best_rd;
3814
    // Calculate best rd cost possible for chroma
3815
0
    if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
3816
0
        (ref_best_chroma_rd != INT64_MAX)) {
3817
0
      ref_best_chroma_rd = (ref_best_chroma_rd -
3818
0
                            AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty));
3819
0
    }
3820
0
    const int is_cost_valid_uv =
3821
0
        av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
3822
0
    if (!is_cost_valid_uv) return 0;
3823
0
    av1_merge_rd_stats(rd_stats, rd_stats_uv);
3824
0
  }
3825
3826
0
  int choose_skip_txfm = rd_stats->skip_txfm;
3827
0
  if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) {
3828
0
    const int64_t rdcost_no_skip_txfm = RDCOST(
3829
0
        x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0],
3830
0
        rd_stats->dist);
3831
0
    const int64_t rdcost_skip_txfm =
3832
0
        RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse);
3833
0
    if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1;
3834
0
  }
3835
0
  if (choose_skip_txfm) {
3836
0
    rd_stats_y->rate = 0;
3837
0
    rd_stats_uv->rate = 0;
3838
0
    rd_stats->rate = mode_rate + skip_txfm_cost[1];
3839
0
    rd_stats->dist = rd_stats->sse;
3840
0
    rd_stats_y->dist = rd_stats_y->sse;
3841
0
    rd_stats_uv->dist = rd_stats_uv->sse;
3842
0
    mbmi->skip_txfm = 1;
3843
0
    if (rd_stats->skip_txfm) {
3844
0
      const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3845
0
      if (tmprd > ref_best_rd) return 0;
3846
0
    }
3847
0
  } else {
3848
0
    rd_stats->rate += skip_txfm_cost[0];
3849
0
    mbmi->skip_txfm = 0;
3850
0
  }
3851
3852
0
  return 1;
3853
0
}