Refactor compound search loop

Refactors the two-loop compound search framework to create a
generic framework where any set of modes could be used in the
first pass and the rest in the second pass.

Also whether to use 2-pass or 1-pass is now a speed-feature.
When turned on, it gives 4% speed up with ~0.08% loss, as
tested on lowres, midres and AWCY-Objective-1-Fast.

Change-Id: Iaca3cc6ac36193cbfe34805795c51024df538ee8
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index e6e4da9..a505320 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -9736,7 +9736,7 @@
 
 static int compound_type_rd(
     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_col,
-    int mi_row, int_mv *cur_mv, int do_comp_distwtd, int masked_compound_used,
+    int mi_row, int_mv *cur_mv, int mode_search_mask, int masked_compound_used,
     const BUFFER_SET *orig_dst, const BUFFER_SET *tmp_dst,
     CompoundTypeRdBuffers *buffers, int *rate_mv, int64_t *rd,
     RD_STATS *rd_stats, int64_t ref_best_rd, int *is_luma_interp_done) {
@@ -9773,10 +9773,10 @@
     if (cur_type >= COMPOUND_WEDGE && !masked_compound_used) break;
     if (!is_interinter_compound_used(cur_type, bsize)) continue;
     if (cur_type == COMPOUND_DISTWTD &&
-        (!do_comp_distwtd ||
-         !cm->seq_params.order_hint_info.enable_dist_wtd_comp ||
+        (!cm->seq_params.order_hint_info.enable_dist_wtd_comp ||
          cpi->sf.use_dist_wtd_comp_flag == DIST_WTD_COMP_DISABLED))
       continue;
+    if (((1 << cur_type) & mode_search_mask) == 0) continue;
     tmp_rate_mv = *rate_mv;
     int64_t best_rd_cur = INT64_MAX;
     mbmi->interinter_comp.type = cur_type;
@@ -9822,7 +9822,7 @@
         }
       }
       // use spare buffer for following compound type try
-      restore_dst_buf(xd, *tmp_dst, 1);
+      if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
     } else {
       mbmi->comp_group_idx = 1;
       mbmi->compound_idx = 1;
@@ -9926,8 +9926,6 @@
   int_mv mv;
 } inter_mode_info;
 
-#define SEPARATE_COMP_DISTWTD_RD 1
-
 static int64_t handle_inter_mode(
     const AV1_COMP *const cpi, TileDataEnc *tile_data, MACROBLOCK *x,
     BLOCK_SIZE bsize, RD_STATS *rd_stats, RD_STATS *rd_stats_y,
@@ -9981,14 +9979,11 @@
   inter_mode_info mode_info[MAX_REF_MV_SERCH];
 
   int comp_idx;
-#if SEPARATE_COMP_DISTWTD_RD
   const int search_dist_wtd_comp =
       is_comp_pred & cm->seq_params.order_hint_info.enable_dist_wtd_comp &
       (mbmi->mode != GLOBAL_GLOBALMV) &
-      (cpi->sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
-#else
-  const int search_dist_wtd_comp = 0;
-#endif  // SEPARATE_COMP_DISTWTD_RD
+      (cpi->sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED) &
+      cpi->sf.two_loop_comp_search;
 
   // TODO(jingning): This should be deprecated shortly.
   const int has_nearmv = have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0;
@@ -10044,12 +10039,13 @@
       continue;
     }
 
+    mbmi->comp_group_idx = 0;
+    mbmi->compound_idx = 1;
     const RD_STATS backup_rd_stats = *rd_stats;
     // If !search_dist_wtd_comp, we need to force mbmi->compound_idx = 1.
     for (comp_idx = 1; comp_idx >= !search_dist_wtd_comp; --comp_idx) {
       int rs = 0;
       int compmode_interinter_cost = 0;
-      mbmi->compound_idx = comp_idx;
 
       if (is_comp_pred && comp_idx == 0) *rd_stats = backup_rd_stats;
 
@@ -10070,10 +10066,7 @@
           newmv_ret_val = args->single_newmv_valid[ref_mv_idx][ref0] ? 0 : 1;
           cur_mv[0] = args->single_newmv[ref_mv_idx][ref0];
           rate_mv = args->single_newmv_rate[ref_mv_idx][ref0];
-        } else if (!(search_dist_wtd_comp &&
-                     (cpi->sf.use_dist_wtd_comp_flag ==
-                      DIST_WTD_COMP_SKIP_MV_SEARCH) &&
-                     comp_idx == 0)) {
+        } else if (comp_idx == 1) {
           newmv_ret_val = handle_newmv(cpi, x, bsize, cur_mv, mi_row, mi_col,
                                        &rate_mv, args);
 
@@ -10179,17 +10172,30 @@
 
       int skip_build_pred = 0;
       if (is_comp_pred) {
-        if (comp_idx == 0) {
-          mbmi->interinter_comp.type = COMPOUND_DISTWTD;
+        int mode_search_mask;
+        if (cpi->sf.two_loop_comp_search) {
+          mode_search_mask = comp_idx ? (1 << COMPOUND_AVERAGE) |
+                                            (1 << COMPOUND_WEDGE) |
+                                            (1 << COMPOUND_DIFFWTD)
+                                      : (1 << COMPOUND_DISTWTD);
+        } else {
+          mode_search_mask = (1 << COMPOUND_AVERAGE) | (1 << COMPOUND_WEDGE) |
+                             (1 << COMPOUND_DIFFWTD) | (1 << COMPOUND_DISTWTD);
+        }
+        if (mode_search_mask == (1 << COMPOUND_DISTWTD) ||
+            mode_search_mask == (1 << COMPOUND_AVERAGE)) {
+          mbmi->interinter_comp.type =
+              (mode_search_mask == (1 << COMPOUND_AVERAGE)) ? COMPOUND_AVERAGE
+                                                            : COMPOUND_DISTWTD;
           mbmi->num_proj_ref = 0;
           mbmi->motion_mode = SIMPLE_TRANSLATION;
           mbmi->comp_group_idx = 0;
+          mbmi->compound_idx = (mode_search_mask == (1 << COMPOUND_AVERAGE));
 
           const int comp_index_ctx = get_comp_index_context(cm, xd);
           compmode_interinter_cost +=
               x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
-        }
-        if (comp_idx) {
+        } else {
           // Find matching interp filter or set to default interp filter
           const int need_search =
               av1_is_interp_needed(xd) && av1_is_interp_search_needed(xd);
@@ -10204,13 +10210,13 @@
           }
 
           int64_t best_rd_compound;
-          const int do_comp_distwtd = !SEPARATE_COMP_DISTWTD_RD;
           compmode_interinter_cost = compound_type_rd(
-              cpi, x, bsize, mi_col, mi_row, cur_mv, do_comp_distwtd,
+              cpi, x, bsize, mi_col, mi_row, cur_mv, mode_search_mask,
               masked_compound_used, &orig_dst, &tmp_dst, rd_buffers, &rate_mv,
               &best_rd_compound, rd_stats, ref_best_rd, &is_luma_interp_done);
           if (ref_best_rd < INT64_MAX &&
-              (best_rd_compound >> 4) * (11 + 2 * SEPARATE_COMP_DISTWTD_RD) >
+              (best_rd_compound >> 4) *
+                      (11 + 2 * cpi->sf.two_loop_comp_search) >
                   ref_best_rd) {
             restore_dst_buf(xd, orig_dst, num_planes);
             continue;
@@ -10266,8 +10272,7 @@
       }
       rd_stats->rate += compmode_interinter_cost;
 
-      if (search_dist_wtd_comp && cpi->sf.dist_wtd_comp_fast_tx_search &&
-          comp_idx == 0) {
+      if (cpi->sf.second_loop_comp_fast_tx_search && comp_idx == 0) {
         // TODO(chengchen): this speed feature introduces big loss.
         // Need better estimation of rate distortion.
         int dummy_rate;
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 5e9eda3..b5121f0 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -719,7 +719,8 @@
   sf->use_inter_txb_hash = 0;
   sf->use_mb_rd_hash = 1;
   sf->optimize_b_precheck = 0;
-  sf->dist_wtd_comp_fast_tx_search = 0;
+  sf->two_loop_comp_search = 1;
+  sf->second_loop_comp_fast_tx_search = 0;
   sf->use_dist_wtd_comp_flag = DIST_WTD_COMP_ENABLED;
   sf->reuse_inter_intra_mode = 0;
   sf->intra_angle_estimation = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 66c096d..f0816f8 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -562,8 +562,11 @@
   // Calculate RD cost before doing optimize_b, and skip if the cost is large.
   int optimize_b_precheck;
 
-  // Use model rd instead of transform search in dist_wtd_comp
-  int dist_wtd_comp_fast_tx_search;
+  // Use two-loop compound search
+  int two_loop_comp_search;
+
+  // Use model rd instead of transform search in second loop of compound search
+  int second_loop_comp_fast_tx_search;
 
   // Decide when and how to use joint_comp.
   DIST_WTD_COMP_FLAG use_dist_wtd_comp_flag;