Refactor ext-inter to loop through all masked modes in rdopt
No change in performance
Change-Id: Ie105a7baf6a2c2258d3ef117e727957e4393f51b
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index bc77687..1ae85b0 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -251,6 +251,24 @@
return mask;
}
+// get a mask according to the compound type
+// TODO(sarahparker) this needs to be extended for other experiments and
+// is currently only intended for ext_inter alone
+#if CONFIG_EXT_INTER
+const uint8_t *av1_get_compound_type_mask(
+ const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type,
+ int invert) {
+ assert(is_masked_compound_type(comp_data->type));
+ switch (comp_data->type) {
+ case COMPOUND_WEDGE:
+ return av1_get_contiguous_soft_mask(
+ comp_data->wedge_index,
+ invert ? !comp_data->wedge_sign : comp_data->wedge_sign, sb_type);
+ default: assert(0); return NULL;
+ }
+}
+#endif // CONFIG_EXT_INTER
+
static void init_wedge_master_masks() {
int i, j, s;
const int w = MASK_MASTER_SIZE;
@@ -378,17 +396,16 @@
#endif // CONFIG_AOM_HIGHBITDEPTH
#endif // CONFIG_SUPERTX
-static void build_masked_compound_wedge(uint8_t *dst, int dst_stride,
- const uint8_t *src0, int src0_stride,
- const uint8_t *src1, int src1_stride,
- int wedge_index, int wedge_sign,
- BLOCK_SIZE sb_type, int h, int w) {
+static void build_masked_compound(
+ uint8_t *dst, int dst_stride, const uint8_t *src0, int src0_stride,
+ const uint8_t *src1, int src1_stride,
+ const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
+ int w) {
// Derive subsampling from h and w passed in. May be refactored to
// pass in subsampling factors directly.
const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
- const uint8_t *mask =
- av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+ const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
aom_blend_a64_mask(dst, dst_stride, src0, src0_stride, src1, src1_stride,
mask, block_size_wide[sb_type], h, w, subh, subw);
}
@@ -402,8 +419,7 @@
// pass in subsampling factors directly.
const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
- const uint8_t *mask =
- av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type);
+ const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
aom_highbd_blend_a64_mask(dst_8, dst_stride, src0_8, src0_stride, src1_8,
src1_stride, mask, block_size_wide[sb_type], h, w,
subh, subw, bd);
@@ -426,6 +442,8 @@
#endif // CONFIG_SUPERTX
const MACROBLOCKD *xd) {
const MODE_INFO *mi = xd->mi[0];
+ const INTERINTER_COMPOUND_DATA *const comp_data =
+ &mi->mbmi.interinter_compound_data;
// The prediction filter types used here should be those for
// the second reference block.
#if CONFIG_DUAL_FILTER
@@ -446,39 +464,35 @@
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
build_masked_compound_wedge_extend_highbd(
dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
- mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
- mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w, xd->bd);
+ comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type,
+ wedge_offset_x, wedge_offset_y, h, w, xd->bd);
else
build_masked_compound_wedge_extend(
dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
- mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
- mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w);
+ comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type,
+ wedge_offset_x, wedge_offset_y, h, w);
#else
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
build_masked_compound_wedge_highbd(
dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
- mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
- mi->mbmi.sb_type, h, w, xd->bd);
+ comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, h, w,
+ xd->bd);
else
- build_masked_compound_wedge(dst, dst_stride, dst, dst_stride, tmp_dst,
- MAX_SB_SIZE, mi->mbmi.interinter_wedge_index,
- mi->mbmi.interinter_wedge_sign,
- mi->mbmi.sb_type, h, w);
+ build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst,
+ MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w);
#endif // CONFIG_SUPERTX
#else // CONFIG_AOM_HIGHBITDEPTH
DECLARE_ALIGNED(16, uint8_t, tmp_dst[MAX_SB_SQUARE]);
av1_make_inter_predictor(pre, pre_stride, tmp_dst, MAX_SB_SIZE, subpel_x,
subpel_y, sf, w, h, 0, tmp_ipf, xs, ys, xd);
#if CONFIG_SUPERTX
- build_masked_compound_wedge_extend(
- dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
- mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign,
- mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w);
+ build_masked_compound_wedge_extend(dst, dst_stride, dst, dst_stride, tmp_dst,
+ MAX_SB_SIZE, comp_data->wedge_index,
+ comp_data->wedge_sign, mi->mbmi.sb_type,
+ wedge_offset_x, wedge_offset_y, h, w);
#else
- build_masked_compound_wedge(dst, dst_stride, dst, dst_stride, tmp_dst,
- MAX_SB_SIZE, mi->mbmi.interinter_wedge_index,
- mi->mbmi.interinter_wedge_sign, mi->mbmi.sb_type,
- h, w);
+ build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
+ comp_data, mi->mbmi.sb_type, h, w);
#endif // CONFIG_SUPERTX
#endif // CONFIG_AOM_HIGHBITDEPTH
}
@@ -630,8 +644,8 @@
(scaled_mv.col >> SUBPEL_BITS);
#if CONFIG_EXT_INTER
- if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) &&
- mi->mbmi.interinter_compound == COMPOUND_WEDGE)
+ if (ref &&
+ is_masked_compound_type(mi->mbmi.interinter_compound_data.type))
av1_make_masked_inter_predictor(
pre, pre_buf->stride, dst, dst_buf->stride, subpel_x, subpel_y,
sf, w, h, mi->mbmi.interp_filter, xs, ys,
@@ -696,8 +710,7 @@
(scaled_mv.col >> SUBPEL_BITS);
#if CONFIG_EXT_INTER
- if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) &&
- mi->mbmi.interinter_compound == COMPOUND_WEDGE)
+ if (ref && is_masked_compound_type(mi->mbmi.interinter_compound_data.type))
av1_make_masked_inter_predictor(pre, pre_buf->stride, dst,
dst_buf->stride, subpel_x, subpel_y, sf,
w, h, mi->mbmi.interp_filter, xs, ys,
@@ -1280,9 +1293,9 @@
void modify_neighbor_predictor_for_obmc(MB_MODE_INFO *mbmi) {
if (is_interintra_pred(mbmi)) {
mbmi->ref_frame[1] = NONE;
- } else if (has_second_ref(mbmi) && is_interinter_wedge_used(mbmi->sb_type) &&
- mbmi->interinter_compound == COMPOUND_WEDGE) {
- mbmi->interinter_compound = COMPOUND_AVERAGE;
+ } else if (has_second_ref(mbmi) &&
+ is_masked_compound_type(mbmi->interinter_compound_data.type)) {
+ mbmi->interinter_compound_data.type = COMPOUND_AVERAGE;
mbmi->ref_frame[1] = NONE;
}
return;
@@ -2080,22 +2093,22 @@
MACROBLOCKD_PLANE *const pd = &xd->plane[plane];
struct buf_2d *const dst_buf = &pd->dst;
uint8_t *const dst = dst_buf->buf + dst_buf->stride * y + x;
+ const INTERINTER_COMPOUND_DATA *const comp_data =
+ &mbmi->interinter_compound_data;
- if (is_compound && is_interinter_wedge_used(mbmi->sb_type) &&
- mbmi->interinter_compound == COMPOUND_WEDGE) {
+ if (is_compound &&
+ is_masked_compound_type(mbmi->interinter_compound_data.type)) {
#if CONFIG_AOM_HIGHBITDEPTH
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
build_masked_compound_wedge_highbd(
dst, dst_buf->stride, CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0,
- CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1,
- mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign,
- mbmi->sb_type, h, w, xd->bd);
+ CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data->wedge_index,
+ comp_data->wedge_sign, mbmi->sb_type, h, w, xd->bd);
else
#endif // CONFIG_AOM_HIGHBITDEPTH
- build_masked_compound_wedge(
- dst, dst_buf->stride, ext_dst0, ext_dst_stride0, ext_dst1,
- ext_dst_stride1, mbmi->interinter_wedge_index,
- mbmi->interinter_wedge_sign, mbmi->sb_type, h, w);
+ build_masked_compound(dst, dst_buf->stride, ext_dst0, ext_dst_stride0,
+ ext_dst1, ext_dst_stride1, comp_data, mbmi->sb_type,
+ h, w);
} else {
#if CONFIG_AOM_HIGHBITDEPTH
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)