ext-inter: Better motion search for compound modes

Use a variant of joint_motion_search to improve the motion search
for compound blocks (both average and masked types) with only one
NEWMV component.

Change-Id: I7cae812dced24acde638ae49869e6986557ce7dd
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index af0a9ce..c1a3ede 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7144,127 +7144,307 @@
 }
 
 #if CONFIG_EXT_INTER
-#if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
-static void do_masked_motion_search(const AV1_COMP *const cpi, MACROBLOCK *x,
-                                    const uint8_t *mask, int mask_stride,
-                                    BLOCK_SIZE bsize, int mi_row, int mi_col,
-                                    int_mv *tmp_mv, int *rate_mv, int ref_idx) {
+static void build_second_inter_pred(const AV1_COMP *cpi, MACROBLOCK *x,
+                                    BLOCK_SIZE bsize, int_mv *frame_mv,
+                                    int mi_row, int mi_col, const int block,
+                                    int ref_idx, uint8_t *second_pred) {
+  const AV1_COMMON *const cm = &cpi->common;
+  const int pw = block_size_wide[bsize];
+  const int ph = block_size_high[bsize];
   MACROBLOCKD *xd = &x->e_mbd;
-  const AV1_COMMON *cm = &cpi->common;
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
-  struct buf_2d backup_yv12[MAX_MB_PLANE] = { { 0, 0, 0, 0, 0 } };
-  int bestsme = INT_MAX;
-  int step_param;
-  int sadpb = x->sadperbit16;
-  MV mvp_full;
-  int ref = mbmi->ref_frame[ref_idx];
-  MV ref_mv = x->mbmi_ext->ref_mvs[ref][0].as_mv;
+  const int other_ref = mbmi->ref_frame[!ref_idx];
+#if CONFIG_DUAL_FILTER
+  InterpFilter interp_filter[2] = {
+    (ref_idx == 0) ? mbmi->interp_filter[2] : mbmi->interp_filter[0],
+    (ref_idx == 0) ? mbmi->interp_filter[3] : mbmi->interp_filter[1]
+  };
+#else
+  const InterpFilter interp_filter = mbmi->interp_filter;
+#endif  // CONFIG_DUAL_FILTER
+  struct scale_factors sf;
+#if CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
+  struct macroblockd_plane *const pd = &xd->plane[0];
+  // ic and ir are the 4x4 coordiantes of the sub8x8 at index "block"
+  const int ic = block & 1;
+  const int ir = (block - ic) >> 1;
+  const int p_col = ((mi_col * MI_SIZE) >> pd->subsampling_x) + 4 * ic;
+  const int p_row = ((mi_row * MI_SIZE) >> pd->subsampling_y) + 4 * ir;
+#if CONFIG_GLOBAL_MOTION
+  WarpedMotionParams *const wm = &xd->global_motion[other_ref];
+  int is_global = is_global_mv_block(xd->mi[0], block, wm->wmtype);
+#endif  // CONFIG_GLOBAL_MOTION
+#else
+  (void)block;
+#endif  // CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
 
-  MvLimits tmp_mv_limits = x->mv_limits;
+  // This function should only ever be called for compound modes
+  assert(has_second_ref(mbmi));
 
-  const YV12_BUFFER_CONFIG *scaled_ref_frame =
-      av1_get_scaled_ref_frame(cpi, ref);
-  int i;
-
-  MV pred_mv[3];
-  pred_mv[0] = x->mbmi_ext->ref_mvs[ref][0].as_mv;
-  pred_mv[1] = x->mbmi_ext->ref_mvs[ref][1].as_mv;
-  pred_mv[2] = x->pred_mv[ref];
-
-  av1_set_mvcost(x, ref, ref_idx, mbmi->ref_mv_idx);
+  struct buf_2d backup_yv12[MAX_MB_PLANE];
+  const YV12_BUFFER_CONFIG *const scaled_ref_frame =
+      av1_get_scaled_ref_frame(cpi, other_ref);
 
   if (scaled_ref_frame) {
+    int i;
+    // Swap out the reference frame for a version that's been scaled to
+    // match the resolution of the current frame, allowing the existing
+    // motion search code to be used without additional modifications.
+    for (i = 0; i < MAX_MB_PLANE; i++)
+      backup_yv12[i] = xd->plane[i].pre[!ref_idx];
+    av1_setup_pre_planes(xd, !ref_idx, scaled_ref_frame, mi_row, mi_col, NULL);
+  }
+
+// Since we have scaled the reference frames to match the size of the current
+// frame we must use a unit scaling factor during mode selection.
+#if CONFIG_HIGHBITDEPTH
+  av1_setup_scale_factors_for_frame(&sf, cm->width, cm->height, cm->width,
+                                    cm->height, cm->use_highbitdepth);
+#else
+  av1_setup_scale_factors_for_frame(&sf, cm->width, cm->height, cm->width,
+                                    cm->height);
+#endif  // CONFIG_HIGHBITDEPTH
+
+  struct buf_2d ref_yv12;
+
+  const int plane = 0;
+  ConvolveParams conv_params = get_conv_params(0, plane);
+#if CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
+  WarpTypesAllowed warp_types;
+#if CONFIG_GLOBAL_MOTION
+  warp_types.global_warp_allowed = is_global;
+#endif  // CONFIG_GLOBAL_MOTION
+#if CONFIG_WARPED_MOTION
+  warp_types.local_warp_allowed = mbmi->motion_mode == WARPED_CAUSAL;
+#endif  // CONFIG_WARPED_MOTION
+#endif  // CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
+
+  // Initialized here because of compiler problem in Visual Studio.
+  ref_yv12 = xd->plane[plane].pre[!ref_idx];
+
+// Get the prediction block from the 'other' reference frame.
+#if CONFIG_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    av1_highbd_build_inter_predictor(
+        ref_yv12.buf, ref_yv12.stride, second_pred, pw,
+        &frame_mv[other_ref].as_mv, &sf, pw, ph, 0, interp_filter,
+#if CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
+        &warp_types, p_col, p_row,
+#endif  // CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
+        plane, MV_PRECISION_Q3, mi_col * MI_SIZE, mi_row * MI_SIZE, xd);
+  } else {
+#endif  // CONFIG_HIGHBITDEPTH
+    av1_build_inter_predictor(
+        ref_yv12.buf, ref_yv12.stride, second_pred, pw,
+        &frame_mv[other_ref].as_mv, &sf, pw, ph, &conv_params, interp_filter,
+#if CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
+        &warp_types, p_col, p_row, plane, !ref_idx,
+#endif  // CONFIG_GLOBAL_MOTION || CONFIG_WARPED_MOTION
+        MV_PRECISION_Q3, mi_col * MI_SIZE, mi_row * MI_SIZE, xd);
+#if CONFIG_HIGHBITDEPTH
+  }
+#endif  // CONFIG_HIGHBITDEPTH
+
+  if (scaled_ref_frame) {
+    // Restore the prediction frame pointers to their unscaled versions.
+    int i;
+    for (i = 0; i < MAX_MB_PLANE; i++)
+      xd->plane[i].pre[!ref_idx] = backup_yv12[i];
+  }
+}
+
+// Search for the best mv for one component of a compound,
+// given that the other component is fixed.
+static void compound_single_motion_search(
+    const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int_mv *frame_mv,
+    int mi_row, int mi_col, int_mv *ref_mv_sub8x8[2],
+    const uint8_t *second_pred, const uint8_t *mask, int mask_stride,
+    int *rate_mv, const int block, int ref_idx) {
+  const int pw = block_size_wide[bsize];
+  const int ph = block_size_high[bsize];
+  MACROBLOCKD *xd = &x->e_mbd;
+  MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+  const int ref = mbmi->ref_frame[ref_idx];
+  int_mv ref_mv;
+  struct macroblockd_plane *const pd = &xd->plane[0];
+
+  struct buf_2d backup_yv12[MAX_MB_PLANE];
+  int last_besterr = INT_MAX;
+  const YV12_BUFFER_CONFIG *const scaled_ref_frame =
+      av1_get_scaled_ref_frame(cpi, ref);
+
+#if CONFIG_CB4X4
+  (void)ref_mv_sub8x8;
+#endif  // CONFIG_CB4X4
+
+  // Check that this is either an interinter or an interintra block
+  assert(has_second_ref(mbmi) ||
+         (ref_idx == 0 && mbmi->ref_frame[1] == INTRA_FRAME));
+
+#if !CONFIG_CB4X4
+  if (bsize < BLOCK_8X8 && ref_mv_sub8x8 != NULL)
+    ref_mv.as_int = ref_mv_sub8x8[ref_idx]->as_int;
+  else
+#endif  // !CONFIG_CB4X4
+    ref_mv = x->mbmi_ext->ref_mvs[ref][0];
+
+  if (scaled_ref_frame) {
+    int i;
     // Swap out the reference frame for a version that's been scaled to
     // match the resolution of the current frame, allowing the existing
     // motion search code to be used without additional modifications.
     for (i = 0; i < MAX_MB_PLANE; i++)
       backup_yv12[i] = xd->plane[i].pre[ref_idx];
-
     av1_setup_pre_planes(xd, ref_idx, scaled_ref_frame, mi_row, mi_col, NULL);
   }
 
-  av1_set_mv_search_range(&x->mv_limits, &ref_mv);
+  struct buf_2d orig_yv12;
+  int bestsme = INT_MAX;
+  int sadpb = x->sadperbit16;
+  MV *const best_mv = &x->best_mv.as_mv;
+  int search_range = 3;
 
-  // Work out the size of the first step in the mv step search.
-  // 0 here is maximum length first step. 1 is MAX >> 1 etc.
-  if (cpi->sf.mv.auto_mv_step_size && cm->show_frame) {
-    // Take wtd average of the step_params based on the last frame's
-    // max mv magnitude and that based on the best ref mvs of the current
-    // block for the given reference.
-    step_param =
-        (av1_init_search_range(x->max_mv_context[ref]) + cpi->mv_step_param) /
-        2;
-  } else {
-    step_param = cpi->mv_step_param;
+  MvLimits tmp_mv_limits = x->mv_limits;
+  const int plane = 0;
+
+  // Initialized here because of compiler problem in Visual Studio.
+  if (ref_idx) {
+    orig_yv12 = xd->plane[plane].pre[0];
+    xd->plane[plane].pre[0] = xd->plane[plane].pre[ref_idx];
   }
 
-  // TODO(debargha): is show_frame needed here?
-  if (cpi->sf.adaptive_motion_search && bsize < cm->sb_size && cm->show_frame) {
-    int boffset =
-        2 * (b_width_log2_lookup[cm->sb_size] -
-             AOMMIN(b_height_log2_lookup[bsize], b_width_log2_lookup[bsize]));
-    step_param = AOMMAX(step_param, boffset);
+  // Do compound motion search on the current reference frame.
+  av1_set_mv_search_range(&x->mv_limits, &ref_mv.as_mv);
+
+  // Use the mv result from the single mode as mv predictor.
+  *best_mv = frame_mv[ref].as_mv;
+
+  best_mv->col >>= 3;
+  best_mv->row >>= 3;
+
+  av1_set_mvcost(x, ref, ref_idx, mbmi->ref_mv_idx);
+
+  // Small-range full-pixel motion search.
+  bestsme = av1_refining_search_8p_c(x, sadpb, search_range,
+                                     &cpi->fn_ptr[bsize], mask, mask_stride,
+                                     ref_idx, &ref_mv.as_mv, second_pred);
+  if (bestsme < INT_MAX) {
+    if (mask)
+      bestsme =
+          av1_get_mvpred_mask_var(x, best_mv, &ref_mv.as_mv, second_pred, mask,
+                                  mask_stride, ref_idx, &cpi->fn_ptr[bsize], 1);
+    else
+      bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv.as_mv, second_pred,
+                                      &cpi->fn_ptr[bsize], 1);
   }
 
-  if (cpi->sf.adaptive_motion_search) {
-    int bwl = b_width_log2_lookup[bsize];
-    int bhl = b_height_log2_lookup[bsize];
-    int tlevel = x->pred_mv_sad[ref] >> (bwl + bhl + 4);
-
-    if (tlevel < 5) step_param += 2;
-
-    // prev_mv_sad is not setup for dynamically scaled frames.
-    if (cpi->oxcf.resize_mode != RESIZE_DYNAMIC) {
-      for (i = LAST_FRAME; i <= ALTREF_FRAME && cm->show_frame; ++i) {
-        if ((x->pred_mv_sad[ref] >> 3) > x->pred_mv_sad[i]) {
-          x->pred_mv[ref].row = 0;
-          x->pred_mv[ref].col = 0;
-          tmp_mv->as_int = INVALID_MV;
-
-          if (scaled_ref_frame) {
-            int j;
-            for (j = 0; j < MAX_MB_PLANE; ++j)
-              xd->plane[j].pre[ref_idx] = backup_yv12[j];
-          }
-          return;
-        }
-      }
-    }
-  }
-
-  mvp_full = pred_mv[x->mv_best_ref_index[ref]];
-
-  mvp_full.col >>= 3;
-  mvp_full.row >>= 3;
-
-  bestsme = av1_masked_full_pixel_diamond(
-      cpi, x, mask, mask_stride, &mvp_full, step_param, sadpb,
-      MAX_MVSEARCH_STEPS - 1 - step_param, 1, &cpi->fn_ptr[bsize], &ref_mv,
-      &tmp_mv->as_mv, ref_idx);
-
   x->mv_limits = tmp_mv_limits;
 
   if (bestsme < INT_MAX) {
     int dis; /* TODO: use dis in distortion calculation later. */
-    av1_find_best_masked_sub_pixel_tree_up(
-        cpi, x, mask, mask_stride, mi_row, mi_col, &tmp_mv->as_mv, &ref_mv,
-        cm->allow_high_precision_mv, x->errorperbit, &cpi->fn_ptr[bsize],
-        cpi->sf.mv.subpel_force_stop, cpi->sf.mv.subpel_iters_per_step,
-        x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], ref_idx,
-        cpi->sf.use_upsampled_references);
-  }
-  *rate_mv = av1_mv_bit_cost(&tmp_mv->as_mv, &ref_mv, x->nmvjointcost,
-                             x->mvcost, MV_COST_WEIGHT);
+    unsigned int sse;
+    if (cpi->sf.use_upsampled_references) {
+      // Use up-sampled reference frames.
+      struct buf_2d backup_pred = pd->pre[0];
+      const YV12_BUFFER_CONFIG *upsampled_ref = get_upsampled_ref(cpi, ref);
 
-  if (cpi->sf.adaptive_motion_search && cm->show_frame)
-    x->pred_mv[ref] = tmp_mv->as_mv;
+      // Set pred for Y plane
+      setup_pred_plane(&pd->pre[0], bsize, upsampled_ref->y_buffer,
+                       upsampled_ref->y_crop_width,
+                       upsampled_ref->y_crop_height, upsampled_ref->y_stride,
+                       (mi_row << 3), (mi_col << 3), NULL, pd->subsampling_x,
+                       pd->subsampling_y);
+
+// If bsize < BLOCK_8X8, adjust pred pointer for this block
+#if !CONFIG_CB4X4
+      if (bsize < BLOCK_8X8)
+        pd->pre[0].buf =
+            &pd->pre[0].buf[(av1_raster_block_offset(BLOCK_8X8, block,
+                                                     pd->pre[0].stride))
+                            << 3];
+#endif  // !CONFIG_CB4X4
+
+      bestsme = cpi->find_fractional_mv_step(
+          x, &ref_mv.as_mv, cpi->common.allow_high_precision_mv, x->errorperbit,
+          &cpi->fn_ptr[bsize], 0, cpi->sf.mv.subpel_iters_per_step, NULL,
+          x->nmvjointcost, x->mvcost, &dis, &sse, second_pred, mask,
+          mask_stride, ref_idx, pw, ph, 1);
+
+      // Restore the reference frames.
+      pd->pre[0] = backup_pred;
+    } else {
+      (void)block;
+      bestsme = cpi->find_fractional_mv_step(
+          x, &ref_mv.as_mv, cpi->common.allow_high_precision_mv, x->errorperbit,
+          &cpi->fn_ptr[bsize], 0, cpi->sf.mv.subpel_iters_per_step, NULL,
+          x->nmvjointcost, x->mvcost, &dis, &sse, second_pred, mask,
+          mask_stride, ref_idx, pw, ph, 0);
+    }
+  }
+
+  // Restore the pointer to the first (possibly scaled) prediction buffer.
+  if (ref_idx) xd->plane[plane].pre[0] = orig_yv12;
+
+  if (bestsme < last_besterr) {
+    frame_mv[ref].as_mv = *best_mv;
+    last_besterr = bestsme;
+  }
+
+  *rate_mv = 0;
 
   if (scaled_ref_frame) {
+    // Restore the prediction frame pointers to their unscaled versions.
+    int i;
     for (i = 0; i < MAX_MB_PLANE; i++)
       xd->plane[i].pre[ref_idx] = backup_yv12[i];
   }
+
+  av1_set_mvcost(x, ref, ref_idx, mbmi->ref_mv_idx);
+#if !CONFIG_CB4X4
+  if (bsize >= BLOCK_8X8)
+#endif  // !CONFIG_CB4X4
+    *rate_mv += av1_mv_bit_cost(&frame_mv[ref].as_mv,
+                                &x->mbmi_ext->ref_mvs[ref][0].as_mv,
+                                x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
+#if !CONFIG_CB4X4
+  else
+    *rate_mv +=
+        av1_mv_bit_cost(&frame_mv[ref].as_mv, &ref_mv_sub8x8[ref_idx]->as_mv,
+                        x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
+#endif  // !CONFIG_CB4X4
 }
 
+// Wrapper for compound_single_motion_search, for the common case
+// where the second prediction is also an inter mode.
+static void compound_single_motion_search_interinter(
+    const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int_mv *frame_mv,
+    int mi_row, int mi_col, int_mv *ref_mv_sub8x8[2], const uint8_t *mask,
+    int mask_stride, int *rate_mv, const int block, int ref_idx) {
+  // This function should only ever be called for compound modes
+  assert(has_second_ref(&x->e_mbd.mi[0]->mbmi));
+
+// Prediction buffer from second frame.
+#if CONFIG_HIGHBITDEPTH
+  MACROBLOCKD *xd = &x->e_mbd;
+  DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE]);
+  uint8_t *second_pred;
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+    second_pred = CONVERT_TO_BYTEPTR(second_pred_alloc_16);
+  else
+    second_pred = (uint8_t *)second_pred_alloc_16;
+#else
+  DECLARE_ALIGNED(16, uint8_t, second_pred[MAX_SB_SQUARE]);
+#endif  // CONFIG_HIGHBITDEPTH
+
+  build_second_inter_pred(cpi, x, bsize, frame_mv, mi_row, mi_col, block,
+                          ref_idx, second_pred);
+
+  compound_single_motion_search(cpi, x, bsize, frame_mv, mi_row, mi_col,
+                                ref_mv_sub8x8, second_pred, mask, mask_stride,
+                                rate_mv, block, ref_idx);
+}
+
+#if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
 static void do_masked_motion_search_indexed(
     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
     const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
@@ -7278,33 +7458,26 @@
 
   mask = av1_get_compound_type_mask(comp_data, sb_type);
 
-  if (which == 2) {
-    int_mv frame_mv[TOTAL_REFS_PER_FRAME];
-    MV_REFERENCE_FRAME rf[2] = { mbmi->ref_frame[0], mbmi->ref_frame[1] };
-    assert(bsize >= BLOCK_8X8 || CONFIG_CB4X4);
+  int_mv frame_mv[TOTAL_REFS_PER_FRAME];
+  MV_REFERENCE_FRAME rf[2] = { mbmi->ref_frame[0], mbmi->ref_frame[1] };
+  assert(bsize >= BLOCK_8X8 || CONFIG_CB4X4);
 
-    frame_mv[rf[0]].as_int = cur_mv[0].as_int;
-    frame_mv[rf[1]].as_int = cur_mv[1].as_int;
+  frame_mv[rf[0]].as_int = cur_mv[0].as_int;
+  frame_mv[rf[1]].as_int = cur_mv[1].as_int;
+  if (which == 2) {
     joint_motion_search(cpi, x, bsize, frame_mv, mi_row, mi_col, NULL, mask,
                         mask_stride, rate_mv, 0);
-    tmp_mv[0].as_int = frame_mv[rf[0]].as_int;
-    tmp_mv[1].as_int = frame_mv[rf[1]].as_int;
   } else if (which == 0) {
-    do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
-                            &tmp_mv[0], rate_mv, 0);
+    compound_single_motion_search_interinter(cpi, x, bsize, frame_mv, mi_row,
+                                             mi_col, NULL, mask, mask_stride,
+                                             rate_mv, 0, 0);
   } else if (which == 1) {
-// get the negative mask
-#if CONFIG_COMPOUND_SEGMENT
-    uint8_t inv_mask_buf[2 * MAX_SB_SQUARE];
-    const int h = block_size_high[bsize];
-    mask = av1_get_compound_type_mask_inverse(
-        comp_data, inv_mask_buf, h, mask_stride, mask_stride, sb_type);
-#else
-    mask = av1_get_compound_type_mask_inverse(comp_data, sb_type);
-#endif  // CONFIG_COMPOUND_SEGMENT
-    do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
-                            &tmp_mv[1], rate_mv, 1);
+    compound_single_motion_search_interinter(cpi, x, bsize, frame_mv, mi_row,
+                                             mi_col, NULL, mask, mask_stride,
+                                             rate_mv, 0, 1);
   }
+  tmp_mv[0].as_int = frame_mv[rf[0]].as_int;
+  tmp_mv[1].as_int = frame_mv[rf[1]].as_int;
 }
 #endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
 #endif  // CONFIG_EXT_INTER
@@ -7946,17 +8119,33 @@
       }
     } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
       frame_mv[refs[1]].as_int = single_newmv[refs[1]].as_int;
-      av1_set_mvcost(x, refs[1], 1, mbmi->ref_mv_idx);
-      *rate_mv = av1_mv_bit_cost(&frame_mv[refs[1]].as_mv,
-                                 &mbmi_ext->ref_mvs[refs[1]][0].as_mv,
-                                 x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
+      if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
+        frame_mv[refs[0]].as_int =
+            mode_mv[compound_ref0_mode(this_mode)][refs[0]].as_int;
+        compound_single_motion_search_interinter(cpi, x, bsize, frame_mv,
+                                                 mi_row, mi_col, NULL, NULL, 0,
+                                                 rate_mv, 0, 1);
+      } else {
+        av1_set_mvcost(x, refs[1], 1, mbmi->ref_mv_idx);
+        *rate_mv = av1_mv_bit_cost(&frame_mv[refs[1]].as_mv,
+                                   &mbmi_ext->ref_mvs[refs[1]][0].as_mv,
+                                   x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
+      }
     } else {
       assert(this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV);
       frame_mv[refs[0]].as_int = single_newmv[refs[0]].as_int;
-      av1_set_mvcost(x, refs[0], 0, mbmi->ref_mv_idx);
-      *rate_mv = av1_mv_bit_cost(&frame_mv[refs[0]].as_mv,
-                                 &mbmi_ext->ref_mvs[refs[0]][0].as_mv,
-                                 x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
+      if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
+        frame_mv[refs[1]].as_int =
+            mode_mv[compound_ref1_mode(this_mode)][refs[1]].as_int;
+        compound_single_motion_search_interinter(cpi, x, bsize, frame_mv,
+                                                 mi_row, mi_col, NULL, NULL, 0,
+                                                 rate_mv, 0, 0);
+      } else {
+        av1_set_mvcost(x, refs[0], 0, mbmi->ref_mv_idx);
+        *rate_mv = av1_mv_bit_cost(&frame_mv[refs[0]].as_mv,
+                                   &mbmi_ext->ref_mvs[refs[0]][0].as_mv,
+                                   x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
+      }
     }
 #else
     // Initialize mv using single prediction mode result.
@@ -8968,8 +9157,12 @@
           // get negative of mask
           const uint8_t *mask = av1_get_contiguous_soft_mask(
               mbmi->interintra_wedge_index, 1, bsize);
-          do_masked_motion_search(cpi, x, mask, bw, bsize, mi_row, mi_col,
-                                  &tmp_mv, &tmp_rate_mv, 0);
+          int_mv frame_mv2[TOTAL_REFS_PER_FRAME];
+          frame_mv2[refs[0]].as_int = x->mbmi_ext->ref_mvs[refs[0]][0].as_int;
+          compound_single_motion_search(cpi, x, bsize, frame_mv2, mi_row,
+                                        mi_col, NULL, intrapred, mask, bw,
+                                        &tmp_rate_mv, 0, 0);
+          tmp_mv.as_int = frame_mv2[refs[0]].as_int;
           mbmi->mv[0].as_int = tmp_mv.as_int;
           av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst,
                                          bsize);