Use simple-translation to filter near mv

Before performing the motion vector search, do a simple translation
and check the RDCOST. Only do the full motion vector search on
mv that perform well in the initial check.

STATS_CHANGED
Fewer MV for near MV will be searched.
For lowres, BD-rate (psnr) drop is 0.06%, speed up is ~3%
For midres, BD-rate (psnr) drop is 0.04%, speed up is ~5%

Change-Id: I159cb9f173702dae4a420c1783e7c044819d6871
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 266f702..6a50454 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -225,7 +225,7 @@
   INTERPOLATION_FILTER_STATS interp_filter_stats[2][MAX_INTERP_FILTER_STATS];
   int interp_filter_stats_idx[2];
 
-  // prune_comp_search_by_single_result (3:MAX_REF_MV_SERCH)
+  // prune_comp_search_by_single_result (3:MAX_REF_MV_SEARCH)
   SimpleRDState simple_rd_state[SINGLE_REF_MODES][3];
 
   // Activate constrained coding block partition search range.
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index c7383d3..18fe9b8 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -716,12 +716,12 @@
   int64_t best_pred_rd[REFERENCE_MODES];
   int64_t best_pred_diff[REFERENCE_MODES];
   // Save a set of single_newmv for each checked ref_mv.
-  int_mv single_newmv[MAX_REF_MV_SERCH][REF_FRAMES];
-  int single_newmv_rate[MAX_REF_MV_SERCH][REF_FRAMES];
-  int single_newmv_valid[MAX_REF_MV_SERCH][REF_FRAMES];
-  int64_t modelled_rd[MB_MODE_COUNT][MAX_REF_MV_SERCH][REF_FRAMES];
+  int_mv single_newmv[MAX_REF_MV_SEARCH][REF_FRAMES];
+  int single_newmv_rate[MAX_REF_MV_SEARCH][REF_FRAMES];
+  int single_newmv_valid[MAX_REF_MV_SEARCH][REF_FRAMES];
+  int64_t modelled_rd[MB_MODE_COUNT][MAX_REF_MV_SEARCH][REF_FRAMES];
   // The rd of simple translation in single inter modes
-  int64_t simple_rd[MB_MODE_COUNT][MAX_REF_MV_SERCH][REF_FRAMES];
+  int64_t simple_rd[MB_MODE_COUNT][MAX_REF_MV_SEARCH][REF_FRAMES];
 
   // Single search results by [directions][modes][reference frames]
   SingleInterModeState single_state[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
@@ -8034,11 +8034,11 @@
   int (*single_newmv_valid)[REF_FRAMES];
   // Pointer to array of predicted rate-distortion
   // Should point to first of 2 arrays in 2D array
-  int64_t (*modelled_rd)[MAX_REF_MV_SERCH][REF_FRAMES];
+  int64_t (*modelled_rd)[MAX_REF_MV_SEARCH][REF_FRAMES];
   InterpFilter single_filter[MB_MODE_COUNT][REF_FRAMES];
   int ref_frame_cost;
   int single_comp_cost;
-  int64_t (*simple_rd)[MAX_REF_MV_SERCH][REF_FRAMES];
+  int64_t (*simple_rd)[MAX_REF_MV_SEARCH][REF_FRAMES];
   int skip_motion_mode;
   INTERINTRA_MODE *inter_intra_mode;
   int single_ref_first_pass;
@@ -9822,7 +9822,7 @@
 
 static INLINE int get_drl_cost(const MB_MODE_INFO *mbmi,
                                const MB_MODE_INFO_EXT *mbmi_ext,
-                               int (*drl_mode_cost0)[2],
+                               const int (*const drl_mode_cost0)[2],
                                int8_t ref_frame_type) {
   int cost = 0;
   if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) {
@@ -10132,8 +10132,8 @@
   return best_compmode_interinter_cost;
 }
 
-static INLINE int is_single_newmv_valid(HandleInterModeArgs *args,
-                                        MB_MODE_INFO *mbmi,
+static INLINE int is_single_newmv_valid(const HandleInterModeArgs *const args,
+                                        const MB_MODE_INFO *const mbmi,
                                         PREDICTION_MODE this_mode) {
   for (int ref_idx = 0; ref_idx < 2; ++ref_idx) {
     const PREDICTION_MODE single_mode = get_single_mode(this_mode, ref_idx, 1);
@@ -10157,11 +10157,55 @@
   const int has_drl =
       (has_nearmv && ref_mv_count > 2) || (only_newmv && ref_mv_count > 1);
   const int ref_set =
-      has_drl ? AOMMIN(MAX_REF_MV_SERCH, ref_mv_count - has_nearmv) : 1;
+      has_drl ? AOMMIN(MAX_REF_MV_SEARCH, ref_mv_count - has_nearmv) : 1;
 
   return ref_set;
 }
 
+// Whether this reference motion vector can be skipped, based on initial
+// heuristics.
+static bool ref_mv_idx_early_breakout(MACROBLOCK *x,
+                                      const SPEED_FEATURES *const sf,
+                                      const HandleInterModeArgs *const args,
+                                      int64_t ref_best_rd, int ref_mv_idx) {
+  MACROBLOCKD *xd = &x->e_mbd;
+  MB_MODE_INFO *mbmi = xd->mi[0];
+  const MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
+  const int8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
+  if (sf->reduce_inter_modes && ref_mv_idx > 0) {
+    if (mbmi->ref_frame[0] == LAST2_FRAME ||
+        mbmi->ref_frame[0] == LAST3_FRAME ||
+        mbmi->ref_frame[1] == LAST2_FRAME ||
+        mbmi->ref_frame[1] == LAST3_FRAME) {
+      const int has_nearmv = have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0;
+      if (mbmi_ext->weight[ref_frame_type][ref_mv_idx + has_nearmv] <
+          REF_CAT_LEVEL) {
+        return true;
+      }
+    }
+  }
+  const int is_comp_pred = has_second_ref(mbmi);
+  if (sf->prune_single_motion_modes_by_simple_trans && !is_comp_pred &&
+      args->single_ref_first_pass == 0) {
+    if (args->simple_rd_state[ref_mv_idx].early_skipped) {
+      return true;
+    }
+  }
+  mbmi->ref_mv_idx = ref_mv_idx;
+  if (is_comp_pred && (!is_single_newmv_valid(args, mbmi, mbmi->mode))) {
+    return true;
+  }
+  size_t est_rd_rate = args->ref_frame_cost + args->single_comp_cost;
+  const int drl_cost =
+      get_drl_cost(mbmi, mbmi_ext, x->drl_mode_cost0, ref_frame_type);
+  est_rd_rate += drl_cost;
+  if (RDCOST(x->rdmult, est_rd_rate, 0) > ref_best_rd &&
+      mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) {
+    return true;
+  }
+  return false;
+}
+
 typedef struct {
   int64_t rd;
   int drl_cost;
@@ -10169,6 +10213,152 @@
   int_mv mv;
 } inter_mode_info;
 
+// Compute the estimated RD cost for the motion vector with simple translation.
+static int64_t simple_translation_pred_rd(
+    AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
+    HandleInterModeArgs *args, int ref_mv_idx, inter_mode_info *mode_info,
+    int64_t ref_best_rd, BLOCK_SIZE bsize, int mi_row, int mi_col) {
+  MACROBLOCKD *xd = &x->e_mbd;
+  MB_MODE_INFO *mbmi = xd->mi[0];
+  MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
+  const int8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
+  const AV1_COMMON *cm = &cpi->common;
+  const int is_comp_pred = has_second_ref(mbmi);
+
+  struct macroblockd_plane *p = xd->plane;
+  const BUFFER_SET orig_dst = {
+    { p[0].dst.buf, p[1].dst.buf, p[2].dst.buf },
+    { p[0].dst.stride, p[1].dst.stride, p[2].dst.stride },
+  };
+  av1_init_rd_stats(rd_stats);
+
+  mbmi->interinter_comp.type = COMPOUND_AVERAGE;
+  mbmi->comp_group_idx = 0;
+  mbmi->compound_idx = 1;
+  if (mbmi->ref_frame[1] == INTRA_FRAME) {
+    mbmi->ref_frame[1] = NONE_FRAME;
+  }
+  int16_t mode_ctx =
+      av1_mode_context_analyzer(mbmi_ext->mode_context, mbmi->ref_frame);
+
+  mbmi->num_proj_ref = 0;
+  mbmi->motion_mode = SIMPLE_TRANSLATION;
+  mbmi->ref_mv_idx = ref_mv_idx;
+
+  rd_stats->rate += args->ref_frame_cost + args->single_comp_cost;
+  const int drl_cost =
+      get_drl_cost(mbmi, mbmi_ext, x->drl_mode_cost0, ref_frame_type);
+  rd_stats->rate += drl_cost;
+  mode_info[ref_mv_idx].drl_cost = drl_cost;
+
+  int_mv cur_mv[2];
+  if (!build_cur_mv(cur_mv, mbmi->mode, cm, x)) {
+    return INT64_MAX;
+  }
+  assert(have_nearmv_in_inter_mode(mbmi->mode));
+  for (int i = 0; i < is_comp_pred + 1; ++i) {
+    mbmi->mv[i].as_int = cur_mv[i].as_int;
+  }
+  const int ref_mv_cost = cost_mv_ref(x, mbmi->mode, mode_ctx);
+  rd_stats->rate += ref_mv_cost;
+
+  if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd) {
+    return INT64_MAX;
+  }
+
+  mbmi->motion_mode = SIMPLE_TRANSLATION;
+  mbmi->num_proj_ref = 0;
+  if (is_comp_pred) {
+    // Only compound_average
+    mbmi->interinter_comp.type = COMPOUND_AVERAGE;
+    mbmi->comp_group_idx = 0;
+    mbmi->compound_idx = 1;
+  }
+  set_default_interp_filters(mbmi, cm->interp_filter);
+
+  av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, &orig_dst, bsize,
+                                AOM_PLANE_Y, AOM_PLANE_Y);
+  int est_rate;
+  int64_t est_dist;
+  model_rd_sb_fn[MODELRD_CURVFIT](cpi, bsize, x, xd, 0, 0, mi_row, mi_col,
+                                  &est_rate, &est_dist, NULL, NULL, NULL, NULL,
+                                  NULL);
+  return RDCOST(x->rdmult, rd_stats->rate + est_rate, est_dist);
+}
+
+// Represents a set of integers, from 0 to sizeof(int) * 8, as bits in
+// an integer. 0 for the i-th bit means that integer is excluded, 1 means
+// it is included.
+static INLINE void mask_set_bit(int *mask, int index) { *mask |= (1 << index); }
+
+static INLINE void mask_clear_bit(int *mask, int index) {
+  *mask &= ~(1 << index);
+}
+
+static INLINE bool mask_check_bit(int mask, int index) {
+  return (mask >> index) & 0x1;
+}
+
+// Before performing the full MV search in handle_inter_mode, do a simple
+// translation search and see if we can eliminate any motion vectors.
+// Returns an integer where, if the i-th bit is set, it means that the i-th
+// motion vector should be searched. This is only set for NEAR_MV.
+static int ref_mv_idx_to_search(AV1_COMP *const cpi, MACROBLOCK *x,
+                                RD_STATS *rd_stats,
+                                HandleInterModeArgs *const args,
+                                int64_t ref_best_rd, inter_mode_info *mode_info,
+                                BLOCK_SIZE bsize, int mi_row, int mi_col,
+                                const int ref_set) {
+  const MACROBLOCKD *const xd = &x->e_mbd;
+  const MB_MODE_INFO *const mbmi = xd->mi[0];
+  const PREDICTION_MODE this_mode = mbmi->mode;
+
+  // Only search indices if they have some chance of being good.
+  int good_indices = 0;
+  for (int i = 0; i < ref_set; ++i) {
+    if (ref_mv_idx_early_breakout(x, &cpi->sf, args, ref_best_rd, i)) {
+      continue;
+    }
+    mask_set_bit(&good_indices, i);
+  }
+
+  // Only prune in NEARMV mode, if the speed feature is set, and the block size
+  // is large enough. If these conditions are not met, return all good indices
+  // found so far.
+  if (!cpi->sf.prune_mode_search_simple_translation) return good_indices;
+  if (!have_nearmv_in_inter_mode(this_mode)) return good_indices;
+  if (num_pels_log2_lookup[bsize] <= 6) return good_indices;
+
+  // Calculate the RD cost for the motion vectors using simple translation.
+  int64_t idx_rdcost[] = { INT64_MAX, INT64_MAX, INT64_MAX };
+  for (int ref_mv_idx = 0; ref_mv_idx < ref_set; ++ref_mv_idx) {
+    // If this index is bad, ignore it.
+    if (!mask_check_bit(good_indices, ref_mv_idx)) {
+      continue;
+    }
+    idx_rdcost[ref_mv_idx] = simple_translation_pred_rd(
+        cpi, x, rd_stats, args, ref_mv_idx, mode_info, ref_best_rd, bsize,
+        mi_row, mi_col);
+  }
+  // Find the index with the best RD cost.
+  int best_idx = 0;
+  for (int i = 1; i < MAX_REF_MV_SEARCH; ++i) {
+    if (idx_rdcost[i] < idx_rdcost[best_idx]) {
+      best_idx = i;
+    }
+  }
+  // Only include indices that are good and within a % of the best.
+  const double dth = has_second_ref(mbmi) ? 1.05 : 1.001;
+  int result = 0;
+  for (int i = 0; i < ref_set; ++i) {
+    if (mask_check_bit(good_indices, i) &&
+        (1.0 * idx_rdcost[i]) / idx_rdcost[best_idx] < dth) {
+      mask_set_bit(&result, i);
+    }
+  }
+  return result;
+}
+
 static int64_t handle_inter_mode(
     AV1_COMP *const cpi, TileDataEnc *tile_data, MACROBLOCK *x,
     BLOCK_SIZE bsize, RD_STATS *rd_stats, RD_STATS *rd_stats_y,
@@ -10219,7 +10409,7 @@
   int64_t newmv_ret_val = INT64_MAX;
   int_mv backup_mv[2] = { { 0 } };
   int backup_rate_mv = 0;
-  inter_mode_info mode_info[MAX_REF_MV_SERCH];
+  inter_mode_info mode_info[MAX_REF_MV_SEARCH];
 
   int mode_search_mask[2];
   const int do_two_loop_comp_search =
@@ -10236,30 +10426,21 @@
                          (1 << COMPOUND_WEDGE) | (1 << COMPOUND_DIFFWTD)) -
                         mode_search_mask[0];
 
-  // TODO(jingning): This should be deprecated shortly.
-  const int has_nearmv = have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0;
+  // First, perform a simple translation search for each of the indices. If
+  // an index performs well, it will be fully searched here.
   const int ref_set = get_drl_refmv_count(x, mbmi->ref_frame, this_mode);
-
+  int idx_mask =
+      ref_mv_idx_to_search(cpi, x, rd_stats, args, ref_best_rd, mode_info,
+                           bsize, mi_row, mi_col, ref_set);
   for (int ref_mv_idx = 0; ref_mv_idx < ref_set; ++ref_mv_idx) {
     mode_info[ref_mv_idx].mv.as_int = INVALID_MV;
     mode_info[ref_mv_idx].rd = INT64_MAX;
-
-    if (cpi->sf.reduce_inter_modes && ref_mv_idx > 0) {
-      if (mbmi->ref_frame[0] == LAST2_FRAME ||
-          mbmi->ref_frame[0] == LAST3_FRAME ||
-          mbmi->ref_frame[1] == LAST2_FRAME ||
-          mbmi->ref_frame[1] == LAST3_FRAME) {
-        if (mbmi_ext->weight[ref_frame_type][ref_mv_idx + has_nearmv] <
-            REF_CAT_LEVEL) {
-          continue;
-        }
-      }
+    if (!mask_check_bit(idx_mask, ref_mv_idx)) {
+      // MV did not perform well in simple translation search. Skip it.
+      continue;
     }
-    if (cpi->sf.prune_single_motion_modes_by_simple_trans && !is_comp_pred &&
-        args->single_ref_first_pass == 0) {
-      if (args->simple_rd_state[ref_mv_idx].early_skipped) {
-        continue;
-      }
+    if (ref_mv_idx_early_breakout(x, &cpi->sf, args, ref_best_rd, ref_mv_idx)) {
+      continue;
     }
     av1_init_rd_stats(rd_stats);
 
@@ -10275,21 +10456,12 @@
     mbmi->motion_mode = SIMPLE_TRANSLATION;
     mbmi->ref_mv_idx = ref_mv_idx;
 
-    if (is_comp_pred && (!is_single_newmv_valid(args, mbmi, this_mode))) {
-      continue;
-    }
-
     rd_stats->rate += args->ref_frame_cost + args->single_comp_cost;
     const int drl_cost =
         get_drl_cost(mbmi, mbmi_ext, x->drl_mode_cost0, ref_frame_type);
     rd_stats->rate += drl_cost;
     mode_info[ref_mv_idx].drl_cost = drl_cost;
 
-    if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd &&
-        mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) {
-      continue;
-    }
-
     const RD_STATS backup_rd_stats = *rd_stats;
 
     for (int comp_loop_idx = 0; comp_loop_idx <= do_two_loop_comp_search;
@@ -11627,7 +11799,7 @@
   av1_zero(search_state->single_newmv_rate);
   av1_zero(search_state->single_newmv_valid);
   for (int i = 0; i < MB_MODE_COUNT; ++i) {
-    for (int j = 0; j < MAX_REF_MV_SERCH; ++j) {
+    for (int j = 0; j < MAX_REF_MV_SEARCH; ++j) {
       for (int ref_frame = 0; ref_frame < REF_FRAMES; ++ref_frame) {
         search_state->modelled_rd[i][j][ref_frame] = INT64_MAX;
         search_state->simple_rd[i][j][ref_frame] = INT64_MAX;
@@ -12462,7 +12634,7 @@
   for (int i = 0; i < SINGLE_REF_MODES; ++i) {
     const MODE_DEFINITION *mode_order = &av1_mode_order[i];
     const MV_REFERENCE_FRAME ref_frame = mode_order->ref_frame[0];
-    for (int k = 0; k < MAX_REF_MV_SERCH; ++k) {
+    for (int k = 0; k < MAX_REF_MV_SEARCH; ++k) {
       const int64_t rd = x->simple_rd_state[i][k].rd_stats.rdcost;
       rdcosts[ref_frame] = AOMMIN(rdcosts[ref_frame], rd);
       min_rd = AOMMIN(min_rd, rd);
@@ -12644,7 +12816,7 @@
       if (!comp_pred) {  // single ref mode
         if (args.single_ref_first_pass) {
           // clear stats
-          for (int k = 0; k < MAX_REF_MV_SERCH; ++k) {
+          for (int k = 0; k < MAX_REF_MV_SEARCH; ++k) {
             x->simple_rd_state[midx][k].rd_stats.rdcost = INT64_MAX;
             x->simple_rd_state[midx][k].early_skipped = 0;
           }
@@ -13288,7 +13460,7 @@
       if (!comp_pred && ref_frame != INTRA_FRAME) {  // single ref mode
         if (args.single_ref_first_pass) {
           // clear stats
-          for (int k = 0; k < MAX_REF_MV_SERCH; ++k) {
+          for (int k = 0; k < MAX_REF_MV_SEARCH; ++k) {
             x->simple_rd_state[midx][k].rd_stats.rdcost = INT64_MAX;
             x->simple_rd_state[midx][k].early_skipped = 0;
           }
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 3293834..ae00b90 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -26,7 +26,7 @@
 extern "C" {
 #endif
 
-#define MAX_REF_MV_SERCH 3
+#define MAX_REF_MV_SEARCH 3
 #define DEFAULT_LUMA_INTERP_SKIP_FLAG 1
 #define DEFAULT_CHROMA_INTERP_SKIP_FLAG 2
 #define DEFAULT_INTERP_SKIP_FLAG \
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 120f9d5..8a64cbb 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -223,6 +223,7 @@
   sf->inter_mode_rd_model_estimation = 1;
   sf->inter_mode_rd_model_estimation_adaptive = 0;
 
+  sf->prune_mode_search_simple_translation = 1;
   sf->two_loop_comp_search = 0;
   sf->prune_ref_frame_for_rect_partitions =
       boosted ? 0 : (is_boosted_arf2_bwd_type ? 1 : 2);
@@ -445,6 +446,7 @@
   // TODO(debargha): Test, tweak and turn on either 1 or 2
   sf->inter_mode_rd_model_estimation = 0;
   sf->inter_mode_rd_model_estimation_adaptive = 0;
+  sf->prune_mode_search_simple_translation = 0;
   sf->two_loop_comp_search = 0;
 
   sf->prune_ref_frame_for_rect_partitions = !boosted;
@@ -787,6 +789,7 @@
   sf->inter_mode_rd_model_estimation = 0;
   sf->inter_mode_rd_model_estimation_adaptive = 0;
 
+  sf->prune_mode_search_simple_translation = 0;
   sf->obmc_full_pixel_search_level = 0;
   sf->skip_sharp_interp_filter_search = 0;
   sf->prune_comp_type_by_comp_avg = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 1eab76d..ce1f021 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -711,6 +711,11 @@
 
   // Perform coarse ME before calculating variance in variance-based partition
   int estimate_motion_for_var_based_partition;
+
+  // Instead of performing a full MV search, do a simple translation first
+  // and only perform a full MV search on the motion vectors that performed
+  // well.
+  int prune_mode_search_simple_translation;
 } SPEED_FEATURES;
 
 struct AV1_COMP;