Refactor compound_segment to try different segmentation masks
Change-Id: I7c992c9aae895aebcfb5c147cb179cf665c0ac10
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index f78d9ff..0e5a599 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -41,6 +41,18 @@
#if CONFIG_EXT_INTER
// Should we try rectangular interintra predictions?
#define USE_RECT_INTERINTRA 1
+#if CONFIG_COMPOUND_SEGMENT
+#define MAX_SEG_MASK_BITS 3
+
+// SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS
+typedef enum {
+ UNIFORM_45 = 0,
+ UNIFORM_45_INV,
+ UNIFORM_55,
+ UNIFORM_55_INV,
+ SEG_MASK_TYPES,
+} SEG_MASK_TYPE;
+#endif // CONFIG_COMPOUND_SEGMENT
#endif
typedef enum {
@@ -256,8 +268,8 @@
int wedge_index;
int wedge_sign;
#if CONFIG_COMPOUND_SEGMENT
- int which;
- DECLARE_ALIGNED(16, uint8_t, seg_mask[2][2 * MAX_SB_SQUARE]);
+ SEG_MASK_TYPE mask_type;
+ DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
#endif // CONFIG_COMPOUND_SEGMENT
} INTERINTER_COMPOUND_DATA;
#endif // CONFIG_EXT_INTER
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 89f627c..601ca3e 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -251,45 +251,80 @@
return mask;
}
-// get a mask according to the compound type
-// TODO(sarahparker) this needs to be extended for other experiments and
-// is currently only intended for ext_inter alone
-const uint8_t *av1_get_compound_type_mask(INTERINTER_COMPOUND_DATA *comp_data,
- BLOCK_SIZE sb_type, int invert) {
+#if CONFIG_COMPOUND_SEGMENT
+static uint8_t *invert_mask(uint8_t *mask_inv_buffer, const uint8_t *const mask,
+ int h, int w, int stride) {
+ int i, j;
+
+ for (i = 0; i < h; ++i)
+ for (j = 0; j < w; ++j) {
+ mask_inv_buffer[i * stride + j] =
+ AOM_BLEND_A64_MAX_ALPHA - mask[i * stride + j];
+ }
+ return mask_inv_buffer;
+}
+#endif // CONFIG_COMPOUND_SEGMENT
+
+const uint8_t *av1_get_compound_type_mask_inverse(
+ const INTERINTER_COMPOUND_DATA *const comp_data,
+#if CONFIG_COMPOUND_SEGMENT
+ uint8_t *mask_buffer, int h, int w, int stride,
+#endif
+ BLOCK_SIZE sb_type) {
assert(is_masked_compound_type(comp_data->type));
switch (comp_data->type) {
case COMPOUND_WEDGE:
- return av1_get_contiguous_soft_mask(
- comp_data->wedge_index,
- invert ? !comp_data->wedge_sign : comp_data->wedge_sign, sb_type);
+ return av1_get_contiguous_soft_mask(comp_data->wedge_index,
+ !comp_data->wedge_sign, sb_type);
#if CONFIG_COMPOUND_SEGMENT
case COMPOUND_SEG:
- if (invert) return comp_data->seg_mask[!comp_data->which];
- return comp_data->seg_mask[comp_data->which];
+ return invert_mask(mask_buffer, comp_data->seg_mask, h, w, stride);
+#endif // CONFIG_COMPOUND_SEGMENT
+ default: assert(0); return NULL;
+ }
+}
+
+const uint8_t *av1_get_compound_type_mask(
+ const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type) {
+ assert(is_masked_compound_type(comp_data->type));
+ switch (comp_data->type) {
+ case COMPOUND_WEDGE:
+ return av1_get_contiguous_soft_mask(comp_data->wedge_index,
+ comp_data->wedge_sign, sb_type);
+#if CONFIG_COMPOUND_SEGMENT
+ case COMPOUND_SEG: return comp_data->seg_mask;
#endif // CONFIG_COMPOUND_SEGMENT
default: assert(0); return NULL;
}
}
#if CONFIG_COMPOUND_SEGMENT
-// temporary placeholder mask, this will be generated using segmentation later
-void build_compound_seg_mask(INTERINTER_COMPOUND_DATA *comp_data,
+void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type, int h,
+ int w, int mask_val) {
+ int i, j;
+ int block_stride = block_size_wide[sb_type];
+ for (i = 0; i < h; ++i)
+ for (j = 0; j < w; ++j) {
+ mask[i * block_stride + j] =
+ which_inverse ? AOM_BLEND_A64_MAX_ALPHA - mask_val : mask_val;
+ }
+}
+
+void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type,
const uint8_t *src0, int src0_stride,
const uint8_t *src1, int src1_stride,
BLOCK_SIZE sb_type, int h, int w) {
- int block_stride = block_size_wide[sb_type];
- int i, j;
(void)src0;
(void)src1;
(void)src0_stride;
(void)src1_stride;
- for (i = 0; i < h; ++i)
- for (j = 0; j < w; ++j) {
- // if which == 0, put more weight on the first predictor
- comp_data->seg_mask[0][i * block_stride + j] = 45;
- comp_data->seg_mask[1][i * block_stride + j] =
- AOM_BLEND_A64_MAX_ALPHA - 45;
- }
+ switch (mask_type) {
+ case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break;
+ case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break;
+ case UNIFORM_55: uniform_mask(mask, 0, sb_type, h, w, 55); break;
+ case UNIFORM_55_INV: uniform_mask(mask, 1, sb_type, h, w, 55); break;
+ default: assert(0);
+ }
}
#endif // CONFIG_COMPOUND_SEGMENT
@@ -420,16 +455,16 @@
#endif // CONFIG_AOM_HIGHBITDEPTH
#endif // CONFIG_SUPERTX
-static void build_masked_compound(uint8_t *dst, int dst_stride,
- const uint8_t *src0, int src0_stride,
- const uint8_t *src1, int src1_stride,
- INTERINTER_COMPOUND_DATA *comp_data,
- BLOCK_SIZE sb_type, int h, int w) {
+static void build_masked_compound(
+ uint8_t *dst, int dst_stride, const uint8_t *src0, int src0_stride,
+ const uint8_t *src1, int src1_stride,
+ const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
+ int w) {
// Derive subsampling from h and w passed in. May be refactored to
// pass in subsampling factors directly.
const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
- const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
+ const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type);
aom_blend_a64_mask(dst, dst_stride, src0, src0_stride, src1, src1_stride,
mask, block_size_wide[sb_type], h, w, subh, subw);
}
@@ -520,8 +555,9 @@
#else
#if CONFIG_COMPOUND_SEGMENT
if (!plane && comp_data->type == COMPOUND_SEG)
- build_compound_seg_mask(comp_data, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
- mi->mbmi.sb_type, h, w);
+ build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, dst,
+ dst_stride, tmp_dst, MAX_SB_SIZE, mi->mbmi.sb_type,
+ h, w);
#endif // CONFIG_COMPOUND_SEGMENT
build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
comp_data, mi->mbmi.sb_type, h, w);
@@ -2216,7 +2252,8 @@
is_masked_compound_type(mbmi->interinter_compound_data.type)) {
#if CONFIG_COMPOUND_SEGMENT
if (!plane && comp_data->type == COMPOUND_SEG)
- build_compound_seg_mask(comp_data, ext_dst0, ext_dst_stride0, ext_dst1,
+ build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type,
+ ext_dst0, ext_dst_stride0, ext_dst1,
ext_dst_stride1, mbmi->sb_type, h, w);
#endif // CONFIG_COMPOUND_SEGMENT
#if CONFIG_AOM_HIGHBITDEPTH
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 1819bb2..408f83a 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -189,7 +189,7 @@
}
#if CONFIG_COMPOUND_SEGMENT
-void build_compound_seg_mask(INTERINTER_COMPOUND_DATA *comp_data,
+void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type,
const uint8_t *src0, int src0_stride,
const uint8_t *src1, int src1_stride,
BLOCK_SIZE sb_type, int h, int w);
@@ -514,8 +514,15 @@
BLOCK_SIZE sb_type, int wedge_offset_x,
int wedge_offset_y);
-const uint8_t *av1_get_compound_type_mask(INTERINTER_COMPOUND_DATA *comp_data,
- BLOCK_SIZE sb_type, int invert);
+const uint8_t *av1_get_compound_type_mask_inverse(
+ const INTERINTER_COMPOUND_DATA *const comp_data,
+#if CONFIG_COMPOUND_SEGMENT
+ uint8_t *mask_buffer, int h, int w, int stride,
+#endif
+ BLOCK_SIZE sb_type);
+
+const uint8_t *av1_get_compound_type_mask(
+ const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type);
void av1_build_interintra_predictors(MACROBLOCKD *xd, uint8_t *ypred,
uint8_t *upred, uint8_t *vpred,