Improvements on segment mask

Adds a few options to make the compound mask lightly dependent on the
the two predictors.

Also adds high bit depth support

Change-Id: If57b6e8ddd140e0c00fd9d4738927d37225091cb
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 0e5a599..9f735bc 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -41,19 +41,35 @@
 #if CONFIG_EXT_INTER
 // Should we try rectangular interintra predictions?
 #define USE_RECT_INTERINTRA 1
-#if CONFIG_COMPOUND_SEGMENT
-#define MAX_SEG_MASK_BITS 3
 
+#if CONFIG_COMPOUND_SEGMENT
+
+// Set COMPOUND_SEGMENT_TYPE to one of the three
+// 0: Uniform
+// 1: Difference weighted
+#define COMPOUND_SEGMENT_TYPE 1
+
+#if COMPOUND_SEGMENT_TYPE == 0
+#define MAX_SEG_MASK_BITS 1
 // SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS
 typedef enum {
   UNIFORM_45 = 0,
   UNIFORM_45_INV,
-  UNIFORM_55,
-  UNIFORM_55_INV,
   SEG_MASK_TYPES,
 } SEG_MASK_TYPE;
+
+#elif COMPOUND_SEGMENT_TYPE == 1
+#define MAX_SEG_MASK_BITS 1
+// SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS
+typedef enum {
+  DIFFWTD_45 = 0,
+  DIFFWTD_45_INV,
+  SEG_MASK_TYPES,
+} SEG_MASK_TYPE;
+
+#endif  // COMPOUND_SEGMENT_TYPE
 #endif  // CONFIG_COMPOUND_SEGMENT
-#endif
+#endif  // CONFIG_EXT_INTER
 
 typedef enum {
   KEY_FRAME = 0,
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index dc22483..4bf4aa4 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -299,8 +299,9 @@
 }
 
 #if CONFIG_COMPOUND_SEGMENT
-void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type, int h,
-                  int w, int mask_val) {
+#if COMPOUND_SEGMENT_TYPE == 0
+static void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type,
+                         int h, int w, int mask_val) {
   int i, j;
   int block_stride = block_size_wide[sb_type];
   for (i = 0; i < h; ++i)
@@ -321,11 +322,103 @@
   switch (mask_type) {
     case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break;
     case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break;
-    case UNIFORM_55: uniform_mask(mask, 0, sb_type, h, w, 55); break;
-    case UNIFORM_55_INV: uniform_mask(mask, 1, sb_type, h, w, 55); break;
     default: assert(0);
   }
 }
+
+#if CONFIG_AOM_HIGHBITDEPTH
+void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type,
+                                    const uint8_t *src0, int src0_stride,
+                                    const uint8_t *src1, int src1_stride,
+                                    BLOCK_SIZE sb_type, int h, int w, int bd) {
+  (void)src0;
+  (void)src1;
+  (void)src0_stride;
+  (void)src1_stride;
+  (void)bd;
+  switch (mask_type) {
+    case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break;
+    case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break;
+    default: assert(0);
+  }
+}
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+
+#elif COMPOUND_SEGMENT_TYPE == 1
+#define DIFF_FACTOR 16
+static void diffwtd_mask(uint8_t *mask, int which_inverse, int mask_base,
+                         const uint8_t *src0, int src0_stride,
+                         const uint8_t *src1, int src1_stride,
+                         BLOCK_SIZE sb_type, int h, int w) {
+  int i, j, m, diff;
+  int block_stride = block_size_wide[sb_type];
+  for (i = 0; i < h; ++i) {
+    for (j = 0; j < w; ++j) {
+      diff =
+          abs((int)src0[i * src0_stride + j] - (int)src1[i * src1_stride + j]);
+      m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA);
+      mask[i * block_stride + j] =
+          which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m;
+    }
+  }
+}
+
+void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type,
+                             const uint8_t *src0, int src0_stride,
+                             const uint8_t *src1, int src1_stride,
+                             BLOCK_SIZE sb_type, int h, int w) {
+  switch (mask_type) {
+    case DIFFWTD_45:
+      diffwtd_mask(mask, 0, 47, src0, src0_stride, src1, src1_stride, sb_type,
+                   h, w);
+      break;
+    case DIFFWTD_45_INV:
+      diffwtd_mask(mask, 1, 47, src0, src0_stride, src1, src1_stride, sb_type,
+                   h, w);
+      break;
+    default: assert(0);
+  }
+}
+
+#if CONFIG_AOM_HIGHBITDEPTH
+static void diffwtd_mask_highbd(uint8_t *mask, int which_inverse, int mask_base,
+                                const uint16_t *src0, int src0_stride,
+                                const uint16_t *src1, int src1_stride,
+                                BLOCK_SIZE sb_type, int h, int w, int bd) {
+  int i, j, m, diff;
+  int block_stride = block_size_wide[sb_type];
+  for (i = 0; i < h; ++i) {
+    for (j = 0; j < w; ++j) {
+      diff = abs((int)src0[i * src0_stride + j] -
+                 (int)src1[i * src1_stride + j]) >>
+             (bd - 8);
+      m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA);
+      mask[i * block_stride + j] =
+          which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m;
+    }
+  }
+}
+
+void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type,
+                                    const uint8_t *src0, int src0_stride,
+                                    const uint8_t *src1, int src1_stride,
+                                    BLOCK_SIZE sb_type, int h, int w, int bd) {
+  switch (mask_type) {
+    case DIFFWTD_42:
+      diffwtd_mask_highbd(mask, 0, 42, CONVERT_TO_SHORTPTR(src0), src0_stride,
+                          CONVERT_TO_SHORTPTR(src1), src1_stride, sb_type, h, w,
+                          bd);
+      break;
+    case DIFFWTD_42_INV:
+      diffwtd_mask_highbd(mask, 1, 42, CONVERT_TO_SHORTPTR(src0), src0_stride,
+                          CONVERT_TO_SHORTPTR(src1), src1_stride, sb_type, h, w,
+                          bd);
+      break;
+    default: assert(0);
+  }
+}
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+#endif  // COMPOUND_SEGMENT_TYPE
 #endif  // CONFIG_COMPOUND_SEGMENT
 
 static void init_wedge_master_masks() {
@@ -470,16 +563,18 @@
 }
 
 #if CONFIG_AOM_HIGHBITDEPTH
-static void build_masked_compound_wedge_highbd(
+static void build_masked_compound_highbd(
     uint8_t *dst_8, int dst_stride, const uint8_t *src0_8, int src0_stride,
-    const uint8_t *src1_8, int src1_stride, int wedge_index, int wedge_sign,
-    BLOCK_SIZE sb_type, int h, int w, int bd) {
+    const uint8_t *src1_8, int src1_stride,
+    const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
+    int w, int bd) {
   // Derive subsampling from h and w passed in. May be refactored to
   // pass in subsampling factors directly.
   const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
   const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
-  const uint8_t *mask =
-      av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+  const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type);
+  // const uint8_t *mask =
+  //     av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
   aom_highbd_blend_a64_mask(dst_8, dst_stride, src0_8, src0_stride, src1_8,
                             src1_stride, mask, block_size_wide[sb_type], h, w,
                             subh, subw, bd);
@@ -534,11 +629,22 @@
         comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type,
         wedge_offset_x, wedge_offset_y, h, w);
 #else
+#if CONFIG_COMPOUND_SEGMENT
+  if (!plane && comp_data->type == COMPOUND_SEG) {
+    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+      build_compound_seg_mask_highbd(comp_data->seg_mask, comp_data->mask_type,
+                                     dst, dst_stride, tmp_dst, MAX_SB_SIZE,
+                                     mi->mbmi.sb_type, h, w, xd->bd);
+    else
+      build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, dst,
+                              dst_stride, tmp_dst, MAX_SB_SIZE,
+                              mi->mbmi.sb_type, h, w);
+  }
+#endif  // CONFIG_COMPOUND_SEGMENT
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-    build_masked_compound_wedge_highbd(
-        dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
-        comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, h, w,
-        xd->bd);
+    build_masked_compound_highbd(dst, dst_stride, dst, dst_stride, tmp_dst,
+                                 MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w,
+                                 xd->bd);
   else
     build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst,
                           MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w);
@@ -2707,17 +2813,32 @@
   if (is_compound &&
       is_masked_compound_type(mbmi->interinter_compound_data.type)) {
 #if CONFIG_COMPOUND_SEGMENT
+#if CONFIG_AOM_HIGHBITDEPTH
+    if (!plane && comp_data->type == COMPOUND_SEG) {
+      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+        build_compound_seg_mask_highbd(
+            comp_data->seg_mask, comp_data->mask_type,
+            CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0,
+            CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, mbmi->sb_type, h, w,
+            xd->bd);
+      else
+        build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type,
+                                ext_dst0, ext_dst_stride0, ext_dst1,
+                                ext_dst_stride1, mbmi->sb_type, h, w);
+    }
+#else
     if (!plane && comp_data->type == COMPOUND_SEG)
       build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type,
                               ext_dst0, ext_dst_stride0, ext_dst1,
                               ext_dst_stride1, mbmi->sb_type, h, w);
+#endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif  // CONFIG_COMPOUND_SEGMENT
 #if CONFIG_AOM_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-      build_masked_compound_wedge_highbd(
+      build_masked_compound_highbd(
           dst, dst_buf->stride, CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0,
-          CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data->wedge_index,
-          comp_data->wedge_sign, mbmi->sb_type, h, w, xd->bd);
+          CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data,
+          mbmi->sb_type, h, w, xd->bd);
     else
 #endif  // CONFIG_AOM_HIGHBITDEPTH
       build_masked_compound(dst, dst_buf->stride, ext_dst0, ext_dst_stride0,
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 581a977..7bea9ed 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -196,6 +196,12 @@
                              const uint8_t *src0, int src0_stride,
                              const uint8_t *src1, int src1_stride,
                              BLOCK_SIZE sb_type, int h, int w);
+#if CONFIG_AOM_HIGHBITDEPTH
+void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type,
+                                    const uint8_t *src0, int src0_stride,
+                                    const uint8_t *src1, int src1_stride,
+                                    BLOCK_SIZE sb_type, int h, int w, int bd);
+#endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif  // CONFIG_COMPOUND_SEGMENT
 #endif  // CONFIG_EXT_INTER