Do masked motion search based on COMPOUND_TYPE

Change-Id: I2d1b5f57a3bb19eb8c00eb4c2e6c7835047dc4ac
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index ce93693..fa82409 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -202,6 +202,10 @@
           mode == NEAR_NEWMV || mode == NEW_NEARMV);
 }
 
+static INLINE int use_masked_motion_search(COMPOUND_TYPE type) {
+  return (type == COMPOUND_WEDGE);
+}
+
 static INLINE int is_masked_compound_type(COMPOUND_TYPE type) {
 #if CONFIG_COMPOUND_SEGMENT
   return (type == COMPOUND_WEDGE || type == COMPOUND_SEG);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 13ed5ad..7001c92 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7164,7 +7164,8 @@
   best_rd_cur = pick_interinter_seg_mask(cpi, x, bsize, *preds0, *preds1);
   best_rd_cur += RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv, 0);
 
-  if (have_newmv_in_inter_mode(this_mode)) {
+  if (have_newmv_in_inter_mode(this_mode) &&
+      use_masked_motion_search(COMPOUND_SEG)) {
     *out_rate_mv = interinter_compound_motion_search(cpi, x, bsize, this_mode,
                                                      mi_row, mi_col);
     av1_build_inter_predictors_sby(xd, mi_row, mi_col, ctx, bsize);
@@ -7218,7 +7219,8 @@
   best_rd_cur = pick_interinter_wedge(cpi, x, bsize, *preds0, *preds1);
   best_rd_cur += RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv, 0);
 
-  if (have_newmv_in_inter_mode(this_mode)) {
+  if (have_newmv_in_inter_mode(this_mode) &&
+      use_masked_motion_search(COMPOUND_WEDGE)) {
     *out_rate_mv = interinter_compound_motion_search(cpi, x, bsize, this_mode,
                                                      mi_row, mi_col);
     av1_build_inter_predictors_sby(xd, mi_row, mi_col, ctx, bsize);
@@ -7762,6 +7764,7 @@
     uint8_t *preds0[1] = { pred0 };
     uint8_t *preds1[1] = { pred1 };
     int strides[1] = { bw };
+    int tmp_rate_mv;
     COMPOUND_TYPE cur_type;
 
     best_mv[0].as_int = cur_mv[0].as_int;
@@ -7779,6 +7782,7 @@
     }
 
     for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
+      tmp_rate_mv = rate_mv;
       best_rd_cur = INT64_MAX;
       mbmi->interinter_compound_data.type = cur_type;
       rs2 = av1_cost_literal(get_interinter_compound_type_bits(
@@ -7793,32 +7797,17 @@
                                    &tmp_skip_txfm_sb, &tmp_skip_sse_sb,
                                    INT64_MAX);
           if (rd != INT64_MAX)
-            rd =
+            best_rd_cur =
                 RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum);
-          best_rd_compound = rd;
+          best_rd_compound = best_rd_cur;
           break;
         case COMPOUND_WEDGE:
           if (!is_interinter_wedge_used(bsize)) break;
           if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
               best_rd_compound / 3 < ref_best_rd) {
-            int tmp_rate_mv = 0;
             best_rd_cur = build_and_cost_compound_wedge(
                 cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &orig_dst,
                 &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
-
-            if (best_rd_cur < best_rd_compound) {
-              best_rd_compound = best_rd_cur;
-              memcpy(&best_compound_data, &mbmi->interinter_compound_data,
-                     sizeof(best_compound_data));
-              if (have_newmv_in_inter_mode(this_mode)) {
-                best_tmp_rate_mv = tmp_rate_mv;
-                best_mv[0].as_int = mbmi->mv[0].as_int;
-                best_mv[1].as_int = mbmi->mv[1].as_int;
-                // reset to original mvs for next iteration
-                mbmi->mv[0].as_int = cur_mv[0].as_int;
-                mbmi->mv[1].as_int = cur_mv[1].as_int;
-              }
-            }
           }
           break;
 #if CONFIG_COMPOUND_SEGMENT
@@ -7826,29 +7815,33 @@
           if (!is_interinter_wedge_used(bsize)) break;
           if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
               best_rd_compound / 3 < ref_best_rd) {
-            int tmp_rate_mv = 0;
             best_rd_cur = build_and_cost_compound_seg(
                 cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &orig_dst,
                 &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
-
-            if (best_rd_cur < best_rd_compound) {
-              best_rd_compound = best_rd_cur;
-              memcpy(&best_compound_data, &mbmi->interinter_compound_data,
-                     sizeof(best_compound_data));
-              if (have_newmv_in_inter_mode(this_mode)) {
-                best_tmp_rate_mv = tmp_rate_mv;
-                best_mv[0].as_int = mbmi->mv[0].as_int;
-                best_mv[1].as_int = mbmi->mv[1].as_int;
-                // reset to original mvs for next iteration
-                mbmi->mv[0].as_int = cur_mv[0].as_int;
-                mbmi->mv[1].as_int = cur_mv[1].as_int;
-              }
-            }
           }
           break;
 #endif  // CONFIG_COMPOUND_SEGMENT
         default: assert(0); return 0;
       }
+
+      if (best_rd_cur < best_rd_compound) {
+        best_rd_compound = best_rd_cur;
+        memcpy(&best_compound_data, &mbmi->interinter_compound_data,
+               sizeof(best_compound_data));
+        if (have_newmv_in_inter_mode(this_mode)) {
+          if (use_masked_motion_search(cur_type)) {
+            best_tmp_rate_mv = tmp_rate_mv;
+            best_mv[0].as_int = mbmi->mv[0].as_int;
+            best_mv[1].as_int = mbmi->mv[1].as_int;
+          } else {
+            best_mv[0].as_int = cur_mv[0].as_int;
+            best_mv[1].as_int = cur_mv[1].as_int;
+          }
+        }
+      }
+      // reset to original mvs for next iteration
+      mbmi->mv[0].as_int = cur_mv[0].as_int;
+      mbmi->mv[1].as_int = cur_mv[1].as_int;
     }
     memcpy(&mbmi->interinter_compound_data, &best_compound_data,
            sizeof(INTERINTER_COMPOUND_DATA));
@@ -7857,7 +7850,7 @@
       mbmi->mv[1].as_int = best_mv[1].as_int;
       xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int;
       xd->mi[0]->bmi[0].as_mv[1].as_int = mbmi->mv[1].as_int;
-      if (mbmi->interinter_compound_data.type) {
+      if (use_masked_motion_search(mbmi->interinter_compound_data.type)) {
         rd_stats->rate += best_tmp_rate_mv - rate_mv;
         rate_mv = best_tmp_rate_mv;
       }