Rework conv_params setup

Set the single prediction index in a compound mode. Decide if to
do regular average or use mask multiplications depending on the
request mode next.

Change-Id: I8d52d2aef26941c051b2c8c65c059b341469eb0a
diff --git a/av1/common/convolve.h b/av1/common/convolve.h
index 6056647..04df86c 100644
--- a/av1/common/convolve.h
+++ b/av1/common/convolve.h
@@ -26,6 +26,7 @@
   int round_1;
   int plane;
   int is_compound;
+  int compound_index;  // 0: the first single in compound mode, 1: the second.
   int use_dist_wtd_comp_avg;
   int fwd_offset;
   int bck_offset;
@@ -61,13 +62,14 @@
                             ConvolveParams *conv_params,
                             const struct scale_factors *sf);
 
-static INLINE ConvolveParams get_conv_params_no_round(int do_average, int plane,
+static INLINE ConvolveParams get_conv_params_no_round(int cmp_index, int plane,
                                                       CONV_BUF_TYPE *dst,
                                                       int dst_stride,
                                                       int is_compound, int bd) {
   ConvolveParams conv_params;
-  conv_params.do_average = do_average;
-  assert(IMPLIES(do_average, is_compound));
+  conv_params.compound_index = cmp_index;
+  assert(IMPLIES(cmp_index, is_compound));
+
   conv_params.is_compound = is_compound;
   conv_params.round_0 = ROUND0_BITS;
   conv_params.round_1 = is_compound ? COMPOUND_ROUND1_BITS
@@ -83,6 +85,10 @@
   conv_params.dst = dst;
   conv_params.dst_stride = dst_stride;
   conv_params.plane = plane;
+
+  // By default, set do average to 1 if this is the second single prediction
+  // in a compound mode.
+  conv_params.do_average = cmp_index;
   return conv_params;
 }
 
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 52b4173..83a0150 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -116,6 +116,8 @@
   inter_pred_params->sb_type = bsize;
   inter_pred_params->mask_comp = *mask_comp;
   inter_pred_params->comp_mode = MASK_COMP;
+
+  inter_pred_params->conv_params.do_average = 0;
 }
 
 void av1_make_inter_predictor(const uint8_t *src, int src_stride, uint8_t *dst,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 284cbd0..f801eab 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -695,6 +695,7 @@
 
     int row = row_start;
     int src_stride;
+    ref = 0;
     for (int y = 0; y < b8_h; y += b4_h) {
       int col = col_start;
       for (int x = 0; x < b8_w; x += b4_w) {
@@ -704,12 +705,10 @@
         InterPredParams inter_pred_params;
         assert(bw < 8 || bh < 8);
         inter_pred_params.conv_params = get_conv_params_no_round(
-            0, plane, xd->tmp_conv_dst, tmp_dst_stride, is_compound, xd->bd);
+            ref, plane, xd->tmp_conv_dst, tmp_dst_stride, is_compound, xd->bd);
         inter_pred_params.conv_params.use_dist_wtd_comp_avg = 0;
         struct buf_2d *const dst_buf = &pd->dst;
         uint8_t *dst = dst_buf->buf + dst_buf->stride * y + x;
-
-        ref = 0;
         const RefCntBuffer *ref_buf =
             get_ref_frame_buf(cm, this_mbmi->ref_frame[ref]);
         const struct scale_factors *ref_scale_factors =
@@ -805,13 +804,6 @@
                        &src_stride[ref]);
     }
 
-    inter_pred_params.conv_params = get_conv_params_no_round(
-        0, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
-    av1_dist_wtd_comp_weight_assign(
-        cm, mi, 0, &inter_pred_params.conv_params.fwd_offset,
-        &inter_pred_params.conv_params.bck_offset,
-        &inter_pred_params.conv_params.use_dist_wtd_comp_avg, is_compound);
-
     for (ref = 0; ref < 1 + is_compound; ++ref) {
       struct buf_2d *const pre_buf = is_intrabc ? dst_buf : &pd->pre[ref];
       const struct scale_factors *const sf =
@@ -827,6 +819,15 @@
                             pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd),
                             mi->use_intrabc, sf, pre_buf, mi->interp_filters);
       if (is_compound) av1_init_comp_mode(&inter_pred_params);
+
+      inter_pred_params.conv_params = get_conv_params_no_round(
+          ref, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
+
+      av1_dist_wtd_comp_weight_assign(
+          cm, mi, 0, &inter_pred_params.conv_params.fwd_offset,
+          &inter_pred_params.conv_params.bck_offset,
+          &inter_pred_params.conv_params.use_dist_wtd_comp_avg, is_compound);
+
       if (!build_for_obmc)
         av1_init_warp_params(&inter_pred_params, &pd->pre[ref], &warp_types,
                              ref, xd, mi);
diff --git a/av1/encoder/reconinter_enc.c b/av1/encoder/reconinter_enc.c
index 6c9ec55..ec60123 100644
--- a/av1/encoder/reconinter_enc.c
+++ b/av1/encoder/reconinter_enc.c
@@ -158,7 +158,7 @@
                               pre_buf, this_mbmi->interp_filters);
 
         inter_pred_params.conv_params = get_conv_params_no_round(
-            0, plane, xd->tmp_conv_dst, tmp_dst_stride, 0, xd->bd);
+            ref, plane, xd->tmp_conv_dst, tmp_dst_stride, 0, xd->bd);
         inter_pred_params.conv_params.use_dist_wtd_comp_avg = 0;
 
         av1_build_inter_predictor(pre_buf->buf, pre_buf->stride, dst,
@@ -201,15 +201,13 @@
       if (is_compound) av1_init_comp_mode(&inter_pred_params);
 
       inter_pred_params.conv_params = get_conv_params_no_round(
-          0, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
+          ref, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
 
       av1_dist_wtd_comp_weight_assign(
           cm, mi, 0, &inter_pred_params.conv_params.fwd_offset,
           &inter_pred_params.conv_params.bck_offset,
           &inter_pred_params.conv_params.use_dist_wtd_comp_avg, is_compound);
 
-      inter_pred_params.conv_params.do_average = ref;
-
       av1_init_warp_params(&inter_pred_params, &pre_buf, &warp_types, ref, xd,
                            mi);