Refactor ext-inter to loop through all masked modes in rdopt

No change in performance

Change-Id: Ie105a7baf6a2c2258d3ef117e727957e4393f51b
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index bc77687..1ae85b0 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -251,6 +251,24 @@
   return mask;
 }
 
+// get a mask according to the compound type
+// TODO(sarahparker) this needs to be extended for other experiments and
+// is currently only intended for ext_inter alone
+#if CONFIG_EXT_INTER
+const uint8_t *av1_get_compound_type_mask(
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type,
+    int invert) {
+  assert(is_masked_compound_type(comp_data->type));
+  switch (comp_data->type) {
+    case COMPOUND_WEDGE:
+      return av1_get_contiguous_soft_mask(
+          comp_data->wedge_index,
+          invert ? !comp_data->wedge_sign : comp_data->wedge_sign, sb_type);
+    default: assert(0); return NULL;
+  }
+}
+#endif  // CONFIG_EXT_INTER
+
 static void init_wedge_master_masks() {
   int i, j, s;
   const int w = MASK_MASTER_SIZE;
@@ -378,17 +396,16 @@
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif  // CONFIG_SUPERTX
 
-static void build_masked_compound_wedge(uint8_t *dst, int dst_stride,
-                                        const uint8_t *src0, int src0_stride,
-                                        const uint8_t *src1, int src1_stride,
-                                        int wedge_index, int wedge_sign,
-                                        BLOCK_SIZE sb_type, int h, int w) {
+static void build_masked_compound(
+    uint8_t *dst, int dst_stride, const uint8_t *src0, int src0_stride,
+    const uint8_t *src1, int src1_stride,
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
+    int w) {
   // Derive subsampling from h and w passed in. May be refactored to
   // pass in subsampling factors directly.
   const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
   const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
-  const uint8_t *mask =
-      av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+  const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
   aom_blend_a64_mask(dst, dst_stride, src0, src0_stride, src1, src1_stride,
                      mask, block_size_wide[sb_type], h, w, subh, subw);
 }
@@ -402,8 +419,7 @@
   // pass in subsampling factors directly.
   const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
   const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
-  const uint8_t *mask =
-      av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+  const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
   aom_highbd_blend_a64_mask(dst_8, dst_stride, src0_8, src0_stride, src1_8,
                             src1_stride, mask, block_size_wide[sb_type], h, w,
                             subh, subw, bd);
@@ -426,6 +442,8 @@
 #endif  // CONFIG_SUPERTX
                                      const MACROBLOCKD *xd) {
   const MODE_INFO *mi = xd->mi[0];
+  const INTERINTER_COMPOUND_DATA *const comp_data =
+      &mi->mbmi.interinter_compound_data;
 // The prediction filter types used here should be those for
 // the second reference block.
 #if CONFIG_DUAL_FILTER
@@ -446,39 +464,35 @@
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
     build_masked_compound_wedge_extend_highbd(
         dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-        mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
-        mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w, xd->bd);
+        comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type,
+        wedge_offset_x, wedge_offset_y, h, w, xd->bd);
   else
     build_masked_compound_wedge_extend(
         dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-        mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
-        mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w);
+        comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type,
+        wedge_offset_x, wedge_offset_y, h, w);
 #else
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
     build_masked_compound_wedge_highbd(
         dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-        mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
-        mi->mbmi.sb_type, h, w, xd->bd);
+        comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, h, w,
+        xd->bd);
   else
-    build_masked_compound_wedge(dst, dst_stride, dst, dst_stride, tmp_dst,
-                                MAX_SB_SIZE, mi->mbmi.interinter_wedge_index,
-                                mi->mbmi.interinter_wedge_sign,
-                                mi->mbmi.sb_type, h, w);
+    build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst,
+                          MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w);
 #endif  // CONFIG_SUPERTX
 #else   // CONFIG_AOM_HIGHBITDEPTH
   DECLARE_ALIGNED(16, uint8_t, tmp_dst[MAX_SB_SQUARE]);
   av1_make_inter_predictor(pre, pre_stride, tmp_dst, MAX_SB_SIZE, subpel_x,
                            subpel_y, sf, w, h, 0, tmp_ipf, xs, ys, xd);
 #if CONFIG_SUPERTX
-  build_masked_compound_wedge_extend(
-      dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-      mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
-      mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w);
+  build_masked_compound_wedge_extend(dst, dst_stride, dst, dst_stride, tmp_dst,
+                                     MAX_SB_SIZE, comp_data->wedge_index,
+                                     comp_data->wedge_sign, mi->mbmi.sb_type,
+                                     wedge_offset_x, wedge_offset_y, h, w);
 #else
-  build_masked_compound_wedge(dst, dst_stride, dst, dst_stride, tmp_dst,
-                              MAX_SB_SIZE, mi->mbmi.interinter_wedge_index,
-                              mi->mbmi.interinter_wedge_sign, mi->mbmi.sb_type,
-                              h, w);
+  build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
+                        comp_data, mi->mbmi.sb_type, h, w);
 #endif  // CONFIG_SUPERTX
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 }
@@ -630,8 +644,8 @@
                  (scaled_mv.col >> SUBPEL_BITS);
 
 #if CONFIG_EXT_INTER
-          if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) &&
-              mi->mbmi.interinter_compound == COMPOUND_WEDGE)
+          if (ref &&
+              is_masked_compound_type(mi->mbmi.interinter_compound_data.type))
             av1_make_masked_inter_predictor(
                 pre, pre_buf->stride, dst, dst_buf->stride, subpel_x, subpel_y,
                 sf, w, h, mi->mbmi.interp_filter, xs, ys,
@@ -696,8 +710,7 @@
            (scaled_mv.col >> SUBPEL_BITS);
 
 #if CONFIG_EXT_INTER
-    if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) &&
-        mi->mbmi.interinter_compound == COMPOUND_WEDGE)
+    if (ref && is_masked_compound_type(mi->mbmi.interinter_compound_data.type))
       av1_make_masked_inter_predictor(pre, pre_buf->stride, dst,
                                       dst_buf->stride, subpel_x, subpel_y, sf,
                                       w, h, mi->mbmi.interp_filter, xs, ys,
@@ -1280,9 +1293,9 @@
 void modify_neighbor_predictor_for_obmc(MB_MODE_INFO *mbmi) {
   if (is_interintra_pred(mbmi)) {
     mbmi->ref_frame[1] = NONE;
-  } else if (has_second_ref(mbmi) && is_interinter_wedge_used(mbmi->sb_type) &&
-             mbmi->interinter_compound == COMPOUND_WEDGE) {
-    mbmi->interinter_compound = COMPOUND_AVERAGE;
+  } else if (has_second_ref(mbmi) &&
+             is_masked_compound_type(mbmi->interinter_compound_data.type)) {
+    mbmi->interinter_compound_data.type = COMPOUND_AVERAGE;
     mbmi->ref_frame[1] = NONE;
   }
   return;
@@ -2080,22 +2093,22 @@
   MACROBLOCKD_PLANE *const pd = &xd->plane[plane];
   struct buf_2d *const dst_buf = &pd->dst;
   uint8_t *const dst = dst_buf->buf + dst_buf->stride * y + x;
+  const INTERINTER_COMPOUND_DATA *const comp_data =
+      &mbmi->interinter_compound_data;
 
-  if (is_compound && is_interinter_wedge_used(mbmi->sb_type) &&
-      mbmi->interinter_compound == COMPOUND_WEDGE) {
+  if (is_compound &&
+      is_masked_compound_type(mbmi->interinter_compound_data.type)) {
 #if CONFIG_AOM_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
       build_masked_compound_wedge_highbd(
           dst, dst_buf->stride, CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0,
-          CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1,
-          mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign,
-          mbmi->sb_type, h, w, xd->bd);
+          CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data->wedge_index,
+          comp_data->wedge_sign, mbmi->sb_type, h, w, xd->bd);
     else
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-      build_masked_compound_wedge(
-          dst, dst_buf->stride, ext_dst0, ext_dst_stride0, ext_dst1,
-          ext_dst_stride1, mbmi->interinter_wedge_index,
-          mbmi->interinter_wedge_sign, mbmi->sb_type, h, w);
+      build_masked_compound(dst, dst_buf->stride, ext_dst0, ext_dst_stride0,
+                            ext_dst1, ext_dst_stride1, comp_data, mbmi->sb_type,
+                            h, w);
   } else {
 #if CONFIG_AOM_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)