/src/aom/av1/encoder/mv_prec.c
Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * Copyright (c) 2019, Alliance for Open Media. All rights reserved. |
3 | | * |
4 | | * This source code is subject to the terms of the BSD 2 Clause License and |
5 | | * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License |
6 | | * was not distributed with this source code in the LICENSE file, you can |
7 | | * obtain it at www.aomedia.org/license/software. If the Alliance for Open |
8 | | * Media Patent License 1.0 was not distributed with this source code in the |
9 | | * PATENTS file, you can obtain it at www.aomedia.org/license/patent. |
10 | | */ |
11 | | |
12 | | #include "config/aom_config.h" |
13 | | |
14 | | #include "av1/encoder/encodemv.h" |
15 | | #if !CONFIG_REALTIME_ONLY |
16 | | #include "av1/encoder/misc_model_weights.h" |
17 | | #endif // !CONFIG_REALTIME_ONLY |
18 | | #include "av1/encoder/mv_prec.h" |
19 | | |
20 | | #if !CONFIG_REALTIME_ONLY |
21 | | static inline int_mv get_ref_mv_for_mv_stats( |
22 | | const MB_MODE_INFO *mbmi, const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame, |
23 | 0 | int ref_idx) { |
24 | 0 | int ref_mv_idx = mbmi->ref_mv_idx; |
25 | 0 | if (mbmi->mode == NEAR_NEWMV || mbmi->mode == NEW_NEARMV) { |
26 | 0 | assert(has_second_ref(mbmi)); |
27 | 0 | ref_mv_idx += 1; |
28 | 0 | } |
29 | |
|
30 | 0 | const MV_REFERENCE_FRAME *ref_frames = mbmi->ref_frame; |
31 | 0 | const int8_t ref_frame_type = av1_ref_frame_type(ref_frames); |
32 | 0 | const CANDIDATE_MV *curr_ref_mv_stack = mbmi_ext_frame->ref_mv_stack; |
33 | |
|
34 | 0 | if (ref_frames[1] > INTRA_FRAME) { |
35 | 0 | assert(ref_idx == 0 || ref_idx == 1); |
36 | 0 | return ref_idx ? curr_ref_mv_stack[ref_mv_idx].comp_mv |
37 | 0 | : curr_ref_mv_stack[ref_mv_idx].this_mv; |
38 | 0 | } |
39 | | |
40 | 0 | assert(ref_idx == 0); |
41 | 0 | return ref_mv_idx < mbmi_ext_frame->ref_mv_count |
42 | 0 | ? curr_ref_mv_stack[ref_mv_idx].this_mv |
43 | 0 | : mbmi_ext_frame->global_mvs[ref_frame_type]; |
44 | 0 | } |
45 | | |
46 | 0 | static inline int get_symbol_cost(const aom_cdf_prob *cdf, int symbol) { |
47 | 0 | const aom_cdf_prob cur_cdf = AOM_ICDF(cdf[symbol]); |
48 | 0 | const aom_cdf_prob prev_cdf = symbol ? AOM_ICDF(cdf[symbol - 1]) : 0; |
49 | 0 | const aom_cdf_prob p15 = AOMMAX(cur_cdf - prev_cdf, EC_MIN_PROB); |
50 | |
|
51 | 0 | return av1_cost_symbol(p15); |
52 | 0 | } |
53 | | |
54 | | static inline int keep_one_comp_stat(MV_STATS *mv_stats, int comp_val, |
55 | | int comp_idx, const AV1_COMP *cpi, |
56 | 0 | int *rates) { |
57 | 0 | assert(comp_val != 0 && "mv component should not have zero value!"); |
58 | 0 | const int sign = comp_val < 0; |
59 | 0 | const int mag = sign ? -comp_val : comp_val; |
60 | 0 | const int mag_minus_1 = mag - 1; |
61 | 0 | int offset; |
62 | 0 | const int mv_class = av1_get_mv_class(mag_minus_1, &offset); |
63 | 0 | const int int_part = offset >> 3; // int mv data |
64 | 0 | const int frac_part = (offset >> 1) & 3; // fractional mv data |
65 | 0 | const int high_part = offset & 1; // high precision mv data |
66 | 0 | const int use_hp = cpi->common.features.allow_high_precision_mv; |
67 | 0 | int r_idx = 0; |
68 | |
|
69 | 0 | const MACROBLOCK *const x = &cpi->td.mb; |
70 | 0 | const MACROBLOCKD *const xd = &x->e_mbd; |
71 | 0 | FRAME_CONTEXT *ec_ctx = xd->tile_ctx; |
72 | 0 | nmv_context *nmvc = &ec_ctx->nmvc; |
73 | 0 | nmv_component *mvcomp_ctx = nmvc->comps; |
74 | 0 | nmv_component *cur_mvcomp_ctx = &mvcomp_ctx[comp_idx]; |
75 | 0 | aom_cdf_prob *sign_cdf = cur_mvcomp_ctx->sign_cdf; |
76 | 0 | aom_cdf_prob *class_cdf = cur_mvcomp_ctx->classes_cdf; |
77 | 0 | aom_cdf_prob *class0_cdf = cur_mvcomp_ctx->class0_cdf; |
78 | 0 | aom_cdf_prob(*bits_cdf)[3] = cur_mvcomp_ctx->bits_cdf; |
79 | 0 | aom_cdf_prob *frac_part_cdf = mv_class |
80 | 0 | ? (cur_mvcomp_ctx->fp_cdf) |
81 | 0 | : (cur_mvcomp_ctx->class0_fp_cdf[int_part]); |
82 | 0 | aom_cdf_prob *high_part_cdf = |
83 | 0 | mv_class ? (cur_mvcomp_ctx->hp_cdf) : (cur_mvcomp_ctx->class0_hp_cdf); |
84 | |
|
85 | 0 | const int sign_rate = get_symbol_cost(sign_cdf, sign); |
86 | 0 | rates[r_idx++] = sign_rate; |
87 | 0 | update_cdf(sign_cdf, sign, 2); |
88 | |
|
89 | 0 | const int class_rate = get_symbol_cost(class_cdf, mv_class); |
90 | 0 | rates[r_idx++] = class_rate; |
91 | 0 | update_cdf(class_cdf, mv_class, MV_CLASSES); |
92 | |
|
93 | 0 | int int_bit_rate = 0; |
94 | 0 | if (mv_class == MV_CLASS_0) { |
95 | 0 | int_bit_rate = get_symbol_cost(class0_cdf, int_part); |
96 | 0 | update_cdf(class0_cdf, int_part, CLASS0_SIZE); |
97 | 0 | } else { |
98 | 0 | const int n = mv_class + CLASS0_BITS - 1; // number of bits |
99 | 0 | for (int i = 0; i < n; ++i) { |
100 | 0 | int_bit_rate += get_symbol_cost(bits_cdf[i], (int_part >> i) & 1); |
101 | 0 | update_cdf(bits_cdf[i], (int_part >> i) & 1, 2); |
102 | 0 | } |
103 | 0 | } |
104 | 0 | rates[r_idx++] = int_bit_rate; |
105 | 0 | const int frac_part_rate = get_symbol_cost(frac_part_cdf, frac_part); |
106 | 0 | rates[r_idx++] = frac_part_rate; |
107 | 0 | update_cdf(frac_part_cdf, frac_part, MV_FP_SIZE); |
108 | 0 | const int high_part_rate = |
109 | 0 | use_hp ? get_symbol_cost(high_part_cdf, high_part) : 0; |
110 | 0 | if (use_hp) { |
111 | 0 | update_cdf(high_part_cdf, high_part, 2); |
112 | 0 | } |
113 | 0 | rates[r_idx++] = high_part_rate; |
114 | |
|
115 | 0 | mv_stats->last_bit_zero += !high_part; |
116 | 0 | mv_stats->last_bit_nonzero += high_part; |
117 | 0 | const int total_rate = |
118 | 0 | (sign_rate + class_rate + int_bit_rate + frac_part_rate + high_part_rate); |
119 | 0 | return total_rate; |
120 | 0 | } |
121 | | |
122 | | static inline void keep_one_mv_stat(MV_STATS *mv_stats, const MV *ref_mv, |
123 | 0 | const MV *cur_mv, const AV1_COMP *cpi) { |
124 | 0 | const MACROBLOCK *const x = &cpi->td.mb; |
125 | 0 | const MACROBLOCKD *const xd = &x->e_mbd; |
126 | 0 | FRAME_CONTEXT *ec_ctx = xd->tile_ctx; |
127 | 0 | nmv_context *nmvc = &ec_ctx->nmvc; |
128 | 0 | aom_cdf_prob *joint_cdf = nmvc->joints_cdf; |
129 | 0 | const int use_hp = cpi->common.features.allow_high_precision_mv; |
130 | |
|
131 | 0 | const MV diff = { cur_mv->row - ref_mv->row, cur_mv->col - ref_mv->col }; |
132 | 0 | const int mv_joint = av1_get_mv_joint(&diff); |
133 | | // TODO(chiyotsai@google.com): Estimate hp_diff when we are using lp |
134 | 0 | const MV hp_diff = diff; |
135 | 0 | const int hp_mv_joint = av1_get_mv_joint(&hp_diff); |
136 | 0 | const MV truncated_diff = { (diff.row / 2) * 2, (diff.col / 2) * 2 }; |
137 | 0 | const MV lp_diff = use_hp ? truncated_diff : diff; |
138 | 0 | const int lp_mv_joint = av1_get_mv_joint(&lp_diff); |
139 | |
|
140 | 0 | const int mv_joint_rate = get_symbol_cost(joint_cdf, mv_joint); |
141 | 0 | const int hp_mv_joint_rate = get_symbol_cost(joint_cdf, hp_mv_joint); |
142 | 0 | const int lp_mv_joint_rate = get_symbol_cost(joint_cdf, lp_mv_joint); |
143 | |
|
144 | 0 | update_cdf(joint_cdf, mv_joint, MV_JOINTS); |
145 | |
|
146 | 0 | mv_stats->total_mv_rate += mv_joint_rate; |
147 | 0 | mv_stats->hp_total_mv_rate += hp_mv_joint_rate; |
148 | 0 | mv_stats->lp_total_mv_rate += lp_mv_joint_rate; |
149 | 0 | mv_stats->mv_joint_count[mv_joint]++; |
150 | |
|
151 | 0 | for (int comp_idx = 0; comp_idx < 2; comp_idx++) { |
152 | 0 | const int comp_val = comp_idx ? diff.col : diff.row; |
153 | 0 | const int hp_comp_val = comp_idx ? hp_diff.col : hp_diff.row; |
154 | 0 | const int lp_comp_val = comp_idx ? lp_diff.col : lp_diff.row; |
155 | 0 | int rates[5]; |
156 | 0 | av1_zero_array(rates, 5); |
157 | |
|
158 | 0 | const int comp_rate = |
159 | 0 | comp_val ? keep_one_comp_stat(mv_stats, comp_val, comp_idx, cpi, rates) |
160 | 0 | : 0; |
161 | | // TODO(chiyotsai@google.com): Properly get hp rate when use_hp is false |
162 | 0 | const int hp_rate = |
163 | 0 | hp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] + rates[4] : 0; |
164 | 0 | const int lp_rate = |
165 | 0 | lp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] : 0; |
166 | |
|
167 | 0 | mv_stats->total_mv_rate += comp_rate; |
168 | 0 | mv_stats->hp_total_mv_rate += hp_rate; |
169 | 0 | mv_stats->lp_total_mv_rate += lp_rate; |
170 | 0 | } |
171 | 0 | } |
172 | | |
173 | | static inline void collect_mv_stats_b(MV_STATS *mv_stats, const AV1_COMP *cpi, |
174 | 0 | int mi_row, int mi_col) { |
175 | 0 | const AV1_COMMON *cm = &cpi->common; |
176 | 0 | const CommonModeInfoParams *const mi_params = &cm->mi_params; |
177 | |
|
178 | 0 | if (mi_row >= mi_params->mi_rows || mi_col >= mi_params->mi_cols) { |
179 | 0 | return; |
180 | 0 | } |
181 | | |
182 | 0 | const MB_MODE_INFO *mbmi = |
183 | 0 | mi_params->mi_grid_base[mi_row * mi_params->mi_stride + mi_col]; |
184 | 0 | const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame = |
185 | 0 | cpi->mbmi_ext_info.frame_base + |
186 | 0 | get_mi_ext_idx(mi_row, mi_col, cm->mi_params.mi_alloc_bsize, |
187 | 0 | cpi->mbmi_ext_info.stride); |
188 | |
|
189 | 0 | if (!is_inter_block(mbmi)) { |
190 | 0 | mv_stats->intra_count++; |
191 | 0 | return; |
192 | 0 | } |
193 | 0 | mv_stats->inter_count++; |
194 | |
|
195 | 0 | const PREDICTION_MODE mode = mbmi->mode; |
196 | 0 | const int is_compound = has_second_ref(mbmi); |
197 | |
|
198 | 0 | if (mode == NEWMV || mode == NEW_NEWMV) { |
199 | | // All mvs are new |
200 | 0 | for (int ref_idx = 0; ref_idx < 1 + is_compound; ++ref_idx) { |
201 | 0 | const MV ref_mv = |
202 | 0 | get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv; |
203 | 0 | const MV cur_mv = mbmi->mv[ref_idx].as_mv; |
204 | 0 | keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi); |
205 | 0 | } |
206 | 0 | } else if (mode == NEAREST_NEWMV || mode == NEAR_NEWMV || |
207 | 0 | mode == NEW_NEARESTMV || mode == NEW_NEARMV) { |
208 | | // has exactly one new_mv |
209 | 0 | mv_stats->default_mvs += 1; |
210 | |
|
211 | 0 | const int ref_idx = (mode == NEAREST_NEWMV || mode == NEAR_NEWMV); |
212 | 0 | const MV ref_mv = |
213 | 0 | get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv; |
214 | 0 | const MV cur_mv = mbmi->mv[ref_idx].as_mv; |
215 | |
|
216 | 0 | keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi); |
217 | 0 | } else { |
218 | | // No new_mv |
219 | 0 | mv_stats->default_mvs += 1 + is_compound; |
220 | 0 | } |
221 | | |
222 | | // Add texture information |
223 | 0 | const BLOCK_SIZE bsize = mbmi->bsize; |
224 | 0 | const int num_rows = block_size_high[bsize]; |
225 | 0 | const int num_cols = block_size_wide[bsize]; |
226 | 0 | const int y_stride = cpi->source->y_stride; |
227 | 0 | const int px_row = 4 * mi_row, px_col = 4 * mi_col; |
228 | 0 | const int buf_is_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH; |
229 | 0 | const int bd = cm->seq_params->bit_depth; |
230 | 0 | if (buf_is_hbd) { |
231 | 0 | uint16_t *source_buf = |
232 | 0 | CONVERT_TO_SHORTPTR(cpi->source->y_buffer) + px_row * y_stride + px_col; |
233 | 0 | for (int row = 0; row < num_rows - 1; row++) { |
234 | 0 | for (int col = 0; col < num_cols - 1; col++) { |
235 | 0 | const int offset = row * y_stride + col; |
236 | 0 | const int horz_diff = |
237 | 0 | abs(source_buf[offset + 1] - source_buf[offset]) >> (bd - 8); |
238 | 0 | const int vert_diff = |
239 | 0 | abs(source_buf[offset + y_stride] - source_buf[offset]) >> (bd - 8); |
240 | 0 | mv_stats->horz_text += horz_diff; |
241 | 0 | mv_stats->vert_text += vert_diff; |
242 | 0 | mv_stats->diag_text += horz_diff * vert_diff; |
243 | 0 | } |
244 | 0 | } |
245 | 0 | } else { |
246 | 0 | uint8_t *source_buf = cpi->source->y_buffer + px_row * y_stride + px_col; |
247 | 0 | for (int row = 0; row < num_rows - 1; row++) { |
248 | 0 | for (int col = 0; col < num_cols - 1; col++) { |
249 | 0 | const int offset = row * y_stride + col; |
250 | 0 | const int horz_diff = abs(source_buf[offset + 1] - source_buf[offset]); |
251 | 0 | const int vert_diff = |
252 | 0 | abs(source_buf[offset + y_stride] - source_buf[offset]); |
253 | 0 | mv_stats->horz_text += horz_diff; |
254 | 0 | mv_stats->vert_text += vert_diff; |
255 | 0 | mv_stats->diag_text += horz_diff * vert_diff; |
256 | 0 | } |
257 | 0 | } |
258 | 0 | } |
259 | 0 | } |
260 | | |
261 | | // Split block |
262 | | static inline void collect_mv_stats_sb(MV_STATS *mv_stats, const AV1_COMP *cpi, |
263 | | int mi_row, int mi_col, |
264 | 0 | BLOCK_SIZE bsize) { |
265 | 0 | assert(bsize < BLOCK_SIZES_ALL); |
266 | 0 | const AV1_COMMON *cm = &cpi->common; |
267 | |
|
268 | 0 | if (mi_row >= cm->mi_params.mi_rows || mi_col >= cm->mi_params.mi_cols) |
269 | 0 | return; |
270 | | |
271 | 0 | const PARTITION_TYPE partition = get_partition(cm, mi_row, mi_col, bsize); |
272 | 0 | const BLOCK_SIZE subsize = get_partition_subsize(bsize, partition); |
273 | |
|
274 | 0 | const int hbs = mi_size_wide[bsize] / 2; |
275 | 0 | const int qbs = mi_size_wide[bsize] / 4; |
276 | 0 | switch (partition) { |
277 | 0 | case PARTITION_NONE: |
278 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); |
279 | 0 | break; |
280 | 0 | case PARTITION_HORZ: |
281 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); |
282 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col); |
283 | 0 | break; |
284 | 0 | case PARTITION_VERT: |
285 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); |
286 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs); |
287 | 0 | break; |
288 | 0 | case PARTITION_SPLIT: |
289 | 0 | collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, subsize); |
290 | 0 | collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col + hbs, subsize); |
291 | 0 | collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col, subsize); |
292 | 0 | collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col + hbs, subsize); |
293 | 0 | break; |
294 | 0 | case PARTITION_HORZ_A: |
295 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); |
296 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs); |
297 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col); |
298 | 0 | break; |
299 | 0 | case PARTITION_HORZ_B: |
300 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); |
301 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col); |
302 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs); |
303 | 0 | break; |
304 | 0 | case PARTITION_VERT_A: |
305 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); |
306 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col); |
307 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs); |
308 | 0 | break; |
309 | 0 | case PARTITION_VERT_B: |
310 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col); |
311 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs); |
312 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs); |
313 | 0 | break; |
314 | 0 | case PARTITION_HORZ_4: |
315 | 0 | for (int i = 0; i < 4; ++i) { |
316 | 0 | const int this_mi_row = mi_row + i * qbs; |
317 | 0 | collect_mv_stats_b(mv_stats, cpi, this_mi_row, mi_col); |
318 | 0 | } |
319 | 0 | break; |
320 | 0 | case PARTITION_VERT_4: |
321 | 0 | for (int i = 0; i < 4; ++i) { |
322 | 0 | const int this_mi_col = mi_col + i * qbs; |
323 | 0 | collect_mv_stats_b(mv_stats, cpi, mi_row, this_mi_col); |
324 | 0 | } |
325 | 0 | break; |
326 | 0 | default: assert(0); |
327 | 0 | } |
328 | 0 | } |
329 | | |
330 | | static inline void collect_mv_stats_tile(MV_STATS *mv_stats, |
331 | | const AV1_COMP *cpi, |
332 | 0 | const TileInfo *tile_info) { |
333 | 0 | const AV1_COMMON *cm = &cpi->common; |
334 | 0 | const int mi_row_start = tile_info->mi_row_start; |
335 | 0 | const int mi_row_end = tile_info->mi_row_end; |
336 | 0 | const int mi_col_start = tile_info->mi_col_start; |
337 | 0 | const int mi_col_end = tile_info->mi_col_end; |
338 | 0 | const int sb_size_mi = cm->seq_params->mib_size; |
339 | 0 | BLOCK_SIZE sb_size = cm->seq_params->sb_size; |
340 | 0 | for (int mi_row = mi_row_start; mi_row < mi_row_end; mi_row += sb_size_mi) { |
341 | 0 | for (int mi_col = mi_col_start; mi_col < mi_col_end; mi_col += sb_size_mi) { |
342 | 0 | collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, sb_size); |
343 | 0 | } |
344 | 0 | } |
345 | 0 | } |
346 | | |
347 | 0 | void av1_collect_mv_stats(AV1_COMP *cpi, int current_q) { |
348 | 0 | MV_STATS *mv_stats = &cpi->mv_stats; |
349 | 0 | const AV1_COMMON *cm = &cpi->common; |
350 | 0 | const int tile_cols = cm->tiles.cols; |
351 | 0 | const int tile_rows = cm->tiles.rows; |
352 | |
|
353 | 0 | for (int tile_row = 0; tile_row < tile_rows; tile_row++) { |
354 | 0 | TileInfo tile_info; |
355 | 0 | av1_tile_set_row(&tile_info, cm, tile_row); |
356 | 0 | for (int tile_col = 0; tile_col < tile_cols; tile_col++) { |
357 | 0 | const int tile_idx = tile_row * tile_cols + tile_col; |
358 | 0 | av1_tile_set_col(&tile_info, cm, tile_col); |
359 | 0 | cpi->tile_data[tile_idx].tctx = *cm->fc; |
360 | 0 | cpi->td.mb.e_mbd.tile_ctx = &cpi->tile_data[tile_idx].tctx; |
361 | 0 | collect_mv_stats_tile(mv_stats, cpi, &tile_info); |
362 | 0 | } |
363 | 0 | } |
364 | |
|
365 | 0 | mv_stats->q = current_q; |
366 | 0 | mv_stats->order = cpi->common.current_frame.order_hint; |
367 | 0 | mv_stats->valid = 1; |
368 | 0 | } |
369 | | |
370 | | static inline int get_smart_mv_prec(AV1_COMP *cpi, const MV_STATS *mv_stats, |
371 | 0 | int current_q) { |
372 | 0 | const AV1_COMMON *cm = &cpi->common; |
373 | 0 | const int order_hint = cpi->common.current_frame.order_hint; |
374 | 0 | const int order_diff = order_hint - mv_stats->order; |
375 | 0 | const float area = (float)(cm->width * cm->height); |
376 | 0 | float features[MV_PREC_FEATURE_SIZE] = { |
377 | 0 | (float)current_q, |
378 | 0 | (float)mv_stats->q, |
379 | 0 | (float)order_diff, |
380 | 0 | mv_stats->inter_count / area, |
381 | 0 | mv_stats->intra_count / area, |
382 | 0 | mv_stats->default_mvs / area, |
383 | 0 | mv_stats->mv_joint_count[0] / area, |
384 | 0 | mv_stats->mv_joint_count[1] / area, |
385 | 0 | mv_stats->mv_joint_count[2] / area, |
386 | 0 | mv_stats->mv_joint_count[3] / area, |
387 | 0 | mv_stats->last_bit_zero / area, |
388 | 0 | mv_stats->last_bit_nonzero / area, |
389 | 0 | mv_stats->total_mv_rate / area, |
390 | 0 | mv_stats->hp_total_mv_rate / area, |
391 | 0 | mv_stats->lp_total_mv_rate / area, |
392 | 0 | mv_stats->horz_text / area, |
393 | 0 | mv_stats->vert_text / area, |
394 | 0 | mv_stats->diag_text / area, |
395 | 0 | }; |
396 | |
|
397 | 0 | for (int f_idx = 0; f_idx < MV_PREC_FEATURE_SIZE; f_idx++) { |
398 | 0 | features[f_idx] = |
399 | 0 | (features[f_idx] - av1_mv_prec_mean[f_idx]) / av1_mv_prec_std[f_idx]; |
400 | 0 | } |
401 | 0 | float score = 0.0f; |
402 | |
|
403 | 0 | av1_nn_predict(features, &av1_mv_prec_dnn_config, 1, &score); |
404 | |
|
405 | 0 | const int use_high_hp = score >= 0.0f; |
406 | 0 | return use_high_hp; |
407 | 0 | } |
408 | | #endif // !CONFIG_REALTIME_ONLY |
409 | | |
410 | 0 | void av1_pick_and_set_high_precision_mv(AV1_COMP *cpi, int qindex) { |
411 | 0 | int use_hp = qindex < HIGH_PRECISION_MV_QTHRESH; |
412 | 0 | #if !CONFIG_REALTIME_ONLY |
413 | 0 | MV_STATS *mv_stats = &cpi->mv_stats; |
414 | 0 | #endif // !CONFIG_REALTIME_ONLY |
415 | |
|
416 | 0 | if (cpi->sf.hl_sf.high_precision_mv_usage == QTR_ONLY) { |
417 | 0 | use_hp = 0; |
418 | 0 | } |
419 | 0 | #if !CONFIG_REALTIME_ONLY |
420 | 0 | else if (cpi->sf.hl_sf.high_precision_mv_usage == LAST_MV_DATA && |
421 | 0 | av1_frame_allows_smart_mv(cpi) && mv_stats->valid) { |
422 | 0 | use_hp = get_smart_mv_prec(cpi, mv_stats, qindex); |
423 | 0 | } |
424 | 0 | #endif // !CONFIG_REALTIME_ONLY |
425 | |
|
426 | 0 | av1_set_high_precision_mv(cpi, use_hp, |
427 | 0 | cpi->common.features.cur_frame_force_integer_mv); |
428 | 0 | } |