Refactor masked compound prediction

Unify data access with regular inter prediction.

Change-Id: Idffd43b28b1fb00c5951567745ad51268f539f7d
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 17e9aa2..22a0429 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -106,6 +106,15 @@
   inter_pred_params->ref_frame_buf = *ref_buf;
 }
 
+void av1_init_mask_comp(InterPredParams *inter_pred_params, BLOCK_SIZE bsize,
+                        const INTERINTER_COMPOUND_DATA *mask_comp) {
+  inter_pred_params->sb_type = bsize;
+  if (inter_pred_params->conv_params.do_average &&
+      is_masked_compound_type(mask_comp->type)) {
+    inter_pred_params->mask_comp = *mask_comp;
+  }
+}
+
 void av1_make_inter_predictor(const uint8_t *src, int src_stride, uint8_t *dst,
                               int dst_stride,
                               InterPredParams *inter_pred_params,
@@ -553,44 +562,42 @@
     uint8_t *dst, int dst_stride, const CONV_BUF_TYPE *src0, int src0_stride,
     const CONV_BUF_TYPE *src1, int src1_stride,
     const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
-    int w, ConvolveParams *conv_params, MACROBLOCKD *xd) {
-  // Derive subsampling from h and w passed in. May be refactored to
-  // pass in subsampling factors directly.
-  const int subh = (2 << mi_size_high_log2[sb_type]) == h;
-  const int subw = (2 << mi_size_wide_log2[sb_type]) == w;
+    int w, InterPredParams *inter_pred_params) {
+  const int ssy = inter_pred_params->subsampling_y;
+  const int ssx = inter_pred_params->subsampling_x;
   const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type);
+  const int mask_stride = block_size_wide[sb_type];
 #if CONFIG_AV1_HIGHBITDEPTH
-  if (is_cur_buf_hbd(xd)) {
+  if (inter_pred_params->use_hbd_buf) {
     aom_highbd_blend_a64_d16_mask(dst, dst_stride, src0, src0_stride, src1,
-                                  src1_stride, mask, block_size_wide[sb_type],
-                                  w, h, subw, subh, conv_params, xd->bd);
+                                  src1_stride, mask, mask_stride, w, h, ssx,
+                                  ssy, &inter_pred_params->conv_params,
+                                  inter_pred_params->bit_depth);
   } else {
     aom_lowbd_blend_a64_d16_mask(dst, dst_stride, src0, src0_stride, src1,
-                                 src1_stride, mask, block_size_wide[sb_type], w,
-                                 h, subw, subh, conv_params);
+                                 src1_stride, mask, mask_stride, w, h, ssx, ssy,
+                                 &inter_pred_params->conv_params);
   }
 #else
-  (void)xd;
   aom_lowbd_blend_a64_d16_mask(dst, dst_stride, src0, src0_stride, src1,
-                               src1_stride, mask, block_size_wide[sb_type], w,
-                               h, subw, subh, conv_params);
+                               src1_stride, mask, mask_stride, w, h, ssx, ssy,
+                               &inter_pred_params->conv_params);
 #endif
 }
 
 void av1_make_masked_inter_predictor(const uint8_t *pre, int pre_stride,
                                      uint8_t *dst, int dst_stride,
                                      InterPredParams *inter_pred_params,
-                                     const SubpelParams *subpel_params, int w,
-                                     int h, int plane, MACROBLOCKD *xd) {
-  MB_MODE_INFO *mi = xd->mi[0];
-  mi->interinter_comp.seg_mask = xd->seg_mask;
-  const INTERINTER_COMPOUND_DATA *comp_data = &mi->interinter_comp;
+                                     const SubpelParams *subpel_params) {
+  const INTERINTER_COMPOUND_DATA *comp_data = &inter_pred_params->mask_comp;
+  BLOCK_SIZE sb_type = inter_pred_params->sb_type;
 
   // We're going to call av1_make_inter_predictor to generate a prediction into
   // a temporary buffer, then will blend that temporary buffer with that from
   // the other reference.
   DECLARE_ALIGNED(32, uint8_t, tmp_buf[2 * MAX_SB_SQUARE]);
-  uint8_t *tmp_dst = get_buf_by_bd(xd, tmp_buf);
+  uint8_t *tmp_dst =
+      inter_pred_params->use_hbd_buf ? CONVERT_TO_BYTEPTR(tmp_buf) : tmp_buf;
 
   const int tmp_buf_stride = MAX_SB_SIZE;
   CONV_BUF_TYPE *org_dst = inter_pred_params->conv_params.dst;
@@ -604,15 +611,18 @@
   av1_make_inter_predictor(pre, pre_stride, tmp_dst, MAX_SB_SIZE,
                            inter_pred_params, subpel_params);
 
-  if (!plane && comp_data->type == COMPOUND_DIFFWTD) {
+  if (!inter_pred_params->conv_params.plane &&
+      comp_data->type == COMPOUND_DIFFWTD) {
     av1_build_compound_diffwtd_mask_d16(
         comp_data->seg_mask, comp_data->mask_type, org_dst, org_dst_stride,
-        tmp_buf16, tmp_buf_stride, h, w, &inter_pred_params->conv_params,
-        xd->bd);
+        tmp_buf16, tmp_buf_stride, inter_pred_params->block_height,
+        inter_pred_params->block_width, &inter_pred_params->conv_params,
+        inter_pred_params->bit_depth);
   }
   build_masked_compound_no_round(
       dst, dst_stride, org_dst, org_dst_stride, tmp_buf16, tmp_buf_stride,
-      comp_data, mi->sb_type, h, w, &inter_pred_params->conv_params, xd);
+      comp_data, sb_type, inter_pred_params->block_height,
+      inter_pred_params->block_width, inter_pred_params);
 }
 
 void av1_dist_wtd_comp_weight_assign(const AV1_COMMON *cm,
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 5e33c72..5e6488a 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -112,6 +112,8 @@
   const struct scale_factors *scale_factors;
   int bit_depth;
   int use_hbd_buf;
+  INTERINTER_COMPOUND_DATA mask_comp;
+  BLOCK_SIZE sb_type;
   int is_intrabc;
 } InterPredParams;
 
@@ -128,6 +130,9 @@
                           const WarpTypesAllowed *warp_types, int ref,
                           const MACROBLOCKD *xd, const MB_MODE_INFO *mi);
 
+void av1_init_mask_comp(InterPredParams *inter_pred_params, BLOCK_SIZE bsize,
+                        const INTERINTER_COMPOUND_DATA *mask_comp);
+
 static INLINE int has_scale(int xs, int ys) {
   return xs != SCALE_SUBPEL_SHIFTS || ys != SCALE_SUBPEL_SHIFTS;
 }
@@ -242,8 +247,7 @@
 void av1_make_masked_inter_predictor(const uint8_t *pre, int pre_stride,
                                      uint8_t *dst, int dst_stride,
                                      InterPredParams *inter_pred_params,
-                                     const SubpelParams *subpel_params, int w,
-                                     int h, int plane, MACROBLOCKD *xd);
+                                     const SubpelParams *subpel_params);
 
 // TODO(jkoleszar): yet another mv clamping function :-(
 static INLINE MV clamp_mv_to_umv_border_sb(const MACROBLOCKD *xd,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index b3d19c5..cbea8f3 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -820,10 +820,6 @@
       warp_types.global_warp_allowed = is_global[ref];
       warp_types.local_warp_allowed = mi->motion_mode == WARPED_CAUSAL;
       inter_pred_params.conv_params.do_average = ref;
-      if (is_masked_compound_type(mi->interinter_comp.type)) {
-        // masked compound type has its own average mechanism
-        inter_pred_params.conv_params.do_average = 0;
-      }
 
       av1_init_inter_params(&inter_pred_params, bw, bh,
                             mi_y >> pd->subsampling_y,
@@ -835,14 +831,21 @@
         av1_init_warp_params(&inter_pred_params, &pd->pre[ref], &warp_types,
                              ref, xd, mi);
 
-      if (ref && is_masked_compound_type(mi->interinter_comp.type))
+      av1_init_mask_comp(&inter_pred_params, mi->sb_type, &mi->interinter_comp);
+      inter_pred_params.mask_comp.seg_mask = xd->seg_mask;
+
+      if (ref && is_masked_compound_type(mi->interinter_comp.type)) {
+        // masked compound type has its own average mechanism
+        inter_pred_params.conv_params.do_average = 0;
+
         av1_make_masked_inter_predictor(pre[ref], src_stride[ref], dst,
                                         dst_buf->stride, &inter_pred_params,
-                                        &subpel_params[ref], bw, bh, plane, xd);
-      else
+                                        &subpel_params[ref]);
+      } else {
         av1_make_inter_predictor(pre[ref], src_stride[ref], dst,
                                  dst_buf->stride, &inter_pred_params,
                                  &subpel_params[ref]);
+      }
     }
   }
 }
diff --git a/av1/encoder/reconinter_enc.c b/av1/encoder/reconinter_enc.c
index cc21df0..b424829 100644
--- a/av1/encoder/reconinter_enc.c
+++ b/av1/encoder/reconinter_enc.c
@@ -177,6 +177,7 @@
     InterPredParams inter_pred_params;
     inter_pred_params.conv_params = get_conv_params_no_round(
         0, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
+
     av1_dist_wtd_comp_weight_assign(
         cm, mi, 0, &inter_pred_params.conv_params.fwd_offset,
         &inter_pred_params.conv_params.bck_offset,
@@ -205,17 +206,22 @@
                             pd->subsampling_y, xd->bd, is_cur_buf_hbd(xd),
                             mi->use_intrabc, sf, pre_buf, mi->interp_filters);
 
+      inter_pred_params.conv_params.do_average = ref;
+
       av1_init_warp_params(&inter_pred_params, &pd->pre[ref], &warp_types, ref,
                            xd, mi);
 
+      av1_init_mask_comp(&inter_pred_params, mi->sb_type, &mi->interinter_comp);
+      // Assigne physical buffer
+      inter_pred_params.mask_comp.seg_mask = xd->seg_mask;
+
       if (ref && is_masked_compound_type(mi->interinter_comp.type)) {
         // masked compound type has its own average mechanism
         inter_pred_params.conv_params.do_average = 0;
         av1_make_masked_inter_predictor(pre, pre_buf->stride, dst,
                                         dst_buf->stride, &inter_pred_params,
-                                        &subpel_params, bw, bh, plane, xd);
+                                        &subpel_params);
       } else {
-        inter_pred_params.conv_params.do_average = ref;
         av1_make_inter_predictor(pre, pre_buf->stride, dst, dst_buf->stride,
                                  &inter_pred_params, &subpel_params);
       }