Improvements on segment mask
Adds a few options to make the compound mask lightly dependent on the
the two predictors.
Also adds high bit depth support
Change-Id: If57b6e8ddd140e0c00fd9d4738927d37225091cb
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 0e5a599..9f735bc 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -41,19 +41,35 @@
#if CONFIG_EXT_INTER
// Should we try rectangular interintra predictions?
#define USE_RECT_INTERINTRA 1
-#if CONFIG_COMPOUND_SEGMENT
-#define MAX_SEG_MASK_BITS 3
+#if CONFIG_COMPOUND_SEGMENT
+
+// Set COMPOUND_SEGMENT_TYPE to one of the three
+// 0: Uniform
+// 1: Difference weighted
+#define COMPOUND_SEGMENT_TYPE 1
+
+#if COMPOUND_SEGMENT_TYPE == 0
+#define MAX_SEG_MASK_BITS 1
// SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS
typedef enum {
UNIFORM_45 = 0,
UNIFORM_45_INV,
- UNIFORM_55,
- UNIFORM_55_INV,
SEG_MASK_TYPES,
} SEG_MASK_TYPE;
+
+#elif COMPOUND_SEGMENT_TYPE == 1
+#define MAX_SEG_MASK_BITS 1
+// SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS
+typedef enum {
+ DIFFWTD_45 = 0,
+ DIFFWTD_45_INV,
+ SEG_MASK_TYPES,
+} SEG_MASK_TYPE;
+
+#endif // COMPOUND_SEGMENT_TYPE
#endif // CONFIG_COMPOUND_SEGMENT
-#endif
+#endif // CONFIG_EXT_INTER
typedef enum {
KEY_FRAME = 0,
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index dc22483..4bf4aa4 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -299,8 +299,9 @@
}
#if CONFIG_COMPOUND_SEGMENT
-void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type, int h,
- int w, int mask_val) {
+#if COMPOUND_SEGMENT_TYPE == 0
+static void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type,
+ int h, int w, int mask_val) {
int i, j;
int block_stride = block_size_wide[sb_type];
for (i = 0; i < h; ++i)
@@ -321,11 +322,103 @@
switch (mask_type) {
case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break;
case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break;
- case UNIFORM_55: uniform_mask(mask, 0, sb_type, h, w, 55); break;
- case UNIFORM_55_INV: uniform_mask(mask, 1, sb_type, h, w, 55); break;
default: assert(0);
}
}
+
+#if CONFIG_AOM_HIGHBITDEPTH
+void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type,
+ const uint8_t *src0, int src0_stride,
+ const uint8_t *src1, int src1_stride,
+ BLOCK_SIZE sb_type, int h, int w, int bd) {
+ (void)src0;
+ (void)src1;
+ (void)src0_stride;
+ (void)src1_stride;
+ (void)bd;
+ switch (mask_type) {
+ case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break;
+ case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break;
+ default: assert(0);
+ }
+}
+#endif // CONFIG_AOM_HIGHBITDEPTH
+
+#elif COMPOUND_SEGMENT_TYPE == 1
+#define DIFF_FACTOR 16
+static void diffwtd_mask(uint8_t *mask, int which_inverse, int mask_base,
+ const uint8_t *src0, int src0_stride,
+ const uint8_t *src1, int src1_stride,
+ BLOCK_SIZE sb_type, int h, int w) {
+ int i, j, m, diff;
+ int block_stride = block_size_wide[sb_type];
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j) {
+ diff =
+ abs((int)src0[i * src0_stride + j] - (int)src1[i * src1_stride + j]);
+ m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA);
+ mask[i * block_stride + j] =
+ which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m;
+ }
+ }
+}
+
+void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type,
+ const uint8_t *src0, int src0_stride,
+ const uint8_t *src1, int src1_stride,
+ BLOCK_SIZE sb_type, int h, int w) {
+ switch (mask_type) {
+ case DIFFWTD_45:
+ diffwtd_mask(mask, 0, 47, src0, src0_stride, src1, src1_stride, sb_type,
+ h, w);
+ break;
+ case DIFFWTD_45_INV:
+ diffwtd_mask(mask, 1, 47, src0, src0_stride, src1, src1_stride, sb_type,
+ h, w);
+ break;
+ default: assert(0);
+ }
+}
+
+#if CONFIG_AOM_HIGHBITDEPTH
+static void diffwtd_mask_highbd(uint8_t *mask, int which_inverse, int mask_base,
+ const uint16_t *src0, int src0_stride,
+ const uint16_t *src1, int src1_stride,
+ BLOCK_SIZE sb_type, int h, int w, int bd) {
+ int i, j, m, diff;
+ int block_stride = block_size_wide[sb_type];
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; ++j) {
+ diff = abs((int)src0[i * src0_stride + j] -
+ (int)src1[i * src1_stride + j]) >>
+ (bd - 8);
+ m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA);
+ mask[i * block_stride + j] =
+ which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m;
+ }
+ }
+}
+
+void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type,
+ const uint8_t *src0, int src0_stride,
+ const uint8_t *src1, int src1_stride,
+ BLOCK_SIZE sb_type, int h, int w, int bd) {
+ switch (mask_type) {
+ case DIFFWTD_42:
+ diffwtd_mask_highbd(mask, 0, 42, CONVERT_TO_SHORTPTR(src0), src0_stride,
+ CONVERT_TO_SHORTPTR(src1), src1_stride, sb_type, h, w,
+ bd);
+ break;
+ case DIFFWTD_42_INV:
+ diffwtd_mask_highbd(mask, 1, 42, CONVERT_TO_SHORTPTR(src0), src0_stride,
+ CONVERT_TO_SHORTPTR(src1), src1_stride, sb_type, h, w,
+ bd);
+ break;
+ default: assert(0);
+ }
+}
+#endif // CONFIG_AOM_HIGHBITDEPTH
+#endif // COMPOUND_SEGMENT_TYPE
#endif // CONFIG_COMPOUND_SEGMENT
static void init_wedge_master_masks() {
@@ -470,16 +563,18 @@
}
#if CONFIG_AOM_HIGHBITDEPTH
-static void build_masked_compound_wedge_highbd(
+static void build_masked_compound_highbd(
uint8_t *dst_8, int dst_stride, const uint8_t *src0_8, int src0_stride,
- const uint8_t *src1_8, int src1_stride, int wedge_index, int wedge_sign,
- BLOCK_SIZE sb_type, int h, int w, int bd) {
+ const uint8_t *src1_8, int src1_stride,
+ const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
+ int w, int bd) {
// Derive subsampling from h and w passed in. May be refactored to
// pass in subsampling factors directly.
const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
- const uint8_t *mask =
- av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+ const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type);
+ // const uint8_t *mask =
+ // av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
aom_highbd_blend_a64_mask(dst_8, dst_stride, src0_8, src0_stride, src1_8,
src1_stride, mask, block_size_wide[sb_type], h, w,
subh, subw, bd);
@@ -534,11 +629,22 @@
comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type,
wedge_offset_x, wedge_offset_y, h, w);
#else
+#if CONFIG_COMPOUND_SEGMENT
+ if (!plane && comp_data->type == COMPOUND_SEG) {
+ if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+ build_compound_seg_mask_highbd(comp_data->seg_mask, comp_data->mask_type,
+ dst, dst_stride, tmp_dst, MAX_SB_SIZE,
+ mi->mbmi.sb_type, h, w, xd->bd);
+ else
+ build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, dst,
+ dst_stride, tmp_dst, MAX_SB_SIZE,
+ mi->mbmi.sb_type, h, w);
+ }
+#endif // CONFIG_COMPOUND_SEGMENT
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
- build_masked_compound_wedge_highbd(
- dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
- comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, h, w,
- xd->bd);
+ build_masked_compound_highbd(dst, dst_stride, dst, dst_stride, tmp_dst,
+ MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w,
+ xd->bd);
else
build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst,
MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w);
@@ -2707,17 +2813,32 @@
if (is_compound &&
is_masked_compound_type(mbmi->interinter_compound_data.type)) {
#if CONFIG_COMPOUND_SEGMENT
+#if CONFIG_AOM_HIGHBITDEPTH
+ if (!plane && comp_data->type == COMPOUND_SEG) {
+ if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+ build_compound_seg_mask_highbd(
+ comp_data->seg_mask, comp_data->mask_type,
+ CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0,
+ CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, mbmi->sb_type, h, w,
+ xd->bd);
+ else
+ build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type,
+ ext_dst0, ext_dst_stride0, ext_dst1,
+ ext_dst_stride1, mbmi->sb_type, h, w);
+ }
+#else
if (!plane && comp_data->type == COMPOUND_SEG)
build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type,
ext_dst0, ext_dst_stride0, ext_dst1,
ext_dst_stride1, mbmi->sb_type, h, w);
+#endif // CONFIG_AOM_HIGHBITDEPTH
#endif // CONFIG_COMPOUND_SEGMENT
#if CONFIG_AOM_HIGHBITDEPTH
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
- build_masked_compound_wedge_highbd(
+ build_masked_compound_highbd(
dst, dst_buf->stride, CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0,
- CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data->wedge_index,
- comp_data->wedge_sign, mbmi->sb_type, h, w, xd->bd);
+ CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data,
+ mbmi->sb_type, h, w, xd->bd);
else
#endif // CONFIG_AOM_HIGHBITDEPTH
build_masked_compound(dst, dst_buf->stride, ext_dst0, ext_dst_stride0,
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 581a977..7bea9ed 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -196,6 +196,12 @@
const uint8_t *src0, int src0_stride,
const uint8_t *src1, int src1_stride,
BLOCK_SIZE sb_type, int h, int w);
+#if CONFIG_AOM_HIGHBITDEPTH
+void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type,
+ const uint8_t *src0, int src0_stride,
+ const uint8_t *src1, int src1_stride,
+ BLOCK_SIZE sb_type, int h, int w, int bd);
+#endif // CONFIG_AOM_HIGHBITDEPTH
#endif // CONFIG_COMPOUND_SEGMENT
#endif // CONFIG_EXT_INTER