Add speed feature prune_comp_search_by_single_result

Skip some ref frames in compound motion search using simple translation
rd and modelled rd from single motion search. The single results are
categorized into forward/backward predict direction and single inter
modes, and use the top reference frames of each category to skip some
reference frames in compound search. This feature is enabled for
speed 2 and above.

Tested foreman, city, students CIF (30 frames, bitrate=500)
Speed 2: +11% faster, coding performance 0.07% drop

STATS_CHANGED

Change-Id: I14bac8ab5aa359bd8dad6a342bbcd98f3d27aafe
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index f6c6d86..45c766b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -534,6 +534,12 @@
   UV_D113_PRED,   UV_D45_PRED,
 };
 
+typedef struct SingleInterModeState {
+  int64_t rd;
+  MV_REFERENCE_FRAME ref_frame;
+  int valid;
+} SingleInterModeState;
+
 typedef struct InterModeSearchState {
   int64_t best_rd;
   MB_MODE_INFO best_mbmode;
@@ -565,7 +571,18 @@
   int_mv single_newmv[MAX_REF_MV_SERCH][REF_FRAMES];
   int single_newmv_rate[MAX_REF_MV_SERCH][REF_FRAMES];
   int single_newmv_valid[MAX_REF_MV_SERCH][REF_FRAMES];
-  int64_t modelled_rd[MB_MODE_COUNT][REF_FRAMES];
+  int64_t modelled_rd[MB_MODE_COUNT][MAX_REF_MV_SERCH][REF_FRAMES];
+  // The rd of simple translation in single inter modes
+  int64_t simple_rd[MB_MODE_COUNT][MAX_REF_MV_SERCH][REF_FRAMES];
+
+  // Single search results by [directions][modes][reference frames]
+  SingleInterModeState single_state[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
+  int single_state_cnt[2][SINGLE_INTER_MODE_NUM];
+  SingleInterModeState single_state_modelled[2][SINGLE_INTER_MODE_NUM]
+                                            [FWD_REFS];
+  int single_state_modelled_cnt[2][SINGLE_INTER_MODE_NUM];
+
+  MV_REFERENCE_FRAME single_rd_order[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
 } InterModeSearchState;
 
 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
@@ -7754,10 +7771,11 @@
   int (*single_newmv_valid)[REF_FRAMES];
   // Pointer to array of predicted rate-distortion
   // Should point to first of 2 arrays in 2D array
-  int64_t (*modelled_rd)[REF_FRAMES];
+  int64_t (*modelled_rd)[MAX_REF_MV_SERCH][REF_FRAMES];
   InterpFilter single_filter[MB_MODE_COUNT][REF_FRAMES];
   int ref_frame_cost;
   int single_comp_cost;
+  int64_t (*simple_rd)[MAX_REF_MV_SERCH][REF_FRAMES];
 } HandleInterModeArgs;
 
 static INLINE int clamp_and_check_mv(int_mv *out_mv, int_mv in_mv,
@@ -8173,12 +8191,13 @@
   }
   if (args->modelled_rd != NULL) {
     if (has_second_ref(mbmi)) {
+      const int ref_mv_idx = mbmi->ref_mv_idx;
       int refs[2] = { mbmi->ref_frame[0],
                       (mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) };
       const int mode0 = compound_ref0_mode(mbmi->mode);
       const int mode1 = compound_ref1_mode(mbmi->mode);
-      const int64_t mrd = AOMMIN(args->modelled_rd[mode0][refs[0]],
-                                 args->modelled_rd[mode1][refs[1]]);
+      const int64_t mrd = AOMMIN(args->modelled_rd[mode0][ref_mv_idx][refs[0]],
+                                 args->modelled_rd[mode1][ref_mv_idx][refs[1]]);
       if ((*rd >> 1) > mrd && ref_best_rd < INT64_MAX) {
         return INT64_MAX;
       }
@@ -8838,6 +8857,9 @@
     }
 
     tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
+    if (mbmi->motion_mode == SIMPLE_TRANSLATION &&
+        mbmi->ref_frame[1] != INTRA_FRAME)
+      args->simple_rd[this_mode][mbmi->ref_mv_idx][mbmi->ref_frame[0]] = tmp_rd;
     if ((mbmi->motion_mode == SIMPLE_TRANSLATION &&
          mbmi->ref_frame[1] != INTRA_FRAME) ||
         (tmp_rd < best_rd)) {
@@ -9382,6 +9404,9 @@
           x, cpi, bsize, mi_row, mi_col, &tmp_dst, &orig_dst,
           args->single_filter, &rd, &rs, &skip_txfm_sb, &skip_sse_sb,
           skip_build_pred, args, ref_best_rd);
+      if (args->modelled_rd != NULL && !is_comp_pred) {
+        args->modelled_rd[this_mode][ref_mv_idx][refs[0]] = rd;
+      }
       if (ret_val != 0) {
         restore_dst_buf(xd, orig_dst, num_planes);
         continue;
@@ -9409,14 +9434,13 @@
         if (is_comp_pred) {
           const int mode0 = compound_ref0_mode(this_mode);
           const int mode1 = compound_ref1_mode(this_mode);
-          const int64_t mrd = AOMMIN(args->modelled_rd[mode0][refs[0]],
-                                     args->modelled_rd[mode1][refs[1]]);
+          const int64_t mrd =
+              AOMMIN(args->modelled_rd[mode0][ref_mv_idx][refs[0]],
+                     args->modelled_rd[mode1][ref_mv_idx][refs[1]]);
           if (rd / 4 * 3 > mrd && ref_best_rd < INT64_MAX) {
             restore_dst_buf(xd, orig_dst, num_planes);
             continue;
           }
-        } else {
-          args->modelled_rd[this_mode][refs[0]] = rd;
         }
       }
 
@@ -9491,8 +9515,6 @@
       }
       restore_dst_buf(xd, orig_dst, num_planes);
     }
-
-    args->modelled_rd = NULL;
   }
 
   if (best_rd == INT64_MAX) return INT64_MAX;
@@ -10432,9 +10454,39 @@
   av1_zero(search_state->single_newmv);
   av1_zero(search_state->single_newmv_rate);
   av1_zero(search_state->single_newmv_valid);
-  for (int i = 0; i < MB_MODE_COUNT; ++i)
-    for (int ref_frame = 0; ref_frame < REF_FRAMES; ++ref_frame)
-      search_state->modelled_rd[i][ref_frame] = INT64_MAX;
+  for (int i = 0; i < MB_MODE_COUNT; ++i) {
+    for (int j = 0; j < MAX_REF_MV_SERCH; ++j) {
+      for (int ref_frame = 0; ref_frame < REF_FRAMES; ++ref_frame) {
+        search_state->modelled_rd[i][j][ref_frame] = INT64_MAX;
+        search_state->simple_rd[i][j][ref_frame] = INT64_MAX;
+      }
+    }
+  }
+
+  for (int dir = 0; dir < 2; ++dir) {
+    for (int mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
+      for (int ref_frame = 0; ref_frame < FWD_REFS; ++ref_frame) {
+        SingleInterModeState *state;
+
+        state = &search_state->single_state[dir][mode][ref_frame];
+        state->ref_frame = NONE_FRAME;
+        state->rd = INT64_MAX;
+
+        state = &search_state->single_state_modelled[dir][mode][ref_frame];
+        state->ref_frame = NONE_FRAME;
+        state->rd = INT64_MAX;
+      }
+    }
+  }
+  for (int dir = 0; dir < 2; ++dir) {
+    for (int mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
+      for (int ref_frame = 0; ref_frame < FWD_REFS; ++ref_frame) {
+        search_state->single_rd_order[dir][mode][ref_frame] = NONE_FRAME;
+      }
+    }
+  }
+  av1_zero(search_state->single_state_cnt);
+  av1_zero(search_state->single_state_modelled_cnt);
 }
 
 static int inter_mode_search_order_independent_skip(
@@ -10788,6 +10840,256 @@
   return 1;
 }
 
+static void collect_single_states(MACROBLOCK *x,
+                                  InterModeSearchState *search_state,
+                                  const MB_MODE_INFO *const mbmi) {
+  int i, j;
+  const MV_REFERENCE_FRAME ref_frame = mbmi->ref_frame[0];
+  const PREDICTION_MODE this_mode = mbmi->mode;
+  const int dir = ref_frame <= GOLDEN_FRAME ? 0 : 1;
+  const int mode_offset = INTER_OFFSET(this_mode);
+
+  const int8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
+  MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
+  const int has_nearmv = have_nearmv_in_inter_mode(this_mode) ? 1 : 0;
+  const int ref_mv_count = mbmi_ext->ref_mv_count[ref_frame_type];
+  const int only_newmv = (this_mode == NEWMV || this_mode == NEW_NEWMV);
+  const int has_drl =
+      (has_nearmv && ref_mv_count > 2) || (only_newmv && ref_mv_count > 1);
+  const int ref_set =
+      has_drl ? AOMMIN(MAX_REF_MV_SERCH, ref_mv_count - has_nearmv) : 1;
+
+  // Simple rd
+  int64_t simple_rd = search_state->simple_rd[this_mode][0][ref_frame];
+  for (int ref_mv_idx = 1; ref_mv_idx < ref_set; ++ref_mv_idx) {
+    int64_t rd = search_state->simple_rd[this_mode][ref_mv_idx][ref_frame];
+    if (rd < simple_rd) simple_rd = rd;
+  }
+
+  // Insertion sort of single_state
+  SingleInterModeState this_state_s = { simple_rd, ref_frame, 1 };
+  SingleInterModeState *state_s = search_state->single_state[dir][mode_offset];
+  i = search_state->single_state_cnt[dir][mode_offset];
+  for (j = i; j > 0 && state_s[j - 1].rd > this_state_s.rd; --j)
+    state_s[j] = state_s[j - 1];
+  state_s[j] = this_state_s;
+  search_state->single_state_cnt[dir][mode_offset]++;
+
+  // Modelled rd
+  int64_t modelled_rd = search_state->modelled_rd[this_mode][0][ref_frame];
+  for (int ref_mv_idx = 1; ref_mv_idx < ref_set; ++ref_mv_idx) {
+    int64_t rd = search_state->modelled_rd[this_mode][ref_mv_idx][ref_frame];
+    if (rd < modelled_rd) modelled_rd = rd;
+  }
+
+  // Insertion sort of single_state_modelled
+  SingleInterModeState this_state_m = { modelled_rd, ref_frame, 1 };
+  SingleInterModeState *state_m =
+      search_state->single_state_modelled[dir][mode_offset];
+  i = search_state->single_state_modelled_cnt[dir][mode_offset];
+  for (j = i; j > 0 && state_m[j - 1].rd > this_state_m.rd; --j)
+    state_m[j] = state_m[j - 1];
+  state_m[j] = this_state_m;
+  search_state->single_state_modelled_cnt[dir][mode_offset]++;
+}
+
+static void analyze_single_states(const AV1_COMP *cpi,
+                                  InterModeSearchState *search_state) {
+  int i, j, dir, mode;
+  if (cpi->sf.prune_comp_search_by_single_result >= 1) {
+    for (dir = 0; dir < 2; ++dir) {
+      int64_t best_rd;
+      SingleInterModeState(*state)[FWD_REFS];
+
+      // Find the best simple rd of all modes
+      state = search_state->single_state[dir];
+      best_rd = INT64_MAX;
+      for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
+        if (state[mode][0].rd < best_rd) best_rd = state[mode][0].rd;
+      }
+      // Prune the unlikely reference frames
+      for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
+        for (i = 0; i < search_state->single_state_cnt[dir][mode]; ++i) {
+          if (state[mode][i].rd != INT64_MAX &&
+              (state[mode][i].rd >> 1) > best_rd) {
+            state[mode][i].valid = 0;
+          }
+        }
+      }
+
+      // Find the best modelled rd of all modes
+      state = search_state->single_state_modelled[dir];
+      best_rd = INT64_MAX;
+      for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
+        if (state[mode][0].rd < best_rd) best_rd = state[mode][0].rd;
+      }
+      // Prune the unlikely reference frames
+      for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
+        for (i = 0; i < search_state->single_state_modelled_cnt[dir][mode];
+             ++i) {
+          if (state[mode][i].rd != INT64_MAX &&
+              (state[mode][i].rd >> 1) > best_rd) {
+            state[mode][i].valid = 0;
+          }
+        }
+      }
+    }
+  }
+
+  // Ordering by simple rd first, then by modelled rd
+  for (dir = 0; dir < 2; ++dir) {
+    for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
+      const int state_cnt_s = search_state->single_state_cnt[dir][mode];
+      const int state_cnt_m =
+          search_state->single_state_modelled_cnt[dir][mode];
+      SingleInterModeState *state_s = search_state->single_state[dir][mode];
+      SingleInterModeState *state_m =
+          search_state->single_state_modelled[dir][mode];
+      int count = 0;
+      const int max_candidates = AOMMAX(state_cnt_s, state_cnt_m);
+      for (i = 0; i < state_cnt_s; ++i) {
+        if (state_s[i].rd == INT64_MAX) break;
+        if (state_s[i].valid)
+          search_state->single_rd_order[dir][mode][count++] =
+              state_s[i].ref_frame;
+      }
+      if (count < max_candidates) {
+        for (i = 0; i < state_cnt_m; ++i) {
+          if (state_m[i].rd == INT64_MAX) break;
+          if (state_m[i].valid) {
+            int ref_frame = state_m[i].ref_frame;
+            int match = 0;
+            // Check if existing already
+            for (j = 0; j < count; ++j) {
+              if (search_state->single_rd_order[dir][mode][j] == ref_frame) {
+                match = 1;
+                break;
+              }
+            }
+            if (!match) {
+              // Check if this ref_frame is removed in simple rd
+              int valid = 1;
+              for (j = 0; j < state_cnt_s; j++) {
+                if (ref_frame == state_s[j].ref_frame && !state_s[j].valid) {
+                  valid = 0;
+                  break;
+                }
+              }
+              if (valid)
+                search_state->single_rd_order[dir][mode][count++] = ref_frame;
+            }
+            if (count >= max_candidates) break;
+          }
+        }
+      }
+    }
+  }
+}
+
+static int compound_skip_get_candidates(
+    const AV1_COMP *cpi, const InterModeSearchState *search_state,
+    const int dir, const PREDICTION_MODE mode) {
+  const int mode_offset = INTER_OFFSET(mode);
+  const SingleInterModeState *state =
+      search_state->single_state[dir][mode_offset];
+  const SingleInterModeState *state_modelled =
+      search_state->single_state_modelled[dir][mode_offset];
+  int max_candidates = 0;
+  int candidates;
+
+  for (int i = 0; i < FWD_REFS; ++i) {
+    if (search_state->single_rd_order[dir][mode_offset][i] == NONE_FRAME) break;
+    max_candidates++;
+  }
+
+  candidates = max_candidates;
+  if (cpi->sf.prune_comp_search_by_single_result >= 2) {
+    candidates = AOMMIN(2, max_candidates);
+  }
+  if (cpi->sf.prune_comp_search_by_single_result >= 3) {
+    if (state[0].rd != INT64_MAX && state_modelled[0].rd != INT64_MAX &&
+        state[0].ref_frame == state_modelled[0].ref_frame)
+      candidates = 1;
+    if (mode == NEARMV || mode == GLOBALMV) candidates = 1;
+  }
+  return candidates;
+}
+
+static int compound_skip_by_single_states(
+    const AV1_COMP *cpi, const InterModeSearchState *search_state,
+    const PREDICTION_MODE this_mode, const MV_REFERENCE_FRAME ref_frame,
+    const MV_REFERENCE_FRAME second_ref_frame) {
+  const int refs[2] = { ref_frame, second_ref_frame };
+  const int mode0 = compound_ref0_mode(this_mode);
+  const int mode0_offset = INTER_OFFSET(mode0);
+  const int mode0_dir = refs[0] <= GOLDEN_FRAME ? 0 : 1;
+  const int mode1 = compound_ref1_mode(this_mode);
+  const int mode1_offset = INTER_OFFSET(mode1);
+  const int mode1_dir = refs[1] <= GOLDEN_FRAME ? 0 : 1;
+  const int candidates0 =
+      compound_skip_get_candidates(cpi, search_state, mode0_dir, mode0);
+  const int candidates1 =
+      compound_skip_get_candidates(cpi, search_state, mode1_dir, mode1);
+  int ref0_searched = 0;
+  int ref1_searched = 0;
+  int i;
+  const MV_REFERENCE_FRAME *ref0_order =
+      search_state->single_rd_order[mode0_dir][mode0_offset];
+  const MV_REFERENCE_FRAME *ref1_order =
+      search_state->single_rd_order[mode1_dir][mode1_offset];
+
+  for (i = 0; i < search_state->single_state_cnt[mode0_dir][mode0_offset];
+       ++i) {
+    if (search_state->single_state[mode0_dir][mode0_offset][i].ref_frame ==
+        refs[0]) {
+      ref0_searched = 1;
+      break;
+    }
+  }
+  for (i = 0; i < search_state->single_state_cnt[mode1_dir][mode1_offset];
+       ++i) {
+    if (search_state->single_state[mode1_dir][mode1_offset][i].ref_frame ==
+        refs[1]) {
+      ref1_searched = 1;
+      break;
+    }
+  }
+
+  if (mode0_dir != mode1_dir) {
+    // Bi-directional prediction
+    if (ref0_searched) {
+      int match = 0;
+      for (i = 0; i < candidates0; i++) {
+        if (refs[0] == ref0_order[i]) {
+          match = 1;
+          break;
+        }
+      }
+      if (!match) return 1;
+    }
+
+    // Only GLOBALMV is the same as single mode
+    // NEWMV should be similar because it has motion search
+    if (ref1_searched && (mode1 == NEWMV || mode1 == GLOBALMV)) {
+      int match = 0;
+      for (i = 0; i < candidates1; i++) {
+        if (refs[1] == ref1_order[i]) {
+          match = 1;
+          break;
+        }
+      }
+      if (!match) return 1;
+    }
+  } else {
+    // Uni-directional prediction
+    if (ref0_searched && ref1_searched) {
+      if (ref0_order[0] != refs[0] && ref1_order[0] != refs[1]) return 1;
+    }
+  }
+
+  return 0;
+}
+
 void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data,
                                MACROBLOCK *x, int mi_row, int mi_col,
                                RD_STATS *rd_cost, BLOCK_SIZE bsize,
@@ -10822,7 +11124,7 @@
     NULL,      NULL,
     NULL,      NULL,
     { { 0 } }, INT_MAX,
-    INT_MAX
+    INT_MAX,   NULL
   };
   for (i = 0; i < REF_FRAMES; ++i) x->pred_sse[i] = INT_MAX;
 
@@ -10848,6 +11150,7 @@
 
   int intra_mode_num = 0;
   int intra_mode_idx_ls[MAX_MODES];
+  int reach_first_comp_mode = 0;
 
   for (int midx = 0; midx < MAX_MODES; ++midx) {
     int mode_index = mode_map[midx];
@@ -10869,11 +11172,25 @@
     x->skip = 0;
     set_ref_ptrs(cm, xd, ref_frame, second_ref_frame);
 
+    // Reach the first compound prediction mode
+    if (sf->prune_comp_search_by_single_result > 0 &&
+        second_ref_frame > INTRA_FRAME && reach_first_comp_mode == 0) {
+      analyze_single_states(cpi, &search_state);
+      reach_first_comp_mode = 1;
+    }
+
     if (inter_mode_search_order_independent_skip(cpi, ctx, x, bsize, mode_index,
                                                  mi_row, mi_col, mode_skip_mask,
                                                  ref_frame_skip_mask))
       continue;
 
+    if (sf->prune_comp_search_by_single_result > 0 &&
+        second_ref_frame > INTRA_FRAME) {
+      if (compound_skip_by_single_states(cpi, &search_state, this_mode,
+                                         ref_frame, second_ref_frame))
+        continue;
+    }
+
     if (ref_frame == INTRA_FRAME) {
       if (sf->skip_intra_in_interframe && search_state.skip_intra_modes)
         continue;
@@ -10958,6 +11275,7 @@
         args.modelled_rd = search_state.modelled_rd;
         args.single_comp_cost = real_compmode_cost;
         args.ref_frame_cost = ref_frame_cost;
+        args.simple_rd = search_state.simple_rd;
 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
         this_rd = handle_inter_mode(cpi, x, bsize, &rd_stats, &rd_stats_y,
                                     &rd_stats_uv, &disable_skip, mi_row, mi_col,
@@ -10975,6 +11293,11 @@
         rate_uv = rd_stats_uv.rate;
       }
 
+      if (sf->prune_comp_search_by_single_result > 0 &&
+          is_inter_singleref_mode(this_mode)) {
+        collect_single_states(x, &search_state, mbmi);
+      }
+
       if (this_rd == INT64_MAX) continue;
 
       this_skip2 = mbmi->skip;
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 0e5a9c9..65ad9d6 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -241,6 +241,7 @@
     sf->allow_partition_search_skip = 1;
     sf->disable_wedge_search_var_thresh = 100;
     sf->fast_wedge_sign_estimate = 1;
+    sf->prune_comp_search_by_single_result = 1;
   }
 
   if (speed >= 3) {
@@ -258,6 +259,7 @@
     sf->adaptive_rd_thresh = 2;
     sf->tx_type_search.prune_mode = PRUNE_2D_FAST;
     sf->gm_search_type = GM_DISABLE_SEARCH;
+    sf->prune_comp_search_by_single_result = 2;
   }
 
   if (speed >= 4) {
@@ -502,6 +504,7 @@
   sf->use_fast_interpolation_filter_search = 0;
   sf->skip_repeat_interpolation_filter_search = 0;
   sf->use_hash_based_trellis = 0;
+  sf->prune_comp_search_by_single_result = 0;
 
   // Set decoder side speed feature to use less dual sgr modes
   sf->dual_sgr_penalty_level = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 5a5230d..8ed3c40 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -631,6 +631,13 @@
 
   // Dynamically estimate final rd from prediction error and mode cost
   int inter_mode_rd_model_estimation;
+
+  // Skip some ref frames in compound motion search by single motion search
+  // result. Has three levels for now: 0 referring to no skipping, and 1 - 3
+  // increasing aggressiveness of skipping in order.
+  // Note: The search order might affect the result. It is better to search same
+  // single inter mode as a group.
+  int prune_comp_search_by_single_result;
 } SPEED_FEATURES;
 
 struct AV1_COMP;