Add sf to prune compound references

This patch adds a speed feature to prune compound reference
frames based on temporal distance and pred mv sad. The speed
feature is enabled for speed 5 and 6.

cpu-used  Instruction Count     BD-Rate Loss(%)
           Reduction(%)     avg.psnr  ovr.psnr  ssim
   5          2.717         0.1067     0.0964    0.0374
   6          3.611         0.3005     0.3073    0.1812

STATS_CHANGED

Change-Id: Ie5f40b23d17573f4d5b969ca5a44aa2209241aa6
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 5a7056a..d073fd1 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -1016,8 +1016,12 @@
    * This is used to measure how viable a reference frame is.
    */
   int pred_mv_sad[REF_FRAMES];
-  //! The minimum of \ref pred_mv_sad.
-  int best_pred_mv_sad;
+  /*! \brief The minimum of \ref pred_mv_sad.
+   *
+   * Index 0 stores the minimum \ref pred_mv_sad across past reference frames.
+   * Index 1 stores the minimum \ref pred_mv_sad across future reference frames.
+   */
+  int best_pred_mv_sad[2];
   //! The sad of the 1st mv ref (nearest).
   int pred_mv0_sad[REF_FRAMES];
   //! The sad of the 2nd mv ref (near).
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 22641b7..4ec7f77 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3836,8 +3836,8 @@
   }
 
   if (sf->inter_sf.alt_ref_search_fp) {
-    if (!cm->show_frame && x->best_pred_mv_sad < INT_MAX) {
-      int sad_thresh = x->best_pred_mv_sad + (x->best_pred_mv_sad >> 3);
+    if (!cm->show_frame && x->best_pred_mv_sad[0] < INT_MAX) {
+      int sad_thresh = x->best_pred_mv_sad[0] + (x->best_pred_mv_sad[0] >> 3);
       // Conservatively skip the modes w.r.t. BWDREF, ALTREF2 and ALTREF, if
       // those are past frames
       MV_REFERENCE_FRAME start_frame =
@@ -3862,8 +3862,8 @@
   }
 
   if (sf->rt_sf.prune_inter_modes_wrt_gf_arf_based_on_sad) {
-    if (x->best_pred_mv_sad < INT_MAX) {
-      int sad_thresh = x->best_pred_mv_sad + (x->best_pred_mv_sad >> 1);
+    if (x->best_pred_mv_sad[0] < INT_MAX) {
+      int sad_thresh = x->best_pred_mv_sad[0] + (x->best_pred_mv_sad[0] >> 1);
       const int prune_ref_list[2] = { GOLDEN_FRAME, ALTREF_FRAME };
 
       // Conservatively skip the modes w.r.t. GOLDEN and ALTREF references
@@ -3980,7 +3980,8 @@
 
   const int mi_row = xd->mi_row;
   const int mi_col = xd->mi_col;
-  x->best_pred_mv_sad = INT_MAX;
+  x->best_pred_mv_sad[0] = INT_MAX;
+  x->best_pred_mv_sad[1] = INT_MAX;
 
   for (MV_REFERENCE_FRAME ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME;
        ++ref_frame) {
@@ -3998,12 +3999,18 @@
       assert(get_ref_frame_yv12_buf(cm, ref_frame) != NULL);
       setup_buffer_ref_mvs_inter(cpi, x, ref_frame, bsize, yv12_mb);
     }
-    // Store the best pred_mv_sad across all past frames
-    if ((cpi->sf.inter_sf.alt_ref_search_fp ||
-         cpi->sf.rt_sf.prune_inter_modes_wrt_gf_arf_based_on_sad) &&
-        cpi->ref_frame_dist_info.ref_relative_dist[ref_frame - LAST_FRAME] < 0)
-      x->best_pred_mv_sad =
-          AOMMIN(x->best_pred_mv_sad, x->pred_mv_sad[ref_frame]);
+    if (cpi->sf.inter_sf.alt_ref_search_fp ||
+        cpi->sf.rt_sf.prune_inter_modes_wrt_gf_arf_based_on_sad) {
+      // Store the best pred_mv_sad across all past frames
+      if (cpi->ref_frame_dist_info.ref_relative_dist[ref_frame - LAST_FRAME] <
+          0)
+        x->best_pred_mv_sad[0] =
+            AOMMIN(x->best_pred_mv_sad[0], x->pred_mv_sad[ref_frame]);
+      else
+        // Store the best pred_mv_sad across all future frames
+        x->best_pred_mv_sad[1] =
+            AOMMIN(x->best_pred_mv_sad[1], x->pred_mv_sad[ref_frame]);
+    }
   }
 
   if (!cpi->sf.rt_sf.use_real_time_ref_set && is_comp_ref_allowed(bsize)) {
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index cbc4bf3..2fead8f 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -213,6 +213,31 @@
   return 0;
 }
 
+static INLINE int has_closest_ref_frames(const MV_REFERENCE_FRAME *ref_frame,
+                                         int8_t closest_past_ref,
+                                         int8_t closest_future_ref) {
+  int has_closest_past_ref =
+      (ref_frame[0] == closest_past_ref) || (ref_frame[1] == closest_past_ref);
+  int has_closest_future_ref = (ref_frame[0] == closest_future_ref) ||
+                               (ref_frame[1] == closest_future_ref);
+  return (has_closest_past_ref && has_closest_future_ref);
+}
+
+static INLINE int has_best_pred_mv_sad(const MV_REFERENCE_FRAME *ref_frame,
+                                       const MACROBLOCK *const x) {
+  int has_best_past_pred_mv_sad = 0;
+  int has_best_future_pred_mv_sad = 0;
+  if (x->best_pred_mv_sad[0] < INT_MAX && x->best_pred_mv_sad[1] < INT_MAX) {
+    has_best_past_pred_mv_sad =
+        (x->pred_mv_sad[ref_frame[0]] == x->best_pred_mv_sad[0]) ||
+        (x->pred_mv_sad[ref_frame[1]] == x->best_pred_mv_sad[0]);
+    has_best_future_pred_mv_sad =
+        (x->pred_mv_sad[ref_frame[0]] == x->best_pred_mv_sad[1]) ||
+        (x->pred_mv_sad[ref_frame[1]] == x->best_pred_mv_sad[1]);
+  }
+  return (has_best_past_pred_mv_sad && has_best_future_pred_mv_sad);
+}
+
 static INLINE int prune_ref_by_selective_ref_frame(
     const AV1_COMP *const cpi, const MACROBLOCK *const x,
     const MV_REFERENCE_FRAME *const ref_frame,
@@ -230,11 +255,11 @@
       // Disable pruning if either tpl suggests that we keep the frame or
       // the pred_mv gives us the best sad
       if (x->tpl_keep_ref_frame[LAST3_FRAME] ||
-          x->pred_mv_sad[LAST3_FRAME] == x->best_pred_mv_sad) {
+          x->pred_mv_sad[LAST3_FRAME] == x->best_pred_mv_sad[0]) {
         ref_frame_list[0] = NONE_FRAME;
       }
       if (x->tpl_keep_ref_frame[LAST2_FRAME] ||
-          x->pred_mv_sad[LAST2_FRAME] == x->best_pred_mv_sad) {
+          x->pred_mv_sad[LAST2_FRAME] == x->best_pred_mv_sad[0]) {
         ref_frame_list[1] = NONE_FRAME;
       }
     }
@@ -252,11 +277,11 @@
       // Disable pruning if either tpl suggests that we keep the frame or
       // the pred_mv gives us the best sad
       if (x->tpl_keep_ref_frame[ALTREF2_FRAME] ||
-          x->pred_mv_sad[ALTREF2_FRAME] == x->best_pred_mv_sad) {
+          x->pred_mv_sad[ALTREF2_FRAME] == x->best_pred_mv_sad[0]) {
         ref_frame_list[0] = NONE_FRAME;
       }
       if (x->tpl_keep_ref_frame[BWDREF_FRAME] ||
-          x->pred_mv_sad[BWDREF_FRAME] == x->best_pred_mv_sad) {
+          x->pred_mv_sad[BWDREF_FRAME] == x->best_pred_mv_sad[0]) {
         ref_frame_list[1] = NONE_FRAME;
       }
     }
@@ -267,6 +292,21 @@
       return 1;
   }
 
+  if (x != NULL && sf->inter_sf.prune_comp_ref_frames && comp_pred) {
+    int closest_ref_frames = has_closest_ref_frames(
+        ref_frame, cpi->ref_frame_dist_info.nearest_past_ref,
+        cpi->ref_frame_dist_info.nearest_future_ref);
+    if (closest_ref_frames == 0) {
+      // Prune reference frames which are not the closest to the current frame.
+      if (sf->inter_sf.prune_comp_ref_frames >= 2) {
+        return 1;
+      } else if (sf->inter_sf.prune_comp_ref_frames == 1) {
+        // Prune reference frames with non minimum pred_mv_sad.
+        if (has_best_pred_mv_sad(ref_frame, x) == 0) return 1;
+      }
+    }
+  }
+
   return 0;
 }
 
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index a0922c2..8f6138c 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -1152,6 +1152,7 @@
     sf->inter_sf.txfm_rd_gate_level = boosted ? 0 : 4;
     // Enable fast search for all valid compound modes.
     sf->inter_sf.enable_fast_compound_mode_search = 2;
+    sf->inter_sf.prune_comp_ref_frames = 1;
 
     sf->intra_sf.chroma_intra_pruning_with_hog = 3;
 
@@ -1179,6 +1180,7 @@
 
     sf->inter_sf.prune_inter_modes_based_on_tpl = boosted ? 0 : 3;
     sf->inter_sf.selective_ref_frame = 6;
+    sf->inter_sf.prune_comp_ref_frames = 2;
     sf->inter_sf.prune_ext_comp_using_neighbors = 3;
 
     sf->intra_sf.chroma_intra_pruning_with_hog = 4;
@@ -1756,6 +1758,7 @@
   inter_sf->model_based_post_interp_filter_breakout = 0;
   inter_sf->reduce_inter_modes = 0;
   inter_sf->alt_ref_search_fp = 0;
+  inter_sf->prune_comp_ref_frames = 0;
   inter_sf->selective_ref_frame = 0;
   inter_sf->prune_ref_frame_for_rect_partitions = 0;
   inter_sf->fast_wedge_sign_estimate = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index bcb18a8..d429cfd 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -833,6 +833,17 @@
   // 2 prune inter modes w.r.t BWDREF, ALTREF2 and ALTREF reference frames
   int alt_ref_search_fp;
 
+  // Prune compound reference frames
+  // 0 no pruning
+  // 1 prune compound references which do not satisfy the two conditions:
+  //   a) The references are at a nearest distance from the current frame in
+  //   both past and future direction.
+  //   b) The references have minimum pred_mv_sad in both past and future
+  //   direction.
+  // 2 prune compound references except the one with nearest distance from the
+  //   current frame in both past and future direction.
+  int prune_comp_ref_frames;
+
   // Skip the current ref_mv in NEW_MV mode based on mv, rate cost, etc.
   // This speed feature equaling 0 means no skipping.
   // If the speed feature equals 1 or 2, skip the current ref_mv in NEW_MV mode