Add temporary dummy mask for compound segmentation

This uses a segmentation mask (which is temporarily computed arbitrarily)
to blend predictors in compound prediction. The mask will be computed
using a color segmentation in a followup patch.
Change-Id: I2d24cf27a8589211f8a70779a5be2d61746406b9
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 93d496e..cf646bf 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -180,10 +180,12 @@
           mode == NEAR_NEWMV || mode == NEW_NEARMV);
 }
 
-// TODO(sarahparker) this will eventually be extended when more
-// masked compound types are added
 static INLINE int is_masked_compound_type(COMPOUND_TYPE type) {
+#if CONFIG_COMPOUND_SEGMENT
+  return (type == COMPOUND_WEDGE || type == COMPOUND_SEG);
+#else
   return (type == COMPOUND_WEDGE);
+#endif  // CONFIG_COMPOUND_SEGMENT
 }
 #else
 
@@ -259,7 +261,10 @@
   COMPOUND_TYPE type;
   int wedge_index;
   int wedge_sign;
-  // TODO(sarahparker) add neccesary data for segmentation compound type
+#if CONFIG_COMPOUND_SEGMENT
+  int which;
+  uint8_t seg_mask[2][2 * MAX_SB_SQUARE];
+#endif  // CONFIG_COMPOUND_SEGMENT
 } INTERINTER_COMPOUND_DATA;
 #endif  // CONFIG_EXT_INTER
 
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index b9df96b..a8b21e3 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -255,18 +255,44 @@
 // TODO(sarahparker) this needs to be extended for other experiments and
 // is currently only intended for ext_inter alone
 #if CONFIG_EXT_INTER
-const uint8_t *av1_get_compound_type_mask(
-    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type,
-    int invert) {
+const uint8_t *av1_get_compound_type_mask(INTERINTER_COMPOUND_DATA *comp_data,
+                                          BLOCK_SIZE sb_type, int invert) {
   assert(is_masked_compound_type(comp_data->type));
   switch (comp_data->type) {
     case COMPOUND_WEDGE:
       return av1_get_contiguous_soft_mask(
           comp_data->wedge_index,
           invert ? !comp_data->wedge_sign : comp_data->wedge_sign, sb_type);
+#if CONFIG_COMPOUND_SEGMENT
+    case COMPOUND_SEG:
+      if (invert) return comp_data->seg_mask[!comp_data->which];
+      return comp_data->seg_mask[comp_data->which];
+#endif  // CONFIG_COMPOUND_SEGMENT
     default: assert(0); return NULL;
   }
 }
+
+#if CONFIG_COMPOUND_SEGMENT
+// temporary placeholder mask, this will be generated using segmentation later
+void build_compound_seg_mask(INTERINTER_COMPOUND_DATA *comp_data,
+                             const uint8_t *src0, int src0_stride,
+                             const uint8_t *src1, int src1_stride,
+                             BLOCK_SIZE sb_type, int h, int w) {
+  int block_stride = block_size_wide[sb_type];
+  int i, j;
+  (void)src0;
+  (void)src1;
+  (void)src0_stride;
+  (void)src1_stride;
+  for (i = 0; i < h; ++i)
+    for (j = 0; j < w; ++j) {
+      // if which == 0, put more weight on the first predictor
+      comp_data->seg_mask[0][i * block_stride + j] = 45;
+      comp_data->seg_mask[1][i * block_stride + j] =
+          AOM_BLEND_A64_MAX_ALPHA - 45;
+    }
+}
+#endif  // CONFIG_COMPOUND_SEGMENT
 #endif  // CONFIG_EXT_INTER
 
 static void init_wedge_master_masks() {
@@ -396,11 +422,11 @@
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif  // CONFIG_SUPERTX
 
-static void build_masked_compound(
-    uint8_t *dst, int dst_stride, const uint8_t *src0, int src0_stride,
-    const uint8_t *src1, int src1_stride,
-    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
-    int w) {
+static void build_masked_compound(uint8_t *dst, int dst_stride,
+                                  const uint8_t *src0, int src0_stride,
+                                  const uint8_t *src1, int src1_stride,
+                                  INTERINTER_COMPOUND_DATA *comp_data,
+                                  BLOCK_SIZE sb_type, int h, int w) {
   // Derive subsampling from h and w passed in. May be refactored to
   // pass in subsampling factors directly.
   const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
@@ -441,10 +467,12 @@
 #if CONFIG_SUPERTX
                                      int wedge_offset_x, int wedge_offset_y,
 #endif  // CONFIG_SUPERTX
-                                     const MACROBLOCKD *xd) {
-  const MODE_INFO *mi = xd->mi[0];
-  const INTERINTER_COMPOUND_DATA *const comp_data =
-      &mi->mbmi.interinter_compound_data;
+#if CONFIG_COMPOUND_SEGMENT
+                                     int plane,
+#endif  // CONFIG_COMPOUND_SEGMENT
+                                     MACROBLOCKD *xd) {
+  MODE_INFO *mi = xd->mi[0];
+  INTERINTER_COMPOUND_DATA *comp_data = &mi->mbmi.interinter_compound_data;
 // The prediction filter types used here should be those for
 // the second reference block.
 #if CONFIG_DUAL_FILTER
@@ -492,6 +520,11 @@
                                      comp_data->wedge_sign, mi->mbmi.sb_type,
                                      wedge_offset_x, wedge_offset_y, h, w);
 #else
+#if CONFIG_COMPOUND_SEGMENT
+  if (!plane && comp_data->type == COMPOUND_SEG)
+    build_compound_seg_mask(comp_data, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
+                            mi->mbmi.sb_type, h, w);
+#endif  // CONFIG_COMPOUND_SEGMENT
   build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
                         comp_data, mi->mbmi.sb_type, h, w);
 #endif  // CONFIG_SUPERTX
@@ -657,6 +690,9 @@
 #if CONFIG_SUPERTX
                 wedge_offset_x, wedge_offset_y,
 #endif  // CONFIG_SUPERTX
+#if CONFIG_COMPOUND_SEGMENT
+                plane,
+#endif  // CONFIG_COMPOUND_SEGMENT
                 xd);
           else
 #endif  // CONFIG_EXT_INTER
@@ -726,6 +762,9 @@
 #if CONFIG_SUPERTX
                                       wedge_offset_x, wedge_offset_y,
 #endif  // CONFIG_SUPERTX
+#if CONFIG_COMPOUND_SEGMENT
+                                      plane,
+#endif  // CONFIG_COMPOUND_SEGMENT
                                       xd);
     else
 #else  // CONFIG_EXT_INTER
@@ -2179,16 +2218,20 @@
 static void build_wedge_inter_predictor_from_buf(
     MACROBLOCKD *xd, int plane, int x, int y, int w, int h, uint8_t *ext_dst0,
     int ext_dst_stride0, uint8_t *ext_dst1, int ext_dst_stride1) {
-  const MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const int is_compound = has_second_ref(mbmi);
   MACROBLOCKD_PLANE *const pd = &xd->plane[plane];
   struct buf_2d *const dst_buf = &pd->dst;
   uint8_t *const dst = dst_buf->buf + dst_buf->stride * y + x;
-  const INTERINTER_COMPOUND_DATA *const comp_data =
-      &mbmi->interinter_compound_data;
+  INTERINTER_COMPOUND_DATA *comp_data = &mbmi->interinter_compound_data;
 
   if (is_compound &&
       is_masked_compound_type(mbmi->interinter_compound_data.type)) {
+#if CONFIG_COMPOUND_SEGMENT
+    if (!plane && comp_data->type == COMPOUND_SEG)
+      build_compound_seg_mask(comp_data, ext_dst0, ext_dst_stride0, ext_dst1,
+                              ext_dst_stride1, mbmi->sb_type, h, w);
+#endif  // CONFIG_COMPOUND_SEGMENT
 #if CONFIG_AOM_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
       build_masked_compound_wedge_highbd(
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index da6a4d6..80bea00 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -211,6 +211,13 @@
 static INLINE int get_interintra_wedge_bits(BLOCK_SIZE sb_type) {
   return wedge_params_lookup[sb_type].bits;
 }
+
+#if CONFIG_COMPOUND_SEGMENT
+void build_compound_seg_mask(INTERINTER_COMPOUND_DATA *comp_data,
+                             const uint8_t *src0, int src0_stride,
+                             const uint8_t *src1, int src1_stride,
+                             BLOCK_SIZE sb_type, int h, int w);
+#endif  // CONFIG_COMPOUND_SEGMENT
 #endif  // CONFIG_EXT_INTER
 
 void build_inter_predictors(MACROBLOCKD *xd, int plane,
@@ -260,7 +267,10 @@
 #if CONFIG_SUPERTX
                                      int wedge_offset_x, int wedge_offset_y,
 #endif  // CONFIG_SUPERTX
-                                     const MACROBLOCKD *xd);
+#if CONFIG_COMPOUND_SEGMENT
+                                     int plane,
+#endif  // CONFIG_COMPOUND_SEGMENT
+                                     MACROBLOCKD *xd);
 #endif  // CONFIG_EXT_INTER
 
 static INLINE int round_mv_comp_q4(int value) {
@@ -528,9 +538,8 @@
                                  BLOCK_SIZE sb_type, int wedge_offset_x,
                                  int wedge_offset_y);
 
-const uint8_t *av1_get_compound_type_mask(
-    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type,
-    int invert);
+const uint8_t *av1_get_compound_type_mask(INTERINTER_COMPOUND_DATA *comp_data,
+                                          BLOCK_SIZE sb_type, int invert);
 
 void av1_build_interintra_predictors(MACROBLOCKD *xd, uint8_t *ypred,
                                      uint8_t *upred, uint8_t *vpred,
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 78ace95..11b5bc8 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -1884,6 +1884,11 @@
           aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
       mbmi->interinter_compound_data.wedge_sign = aom_read_bit(r, ACCT_STR);
     }
+#if CONFIG_COMPOUND_SEGMENT
+    else if (mbmi->interinter_compound_data.type == COMPOUND_SEG) {
+      mbmi->interinter_compound_data.which = aom_read_bit(r, ACCT_STR);
+    }
+#endif  // CONFIG_COMPOUND_SEGMENT
   }
 #endif  // CONFIG_EXT_INTER
 
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index ad052a6..4799617 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1631,6 +1631,11 @@
                           get_wedge_bits_lookup(bsize));
         aom_write_bit(w, mbmi->interinter_compound_data.wedge_sign);
       }
+#if CONFIG_COMPOUND_SEGMENT
+      else if (mbmi->interinter_compound_data.type == COMPOUND_SEG) {
+        aom_write_bit(w, mbmi->interinter_compound_data.which);
+      }
+#endif  // CONFIG_COMPOUND_SEGMENT
     }
 #endif  // CONFIG_EXT_INTER
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 6df0926..4effdf9 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4333,9 +4333,7 @@
     case COMPOUND_AVERAGE: return 0;
     case COMPOUND_WEDGE: return get_interinter_wedge_bits(bsize);
 #if CONFIG_COMPOUND_SEGMENT
-    // TODO(sarahparker) this 0 is just a placeholder, it is possible this will
-    // need to change once this mode is fully implemented
-    case COMPOUND_SEG: return 0;
+    case COMPOUND_SEG: return 1;
 #endif  // CONFIG_COMPOUND_SEGMENT
     default: assert(0); return 0;
   }
@@ -6526,15 +6524,15 @@
 
 static void do_masked_motion_search_indexed(
     const AV1_COMP *const cpi, MACROBLOCK *x,
-    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
-    int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int mv_idx[2],
-    int which) {
+    INTERINTER_COMPOUND_DATA *comp_data, BLOCK_SIZE bsize, int mi_row,
+    int mi_col, int_mv *tmp_mv, int *rate_mv, int mv_idx[2], int which) {
   // NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both
   MACROBLOCKD *xd = &x->e_mbd;
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
   BLOCK_SIZE sb_type = mbmi->sb_type;
   const uint8_t *mask;
   const int mask_stride = block_size_wide[bsize];
+
   mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
 
   if (which == 0 || which == 2)
@@ -6889,6 +6887,74 @@
   return rd;
 }
 
+#if CONFIG_COMPOUND_SEGMENT
+static int64_t pick_interinter_seg_mask(const AV1_COMP *const cpi,
+                                        const MACROBLOCK *const x,
+                                        const BLOCK_SIZE bsize,
+                                        const uint8_t *const p0,
+                                        const uint8_t *const p1) {
+  const MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  const struct buf_2d *const src = &x->plane[0].src;
+  const int bw = block_size_wide[bsize];
+  const int bh = block_size_high[bsize];
+  const int N = bw * bh;
+  int rate;
+  uint64_t sse;
+  int64_t dist;
+  int rd0, rd1;
+#if CONFIG_AOM_HIGHBITDEPTH
+  const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH;
+  const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
+#else
+  const int bd_round = 0;
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+  INTERINTER_COMPOUND_DATA comp_data = mbmi->interinter_compound_data;
+  DECLARE_ALIGNED(32, int16_t, r0[MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(32, int16_t, r1[MAX_SB_SQUARE]);
+  DECLARE_ALIGNED(32, int16_t, d10[MAX_SB_SQUARE]);
+
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (hbd) {
+    aom_highbd_subtract_block(bh, bw, r0, bw, src->buf, src->stride,
+                              CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
+    aom_highbd_subtract_block(bh, bw, r1, bw, src->buf, src->stride,
+                              CONVERT_TO_BYTEPTR(p1), bw, xd->bd);
+    aom_highbd_subtract_block(bh, bw, d10, bw, CONVERT_TO_BYTEPTR(p1), bw,
+                              CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
+  } else  // NOLINT
+#endif    // CONFIG_AOM_HIGHBITDEPTH
+  {
+    aom_subtract_block(bh, bw, r0, bw, src->buf, src->stride, p0, bw);
+    aom_subtract_block(bh, bw, r1, bw, src->buf, src->stride, p1, bw);
+    aom_subtract_block(bh, bw, d10, bw, p1, bw, p0, bw);
+  }
+
+  // build mask and inverse
+  build_compound_seg_mask(&comp_data, p0, bw, p1, bw, bsize, bh, bw);
+
+  // compute rd for mask0
+  sse = av1_wedge_sse_from_residuals(r1, d10, comp_data.seg_mask[0], N);
+  sse = 0;
+  sse = ROUND_POWER_OF_TWO(sse, bd_round);
+
+  model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist);
+  rd0 = RDCOST(x->rdmult, x->rddiv, rate, dist);
+
+  // compute rd for mask1
+  sse = av1_wedge_sse_from_residuals(r1, d10, comp_data.seg_mask[1], N);
+  sse = 0;
+  sse = ROUND_POWER_OF_TWO(sse, bd_round);
+
+  model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist);
+  rd1 = RDCOST(x->rdmult, x->rddiv, rate, dist);
+
+  // pick the better of the two
+  mbmi->interinter_compound_data.which = rd1 < rd0;
+  return mbmi->interinter_compound_data.which ? rd1 : rd0;
+}
+#endif  // CONFIG_COMPOUND_SEGMENT
+
 static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
                                      const MACROBLOCK *const x,
                                      const BLOCK_SIZE bsize,
@@ -6944,6 +7010,63 @@
   return tmp_rate_mv;
 }
 
+#if CONFIG_COMPOUND_SEGMENT
+// TODO(sarahparker) this and build_and_cost_compound_wedge can probably
+// be combined in a refactor
+static int64_t build_and_cost_compound_seg(
+    const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
+    const BLOCK_SIZE bsize, const int this_mode, int rs2, int rate_mv,
+    BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0, uint8_t **preds1,
+    int *strides, int mi_row, int mi_col) {
+  MACROBLOCKD *xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  int rate_sum;
+  int64_t dist_sum;
+  int64_t best_rd_cur = INT64_MAX;
+  int64_t rd = INT64_MAX;
+  int tmp_skip_txfm_sb;
+  int64_t tmp_skip_sse_sb;
+
+  best_rd_cur = pick_interinter_seg_mask(cpi, x, bsize, *preds0, *preds1);
+  best_rd_cur += RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv, 0);
+
+  if (have_newmv_in_inter_mode(this_mode)) {
+    *out_rate_mv = interinter_compound_motion_search(cpi, x, bsize, this_mode,
+                                                     mi_row, mi_col);
+    av1_build_inter_predictors_sby(xd, mi_row, mi_col, ctx, bsize);
+    model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
+                    &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
+    rd = RDCOST(x->rdmult, x->rddiv, rs2 + *out_rate_mv + rate_sum, dist_sum);
+    if (rd < best_rd_cur) {
+      best_rd_cur = rd;
+    } else {
+      mbmi->mv[0].as_int = cur_mv[0].as_int;
+      mbmi->mv[1].as_int = cur_mv[1].as_int;
+      *out_rate_mv = rate_mv;
+      av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
+                                               preds1, strides);
+    }
+    av1_subtract_plane(x, bsize, 0);
+    rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
+                             &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
+    if (rd != INT64_MAX)
+      rd = RDCOST(x->rdmult, x->rddiv, rs2 + *out_rate_mv + rate_sum, dist_sum);
+    best_rd_cur = rd;
+
+  } else {
+    av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
+                                             preds1, strides);
+    av1_subtract_plane(x, bsize, 0);
+    rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
+                             &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
+    if (rd != INT64_MAX)
+      rd = RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum);
+    best_rd_cur = rd;
+  }
+  return best_rd_cur;
+}
+#endif  // CONFIG_COMPOUND_SEGMENT
+
 static int64_t build_and_cost_compound_wedge(
     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
     const BLOCK_SIZE bsize, const int this_mode, int rs2, int rate_mv,
@@ -7566,7 +7689,30 @@
           }
           break;
 #if CONFIG_COMPOUND_SEGMENT
-        case COMPOUND_SEG: break;
+        case COMPOUND_SEG:
+          if (!is_interinter_wedge_used(bsize)) break;
+          if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
+              best_rd_compound / 3 < ref_best_rd) {
+            int tmp_rate_mv = 0;
+            best_rd_cur = build_and_cost_compound_seg(
+                cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &orig_dst,
+                &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
+
+            if (best_rd_cur < best_rd_compound) {
+              best_rd_compound = best_rd_cur;
+              memcpy(&best_compound_data, &mbmi->interinter_compound_data,
+                     sizeof(best_compound_data));
+              if (have_newmv_in_inter_mode(this_mode)) {
+                best_tmp_rate_mv = tmp_rate_mv;
+                best_mv[0].as_int = mbmi->mv[0].as_int;
+                best_mv[1].as_int = mbmi->mv[1].as_int;
+                // reset to original mvs for next iteration
+                mbmi->mv[0].as_int = cur_mv[0].as_int;
+                mbmi->mv[1].as_int = cur_mv[1].as_int;
+              }
+            }
+          }
+          break;
 #endif  // CONFIG_COMPOUND_SEGMENT
         default: assert(0); return 0;
       }