Prune inter modes based on tpl stats
This patch introduces a speed feature ‘prune_inter_modes_based_on_tpl’
to prune inter modes using inter cost obtained during tpl.
This speed feature is enabled for cpu-used 5 for non-boosted frames.
cpu-used Instruction Count BD-Rate Drop
Reduction avg.psnr ovr.psnr ssim
5 3.9% 0.23% 0.24% 0.23%
STATS_CHANGED
Change-Id: Id4ba193d17e84c11283e411c259e3d010540760d
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 8756531..eceaf25 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -3771,9 +3771,6 @@
#undef NUM_SIMPLE_MOTION_FEATURES
#if !CONFIG_REALTIME_ONLY
-static INLINE int coded_to_superres_mi(int mi_col, int denom) {
- return (mi_col * denom + SCALE_NUMERATOR / 2) / SCALE_NUMERATOR;
-}
static int get_rdmult_delta(AV1_COMP *cpi, BLOCK_SIZE bsize, int analysis_type,
int mi_row, int mi_col, int orig_rdmult) {
@@ -4460,8 +4457,28 @@
for (int col = mi_col; col < mi_col_end; col += step) {
TplDepStats *this_stats =
&tpl_stats[av1_tpl_ptr_pos(cpi, row, col, tpl_stride)];
+ int64_t tpl_pred_error[INTER_REFS_PER_FRAME];
+ memset(tpl_pred_error, 0, sizeof(tpl_pred_error));
+
+ // Find the winner ref frame idx for the current block
+ int64_t best_inter_cost = this_stats->pred_error[0];
+ int best_rf_idx = 0;
+ for (int idx = 1; idx < INTER_REFS_PER_FRAME; ++idx) {
+ if ((this_stats->pred_error[idx] < best_inter_cost) &&
+ (this_stats->pred_error[idx] != 0)) {
+ best_inter_cost = this_stats->pred_error[idx];
+ best_rf_idx = idx;
+ }
+ }
+ // Populate tpl_pred_error of
+ // 1. LAST_FRAME
+ // 2. best_ref w.r.t. LAST_FRAME.
+ tpl_pred_error[LAST_FRAME - 1] = this_stats->pred_error[LAST_FRAME - 1];
+ tpl_pred_error[best_rf_idx] =
+ this_stats->pred_error[best_rf_idx] - tpl_pred_error[LAST_FRAME - 1];
+
for (int rf_idx = 0; rf_idx < INTER_REFS_PER_FRAME; ++rf_idx)
- inter_cost[rf_idx] += this_stats->pred_error[rf_idx];
+ inter_cost[rf_idx] += tpl_pred_error[rf_idx];
}
}
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 4490591..d188b7b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -13,6 +13,7 @@
#include <math.h>
#include <stdbool.h>
+#include "config/aom_config.h"
#include "config/aom_dsp_rtcd.h"
#include "config/av1_rtcd.h"
@@ -2049,6 +2050,167 @@
int num_motion_mode_cand;
} motion_mode_best_st_candidate;
+// Checks if the current reference frame matches with neighbouring block's
+// (top/left) reference frames
+static AOM_INLINE int ref_match_found_in_nb_blocks(MB_MODE_INFO *cur_mbmi,
+ MB_MODE_INFO *nb_mbmi) {
+ MV_REFERENCE_FRAME nb_ref_frames[2] = { nb_mbmi->ref_frame[0],
+ nb_mbmi->ref_frame[1] };
+ MV_REFERENCE_FRAME cur_ref_frames[2] = { cur_mbmi->ref_frame[0],
+ cur_mbmi->ref_frame[1] };
+ const int is_cur_comp_pred = has_second_ref(cur_mbmi);
+ int match_found = 0;
+
+ for (int i = 0; i < (is_cur_comp_pred + 1); i++) {
+ if ((cur_ref_frames[i] == nb_ref_frames[0]) ||
+ (cur_ref_frames[i] == nb_ref_frames[1]))
+ match_found = 1;
+ }
+ return match_found;
+}
+
+static AOM_INLINE int find_ref_match_in_above_nbs(const int total_mi_cols,
+ MACROBLOCKD *xd) {
+ if (!xd->up_available) return 0;
+ const int mi_col = xd->mi_col;
+ MB_MODE_INFO **cur_mbmi = xd->mi;
+ // prev_row_mi points into the mi array, starting at the beginning of the
+ // previous row.
+ MB_MODE_INFO **prev_row_mi = xd->mi - mi_col - 1 * xd->mi_stride;
+ const int end_col = AOMMIN(mi_col + xd->n4_w, total_mi_cols);
+ uint8_t mi_step;
+ for (int above_mi_col = mi_col; above_mi_col < end_col;
+ above_mi_col += mi_step) {
+ MB_MODE_INFO **above_mi = prev_row_mi + above_mi_col;
+ mi_step = mi_size_wide[above_mi[0]->sb_type];
+ int match_found = 0;
+ if (is_inter_block(*above_mi))
+ match_found = ref_match_found_in_nb_blocks(*cur_mbmi, *above_mi);
+ if (match_found) return 1;
+ }
+ return 0;
+}
+
+static AOM_INLINE int find_ref_match_in_left_nbs(const int total_mi_rows,
+ MACROBLOCKD *xd) {
+ if (!xd->left_available) return 0;
+ const int mi_row = xd->mi_row;
+ MB_MODE_INFO **cur_mbmi = xd->mi;
+ // prev_col_mi points into the mi array, starting at the top of the
+ // previous column
+ MB_MODE_INFO **prev_col_mi = xd->mi - 1 - mi_row * xd->mi_stride;
+ const int end_row = AOMMIN(mi_row + xd->n4_h, total_mi_rows);
+ uint8_t mi_step;
+ for (int left_mi_row = mi_row; left_mi_row < end_row;
+ left_mi_row += mi_step) {
+ MB_MODE_INFO **left_mi = prev_col_mi + left_mi_row * xd->mi_stride;
+ mi_step = mi_size_high[left_mi[0]->sb_type];
+ int match_found = 0;
+ if (is_inter_block(*left_mi))
+ match_found = ref_match_found_in_nb_blocks(*cur_mbmi, *left_mi);
+ if (match_found) return 1;
+ }
+ return 0;
+}
+
+typedef struct {
+ int64_t best_inter_cost;
+ int64_t ref_inter_cost[INTER_REFS_PER_FRAME];
+} PruneInfoFromTpl;
+
+#if !CONFIG_REALTIME_ONLY
+// TODO(Remya): Check if get_tpl_stats_b() can be reused
+static AOM_INLINE void get_block_level_tpl_stats(
+ AV1_COMP *cpi, BLOCK_SIZE bsize, int mi_row, int mi_col, int *valid_refs,
+ PruneInfoFromTpl *inter_cost_info_from_tpl) {
+ const GF_GROUP *const gf_group = &cpi->gf_group;
+ AV1_COMMON *const cm = &cpi->common;
+
+ assert(IMPLIES(gf_group->size > 0, gf_group->index < gf_group->size));
+ const int tpl_idx = gf_group->index;
+ TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
+ TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
+
+ const int mi_wide = mi_size_wide[bsize];
+ const int mi_high = mi_size_high[bsize];
+ if (tpl_frame->is_valid) {
+ int64_t best_inter_cost = INT64_MAX;
+ int tpl_stride = tpl_frame->stride;
+ const int step = 1 << cpi->tpl_stats_block_mis_log2;
+ const int mi_col_sr =
+ coded_to_superres_mi(mi_col, cm->superres_scale_denominator);
+ const int mi_col_end_sr =
+ coded_to_superres_mi(mi_col + mi_wide, cm->superres_scale_denominator);
+ const int mi_cols_sr = av1_pixels_to_mi(cm->superres_upscaled_width);
+
+ for (int row = mi_row; row < mi_row + mi_high; row += step) {
+ for (int col = mi_col_sr; col < mi_col_end_sr; col += step) {
+ if (row >= cm->mi_rows || col >= mi_cols_sr) continue;
+ TplDepStats *this_stats =
+ &tpl_stats[av1_tpl_ptr_pos(cpi, row, col, tpl_stride)];
+
+ // Sums up the inter cost of corresponding ref frames
+ for (int ref_idx = 0; ref_idx < INTER_REFS_PER_FRAME; ref_idx++) {
+ inter_cost_info_from_tpl->ref_inter_cost[ref_idx] +=
+ this_stats->pred_error[ref_idx];
+ }
+ }
+ }
+
+ // Computes the best inter cost (minimum inter_cost)
+ for (int ref_idx = 0; ref_idx < INTER_REFS_PER_FRAME; ref_idx++) {
+ int64_t cur_inter_cost =
+ inter_cost_info_from_tpl->ref_inter_cost[ref_idx];
+ // For invalid ref frames, cur_inter_cost = 0 and has to be handled while
+ // calculating the minimum inter_cost
+ if (cur_inter_cost != 0 && (cur_inter_cost < best_inter_cost) &&
+ (valid_refs[ref_idx]))
+ best_inter_cost = cur_inter_cost;
+ }
+ inter_cost_info_from_tpl->best_inter_cost = best_inter_cost;
+ }
+}
+#endif
+
+static AOM_INLINE int prune_modes_based_on_tpl_stats(
+ PruneInfoFromTpl *inter_cost_info_from_tpl, const int *refs, int ref_mv_idx,
+ const PREDICTION_MODE this_mode) {
+ int64_t cur_inter_cost;
+
+ int is_globalmv = (this_mode == GLOBALMV) || (this_mode == GLOBAL_GLOBALMV);
+ int prune_index = is_globalmv ? MAX_REF_MV_SEARCH : ref_mv_idx;
+
+ // Thresholds used for pruning:
+ // Lower value indicates aggressive pruning for ref_mv indices 1, 2 and for
+ // GLOBAL/GLOBAL_GLOBALMV. Higher value indicates conservative pruning for
+ // ref_mv_idx 0. 'prune_index' 0, 1, 2 corresponds to ref_mv indices 0, 1
+ // and 2. prune_index 3 corresponds to GLOBALMV/GLOBAL_GLOBALMV
+ const int tpl_inter_mode_prune_mul_factor[MAX_REF_MV_SEARCH + 1] = { 3, 2, 2,
+ 2 };
+
+ int is_comp_pred = (refs[1] > INTRA_FRAME);
+
+ if (!is_comp_pred) {
+ cur_inter_cost = inter_cost_info_from_tpl->ref_inter_cost[refs[0] - 1];
+ } else {
+ int64_t inter_cost_ref0 =
+ inter_cost_info_from_tpl->ref_inter_cost[refs[0] - 1];
+ int64_t inter_cost_ref1 =
+ inter_cost_info_from_tpl->ref_inter_cost[refs[1] - 1];
+ // Choose maximum inter_cost among inter_cost_ref0 and inter_cost_ref1 for
+ // more aggressive pruning
+ cur_inter_cost = AOMMAX(inter_cost_ref0, inter_cost_ref1);
+ }
+
+ // Prune the mode if cur_inter_cost is greater than threshold times
+ // best_inter_cost
+ int64_t best_inter_cost = inter_cost_info_from_tpl->best_inter_cost;
+ if (cur_inter_cost >
+ ((tpl_inter_mode_prune_mul_factor[prune_index] * best_inter_cost) >> 1))
+ return 1;
+ return 0;
+}
+
static int64_t handle_inter_mode(
AV1_COMP *const cpi, TileDataEnc *tile_data, MACROBLOCK *x,
BLOCK_SIZE bsize, RD_STATS *rd_stats, RD_STATS *rd_stats_y,
@@ -2056,7 +2218,8 @@
int64_t ref_best_rd, uint8_t *const tmp_buf,
const CompoundTypeRdBuffers *rd_buffers, int64_t *best_est_rd,
const int do_tx_search, InterModesInfo *inter_modes_info,
- motion_mode_candidate *motion_mode_cand, int64_t *skip_rd) {
+ motion_mode_candidate *motion_mode_cand, int64_t *skip_rd,
+ PruneInfoFromTpl *inter_cost_info_from_tpl) {
const AV1_COMMON *cm = &cpi->common;
const int num_planes = av1_num_planes(cm);
MACROBLOCKD *xd = &x->e_mbd;
@@ -2064,6 +2227,12 @@
MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
const int is_comp_pred = has_second_ref(mbmi);
const PREDICTION_MODE this_mode = mbmi->mode;
+
+ const GF_GROUP *const gf_group = &cpi->gf_group;
+ const int tpl_idx = gf_group->index;
+ TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
+ int prune_modes_based_on_tpl =
+ cpi->sf.inter_sf.prune_inter_modes_based_on_tpl && tpl_frame->is_valid;
int i;
const int refs[2] = { mbmi->ref_frame[0],
(mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) };
@@ -2100,6 +2269,18 @@
int mode_search_mask = (1 << COMPOUND_AVERAGE) | (1 << COMPOUND_DISTWTD) |
(1 << COMPOUND_WEDGE) | (1 << COMPOUND_DIFFWTD);
+ // Do not prune the mode based on inter cost from tpl if the current ref frame
+ // is the winner ref in neighbouring blocks.
+ int ref_match_found_in_above_nb = 0;
+ int ref_match_found_in_left_nb = 0;
+ if (prune_modes_based_on_tpl) {
+ const int total_mi_cols = cm->mi_cols;
+ ref_match_found_in_above_nb =
+ find_ref_match_in_above_nbs(total_mi_cols, xd);
+ const int total_mi_rows = cm->mi_rows;
+ ref_match_found_in_left_nb = find_ref_match_in_left_nbs(total_mi_rows, xd);
+ }
+
// First, perform a simple translation search for each of the indices. If
// an index performs well, it will be fully searched here.
const int ref_set = get_drl_refmv_count(x, mbmi->ref_frame, this_mode);
@@ -2121,6 +2302,12 @@
// MV did not perform well in simple translation search. Skip it.
continue;
}
+ if (prune_modes_based_on_tpl && !ref_match_found_in_above_nb &&
+ !ref_match_found_in_left_nb) {
+ if (prune_modes_based_on_tpl_stats(inter_cost_info_from_tpl, refs,
+ ref_mv_idx, this_mode))
+ continue;
+ }
av1_init_rd_stats(rd_stats);
mbmi->interinter_comp.type = COMPOUND_AVERAGE;
@@ -4077,6 +4264,37 @@
// Need to tweak the threshold for hdres speed 0 & 1.
const int mi_row = xd->mi_row;
const int mi_col = xd->mi_col;
+ // x->search_ref_frame[id] = 1 => no pruning in
+ // prune_ref_by_selective_ref_frame()
+ // x->search_ref_frame[id] = 0 => ref frame can be pruned in
+ // prune_ref_by_selective_ref_frame()
+ // Populating valid_refs[idx] = 1 ensures that
+ // 'inter_cost_info_from_tpl.best_inter_cost' does not correspond to a pruned
+ // ref frame
+ int valid_refs[INTER_REFS_PER_FRAME];
+ memset(valid_refs, 0, sizeof(valid_refs));
+
+ for (MV_REFERENCE_FRAME frame = LAST_FRAME; frame < REF_FRAMES; frame++) {
+ MV_REFERENCE_FRAME refs[2] = { frame, NONE_FRAME };
+ valid_refs[frame - 1] = x->search_ref_frame[frame];
+ if (!valid_refs[frame - 1]) {
+ valid_refs[frame - 1] = (!prune_ref_by_selective_ref_frame(
+ cpi, x, refs, cm->cur_frame->ref_display_order_hint,
+ cm->current_frame.display_order_hint));
+ }
+ }
+
+ // Obtain the relevant tpl stats for pruning inter modes
+ PruneInfoFromTpl inter_cost_info_from_tpl;
+ inter_cost_info_from_tpl.best_inter_cost = 0;
+ memset(inter_cost_info_from_tpl.ref_inter_cost, 0,
+ sizeof(inter_cost_info_from_tpl.ref_inter_cost));
+#if !CONFIG_REALTIME_ONLY
+ if (cpi->sf.inter_sf.prune_inter_modes_based_on_tpl) {
+ get_block_level_tpl_stats(cpi, bsize, mi_row, mi_col, valid_refs,
+ &inter_cost_info_from_tpl);
+ }
+#endif
const int do_pruning =
(AOMMIN(cm->width, cm->height) > 480 && cpi->speed <= 1) ? 0 : 1;
if (do_pruning && sf->intra_sf.skip_intra_in_interframe) {
@@ -4250,7 +4468,7 @@
cpi, tile_data, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
&disable_skip, &args, ref_best_rd, tmp_buf, &x->comp_rd_buffer,
&best_est_rd, do_tx_search, inter_modes_info, &motion_mode_cand,
- &skip_rd);
+ &skip_rd, &inter_cost_info_from_tpl);
if (sf->inter_sf.prune_comp_search_by_single_result > 0 &&
is_inter_singleref_mode(this_mode) && args.single_ref_first_pass) {
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 22f83ff..986ac28 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -118,6 +118,12 @@
void av1_inter_mode_data_init(struct TileDataEnc *tile_data);
void av1_inter_mode_data_fit(TileDataEnc *tile_data, int rdmult);
+#if !CONFIG_REALTIME_ONLY
+static INLINE int coded_to_superres_mi(int mi_col, int denom) {
+ return (mi_col * denom + SCALE_NUMERATOR / 2) / SCALE_NUMERATOR;
+}
+#endif
+
static INLINE int av1_encoder_get_relative_dist(const OrderHintInfo *oh, int a,
int b) {
if (!oh->enable_order_hint) return 0;
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 6b0e1a6..5dd2e6c 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -573,6 +573,7 @@
sf->part_sf.ext_partition_eval_thresh =
cm->allow_screen_content_tools ? BLOCK_8X8 : BLOCK_16X16;
+ sf->inter_sf.prune_inter_modes_based_on_tpl = boosted ? 0 : 1;
sf->inter_sf.disable_interinter_wedge = 1;
sf->inter_sf.disable_obmc = 1;
sf->inter_sf.disable_onesided_comp = 1;
@@ -968,6 +969,7 @@
inter_sf->reuse_inter_intra_mode = 0;
inter_sf->disable_sb_level_coeff_cost_upd = 0;
inter_sf->disable_sb_level_mv_cost_upd = 0;
+ inter_sf->prune_inter_modes_based_on_tpl = 0;
inter_sf->prune_comp_search_by_single_result = 0;
inter_sf->skip_repeated_ref_mv = 0;
inter_sf->skip_repeated_newmv = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index d610383..e45c18c 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -656,6 +656,9 @@
// cpi->oxcf.coeff_cost_upd_freq = COST_UPD_SB (i.e. set at SB level)
int disable_sb_level_mv_cost_upd;
+ // Prune inter modes based on tpl stats
+ int prune_inter_modes_based_on_tpl;
+
// Model based breakout after interpolation filter search
// 0: no breakout
// 1: use model based rd breakout
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 7917213..2c50db9 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -378,13 +378,13 @@
inter_cost = tpl_get_satd_cost(x, src_diff, bw, src_mb_buffer, src_stride,
predictor, bw, coeff, bw, bh, tx_size);
+ // Store inter cost for each ref frame
+ tpl_stats->pred_error[rf_idx] = AOMMAX(1, inter_cost);
if (inter_cost < best_inter_cost) {
memcpy(best_coeff, coeff, sizeof(best_coeff));
best_rf_idx = rf_idx;
- if (rf_idx == 0) tpl_stats->pred_error[rf_idx] = inter_cost;
-
best_inter_cost = inter_cost;
best_mv.as_int = x->best_mv.as_int;
if (best_inter_cost < best_intra_cost) {
@@ -395,11 +395,6 @@
}
}
- if (best_rf_idx >= 0) {
- tpl_stats->pred_error[best_rf_idx] =
- best_inter_cost - tpl_stats->pred_error[0];
- }
-
if (best_inter_cost < INT64_MAX) {
uint16_t eob;
get_quantize_error(x, 0, best_coeff, qcoeff, dqcoeff, tx_size, &eob,