Coverage Report

Created: 2026-06-16 07:20

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/aom/av1/encoder/tx_search.c
Line
Count
Source
1
/*
2
 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3
 *
4
 * This source code is subject to the terms of the BSD 2 Clause License and
5
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6
 * was not distributed with this source code in the LICENSE file, you can
7
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8
 * Media Patent License 1.0 was not distributed with this source code in the
9
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
 */
11
12
#include <inttypes.h>
13
14
#include "av1/common/cfl.h"
15
#include "av1/common/reconintra.h"
16
#include "av1/encoder/block.h"
17
#include "av1/encoder/hybrid_fwd_txfm.h"
18
#include "av1/common/idct.h"
19
#include "av1/encoder/model_rd.h"
20
#include "av1/encoder/random.h"
21
#include "av1/encoder/rdopt_utils.h"
22
#include "av1/encoder/sorting_network.h"
23
#include "av1/encoder/tx_prune_model_weights.h"
24
#include "av1/encoder/tx_search.h"
25
#include "av1/encoder/txb_rdopt.h"
26
27
0
#define PROB_THRESH_OFFSET_TX_TYPE 100
28
29
struct rdcost_block_args {
30
  const AV1_COMP *cpi;
31
  MACROBLOCK *x;
32
  ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
33
  ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
34
  RD_STATS rd_stats;
35
  int64_t current_rd;
36
  int64_t best_rd;
37
  int exit_early;
38
  int incomplete_exit;
39
  FAST_TX_SEARCH_MODE ftxs_mode;
40
  int skip_trellis;
41
};
42
43
typedef struct {
44
  int64_t rd;
45
  int txb_entropy_ctx;
46
  TX_TYPE tx_type;
47
} TxCandidateInfo;
48
49
// origin_threshold * 128 / 100
50
static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
51
  {
52
      64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
53
      68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
54
  },
55
  {
56
      88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
57
      68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
58
  },
59
  {
60
      90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
61
      74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
62
  },
63
};
64
65
// lookup table for predict_skip_txfm
66
// int max_tx_size = max_txsize_rect_lookup[bsize];
67
// if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
68
//   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
69
static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
70
  TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
71
  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
72
  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
73
  TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
74
};
75
76
// look-up table for sqrt of number of pixels in a transform block
77
// rounded up to the nearest integer.
78
static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4,  8,  16, 32, 32, 6,  6,
79
                                                     12, 12, 23, 23, 32, 32, 8,
80
                                                     8,  16, 16, 23, 23 };
81
82
// look-up table for number of top no-split RD Costs that should be considered
83
// based on prune_inter_tx_split_rd_eval_lvl speed feature.
84
static const int num_inter_tx_no_split_cand[2] = { 4, 3 };
85
86
0
static inline uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
87
0
  const int rows = block_size_high[bsize];
88
0
  const int cols = block_size_wide[bsize];
89
0
  const int16_t *diff = x->plane[0].src_diff;
90
0
  const uint32_t hash =
91
0
      av1_get_crc32c_value(&x->txfm_search_info.mb_rd_record->crc_calculator,
92
0
                           (uint8_t *)diff, 2 * rows * cols);
93
0
  return (hash << 5) + bsize;
94
0
}
95
96
static inline int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
97
                                      const int64_t ref_best_rd,
98
0
                                      const uint32_t hash) {
99
0
  int32_t match_index = -1;
100
0
  if (ref_best_rd != INT64_MAX) {
101
0
    for (int i = 0; i < mb_rd_record->num; ++i) {
102
0
      const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
103
      // If there is a match in the mb_rd_record, fetch the RD decision and
104
      // terminate early.
105
0
      if (mb_rd_record->mb_rd_info[index].hash_value == hash) {
106
0
        match_index = index;
107
0
        break;
108
0
      }
109
0
    }
110
0
  }
111
0
  return match_index;
112
0
}
113
114
static inline void fetch_mb_rd_info(int n4, const MB_RD_INFO *const mb_rd_info,
115
                                    RD_STATS *const rd_stats,
116
0
                                    MACROBLOCK *const x) {
117
0
  MACROBLOCKD *const xd = &x->e_mbd;
118
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
119
0
  mbmi->tx_size = mb_rd_info->tx_size;
120
0
  av1_copy(mbmi->inter_tx_size, mb_rd_info->inter_tx_size);
121
0
  av1_copy_array(xd->tx_type_map, mb_rd_info->tx_type_map, n4);
122
0
  *rd_stats = mb_rd_info->rd_stats;
123
0
}
124
125
int64_t av1_pixel_diff_dist(const MACROBLOCK *x, int plane, int blk_row,
126
                            int blk_col, const BLOCK_SIZE plane_bsize,
127
                            const BLOCK_SIZE tx_bsize,
128
0
                            unsigned int *block_mse_q8) {
129
0
  int visible_rows, visible_cols;
130
0
  const MACROBLOCKD *xd = &x->e_mbd;
131
0
  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
132
0
                     NULL, &visible_cols, &visible_rows);
133
0
  const int diff_stride = block_size_wide[plane_bsize];
134
0
  const int16_t *diff = x->plane[plane].src_diff;
135
136
0
  diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
137
0
  uint64_t sse =
138
0
      aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
139
0
  if (block_mse_q8 != NULL) {
140
0
    if (visible_cols > 0 && visible_rows > 0)
141
0
      *block_mse_q8 =
142
0
          (unsigned int)((256 * sse) / (visible_cols * visible_rows));
143
0
    else
144
0
      *block_mse_q8 = UINT_MAX;
145
0
  }
146
0
  return sse;
147
0
}
148
149
// Computes the residual block's SSE and mean on all visible 4x4s in the
150
// transform block
151
static inline int64_t pixel_diff_stats(
152
    MACROBLOCK *x, int plane, int blk_row, int blk_col,
153
    const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
154
0
    unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {
155
0
  int visible_rows, visible_cols;
156
0
  const MACROBLOCKD *xd = &x->e_mbd;
157
0
  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
158
0
                     NULL, &visible_cols, &visible_rows);
159
0
  const int diff_stride = block_size_wide[plane_bsize];
160
0
  const int16_t *diff = x->plane[plane].src_diff;
161
162
0
  diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
163
0
  uint64_t sse = 0;
164
0
  int sum = 0;
165
0
  sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum);
166
0
  if (visible_cols > 0 && visible_rows > 0) {
167
0
    double norm_factor = 1.0 / (visible_cols * visible_rows);
168
0
    int sign_sum = sum > 0 ? 1 : -1;
169
    // Conversion to transform domain
170
0
    *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7;
171
0
    *per_px_mean = sign_sum * (*per_px_mean);
172
0
    *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse));
173
0
    *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum));
174
0
  } else {
175
0
    *block_mse_q8 = UINT_MAX;
176
0
  }
177
0
  return sse;
178
0
}
179
180
// Uses simple features on top of DCT coefficients to quickly predict
181
// whether optimal RD decision is to skip encoding the residual.
182
// The sse value is stored in dist.
183
static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
184
0
                             int reduced_tx_set) {
185
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
186
0
  const int bw = block_size_wide[bsize];
187
0
  const int bh = block_size_high[bsize];
188
0
  const MACROBLOCKD *xd = &x->e_mbd;
189
0
  const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
190
191
0
  *dist = av1_pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
192
193
0
  const int64_t mse = *dist / bw / bh;
194
  // Normalized quantizer takes the transform upscaling factor (8 for tx size
195
  // smaller than 32) into account.
196
0
  const int16_t normalized_dc_q = dc_q >> 3;
197
0
  const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
198
  // For faster early skip decision, use dist to compare against threshold so
199
  // that quality risk is less for the skip=1 decision. Otherwise, use mse
200
  // since the fwd_txfm coeff checks will take care of quality
201
  // TODO(any): Use dist to return 0 when skip_txfm_level is 1
202
0
  int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse;
203
  // Predict not to skip when error is larger than threshold.
204
0
  if (pred_err > mse_thresh) return 0;
205
  // Return as skip otherwise for aggressive early skip
206
0
  else if (txfm_params->skip_txfm_level >= 2)
207
0
    return 1;
208
209
0
  const int max_tx_size = max_predict_sf_tx_size[bsize];
210
0
  const int tx_h = tx_size_high[max_tx_size];
211
0
  const int tx_w = tx_size_wide[max_tx_size];
212
0
  DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
213
0
  TxfmParam param;
214
0
  param.tx_type = DCT_DCT;
215
0
  param.tx_size = max_tx_size;
216
0
  param.bd = xd->bd;
217
0
  param.is_hbd = is_cur_buf_hbd(xd);
218
0
  param.lossless = 0;
219
0
  param.tx_set_type = av1_get_ext_tx_set_type(
220
0
      param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
221
0
  const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
222
0
  const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
223
0
  const int16_t *src_diff = x->plane[0].src_diff;
224
0
  const int n_coeff = tx_w * tx_h;
225
0
  const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
226
0
  const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
227
0
  const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
228
0
  for (int row = 0; row < bh; row += tx_h) {
229
0
    for (int col = 0; col < bw; col += tx_w) {
230
0
      av1_fwd_txfm(src_diff + col, coefs, bw, &param);
231
      // Operating on TX domain, not pixels; we want the QTX quantizers
232
0
      const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
233
0
      if (dc_coef >= dc_thresh) return 0;
234
0
      for (int i = 1; i < n_coeff; ++i) {
235
0
        const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
236
0
        if (ac_coef >= ac_thresh) return 0;
237
0
      }
238
0
    }
239
0
    src_diff += tx_h * bw;
240
0
  }
241
0
  return 1;
242
0
}
243
244
// Used to set proper context for early termination with skip = 1.
245
static inline void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
246
0
                                 BLOCK_SIZE bsize, int64_t dist) {
247
0
  MACROBLOCKD *const xd = &x->e_mbd;
248
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
249
0
  const int n4 = bsize_to_num_blk(bsize);
250
0
  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
251
0
  memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
252
0
  memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
253
0
  mbmi->tx_size = tx_size;
254
0
  rd_stats->skip_txfm = 1;
255
0
  if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
256
0
  rd_stats->dist = rd_stats->sse = (dist << 4);
257
  // Though decision is to make the block as skip based on luma stats,
258
  // it is possible that block becomes non skip after chroma rd. In addition
259
  // intermediate non skip costs calculated by caller function will be
260
  // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
261
  // accounted). Hence intermediate rate is populated to code the luma tx blks
262
  // as skip, the caller function based on final rd decision (i.e., skip vs
263
  // non-skip) sets the final rate accordingly. Here the rate populated
264
  // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
265
  // size possible) in the current block. Eg: For 128*128 block, rate would be
266
  // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
267
  // block as 'all zeros'
268
0
  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
269
0
  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
270
0
  av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
271
0
  ENTROPY_CONTEXT *ta = ctxa;
272
0
  ENTROPY_CONTEXT *tl = ctxl;
273
0
  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
274
0
  TXB_CTX txb_ctx;
275
0
  get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
276
0
  const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
277
0
                                .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
278
0
  rd_stats->rate = zero_blk_rate *
279
0
                   (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
280
0
                   (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
281
0
}
282
283
static inline void save_mb_rd_info(int n4, uint32_t hash,
284
                                   const MACROBLOCK *const x,
285
                                   const RD_STATS *const rd_stats,
286
0
                                   MB_RD_RECORD *mb_rd_record) {
287
0
  int index;
288
0
  if (mb_rd_record->num < RD_RECORD_BUFFER_LEN) {
289
0
    index =
290
0
        (mb_rd_record->index_start + mb_rd_record->num) % RD_RECORD_BUFFER_LEN;
291
0
    ++mb_rd_record->num;
292
0
  } else {
293
0
    index = mb_rd_record->index_start;
294
0
    mb_rd_record->index_start =
295
0
        (mb_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
296
0
  }
297
0
  MB_RD_INFO *const mb_rd_info = &mb_rd_record->mb_rd_info[index];
298
0
  const MACROBLOCKD *const xd = &x->e_mbd;
299
0
  const MB_MODE_INFO *const mbmi = xd->mi[0];
300
0
  mb_rd_info->hash_value = hash;
301
0
  mb_rd_info->tx_size = mbmi->tx_size;
302
0
  av1_copy(mb_rd_info->inter_tx_size, mbmi->inter_tx_size);
303
0
  av1_copy_array(mb_rd_info->tx_type_map, xd->tx_type_map, n4);
304
0
  mb_rd_info->rd_stats = *rd_stats;
305
0
}
306
307
// Store the RD Cost of transform no-split.
308
static inline void push_inter_block_tx_no_split_rd(
309
    MACROBLOCK *x, const MB_MODE_INFO *mbmi, int64_t tmp_rd, int blk_idx,
310
0
    int prune_inter_tx_split_rd_eval_lvl) {
311
0
  assert(blk_idx < MAX_TX_BLOCKS_IN_MAX_SB);
312
0
  if (!prune_inter_tx_split_rd_eval_lvl) return;
313
314
0
  if (blk_idx == -1 || tmp_rd == INT64_MAX) return;
315
316
  // Do not store for skip and intraBC modes
317
0
  if (mbmi->skip_mode != 0 || is_intrabc_block(mbmi)) return;
318
319
0
  assert(prune_inter_tx_split_rd_eval_lvl <= 2);
320
0
  const int num_top_cand =
321
0
      num_inter_tx_no_split_cand[prune_inter_tx_split_rd_eval_lvl - 1];
322
0
  assert(num_top_cand <= TOP_INTER_TX_NO_SPLIT_COUNT);
323
324
  // Insert the RD Cost in sorted order
325
0
  for (int i = 0; i < num_top_cand; i++) {
326
0
    if (tmp_rd < x->top_inter_tx_no_split_rd[blk_idx][i]) {
327
0
      for (int j = num_top_cand - 1; j > i; j--) {
328
0
        x->top_inter_tx_no_split_rd[blk_idx][j] =
329
0
            x->top_inter_tx_no_split_rd[blk_idx][j - 1];
330
0
      }
331
0
      x->top_inter_tx_no_split_rd[blk_idx][i] = tmp_rd;
332
0
      break;
333
0
    }
334
0
  }
335
0
}
336
337
// Prune the evaluation of transform split.
338
static inline bool prune_tx_split_eval_using_no_split_rd(
339
    const MACROBLOCK *x, const MB_MODE_INFO *mbmi, int64_t tmp_rd, int blk_idx,
340
0
    int prune_inter_tx_split_rd_eval_lvl) {
341
0
  if (!prune_inter_tx_split_rd_eval_lvl) return false;
342
343
0
  if (blk_idx == -1 || tmp_rd == INT64_MAX) return false;
344
345
  // Do not prune for skip and intraBC modes
346
0
  if (mbmi->skip_mode != 0 || is_intrabc_block(mbmi)) return false;
347
348
0
  assert(prune_inter_tx_split_rd_eval_lvl <= 2);
349
0
  const int num_top_cand =
350
0
      num_inter_tx_no_split_cand[prune_inter_tx_split_rd_eval_lvl - 1];
351
0
  assert(num_top_cand <= TOP_INTER_TX_NO_SPLIT_COUNT);
352
353
  // Do not prune if there is no valid top RD Cost for comparison
354
0
  if (x->top_inter_tx_no_split_rd[blk_idx][num_top_cand - 1] == INT64_MAX)
355
0
    return false;
356
357
0
  if (tmp_rd > x->top_inter_tx_no_split_rd[blk_idx][num_top_cand - 1])
358
0
    return true;
359
360
0
  return false;
361
0
}
362
363
static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
364
                                 const SPEED_FEATURES *sf,
365
0
                                 int tx_size_search_method) {
366
0
  if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
367
368
0
  if (sf->tx_sf.tx_size_search_lgr_block) {
369
0
    if (mi_width > mi_size_wide[BLOCK_64X64] ||
370
0
        mi_height > mi_size_high[BLOCK_64X64])
371
0
      return MAX_VARTX_DEPTH;
372
0
  }
373
374
0
  if (is_inter) {
375
0
    return (mi_height != mi_width)
376
0
               ? sf->tx_sf.inter_tx_size_search_init_depth_rect
377
0
               : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
378
0
  } else {
379
0
    return (mi_height != mi_width)
380
0
               ? sf->tx_sf.intra_tx_size_search_init_depth_rect
381
0
               : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
382
0
  }
383
0
}
384
385
static inline void select_tx_block(
386
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
387
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
388
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
389
    RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
390
    int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode, int blk_idx);
391
392
// NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
393
// 0: Do not collect any RD stats
394
// 1: Collect RD stats for transform units
395
// 2: Collect RD stats for partition units
396
#if CONFIG_COLLECT_RD_STATS
397
398
static inline void get_energy_distribution_fine(
399
    const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
400
    const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
401
    double *verdist) {
402
  const int bw = block_size_wide[bsize];
403
  const int bh = block_size_high[bsize];
404
  unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
405
406
  if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
407
    // Special cases: calculate 'esq' values manually, as we don't have 'vf'
408
    // functions for the 16 (very small) sub-blocks of this block.
409
    const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
410
    const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
411
    assert(bw <= 32);
412
    assert(bh <= 32);
413
    assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
414
    if (cpi->common.seq_params->use_highbitdepth) {
415
      const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
416
      const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
417
      for (int i = 0; i < bh; ++i)
418
        for (int j = 0; j < bw; ++j) {
419
          const int index = (j >> w_shift) + ((i >> h_shift) << 2);
420
          esq[index] +=
421
              (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
422
              (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
423
        }
424
    } else {
425
      for (int i = 0; i < bh; ++i)
426
        for (int j = 0; j < bw; ++j) {
427
          const int index = (j >> w_shift) + ((i >> h_shift) << 2);
428
          esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
429
                        (src[j + i * src_stride] - dst[j + i * dst_stride]);
430
        }
431
    }
432
  } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
433
    const int f_index =
434
        (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
435
    assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
436
    const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
437
    assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
438
    assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
439
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
440
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
441
                                 dst_stride, &esq[1]);
442
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
443
                                 dst_stride, &esq[2]);
444
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
445
                                 dst_stride, &esq[3]);
446
    src += bh / 4 * src_stride;
447
    dst += bh / 4 * dst_stride;
448
449
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
450
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
451
                                 dst_stride, &esq[5]);
452
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
453
                                 dst_stride, &esq[6]);
454
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
455
                                 dst_stride, &esq[7]);
456
    src += bh / 4 * src_stride;
457
    dst += bh / 4 * dst_stride;
458
459
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
460
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
461
                                 dst_stride, &esq[9]);
462
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
463
                                 dst_stride, &esq[10]);
464
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
465
                                 dst_stride, &esq[11]);
466
    src += bh / 4 * src_stride;
467
    dst += bh / 4 * dst_stride;
468
469
    cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
470
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
471
                                 dst_stride, &esq[13]);
472
    cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
473
                                 dst_stride, &esq[14]);
474
    cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
475
                                 dst_stride, &esq[15]);
476
  }
477
478
  double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
479
                 esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
480
                 esq[12] + esq[13] + esq[14] + esq[15];
481
  if (total > 0) {
482
    const double e_recip = 1.0 / total;
483
    hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
484
    hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
485
    hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
486
    if (need_4th) {
487
      hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
488
    }
489
    verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
490
    verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
491
    verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
492
    if (need_4th) {
493
      verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
494
    }
495
  } else {
496
    hordist[0] = verdist[0] = 0.25;
497
    hordist[1] = verdist[1] = 0.25;
498
    hordist[2] = verdist[2] = 0.25;
499
    if (need_4th) {
500
      hordist[3] = verdist[3] = 0.25;
501
    }
502
  }
503
}
504
505
static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
506
  double sum = 0.0;
507
  for (int j = 0; j < h; ++j) {
508
    for (int i = 0; i < w; ++i) {
509
      const int err = diff[j * stride + i];
510
      sum += err * err;
511
    }
512
  }
513
  assert(w > 0 && h > 0);
514
  return sum / (w * h);
515
}
516
517
static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
518
  double sum = 0.0;
519
  for (int j = 0; j < h; ++j) {
520
    for (int i = 0; i < w; ++i) {
521
      sum += abs(diff[j * stride + i]);
522
    }
523
  }
524
  assert(w > 0 && h > 0);
525
  return sum / (w * h);
526
}
527
528
static inline void get_2x2_normalized_sses_and_sads(
529
    const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
530
    int src_stride, const uint8_t *const dst, int dst_stride,
531
    const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
532
    double *const sad_norm_arr) {
533
  const BLOCK_SIZE tx_bsize_half =
534
      get_partition_subsize(tx_bsize, PARTITION_SPLIT);
535
  if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
536
    const int half_width = block_size_wide[tx_bsize] / 2;
537
    const int half_height = block_size_high[tx_bsize] / 2;
538
    for (int row = 0; row < 2; ++row) {
539
      for (int col = 0; col < 2; ++col) {
540
        const int16_t *const this_src_diff =
541
            src_diff + row * half_height * diff_stride + col * half_width;
542
        if (sse_norm_arr) {
543
          sse_norm_arr[row * 2 + col] =
544
              get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
545
        }
546
        if (sad_norm_arr) {
547
          sad_norm_arr[row * 2 + col] =
548
              get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
549
        }
550
      }
551
    }
552
  } else {  // use function pointers to calculate stats
553
    const int half_width = block_size_wide[tx_bsize_half];
554
    const int half_height = block_size_high[tx_bsize_half];
555
    const int num_samples_half = half_width * half_height;
556
    for (int row = 0; row < 2; ++row) {
557
      for (int col = 0; col < 2; ++col) {
558
        const uint8_t *const this_src =
559
            src + row * half_height * src_stride + col * half_width;
560
        const uint8_t *const this_dst =
561
            dst + row * half_height * dst_stride + col * half_width;
562
563
        if (sse_norm_arr) {
564
          unsigned int this_sse;
565
          cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
566
                                             dst_stride, &this_sse);
567
          sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
568
        }
569
570
        if (sad_norm_arr) {
571
          const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf(
572
              this_src, src_stride, this_dst, dst_stride);
573
          sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
574
        }
575
      }
576
    }
577
  }
578
}
579
580
#if CONFIG_COLLECT_RD_STATS == 1
581
static double get_mean(const int16_t *diff, int stride, int w, int h) {
582
  double sum = 0.0;
583
  for (int j = 0; j < h; ++j) {
584
    for (int i = 0; i < w; ++i) {
585
      sum += diff[j * stride + i];
586
    }
587
  }
588
  assert(w > 0 && h > 0);
589
  return sum / (w * h);
590
}
591
static inline void PrintTransformUnitStats(
592
    const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
593
    int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
594
    TX_TYPE tx_type, int64_t rd) {
595
  if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
596
597
  // Generate small sample to restrict output size.
598
  static unsigned int seed = 21743;
599
  if (lcg_rand16(&seed) % 256 > 0) return;
600
601
  const char output_file[] = "tu_stats.txt";
602
  FILE *fout = fopen(output_file, "a");
603
  if (!fout) return;
604
605
  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
606
  const MACROBLOCKD *const xd = &x->e_mbd;
607
  const int plane = 0;
608
  struct macroblock_plane *const p = &x->plane[plane];
609
  const struct macroblockd_plane *const pd = &xd->plane[plane];
610
  const int txw = tx_size_wide[tx_size];
611
  const int txh = tx_size_high[tx_size];
612
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
613
  const int q_step = p->dequant_QTX[1] >> dequant_shift;
614
  const int num_samples = txw * txh;
615
616
  const double rate_norm = (double)rd_stats->rate / num_samples;
617
  const double dist_norm = (double)rd_stats->dist / num_samples;
618
619
  fprintf(fout, "%g %g", rate_norm, dist_norm);
620
621
  const int src_stride = p->src.stride;
622
  const uint8_t *const src =
623
      &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
624
  const int dst_stride = pd->dst.stride;
625
  const uint8_t *const dst =
626
      &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
627
  unsigned int sse;
628
  cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
629
  const double sse_norm = (double)sse / num_samples;
630
631
  const unsigned int sad =
632
      cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
633
  const double sad_norm = (double)sad / num_samples;
634
635
  fprintf(fout, " %g %g", sse_norm, sad_norm);
636
637
  const int diff_stride = block_size_wide[plane_bsize];
638
  const int16_t *const src_diff =
639
      &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
640
641
  double sse_norm_arr[4], sad_norm_arr[4];
642
  get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
643
                                   dst_stride, src_diff, diff_stride,
644
                                   sse_norm_arr, sad_norm_arr);
645
  for (int i = 0; i < 4; ++i) {
646
    fprintf(fout, " %g", sse_norm_arr[i]);
647
  }
648
  for (int i = 0; i < 4; ++i) {
649
    fprintf(fout, " %g", sad_norm_arr[i]);
650
  }
651
652
  const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
653
  const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
654
655
  fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
656
          tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
657
658
  int model_rate;
659
  int64_t model_dist;
660
  model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
661
                                   &model_rate, &model_dist);
662
  const double model_rate_norm = (double)model_rate / num_samples;
663
  const double model_dist_norm = (double)model_dist / num_samples;
664
  fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
665
666
  const double mean = get_mean(src_diff, diff_stride, txw, txh);
667
  float hor_corr, vert_corr;
668
  av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
669
                                  &vert_corr);
670
  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
671
672
  double hdist[4] = { 0 }, vdist[4] = { 0 };
673
  get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
674
                               1, hdist, vdist);
675
  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
676
          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
677
678
  fprintf(fout, " %d %" PRId64, x->rdmult, rd);
679
680
  fprintf(fout, "\n");
681
  fclose(fout);
682
}
683
#endif  // CONFIG_COLLECT_RD_STATS == 1
684
685
#if CONFIG_COLLECT_RD_STATS >= 2
686
static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
687
  const AV1_COMMON *cm = &cpi->common;
688
  const int num_planes = av1_num_planes(cm);
689
  const MACROBLOCKD *xd = &x->e_mbd;
690
  const MB_MODE_INFO *mbmi = xd->mi[0];
691
  int64_t total_sse = 0;
692
  for (int plane = 0; plane < num_planes; ++plane) {
693
    const struct macroblock_plane *const p = &x->plane[plane];
694
    const struct macroblockd_plane *const pd = &xd->plane[plane];
695
    const BLOCK_SIZE bs =
696
        get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
697
    unsigned int sse;
698
699
    if (plane) continue;
700
701
    cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf,
702
                            pd->dst.stride, &sse);
703
    total_sse += sse;
704
  }
705
  total_sse <<= 4;
706
  return total_sse;
707
}
708
709
static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
710
                             int64_t sse, int *est_residue_cost,
711
                             int64_t *est_dist) {
712
  const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
713
  if (md->ready) {
714
    if (sse < md->dist_mean) {
715
      *est_residue_cost = 0;
716
      *est_dist = sse;
717
    } else {
718
      *est_dist = (int64_t)round(md->dist_mean);
719
      const double est_ld = md->a * sse + md->b;
720
      // Clamp estimated rate cost by INT_MAX / 2.
721
      // TODO(angiebird@google.com): find better solution than clamping.
722
      if (fabs(est_ld) < 1e-2) {
723
        *est_residue_cost = INT_MAX / 2;
724
      } else {
725
        double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
726
        if (est_residue_cost_dbl < 0) {
727
          *est_residue_cost = 0;
728
        } else {
729
          *est_residue_cost =
730
              (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
731
        }
732
      }
733
      if (*est_residue_cost <= 0) {
734
        *est_residue_cost = 0;
735
        *est_dist = sse;
736
      }
737
    }
738
    return 1;
739
  }
740
  return 0;
741
}
742
743
static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
744
                                   const uint8_t *dst8, int dst_stride, int w,
745
                                   int h) {
746
  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
747
  const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
748
  double sum = 0.0;
749
  for (int j = 0; j < h; ++j) {
750
    for (int i = 0; i < w; ++i) {
751
      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
752
      sum += diff;
753
    }
754
  }
755
  assert(w > 0 && h > 0);
756
  return sum / (w * h);
757
}
758
759
static double get_diff_mean(const uint8_t *src, int src_stride,
760
                            const uint8_t *dst, int dst_stride, int w, int h) {
761
  double sum = 0.0;
762
  for (int j = 0; j < h; ++j) {
763
    for (int i = 0; i < w; ++i) {
764
      const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
765
      sum += diff;
766
    }
767
  }
768
  assert(w > 0 && h > 0);
769
  return sum / (w * h);
770
}
771
772
static inline void PrintPredictionUnitStats(const AV1_COMP *const cpi,
773
                                            const TileDataEnc *tile_data,
774
                                            MACROBLOCK *x,
775
                                            const RD_STATS *const rd_stats,
776
                                            BLOCK_SIZE plane_bsize) {
777
  if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
778
779
  if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
780
      (tile_data == NULL ||
781
       !tile_data->inter_mode_rd_models[plane_bsize].ready))
782
    return;
783
  (void)tile_data;
784
  // Generate small sample to restrict output size.
785
  static unsigned int seed = 95014;
786
787
  if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
788
      1)
789
    return;
790
791
  const char output_file[] = "pu_stats.txt";
792
  FILE *fout = fopen(output_file, "a");
793
  if (!fout) return;
794
795
  MACROBLOCKD *const xd = &x->e_mbd;
796
  const int plane = 0;
797
  struct macroblock_plane *const p = &x->plane[plane];
798
  struct macroblockd_plane *pd = &xd->plane[plane];
799
  const int diff_stride = block_size_wide[plane_bsize];
800
  int bw, bh;
801
  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
802
                     &bh);
803
  const int num_samples = bw * bh;
804
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
805
  const int q_step = p->dequant_QTX[1] >> dequant_shift;
806
  const int shift = (xd->bd - 8);
807
808
  const double rate_norm = (double)rd_stats->rate / num_samples;
809
  const double dist_norm = (double)rd_stats->dist / num_samples;
810
  const double rdcost_norm =
811
      (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
812
813
  fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
814
815
  const int src_stride = p->src.stride;
816
  const uint8_t *const src = p->src.buf;
817
  const int dst_stride = pd->dst.stride;
818
  const uint8_t *const dst = pd->dst.buf;
819
  const int16_t *const src_diff = p->src_diff;
820
821
  int64_t sse = calculate_sse(xd, p, pd, bw, bh);
822
  const double sse_norm = (double)sse / num_samples;
823
824
  const unsigned int sad =
825
      cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
826
  const double sad_norm =
827
      (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
828
829
  fprintf(fout, " %g %g", sse_norm, sad_norm);
830
831
  double sse_norm_arr[4], sad_norm_arr[4];
832
  get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
833
                                   dst_stride, src_diff, diff_stride,
834
                                   sse_norm_arr, sad_norm_arr);
835
  if (shift) {
836
    for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
837
    for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
838
  }
839
  for (int i = 0; i < 4; ++i) {
840
    fprintf(fout, " %g", sse_norm_arr[i]);
841
  }
842
  for (int i = 0; i < 4; ++i) {
843
    fprintf(fout, " %g", sad_norm_arr[i]);
844
  }
845
846
  fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
847
848
  int model_rate;
849
  int64_t model_dist;
850
  model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
851
                                   &model_rate, &model_dist);
852
  const double model_rdcost_norm =
853
      (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
854
  const double model_rate_norm = (double)model_rate / num_samples;
855
  const double model_dist_norm = (double)model_dist / num_samples;
856
  fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
857
          model_rdcost_norm);
858
859
  double mean;
860
  if (is_cur_buf_hbd(xd)) {
861
    mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
862
                                pd->dst.stride, bw, bh);
863
  } else {
864
    mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
865
                         bw, bh);
866
  }
867
  mean /= (1 << shift);
868
  float hor_corr, vert_corr;
869
  av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
870
                                  &vert_corr);
871
  fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
872
873
  double hdist[4] = { 0 }, vdist[4] = { 0 };
874
  get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
875
                               dst_stride, 1, hdist, vdist);
876
  fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
877
          hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
878
879
  if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
880
    assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
881
    const int64_t overall_sse = get_sse(cpi, x);
882
    int est_residue_cost = 0;
883
    int64_t est_dist = 0;
884
    get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
885
                      &est_dist);
886
    const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
887
    const double est_dist_norm = (double)est_dist / num_samples;
888
    const double est_rdcost_norm =
889
        (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
890
    fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
891
            est_rdcost_norm);
892
  }
893
894
  fprintf(fout, "\n");
895
  fclose(fout);
896
}
897
#endif  // CONFIG_COLLECT_RD_STATS >= 2
898
#endif  // CONFIG_COLLECT_RD_STATS
899
900
static inline void inverse_transform_block_facade(MACROBLOCK *const x,
901
                                                  int plane, int block,
902
                                                  int blk_row, int blk_col,
903
0
                                                  int eob, int reduced_tx_set) {
904
0
  if (!eob) return;
905
0
  struct macroblock_plane *const p = &x->plane[plane];
906
0
  MACROBLOCKD *const xd = &x->e_mbd;
907
0
  tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
908
0
  const PLANE_TYPE plane_type = get_plane_type(plane);
909
0
  const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
910
0
  const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
911
0
                                          tx_size, reduced_tx_set);
912
913
0
  struct macroblockd_plane *const pd = &xd->plane[plane];
914
0
  const int dst_stride = pd->dst.stride;
915
0
  uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
916
0
  av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
917
0
                              dst_stride, eob, reduced_tx_set);
918
0
}
919
920
static inline void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
921
                               int block, int blk_row, int blk_col,
922
                               BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
923
                               const TXB_CTX *const txb_ctx, int skip_trellis,
924
                               TX_TYPE best_tx_type, int do_quant,
925
0
                               int *rate_cost, uint16_t best_eob) {
926
0
  const AV1_COMMON *cm = &cpi->common;
927
0
  MACROBLOCKD *xd = &x->e_mbd;
928
0
  MB_MODE_INFO *mbmi = xd->mi[0];
929
0
  const int is_inter = is_inter_block(mbmi);
930
0
  if (!is_inter && best_eob &&
931
0
      (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
932
0
       blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
933
    // if the quantized coefficients are stored in the dqcoeff buffer, we don't
934
    // need to do transform and quantization again.
935
0
    if (do_quant) {
936
0
      TxfmParam txfm_param_intra;
937
0
      QUANT_PARAM quant_param_intra;
938
0
      av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
939
0
      av1_setup_quant(tx_size, !skip_trellis,
940
0
                      skip_trellis
941
0
                          ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
942
0
                                                    : AV1_XFORM_QUANT_FP)
943
0
                          : AV1_XFORM_QUANT_FP,
944
0
                      cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra);
945
0
      av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
946
0
                        &quant_param_intra);
947
0
      av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
948
0
                      &txfm_param_intra, &quant_param_intra);
949
0
      if (quant_param_intra.use_optimize_b) {
950
0
        av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
951
0
                       rate_cost);
952
0
      }
953
0
    }
954
955
0
    inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
956
0
                                   x->plane[plane].eobs[block],
957
0
                                   cm->features.reduced_tx_set_used);
958
959
    // This may happen because of hash collision. The eob stored in the hash
960
    // table is non-zero, but the real eob is zero. We need to make sure tx_type
961
    // is DCT_DCT in this case.
962
0
    if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
963
0
        best_tx_type != DCT_DCT) {
964
0
      update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
965
0
    }
966
0
  }
967
0
}
968
969
static unsigned pixel_dist_visible_only(
970
    const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
971
    const int src_stride, const uint8_t *dst, const int dst_stride,
972
    const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
973
0
    int visible_cols) {
974
0
  unsigned sse;
975
976
0
  if (txb_rows == visible_rows && txb_cols == visible_cols) {
977
0
    cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
978
0
    return sse;
979
0
  }
980
981
0
#if CONFIG_AV1_HIGHBITDEPTH
982
0
  const MACROBLOCKD *xd = &x->e_mbd;
983
0
  if (is_cur_buf_hbd(xd)) {
984
0
    uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
985
0
                                             visible_cols, visible_rows);
986
0
    return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
987
0
  }
988
#else
989
  (void)x;
990
#endif
991
0
  sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
992
0
                         visible_rows);
993
0
  return sse;
994
0
}
995
996
// Compute the pixel domain distortion from src and dst on all visible 4x4s in
997
// the
998
// transform block.
999
static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
1000
                           int plane, const uint8_t *src, const int src_stride,
1001
                           const uint8_t *dst, const int dst_stride,
1002
                           int blk_row, int blk_col,
1003
                           const BLOCK_SIZE plane_bsize,
1004
0
                           const BLOCK_SIZE tx_bsize) {
1005
0
  int txb_rows, txb_cols, visible_rows, visible_cols;
1006
0
  const MACROBLOCKD *xd = &x->e_mbd;
1007
1008
0
  get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
1009
0
                     &txb_cols, &txb_rows, &visible_cols, &visible_rows);
1010
0
  assert(visible_rows > 0);
1011
0
  assert(visible_cols > 0);
1012
1013
0
  unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
1014
0
                                         dst_stride, tx_bsize, txb_rows,
1015
0
                                         txb_cols, visible_rows, visible_cols);
1016
1017
0
  return sse;
1018
0
}
1019
1020
static inline int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
1021
                                           int plane, BLOCK_SIZE plane_bsize,
1022
                                           int block, int blk_row, int blk_col,
1023
0
                                           TX_SIZE tx_size) {
1024
0
  MACROBLOCKD *const xd = &x->e_mbd;
1025
0
  const struct macroblock_plane *const p = &x->plane[plane];
1026
0
  const uint16_t eob = p->eobs[block];
1027
0
  const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
1028
0
  const int bsw = block_size_wide[tx_bsize];
1029
0
  const int bsh = block_size_high[tx_bsize];
1030
0
  const int src_stride = x->plane[plane].src.stride;
1031
0
  const int dst_stride = xd->plane[plane].dst.stride;
1032
  // Scale the transform block index to pixel unit.
1033
0
  const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
1034
0
  const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
1035
0
  const uint8_t *src = &x->plane[plane].src.buf[src_idx];
1036
0
  const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
1037
0
  const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
1038
1039
0
  assert(cpi != NULL);
1040
0
  assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
1041
1042
0
  uint8_t *recon;
1043
0
  DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
1044
1045
0
#if CONFIG_AV1_HIGHBITDEPTH
1046
0
  if (is_cur_buf_hbd(xd)) {
1047
0
    recon = CONVERT_TO_BYTEPTR(recon16);
1048
0
    aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride,
1049
0
                             CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh);
1050
0
  } else {
1051
0
    recon = (uint8_t *)recon16;
1052
0
    aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1053
0
  }
1054
#else
1055
  recon = (uint8_t *)recon16;
1056
  aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1057
#endif
1058
1059
0
  const PLANE_TYPE plane_type = get_plane_type(plane);
1060
0
  TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
1061
0
                                    cpi->common.features.reduced_tx_set_used);
1062
0
  av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
1063
0
                              MAX_TX_SIZE, eob,
1064
0
                              cpi->common.features.reduced_tx_set_used);
1065
1066
0
  return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
1067
0
                         blk_row, blk_col, plane_bsize, tx_bsize);
1068
0
}
1069
1070
// pruning thresholds for prune_txk_type and prune_txk_type_separ
1071
static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
1072
static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
1073
1074
// R-D costs are sorted in ascending order.
1075
0
static inline void sort_rd(int64_t rds[], int txk[], int len) {
1076
0
  int i, j, k;
1077
1078
0
  for (i = 1; i <= len - 1; ++i) {
1079
0
    for (j = 0; j < i; ++j) {
1080
0
      if (rds[j] > rds[i]) {
1081
0
        int64_t temprd;
1082
0
        int tempi;
1083
1084
0
        temprd = rds[i];
1085
0
        tempi = txk[i];
1086
1087
0
        for (k = i; k > j; k--) {
1088
0
          rds[k] = rds[k - 1];
1089
0
          txk[k] = txk[k - 1];
1090
0
        }
1091
1092
0
        rds[j] = temprd;
1093
0
        txk[j] = tempi;
1094
0
        break;
1095
0
      }
1096
0
    }
1097
0
  }
1098
0
}
1099
1100
static inline int64_t av1_block_error_qm(
1101
    const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size,
1102
0
    const qm_val_t *qmatrix, const int16_t *scan, int64_t *ssz, int bd) {
1103
0
  int i;
1104
0
  int64_t error = 0, sqcoeff = 0;
1105
0
  int shift = 2 * (bd - 8);
1106
0
  int rounding = (1 << shift) >> 1;
1107
1108
0
  for (i = 0; i < block_size; i++) {
1109
0
    int64_t weight = qmatrix[scan[i]];
1110
0
    int64_t dd = coeff[i] - dqcoeff[i];
1111
0
    dd *= weight;
1112
0
    int64_t cc = coeff[i];
1113
0
    cc *= weight;
1114
    // The ranges of coeff and dqcoeff are
1115
    //  bd8 : 18 bits (including sign)
1116
    //  bd10: 20 bits (including sign)
1117
    //  bd12: 22 bits (including sign)
1118
    // As AOM_QM_BITS is 5, the intermediate quantities in the calculation
1119
    // below should fit in 54 bits, thus no overflow should happen.
1120
0
    error += (dd * dd + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1121
0
    sqcoeff += (cc * cc + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1122
0
  }
1123
1124
0
  error = (error + rounding) >> shift;
1125
0
  sqcoeff = (sqcoeff + rounding) >> shift;
1126
1127
0
  *ssz = sqcoeff;
1128
0
  return error;
1129
0
}
1130
1131
static inline void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
1132
                                        TX_SIZE tx_size,
1133
                                        const qm_val_t *qmatrix,
1134
                                        const int16_t *scan, int64_t *out_dist,
1135
0
                                        int64_t *out_sse) {
1136
0
  const struct macroblock_plane *const p = &x->plane[plane];
1137
  // Transform domain distortion computation is more efficient as it does
1138
  // not involve an inverse transform, but it is less accurate.
1139
0
  const int buffer_length = av1_get_max_eob(tx_size);
1140
0
  int64_t this_sse;
1141
  // TX-domain results need to shift down to Q2/D10 to match pixel
1142
  // domain distortion values which are in Q2^2
1143
0
  int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
1144
0
  const int block_offset = BLOCK_OFFSET(block);
1145
0
  tran_low_t *const coeff = p->coeff + block_offset;
1146
0
  tran_low_t *const dqcoeff = p->dqcoeff + block_offset;
1147
0
#if CONFIG_AV1_HIGHBITDEPTH
1148
0
  MACROBLOCKD *const xd = &x->e_mbd;
1149
0
  if (is_cur_buf_hbd(xd)) {
1150
0
    if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
1151
0
      *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length,
1152
0
                                         &this_sse, xd->bd);
1153
0
    } else {
1154
0
      *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
1155
0
                                     scan, &this_sse, xd->bd);
1156
0
    }
1157
0
  } else {
1158
0
#endif
1159
0
    if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
1160
0
      *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1161
0
    } else {
1162
0
      *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
1163
0
                                     scan, &this_sse, 8);
1164
0
    }
1165
0
#if CONFIG_AV1_HIGHBITDEPTH
1166
0
  }
1167
0
#endif
1168
1169
0
  *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
1170
0
  *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
1171
0
}
1172
1173
static uint16_t prune_txk_type_separ(
1174
    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, TX_SIZE tx_size,
1175
    int blk_row, int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
1176
    int16_t allowed_tx_mask, int prune_factor, const TXB_CTX *const txb_ctx,
1177
0
    int reduced_tx_set_used, int64_t ref_best_rd, int num_sel) {
1178
0
  const AV1_COMMON *cm = &cpi->common;
1179
0
  MACROBLOCKD *xd = &x->e_mbd;
1180
1181
0
  int idx;
1182
1183
0
  int64_t rds_v[4];
1184
0
  int64_t rds_h[4];
1185
0
  int idx_v[4] = { 0, 1, 2, 3 };
1186
0
  int idx_h[4] = { 0, 1, 2, 3 };
1187
0
  int skip_v[4] = { 0 };
1188
0
  int skip_h[4] = { 0 };
1189
0
  const int idx_map[16] = {
1190
0
    DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1191
0
    ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1192
0
    FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1193
0
    H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1194
0
  };
1195
1196
0
  const int sel_pattern_v[16] = {
1197
0
    0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
1198
0
  };
1199
0
  const int sel_pattern_h[16] = {
1200
0
    0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
1201
0
  };
1202
1203
0
  QUANT_PARAM quant_param;
1204
0
  TxfmParam txfm_param;
1205
0
  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1206
0
  av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1207
0
                  &quant_param);
1208
0
  int tx_type;
1209
  // to ensure we can try ones even outside of ext_tx_set of current block
1210
  // this function should only be called for size < 16
1211
0
  assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
1212
0
  txfm_param.tx_set_type = EXT_TX_SET_ALL16;
1213
1214
0
  int rate_cost = 0;
1215
0
  int64_t dist = 0, sse = 0;
1216
  // evaluate horizontal with vertical DCT
1217
0
  for (idx = 0; idx < 4; ++idx) {
1218
0
    tx_type = idx_map[idx];
1219
0
    txfm_param.tx_type = tx_type;
1220
1221
0
    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1222
0
                      &quant_param);
1223
1224
0
    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1225
0
                    &quant_param);
1226
1227
0
    const SCAN_ORDER *const scan_order =
1228
0
        get_scan(txfm_param.tx_size, txfm_param.tx_type);
1229
0
    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1230
0
                         scan_order->scan, &dist, &sse);
1231
1232
0
    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1233
0
                                              txb_ctx, reduced_tx_set_used, 0);
1234
1235
0
    rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
1236
1237
0
    if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
1238
0
      skip_h[idx] = 1;
1239
0
    }
1240
0
  }
1241
0
  sort_rd(rds_h, idx_h, 4);
1242
0
  for (idx = 1; idx < 4; idx++) {
1243
0
    if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
1244
0
  }
1245
1246
0
  if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
1247
1248
  // evaluate vertical with the best horizontal chosen
1249
0
  rds_v[0] = rds_h[0];
1250
0
  int start_v = 1, end_v = 4;
1251
0
  const int *idx_map_v = idx_map + idx_h[0];
1252
1253
0
  for (idx = start_v; idx < end_v; ++idx) {
1254
0
    tx_type = idx_map_v[idx_v[idx] * 4];
1255
0
    txfm_param.tx_type = tx_type;
1256
1257
0
    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1258
0
                      &quant_param);
1259
1260
0
    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1261
0
                    &quant_param);
1262
1263
0
    const SCAN_ORDER *const scan_order =
1264
0
        get_scan(txfm_param.tx_size, txfm_param.tx_type);
1265
0
    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1266
0
                         scan_order->scan, &dist, &sse);
1267
1268
0
    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1269
0
                                              txb_ctx, reduced_tx_set_used, 0);
1270
1271
0
    rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
1272
1273
0
    if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
1274
0
      skip_v[idx] = 1;
1275
0
    }
1276
0
  }
1277
0
  sort_rd(rds_v, idx_v, 4);
1278
0
  for (idx = 1; idx < 4; idx++) {
1279
0
    if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
1280
0
  }
1281
1282
  // combine rd_h and rd_v to prune tx candidates
1283
0
  int i_v, i_h;
1284
0
  int64_t rds[16];
1285
0
  int num_cand = 0, last = TX_TYPES - 1;
1286
1287
0
  for (int i = 0; i < 16; i++) {
1288
0
    i_v = sel_pattern_v[i];
1289
0
    i_h = sel_pattern_h[i];
1290
0
    tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
1291
0
    if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
1292
0
        skip_v[idx_v[i_v]]) {
1293
0
      txk_map[last] = tx_type;
1294
0
      last--;
1295
0
    } else {
1296
0
      txk_map[num_cand] = tx_type;
1297
0
      rds[num_cand] = rds_v[i_v] + rds_h[i_h];
1298
0
      if (rds[num_cand] == 0) rds[num_cand] = 1;
1299
0
      num_cand++;
1300
0
    }
1301
0
  }
1302
0
  sort_rd(rds, txk_map, num_cand);
1303
1304
0
  uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1305
0
  num_sel = AOMMIN(num_sel, num_cand);
1306
1307
0
  for (int i = 1; i < num_sel; i++) {
1308
0
    int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
1309
0
    if (factor < (int64_t)prune_factor)
1310
0
      prune &= ~(1 << txk_map[i]);
1311
0
    else
1312
0
      break;
1313
0
  }
1314
0
  return prune;
1315
0
}
1316
1317
static uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1318
                               int block, TX_SIZE tx_size, int blk_row,
1319
                               int blk_col, BLOCK_SIZE plane_bsize,
1320
                               int *txk_map, uint16_t allowed_tx_mask,
1321
                               int prune_factor, const TXB_CTX *const txb_ctx,
1322
0
                               int reduced_tx_set_used) {
1323
0
  const AV1_COMMON *cm = &cpi->common;
1324
0
  MACROBLOCKD *xd = &x->e_mbd;
1325
0
  int tx_type;
1326
1327
0
  int64_t rds[TX_TYPES];
1328
1329
0
  int num_cand = 0;
1330
0
  int last = TX_TYPES - 1;
1331
1332
0
  TxfmParam txfm_param;
1333
0
  QUANT_PARAM quant_param;
1334
0
  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1335
0
  av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1336
0
                  &quant_param);
1337
1338
0
  for (int idx = 0; idx < TX_TYPES; idx++) {
1339
0
    tx_type = idx;
1340
0
    int rate_cost = 0;
1341
0
    int64_t dist = 0, sse = 0;
1342
0
    if (!(allowed_tx_mask & (1 << tx_type))) {
1343
0
      txk_map[last] = tx_type;
1344
0
      last--;
1345
0
      continue;
1346
0
    }
1347
0
    txfm_param.tx_type = tx_type;
1348
1349
0
    av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1350
0
                      &quant_param);
1351
1352
    // do txfm and quantization
1353
0
    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1354
0
                    &quant_param);
1355
    // estimate rate cost
1356
0
    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1357
0
                                              txb_ctx, reduced_tx_set_used, 0);
1358
    // tx domain dist
1359
0
    const SCAN_ORDER *const scan_order =
1360
0
        get_scan(txfm_param.tx_size, txfm_param.tx_type);
1361
0
    dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1362
0
                         scan_order->scan, &dist, &sse);
1363
1364
0
    txk_map[num_cand] = tx_type;
1365
0
    rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
1366
0
    if (rds[num_cand] == 0) rds[num_cand] = 1;
1367
0
    num_cand++;
1368
0
  }
1369
1370
0
  if (num_cand == 0) return (uint16_t)0xFFFF;
1371
1372
0
  sort_rd(rds, txk_map, num_cand);
1373
0
  uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1374
1375
  // 0 < prune_factor <= 1000 controls aggressiveness
1376
0
  int64_t factor = 0;
1377
0
  for (int idx = 1; idx < num_cand; idx++) {
1378
0
    factor = 1000 * (rds[idx] - rds[0]) / rds[0];
1379
0
    if (factor < (int64_t)prune_factor)
1380
0
      prune &= ~(1 << txk_map[idx]);
1381
0
    else
1382
0
      break;
1383
0
  }
1384
0
  return prune;
1385
0
}
1386
1387
// These thresholds were calibrated to provide a certain number of TX types
1388
// pruned by the model on average, i.e. selecting a threshold with index i
1389
// will lead to pruning i+1 TX types on average
1390
static const float *prune_2D_adaptive_thresholds[] = {
1391
  // TX_4X4
1392
  (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1393
             0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1394
             0.09778f, 0.11780f },
1395
  // TX_8X8
1396
  (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1397
             0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1398
             0.10803f, 0.14124f },
1399
  // TX_16X16
1400
  (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1401
             0.06897f, 0.07629f, 0.08875f, 0.11169f },
1402
  // TX_32X32
1403
  NULL,
1404
  // TX_64X64
1405
  NULL,
1406
  // TX_4X8
1407
  (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1408
             0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1409
             0.10168f, 0.12585f },
1410
  // TX_8X4
1411
  (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1412
             0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1413
             0.10583f, 0.13123f },
1414
  // TX_8X16
1415
  (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1416
             0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1417
             0.10730f, 0.14221f },
1418
  // TX_16X8
1419
  (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1420
             0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1421
             0.10339f, 0.13464f },
1422
  // TX_16X32
1423
  NULL,
1424
  // TX_32X16
1425
  NULL,
1426
  // TX_32X64
1427
  NULL,
1428
  // TX_64X32
1429
  NULL,
1430
  // TX_4X16
1431
  (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1432
             0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1433
             0.10242f, 0.12878f },
1434
  // TX_16X4
1435
  (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1436
             0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1437
             0.10217f, 0.12610f },
1438
  // TX_8X32
1439
  NULL,
1440
  // TX_32X8
1441
  NULL,
1442
  // TX_16X64
1443
  NULL,
1444
  // TX_64X16
1445
  NULL,
1446
};
1447
1448
static inline float get_adaptive_thresholds(
1449
    TX_SIZE tx_size, TxSetType tx_set_type,
1450
0
    TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {
1451
0
  const int prune_aggr_table[5][2] = {
1452
0
    { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 }
1453
0
  };
1454
0
  int pruning_aggressiveness = 0;
1455
0
  if (tx_set_type == EXT_TX_SET_ALL16)
1456
0
    pruning_aggressiveness =
1457
0
        prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0];
1458
0
  else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
1459
0
    pruning_aggressiveness =
1460
0
        prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1];
1461
1462
0
  return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
1463
0
}
1464
1465
static inline void get_energy_distribution_finer(const int16_t *diff,
1466
                                                 int stride, int bw, int bh,
1467
                                                 float *hordist,
1468
0
                                                 float *verdist) {
1469
  // First compute downscaled block energy values (esq); downscale factors
1470
  // are defined by w_shift and h_shift.
1471
0
  unsigned int esq[256];
1472
0
  const int w_shift = bw <= 8 ? 0 : 1;
1473
0
  const int h_shift = bh <= 8 ? 0 : 1;
1474
0
  const int esq_w = bw >> w_shift;
1475
0
  const int esq_h = bh >> h_shift;
1476
0
  const int esq_sz = esq_w * esq_h;
1477
0
  int i, j;
1478
0
  memset(esq, 0, esq_sz * sizeof(esq[0]));
1479
0
  if (w_shift) {
1480
0
    for (i = 0; i < bh; i++) {
1481
0
      unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1482
0
      const int16_t *cur_diff_row = diff + i * stride;
1483
0
      for (j = 0; j < bw; j += 2) {
1484
0
        cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1485
0
                                cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1486
0
      }
1487
0
    }
1488
0
  } else {
1489
0
    for (i = 0; i < bh; i++) {
1490
0
      unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1491
0
      const int16_t *cur_diff_row = diff + i * stride;
1492
0
      for (j = 0; j < bw; j++) {
1493
0
        cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1494
0
      }
1495
0
    }
1496
0
  }
1497
1498
0
  uint64_t total = 0;
1499
0
  for (i = 0; i < esq_sz; i++) total += esq[i];
1500
1501
  // Output hordist and verdist arrays are normalized 1D projections of esq
1502
0
  if (total == 0) {
1503
0
    float hor_val = 1.0f / esq_w;
1504
0
    for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1505
0
    float ver_val = 1.0f / esq_h;
1506
0
    for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1507
0
    return;
1508
0
  }
1509
1510
0
  const float e_recip = 1.0f / (float)total;
1511
0
  memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1512
0
  memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1513
0
  const unsigned int *cur_esq_row;
1514
0
  for (i = 0; i < esq_h - 1; i++) {
1515
0
    cur_esq_row = esq + i * esq_w;
1516
0
    for (j = 0; j < esq_w - 1; j++) {
1517
0
      hordist[j] += (float)cur_esq_row[j];
1518
0
      verdist[i] += (float)cur_esq_row[j];
1519
0
    }
1520
0
    verdist[i] += (float)cur_esq_row[j];
1521
0
  }
1522
0
  cur_esq_row = esq + i * esq_w;
1523
0
  for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1524
1525
0
  for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1526
0
  for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1527
0
}
1528
1529
0
static inline bool check_bit_mask(uint16_t mask, int val) {
1530
0
  return mask & (1 << val);
1531
0
}
1532
1533
0
static inline void set_bit_mask(uint16_t *mask, int val) {
1534
0
  *mask |= (1 << val);
1535
0
}
1536
1537
0
static inline void unset_bit_mask(uint16_t *mask, int val) {
1538
0
  *mask &= ~(1 << val);
1539
0
}
1540
1541
static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1542
                        int blk_row, int blk_col, TxSetType tx_set_type,
1543
                        TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
1544
0
                        uint16_t *allowed_tx_mask) {
1545
  // This table is used because the search order is different from the enum
1546
  // order.
1547
0
  static const int tx_type_table_2D[16] = {
1548
0
    DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1549
0
    ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1550
0
    FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1551
0
    H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1552
0
  };
1553
0
  if (tx_set_type != EXT_TX_SET_ALL16 &&
1554
0
      tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1555
0
    return;
1556
#if CONFIG_NN_V2
1557
  NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1558
  NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1559
#else
1560
0
  const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1561
0
  const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1562
0
#endif
1563
0
  if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
1564
1565
0
  float hfeatures[16], vfeatures[16];
1566
0
  float hscores[4], vscores[4];
1567
0
  float scores_2D_raw[16];
1568
0
  const int bw = tx_size_wide[tx_size];
1569
0
  const int bh = tx_size_high[tx_size];
1570
0
  const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1571
0
  const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1572
0
  assert(hfeatures_num <= 16);
1573
0
  assert(vfeatures_num <= 16);
1574
1575
0
  const struct macroblock_plane *const p = &x->plane[0];
1576
0
  const int diff_stride = block_size_wide[bsize];
1577
0
  const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1578
0
  get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1579
0
                                vfeatures);
1580
1581
0
  av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1582
0
                                  &hfeatures[hfeatures_num - 1],
1583
0
                                  &vfeatures[vfeatures_num - 1]);
1584
1585
#if CONFIG_NN_V2
1586
  av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
1587
  av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
1588
#else
1589
0
  av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
1590
0
  av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
1591
0
#endif
1592
1593
0
  for (int i = 0; i < 4; i++) {
1594
0
    float *cur_scores_2D = scores_2D_raw + i * 4;
1595
0
    cur_scores_2D[0] = vscores[i] * hscores[0];
1596
0
    cur_scores_2D[1] = vscores[i] * hscores[1];
1597
0
    cur_scores_2D[2] = vscores[i] * hscores[2];
1598
0
    cur_scores_2D[3] = vscores[i] * hscores[3];
1599
0
  }
1600
1601
0
  assert(TX_TYPES == 16);
1602
  // This version of the function only works when there are at most 16 classes.
1603
  // So we will need to change the optimization or use av1_nn_softmax instead if
1604
  // this ever gets changed.
1605
0
  av1_nn_fast_softmax_16(scores_2D_raw, scores_2D_raw);
1606
1607
0
  const float score_thresh =
1608
0
      get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
1609
1610
  // Always keep the TX type with the highest score, prune all others with
1611
  // score below score_thresh.
1612
0
  int max_score_i = 0;
1613
0
  float max_score = 0.0f;
1614
0
  uint16_t allow_bitmask = 0;
1615
0
  float sum_score = 0.0;
1616
  // Calculate sum of allowed tx type score and Populate allow bit mask based
1617
  // on score_thresh and allowed_tx_mask
1618
0
  int allow_count = 0;
1619
0
  int tx_type_allowed[16] = { TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1620
0
                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1621
0
                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1622
0
                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1623
0
                              TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1624
0
                              TX_TYPE_INVALID };
1625
0
  float scores_2D[16] = {
1626
0
    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1627
0
  };
1628
0
  for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1629
0
    const int allow_tx_type =
1630
0
        check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]);
1631
0
    if (!allow_tx_type) {
1632
0
      continue;
1633
0
    }
1634
0
    if (scores_2D_raw[tx_idx] > max_score) {
1635
0
      max_score = scores_2D_raw[tx_idx];
1636
0
      max_score_i = tx_idx;
1637
0
    }
1638
0
    if (scores_2D_raw[tx_idx] >= score_thresh) {
1639
      // Set allow mask based on score_thresh
1640
0
      set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
1641
1642
      // Accumulate score of allowed tx type
1643
0
      sum_score += scores_2D_raw[tx_idx];
1644
1645
0
      scores_2D[allow_count] = scores_2D_raw[tx_idx];
1646
0
      tx_type_allowed[allow_count] = tx_type_table_2D[tx_idx];
1647
0
      allow_count += 1;
1648
0
    }
1649
0
  }
1650
0
  if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) {
1651
    // If even the tx_type with max score is pruned, this means that no other
1652
    // tx_type is feasible. When this happens, we force enable max_score_i and
1653
    // end the search.
1654
0
    set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]);
1655
0
    memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
1656
0
    *allowed_tx_mask = allow_bitmask;
1657
0
    return;
1658
0
  }
1659
1660
  // Sort tx type probability of all types
1661
0
  if (allow_count <= 8) {
1662
0
    av1_sort_fi32_8(scores_2D, tx_type_allowed);
1663
0
  } else {
1664
0
    av1_sort_fi32_16(scores_2D, tx_type_allowed);
1665
0
  }
1666
1667
  // Enable more pruning based on tx type probability and number of allowed tx
1668
  // types
1669
0
  if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) {
1670
0
    float temp_score = 0.0;
1671
0
    float score_ratio = 0.0;
1672
0
    int tx_idx, tx_count = 0;
1673
0
    const float inv_sum_score = 100 / sum_score;
1674
    // Get allowed tx types based on sorted probability score and tx count
1675
0
    for (tx_idx = 0; tx_idx < allow_count; tx_idx++) {
1676
      // Skip the tx type which has more than 30% of cumulative
1677
      // probability and allowed tx type count is more than 2
1678
0
      if (score_ratio > 30.0 && tx_count >= 2) break;
1679
1680
0
      assert(check_bit_mask(allow_bitmask, tx_type_allowed[tx_idx]));
1681
      // Calculate cumulative probability
1682
0
      temp_score += scores_2D[tx_idx];
1683
1684
      // Calculate percentage of cumulative probability of allowed tx type
1685
0
      score_ratio = temp_score * inv_sum_score;
1686
0
      tx_count++;
1687
0
    }
1688
    // Set remaining tx types as pruned
1689
0
    for (; tx_idx < allow_count; tx_idx++)
1690
0
      unset_bit_mask(&allow_bitmask, tx_type_allowed[tx_idx]);
1691
0
  }
1692
1693
0
  memcpy(txk_map, tx_type_allowed, sizeof(tx_type_table_2D));
1694
0
  *allowed_tx_mask = allow_bitmask;
1695
0
}
1696
1697
0
static float get_dev(float mean, double x2_sum, int num) {
1698
0
  const float e_x2 = (float)(x2_sum / num);
1699
0
  const float diff = e_x2 - mean * mean;
1700
0
  const float dev = (diff > 0) ? sqrtf(diff) : 0;
1701
0
  return dev;
1702
0
}
1703
1704
// Writes the features required by the ML model to predict tx split based on
1705
// mean and standard deviation values of the block and sub-blocks.
1706
// Returns the number of elements written to the output array which is at most
1707
// 12 currently. Hence 'features' buffer should be able to accommodate at least
1708
// 12 elements.
1709
static inline int get_mean_dev_features(const int16_t *data, int stride, int bw,
1710
0
                                        int bh, float *features) {
1711
0
  const int16_t *const data_ptr = &data[0];
1712
0
  const int subh = (bh >= bw) ? (bh >> 1) : bh;
1713
0
  const int subw = (bw >= bh) ? (bw >> 1) : bw;
1714
0
  const int num = bw * bh;
1715
0
  const int sub_num = subw * subh;
1716
0
  int feature_idx = 2;
1717
0
  int total_x_sum = 0;
1718
0
  int64_t total_x2_sum = 0;
1719
0
  int num_sub_blks = 0;
1720
0
  double mean2_sum = 0.0f;
1721
0
  float dev_sum = 0.0f;
1722
1723
0
  for (int row = 0; row < bh; row += subh) {
1724
0
    for (int col = 0; col < bw; col += subw) {
1725
0
      int x_sum;
1726
0
      int64_t x2_sum;
1727
      // TODO(any): Write a SIMD version. Clear registers.
1728
0
      aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
1729
0
                          &x_sum, &x2_sum);
1730
0
      total_x_sum += x_sum;
1731
0
      total_x2_sum += x2_sum;
1732
1733
0
      const float mean = (float)x_sum / sub_num;
1734
0
      const float dev = get_dev(mean, (double)x2_sum, sub_num);
1735
0
      features[feature_idx++] = mean;
1736
0
      features[feature_idx++] = dev;
1737
0
      mean2_sum += (double)(mean * mean);
1738
0
      dev_sum += dev;
1739
0
      num_sub_blks++;
1740
0
    }
1741
0
  }
1742
1743
0
  const float lvl0_mean = (float)total_x_sum / num;
1744
0
  features[0] = lvl0_mean;
1745
0
  features[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
1746
1747
  // Deviation of means.
1748
0
  features[feature_idx++] = get_dev(lvl0_mean, mean2_sum, num_sub_blks);
1749
  // Mean of deviations.
1750
0
  features[feature_idx++] = dev_sum / num_sub_blks;
1751
1752
0
  return feature_idx;
1753
0
}
1754
1755
static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
1756
0
                               int blk_col, TX_SIZE tx_size) {
1757
0
  const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
1758
0
  if (!nn_config) return -1;
1759
1760
0
  const int diff_stride = block_size_wide[bsize];
1761
0
  const int16_t *diff =
1762
0
      x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1763
0
  const int bw = tx_size_wide[tx_size];
1764
0
  const int bh = tx_size_high[tx_size];
1765
1766
0
  float features[64] = { 0.0f };
1767
0
  get_mean_dev_features(diff, diff_stride, bw, bh, features);
1768
1769
0
  float score = 0.0f;
1770
0
  av1_nn_predict(features, nn_config, 1, &score);
1771
1772
0
  int int_score = (int)(score * 10000);
1773
0
  return clamp(int_score, -80000, 80000);
1774
0
}
1775
1776
static inline uint16_t get_tx_mask(
1777
    const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, int blk_row,
1778
    int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1779
    const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
1780
0
    int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
1781
0
  const AV1_COMMON *cm = &cpi->common;
1782
0
  MACROBLOCKD *xd = &x->e_mbd;
1783
0
  MB_MODE_INFO *mbmi = xd->mi[0];
1784
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
1785
0
  const int is_inter = is_inter_block(mbmi);
1786
0
  const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
1787
  // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
1788
  // TX_TYPES, only that specific tx type is allowed.
1789
0
  TX_TYPE txk_allowed = TX_TYPES;
1790
1791
0
  const FRAME_UPDATE_TYPE update_type =
1792
0
      get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index);
1793
0
  int use_actual_frame_probs = 1;
1794
0
  const int *tx_type_probs;
1795
#if CONFIG_FPMT_TEST
1796
  use_actual_frame_probs =
1797
      (cpi->ppi->fpmt_unit_test_cfg == PARALLEL_SIMULATION_ENCODE) ? 0 : 1;
1798
  if (!use_actual_frame_probs) {
1799
    tx_type_probs =
1800
        (int *)cpi->ppi->temp_frame_probs.tx_type_probs[update_type][tx_size];
1801
  }
1802
#endif
1803
0
  if (use_actual_frame_probs) {
1804
0
    tx_type_probs = cpi->ppi->frame_probs.tx_type_probs[update_type][tx_size];
1805
0
  }
1806
1807
0
  if ((!is_inter && txfm_params->use_default_intra_tx_type) ||
1808
0
      (is_inter && txfm_params->default_inter_tx_type_prob_thresh == 0)) {
1809
0
    txk_allowed =
1810
0
        get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools);
1811
0
  } else if (is_inter &&
1812
0
             txfm_params->default_inter_tx_type_prob_thresh != INT_MAX) {
1813
0
    if (tx_type_probs[DEFAULT_INTER_TX_TYPE] >
1814
0
        txfm_params->default_inter_tx_type_prob_thresh) {
1815
0
      txk_allowed = DEFAULT_INTER_TX_TYPE;
1816
0
    } else {
1817
0
      int force_tx_type = 0;
1818
0
      int max_prob = 0;
1819
0
      const int tx_type_prob_threshold =
1820
0
          txfm_params->default_inter_tx_type_prob_thresh +
1821
0
          PROB_THRESH_OFFSET_TX_TYPE;
1822
0
      for (int i = 1; i < TX_TYPES; i++) {  // find maximum probability.
1823
0
        if (tx_type_probs[i] > max_prob) {
1824
0
          max_prob = tx_type_probs[i];
1825
0
          force_tx_type = i;
1826
0
        }
1827
0
      }
1828
0
      if (max_prob > tx_type_prob_threshold)  // force tx type with max prob.
1829
0
        txk_allowed = force_tx_type;
1830
0
      else if (x->rd_model == LOW_TXFM_RD) {
1831
0
        if (plane == 0) txk_allowed = DCT_DCT;
1832
0
      }
1833
0
    }
1834
0
  } else if (x->rd_model == LOW_TXFM_RD) {
1835
0
    if (plane == 0) txk_allowed = DCT_DCT;
1836
0
  }
1837
1838
0
  const TxSetType tx_set_type = av1_get_ext_tx_set_type(
1839
0
      tx_size, is_inter, cm->features.reduced_tx_set_used);
1840
1841
0
  TX_TYPE uv_tx_type = DCT_DCT;
1842
0
  if (plane) {
1843
    // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
1844
0
    uv_tx_type = txk_allowed =
1845
0
        av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1846
0
                        cm->features.reduced_tx_set_used);
1847
0
  }
1848
0
  PREDICTION_MODE intra_dir =
1849
0
      mbmi->filter_intra_mode_info.use_filter_intra
1850
0
          ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
1851
0
          : mbmi->mode;
1852
0
  uint16_t ext_tx_used_flag =
1853
0
      cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset != 0 &&
1854
0
              tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
1855
0
          ? av1_reduced_intra_tx_used_flag[intra_dir]
1856
0
          : av1_ext_tx_used_flag[tx_set_type];
1857
1858
0
  if (cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset == 2)
1859
0
    ext_tx_used_flag &= av1_derived_intra_tx_used_flag[intra_dir];
1860
1861
0
  if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
1862
0
      ext_tx_used_flag == 0x0001 ||
1863
0
      (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) ||
1864
0
      (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) {
1865
0
    txk_allowed = DCT_DCT;
1866
0
  }
1867
1868
0
  if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0)
1869
0
    ext_tx_used_flag &= DCT_ADST_TX_MASK;
1870
1871
0
  uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
1872
0
  if (txk_allowed < TX_TYPES) {
1873
0
    allowed_tx_mask = 1 << txk_allowed;
1874
0
    allowed_tx_mask &= ext_tx_used_flag;
1875
0
  } else if (fast_tx_search) {
1876
0
    allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
1877
0
    allowed_tx_mask &= ext_tx_used_flag;
1878
0
  } else if (!is_inter && txfm_params->use_derived_intra_tx_type_set) {
1879
0
    allowed_tx_mask = av1_derived_intra_tx_used_flag[intra_dir];
1880
0
    allowed_tx_mask &= ext_tx_used_flag;
1881
0
  } else {
1882
0
    assert(plane == 0);
1883
0
    allowed_tx_mask = ext_tx_used_flag;
1884
0
    int num_allowed = 0;
1885
0
    int i;
1886
1887
0
    if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
1888
0
      static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
1889
0
                                            { 10, 17, 17, 10, 17, 17, 17 } };
1890
0
      const int thresh =
1891
0
          thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
1892
0
                    [update_type];
1893
0
      uint16_t prune = 0;
1894
0
      int max_prob = -1;
1895
0
      int max_idx = 0;
1896
0
      for (i = 0; i < TX_TYPES; i++) {
1897
0
        if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
1898
0
          max_prob = tx_type_probs[i];
1899
0
          max_idx = i;
1900
0
        }
1901
0
        if (tx_type_probs[i] < thresh) prune |= (1 << i);
1902
0
      }
1903
0
      if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
1904
0
      allowed_tx_mask &= (~prune);
1905
0
    }
1906
0
    for (i = 0; i < TX_TYPES; i++) {
1907
0
      if (allowed_tx_mask & (1 << i)) num_allowed++;
1908
0
    }
1909
0
    assert(num_allowed > 0);
1910
1911
0
    if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
1912
0
      int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
1913
0
      int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
1914
0
      if (num_allowed <= 7) {
1915
0
        const uint16_t prune =
1916
0
            prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
1917
0
                           plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
1918
0
                           cm->features.reduced_tx_set_used);
1919
0
        allowed_tx_mask &= (~prune);
1920
0
      } else {
1921
0
        const int num_sel = (num_allowed * mf + 50) / 100;
1922
0
        const uint16_t prune = prune_txk_type_separ(
1923
0
            cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
1924
0
            txk_map, allowed_tx_mask, pf, txb_ctx,
1925
0
            cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
1926
1927
0
        allowed_tx_mask &= (~prune);
1928
0
      }
1929
0
    } else {
1930
0
      assert(num_allowed > 0);
1931
0
      int allowed_tx_count =
1932
0
          (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5;
1933
      // !fast_tx_search && txk_end != txk_start && plane == 0
1934
0
      if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter &&
1935
0
          num_allowed > allowed_tx_count) {
1936
0
        prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
1937
0
                    txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask);
1938
0
      }
1939
0
    }
1940
0
  }
1941
1942
  // Need to have at least one transform type allowed.
1943
0
  if (allowed_tx_mask == 0) {
1944
0
    txk_allowed = (plane ? uv_tx_type : DCT_DCT);
1945
0
    allowed_tx_mask = (1 << txk_allowed);
1946
0
  }
1947
1948
0
  assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
1949
0
  *allowed_txk_types = txk_allowed;
1950
0
  return allowed_tx_mask;
1951
0
}
1952
1953
#if CONFIG_RD_DEBUG
1954
static inline void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
1955
                                         int txb_coeff_cost) {
1956
  rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
1957
}
1958
#endif
1959
1960
static inline int cost_coeffs(MACROBLOCK *x, int plane, int block,
1961
                              TX_SIZE tx_size, const TX_TYPE tx_type,
1962
                              const TXB_CTX *const txb_ctx,
1963
0
                              int reduced_tx_set_used) {
1964
#if TXCOEFF_COST_TIMER
1965
  struct aom_usec_timer timer;
1966
  aom_usec_timer_start(&timer);
1967
#endif
1968
0
  const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
1969
0
                                       txb_ctx, reduced_tx_set_used);
1970
#if TXCOEFF_COST_TIMER
1971
  AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
1972
  aom_usec_timer_mark(&timer);
1973
  const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
1974
  tmp_cm->txcoeff_cost_timer += elapsed_time;
1975
  ++tmp_cm->txcoeff_cost_count;
1976
#endif
1977
0
  return cost;
1978
0
}
1979
1980
static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
1981
                                          QUANT_PARAM *quant_param, int plane,
1982
                                          int block, TX_SIZE tx_size,
1983
                                          int quant_b_adapt, int qstep,
1984
                                          unsigned int coeff_opt_satd_threshold,
1985
0
                                          int skip_trellis, int dc_only_blk) {
1986
0
  if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX))
1987
0
    return skip_trellis;
1988
1989
0
  const struct macroblock_plane *const p = &x->plane[plane];
1990
0
  const int block_offset = BLOCK_OFFSET(block);
1991
0
  tran_low_t *const coeff_ptr = p->coeff + block_offset;
1992
0
  const int n_coeffs = av1_get_max_eob(tx_size);
1993
0
  const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size));
1994
0
  int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs);
1995
0
  satd = RIGHT_SIGNED_SHIFT(satd, shift);
1996
0
  satd >>= (x->e_mbd.bd - 8);
1997
1998
0
  const int skip_block_trellis =
1999
0
      ((uint64_t)satd >
2000
0
       (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]);
2001
2002
0
  av1_setup_quant(
2003
0
      tx_size, !skip_block_trellis,
2004
0
      skip_block_trellis
2005
0
          ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP)
2006
0
          : AV1_XFORM_QUANT_FP,
2007
0
      quant_b_adapt, quant_param);
2008
2009
0
  return skip_block_trellis;
2010
0
}
2011
2012
// Predict DC only blocks if the residual variance is below a qstep based
2013
// threshold.For such blocks, transform type search is bypassed.
2014
static inline void predict_dc_only_block(
2015
    MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2016
    int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
2017
    int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
2018
0
    int *dc_only_blk) {
2019
0
  MACROBLOCKD *xd = &x->e_mbd;
2020
0
  MB_MODE_INFO *mbmi = xd->mi[0];
2021
0
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2022
0
  const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2023
0
  uint64_t block_var = UINT64_MAX;
2024
0
  const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3;
2025
0
  *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize,
2026
0
                                txsize_to_bsize[tx_size], block_mse_q8,
2027
0
                                per_px_mean, &block_var);
2028
0
  assert((*block_mse_q8) != UINT_MAX);
2029
0
  uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep);
2030
0
  if (is_cur_buf_hbd(xd))
2031
0
    block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2);
2032
2033
0
  if (block_var >= var_threshold) return;
2034
0
  const unsigned int predict_dc_level = x->txfm_search_params.predict_dc_level;
2035
0
  assert(predict_dc_level != 0);
2036
2037
  // Prediction of skip block if residual mean and variance are less
2038
  // than qstep based threshold
2039
0
  if ((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) {
2040
    // If the normalized mean of residual block is less than the dc qstep and
2041
    // the  normalized block variance is less than ac qstep, then the block is
2042
    // assumed to be a skip block and its rdcost is updated accordingly.
2043
0
    best_rd_stats->skip_txfm = 1;
2044
2045
0
    x->plane[plane].eobs[block] = 0;
2046
2047
0
    if (is_cur_buf_hbd(xd))
2048
0
      *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2);
2049
2050
0
    best_rd_stats->dist = (*block_sse) << 4;
2051
0
    best_rd_stats->sse = best_rd_stats->dist;
2052
2053
0
    ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
2054
0
    ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
2055
0
    av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl);
2056
0
    ENTROPY_CONTEXT *ta = ctxa;
2057
0
    ENTROPY_CONTEXT *tl = ctxl;
2058
0
    const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2059
0
    TXB_CTX txb_ctx_tmp;
2060
0
    const PLANE_TYPE plane_type = get_plane_type(plane);
2061
0
    get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp);
2062
0
    const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type]
2063
0
                                  .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1];
2064
0
    best_rd_stats->rate = zero_blk_rate;
2065
2066
0
    best_rd_stats->rdcost =
2067
0
        RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse);
2068
2069
0
    x->plane[plane].txb_entropy_ctx[block] = 0;
2070
0
  } else if (predict_dc_level > 1) {
2071
    // Predict DC only blocks based on residual variance.
2072
    // For chroma plane, this prediction is disabled for intra blocks.
2073
0
    if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1;
2074
0
  }
2075
0
}
2076
2077
// Search for the best transform type for a given transform block.
2078
// This function can be used for both inter and intra, both luma and chroma.
2079
static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2080
                           int block, int blk_row, int blk_col,
2081
                           BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2082
                           const TXB_CTX *const txb_ctx,
2083
                           FAST_TX_SEARCH_MODE ftxs_mode, int64_t ref_best_rd,
2084
0
                           RD_STATS *best_rd_stats) {
2085
0
  const AV1_COMMON *cm = &cpi->common;
2086
0
  MACROBLOCKD *xd = &x->e_mbd;
2087
0
  MB_MODE_INFO *mbmi = xd->mi[0];
2088
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2089
0
  int64_t best_rd = INT64_MAX;
2090
0
  uint16_t best_eob = 0;
2091
0
  TX_TYPE best_tx_type = DCT_DCT;
2092
0
  int rate_cost = 0;
2093
0
  struct macroblock_plane *const p = &x->plane[plane];
2094
0
  tran_low_t *orig_dqcoeff = p->dqcoeff;
2095
0
  tran_low_t *best_dqcoeff = x->dqcoeff_buf;
2096
0
  const int tx_type_map_idx =
2097
0
      plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
2098
0
  av1_invalid_rd_stats(best_rd_stats);
2099
2100
0
  int skip_trellis = !is_trellis_used(
2101
0
      cpi->optimize_seg_arr[xd->mi[0]->segment_id], DRY_RUN_NORMAL);
2102
2103
0
  uint8_t best_txb_ctx = 0;
2104
  // txk_allowed = TX_TYPES: >1 tx types are allowed
2105
  // txk_allowed < TX_TYPES: only that specific tx type is allowed.
2106
0
  TX_TYPE txk_allowed = TX_TYPES;
2107
0
  int txk_map[TX_TYPES] = {
2108
0
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
2109
0
  };
2110
0
  const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2111
0
  const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2112
2113
0
  const uint8_t txw = tx_size_wide[tx_size];
2114
0
  const uint8_t txh = tx_size_high[tx_size];
2115
0
  int64_t block_sse;
2116
0
  unsigned int block_mse_q8;
2117
0
  int dc_only_blk = 0;
2118
0
  const bool predict_dc_block =
2119
0
      txfm_params->predict_dc_level >= 1 && txw != 64 && txh != 64;
2120
0
  int64_t per_px_mean = INT64_MAX;
2121
0
  if (predict_dc_block) {
2122
0
    predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row,
2123
0
                          blk_col, best_rd_stats, &block_sse, &block_mse_q8,
2124
0
                          &per_px_mean, &dc_only_blk);
2125
0
    if (best_rd_stats->skip_txfm == 1) {
2126
0
      const TX_TYPE tx_type = DCT_DCT;
2127
0
      if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2128
0
      return;
2129
0
    }
2130
0
  } else {
2131
0
    block_sse = av1_pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
2132
0
                                    txsize_to_bsize[tx_size], &block_mse_q8);
2133
0
    assert(block_mse_q8 != UINT_MAX);
2134
0
  }
2135
2136
  // Bit mask to indicate which transform types are allowed in the RD search.
2137
0
  uint16_t tx_mask;
2138
2139
  // Use DCT_DCT transform for DC only block.
2140
0
  if (dc_only_blk || cpi->sf.rt_sf.dct_only_palette_nonrd == 1)
2141
0
    tx_mask = 1 << DCT_DCT;
2142
0
  else
2143
0
    tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
2144
0
                          tx_size, txb_ctx, ftxs_mode, ref_best_rd,
2145
0
                          &txk_allowed, txk_map);
2146
0
  const uint16_t allowed_tx_mask = tx_mask;
2147
2148
0
  if (is_cur_buf_hbd(xd)) {
2149
0
    block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
2150
0
    block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
2151
0
  }
2152
0
  block_sse *= 16;
2153
  // Use mse / qstep^2 based threshold logic to take decision of R-D
2154
  // optimization of coeffs. For smaller residuals, coeff optimization
2155
  // would be helpful. For larger residuals, R-D optimization may not be
2156
  // effective.
2157
  // TODO(any): Experiment with variance and mean based thresholds
2158
0
  const int perform_block_coeff_opt =
2159
0
      ((uint64_t)block_mse_q8 <=
2160
0
       (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep);
2161
0
  skip_trellis |= !perform_block_coeff_opt;
2162
2163
  // Flag to indicate if distortion should be calculated in transform domain or
2164
  // not during iterating through transform type candidates.
2165
  // Transform domain distortion is accurate for higher residuals.
2166
  // TODO(any): Experiment with variance and mean based thresholds
2167
0
  int use_transform_domain_distortion =
2168
0
      (txfm_params->use_transform_domain_distortion > 0) &&
2169
0
      (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) &&
2170
      // Any 64-pt transforms only preserves half the coefficients.
2171
      // Therefore transform domain distortion is not valid for these
2172
      // transform sizes.
2173
0
      (txsize_sqr_up_map[tx_size] != TX_64X64) &&
2174
      // Use pixel domain distortion for DC only blocks
2175
0
      !dc_only_blk;
2176
  // Flag to indicate if an extra calculation of distortion in the pixel domain
2177
  // should be performed at the end, after the best transform type has been
2178
  // decided.
2179
0
  int calc_pixel_domain_distortion_final =
2180
0
      txfm_params->use_transform_domain_distortion == 1 &&
2181
0
      use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
2182
0
  if (calc_pixel_domain_distortion_final &&
2183
0
      (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
2184
0
    calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
2185
2186
0
  const uint16_t *eobs_ptr = x->plane[plane].eobs;
2187
2188
0
  TxfmParam txfm_param;
2189
0
  QUANT_PARAM quant_param;
2190
0
  int skip_trellis_based_on_satd[TX_TYPES] = { 0 };
2191
0
  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
2192
0
  av1_setup_quant(tx_size, !skip_trellis,
2193
0
                  skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
2194
0
                                                         : AV1_XFORM_QUANT_FP)
2195
0
                               : AV1_XFORM_QUANT_FP,
2196
0
                  cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
2197
2198
  // Iterate through all transform type candidates.
2199
0
  for (int idx = 0; idx < TX_TYPES; ++idx) {
2200
0
    const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
2201
0
    if (tx_type == TX_TYPE_INVALID || !check_bit_mask(allowed_tx_mask, tx_type))
2202
0
      continue;
2203
0
    txfm_param.tx_type = tx_type;
2204
0
    if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
2205
0
      av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
2206
0
                        &quant_param);
2207
0
    }
2208
0
    if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2209
0
    RD_STATS this_rd_stats;
2210
0
    av1_invalid_rd_stats(&this_rd_stats);
2211
2212
0
    if (!dc_only_blk)
2213
0
      av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
2214
0
    else
2215
0
      av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
2216
2217
0
    skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd(
2218
0
        x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt,
2219
0
        qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk);
2220
2221
0
    av1_quant(x, plane, block, &txfm_param, &quant_param);
2222
2223
    // Calculate rate cost of quantized coefficients.
2224
0
    if (quant_param.use_optimize_b) {
2225
      // TODO(aomedia:3209): update Trellis quantization to take into account
2226
      // quantization matrices.
2227
0
      av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
2228
0
                     &rate_cost);
2229
0
    } else {
2230
0
      rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
2231
0
                              cm->features.reduced_tx_set_used);
2232
0
    }
2233
2234
    // If rd cost based on coeff rate alone is already more than best_rd,
2235
    // terminate early.
2236
0
    if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
2237
2238
    // Calculate distortion.
2239
0
    if (eobs_ptr[block] == 0) {
2240
      // When eob is 0, pixel domain distortion is more efficient and accurate.
2241
0
      this_rd_stats.dist = this_rd_stats.sse = block_sse;
2242
0
    } else if (dc_only_blk) {
2243
0
      this_rd_stats.sse = block_sse;
2244
0
      this_rd_stats.dist = dist_block_px_domain(
2245
0
          cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2246
0
    } else if (use_transform_domain_distortion) {
2247
0
      const SCAN_ORDER *const scan_order =
2248
0
          get_scan(txfm_param.tx_size, txfm_param.tx_type);
2249
0
      dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2250
0
                           scan_order->scan, &this_rd_stats.dist,
2251
0
                           &this_rd_stats.sse);
2252
0
    } else {
2253
0
      int64_t sse_diff = INT64_MAX;
2254
      // high_energy threshold assumes that every pixel within a txfm block
2255
      // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
2256
      // for 8 bit.
2257
0
      const int64_t high_energy_thresh =
2258
0
          ((int64_t)128 * 128 * tx_size_2d[tx_size]);
2259
0
      const int is_high_energy = (block_sse >= high_energy_thresh);
2260
0
      if (tx_size == TX_64X64 || is_high_energy) {
2261
        // Because 3 out 4 quadrants of transform coefficients are forced to
2262
        // zero, the inverse transform has a tendency to overflow. sse_diff
2263
        // is effectively the energy of those 3 quadrants, here we use it
2264
        // to decide if we should do pixel domain distortion. If the energy
2265
        // is mostly in first quadrant, then it is unlikely that we have
2266
        // overflow issue in inverse transform.
2267
0
        const SCAN_ORDER *const scan_order =
2268
0
            get_scan(txfm_param.tx_size, txfm_param.tx_type);
2269
0
        dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2270
0
                             scan_order->scan, &this_rd_stats.dist,
2271
0
                             &this_rd_stats.sse);
2272
0
        sse_diff = block_sse - this_rd_stats.sse;
2273
0
      }
2274
0
      if (tx_size != TX_64X64 || !is_high_energy ||
2275
0
          (sse_diff * 2) < this_rd_stats.sse) {
2276
0
        const int64_t tx_domain_dist = this_rd_stats.dist;
2277
0
        this_rd_stats.dist = dist_block_px_domain(
2278
0
            cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2279
        // For high energy blocks, occasionally, the pixel domain distortion
2280
        // can be artificially low due to clamping at reconstruction stage
2281
        // even when inverse transform output is hugely different from the
2282
        // actual residue.
2283
0
        if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
2284
0
          this_rd_stats.dist = tx_domain_dist;
2285
0
      } else {
2286
0
        assert(sse_diff < INT64_MAX);
2287
0
        this_rd_stats.dist += sse_diff;
2288
0
      }
2289
0
      this_rd_stats.sse = block_sse;
2290
0
    }
2291
2292
0
    this_rd_stats.rate = rate_cost;
2293
2294
0
    const int64_t rd =
2295
0
        RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2296
2297
0
    if (rd < best_rd) {
2298
0
      best_rd = rd;
2299
0
      *best_rd_stats = this_rd_stats;
2300
0
      best_tx_type = tx_type;
2301
0
      best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
2302
0
      best_eob = x->plane[plane].eobs[block];
2303
      // Swap dqcoeff buffers
2304
0
      tran_low_t *const tmp_dqcoeff = best_dqcoeff;
2305
0
      best_dqcoeff = p->dqcoeff;
2306
0
      p->dqcoeff = tmp_dqcoeff;
2307
0
    }
2308
2309
#if CONFIG_COLLECT_RD_STATS == 1
2310
    if (plane == 0) {
2311
      PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
2312
                              plane_bsize, tx_size, tx_type, rd);
2313
    }
2314
#endif  // CONFIG_COLLECT_RD_STATS == 1
2315
2316
#if COLLECT_TX_SIZE_DATA
2317
    // Generate small sample to restrict output size.
2318
    static unsigned int seed = 21743;
2319
    if (lcg_rand16(&seed) % 200 == 0) {
2320
      FILE *fp = NULL;
2321
2322
      if (within_border) {
2323
        fp = fopen(av1_tx_size_data_output_file, "a");
2324
      }
2325
2326
      if (fp) {
2327
        // Transform info and RD
2328
        const int txb_w = tx_size_wide[tx_size];
2329
        const int txb_h = tx_size_high[tx_size];
2330
2331
        // Residue signal.
2332
        const int diff_stride = block_size_wide[plane_bsize];
2333
        struct macroblock_plane *const p = &x->plane[plane];
2334
        const int16_t *src_diff =
2335
            &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
2336
2337
        for (int r = 0; r < txb_h; ++r) {
2338
          for (int c = 0; c < txb_w; ++c) {
2339
            fprintf(fp, "%d,", src_diff[c]);
2340
          }
2341
          src_diff += diff_stride;
2342
        }
2343
2344
        fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
2345
        fprintf(fp, "\n");
2346
        fclose(fp);
2347
      }
2348
    }
2349
#endif  // COLLECT_TX_SIZE_DATA
2350
2351
    // If the current best RD cost is much worse than the reference RD cost,
2352
    // terminate early.
2353
0
    if (cpi->sf.tx_sf.adaptive_txb_search_level) {
2354
0
      if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
2355
0
          ref_best_rd) {
2356
0
        break;
2357
0
      }
2358
0
    }
2359
2360
    // Terminate transform type search if the block has been quantized to
2361
    // all zero.
2362
0
    if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
2363
0
  }
2364
2365
0
  assert(best_rd != INT64_MAX);
2366
2367
0
  best_rd_stats->skip_txfm = best_eob == 0;
2368
0
  if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2369
0
  x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
2370
0
  x->plane[plane].eobs[block] = best_eob;
2371
0
  skip_trellis = skip_trellis_based_on_satd[best_tx_type];
2372
2373
  // Point dqcoeff to the quantized coefficients corresponding to the best
2374
  // transform type, then we can skip transform and quantization, e.g. in the
2375
  // final pixel domain distortion calculation and recon_intra().
2376
0
  p->dqcoeff = best_dqcoeff;
2377
2378
0
  if (calc_pixel_domain_distortion_final && best_eob) {
2379
0
    best_rd_stats->dist = dist_block_px_domain(
2380
0
        cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2381
0
    best_rd_stats->sse = block_sse;
2382
0
  }
2383
2384
  // Intra mode needs decoded pixels such that the next transform block
2385
  // can use them for prediction.
2386
0
  recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2387
0
              txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
2388
0
  p->dqcoeff = orig_dqcoeff;
2389
0
}
2390
2391
// Pick transform type for a luma transform block of tx_size. Note this function
2392
// is used only for inter-predicted blocks.
2393
static inline void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
2394
                              TX_SIZE tx_size, int blk_row, int blk_col,
2395
                              int block, int plane_bsize, TXB_CTX *txb_ctx,
2396
                              RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode,
2397
0
                              int64_t ref_rdcost) {
2398
0
  assert(is_inter_block(x->e_mbd.mi[0]));
2399
0
  RD_STATS this_rd_stats;
2400
0
  search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
2401
0
                 txb_ctx, ftxs_mode, ref_rdcost, &this_rd_stats);
2402
2403
0
  av1_merge_rd_stats(rd_stats, &this_rd_stats);
2404
0
}
2405
2406
static inline void try_tx_block_no_split(
2407
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2408
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
2409
    const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
2410
    int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
2411
0
    FAST_TX_SEARCH_MODE ftxs_mode, TxCandidateInfo *no_split) {
2412
0
  MACROBLOCKD *const xd = &x->e_mbd;
2413
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2414
0
  struct macroblock_plane *const p = &x->plane[0];
2415
0
  const ENTROPY_CONTEXT *const pta = ta + blk_col;
2416
0
  const ENTROPY_CONTEXT *const ptl = tl + blk_row;
2417
0
  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2418
0
  TXB_CTX txb_ctx;
2419
0
  get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
2420
0
  const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
2421
0
                                .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2422
0
  rd_stats->zero_rate = zero_blk_rate;
2423
0
  const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
2424
0
  mbmi->inter_tx_size[index] = tx_size;
2425
0
  tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2426
0
             rd_stats, ftxs_mode, ref_best_rd);
2427
0
  assert(rd_stats->rate < INT_MAX);
2428
2429
0
  const int pick_skip_txfm =
2430
0
      !xd->lossless[mbmi->segment_id] &&
2431
0
      (rd_stats->skip_txfm == 1 ||
2432
0
       RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2433
0
           RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
2434
0
  if (pick_skip_txfm) {
2435
#if CONFIG_RD_DEBUG
2436
    update_txb_coeff_cost(rd_stats, 0, zero_blk_rate - rd_stats->rate);
2437
#endif  // CONFIG_RD_DEBUG
2438
0
    rd_stats->rate = zero_blk_rate;
2439
0
    rd_stats->dist = rd_stats->sse;
2440
0
    p->eobs[block] = 0;
2441
0
    update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2442
0
  }
2443
0
  rd_stats->skip_txfm = pick_skip_txfm;
2444
2445
0
  if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2446
0
    rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0];
2447
2448
0
  no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
2449
0
  no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
2450
0
  no_split->tx_type =
2451
0
      xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
2452
0
}
2453
2454
static inline void try_tx_block_split(
2455
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2456
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2457
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2458
    int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
2459
0
    FAST_TX_SEARCH_MODE ftxs_mode, RD_STATS *split_rd_stats) {
2460
0
  assert(tx_size < TX_SIZES_ALL);
2461
0
  MACROBLOCKD *const xd = &x->e_mbd;
2462
0
  const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2463
0
  const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2464
0
  const int txb_width = tx_size_wide_unit[tx_size];
2465
0
  const int txb_height = tx_size_high_unit[tx_size];
2466
  // Transform size after splitting current block.
2467
0
  const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2468
0
  const int sub_txb_width = tx_size_wide_unit[sub_txs];
2469
0
  const int sub_txb_height = tx_size_high_unit[sub_txs];
2470
0
  const int sub_step = sub_txb_width * sub_txb_height;
2471
0
  const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
2472
0
  assert(nblks > 0);
2473
0
  av1_init_rd_stats(split_rd_stats);
2474
0
  split_rd_stats->rate =
2475
0
      x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1];
2476
2477
0
  for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
2478
0
    const int offsetr = blk_row + r;
2479
0
    if (offsetr >= max_blocks_high) break;
2480
0
    for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
2481
0
      assert(blk_idx < 4);
2482
0
      (void)blk_idx;
2483
0
      const int offsetc = blk_col + c;
2484
0
      if (offsetc >= max_blocks_wide) continue;
2485
2486
0
      RD_STATS this_rd_stats;
2487
0
      int this_cost_valid = 1;
2488
0
      select_tx_block(cpi, x, offsetr, offsetc, block, sub_txs, depth + 1,
2489
0
                      plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
2490
0
                      no_split_rd / nblks, ref_best_rd - split_rd_stats->rdcost,
2491
0
                      &this_cost_valid, ftxs_mode, -1);
2492
0
      if (!this_cost_valid) {
2493
0
        split_rd_stats->rdcost = INT64_MAX;
2494
0
        return;
2495
0
      }
2496
0
      av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
2497
0
      split_rd_stats->rdcost =
2498
0
          RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
2499
0
      if (split_rd_stats->rdcost > ref_best_rd) {
2500
0
        split_rd_stats->rdcost = INT64_MAX;
2501
0
        return;
2502
0
      }
2503
0
      block += sub_step;
2504
0
    }
2505
0
  }
2506
0
}
2507
2508
0
static float get_var(float mean, double x2_sum, int num) {
2509
0
  const float e_x2 = (float)(x2_sum / num);
2510
0
  const float diff = e_x2 - mean * mean;
2511
0
  return diff;
2512
0
}
2513
2514
static inline void get_blk_var_dev(const int16_t *data, int stride, int bw,
2515
                                   int bh, float *dev_of_mean,
2516
0
                                   float *var_of_vars) {
2517
0
  const int16_t *const data_ptr = &data[0];
2518
0
  const int subh = (bh >= bw) ? (bh >> 1) : bh;
2519
0
  const int subw = (bw >= bh) ? (bw >> 1) : bw;
2520
0
  const int num = bw * bh;
2521
0
  const int sub_num = subw * subh;
2522
0
  int total_x_sum = 0;
2523
0
  int64_t total_x2_sum = 0;
2524
0
  int blk_idx = 0;
2525
0
  float var_sum = 0.0f;
2526
0
  float mean_sum = 0.0f;
2527
0
  double var2_sum = 0.0f;
2528
0
  double mean2_sum = 0.0f;
2529
2530
0
  for (int row = 0; row < bh; row += subh) {
2531
0
    for (int col = 0; col < bw; col += subw) {
2532
0
      int x_sum;
2533
0
      int64_t x2_sum;
2534
0
      aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
2535
0
                          &x_sum, &x2_sum);
2536
0
      total_x_sum += x_sum;
2537
0
      total_x2_sum += x2_sum;
2538
2539
0
      const float mean = (float)x_sum / sub_num;
2540
0
      const float var = get_var(mean, (double)x2_sum, sub_num);
2541
0
      mean_sum += mean;
2542
0
      mean2_sum += (double)(mean * mean);
2543
0
      var_sum += var;
2544
0
      var2_sum += var * var;
2545
0
      blk_idx++;
2546
0
    }
2547
0
  }
2548
2549
0
  const float lvl0_mean = (float)total_x_sum / num;
2550
0
  const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num);
2551
0
  mean_sum += lvl0_mean;
2552
0
  mean2_sum += (double)(lvl0_mean * lvl0_mean);
2553
0
  var_sum += block_var;
2554
0
  var2_sum += block_var * block_var;
2555
0
  const float av_mean = mean_sum / 5;
2556
2557
0
  if (blk_idx > 1) {
2558
    // Deviation of means.
2559
0
    *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1));
2560
    // Variance of variances.
2561
0
    const float mean_var = var_sum / (blk_idx + 1);
2562
0
    *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1));
2563
0
  }
2564
0
}
2565
2566
static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
2567
                                    int blk_row, int blk_col, TX_SIZE tx_size,
2568
                                    int *try_no_split, int *try_split,
2569
0
                                    int pruning_level) {
2570
0
  const int diff_stride = block_size_wide[bsize];
2571
0
  const int16_t *diff =
2572
0
      x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
2573
0
  const int bw = tx_size_wide[tx_size];
2574
0
  const int bh = tx_size_high[tx_size];
2575
0
  float dev_of_means = 0.0f;
2576
0
  float var_of_vars = 0.0f;
2577
2578
  // This function calculates the deviation of means, and the variance of pixel
2579
  // variances of the block as well as it's sub-blocks.
2580
0
  get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars);
2581
0
  const int dc_q = x->plane[0].dequant_QTX[0] >> 3;
2582
0
  const int ac_q = x->plane[0].dequant_QTX[1] >> 3;
2583
0
  const int no_split_thresh_scales[4] = { 0, 24, 8, 8 };
2584
0
  const int no_split_thresh_scale = no_split_thresh_scales[pruning_level];
2585
0
  const int split_thresh_scales[4] = { 0, 24, 10, 8 };
2586
0
  const int split_thresh_scale = split_thresh_scales[pruning_level];
2587
2588
0
  if ((dev_of_means <= dc_q) &&
2589
0
      (split_thresh_scale * var_of_vars <= ac_q * ac_q)) {
2590
0
    *try_split = 0;
2591
0
  }
2592
0
  if ((dev_of_means > no_split_thresh_scale * dc_q) &&
2593
0
      (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) {
2594
0
    *try_no_split = 0;
2595
0
  }
2596
0
}
2597
2598
// Search for the best transform partition(recursive)/type for a given
2599
// inter-predicted luma block. The obtained transform selection will be saved
2600
// in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
2601
static inline void select_tx_block(
2602
    const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2603
    TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2604
    ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2605
    RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
2606
0
    int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode, int blk_idx) {
2607
0
  assert(tx_size < TX_SIZES_ALL);
2608
0
  av1_init_rd_stats(rd_stats);
2609
0
  if (ref_best_rd < 0) {
2610
0
    *is_cost_valid = 0;
2611
0
    return;
2612
0
  }
2613
2614
0
  MACROBLOCKD *const xd = &x->e_mbd;
2615
0
  assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
2616
0
         blk_col < max_block_wide(xd, plane_bsize, 0));
2617
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2618
0
  const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2619
0
                                         mbmi->bsize, tx_size);
2620
0
  struct macroblock_plane *const p = &x->plane[0];
2621
2622
0
  int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
2623
0
                      txsize_sqr_up_map[tx_size] != TX_64X64) &&
2624
0
                     (cpi->oxcf.txfm_cfg.enable_rect_tx ||
2625
0
                      tx_size_wide[tx_size] == tx_size_high[tx_size]);
2626
0
  int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
2627
0
  TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
2628
2629
  // Prune tx_split and no-split based on sub-block properties.
2630
0
  if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 &&
2631
0
      cpi->sf.tx_sf.prune_tx_size_level > 0) {
2632
0
    prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size,
2633
0
                            &try_no_split, &try_split,
2634
0
                            cpi->sf.tx_sf.prune_tx_size_level);
2635
0
  }
2636
2637
0
  if (cpi->sf.rt_sf.skip_tx_no_split_var_based_partition) {
2638
0
    if (x->try_merge_partition && try_split && p->eobs[block]) try_no_split = 0;
2639
0
  }
2640
2641
  // Try using current block as a single transform block without split.
2642
0
  if (try_no_split) {
2643
0
    try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2644
0
                          plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
2645
0
                          ftxs_mode, &no_split);
2646
2647
0
    push_inter_block_tx_no_split_rd(
2648
0
        x, mbmi, no_split.rd, blk_idx,
2649
0
        cpi->sf.tx_sf.prune_inter_tx_split_rd_eval_lvl);
2650
2651
    // Speed features for early termination.
2652
0
    const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
2653
0
    if (search_level) {
2654
0
      if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
2655
0
        *is_cost_valid = 0;
2656
0
        return;
2657
0
      }
2658
0
      if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
2659
0
        try_split = 0;
2660
0
      }
2661
0
    }
2662
0
    if (cpi->sf.tx_sf.txb_split_cap) {
2663
0
      if (p->eobs[block] == 0) try_split = 0;
2664
0
    }
2665
0
    if (prune_tx_split_eval_using_no_split_rd(
2666
0
            x, mbmi, no_split.rd, blk_idx,
2667
0
            cpi->sf.tx_sf.prune_inter_tx_split_rd_eval_lvl)) {
2668
0
      try_split = 0;
2669
0
    }
2670
0
  }
2671
2672
  // ML based speed feature to skip searching for split transform blocks.
2673
0
  if (x->e_mbd.bd == 8 && try_split &&
2674
0
      !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
2675
0
    const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
2676
0
    if (threshold >= 0) {
2677
0
      const int split_score =
2678
0
          ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
2679
0
      if (split_score < -threshold) try_split = 0;
2680
0
    }
2681
0
  }
2682
2683
0
  RD_STATS split_rd_stats;
2684
0
  split_rd_stats.rdcost = INT64_MAX;
2685
  // Try splitting current block into smaller transform blocks.
2686
0
  if (try_split) {
2687
0
    try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2688
0
                       plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
2689
0
                       AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
2690
0
                       &split_rd_stats);
2691
0
  }
2692
2693
0
  if (no_split.rd < split_rd_stats.rdcost) {
2694
0
    ENTROPY_CONTEXT *pta = ta + blk_col;
2695
0
    ENTROPY_CONTEXT *ptl = tl + blk_row;
2696
0
    p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
2697
0
    av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
2698
0
    txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2699
0
                          tx_size);
2700
0
    for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
2701
0
      for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
2702
0
        const int index =
2703
0
            av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
2704
0
        mbmi->inter_tx_size[index] = tx_size;
2705
0
      }
2706
0
    }
2707
0
    mbmi->tx_size = tx_size;
2708
0
    update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
2709
0
  } else {
2710
0
    *rd_stats = split_rd_stats;
2711
0
    if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
2712
0
  }
2713
0
}
2714
2715
static inline void choose_largest_tx_size(const AV1_COMP *const cpi,
2716
                                          MACROBLOCK *x, RD_STATS *rd_stats,
2717
0
                                          int64_t ref_best_rd, BLOCK_SIZE bs) {
2718
0
  MACROBLOCKD *const xd = &x->e_mbd;
2719
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2720
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2721
0
  mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2722
2723
  // If tx64 is not enabled, we need to go down to the next available size
2724
0
  if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
2725
0
    static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
2726
0
      TX_4X4,    // 4x4 transform
2727
0
      TX_8X8,    // 8x8 transform
2728
0
      TX_16X16,  // 16x16 transform
2729
0
      TX_32X32,  // 32x32 transform
2730
0
      TX_32X32,  // 64x64 transform
2731
0
      TX_4X8,    // 4x8 transform
2732
0
      TX_8X4,    // 8x4 transform
2733
0
      TX_8X16,   // 8x16 transform
2734
0
      TX_16X8,   // 16x8 transform
2735
0
      TX_16X32,  // 16x32 transform
2736
0
      TX_32X16,  // 32x16 transform
2737
0
      TX_32X32,  // 32x64 transform
2738
0
      TX_32X32,  // 64x32 transform
2739
0
      TX_4X16,   // 4x16 transform
2740
0
      TX_16X4,   // 16x4 transform
2741
0
      TX_8X32,   // 8x32 transform
2742
0
      TX_32X8,   // 32x8 transform
2743
0
      TX_16X32,  // 16x64 transform
2744
0
      TX_32X16,  // 64x16 transform
2745
0
    };
2746
0
    mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
2747
0
  } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
2748
0
             !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2749
0
    static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
2750
0
      TX_4X4,    // 4x4 transform
2751
0
      TX_8X8,    // 8x8 transform
2752
0
      TX_16X16,  // 16x16 transform
2753
0
      TX_32X32,  // 32x32 transform
2754
0
      TX_64X64,  // 64x64 transform
2755
0
      TX_4X4,    // 4x8 transform
2756
0
      TX_4X4,    // 8x4 transform
2757
0
      TX_8X8,    // 8x16 transform
2758
0
      TX_8X8,    // 16x8 transform
2759
0
      TX_16X16,  // 16x32 transform
2760
0
      TX_16X16,  // 32x16 transform
2761
0
      TX_32X32,  // 32x64 transform
2762
0
      TX_32X32,  // 64x32 transform
2763
0
      TX_4X4,    // 4x16 transform
2764
0
      TX_4X4,    // 16x4 transform
2765
0
      TX_8X8,    // 8x32 transform
2766
0
      TX_8X8,    // 32x8 transform
2767
0
      TX_16X16,  // 16x64 transform
2768
0
      TX_16X16,  // 64x16 transform
2769
0
    };
2770
0
    mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
2771
0
  } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
2772
0
             !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2773
0
    static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
2774
0
      TX_4X4,    // 4x4 transform
2775
0
      TX_8X8,    // 8x8 transform
2776
0
      TX_16X16,  // 16x16 transform
2777
0
      TX_32X32,  // 32x32 transform
2778
0
      TX_32X32,  // 64x64 transform
2779
0
      TX_4X4,    // 4x8 transform
2780
0
      TX_4X4,    // 8x4 transform
2781
0
      TX_8X8,    // 8x16 transform
2782
0
      TX_8X8,    // 16x8 transform
2783
0
      TX_16X16,  // 16x32 transform
2784
0
      TX_16X16,  // 32x16 transform
2785
0
      TX_32X32,  // 32x64 transform
2786
0
      TX_32X32,  // 64x32 transform
2787
0
      TX_4X4,    // 4x16 transform
2788
0
      TX_4X4,    // 16x4 transform
2789
0
      TX_8X8,    // 8x32 transform
2790
0
      TX_8X8,    // 32x8 transform
2791
0
      TX_16X16,  // 16x64 transform
2792
0
      TX_16X16,  // 64x16 transform
2793
0
    };
2794
2795
0
    mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
2796
0
  }
2797
2798
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
2799
0
  const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
2800
0
  const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
2801
  // Skip RDcost is used only for Inter blocks
2802
0
  const int64_t skip_txfm_rd =
2803
0
      is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2804
0
  const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0);
2805
0
  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2806
0
                       AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2807
0
                       mbmi->tx_size, FTXS_NONE);
2808
0
}
2809
2810
static inline void choose_smallest_tx_size(const AV1_COMP *const cpi,
2811
                                           MACROBLOCK *x, RD_STATS *rd_stats,
2812
0
                                           int64_t ref_best_rd, BLOCK_SIZE bs) {
2813
0
  MACROBLOCKD *const xd = &x->e_mbd;
2814
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2815
2816
0
  mbmi->tx_size = TX_4X4;
2817
  // TODO(any) : Pass this_rd based on skip/non-skip cost
2818
0
  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
2819
0
                       FTXS_NONE);
2820
0
}
2821
2822
#if !CONFIG_REALTIME_ONLY
2823
static void ml_predict_intra_tx_depth_prune(MACROBLOCK *x, int blk_row,
2824
                                            int blk_col, BLOCK_SIZE bsize,
2825
0
                                            TX_SIZE tx_size) {
2826
0
  const MACROBLOCKD *const xd = &x->e_mbd;
2827
0
  const MB_MODE_INFO *const mbmi = xd->mi[0];
2828
2829
  // Disable the pruning logic using NN model for the following cases:
2830
  // 1) Lossless coding as only 4x4 transform is evaluated in this case
2831
  // 2) When transform and current block sizes do not match as the features are
2832
  // obtained over the current block
2833
  // 3) When operating bit-depth is not 8-bit as the input features are not
2834
  // scaled according to bit-depth.
2835
0
  if (xd->lossless[mbmi->segment_id] || txsize_to_bsize[tx_size] != bsize ||
2836
0
      xd->bd != 8)
2837
0
    return;
2838
2839
  // Currently NN model based pruning is supported only when largest transform
2840
  // size is 8x8
2841
0
  if (tx_size != TX_8X8) return;
2842
2843
  // Neural network model is a sequential neural net and was trained using SGD
2844
  // optimizer. The model can be further improved in terms of speed/quality by
2845
  // considering the following experiments:
2846
  // 1) Generate ML model by training with balanced data for different learning
2847
  // rates and optimizers.
2848
  // 2) Experiment with ML model by adding features related to the statistics of
2849
  // top and left pixels to capture the accuracy of reconstructed neighbouring
2850
  // pixels for 4x4 blocks numbered 1, 2, 3 in 8x8 block, source variance of 4x4
2851
  // sub-blocks, etc.
2852
  // 3) Generate ML models for transform blocks other than 8x8.
2853
0
  const NN_CONFIG *const nn_config = &av1_intra_tx_split_nnconfig_8x8;
2854
0
  const float *const intra_tx_prune_thresh = av1_intra_tx_prune_nn_thresh_8x8;
2855
2856
0
  float features[NUM_INTRA_TX_SPLIT_FEATURES] = { 0.0f };
2857
0
  const int diff_stride = block_size_wide[bsize];
2858
2859
0
  const int16_t *diff = x->plane[0].src_diff + MI_SIZE * blk_row * diff_stride +
2860
0
                        MI_SIZE * blk_col;
2861
0
  const int bw = tx_size_wide[tx_size];
2862
0
  const int bh = tx_size_high[tx_size];
2863
2864
0
  int feature_idx = get_mean_dev_features(diff, diff_stride, bw, bh, features);
2865
2866
0
  features[feature_idx++] = log1pf((float)x->source_variance);
2867
2868
0
  const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
2869
0
  const float log_dc_q_square = log1pf((float)(dc_q * dc_q) / 256.0f);
2870
0
  features[feature_idx++] = log_dc_q_square;
2871
0
  assert(feature_idx == NUM_INTRA_TX_SPLIT_FEATURES);
2872
0
  for (int i = 0; i < NUM_INTRA_TX_SPLIT_FEATURES; i++) {
2873
0
    features[i] = (features[i] - av1_intra_tx_split_8x8_mean[i]) /
2874
0
                  av1_intra_tx_split_8x8_std[i];
2875
0
  }
2876
2877
0
  float score;
2878
0
  av1_nn_predict(features, nn_config, 1, &score);
2879
2880
0
  TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2881
0
  if (score <= intra_tx_prune_thresh[0])
2882
0
    txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_SPLIT;
2883
0
  else if (score > intra_tx_prune_thresh[1])
2884
0
    txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_LARGEST;
2885
0
}
2886
#endif  // !CONFIG_REALTIME_ONLY
2887
2888
/*!\brief Transform type search for luma macroblock with fixed transform size.
2889
 *
2890
 * \ingroup transform_search
2891
 * Search for the best transform type and return the transform coefficients RD
2892
 * cost of current luma macroblock with the given uniform transform size.
2893
 *
2894
 * \param[in]    x              Pointer to structure holding the data for the
2895
                                current encoding macroblock
2896
 * \param[in]    cpi            Top-level encoder structure
2897
 * \param[in]    rd_stats       Pointer to struct to keep track of the RD stats
2898
 * \param[in]    ref_best_rd    Best RD cost seen for this block so far
2899
 * \param[in]    bs             Size of the current macroblock
2900
 * \param[in]    tx_size        The given transform size
2901
 * \param[in]    ftxs_mode      Transform search mode specifying desired speed
2902
                                and quality tradeoff
2903
 * \return       An int64_t value that is the best RD cost found.
2904
 */
2905
static int64_t uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
2906
                                RD_STATS *rd_stats, int64_t ref_best_rd,
2907
                                BLOCK_SIZE bs, TX_SIZE tx_size,
2908
0
                                FAST_TX_SEARCH_MODE ftxs_mode) {
2909
0
  assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
2910
0
  MACROBLOCKD *const xd = &x->e_mbd;
2911
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2912
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2913
0
  const ModeCosts *mode_costs = &x->mode_costs;
2914
0
  const int is_inter = is_inter_block(mbmi);
2915
0
  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
2916
0
                        block_signals_txsize(mbmi->bsize);
2917
0
  int tx_size_rate = 0;
2918
0
  if (tx_select) {
2919
0
    const int ctx = txfm_partition_context(
2920
0
        xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
2921
0
    tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0]
2922
0
                            : tx_size_cost(x, bs, tx_size);
2923
0
  }
2924
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
2925
0
  const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
2926
0
  const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
2927
0
  const int64_t skip_txfm_rd =
2928
0
      is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2929
0
  const int64_t no_this_rd =
2930
0
      RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
2931
2932
0
  mbmi->tx_size = tx_size;
2933
0
  av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2934
0
                       AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2935
0
                       tx_size, ftxs_mode);
2936
0
  if (rd_stats->rate == INT_MAX) return INT64_MAX;
2937
2938
0
  int64_t rd;
2939
  // rdstats->rate should include all the rate except skip/non-skip cost as the
2940
  // same is accounted in the caller functions after rd evaluation of all
2941
  // planes. However the decisions should be done after considering the
2942
  // skip/non-skip header cost
2943
0
  if (rd_stats->skip_txfm && is_inter) {
2944
0
    rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
2945
0
  } else {
2946
    // Intra blocks are always signalled as non-skip
2947
0
    rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
2948
0
                rd_stats->dist);
2949
0
    rd_stats->rate += tx_size_rate;
2950
0
  }
2951
  // Check if forcing the block to skip transform leads to smaller RD cost.
2952
0
  if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
2953
0
    int64_t temp_skip_txfm_rd =
2954
0
        RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
2955
0
    if (temp_skip_txfm_rd <= rd) {
2956
0
      rd = temp_skip_txfm_rd;
2957
0
      rd_stats->rate = 0;
2958
0
      rd_stats->dist = rd_stats->sse;
2959
0
      rd_stats->skip_txfm = 1;
2960
0
    }
2961
0
  }
2962
2963
0
  return rd;
2964
0
}
2965
2966
// Search for the best uniform transform size and type for current coding block.
2967
static inline void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
2968
                                               MACROBLOCK *x,
2969
                                               RD_STATS *rd_stats,
2970
                                               int64_t ref_best_rd,
2971
0
                                               BLOCK_SIZE bs) {
2972
0
  av1_invalid_rd_stats(rd_stats);
2973
2974
0
  MACROBLOCKD *const xd = &x->e_mbd;
2975
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
2976
0
  TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2977
0
  const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
2978
0
  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT;
2979
0
  int start_tx;
2980
  // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
2981
  // how many times of splitting is allowed during the RD search.
2982
0
  int init_depth;
2983
2984
0
  if (tx_select) {
2985
0
    start_tx = max_rect_tx_size;
2986
0
    init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
2987
0
                                       is_inter_block(mbmi), &cpi->sf,
2988
0
                                       txfm_params->tx_size_search_method);
2989
0
    if (init_depth == MAX_TX_DEPTH && !cpi->oxcf.txfm_cfg.enable_tx64 &&
2990
0
        txsize_sqr_up_map[start_tx] == TX_64X64) {
2991
0
      start_tx = sub_tx_size_map[start_tx];
2992
0
    }
2993
0
  } else {
2994
0
    const TX_SIZE chosen_tx_size =
2995
0
        tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2996
0
    start_tx = chosen_tx_size;
2997
0
    init_depth = MAX_TX_DEPTH;
2998
0
  }
2999
3000
0
  uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
3001
0
  TX_SIZE best_tx_size = max_rect_tx_size;
3002
0
  int64_t best_rd = INT64_MAX;
3003
0
  const int num_blks = bsize_to_num_blk(bs);
3004
0
  x->rd_model = FULL_TXFM_RD;
3005
0
  int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
3006
0
  for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
3007
0
       depth++, tx_size = sub_tx_size_map[tx_size]) {
3008
0
    if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
3009
0
         txsize_sqr_up_map[tx_size] == TX_64X64) ||
3010
0
        (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
3011
0
         tx_size_wide[tx_size] != tx_size_high[tx_size])) {
3012
0
      continue;
3013
0
    }
3014
3015
0
#if !CONFIG_REALTIME_ONLY
3016
0
    if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_SPLIT) break;
3017
3018
    // Set the flag to enable the evaluation of NN classifier to prune transform
3019
    // depths. As the features are based on intra residual information of
3020
    // largest transform, the evaluation of NN model is enabled only for this
3021
    // case.
3022
0
    txfm_params->enable_nn_prune_intra_tx_depths =
3023
0
        (cpi->sf.tx_sf.prune_intra_tx_depths_using_nn && tx_size == start_tx);
3024
0
#endif
3025
3026
0
    RD_STATS this_rd_stats;
3027
    // When the speed feature use_rd_based_breakout_for_intra_tx_search is
3028
    // enabled, use the known minimum best_rd for early termination.
3029
0
    const int64_t rd_thresh =
3030
0
        cpi->sf.tx_sf.use_rd_based_breakout_for_intra_tx_search
3031
0
            ? AOMMIN(ref_best_rd, best_rd)
3032
0
            : ref_best_rd;
3033
0
    rd[depth] = uniform_txfm_yrd(cpi, x, &this_rd_stats, rd_thresh, bs, tx_size,
3034
0
                                 FTXS_NONE);
3035
0
    if (rd[depth] < best_rd) {
3036
0
      av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
3037
0
      best_tx_size = tx_size;
3038
0
      best_rd = rd[depth];
3039
0
      *rd_stats = this_rd_stats;
3040
0
    }
3041
0
    if (tx_size == TX_4X4) break;
3042
    // If we are searching three depths, prune the smallest size depending
3043
    // on rd results for the first two depths for low contrast blocks.
3044
0
    if (depth > init_depth && depth != MAX_TX_DEPTH &&
3045
0
        x->source_variance < 256) {
3046
0
      if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
3047
0
    }
3048
0
  }
3049
3050
0
  if (rd_stats->rate != INT_MAX) {
3051
0
    mbmi->tx_size = best_tx_size;
3052
0
    av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
3053
0
  }
3054
3055
0
#if !CONFIG_REALTIME_ONLY
3056
  // Reset the flags to avoid any unintentional evaluation of NN model and
3057
  // consumption of prune depths.
3058
0
  txfm_params->enable_nn_prune_intra_tx_depths = false;
3059
0
  txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_NONE;
3060
0
#endif
3061
0
}
3062
3063
// Search for the best transform type for the given transform block in the
3064
// given plane/channel, and calculate the corresponding RD cost.
3065
static inline void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
3066
                                 BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
3067
0
                                 void *arg) {
3068
0
  struct rdcost_block_args *args = arg;
3069
0
  if (args->exit_early) {
3070
0
    args->incomplete_exit = 1;
3071
0
    return;
3072
0
  }
3073
3074
0
  MACROBLOCK *const x = args->x;
3075
0
  MACROBLOCKD *const xd = &x->e_mbd;
3076
0
  const int is_inter = is_inter_block(xd->mi[0]);
3077
0
  const AV1_COMP *cpi = args->cpi;
3078
0
  ENTROPY_CONTEXT *a = args->t_above + blk_col;
3079
0
  ENTROPY_CONTEXT *l = args->t_left + blk_row;
3080
0
  const AV1_COMMON *cm = &cpi->common;
3081
0
  RD_STATS this_rd_stats;
3082
0
  av1_init_rd_stats(&this_rd_stats);
3083
3084
0
  if (!is_inter) {
3085
0
    av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
3086
0
    av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
3087
0
#if !CONFIG_REALTIME_ONLY
3088
0
    const TxfmSearchParams *const txfm_params = &x->txfm_search_params;
3089
0
    if (txfm_params->enable_nn_prune_intra_tx_depths) {
3090
0
      ml_predict_intra_tx_depth_prune(x, blk_row, blk_col, plane_bsize,
3091
0
                                      tx_size);
3092
0
      if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_LARGEST) {
3093
0
        av1_invalid_rd_stats(&args->rd_stats);
3094
0
        args->exit_early = 1;
3095
0
        return;
3096
0
      }
3097
0
    }
3098
0
#endif
3099
0
  }
3100
3101
0
  TXB_CTX txb_ctx;
3102
0
  get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
3103
0
  search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3104
0
                 &txb_ctx, args->ftxs_mode, args->best_rd - args->current_rd,
3105
0
                 &this_rd_stats);
3106
3107
0
#if !CONFIG_REALTIME_ONLY
3108
0
  if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
3109
0
    assert(!is_inter || plane_bsize < BLOCK_8X8);
3110
0
    cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
3111
0
  }
3112
0
#endif
3113
3114
#if CONFIG_RD_DEBUG
3115
  update_txb_coeff_cost(&this_rd_stats, plane, this_rd_stats.rate);
3116
#endif  // CONFIG_RD_DEBUG
3117
0
  av1_set_txb_context(x, plane, block, tx_size, a, l);
3118
3119
0
  int64_t rd;
3120
0
  if (is_inter) {
3121
0
    const int64_t no_skip_txfm_rd =
3122
0
        RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3123
0
    const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3124
0
    rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
3125
0
    this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block];
3126
0
  } else {
3127
    // Signal non-skip_txfm for Intra blocks
3128
0
    rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3129
0
    this_rd_stats.skip_txfm = 0;
3130
0
  }
3131
3132
0
  av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3133
3134
0
  args->current_rd += rd;
3135
0
  if (args->current_rd > args->best_rd) args->exit_early = 1;
3136
0
}
3137
3138
int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3139
                              RD_STATS *rd_stats, int64_t ref_best_rd,
3140
0
                              BLOCK_SIZE bs, TX_SIZE tx_size) {
3141
0
  MACROBLOCKD *const xd = &x->e_mbd;
3142
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3143
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3144
0
  const ModeCosts *mode_costs = &x->mode_costs;
3145
0
  const int is_inter = is_inter_block(mbmi);
3146
0
  const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3147
0
                        block_signals_txsize(mbmi->bsize);
3148
0
  int tx_size_rate = 0;
3149
0
  if (tx_select) {
3150
0
    const int ctx = txfm_partition_context(
3151
0
        xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3152
0
    tx_size_rate = mode_costs->txfm_partition_cost[ctx][0];
3153
0
  }
3154
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
3155
0
  const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3156
0
  const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3157
0
  const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0);
3158
0
  const int64_t no_this_rd =
3159
0
      RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3160
0
  mbmi->tx_size = tx_size;
3161
3162
0
  const uint8_t txw_unit = tx_size_wide_unit[tx_size];
3163
0
  const uint8_t txh_unit = tx_size_high_unit[tx_size];
3164
0
  const int step = txw_unit * txh_unit;
3165
0
  const int max_blocks_wide = max_block_wide(xd, bs, 0);
3166
0
  const int max_blocks_high = max_block_high(xd, bs, 0);
3167
3168
0
  struct rdcost_block_args args;
3169
0
  av1_zero(args);
3170
0
  args.x = x;
3171
0
  args.cpi = cpi;
3172
0
  args.best_rd = ref_best_rd;
3173
0
  args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd);
3174
0
  av1_init_rd_stats(&args.rd_stats);
3175
0
  av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left);
3176
0
  int i = 0;
3177
0
  for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit;
3178
0
       blk_row += txh_unit) {
3179
0
    for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) {
3180
0
      RD_STATS this_rd_stats;
3181
0
      av1_init_rd_stats(&this_rd_stats);
3182
3183
0
      if (args.exit_early) {
3184
0
        args.incomplete_exit = 1;
3185
0
        break;
3186
0
      }
3187
3188
0
      ENTROPY_CONTEXT *a = args.t_above + blk_col;
3189
0
      ENTROPY_CONTEXT *l = args.t_left + blk_row;
3190
0
      TXB_CTX txb_ctx;
3191
0
      get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx);
3192
3193
0
      TxfmParam txfm_param;
3194
0
      QUANT_PARAM quant_param;
3195
0
      av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param);
3196
0
      av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param);
3197
3198
0
      av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param);
3199
0
      av1_quant(x, 0, i, &txfm_param, &quant_param);
3200
3201
0
      this_rd_stats.rate =
3202
0
          cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0);
3203
3204
0
      const SCAN_ORDER *const scan_order =
3205
0
          get_scan(txfm_param.tx_size, txfm_param.tx_type);
3206
0
      dist_block_tx_domain(x, 0, i, tx_size, quant_param.qmatrix,
3207
0
                           scan_order->scan, &this_rd_stats.dist,
3208
0
                           &this_rd_stats.sse);
3209
3210
0
      const int64_t no_skip_txfm_rd =
3211
0
          RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3212
0
      const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3213
3214
0
      this_rd_stats.skip_txfm &= !x->plane[0].eobs[i];
3215
3216
0
      av1_merge_rd_stats(&args.rd_stats, &this_rd_stats);
3217
0
      args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd);
3218
3219
0
      if (args.current_rd > ref_best_rd) {
3220
0
        args.exit_early = 1;
3221
0
        break;
3222
0
      }
3223
3224
0
      av1_set_txb_context(x, 0, i, tx_size, a, l);
3225
0
      i += step;
3226
0
    }
3227
0
  }
3228
3229
0
  if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats);
3230
3231
0
  *rd_stats = args.rd_stats;
3232
0
  if (rd_stats->rate == INT_MAX) return INT64_MAX;
3233
3234
0
  int64_t rd;
3235
  // rdstats->rate should include all the rate except skip/non-skip cost as the
3236
  // same is accounted in the caller functions after rd evaluation of all
3237
  // planes. However the decisions should be done after considering the
3238
  // skip/non-skip header cost
3239
0
  if (rd_stats->skip_txfm && is_inter) {
3240
0
    rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3241
0
  } else {
3242
    // Intra blocks are always signalled as non-skip
3243
0
    rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3244
0
                rd_stats->dist);
3245
0
    rd_stats->rate += tx_size_rate;
3246
0
  }
3247
  // Check if forcing the block to skip transform leads to smaller RD cost.
3248
0
  if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3249
0
    int64_t temp_skip_txfm_rd =
3250
0
        RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3251
0
    if (temp_skip_txfm_rd <= rd) {
3252
0
      rd = temp_skip_txfm_rd;
3253
0
      rd_stats->rate = 0;
3254
0
      rd_stats->dist = rd_stats->sse;
3255
0
      rd_stats->skip_txfm = 1;
3256
0
    }
3257
0
  }
3258
3259
0
  return rd;
3260
0
}
3261
3262
// Search for the best transform type for a luma inter-predicted block, given
3263
// the transform block partitions.
3264
// This function is used only when some speed features are enabled.
3265
static inline void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
3266
                                int blk_col, int block, TX_SIZE tx_size,
3267
                                BLOCK_SIZE plane_bsize, int depth,
3268
                                ENTROPY_CONTEXT *above_ctx,
3269
                                ENTROPY_CONTEXT *left_ctx,
3270
                                TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
3271
                                int64_t ref_best_rd, RD_STATS *rd_stats,
3272
0
                                FAST_TX_SEARCH_MODE ftxs_mode) {
3273
0
  assert(tx_size < TX_SIZES_ALL);
3274
0
  MACROBLOCKD *const xd = &x->e_mbd;
3275
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3276
0
  assert(is_inter_block(mbmi));
3277
0
  const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
3278
0
  const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
3279
3280
0
  if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
3281
3282
0
  const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
3283
0
      plane_bsize, blk_row, blk_col)];
3284
0
  const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
3285
0
                                         mbmi->bsize, tx_size);
3286
3287
0
  av1_init_rd_stats(rd_stats);
3288
0
  if (tx_size == plane_tx_size) {
3289
0
    ENTROPY_CONTEXT *ta = above_ctx + blk_col;
3290
0
    ENTROPY_CONTEXT *tl = left_ctx + blk_row;
3291
0
    const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
3292
0
    TXB_CTX txb_ctx;
3293
0
    get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
3294
3295
0
    const int zero_blk_rate =
3296
0
        x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)]
3297
0
            .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
3298
0
    rd_stats->zero_rate = zero_blk_rate;
3299
0
    tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
3300
0
               rd_stats, ftxs_mode, ref_best_rd);
3301
0
    if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
3302
0
            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
3303
0
        rd_stats->skip_txfm == 1) {
3304
0
      rd_stats->rate = zero_blk_rate;
3305
0
      rd_stats->dist = rd_stats->sse;
3306
0
      rd_stats->skip_txfm = 1;
3307
0
      x->plane[0].eobs[block] = 0;
3308
0
      x->plane[0].txb_entropy_ctx[block] = 0;
3309
0
      update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
3310
0
    } else {
3311
0
      rd_stats->skip_txfm = 0;
3312
0
    }
3313
0
    if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3314
0
      rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0];
3315
0
    av1_set_txb_context(x, 0, block, tx_size, ta, tl);
3316
0
    txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
3317
0
                          tx_size);
3318
0
  } else {
3319
0
    const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
3320
0
    const int txb_width = tx_size_wide_unit[sub_txs];
3321
0
    const int txb_height = tx_size_high_unit[sub_txs];
3322
0
    const int step = txb_height * txb_width;
3323
0
    const int row_end =
3324
0
        AOMMIN(tx_size_high_unit[tx_size], max_blocks_high - blk_row);
3325
0
    const int col_end =
3326
0
        AOMMIN(tx_size_wide_unit[tx_size], max_blocks_wide - blk_col);
3327
0
    RD_STATS pn_rd_stats;
3328
0
    int64_t this_rd = 0;
3329
0
    assert(txb_width > 0 && txb_height > 0);
3330
3331
0
    for (int row = 0; row < row_end; row += txb_height) {
3332
0
      const int offsetr = blk_row + row;
3333
0
      for (int col = 0; col < col_end; col += txb_width) {
3334
0
        const int offsetc = blk_col + col;
3335
3336
0
        av1_init_rd_stats(&pn_rd_stats);
3337
0
        tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
3338
0
                     depth + 1, above_ctx, left_ctx, tx_above, tx_left,
3339
0
                     ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
3340
0
        if (pn_rd_stats.rate == INT_MAX) {
3341
0
          av1_invalid_rd_stats(rd_stats);
3342
0
          return;
3343
0
        }
3344
0
        av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3345
0
        this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
3346
0
        block += step;
3347
0
      }
3348
0
    }
3349
3350
0
    if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3351
0
      rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1];
3352
0
  }
3353
0
}
3354
3355
// search for tx type with tx sizes already decided for a inter-predicted luma
3356
// partition block. It's used only when some speed features are enabled.
3357
// Return value 0: early termination triggered, no valid rd cost available;
3358
//              1: rd cost values are valid.
3359
static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3360
                           RD_STATS *rd_stats, BLOCK_SIZE bsize,
3361
0
                           int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
3362
0
  if (ref_best_rd < 0) {
3363
0
    av1_invalid_rd_stats(rd_stats);
3364
0
    return 0;
3365
0
  }
3366
3367
0
  av1_init_rd_stats(rd_stats);
3368
3369
0
  MACROBLOCKD *const xd = &x->e_mbd;
3370
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3371
0
  const struct macroblockd_plane *const pd = &xd->plane[0];
3372
0
  const int mi_width = mi_size_wide[bsize];
3373
0
  const int mi_height = mi_size_high[bsize];
3374
0
  const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
3375
0
  const int bh = tx_size_high_unit[max_tx_size];
3376
0
  const int bw = tx_size_wide_unit[max_tx_size];
3377
0
  const int step = bw * bh;
3378
0
  const int init_depth = get_search_init_depth(
3379
0
      mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3380
0
  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3381
0
  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3382
0
  TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3383
0
  TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3384
0
  av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3385
0
  memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3386
0
  memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3387
3388
0
  int64_t this_rd = 0;
3389
0
  for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
3390
0
    for (int idx = 0; idx < mi_width; idx += bw) {
3391
0
      RD_STATS pn_rd_stats;
3392
0
      av1_init_rd_stats(&pn_rd_stats);
3393
0
      tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
3394
0
                   ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
3395
0
                   &pn_rd_stats, ftxs_mode);
3396
0
      if (pn_rd_stats.rate == INT_MAX) {
3397
0
        av1_invalid_rd_stats(rd_stats);
3398
0
        return 0;
3399
0
      }
3400
0
      av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3401
0
      this_rd +=
3402
0
          AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
3403
0
                 RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
3404
0
      block += step;
3405
0
    }
3406
0
  }
3407
3408
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
3409
0
  const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3410
0
  const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3411
0
  const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3412
0
  this_rd =
3413
0
      RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist);
3414
0
  if (skip_txfm_rd < this_rd) {
3415
0
    this_rd = skip_txfm_rd;
3416
0
    rd_stats->rate = 0;
3417
0
    rd_stats->dist = rd_stats->sse;
3418
0
    rd_stats->skip_txfm = 1;
3419
0
  }
3420
3421
0
  const int is_cost_valid = this_rd > ref_best_rd;
3422
0
  if (!is_cost_valid) {
3423
    // reset cost value
3424
0
    av1_invalid_rd_stats(rd_stats);
3425
0
  }
3426
0
  return is_cost_valid;
3427
0
}
3428
3429
// Search for the best transform size and type for current inter-predicted
3430
// luma block with recursive transform block partitioning. The obtained
3431
// transform selection will be saved in xd->mi[0], the corresponding RD stats
3432
// will be saved in rd_stats. The returned value is the corresponding RD cost.
3433
static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
3434
                                       RD_STATS *rd_stats, BLOCK_SIZE bsize,
3435
0
                                       int64_t ref_best_rd) {
3436
0
  MACROBLOCKD *const xd = &x->e_mbd;
3437
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3438
0
  assert(is_inter_block(xd->mi[0]));
3439
0
  assert(bsize < BLOCK_SIZES_ALL);
3440
0
  const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD;
3441
0
  int64_t rd_thresh = ref_best_rd;
3442
0
  if (rd_thresh == 0) {
3443
0
    av1_invalid_rd_stats(rd_stats);
3444
0
    return INT64_MAX;
3445
0
  }
3446
0
  if (fast_tx_search && rd_thresh < INT64_MAX) {
3447
0
    if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
3448
0
  }
3449
0
  assert(rd_thresh > 0);
3450
0
  const FAST_TX_SEARCH_MODE ftxs_mode =
3451
0
      fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
3452
0
  const struct macroblockd_plane *const pd = &xd->plane[0];
3453
0
  assert(bsize < BLOCK_SIZES_ALL);
3454
0
  const int mi_width = mi_size_wide[bsize];
3455
0
  const int mi_height = mi_size_high[bsize];
3456
0
  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3457
0
  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3458
0
  TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3459
0
  TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3460
0
  av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3461
0
  memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3462
0
  memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3463
0
  const int init_depth = get_search_init_depth(
3464
0
      mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3465
0
  const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
3466
0
  const int bh = tx_size_high_unit[max_tx_size];
3467
0
  const int bw = tx_size_wide_unit[max_tx_size];
3468
0
  const int step = bw * bh;
3469
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
3470
0
  const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3471
0
  const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3472
0
  int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0);
3473
0
  int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0);
3474
0
  int block = 0;
3475
0
  int blk_idx = 0;
3476
3477
0
  av1_init_rd_stats(rd_stats);
3478
0
  for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
3479
0
    for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
3480
0
      const int64_t best_rd_sofar =
3481
0
          (rd_thresh == INT64_MAX)
3482
0
              ? INT64_MAX
3483
0
              : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd)));
3484
0
      int is_cost_valid = 1;
3485
0
      RD_STATS pn_rd_stats;
3486
      // Search for the best transform block size and type for the sub-block.
3487
0
      select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
3488
0
                      ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
3489
0
                      best_rd_sofar, &is_cost_valid, ftxs_mode, blk_idx);
3490
0
      blk_idx++;
3491
0
      if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
3492
0
        av1_invalid_rd_stats(rd_stats);
3493
0
        return INT64_MAX;
3494
0
      }
3495
0
      av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3496
0
      skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3497
0
      no_skip_txfm_rd =
3498
0
          RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3499
0
      block += step;
3500
0
    }
3501
0
  }
3502
3503
0
  if (rd_stats->rate == INT_MAX) return INT64_MAX;
3504
3505
0
  rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd);
3506
3507
  // If fast_tx_search is true, only DCT and 1D DCT were tested in
3508
  // select_inter_block_yrd() above. Do a better search for tx type with
3509
  // tx sizes already decided.
3510
0
  if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
3511
0
    if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
3512
0
      return INT64_MAX;
3513
0
  }
3514
3515
0
  int64_t final_rd;
3516
0
  if (rd_stats->skip_txfm) {
3517
0
    final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3518
0
  } else {
3519
0
    final_rd =
3520
0
        RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3521
0
    if (!xd->lossless[xd->mi[0]->segment_id]) {
3522
0
      final_rd =
3523
0
          AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse));
3524
0
    }
3525
0
  }
3526
3527
0
  return final_rd;
3528
0
}
3529
3530
// Return 1 to terminate transform search early. The decision is made based on
3531
// the comparison with the reference RD cost and the model-estimated RD cost.
3532
static inline int model_based_tx_search_prune(const AV1_COMP *cpi,
3533
                                              MACROBLOCK *x, BLOCK_SIZE bsize,
3534
0
                                              int64_t ref_best_rd) {
3535
0
  const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
3536
0
  assert(level >= 0 && level <= 2);
3537
0
  int model_rate;
3538
0
  int64_t model_dist;
3539
0
  uint8_t model_skip;
3540
0
  MACROBLOCKD *const xd = &x->e_mbd;
3541
0
  model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
3542
0
      cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
3543
0
      NULL, NULL, NULL);
3544
0
  if (model_skip) return 0;
3545
0
  const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
3546
  // TODO(debargha, urvang): Improve the model and make the check below
3547
  // tighter.
3548
0
  static const int prune_factor_by8[] = { 3, 5 };
3549
0
  const int factor = prune_factor_by8[level - 1];
3550
0
  return ((model_rd * factor) >> 3) > ref_best_rd;
3551
0
}
3552
3553
void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3554
                                         RD_STATS *rd_stats, BLOCK_SIZE bsize,
3555
0
                                         int64_t ref_best_rd) {
3556
0
  MACROBLOCKD *const xd = &x->e_mbd;
3557
0
  const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3558
0
  assert(is_inter_block(xd->mi[0]));
3559
3560
0
  av1_invalid_rd_stats(rd_stats);
3561
3562
  // If modeled RD cost is a lot worse than the best so far, terminate early.
3563
0
  if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
3564
0
      ref_best_rd != INT64_MAX) {
3565
0
    if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
3566
0
  }
3567
3568
  // Hashing based speed feature. If the hash of the prediction residue block is
3569
  // found in the hash table, use previous search results and terminate early.
3570
0
  uint32_t hash = 0;
3571
0
  MB_RD_RECORD *mb_rd_record = NULL;
3572
0
  const int mi_row = x->e_mbd.mi_row;
3573
0
  const int mi_col = x->e_mbd.mi_col;
3574
0
  const int within_border =
3575
0
      mi_row >= xd->tile.mi_row_start &&
3576
0
      (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
3577
0
      mi_col >= xd->tile.mi_col_start &&
3578
0
      (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
3579
0
  const int is_mb_rd_hash_enabled =
3580
0
      (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
3581
0
  const int n4 = bsize_to_num_blk(bsize);
3582
0
  if (is_mb_rd_hash_enabled) {
3583
0
    hash = get_block_residue_hash(x, bsize);
3584
0
    mb_rd_record = x->txfm_search_info.mb_rd_record;
3585
0
    const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3586
0
    if (match_index != -1) {
3587
0
      MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3588
0
      fetch_mb_rd_info(n4, mb_rd_info, rd_stats, x);
3589
0
      return;
3590
0
    }
3591
0
  }
3592
3593
  // If we predict that skip is the optimal RD decision - set the respective
3594
  // context and terminate early.
3595
0
  int64_t dist;
3596
0
  if (txfm_params->skip_txfm_level &&
3597
0
      predict_skip_txfm(x, bsize, &dist,
3598
0
                        cpi->common.features.reduced_tx_set_used)) {
3599
0
    set_skip_txfm(x, rd_stats, bsize, dist);
3600
    // Save the RD search results into mb_rd_record.
3601
0
    if (is_mb_rd_hash_enabled)
3602
0
      save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3603
0
    return;
3604
0
  }
3605
#if CONFIG_SPEED_STATS
3606
  ++x->txfm_search_info.tx_search_count;
3607
#endif  // CONFIG_SPEED_STATS
3608
3609
0
  const int64_t rd =
3610
0
      select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd);
3611
3612
0
  if (rd == INT64_MAX) {
3613
    // We should always find at least one candidate unless ref_best_rd is less
3614
    // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
3615
    // might have failed to find something better)
3616
0
    assert(ref_best_rd != INT64_MAX);
3617
0
    av1_invalid_rd_stats(rd_stats);
3618
0
    return;
3619
0
  }
3620
3621
  // Save the RD search results into mb_rd_record.
3622
0
  if (is_mb_rd_hash_enabled) {
3623
0
    assert(mb_rd_record != NULL);
3624
0
    save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3625
0
  }
3626
0
}
3627
3628
void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3629
                                       RD_STATS *rd_stats, BLOCK_SIZE bs,
3630
0
                                       int64_t ref_best_rd) {
3631
0
  MACROBLOCKD *const xd = &x->e_mbd;
3632
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3633
0
  const TxfmSearchParams *tx_params = &x->txfm_search_params;
3634
0
  assert(bs == mbmi->bsize);
3635
0
  const int is_inter = is_inter_block(mbmi);
3636
0
  const int mi_row = xd->mi_row;
3637
0
  const int mi_col = xd->mi_col;
3638
3639
0
  av1_init_rd_stats(rd_stats);
3640
3641
  // Hashing based speed feature for inter blocks. If the hash of the residue
3642
  // block is found in the table, use previously saved search results and
3643
  // terminate early.
3644
0
  uint32_t hash = 0;
3645
0
  MB_RD_RECORD *mb_rd_record = NULL;
3646
0
  const int num_blks = bsize_to_num_blk(bs);
3647
0
  if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
3648
0
    const int within_border =
3649
0
        mi_row >= xd->tile.mi_row_start &&
3650
0
        (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
3651
0
        mi_col >= xd->tile.mi_col_start &&
3652
0
        (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
3653
0
    if (within_border) {
3654
0
      hash = get_block_residue_hash(x, bs);
3655
0
      mb_rd_record = x->txfm_search_info.mb_rd_record;
3656
0
      const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3657
0
      if (match_index != -1) {
3658
0
        MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3659
0
        fetch_mb_rd_info(num_blks, mb_rd_info, rd_stats, x);
3660
0
        return;
3661
0
      }
3662
0
    }
3663
0
  }
3664
3665
  // If we predict that skip is the optimal RD decision - set the respective
3666
  // context and terminate early.
3667
0
  int64_t dist;
3668
0
  if (tx_params->skip_txfm_level && is_inter &&
3669
0
      !xd->lossless[mbmi->segment_id] &&
3670
0
      predict_skip_txfm(x, bs, &dist,
3671
0
                        cpi->common.features.reduced_tx_set_used)) {
3672
    // Populate rdstats as per skip decision
3673
0
    set_skip_txfm(x, rd_stats, bs, dist);
3674
    // Save the RD search results into mb_rd_record.
3675
0
    if (mb_rd_record) {
3676
0
      save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3677
0
    }
3678
0
    return;
3679
0
  }
3680
3681
0
  if (xd->lossless[mbmi->segment_id]) {
3682
    // Lossless mode can only pick the smallest (4x4) transform size.
3683
0
    choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3684
0
  } else if (tx_params->tx_size_search_method == USE_LARGESTALL) {
3685
0
    choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3686
0
  } else {
3687
0
    choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3688
0
  }
3689
3690
  // Save the RD search results into mb_rd_record for possible reuse in future.
3691
0
  if (mb_rd_record) {
3692
0
    save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3693
0
  }
3694
0
}
3695
3696
int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
3697
0
                  BLOCK_SIZE bsize, int64_t ref_best_rd) {
3698
0
  av1_init_rd_stats(rd_stats);
3699
0
  if (ref_best_rd < 0) return 0;
3700
0
  if (!x->e_mbd.is_chroma_ref) return 1;
3701
3702
0
  MACROBLOCKD *const xd = &x->e_mbd;
3703
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3704
0
  struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
3705
0
  const int is_inter = is_inter_block(mbmi);
3706
0
  int64_t this_rd = 0, skip_txfm_rd = 0;
3707
0
  const BLOCK_SIZE plane_bsize =
3708
0
      get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3709
3710
0
  if (is_inter) {
3711
0
    for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
3712
0
      av1_subtract_plane(x, plane_bsize, plane);
3713
0
  }
3714
3715
0
  const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
3716
0
  int is_cost_valid = 1;
3717
0
  for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
3718
0
    RD_STATS this_rd_stats;
3719
0
    int64_t chroma_ref_best_rd = ref_best_rd;
3720
    // For inter blocks, refined ref_best_rd is used for early exit
3721
    // For intra blocks, even though current rd crosses ref_best_rd, early
3722
    // exit is not recommended as current rd is used for gating subsequent
3723
    // modes as well (say, for angular modes)
3724
    // TODO(any): Extend the early exit mechanism for intra modes as well
3725
0
    if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
3726
0
        chroma_ref_best_rd != INT64_MAX)
3727
0
      chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
3728
0
    av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
3729
0
                         plane_bsize, uv_tx_size, FTXS_NONE);
3730
0
    if (this_rd_stats.rate == INT_MAX) {
3731
0
      is_cost_valid = 0;
3732
0
      break;
3733
0
    }
3734
0
    av1_merge_rd_stats(rd_stats, &this_rd_stats);
3735
0
    this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3736
0
    skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
3737
0
    if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
3738
0
      is_cost_valid = 0;
3739
0
      break;
3740
0
    }
3741
0
  }
3742
3743
0
  if (!is_cost_valid) {
3744
    // reset cost value
3745
0
    av1_invalid_rd_stats(rd_stats);
3746
0
  }
3747
3748
0
  return is_cost_valid;
3749
0
}
3750
3751
void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3752
                          RD_STATS *rd_stats, int64_t ref_best_rd,
3753
                          int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
3754
0
                          TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode) {
3755
0
  assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
3756
3757
0
  if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3758
0
      txsize_sqr_up_map[tx_size] == TX_64X64) {
3759
0
    av1_invalid_rd_stats(rd_stats);
3760
0
    return;
3761
0
  }
3762
3763
0
  if (current_rd > ref_best_rd) {
3764
0
    av1_invalid_rd_stats(rd_stats);
3765
0
    return;
3766
0
  }
3767
3768
0
  MACROBLOCKD *const xd = &x->e_mbd;
3769
0
  const struct macroblockd_plane *const pd = &xd->plane[plane];
3770
0
  struct rdcost_block_args args;
3771
0
  av1_zero(args);
3772
0
  args.x = x;
3773
0
  args.cpi = cpi;
3774
0
  args.best_rd = ref_best_rd;
3775
0
  args.current_rd = current_rd;
3776
0
  args.ftxs_mode = ftxs_mode;
3777
0
  args.skip_trellis = 0;
3778
0
  av1_init_rd_stats(&args.rd_stats);
3779
3780
0
  av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
3781
0
  av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
3782
0
                                         &args);
3783
3784
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3785
0
  const int is_inter = is_inter_block(mbmi);
3786
0
  const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3787
3788
0
  if (invalid_rd) {
3789
0
    av1_invalid_rd_stats(rd_stats);
3790
0
  } else {
3791
0
    *rd_stats = args.rd_stats;
3792
0
  }
3793
0
}
3794
3795
int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
3796
                    RD_STATS *rd_stats, RD_STATS *rd_stats_y,
3797
0
                    RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
3798
0
  MACROBLOCKD *const xd = &x->e_mbd;
3799
0
  TxfmSearchParams *txfm_params = &x->txfm_search_params;
3800
0
  const int skip_ctx = av1_get_skip_txfm_context(xd);
3801
0
  const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0],
3802
0
                                  x->mode_costs.skip_txfm_cost[skip_ctx][1] };
3803
0
  const int64_t min_header_rate =
3804
0
      mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]);
3805
  // Account for minimum skip and non_skip rd.
3806
  // Eventually either one of them will be added to mode_rate
3807
0
  const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
3808
0
  if (min_header_rd_possible > ref_best_rd) {
3809
0
    av1_invalid_rd_stats(rd_stats_y);
3810
0
    return 0;
3811
0
  }
3812
3813
0
  const AV1_COMMON *cm = &cpi->common;
3814
0
  MB_MODE_INFO *const mbmi = xd->mi[0];
3815
0
  const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
3816
0
  const int64_t rd_thresh =
3817
0
      ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
3818
0
  av1_init_rd_stats(rd_stats);
3819
0
  av1_init_rd_stats(rd_stats_y);
3820
0
  rd_stats->rate = mode_rate;
3821
3822
  // cost and distortion
3823
0
  av1_subtract_plane(x, bsize, 0);
3824
0
  if (txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3825
0
      !xd->lossless[mbmi->segment_id]) {
3826
0
    av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3827
#if CONFIG_COLLECT_RD_STATS == 2
3828
    PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
3829
#endif  // CONFIG_COLLECT_RD_STATS == 2
3830
0
  } else {
3831
0
    av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3832
0
    memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
3833
0
  }
3834
3835
0
  if (rd_stats_y->rate == INT_MAX) return 0;
3836
3837
0
  av1_merge_rd_stats(rd_stats, rd_stats_y);
3838
3839
0
  const int64_t non_skip_txfm_rdcosty =
3840
0
      RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist);
3841
0
  const int64_t skip_txfm_rdcosty =
3842
0
      RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse);
3843
0
  const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty);
3844
0
  if (min_rdcosty > ref_best_rd) return 0;
3845
3846
0
  av1_init_rd_stats(rd_stats_uv);
3847
0
  const int num_planes = av1_num_planes(cm);
3848
0
  if (num_planes > 1) {
3849
0
    int64_t ref_best_chroma_rd = ref_best_rd;
3850
    // Calculate best rd cost possible for chroma
3851
0
    if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
3852
0
        (ref_best_chroma_rd != INT64_MAX)) {
3853
0
      ref_best_chroma_rd = (ref_best_chroma_rd -
3854
0
                            AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty));
3855
0
    }
3856
0
    const int is_cost_valid_uv =
3857
0
        av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
3858
0
    if (!is_cost_valid_uv) return 0;
3859
0
    av1_merge_rd_stats(rd_stats, rd_stats_uv);
3860
0
  }
3861
3862
0
  int choose_skip_txfm = rd_stats->skip_txfm;
3863
0
  if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) {
3864
0
    const int64_t rdcost_no_skip_txfm = RDCOST(
3865
0
        x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0],
3866
0
        rd_stats->dist);
3867
0
    const int64_t rdcost_skip_txfm =
3868
0
        RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse);
3869
0
    if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1;
3870
0
  }
3871
0
  if (choose_skip_txfm) {
3872
0
    rd_stats_y->rate = 0;
3873
0
    rd_stats_uv->rate = 0;
3874
0
    rd_stats->rate = mode_rate + skip_txfm_cost[1];
3875
0
    rd_stats->dist = rd_stats->sse;
3876
0
    rd_stats_y->dist = rd_stats_y->sse;
3877
0
    rd_stats_uv->dist = rd_stats_uv->sse;
3878
0
    mbmi->skip_txfm = 1;
3879
0
    if (rd_stats->skip_txfm) {
3880
0
      const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3881
0
      if (tmprd > ref_best_rd) return 0;
3882
0
    }
3883
0
  } else {
3884
0
    rd_stats->rate += skip_txfm_cost[0];
3885
0
    mbmi->skip_txfm = 0;
3886
0
  }
3887
3888
0
  return 1;
3889
0
}