Refactor compound_segment to try different segmentation masks

Change-Id: I7c992c9aae895aebcfb5c147cb179cf665c0ac10
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index f78d9ff..0e5a599 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -41,6 +41,18 @@
 #if CONFIG_EXT_INTER
 // Should we try rectangular interintra predictions?
 #define USE_RECT_INTERINTRA 1
+#if CONFIG_COMPOUND_SEGMENT
+#define MAX_SEG_MASK_BITS 3
+
+// SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS
+typedef enum {
+  UNIFORM_45 = 0,
+  UNIFORM_45_INV,
+  UNIFORM_55,
+  UNIFORM_55_INV,
+  SEG_MASK_TYPES,
+} SEG_MASK_TYPE;
+#endif  // CONFIG_COMPOUND_SEGMENT
 #endif
 
 typedef enum {
@@ -256,8 +268,8 @@
   int wedge_index;
   int wedge_sign;
 #if CONFIG_COMPOUND_SEGMENT
-  int which;
-  DECLARE_ALIGNED(16, uint8_t, seg_mask[2][2 * MAX_SB_SQUARE]);
+  SEG_MASK_TYPE mask_type;
+  DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
 #endif  // CONFIG_COMPOUND_SEGMENT
 } INTERINTER_COMPOUND_DATA;
 #endif  // CONFIG_EXT_INTER
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 89f627c..601ca3e 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -251,45 +251,80 @@
   return mask;
 }
 
-// get a mask according to the compound type
-// TODO(sarahparker) this needs to be extended for other experiments and
-// is currently only intended for ext_inter alone
-const uint8_t *av1_get_compound_type_mask(INTERINTER_COMPOUND_DATA *comp_data,
-                                          BLOCK_SIZE sb_type, int invert) {
+#if CONFIG_COMPOUND_SEGMENT
+static uint8_t *invert_mask(uint8_t *mask_inv_buffer, const uint8_t *const mask,
+                            int h, int w, int stride) {
+  int i, j;
+
+  for (i = 0; i < h; ++i)
+    for (j = 0; j < w; ++j) {
+      mask_inv_buffer[i * stride + j] =
+          AOM_BLEND_A64_MAX_ALPHA - mask[i * stride + j];
+    }
+  return mask_inv_buffer;
+}
+#endif  // CONFIG_COMPOUND_SEGMENT
+
+const uint8_t *av1_get_compound_type_mask_inverse(
+    const INTERINTER_COMPOUND_DATA *const comp_data,
+#if CONFIG_COMPOUND_SEGMENT
+    uint8_t *mask_buffer, int h, int w, int stride,
+#endif
+    BLOCK_SIZE sb_type) {
   assert(is_masked_compound_type(comp_data->type));
   switch (comp_data->type) {
     case COMPOUND_WEDGE:
-      return av1_get_contiguous_soft_mask(
-          comp_data->wedge_index,
-          invert ? !comp_data->wedge_sign : comp_data->wedge_sign, sb_type);
+      return av1_get_contiguous_soft_mask(comp_data->wedge_index,
+                                          !comp_data->wedge_sign, sb_type);
 #if CONFIG_COMPOUND_SEGMENT
     case COMPOUND_SEG:
-      if (invert) return comp_data->seg_mask[!comp_data->which];
-      return comp_data->seg_mask[comp_data->which];
+      return invert_mask(mask_buffer, comp_data->seg_mask, h, w, stride);
+#endif  // CONFIG_COMPOUND_SEGMENT
+    default: assert(0); return NULL;
+  }
+}
+
+const uint8_t *av1_get_compound_type_mask(
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type) {
+  assert(is_masked_compound_type(comp_data->type));
+  switch (comp_data->type) {
+    case COMPOUND_WEDGE:
+      return av1_get_contiguous_soft_mask(comp_data->wedge_index,
+                                          comp_data->wedge_sign, sb_type);
+#if CONFIG_COMPOUND_SEGMENT
+    case COMPOUND_SEG: return comp_data->seg_mask;
 #endif  // CONFIG_COMPOUND_SEGMENT
     default: assert(0); return NULL;
   }
 }
 
 #if CONFIG_COMPOUND_SEGMENT
-// temporary placeholder mask, this will be generated using segmentation later
-void build_compound_seg_mask(INTERINTER_COMPOUND_DATA *comp_data,
+void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type, int h,
+                  int w, int mask_val) {
+  int i, j;
+  int block_stride = block_size_wide[sb_type];
+  for (i = 0; i < h; ++i)
+    for (j = 0; j < w; ++j) {
+      mask[i * block_stride + j] =
+          which_inverse ? AOM_BLEND_A64_MAX_ALPHA - mask_val : mask_val;
+    }
+}
+
+void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type,
                              const uint8_t *src0, int src0_stride,
                              const uint8_t *src1, int src1_stride,
                              BLOCK_SIZE sb_type, int h, int w) {
-  int block_stride = block_size_wide[sb_type];
-  int i, j;
   (void)src0;
   (void)src1;
   (void)src0_stride;
   (void)src1_stride;
-  for (i = 0; i < h; ++i)
-    for (j = 0; j < w; ++j) {
-      // if which == 0, put more weight on the first predictor
-      comp_data->seg_mask[0][i * block_stride + j] = 45;
-      comp_data->seg_mask[1][i * block_stride + j] =
-          AOM_BLEND_A64_MAX_ALPHA - 45;
-    }
+  switch (mask_type) {
+    case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break;
+    case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break;
+    case UNIFORM_55: uniform_mask(mask, 0, sb_type, h, w, 55); break;
+    case UNIFORM_55_INV: uniform_mask(mask, 1, sb_type, h, w, 55); break;
+    default: assert(0);
+  }
 }
 #endif  // CONFIG_COMPOUND_SEGMENT
 
@@ -420,16 +455,16 @@
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif  // CONFIG_SUPERTX
 
-static void build_masked_compound(uint8_t *dst, int dst_stride,
-                                  const uint8_t *src0, int src0_stride,
-                                  const uint8_t *src1, int src1_stride,
-                                  INTERINTER_COMPOUND_DATA *comp_data,
-                                  BLOCK_SIZE sb_type, int h, int w) {
+static void build_masked_compound(
+    uint8_t *dst, int dst_stride, const uint8_t *src0, int src0_stride,
+    const uint8_t *src1, int src1_stride,
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
+    int w) {
   // Derive subsampling from h and w passed in. May be refactored to
   // pass in subsampling factors directly.
   const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
   const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
-  const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
+  const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type);
   aom_blend_a64_mask(dst, dst_stride, src0, src0_stride, src1, src1_stride,
                      mask, block_size_wide[sb_type], h, w, subh, subw);
 }
@@ -520,8 +555,9 @@
 #else
 #if CONFIG_COMPOUND_SEGMENT
   if (!plane && comp_data->type == COMPOUND_SEG)
-    build_compound_seg_mask(comp_data, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-                            mi->mbmi.sb_type, h, w);
+    build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, dst,
+                            dst_stride, tmp_dst, MAX_SB_SIZE, mi->mbmi.sb_type,
+                            h, w);
 #endif  // CONFIG_COMPOUND_SEGMENT
   build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
                         comp_data, mi->mbmi.sb_type, h, w);
@@ -2216,7 +2252,8 @@
       is_masked_compound_type(mbmi->interinter_compound_data.type)) {
 #if CONFIG_COMPOUND_SEGMENT
     if (!plane && comp_data->type == COMPOUND_SEG)
-      build_compound_seg_mask(comp_data, ext_dst0, ext_dst_stride0, ext_dst1,
+      build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type,
+                              ext_dst0, ext_dst_stride0, ext_dst1,
                               ext_dst_stride1, mbmi->sb_type, h, w);
 #endif  // CONFIG_COMPOUND_SEGMENT
 #if CONFIG_AOM_HIGHBITDEPTH
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 1819bb2..408f83a 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -189,7 +189,7 @@
 }
 
 #if CONFIG_COMPOUND_SEGMENT
-void build_compound_seg_mask(INTERINTER_COMPOUND_DATA *comp_data,
+void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type,
                              const uint8_t *src0, int src0_stride,
                              const uint8_t *src1, int src1_stride,
                              BLOCK_SIZE sb_type, int h, int w);
@@ -514,8 +514,15 @@
                                  BLOCK_SIZE sb_type, int wedge_offset_x,
                                  int wedge_offset_y);
 
-const uint8_t *av1_get_compound_type_mask(INTERINTER_COMPOUND_DATA *comp_data,
-                                          BLOCK_SIZE sb_type, int invert);
+const uint8_t *av1_get_compound_type_mask_inverse(
+    const INTERINTER_COMPOUND_DATA *const comp_data,
+#if CONFIG_COMPOUND_SEGMENT
+    uint8_t *mask_buffer, int h, int w, int stride,
+#endif
+    BLOCK_SIZE sb_type);
+
+const uint8_t *av1_get_compound_type_mask(
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type);
 
 void av1_build_interintra_predictors(MACROBLOCKD *xd, uint8_t *ypred,
                                      uint8_t *upred, uint8_t *vpred,
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index e07625d..ce2e30e 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -1888,7 +1888,8 @@
     }
 #if CONFIG_COMPOUND_SEGMENT
     else if (mbmi->interinter_compound_data.type == COMPOUND_SEG) {
-      mbmi->interinter_compound_data.which = aom_read_bit(r, ACCT_STR);
+      mbmi->interinter_compound_data.mask_type =
+          aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
     }
 #endif  // CONFIG_COMPOUND_SEGMENT
   }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index a047a90..ab4e47e 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1626,7 +1626,8 @@
       }
 #if CONFIG_COMPOUND_SEGMENT
       else if (mbmi->interinter_compound_data.type == COMPOUND_SEG) {
-        aom_write_bit(w, mbmi->interinter_compound_data.which);
+        aom_write_literal(w, mbmi->interinter_compound_data.mask_type,
+                          MAX_SEG_MASK_BITS);
       }
 #endif  // CONFIG_COMPOUND_SEGMENT
     }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 346a63e..0c7eea1 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -6468,8 +6468,9 @@
 
 static void do_masked_motion_search_indexed(
     const AV1_COMP *const cpi, MACROBLOCK *x,
-    INTERINTER_COMPOUND_DATA *comp_data, BLOCK_SIZE bsize, int mi_row,
-    int mi_col, int_mv *tmp_mv, int *rate_mv, int mv_idx[2], int which) {
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
+    int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int mv_idx[2],
+    int which) {
   // NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both
   MACROBLOCKD *xd = &x->e_mbd;
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
@@ -6477,15 +6478,22 @@
   const uint8_t *mask;
   const int mask_stride = block_size_wide[bsize];
 
-  mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
+  mask = av1_get_compound_type_mask(comp_data, sb_type);
 
   if (which == 0 || which == 2)
     do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
                             &tmp_mv[0], &rate_mv[0], 0, mv_idx[0]);
 
   if (which == 1 || which == 2) {
-    // get the negative mask
-    mask = av1_get_compound_type_mask(comp_data, sb_type, 1);
+// get the negative mask
+#if CONFIG_COMPOUND_SEGMENT
+    uint8_t inv_mask_buf[2 * MAX_SB_SQUARE];
+    const int h = block_size_high[bsize];
+    mask = av1_get_compound_type_mask_inverse(
+        comp_data, inv_mask_buf, h, mask_stride, mask_stride, sb_type);
+#else
+    mask = av1_get_compound_type_mask_inverse(comp_data, sb_type);
+#endif
     do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
                             &tmp_mv[1], &rate_mv[1], 1, mv_idx[1]);
   }
@@ -6846,7 +6854,10 @@
   int rate;
   uint64_t sse;
   int64_t dist;
-  int rd0, rd1;
+  int rd0;
+  SEG_MASK_TYPE cur_mask_type;
+  int64_t best_rd = INT64_MAX;
+  SEG_MASK_TYPE best_mask_type = 0;
 #if CONFIG_AOM_HIGHBITDEPTH
   const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH;
   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
@@ -6874,26 +6885,31 @@
     aom_subtract_block(bh, bw, d10, bw, p1, bw, p0, bw);
   }
 
-  // build mask and inverse
-  build_compound_seg_mask(comp_data, p0, bw, p1, bw, bsize, bh, bw);
+  // try each mask type and its inverse
+  for (cur_mask_type = 0; cur_mask_type < SEG_MASK_TYPES; cur_mask_type++) {
+    // build mask and inverse
+    build_compound_seg_mask(comp_data->seg_mask, cur_mask_type, p0, bw, p1, bw,
+                            bsize, bh, bw);
 
-  // compute rd for mask0
-  sse = av1_wedge_sse_from_residuals(r1, d10, comp_data->seg_mask[0], N);
-  sse = ROUND_POWER_OF_TWO(sse, bd_round);
+    // compute rd for mask
+    sse = av1_wedge_sse_from_residuals(r1, d10, comp_data->seg_mask, N);
+    sse = ROUND_POWER_OF_TWO(sse, bd_round);
 
-  model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist);
-  rd0 = RDCOST(x->rdmult, x->rddiv, rate, dist);
+    model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist);
+    rd0 = RDCOST(x->rdmult, x->rddiv, rate, dist);
 
-  // compute rd for mask1
-  sse = av1_wedge_sse_from_residuals(r1, d10, comp_data->seg_mask[1], N);
-  sse = ROUND_POWER_OF_TWO(sse, bd_round);
+    if (rd0 < best_rd) {
+      best_mask_type = cur_mask_type;
+      best_rd = rd0;
+    }
+  }
 
-  model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist);
-  rd1 = RDCOST(x->rdmult, x->rddiv, rate, dist);
+  // make final mask
+  comp_data->mask_type = best_mask_type;
+  build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, p0, bw, p1,
+                          bw, bsize, bh, bw);
 
-  // pick the better of the two
-  mbmi->interinter_compound_data.which = rd1 < rd0;
-  return mbmi->interinter_compound_data.which ? rd1 : rd0;
+  return best_rd;
 }
 #endif  // CONFIG_COMPOUND_SEGMENT