Refactor masked compound prediction
Unify data access with regular inter prediction.
Change-Id: Idffd43b28b1fb00c5951567745ad51268f539f7d
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 17e9aa2..22a0429 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -106,6 +106,15 @@
inter_pred_params->ref_frame_buf = *ref_buf;
}
+void av1_init_mask_comp(InterPredParams *inter_pred_params, BLOCK_SIZE bsize,
+ const INTERINTER_COMPOUND_DATA *mask_comp) {
+ inter_pred_params->sb_type = bsize;
+ if (inter_pred_params->conv_params.do_average &&
+ is_masked_compound_type(mask_comp->type)) {
+ inter_pred_params->mask_comp = *mask_comp;
+ }
+}
+
void av1_make_inter_predictor(const uint8_t *src, int src_stride, uint8_t *dst,
int dst_stride,
InterPredParams *inter_pred_params,
@@ -553,44 +562,42 @@
uint8_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
const CONV_BUF_TYPE *src1, int src1_stride,
const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
- int w, ConvolveParams *conv_params, MACROBLOCKD *xd) {
- // Derive subsampling from h and w passed in. May be refactored to
- // pass in subsampling factors directly.
- const int subh = (2 << mi_size_high_log2[sb_type]) == h;
- const int subw = (2 << mi_size_wide_log2[sb_type]) == w;
+ int w, InterPredParams *inter_pred_params) {
+ const int ssy = inter_pred_params->subsampling_y;
+ const int ssx = inter_pred_params->subsampling_x;
const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type);
+ const int mask_stride = block_size_wide[sb_type];
#if CONFIG_AV1_HIGHBITDEPTH
- if (is_cur_buf_hbd(xd)) {
+ if (inter_pred_params->use_hbd_buf) {
aom_highbd_blend_a64_d16_mask(dst, dst_stride, src0, src0_stride, src1,
- src1_stride, mask, block_size_wide[sb_type],
- w, h, subw, subh, conv_params, xd->bd);
+ src1_stride, mask, mask_stride, w, h, ssx,
+ ssy, &inter_pred_params->conv_params,
+ inter_pred_params->bit_depth);
} else {
aom_lowbd_blend_a64_d16_mask(dst, dst_stride, src0, src0_stride, src1,
- src1_stride, mask, block_size_wide[sb_type], w,
- h, subw, subh, conv_params);
+ src1_stride, mask, mask_stride, w, h, ssx, ssy,
+ &inter_pred_params->conv_params);
}
#else
- (void)xd;
aom_lowbd_blend_a64_d16_mask(dst, dst_stride, src0, src0_stride, src1,
- src1_stride, mask, block_size_wide[sb_type], w,
- h, subw, subh, conv_params);
+ src1_stride, mask, mask_stride, w, h, ssx, ssy,
+ &inter_pred_params->conv_params);
#endif
}
void av1_make_masked_inter_predictor(const uint8_t *pre, int pre_stride,
uint8_t *dst, int dst_stride,
InterPredParams *inter_pred_params,
- const SubpelParams *subpel_params, int w,
- int h, int plane, MACROBLOCKD *xd) {
- MB_MODE_INFO *mi = xd->mi[0];
- mi->interinter_comp.seg_mask = xd->seg_mask;
- const INTERINTER_COMPOUND_DATA *comp_data = &mi->interinter_comp;
+ const SubpelParams *subpel_params) {
+ const INTERINTER_COMPOUND_DATA *comp_data = &inter_pred_params->mask_comp;
+ BLOCK_SIZE sb_type = inter_pred_params->sb_type;
// We're going to call av1_make_inter_predictor to generate a prediction into
// a temporary buffer, then will blend that temporary buffer with that from
// the other reference.
DECLARE_ALIGNED(32, uint8_t, tmp_buf[2 * MAX_SB_SQUARE]);
- uint8_t *tmp_dst = get_buf_by_bd(xd, tmp_buf);
+ uint8_t *tmp_dst =
+ inter_pred_params->use_hbd_buf ? CONVERT_TO_BYTEPTR(tmp_buf) : tmp_buf;
const int tmp_buf_stride = MAX_SB_SIZE;
CONV_BUF_TYPE *org_dst = inter_pred_params->conv_params.dst;
@@ -604,15 +611,18 @@
av1_make_inter_predictor(pre, pre_stride, tmp_dst, MAX_SB_SIZE,
inter_pred_params, subpel_params);
- if (!plane && comp_data->type == COMPOUND_DIFFWTD) {
+ if (!inter_pred_params->conv_params.plane &&
+ comp_data->type == COMPOUND_DIFFWTD) {
av1_build_compound_diffwtd_mask_d16(
comp_data->seg_mask, comp_data->mask_type, org_dst, org_dst_stride,
- tmp_buf16, tmp_buf_stride, h, w, &inter_pred_params->conv_params,
- xd->bd);
+ tmp_buf16, tmp_buf_stride, inter_pred_params->block_height,
+ inter_pred_params->block_width, &inter_pred_params->conv_params,
+ inter_pred_params->bit_depth);
}
build_masked_compound_no_round(
dst, dst_stride, org_dst, org_dst_stride, tmp_buf16, tmp_buf_stride,
- comp_data, mi->sb_type, h, w, &inter_pred_params->conv_params, xd);
+ comp_data, sb_type, inter_pred_params->block_height,
+ inter_pred_params->block_width, inter_pred_params);
}
void av1_dist_wtd_comp_weight_assign(const AV1_COMMON *cm,
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 5e33c72..5e6488a 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -112,6 +112,8 @@
const struct scale_factors *scale_factors;
int bit_depth;
int use_hbd_buf;
+ INTERINTER_COMPOUND_DATA mask_comp;
+ BLOCK_SIZE sb_type;
int is_intrabc;
} InterPredParams;
@@ -128,6 +130,9 @@
const WarpTypesAllowed *warp_types, int ref,
const MACROBLOCKD *xd, const MB_MODE_INFO *mi);
+void av1_init_mask_comp(InterPredParams *inter_pred_params, BLOCK_SIZE bsize,
+ const INTERINTER_COMPOUND_DATA *mask_comp);
+
static INLINE int has_scale(int xs, int ys) {
return xs != SCALE_SUBPEL_SHIFTS || ys != SCALE_SUBPEL_SHIFTS;
}
@@ -242,8 +247,7 @@
void av1_make_masked_inter_predictor(const uint8_t *pre, int pre_stride,
uint8_t *dst, int dst_stride,
InterPredParams *inter_pred_params,
- const SubpelParams *subpel_params, int w,
- int h, int plane, MACROBLOCKD *xd);
+ const SubpelParams *subpel_params);
// TODO(jkoleszar): yet another mv clamping function :-(
static INLINE MV clamp_mv_to_umv_border_sb(const MACROBLOCKD *xd,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index b3d19c5..cbea8f3 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -820,10 +820,6 @@
warp_types.global_warp_allowed = is_global[ref];
warp_types.local_warp_allowed = mi->motion_mode == WARPED_CAUSAL;
inter_pred_params.conv_params.do_average = ref;
- if (is_masked_compound_type(mi->interinter_comp.type)) {
- // masked compound type has its own average mechanism
- inter_pred_params.conv_params.do_average = 0;
- }
av1_init_inter_params(&inter_pred_params, bw, bh,
mi_y >> pd->subsampling_y,
@@ -835,14 +831,21 @@
av1_init_warp_params(&inter_pred_params, &pd->pre[ref], &warp_types,
ref, xd, mi);
- if (ref && is_masked_compound_type(mi->interinter_comp.type))
+ av1_init_mask_comp(&inter_pred_params, mi->sb_type, &mi->interinter_comp);
+ inter_pred_params.mask_comp.seg_mask = xd->seg_mask;
+
+ if (ref && is_masked_compound_type(mi->interinter_comp.type)) {
+ // masked compound type has its own average mechanism
+ inter_pred_params.conv_params.do_average = 0;
+
av1_make_masked_inter_predictor(pre[ref], src_stride[ref], dst,
dst_buf->stride, &inter_pred_params,
- &subpel_params[ref], bw, bh, plane, xd);
- else
+ &subpel_params[ref]);
+ } else {
av1_make_inter_predictor(pre[ref], src_stride[ref], dst,
dst_buf->stride, &inter_pred_params,
&subpel_params[ref]);
+ }
}
}
}
diff --git a/av1/encoder/reconinter_enc.c b/av1/encoder/reconinter_enc.c
index cc21df0..b424829 100644
--- a/av1/encoder/reconinter_enc.c
+++ b/av1/encoder/reconinter_enc.c
@@ -177,6 +177,7 @@
InterPredParams inter_pred_params;
inter_pred_params.conv_params = get_conv_params_no_round(
0, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
+
av1_dist_wtd_comp_weight_assign(
cm, mi, 0, &inter_pred_params.conv_params.fwd_offset,
&inter_pred_params.conv_params.bck_offset,
@@ -205,17 +206,22 @@
pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd),
mi->use_intrabc, sf, pre_buf, mi->interp_filters);
+ inter_pred_params.conv_params.do_average = ref;
+
av1_init_warp_params(&inter_pred_params, &pd->pre[ref], &warp_types, ref,
xd, mi);
+ av1_init_mask_comp(&inter_pred_params, mi->sb_type, &mi->interinter_comp);
+ // Assigne physical buffer
+ inter_pred_params.mask_comp.seg_mask = xd->seg_mask;
+
if (ref && is_masked_compound_type(mi->interinter_comp.type)) {
// masked compound type has its own average mechanism
inter_pred_params.conv_params.do_average = 0;
av1_make_masked_inter_predictor(pre, pre_buf->stride, dst,
dst_buf->stride, &inter_pred_params,
- &subpel_params, bw, bh, plane, xd);
+ &subpel_params);
} else {
- inter_pred_params.conv_params.do_average = ref;
av1_make_inter_predictor(pre, pre_buf->stride, dst, dst_buf->stride,
&inter_pred_params, &subpel_params);
}