Prune inter modes based on tpl stats

This patch introduces a speed feature ‘prune_inter_modes_based_on_tpl’
to prune inter modes using inter cost obtained during tpl.
This speed feature is enabled for cpu-used 5 for non-boosted frames.

cpu-used   Instruction Count    BD-Rate Drop
              Reduction     avg.psnr  ovr.psnr  ssim
    5          3.9%           0.23%    0.24%    0.23%

STATS_CHANGED

Change-Id: Id4ba193d17e84c11283e411c259e3d010540760d
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 8756531..eceaf25 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -3771,9 +3771,6 @@
 #undef NUM_SIMPLE_MOTION_FEATURES
 
 #if !CONFIG_REALTIME_ONLY
-static INLINE int coded_to_superres_mi(int mi_col, int denom) {
-  return (mi_col * denom + SCALE_NUMERATOR / 2) / SCALE_NUMERATOR;
-}
 
 static int get_rdmult_delta(AV1_COMP *cpi, BLOCK_SIZE bsize, int analysis_type,
                             int mi_row, int mi_col, int orig_rdmult) {
@@ -4460,8 +4457,28 @@
     for (int col = mi_col; col < mi_col_end; col += step) {
       TplDepStats *this_stats =
           &tpl_stats[av1_tpl_ptr_pos(cpi, row, col, tpl_stride)];
+      int64_t tpl_pred_error[INTER_REFS_PER_FRAME];
+      memset(tpl_pred_error, 0, sizeof(tpl_pred_error));
+
+      // Find the winner ref frame idx for the current block
+      int64_t best_inter_cost = this_stats->pred_error[0];
+      int best_rf_idx = 0;
+      for (int idx = 1; idx < INTER_REFS_PER_FRAME; ++idx) {
+        if ((this_stats->pred_error[idx] < best_inter_cost) &&
+            (this_stats->pred_error[idx] != 0)) {
+          best_inter_cost = this_stats->pred_error[idx];
+          best_rf_idx = idx;
+        }
+      }
+      // Populate tpl_pred_error of
+      // 1. LAST_FRAME
+      // 2. best_ref w.r.t. LAST_FRAME.
+      tpl_pred_error[LAST_FRAME - 1] = this_stats->pred_error[LAST_FRAME - 1];
+      tpl_pred_error[best_rf_idx] =
+          this_stats->pred_error[best_rf_idx] - tpl_pred_error[LAST_FRAME - 1];
+
       for (int rf_idx = 0; rf_idx < INTER_REFS_PER_FRAME; ++rf_idx)
-        inter_cost[rf_idx] += this_stats->pred_error[rf_idx];
+        inter_cost[rf_idx] += tpl_pred_error[rf_idx];
     }
   }
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 4490591..d188b7b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -13,6 +13,7 @@
 #include <math.h>
 #include <stdbool.h>
 
+#include "config/aom_config.h"
 #include "config/aom_dsp_rtcd.h"
 #include "config/av1_rtcd.h"
 
@@ -2049,6 +2050,167 @@
   int num_motion_mode_cand;
 } motion_mode_best_st_candidate;
 
+// Checks if the current reference frame matches with neighbouring block's
+// (top/left) reference frames
+static AOM_INLINE int ref_match_found_in_nb_blocks(MB_MODE_INFO *cur_mbmi,
+                                                   MB_MODE_INFO *nb_mbmi) {
+  MV_REFERENCE_FRAME nb_ref_frames[2] = { nb_mbmi->ref_frame[0],
+                                          nb_mbmi->ref_frame[1] };
+  MV_REFERENCE_FRAME cur_ref_frames[2] = { cur_mbmi->ref_frame[0],
+                                           cur_mbmi->ref_frame[1] };
+  const int is_cur_comp_pred = has_second_ref(cur_mbmi);
+  int match_found = 0;
+
+  for (int i = 0; i < (is_cur_comp_pred + 1); i++) {
+    if ((cur_ref_frames[i] == nb_ref_frames[0]) ||
+        (cur_ref_frames[i] == nb_ref_frames[1]))
+      match_found = 1;
+  }
+  return match_found;
+}
+
+static AOM_INLINE int find_ref_match_in_above_nbs(const int total_mi_cols,
+                                                  MACROBLOCKD *xd) {
+  if (!xd->up_available) return 0;
+  const int mi_col = xd->mi_col;
+  MB_MODE_INFO **cur_mbmi = xd->mi;
+  // prev_row_mi points into the mi array, starting at the beginning of the
+  // previous row.
+  MB_MODE_INFO **prev_row_mi = xd->mi - mi_col - 1 * xd->mi_stride;
+  const int end_col = AOMMIN(mi_col + xd->n4_w, total_mi_cols);
+  uint8_t mi_step;
+  for (int above_mi_col = mi_col; above_mi_col < end_col;
+       above_mi_col += mi_step) {
+    MB_MODE_INFO **above_mi = prev_row_mi + above_mi_col;
+    mi_step = mi_size_wide[above_mi[0]->sb_type];
+    int match_found = 0;
+    if (is_inter_block(*above_mi))
+      match_found = ref_match_found_in_nb_blocks(*cur_mbmi, *above_mi);
+    if (match_found) return 1;
+  }
+  return 0;
+}
+
+static AOM_INLINE int find_ref_match_in_left_nbs(const int total_mi_rows,
+                                                 MACROBLOCKD *xd) {
+  if (!xd->left_available) return 0;
+  const int mi_row = xd->mi_row;
+  MB_MODE_INFO **cur_mbmi = xd->mi;
+  // prev_col_mi points into the mi array, starting at the top of the
+  // previous column
+  MB_MODE_INFO **prev_col_mi = xd->mi - 1 - mi_row * xd->mi_stride;
+  const int end_row = AOMMIN(mi_row + xd->n4_h, total_mi_rows);
+  uint8_t mi_step;
+  for (int left_mi_row = mi_row; left_mi_row < end_row;
+       left_mi_row += mi_step) {
+    MB_MODE_INFO **left_mi = prev_col_mi + left_mi_row * xd->mi_stride;
+    mi_step = mi_size_high[left_mi[0]->sb_type];
+    int match_found = 0;
+    if (is_inter_block(*left_mi))
+      match_found = ref_match_found_in_nb_blocks(*cur_mbmi, *left_mi);
+    if (match_found) return 1;
+  }
+  return 0;
+}
+
+typedef struct {
+  int64_t best_inter_cost;
+  int64_t ref_inter_cost[INTER_REFS_PER_FRAME];
+} PruneInfoFromTpl;
+
+#if !CONFIG_REALTIME_ONLY
+// TODO(Remya): Check if get_tpl_stats_b() can be reused
+static AOM_INLINE void get_block_level_tpl_stats(
+    AV1_COMP *cpi, BLOCK_SIZE bsize, int mi_row, int mi_col, int *valid_refs,
+    PruneInfoFromTpl *inter_cost_info_from_tpl) {
+  const GF_GROUP *const gf_group = &cpi->gf_group;
+  AV1_COMMON *const cm = &cpi->common;
+
+  assert(IMPLIES(gf_group->size > 0, gf_group->index < gf_group->size));
+  const int tpl_idx = gf_group->index;
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
+  TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
+
+  const int mi_wide = mi_size_wide[bsize];
+  const int mi_high = mi_size_high[bsize];
+  if (tpl_frame->is_valid) {
+    int64_t best_inter_cost = INT64_MAX;
+    int tpl_stride = tpl_frame->stride;
+    const int step = 1 << cpi->tpl_stats_block_mis_log2;
+    const int mi_col_sr =
+        coded_to_superres_mi(mi_col, cm->superres_scale_denominator);
+    const int mi_col_end_sr =
+        coded_to_superres_mi(mi_col + mi_wide, cm->superres_scale_denominator);
+    const int mi_cols_sr = av1_pixels_to_mi(cm->superres_upscaled_width);
+
+    for (int row = mi_row; row < mi_row + mi_high; row += step) {
+      for (int col = mi_col_sr; col < mi_col_end_sr; col += step) {
+        if (row >= cm->mi_rows || col >= mi_cols_sr) continue;
+        TplDepStats *this_stats =
+            &tpl_stats[av1_tpl_ptr_pos(cpi, row, col, tpl_stride)];
+
+        // Sums up the inter cost of corresponding ref frames
+        for (int ref_idx = 0; ref_idx < INTER_REFS_PER_FRAME; ref_idx++) {
+          inter_cost_info_from_tpl->ref_inter_cost[ref_idx] +=
+              this_stats->pred_error[ref_idx];
+        }
+      }
+    }
+
+    // Computes the best inter cost (minimum inter_cost)
+    for (int ref_idx = 0; ref_idx < INTER_REFS_PER_FRAME; ref_idx++) {
+      int64_t cur_inter_cost =
+          inter_cost_info_from_tpl->ref_inter_cost[ref_idx];
+      // For invalid ref frames, cur_inter_cost = 0 and has to be handled while
+      // calculating the minimum inter_cost
+      if (cur_inter_cost != 0 && (cur_inter_cost < best_inter_cost) &&
+          (valid_refs[ref_idx]))
+        best_inter_cost = cur_inter_cost;
+    }
+    inter_cost_info_from_tpl->best_inter_cost = best_inter_cost;
+  }
+}
+#endif
+
+static AOM_INLINE int prune_modes_based_on_tpl_stats(
+    PruneInfoFromTpl *inter_cost_info_from_tpl, const int *refs, int ref_mv_idx,
+    const PREDICTION_MODE this_mode) {
+  int64_t cur_inter_cost;
+
+  int is_globalmv = (this_mode == GLOBALMV) || (this_mode == GLOBAL_GLOBALMV);
+  int prune_index = is_globalmv ? MAX_REF_MV_SEARCH : ref_mv_idx;
+
+  // Thresholds used for pruning:
+  // Lower value indicates aggressive pruning for ref_mv indices 1, 2 and for
+  // GLOBAL/GLOBAL_GLOBALMV. Higher value indicates conservative pruning for
+  // ref_mv_idx 0. 'prune_index' 0, 1, 2 corresponds to ref_mv indices 0, 1
+  // and 2. prune_index 3 corresponds to GLOBALMV/GLOBAL_GLOBALMV
+  const int tpl_inter_mode_prune_mul_factor[MAX_REF_MV_SEARCH + 1] = { 3, 2, 2,
+                                                                       2 };
+
+  int is_comp_pred = (refs[1] > INTRA_FRAME);
+
+  if (!is_comp_pred) {
+    cur_inter_cost = inter_cost_info_from_tpl->ref_inter_cost[refs[0] - 1];
+  } else {
+    int64_t inter_cost_ref0 =
+        inter_cost_info_from_tpl->ref_inter_cost[refs[0] - 1];
+    int64_t inter_cost_ref1 =
+        inter_cost_info_from_tpl->ref_inter_cost[refs[1] - 1];
+    // Choose maximum inter_cost among inter_cost_ref0 and inter_cost_ref1 for
+    // more aggressive pruning
+    cur_inter_cost = AOMMAX(inter_cost_ref0, inter_cost_ref1);
+  }
+
+  // Prune the mode if cur_inter_cost is greater than threshold times
+  // best_inter_cost
+  int64_t best_inter_cost = inter_cost_info_from_tpl->best_inter_cost;
+  if (cur_inter_cost >
+      ((tpl_inter_mode_prune_mul_factor[prune_index] * best_inter_cost) >> 1))
+    return 1;
+  return 0;
+}
+
 static int64_t handle_inter_mode(
     AV1_COMP *const cpi, TileDataEnc *tile_data, MACROBLOCK *x,
     BLOCK_SIZE bsize, RD_STATS *rd_stats, RD_STATS *rd_stats_y,
@@ -2056,7 +2218,8 @@
     int64_t ref_best_rd, uint8_t *const tmp_buf,
     const CompoundTypeRdBuffers *rd_buffers, int64_t *best_est_rd,
     const int do_tx_search, InterModesInfo *inter_modes_info,
-    motion_mode_candidate *motion_mode_cand, int64_t *skip_rd) {
+    motion_mode_candidate *motion_mode_cand, int64_t *skip_rd,
+    PruneInfoFromTpl *inter_cost_info_from_tpl) {
   const AV1_COMMON *cm = &cpi->common;
   const int num_planes = av1_num_planes(cm);
   MACROBLOCKD *xd = &x->e_mbd;
@@ -2064,6 +2227,12 @@
   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
   const int is_comp_pred = has_second_ref(mbmi);
   const PREDICTION_MODE this_mode = mbmi->mode;
+
+  const GF_GROUP *const gf_group = &cpi->gf_group;
+  const int tpl_idx = gf_group->index;
+  TplDepFrame *tpl_frame = &cpi->tpl_frame[tpl_idx];
+  int prune_modes_based_on_tpl =
+      cpi->sf.inter_sf.prune_inter_modes_based_on_tpl && tpl_frame->is_valid;
   int i;
   const int refs[2] = { mbmi->ref_frame[0],
                         (mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) };
@@ -2100,6 +2269,18 @@
   int mode_search_mask = (1 << COMPOUND_AVERAGE) | (1 << COMPOUND_DISTWTD) |
                          (1 << COMPOUND_WEDGE) | (1 << COMPOUND_DIFFWTD);
 
+  // Do not prune the mode based on inter cost from tpl if the current ref frame
+  // is the winner ref in neighbouring blocks.
+  int ref_match_found_in_above_nb = 0;
+  int ref_match_found_in_left_nb = 0;
+  if (prune_modes_based_on_tpl) {
+    const int total_mi_cols = cm->mi_cols;
+    ref_match_found_in_above_nb =
+        find_ref_match_in_above_nbs(total_mi_cols, xd);
+    const int total_mi_rows = cm->mi_rows;
+    ref_match_found_in_left_nb = find_ref_match_in_left_nbs(total_mi_rows, xd);
+  }
+
   // First, perform a simple translation search for each of the indices. If
   // an index performs well, it will be fully searched here.
   const int ref_set = get_drl_refmv_count(x, mbmi->ref_frame, this_mode);
@@ -2121,6 +2302,12 @@
       // MV did not perform well in simple translation search. Skip it.
       continue;
     }
+    if (prune_modes_based_on_tpl && !ref_match_found_in_above_nb &&
+        !ref_match_found_in_left_nb) {
+      if (prune_modes_based_on_tpl_stats(inter_cost_info_from_tpl, refs,
+                                         ref_mv_idx, this_mode))
+        continue;
+    }
     av1_init_rd_stats(rd_stats);
 
     mbmi->interinter_comp.type = COMPOUND_AVERAGE;
@@ -4077,6 +4264,37 @@
   // Need to tweak the threshold for hdres speed 0 & 1.
   const int mi_row = xd->mi_row;
   const int mi_col = xd->mi_col;
+  // x->search_ref_frame[id] = 1 => no pruning in
+  // prune_ref_by_selective_ref_frame()
+  // x->search_ref_frame[id] = 0  => ref frame can be pruned in
+  // prune_ref_by_selective_ref_frame()
+  // Populating valid_refs[idx] = 1 ensures that
+  // 'inter_cost_info_from_tpl.best_inter_cost' does not correspond to a pruned
+  // ref frame
+  int valid_refs[INTER_REFS_PER_FRAME];
+  memset(valid_refs, 0, sizeof(valid_refs));
+
+  for (MV_REFERENCE_FRAME frame = LAST_FRAME; frame < REF_FRAMES; frame++) {
+    MV_REFERENCE_FRAME refs[2] = { frame, NONE_FRAME };
+    valid_refs[frame - 1] = x->search_ref_frame[frame];
+    if (!valid_refs[frame - 1]) {
+      valid_refs[frame - 1] = (!prune_ref_by_selective_ref_frame(
+          cpi, x, refs, cm->cur_frame->ref_display_order_hint,
+          cm->current_frame.display_order_hint));
+    }
+  }
+
+  // Obtain the relevant tpl stats for pruning inter modes
+  PruneInfoFromTpl inter_cost_info_from_tpl;
+  inter_cost_info_from_tpl.best_inter_cost = 0;
+  memset(inter_cost_info_from_tpl.ref_inter_cost, 0,
+         sizeof(inter_cost_info_from_tpl.ref_inter_cost));
+#if !CONFIG_REALTIME_ONLY
+  if (cpi->sf.inter_sf.prune_inter_modes_based_on_tpl) {
+    get_block_level_tpl_stats(cpi, bsize, mi_row, mi_col, valid_refs,
+                              &inter_cost_info_from_tpl);
+  }
+#endif
   const int do_pruning =
       (AOMMIN(cm->width, cm->height) > 480 && cpi->speed <= 1) ? 0 : 1;
   if (do_pruning && sf->intra_sf.skip_intra_in_interframe) {
@@ -4250,7 +4468,7 @@
         cpi, tile_data, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
         &disable_skip, &args, ref_best_rd, tmp_buf, &x->comp_rd_buffer,
         &best_est_rd, do_tx_search, inter_modes_info, &motion_mode_cand,
-        &skip_rd);
+        &skip_rd, &inter_cost_info_from_tpl);
 
     if (sf->inter_sf.prune_comp_search_by_single_result > 0 &&
         is_inter_singleref_mode(this_mode) && args.single_ref_first_pass) {
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 22f83ff..986ac28 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -118,6 +118,12 @@
 void av1_inter_mode_data_init(struct TileDataEnc *tile_data);
 void av1_inter_mode_data_fit(TileDataEnc *tile_data, int rdmult);
 
+#if !CONFIG_REALTIME_ONLY
+static INLINE int coded_to_superres_mi(int mi_col, int denom) {
+  return (mi_col * denom + SCALE_NUMERATOR / 2) / SCALE_NUMERATOR;
+}
+#endif
+
 static INLINE int av1_encoder_get_relative_dist(const OrderHintInfo *oh, int a,
                                                 int b) {
   if (!oh->enable_order_hint) return 0;
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 6b0e1a6..5dd2e6c 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -573,6 +573,7 @@
     sf->part_sf.ext_partition_eval_thresh =
         cm->allow_screen_content_tools ? BLOCK_8X8 : BLOCK_16X16;
 
+    sf->inter_sf.prune_inter_modes_based_on_tpl = boosted ? 0 : 1;
     sf->inter_sf.disable_interinter_wedge = 1;
     sf->inter_sf.disable_obmc = 1;
     sf->inter_sf.disable_onesided_comp = 1;
@@ -968,6 +969,7 @@
   inter_sf->reuse_inter_intra_mode = 0;
   inter_sf->disable_sb_level_coeff_cost_upd = 0;
   inter_sf->disable_sb_level_mv_cost_upd = 0;
+  inter_sf->prune_inter_modes_based_on_tpl = 0;
   inter_sf->prune_comp_search_by_single_result = 0;
   inter_sf->skip_repeated_ref_mv = 0;
   inter_sf->skip_repeated_newmv = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index d610383..e45c18c 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -656,6 +656,9 @@
   // cpi->oxcf.coeff_cost_upd_freq = COST_UPD_SB (i.e. set at SB level)
   int disable_sb_level_mv_cost_upd;
 
+  // Prune inter modes based on tpl stats
+  int prune_inter_modes_based_on_tpl;
+
   // Model based breakout after interpolation filter search
   // 0: no breakout
   // 1: use model based rd breakout
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 7917213..2c50db9 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -378,13 +378,13 @@
 
     inter_cost = tpl_get_satd_cost(x, src_diff, bw, src_mb_buffer, src_stride,
                                    predictor, bw, coeff, bw, bh, tx_size);
+    // Store inter cost for each ref frame
+    tpl_stats->pred_error[rf_idx] = AOMMAX(1, inter_cost);
 
     if (inter_cost < best_inter_cost) {
       memcpy(best_coeff, coeff, sizeof(best_coeff));
       best_rf_idx = rf_idx;
 
-      if (rf_idx == 0) tpl_stats->pred_error[rf_idx] = inter_cost;
-
       best_inter_cost = inter_cost;
       best_mv.as_int = x->best_mv.as_int;
       if (best_inter_cost < best_intra_cost) {
@@ -395,11 +395,6 @@
     }
   }
 
-  if (best_rf_idx >= 0) {
-    tpl_stats->pred_error[best_rf_idx] =
-        best_inter_cost - tpl_stats->pred_error[0];
-  }
-
   if (best_inter_cost < INT64_MAX) {
     uint16_t eob;
     get_quantize_error(x, 0, best_coeff, qcoeff, dqcoeff, tx_size, &eob,