Rework conv_params setup
Set the single prediction index in a compound mode. Decide if to
do regular average or use mask multiplications depending on the
request mode next.
Change-Id: I8d52d2aef26941c051b2c8c65c059b341469eb0a
diff --git a/av1/common/convolve.h b/av1/common/convolve.h
index 6056647..04df86c 100644
--- a/av1/common/convolve.h
+++ b/av1/common/convolve.h
@@ -26,6 +26,7 @@
int round_1;
int plane;
int is_compound;
+ int compound_index; // 0: the first single in compound mode, 1: the second.
int use_dist_wtd_comp_avg;
int fwd_offset;
int bck_offset;
@@ -61,13 +62,14 @@
ConvolveParams *conv_params,
const struct scale_factors *sf);
-static INLINE ConvolveParams get_conv_params_no_round(int do_average, int plane,
+static INLINE ConvolveParams get_conv_params_no_round(int cmp_index, int plane,
CONV_BUF_TYPE *dst,
int dst_stride,
int is_compound, int bd) {
ConvolveParams conv_params;
- conv_params.do_average = do_average;
- assert(IMPLIES(do_average, is_compound));
+ conv_params.compound_index = cmp_index;
+ assert(IMPLIES(cmp_index, is_compound));
+
conv_params.is_compound = is_compound;
conv_params.round_0 = ROUND0_BITS;
conv_params.round_1 = is_compound ? COMPOUND_ROUND1_BITS
@@ -83,6 +85,10 @@
conv_params.dst = dst;
conv_params.dst_stride = dst_stride;
conv_params.plane = plane;
+
+ // By default, set do average to 1 if this is the second single prediction
+ // in a compound mode.
+ conv_params.do_average = cmp_index;
return conv_params;
}
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 52b4173..83a0150 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -116,6 +116,8 @@
inter_pred_params->sb_type = bsize;
inter_pred_params->mask_comp = *mask_comp;
inter_pred_params->comp_mode = MASK_COMP;
+
+ inter_pred_params->conv_params.do_average = 0;
}
void av1_make_inter_predictor(const uint8_t *src, int src_stride, uint8_t *dst,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 284cbd0..f801eab 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -695,6 +695,7 @@
int row = row_start;
int src_stride;
+ ref = 0;
for (int y = 0; y < b8_h; y += b4_h) {
int col = col_start;
for (int x = 0; x < b8_w; x += b4_w) {
@@ -704,12 +705,10 @@
InterPredParams inter_pred_params;
assert(bw < 8 || bh < 8);
inter_pred_params.conv_params = get_conv_params_no_round(
- 0, plane, xd->tmp_conv_dst, tmp_dst_stride, is_compound, xd->bd);
+ ref, plane, xd->tmp_conv_dst, tmp_dst_stride, is_compound, xd->bd);
inter_pred_params.conv_params.use_dist_wtd_comp_avg = 0;
struct buf_2d *const dst_buf = &pd->dst;
uint8_t *dst = dst_buf->buf + dst_buf->stride * y + x;
-
- ref = 0;
const RefCntBuffer *ref_buf =
get_ref_frame_buf(cm, this_mbmi->ref_frame[ref]);
const struct scale_factors *ref_scale_factors =
@@ -805,13 +804,6 @@
&src_stride[ref]);
}
- inter_pred_params.conv_params = get_conv_params_no_round(
- 0, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
- av1_dist_wtd_comp_weight_assign(
- cm, mi, 0, &inter_pred_params.conv_params.fwd_offset,
- &inter_pred_params.conv_params.bck_offset,
- &inter_pred_params.conv_params.use_dist_wtd_comp_avg, is_compound);
-
for (ref = 0; ref < 1 + is_compound; ++ref) {
struct buf_2d *const pre_buf = is_intrabc ? dst_buf : &pd->pre[ref];
const struct scale_factors *const sf =
@@ -827,6 +819,15 @@
pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd),
mi->use_intrabc, sf, pre_buf, mi->interp_filters);
if (is_compound) av1_init_comp_mode(&inter_pred_params);
+
+ inter_pred_params.conv_params = get_conv_params_no_round(
+ ref, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
+
+ av1_dist_wtd_comp_weight_assign(
+ cm, mi, 0, &inter_pred_params.conv_params.fwd_offset,
+ &inter_pred_params.conv_params.bck_offset,
+ &inter_pred_params.conv_params.use_dist_wtd_comp_avg, is_compound);
+
if (!build_for_obmc)
av1_init_warp_params(&inter_pred_params, &pd->pre[ref], &warp_types,
ref, xd, mi);
diff --git a/av1/encoder/reconinter_enc.c b/av1/encoder/reconinter_enc.c
index 6c9ec55..ec60123 100644
--- a/av1/encoder/reconinter_enc.c
+++ b/av1/encoder/reconinter_enc.c
@@ -158,7 +158,7 @@
pre_buf, this_mbmi->interp_filters);
inter_pred_params.conv_params = get_conv_params_no_round(
- 0, plane, xd->tmp_conv_dst, tmp_dst_stride, 0, xd->bd);
+ ref, plane, xd->tmp_conv_dst, tmp_dst_stride, 0, xd->bd);
inter_pred_params.conv_params.use_dist_wtd_comp_avg = 0;
av1_build_inter_predictor(pre_buf->buf, pre_buf->stride, dst,
@@ -201,15 +201,13 @@
if (is_compound) av1_init_comp_mode(&inter_pred_params);
inter_pred_params.conv_params = get_conv_params_no_round(
- 0, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
+ ref, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
av1_dist_wtd_comp_weight_assign(
cm, mi, 0, &inter_pred_params.conv_params.fwd_offset,
&inter_pred_params.conv_params.bck_offset,
&inter_pred_params.conv_params.use_dist_wtd_comp_avg, is_compound);
- inter_pred_params.conv_params.do_average = ref;
-
av1_init_warp_params(&inter_pred_params, &pre_buf, &warp_types, ref, xd,
mi);