ext-inter: Use joint_motion_search for masked compounds

Add functions which take both components of a masked compound and
compute the resulting SAD/SSE. Extend joint_motion_search to understand
masked compounds, and use it to evaluate NEW_NEWMV modes.

Change-Id: I782199a20d119a6c61c6567df157508125ac7ce7
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index a4aff3c..63e594a 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5452,7 +5452,8 @@
                                 BLOCK_SIZE bsize, int_mv *frame_mv, int mi_row,
                                 int mi_col,
 #if CONFIG_EXT_INTER
-                                int_mv *ref_mv_sub8x8[2],
+                                int_mv *ref_mv_sub8x8[2], const uint8_t *mask,
+                                int mask_stride,
 #endif  // CONFIG_EXT_INTER
                                 int *rate_mv, const int block) {
   const AV1_COMMON *const cm = &cpi->common;
@@ -5618,10 +5619,21 @@
     // Small-range full-pixel motion search.
     bestsme =
         av1_refining_search_8p_c(x, sadpb, search_range, &cpi->fn_ptr[bsize],
+#if CONFIG_EXT_INTER
+                                 mask, mask_stride, id,
+#endif
                                  &ref_mv[id].as_mv, second_pred);
-    if (bestsme < INT_MAX)
-      bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv[id].as_mv,
-                                      second_pred, &cpi->fn_ptr[bsize], 1);
+    if (bestsme < INT_MAX) {
+#if CONFIG_EXT_INTER
+      if (mask)
+        bestsme = av1_get_mvpred_mask_var(x, best_mv, &ref_mv[id].as_mv,
+                                          second_pred, mask, mask_stride, id,
+                                          &cpi->fn_ptr[bsize], 1);
+      else
+#endif
+        bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv[id].as_mv,
+                                        second_pred, &cpi->fn_ptr[bsize], 1);
+    }
 
     x->mv_limits = tmp_mv_limits;
 
@@ -5654,7 +5666,11 @@
             x, &ref_mv[id].as_mv, cpi->common.allow_high_precision_mv,
             x->errorperbit, &cpi->fn_ptr[bsize], 0,
             cpi->sf.mv.subpel_iters_per_step, NULL, x->nmvjointcost, x->mvcost,
-            &dis, &sse, second_pred, pw, ph, 1);
+            &dis, &sse, second_pred,
+#if CONFIG_EXT_INTER
+            mask, mask_stride, id,
+#endif
+            pw, ph, 1);
 
         // Restore the reference frames.
         pd->pre[0] = backup_pred;
@@ -5664,7 +5680,11 @@
             x, &ref_mv[id].as_mv, cpi->common.allow_high_precision_mv,
             x->errorperbit, &cpi->fn_ptr[bsize], 0,
             cpi->sf.mv.subpel_iters_per_step, NULL, x->nmvjointcost, x->mvcost,
-            &dis, &sse, second_pred, pw, ph, 0);
+            &dis, &sse, second_pred,
+#if CONFIG_EXT_INTER
+            mask, mask_stride, id,
+#endif
+            pw, ph, 0);
       }
     }
 
@@ -6060,8 +6080,11 @@
                   cpi->sf.mv.subpel_force_stop,
                   cpi->sf.mv.subpel_iters_per_step,
                   cond_cost_list(cpi, cost_list), x->nmvjointcost, x->mvcost,
-                  &distortion, &x->pred_sse[mbmi->ref_frame[0]], NULL, pw, ph,
-                  1);
+                  &distortion, &x->pred_sse[mbmi->ref_frame[0]], NULL,
+#if CONFIG_EXT_INTER
+                  NULL, 0, 0,
+#endif
+                  pw, ph, 1);
 
               if (try_second) {
                 int this_var;
@@ -6088,7 +6111,11 @@
                       cpi->sf.mv.subpel_iters_per_step,
                       cond_cost_list(cpi, cost_list), x->nmvjointcost,
                       x->mvcost, &distortion, &x->pred_sse[mbmi->ref_frame[0]],
-                      NULL, pw, ph, 1);
+                      NULL,
+#if CONFIG_EXT_INTER
+                      NULL, 0, 0,
+#endif
+                      pw, ph, 1);
                   if (this_var < best_mv_var) best_mv = x->best_mv.as_mv;
                   x->best_mv.as_mv = best_mv;
                 }
@@ -6103,7 +6130,11 @@
                   cpi->sf.mv.subpel_force_stop,
                   cpi->sf.mv.subpel_iters_per_step,
                   cond_cost_list(cpi, cost_list), x->nmvjointcost, x->mvcost,
-                  &distortion, &x->pred_sse[mbmi->ref_frame[0]], NULL, 0, 0, 0);
+                  &distortion, &x->pred_sse[mbmi->ref_frame[0]], NULL,
+#if CONFIG_EXT_INTER
+                  NULL, 0, 0,
+#endif
+                  0, 0, 0);
             }
 
 // save motion search result for use in compound prediction
@@ -6165,7 +6196,7 @@
             joint_motion_search(cpi, x, bsize, frame_mv[this_mode], mi_row,
                                 mi_col,
 #if CONFIG_EXT_INTER
-                                bsi->ref_mv,
+                                bsi->ref_mv, NULL, 0,
 #endif  // CONFIG_EXT_INTER
                                 &rate_mv, index);
 #if CONFIG_EXT_INTER
@@ -6958,8 +6989,11 @@
               x, &ref_mv, cm->allow_high_precision_mv, x->errorperbit,
               &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
               cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
-              x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL, pw, ph,
-              1);
+              x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL,
+#if CONFIG_EXT_INTER
+              NULL, 0, 0,
+#endif
+              pw, ph, 1);
 
           if (try_second) {
             const int minc =
@@ -6983,7 +7017,11 @@
                   &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
                   cpi->sf.mv.subpel_iters_per_step,
                   cond_cost_list(cpi, cost_list), x->nmvjointcost, x->mvcost,
-                  &dis, &x->pred_sse[ref], NULL, pw, ph, 1);
+                  &dis, &x->pred_sse[ref], NULL,
+#if CONFIG_EXT_INTER
+                  NULL, 0, 0,
+#endif
+                  pw, ph, 1);
               if (this_var < best_mv_var) best_mv = x->best_mv.as_mv;
               x->best_mv.as_mv = best_mv;
             }
@@ -6996,8 +7034,11 @@
               x, &ref_mv, cm->allow_high_precision_mv, x->errorperbit,
               &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
               cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
-              x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL, 0, 0,
-              0);
+              x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL,
+#if CONFIG_EXT_INTER
+              NULL, 0, 0,
+#endif
+              0, 0, 0);
         }
 #if CONFIG_MOTION_VAR
         break;
@@ -7161,7 +7202,7 @@
 }
 
 static void do_masked_motion_search_indexed(
-    const AV1_COMP *const cpi, MACROBLOCK *x,
+    const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
     const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
     int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int which) {
   // NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both
@@ -7173,11 +7214,21 @@
 
   mask = av1_get_compound_type_mask(comp_data, sb_type);
 
-  if (which == 0 || which == 2)
-    do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
-                            &tmp_mv[0], &rate_mv[0], 0);
+  if (which == 2) {
+    int_mv frame_mv[TOTAL_REFS_PER_FRAME];
+    MV_REFERENCE_FRAME rf[2] = { mbmi->ref_frame[0], mbmi->ref_frame[1] };
+    assert(bsize >= BLOCK_8X8 || CONFIG_CB4X4);
 
-  if (which == 1 || which == 2) {
+    frame_mv[rf[0]].as_int = cur_mv[0].as_int;
+    frame_mv[rf[1]].as_int = cur_mv[1].as_int;
+    joint_motion_search(cpi, x, bsize, frame_mv, mi_row, mi_col, NULL, mask,
+                        mask_stride, rate_mv, 0);
+    tmp_mv[0].as_int = frame_mv[rf[0]].as_int;
+    tmp_mv[1].as_int = frame_mv[rf[1]].as_int;
+  } else if (which == 0) {
+    do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
+                            &tmp_mv[0], rate_mv, 0);
+  } else if (which == 1) {
 // get the negative mask
 #if CONFIG_COMPOUND_SEGMENT
     uint8_t inv_mask_buf[2 * MAX_SB_SQUARE];
@@ -7188,7 +7239,7 @@
     mask = av1_get_compound_type_mask_inverse(comp_data, sb_type);
 #endif  // CONFIG_COMPOUND_SEGMENT
     do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
-                            &tmp_mv[1], &rate_mv[1], 1);
+                            &tmp_mv[1], rate_mv, 1);
   }
 }
 #endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
@@ -7665,15 +7716,13 @@
   }
 }
 
-static int interinter_compound_motion_search(const AV1_COMP *const cpi,
-                                             MACROBLOCK *x,
-                                             const BLOCK_SIZE bsize,
-                                             const int this_mode, int mi_row,
-                                             int mi_col) {
+static int interinter_compound_motion_search(
+    const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
+    const BLOCK_SIZE bsize, const int this_mode, int mi_row, int mi_col) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   int_mv tmp_mv[2];
-  int rate_mvs[2], tmp_rate_mv = 0;
+  int tmp_rate_mv = 0;
   const INTERINTER_COMPOUND_DATA compound_data = {
 #if CONFIG_WEDGE
     mbmi->wedge_index,
@@ -7686,20 +7735,17 @@
     mbmi->interinter_compound_type
   };
   if (this_mode == NEW_NEWMV) {
-    do_masked_motion_search_indexed(cpi, x, &compound_data, bsize, mi_row,
-                                    mi_col, tmp_mv, rate_mvs, 2);
-    tmp_rate_mv = rate_mvs[0] + rate_mvs[1];
+    do_masked_motion_search_indexed(cpi, x, cur_mv, &compound_data, bsize,
+                                    mi_row, mi_col, tmp_mv, &tmp_rate_mv, 2);
     mbmi->mv[0].as_int = tmp_mv[0].as_int;
     mbmi->mv[1].as_int = tmp_mv[1].as_int;
   } else if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV) {
-    do_masked_motion_search_indexed(cpi, x, &compound_data, bsize, mi_row,
-                                    mi_col, tmp_mv, rate_mvs, 0);
-    tmp_rate_mv = rate_mvs[0];
+    do_masked_motion_search_indexed(cpi, x, cur_mv, &compound_data, bsize,
+                                    mi_row, mi_col, tmp_mv, &tmp_rate_mv, 0);
     mbmi->mv[0].as_int = tmp_mv[0].as_int;
   } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
-    do_masked_motion_search_indexed(cpi, x, &compound_data, bsize, mi_row,
-                                    mi_col, tmp_mv, rate_mvs, 1);
-    tmp_rate_mv = rate_mvs[1];
+    do_masked_motion_search_indexed(cpi, x, cur_mv, &compound_data, bsize,
+                                    mi_row, mi_col, tmp_mv, &tmp_rate_mv, 1);
     mbmi->mv[1].as_int = tmp_mv[1].as_int;
   }
   return tmp_rate_mv;
@@ -7726,8 +7772,8 @@
 
   if (have_newmv_in_inter_mode(this_mode) &&
       use_masked_motion_search(compound_type)) {
-    *out_rate_mv = interinter_compound_motion_search(cpi, x, bsize, this_mode,
-                                                     mi_row, mi_col);
+    *out_rate_mv = interinter_compound_motion_search(cpi, x, cur_mv, bsize,
+                                                     this_mode, mi_row, mi_col);
     av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, ctx, bsize);
     model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
                     &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
@@ -7823,8 +7869,8 @@
       frame_mv[refs[1]].as_int = single_newmv[refs[1]].as_int;
 
       if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
-        joint_motion_search(cpi, x, bsize, frame_mv, mi_row, mi_col, NULL,
-                            rate_mv, 0);
+        joint_motion_search(cpi, x, bsize, frame_mv, mi_row, mi_col, NULL, NULL,
+                            0, rate_mv, 0);
       } else {
         *rate_mv = 0;
         for (i = 0; i < 2; ++i) {