Refactor ext-inter to loop through all masked modes in rdopt

No change in performance

Change-Id: Ie105a7baf6a2c2258d3ef117e727957e4393f51b
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 2a88725..c9fcfb2 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -163,6 +163,12 @@
           mode == NEAREST_NEWMV || mode == NEW_NEARESTMV ||
           mode == NEAR_NEWMV || mode == NEW_NEARMV);
 }
+
+// TODO(sarahparker) this will eventually be extended when more
+// masked compound types are added
+static INLINE int is_masked_compound_type(COMPOUND_TYPE type) {
+  return (type == COMPOUND_WEDGE);
+}
 #else
 
 static INLINE int have_newmv_in_inter_mode(PREDICTION_MODE mode) {
@@ -232,6 +238,15 @@
 #endif  // CONFIG_RD_DEBUG
 } RD_STATS;
 
+#if CONFIG_EXT_INTER
+typedef struct {
+  COMPOUND_TYPE type;
+  int wedge_index;
+  int wedge_sign;
+  // TODO(sarahparker) add neccesary data for segmentation compound type
+} INTERINTER_COMPOUND_DATA;
+#endif  // CONFIG_EXT_INTER
+
 // This structure now relates to 8x8 block regions.
 typedef struct {
   // Common for both INTER and INTRA blocks
@@ -282,9 +297,7 @@
   int use_wedge_interintra;
   int interintra_wedge_index;
   int interintra_wedge_sign;
-  COMPOUND_TYPE interinter_compound;
-  int interinter_wedge_index;
-  int interinter_wedge_sign;
+  INTERINTER_COMPOUND_DATA interinter_compound_data;
 #endif  // CONFIG_EXT_INTER
   MOTION_MODE motion_mode;
   int_mv mv[2];
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index faa9abf..d91d6b3 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -1910,10 +1910,9 @@
   }
 
   for (i = 0; i < BLOCK_SIZES; ++i) {
-    if (is_interinter_wedge_used(i))
-      aom_tree_merge_probs(
-          av1_compound_type_tree, pre_fc->compound_type_prob[i],
-          counts->compound_interinter[i], fc->compound_type_prob[i]);
+    aom_tree_merge_probs(av1_compound_type_tree, pre_fc->compound_type_prob[i],
+                         counts->compound_interinter[i],
+                         fc->compound_type_prob[i]);
   }
 #endif  // CONFIG_EXT_INTER
 
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index bc77687..1ae85b0 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -251,6 +251,24 @@
   return mask;
 }
 
+// get a mask according to the compound type
+// TODO(sarahparker) this needs to be extended for other experiments and
+// is currently only intended for ext_inter alone
+#if CONFIG_EXT_INTER
+const uint8_t *av1_get_compound_type_mask(
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type,
+    int invert) {
+  assert(is_masked_compound_type(comp_data->type));
+  switch (comp_data->type) {
+    case COMPOUND_WEDGE:
+      return av1_get_contiguous_soft_mask(
+          comp_data->wedge_index,
+          invert ? !comp_data->wedge_sign : comp_data->wedge_sign, sb_type);
+    default: assert(0); return NULL;
+  }
+}
+#endif  // CONFIG_EXT_INTER
+
 static void init_wedge_master_masks() {
   int i, j, s;
   const int w = MASK_MASTER_SIZE;
@@ -378,17 +396,16 @@
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif  // CONFIG_SUPERTX
 
-static void build_masked_compound_wedge(uint8_t *dst, int dst_stride,
-                                        const uint8_t *src0, int src0_stride,
-                                        const uint8_t *src1, int src1_stride,
-                                        int wedge_index, int wedge_sign,
-                                        BLOCK_SIZE sb_type, int h, int w) {
+static void build_masked_compound(
+    uint8_t *dst, int dst_stride, const uint8_t *src0, int src0_stride,
+    const uint8_t *src1, int src1_stride,
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
+    int w) {
   // Derive subsampling from h and w passed in. May be refactored to
   // pass in subsampling factors directly.
   const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
   const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
-  const uint8_t *mask =
-      av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+  const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
   aom_blend_a64_mask(dst, dst_stride, src0, src0_stride, src1, src1_stride,
                      mask, block_size_wide[sb_type], h, w, subh, subw);
 }
@@ -402,8 +419,7 @@
   // pass in subsampling factors directly.
   const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
   const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
-  const uint8_t *mask =
-      av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+  const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
   aom_highbd_blend_a64_mask(dst_8, dst_stride, src0_8, src0_stride, src1_8,
                             src1_stride, mask, block_size_wide[sb_type], h, w,
                             subh, subw, bd);
@@ -426,6 +442,8 @@
 #endif  // CONFIG_SUPERTX
                                      const MACROBLOCKD *xd) {
   const MODE_INFO *mi = xd->mi[0];
+  const INTERINTER_COMPOUND_DATA *const comp_data =
+      &mi->mbmi.interinter_compound_data;
 // The prediction filter types used here should be those for
 // the second reference block.
 #if CONFIG_DUAL_FILTER
@@ -446,39 +464,35 @@
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
     build_masked_compound_wedge_extend_highbd(
         dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-        mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
-        mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w, xd->bd);
+        comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type,
+        wedge_offset_x, wedge_offset_y, h, w, xd->bd);
   else
     build_masked_compound_wedge_extend(
         dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-        mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
-        mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w);
+        comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type,
+        wedge_offset_x, wedge_offset_y, h, w);
 #else
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
     build_masked_compound_wedge_highbd(
         dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-        mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
-        mi->mbmi.sb_type, h, w, xd->bd);
+        comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, h, w,
+        xd->bd);
   else
-    build_masked_compound_wedge(dst, dst_stride, dst, dst_stride, tmp_dst,
-                                MAX_SB_SIZE, mi->mbmi.interinter_wedge_index,
-                                mi->mbmi.interinter_wedge_sign,
-                                mi->mbmi.sb_type, h, w);
+    build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst,
+                          MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w);
 #endif  // CONFIG_SUPERTX
 #else   // CONFIG_AOM_HIGHBITDEPTH
   DECLARE_ALIGNED(16, uint8_t, tmp_dst[MAX_SB_SQUARE]);
   av1_make_inter_predictor(pre, pre_stride, tmp_dst, MAX_SB_SIZE, subpel_x,
                            subpel_y, sf, w, h, 0, tmp_ipf, xs, ys, xd);
 #if CONFIG_SUPERTX
-  build_masked_compound_wedge_extend(
-      dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-      mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
-      mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w);
+  build_masked_compound_wedge_extend(dst, dst_stride, dst, dst_stride, tmp_dst,
+                                     MAX_SB_SIZE, comp_data->wedge_index,
+                                     comp_data->wedge_sign, mi->mbmi.sb_type,
+                                     wedge_offset_x, wedge_offset_y, h, w);
 #else
-  build_masked_compound_wedge(dst, dst_stride, dst, dst_stride, tmp_dst,
-                              MAX_SB_SIZE, mi->mbmi.interinter_wedge_index,
-                              mi->mbmi.interinter_wedge_sign, mi->mbmi.sb_type,
-                              h, w);
+  build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
+                        comp_data, mi->mbmi.sb_type, h, w);
 #endif  // CONFIG_SUPERTX
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 }
@@ -630,8 +644,8 @@
                  (scaled_mv.col >> SUBPEL_BITS);
 
 #if CONFIG_EXT_INTER
-          if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) &&
-              mi->mbmi.interinter_compound == COMPOUND_WEDGE)
+          if (ref &&
+              is_masked_compound_type(mi->mbmi.interinter_compound_data.type))
             av1_make_masked_inter_predictor(
                 pre, pre_buf->stride, dst, dst_buf->stride, subpel_x, subpel_y,
                 sf, w, h, mi->mbmi.interp_filter, xs, ys,
@@ -696,8 +710,7 @@
            (scaled_mv.col >> SUBPEL_BITS);
 
 #if CONFIG_EXT_INTER
-    if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) &&
-        mi->mbmi.interinter_compound == COMPOUND_WEDGE)
+    if (ref && is_masked_compound_type(mi->mbmi.interinter_compound_data.type))
       av1_make_masked_inter_predictor(pre, pre_buf->stride, dst,
                                       dst_buf->stride, subpel_x, subpel_y, sf,
                                       w, h, mi->mbmi.interp_filter, xs, ys,
@@ -1280,9 +1293,9 @@
 void modify_neighbor_predictor_for_obmc(MB_MODE_INFO *mbmi) {
   if (is_interintra_pred(mbmi)) {
     mbmi->ref_frame[1] = NONE;
-  } else if (has_second_ref(mbmi) && is_interinter_wedge_used(mbmi->sb_type) &&
-             mbmi->interinter_compound == COMPOUND_WEDGE) {
-    mbmi->interinter_compound = COMPOUND_AVERAGE;
+  } else if (has_second_ref(mbmi) &&
+             is_masked_compound_type(mbmi->interinter_compound_data.type)) {
+    mbmi->interinter_compound_data.type = COMPOUND_AVERAGE;
     mbmi->ref_frame[1] = NONE;
   }
   return;
@@ -2080,22 +2093,22 @@
   MACROBLOCKD_PLANE *const pd = &xd->plane[plane];
   struct buf_2d *const dst_buf = &pd->dst;
   uint8_t *const dst = dst_buf->buf + dst_buf->stride * y + x;
+  const INTERINTER_COMPOUND_DATA *const comp_data =
+      &mbmi->interinter_compound_data;
 
-  if (is_compound && is_interinter_wedge_used(mbmi->sb_type) &&
-      mbmi->interinter_compound == COMPOUND_WEDGE) {
+  if (is_compound &&
+      is_masked_compound_type(mbmi->interinter_compound_data.type)) {
 #if CONFIG_AOM_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
       build_masked_compound_wedge_highbd(
           dst, dst_buf->stride, CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0,
-          CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1,
-          mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign,
-          mbmi->sb_type, h, w, xd->bd);
+          CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data->wedge_index,
+          comp_data->wedge_sign, mbmi->sb_type, h, w, xd->bd);
     else
 #endif  // CONFIG_AOM_HIGHBITDEPTH
-      build_masked_compound_wedge(
-          dst, dst_buf->stride, ext_dst0, ext_dst_stride0, ext_dst1,
-          ext_dst_stride1, mbmi->interinter_wedge_index,
-          mbmi->interinter_wedge_sign, mbmi->sb_type, h, w);
+      build_masked_compound(dst, dst_buf->stride, ext_dst0, ext_dst_stride0,
+                            ext_dst1, ext_dst_stride1, comp_data, mbmi->sb_type,
+                            h, w);
   } else {
 #if CONFIG_AOM_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 13f581e..62a196f 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -522,6 +522,10 @@
                                  BLOCK_SIZE sb_type, int wedge_offset_x,
                                  int wedge_offset_y);
 
+const uint8_t *av1_get_compound_type_mask(
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type,
+    int invert);
+
 void av1_build_interintra_predictors(MACROBLOCKD *xd, uint8_t *ypred,
                                      uint8_t *upred, uint8_t *vpred,
                                      int ystride, int ustride, int vstride,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 2b00c51..8fba4cb 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -4327,10 +4327,8 @@
     }
     if (cm->reference_mode != SINGLE_REFERENCE) {
       for (i = 0; i < BLOCK_SIZES; i++) {
-        if (is_interinter_wedge_used(i)) {
-          for (j = 0; j < COMPOUND_TYPES - 1; j++) {
-            av1_diff_update_prob(&r, &fc->compound_type_prob[i][j], ACCT_STR);
-          }
+        for (j = 0; j < COMPOUND_TYPES - 1; j++) {
+          av1_diff_update_prob(&r, &fc->compound_type_prob[i][j], ACCT_STR);
         }
       }
     }
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 101ed3e..05dc0fa 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -1816,21 +1816,22 @@
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
 
 #if CONFIG_EXT_INTER
-  mbmi->interinter_compound = COMPOUND_AVERAGE;
+  mbmi->interinter_compound_data.type = COMPOUND_AVERAGE;
   if (cm->reference_mode != SINGLE_REFERENCE &&
-      is_inter_compound_mode(mbmi->mode) &&
+      is_inter_compound_mode(mbmi->mode)
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
-      mbmi->motion_mode == SIMPLE_TRANSLATION &&
+      && mbmi->motion_mode == SIMPLE_TRANSLATION
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
-      is_interinter_wedge_used(bsize)) {
-    mbmi->interinter_compound = aom_read_tree(
+      ) {
+    mbmi->interinter_compound_data.type = aom_read_tree(
         r, av1_compound_type_tree, cm->fc->compound_type_prob[bsize], ACCT_STR);
     if (xd->counts)
-      xd->counts->compound_interinter[bsize][mbmi->interinter_compound]++;
-    if (mbmi->interinter_compound == COMPOUND_WEDGE) {
-      mbmi->interinter_wedge_index =
+      xd->counts->compound_interinter[bsize]
+                                     [mbmi->interinter_compound_data.type]++;
+    if (mbmi->interinter_compound_data.type == COMPOUND_WEDGE) {
+      mbmi->interinter_compound_data.wedge_index =
           aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
-      mbmi->interinter_wedge_sign = aom_read_bit(r, ACCT_STR);
+      mbmi->interinter_compound_data.wedge_sign = aom_read_bit(r, ACCT_STR);
     }
   }
 #endif  // CONFIG_EXT_INTER
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index acce688..d7d3701 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1599,18 +1599,18 @@
 
 #if CONFIG_EXT_INTER
     if (cpi->common.reference_mode != SINGLE_REFERENCE &&
-        is_inter_compound_mode(mbmi->mode) &&
+        is_inter_compound_mode(mbmi->mode)
 #if CONFIG_MOTION_VAR
-        mbmi->motion_mode == SIMPLE_TRANSLATION &&
+        && mbmi->motion_mode == SIMPLE_TRANSLATION
 #endif  // CONFIG_MOTION_VAR
-        is_interinter_wedge_used(bsize)) {
-      av1_write_token(w, av1_compound_type_tree,
-                      cm->fc->compound_type_prob[bsize],
-                      &compound_type_encodings[mbmi->interinter_compound]);
-      if (mbmi->interinter_compound == COMPOUND_WEDGE) {
-        aom_write_literal(w, mbmi->interinter_wedge_index,
+        ) {
+      av1_write_token(
+          w, av1_compound_type_tree, cm->fc->compound_type_prob[bsize],
+          &compound_type_encodings[mbmi->interinter_compound_data.type]);
+      if (mbmi->interinter_compound_data.type == COMPOUND_WEDGE) {
+        aom_write_literal(w, mbmi->interinter_compound_data.wedge_index,
                           get_wedge_bits_lookup(bsize));
-        aom_write_bit(w, mbmi->interinter_wedge_sign);
+        aom_write_bit(w, mbmi->interinter_compound_data.wedge_sign);
       }
     }
 #endif  // CONFIG_EXT_INTER
@@ -4232,10 +4232,9 @@
     }
     if (cm->reference_mode != SINGLE_REFERENCE) {
       for (i = 0; i < BLOCK_SIZES; i++)
-        if (is_interinter_wedge_used(i))
-          prob_diff_update(av1_compound_type_tree, fc->compound_type_prob[i],
-                           cm->counts.compound_interinter[i], COMPOUND_TYPES,
-                           probwt, header_bc);
+        prob_diff_update(av1_compound_type_tree, fc->compound_type_prob[i],
+                         cm->counts.compound_interinter[i], COMPOUND_TYPES,
+                         probwt, header_bc);
     }
 #endif  // CONFIG_EXT_INTER
 
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a768c1c..50199b9 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1992,12 +1992,13 @@
 
 #if CONFIG_EXT_INTER
         if (cm->reference_mode != SINGLE_REFERENCE &&
-            is_inter_compound_mode(mbmi->mode) &&
+            is_inter_compound_mode(mbmi->mode)
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
-            mbmi->motion_mode == SIMPLE_TRANSLATION &&
+            && mbmi->motion_mode == SIMPLE_TRANSLATION
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
-            is_interinter_wedge_used(bsize)) {
-          counts->compound_interinter[bsize][mbmi->interinter_compound]++;
+            ) {
+          counts->compound_interinter[bsize]
+                                     [mbmi->interinter_compound_data.type]++;
         }
 #endif  // CONFIG_EXT_INTER
       }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 3f45757..8ead4ac 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4284,6 +4284,17 @@
 #endif
 }
 
+#if CONFIG_EXT_INTER
+static int get_interinter_compound_type_bits(BLOCK_SIZE bsize,
+                                             COMPOUND_TYPE comp_type) {
+  switch (comp_type) {
+    case COMPOUND_AVERAGE: return 0;
+    case COMPOUND_WEDGE: return get_interinter_wedge_bits(bsize);
+    default: assert(0); return 0;
+  }
+}
+#endif  // CONFIG_EXT_INTER
+
 #if CONFIG_GLOBAL_MOTION
 #define GLOBAL_MOTION_COST_AMORTIZATION_BLKS 8
 
@@ -6466,19 +6477,18 @@
   }
 }
 
-static void do_masked_motion_search_indexed(const AV1_COMP *const cpi,
-                                            MACROBLOCK *x, int wedge_index,
-                                            int wedge_sign, BLOCK_SIZE bsize,
-                                            int mi_row, int mi_col,
-                                            int_mv *tmp_mv, int *rate_mv,
-                                            int mv_idx[2], int which) {
+static void do_masked_motion_search_indexed(
+    const AV1_COMP *const cpi, MACROBLOCK *x,
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
+    int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int mv_idx[2],
+    int which) {
   // NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both
   MACROBLOCKD *xd = &x->e_mbd;
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   BLOCK_SIZE sb_type = mbmi->sb_type;
   const uint8_t *mask;
   const int mask_stride = block_size_wide[bsize];
-  mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+  mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
 
   if (which == 0 || which == 2)
     do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
@@ -6486,7 +6496,7 @@
 
   if (which == 1 || which == 2) {
     // get the negative mask
-    mask = av1_get_contiguous_soft_mask(wedge_index, !wedge_sign, sb_type);
+    mask = av1_get_compound_type_mask(comp_data, sb_type, 1);
     do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
                             &tmp_mv[1], &rate_mv[1], 1, mv_idx[1]);
   }
@@ -6827,8 +6837,8 @@
     rd = pick_wedge(cpi, x, bsize, p0, p1, &wedge_sign, &wedge_index);
   }
 
-  mbmi->interinter_wedge_sign = wedge_sign;
-  mbmi->interinter_wedge_index = wedge_index;
+  mbmi->interinter_compound_data.wedge_sign = wedge_sign;
+  mbmi->interinter_compound_data.wedge_index = wedge_index;
   return rd;
 }
 
@@ -6851,6 +6861,94 @@
   mbmi->interintra_wedge_index = wedge_index;
   return rd;
 }
+
+static int interinter_compound_motion_search(const AV1_COMP *const cpi,
+                                             MACROBLOCK *x,
+                                             const BLOCK_SIZE bsize,
+                                             const int this_mode, int mi_row,
+                                             int mi_col) {
+  const MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  int_mv tmp_mv[2];
+  int rate_mvs[2], tmp_rate_mv = 0;
+  if (this_mode == NEW_NEWMV) {
+    int mv_idxs[2] = { 0, 0 };
+    do_masked_motion_search_indexed(cpi, x, &mbmi->interinter_compound_data,
+                                    bsize, mi_row, mi_col, tmp_mv, rate_mvs,
+                                    mv_idxs, 2);
+    tmp_rate_mv = rate_mvs[0] + rate_mvs[1];
+    mbmi->mv[0].as_int = tmp_mv[0].as_int;
+    mbmi->mv[1].as_int = tmp_mv[1].as_int;
+  } else if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV) {
+    int mv_idxs[2] = { 0, 0 };
+    do_masked_motion_search_indexed(cpi, x, &mbmi->interinter_compound_data,
+                                    bsize, mi_row, mi_col, tmp_mv, rate_mvs,
+                                    mv_idxs, 0);
+    tmp_rate_mv = rate_mvs[0];
+    mbmi->mv[0].as_int = tmp_mv[0].as_int;
+  } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
+    int mv_idxs[2] = { 0, 0 };
+    do_masked_motion_search_indexed(cpi, x, &mbmi->interinter_compound_data,
+                                    bsize, mi_row, mi_col, tmp_mv, rate_mvs,
+                                    mv_idxs, 1);
+    tmp_rate_mv = rate_mvs[1];
+    mbmi->mv[1].as_int = tmp_mv[1].as_int;
+  }
+  return tmp_rate_mv;
+}
+
+static int64_t build_and_cost_compound_wedge(
+    const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
+    const BLOCK_SIZE bsize, const int this_mode, int rs2, int rate_mv,
+    int *out_rate_mv, uint8_t **preds0, uint8_t **preds1, int *strides,
+    int mi_row, int mi_col) {
+  MACROBLOCKD *xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  int rate_sum;
+  int64_t dist_sum;
+  int64_t best_rd_cur = INT64_MAX;
+  int64_t rd = INT64_MAX;
+  int tmp_skip_txfm_sb;
+  int64_t tmp_skip_sse_sb;
+
+  best_rd_cur = pick_interinter_wedge(cpi, x, bsize, *preds0, *preds1);
+  best_rd_cur += RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv, 0);
+
+  if (have_newmv_in_inter_mode(this_mode)) {
+    *out_rate_mv = interinter_compound_motion_search(cpi, x, bsize, this_mode,
+                                                     mi_row, mi_col);
+    av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize);
+    model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
+                    &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
+    rd = RDCOST(x->rdmult, x->rddiv, rs2 + *out_rate_mv + rate_sum, dist_sum);
+    if (rd < best_rd_cur) {
+      best_rd_cur = rd;
+    } else {
+      mbmi->mv[0].as_int = cur_mv[0].as_int;
+      mbmi->mv[1].as_int = cur_mv[1].as_int;
+      *out_rate_mv = rate_mv;
+      av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
+                                               preds1, strides);
+    }
+    av1_subtract_plane(x, bsize, 0);
+    rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
+                             &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
+    if (rd != INT64_MAX)
+      rd = RDCOST(x->rdmult, x->rddiv, rs2 + *out_rate_mv + rate_sum, dist_sum);
+    best_rd_cur = rd;
+
+  } else {
+    av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
+                                             preds1, strides);
+    av1_subtract_plane(x, bsize, 0);
+    rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
+                             &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
+    if (rd != INT64_MAX)
+      rd = RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum);
+    best_rd_cur = rd;
+  }
+  return best_rd_cur;
+}
 #endif  // CONFIG_EXT_INTER
 
 static int64_t handle_inter_mode(
@@ -6865,7 +6963,7 @@
 #if CONFIG_EXT_INTER
     int_mv single_newmvs[2][TOTAL_REFS_PER_FRAME],
     int single_newmvs_rate[2][TOTAL_REFS_PER_FRAME],
-    int *compmode_interintra_cost, int *compmode_wedge_cost,
+    int *compmode_interintra_cost, int *compmode_interinter_cost,
     int64_t (*const modelled_rd)[TOTAL_REFS_PER_FRAME],
 #else
     int_mv single_newmv[TOTAL_REFS_PER_FRAME],
@@ -6941,8 +7039,8 @@
 #if CONFIG_EXT_INTER
   *compmode_interintra_cost = 0;
   mbmi->use_wedge_interintra = 0;
-  *compmode_wedge_cost = 0;
-  mbmi->interinter_compound = COMPOUND_AVERAGE;
+  *compmode_interinter_cost = 0;
+  mbmi->interinter_compound_data.type = COMPOUND_AVERAGE;
 
   // is_comp_interintra_pred implies !is_comp_pred
   assert(!is_comp_interintra_pred || (!is_comp_pred));
@@ -7351,141 +7449,107 @@
 #endif  // CONFIG_MOTION_VAR
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
 
-  if (is_comp_pred && is_interinter_wedge_used(bsize)) {
+  if (is_comp_pred) {
     int rate_sum, rs2;
     int64_t dist_sum;
-    int64_t best_rd_nowedge = INT64_MAX;
-    int64_t best_rd_wedge = INT64_MAX;
+    int64_t best_rd_compound = INT64_MAX, best_rd_cur = INT64_MAX;
+    INTERINTER_COMPOUND_DATA best_compound_data;
+    int_mv best_mv[2];
+    int best_tmp_rate_mv = rate_mv;
     int tmp_skip_txfm_sb;
     int64_t tmp_skip_sse_sb;
     int compound_type_cost[COMPOUND_TYPES];
+    uint8_t pred0[2 * MAX_SB_SQUARE];
+    uint8_t pred1[2 * MAX_SB_SQUARE];
+    uint8_t *preds0[1] = { pred0 };
+    uint8_t *preds1[1] = { pred1 };
+    int strides[1] = { bw };
+    COMPOUND_TYPE cur_type;
 
-    mbmi->interinter_compound = COMPOUND_AVERAGE;
+    best_mv[0].as_int = cur_mv[0].as_int;
+    best_mv[1].as_int = cur_mv[1].as_int;
+    memset(&best_compound_data, 0, sizeof(INTERINTER_COMPOUND_DATA));
     av1_cost_tokens(compound_type_cost, cm->fc->compound_type_prob[bsize],
                     av1_compound_type_tree);
-    rs2 = compound_type_cost[mbmi->interinter_compound];
-    av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize);
-    av1_subtract_plane(x, bsize, 0);
-    rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
-                             &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
-    if (rd != INT64_MAX)
-      rd = RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum);
-    best_rd_nowedge = rd;
 
-    // Disbale wedge search if source variance is small
-    if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
-        best_rd_nowedge / 3 < ref_best_rd) {
-      uint8_t pred0[2 * MAX_SB_SQUARE];
-      uint8_t pred1[2 * MAX_SB_SQUARE];
-      uint8_t *preds0[1] = { pred0 };
-      uint8_t *preds1[1] = { pred1 };
-      int strides[1] = { bw };
-
-      mbmi->interinter_compound = COMPOUND_WEDGE;
-      rs2 = av1_cost_literal(get_interinter_wedge_bits(bsize)) +
-            compound_type_cost[mbmi->interinter_compound];
-
+    if (is_interinter_wedge_used(bsize)) {
+      // get inter predictors to use for masked compound modes
       av1_build_inter_predictors_for_planes_single_buf(
           xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides);
       av1_build_inter_predictors_for_planes_single_buf(
           xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides);
+    }
 
-      // Choose the best wedge
-      best_rd_wedge = pick_interinter_wedge(cpi, x, bsize, pred0, pred1);
-      best_rd_wedge += RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv, 0);
+    for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
+      best_rd_cur = INT64_MAX;
+      mbmi->interinter_compound_data.type = cur_type;
+      rs2 = av1_cost_literal(get_interinter_compound_type_bits(
+                bsize, mbmi->interinter_compound_data.type)) +
+            compound_type_cost[mbmi->interinter_compound_data.type];
 
-      if (have_newmv_in_inter_mode(this_mode)) {
-        int_mv tmp_mv[2];
-        int rate_mvs[2], tmp_rate_mv = 0;
-        if (this_mode == NEW_NEWMV) {
-          int mv_idxs[2] = { 0, 0 };
-          do_masked_motion_search_indexed(
-              cpi, x, mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign,
-              bsize, mi_row, mi_col, tmp_mv, rate_mvs, mv_idxs, 2);
-          tmp_rate_mv = rate_mvs[0] + rate_mvs[1];
-          mbmi->mv[0].as_int = tmp_mv[0].as_int;
-          mbmi->mv[1].as_int = tmp_mv[1].as_int;
-        } else if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV) {
-          int mv_idxs[2] = { 0, 0 };
-          do_masked_motion_search_indexed(
-              cpi, x, mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign,
-              bsize, mi_row, mi_col, tmp_mv, rate_mvs, mv_idxs, 0);
-          tmp_rate_mv = rate_mvs[0];
-          mbmi->mv[0].as_int = tmp_mv[0].as_int;
-        } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
-          int mv_idxs[2] = { 0, 0 };
-          do_masked_motion_search_indexed(
-              cpi, x, mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign,
-              bsize, mi_row, mi_col, tmp_mv, rate_mvs, mv_idxs, 1);
-          tmp_rate_mv = rate_mvs[1];
-          mbmi->mv[1].as_int = tmp_mv[1].as_int;
-        }
-        av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize);
-        model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
-                        &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
-        rd =
-            RDCOST(x->rdmult, x->rddiv, rs2 + tmp_rate_mv + rate_sum, dist_sum);
-        if (rd < best_rd_wedge) {
-          best_rd_wedge = rd;
-        } else {
-          mbmi->mv[0].as_int = cur_mv[0].as_int;
-          mbmi->mv[1].as_int = cur_mv[1].as_int;
-          tmp_rate_mv = rate_mv;
-          av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
-                                                   strides, preds1, strides);
-        }
-        av1_subtract_plane(x, bsize, 0);
-        rd =
-            estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
-                                &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
-        if (rd != INT64_MAX)
-          rd = RDCOST(x->rdmult, x->rddiv, rs2 + tmp_rate_mv + rate_sum,
-                      dist_sum);
-        best_rd_wedge = rd;
+      switch (cur_type) {
+        case COMPOUND_AVERAGE:
+          av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize);
+          av1_subtract_plane(x, bsize, 0);
+          rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
+                                   &tmp_skip_txfm_sb, &tmp_skip_sse_sb,
+                                   INT64_MAX);
+          if (rd != INT64_MAX)
+            rd =
+                RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum);
+          best_rd_compound = rd;
+          break;
+        case COMPOUND_WEDGE:
+          if (!is_interinter_wedge_used(bsize)) break;
+          if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
+              best_rd_compound / 3 < ref_best_rd) {
+            int tmp_rate_mv = 0;
+            best_rd_cur = build_and_cost_compound_wedge(
+                cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &tmp_rate_mv,
+                preds0, preds1, strides, mi_row, mi_col);
 
-        if (best_rd_wedge < best_rd_nowedge) {
-          mbmi->interinter_compound = COMPOUND_WEDGE;
-          xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int;
-          xd->mi[0]->bmi[0].as_mv[1].as_int = mbmi->mv[1].as_int;
-          rd_stats->rate += tmp_rate_mv - rate_mv;
-          rate_mv = tmp_rate_mv;
-        } else {
-          mbmi->interinter_compound = COMPOUND_AVERAGE;
-          mbmi->mv[0].as_int = cur_mv[0].as_int;
-          mbmi->mv[1].as_int = cur_mv[1].as_int;
-          xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int;
-          xd->mi[0]->bmi[0].as_mv[1].as_int = mbmi->mv[1].as_int;
-        }
-      } else {
-        av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
-                                                 strides, preds1, strides);
-        av1_subtract_plane(x, bsize, 0);
-        rd =
-            estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
-                                &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
-        if (rd != INT64_MAX)
-          rd = RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum);
-        best_rd_wedge = rd;
-        if (best_rd_wedge < best_rd_nowedge) {
-          mbmi->interinter_compound = COMPOUND_WEDGE;
-        } else {
-          mbmi->interinter_compound = COMPOUND_AVERAGE;
-        }
+            if (best_rd_cur < best_rd_compound) {
+              best_rd_compound = best_rd_cur;
+              memcpy(&best_compound_data, &mbmi->interinter_compound_data,
+                     sizeof(best_compound_data));
+              if (have_newmv_in_inter_mode(this_mode)) {
+                best_tmp_rate_mv = tmp_rate_mv;
+                best_mv[0].as_int = mbmi->mv[0].as_int;
+                best_mv[1].as_int = mbmi->mv[1].as_int;
+                // reset to original mvs for next iteration
+                mbmi->mv[0].as_int = cur_mv[0].as_int;
+                mbmi->mv[1].as_int = cur_mv[1].as_int;
+              }
+            }
+          }
+          break;
+        default: assert(0); return 0;
       }
     }
-    if (ref_best_rd < INT64_MAX &&
-        AOMMIN(best_rd_wedge, best_rd_nowedge) / 3 > ref_best_rd) {
+    memcpy(&mbmi->interinter_compound_data, &best_compound_data,
+           sizeof(INTERINTER_COMPOUND_DATA));
+    if (have_newmv_in_inter_mode(this_mode)) {
+      mbmi->mv[0].as_int = best_mv[0].as_int;
+      mbmi->mv[1].as_int = best_mv[1].as_int;
+      xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int;
+      xd->mi[0]->bmi[0].as_mv[1].as_int = mbmi->mv[1].as_int;
+      if (mbmi->interinter_compound_data.type) {
+        rd_stats->rate += best_tmp_rate_mv - rate_mv;
+        rate_mv = best_tmp_rate_mv;
+      }
+    }
+
+    if (ref_best_rd < INT64_MAX && best_rd_compound / 3 > ref_best_rd) {
       restore_dst_buf(xd, orig_dst, orig_dst_stride);
       return INT64_MAX;
     }
 
     pred_exists = 0;
 
-    *compmode_wedge_cost = compound_type_cost[mbmi->interinter_compound];
-
-    if (mbmi->interinter_compound == COMPOUND_WEDGE)
-      *compmode_wedge_cost +=
-          av1_cost_literal(get_interinter_wedge_bits(bsize));
+    *compmode_interinter_cost =
+        compound_type_cost[mbmi->interinter_compound_data.type] +
+        av1_cost_literal(get_interinter_compound_type_bits(
+            bsize, mbmi->interinter_compound_data.type));
   }
 
   if (is_comp_interintra_pred) {
@@ -8782,7 +8846,7 @@
     int compmode_cost = 0;
 #if CONFIG_EXT_INTER
     int compmode_interintra_cost = 0;
-    int compmode_wedge_cost = 0;
+    int compmode_interinter_cost = 0;
 #endif  // CONFIG_EXT_INTER
     int rate2 = 0, rate_y = 0, rate_uv = 0;
     int64_t distortion2 = 0, distortion_y = 0, distortion_uv = 0;
@@ -9184,7 +9248,7 @@
 #endif  // CONFIG_MOTION_VAR
 #if CONFIG_EXT_INTER
             single_newmvs, single_newmvs_rate, &compmode_interintra_cost,
-            &compmode_wedge_cost, modelled_rd,
+            &compmode_interinter_cost, modelled_rd,
 #else
             single_newmv,
 #endif  // CONFIG_EXT_INTER
@@ -9280,7 +9344,7 @@
             int dummy_single_newmvs_rate[2][TOTAL_REFS_PER_FRAME] = { { 0 },
                                                                       { 0 } };
             int dummy_compmode_interintra_cost = 0;
-            int dummy_compmode_wedge_cost = 0;
+            int dummy_compmode_interinter_cost = 0;
 #else
             int_mv dummy_single_newmv[TOTAL_REFS_PER_FRAME] = { { 0 } };
 #endif
@@ -9295,8 +9359,8 @@
 #endif  // CONFIG_MOTION_VAR
 #if CONFIG_EXT_INTER
                 dummy_single_newmvs, dummy_single_newmvs_rate,
-                &dummy_compmode_interintra_cost, &dummy_compmode_wedge_cost,
-                NULL,
+                &dummy_compmode_interintra_cost,
+                &dummy_compmode_interinter_cost, NULL,
 #else
                 dummy_single_newmv,
 #endif
@@ -9396,7 +9460,7 @@
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
       if (mbmi->motion_mode == SIMPLE_TRANSLATION)
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
-        rate2 += compmode_wedge_cost;
+        rate2 += compmode_interinter_cost;
 #endif  // CONFIG_EXT_INTER
 
     // Estimate the reference frame signaling cost and add it
@@ -10278,7 +10342,7 @@
 #endif  // CONFIG_FILTER_INTRA
   mbmi->motion_mode = SIMPLE_TRANSLATION;
 #if CONFIG_EXT_INTER
-  mbmi->interinter_compound = COMPOUND_AVERAGE;
+  mbmi->interinter_compound_data.type = COMPOUND_AVERAGE;
   mbmi->use_wedge_interintra = 0;
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_WARPED_MOTION