Refactor jnt_comp into handle_inter_mode

Put loops of searching jnt_comp only inside handle_inter_mode
function.

Change-Id: I9508dea64da7e6dfe31fad09aac8679cc855f55f
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index fea876e..84d8807 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7757,6 +7757,10 @@
   mbmi->use_wedge_interintra = 0;
   int compmode_interinter_cost = 0;
   mbmi->interinter_compound_type = COMPOUND_AVERAGE;
+#if CONFIG_JNT_COMP
+  mbmi->comp_group_idx = 0;
+  mbmi->compound_idx = 1;
+#endif
 
   if (!cm->allow_interintra_compound && is_comp_interintra_pred)
     return INT64_MAX;
@@ -7790,224 +7794,318 @@
   }
 
   mbmi->motion_mode = SIMPLE_TRANSLATION;
-  if (have_newmv_in_inter_mode(this_mode)) {
-    const int64_t ret_val = handle_newmv(cpi, x, bsize, mode_mv, mi_row, mi_col,
-                                         &rate_mv, single_newmv, args);
+  const int masked_compound_used =
+      is_any_masked_compound_used(bsize) && cm->allow_masked_compound;
+  int64_t ret_val = INT64_MAX;
+#if CONFIG_JNT_COMP
+  const RD_STATS backup_rd_stats = *rd_stats;
+  const RD_STATS backup_rd_stats_y = *rd_stats_y;
+  const RD_STATS backup_rd_stats_uv = *rd_stats_uv;
+  RD_STATS best_rd_stats, best_rd_stats_y, best_rd_stats_uv;
+  int64_t best_rd = INT64_MAX;
+  int best_compound_idx = 1;
+  int64_t best_ret_val = INT64_MAX;
+  uint8_t best_blk_skip[MAX_MB_PLANE][MAX_MIB_SIZE * MAX_MIB_SIZE * 4];
+  const MB_MODE_INFO backup_mbmi = *mbmi;
+  MB_MODE_INFO best_mbmi = *mbmi;
+  int64_t early_terminate = 0;
+
+  int comp_idx;
+  for (comp_idx = 0; comp_idx < 1 + is_comp_pred; ++comp_idx) {
+    early_terminate = 0;
+    *rd_stats = backup_rd_stats;
+    *rd_stats_y = backup_rd_stats_y;
+    *rd_stats_uv = backup_rd_stats_uv;
+    *mbmi = backup_mbmi;
+    mbmi->compound_idx = comp_idx;
+
+    if (is_comp_pred && comp_idx == 0) {
+      mbmi->comp_group_idx = 0;
+      mbmi->compound_idx = 0;
+
+      const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+      const int comp_index_ctx = get_comp_index_context(cm, xd);
+      if (masked_compound_used)
+        rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+      rd_stats->rate += x->comp_idx_cost[comp_index_ctx][0];
+    }
+#endif  // CONFIG_JNT_COMP
+
+    if (have_newmv_in_inter_mode(this_mode)) {
+      ret_val = handle_newmv(cpi, x, bsize, mode_mv, mi_row, mi_col, &rate_mv,
+                             single_newmv, args);
+#if CONFIG_JNT_COMP
+      if (ret_val != 0) {
+        early_terminate = INT64_MAX;
+        continue;
+      } else {
+        rd_stats->rate += rate_mv;
+      }
+#else
     if (ret_val != 0)
       return ret_val;
     else
       rd_stats->rate += rate_mv;
-  }
-  for (i = 0; i < is_comp_pred + 1; ++i) {
-    cur_mv[i] = frame_mv[refs[i]];
-    // Clip "next_nearest" so that it does not extend to far out of image
-    if (this_mode != NEWMV) clamp_mv2(&cur_mv[i].as_mv, xd);
+#endif  // CONFIG_JNT_COMP
+    }
+    for (i = 0; i < is_comp_pred + 1; ++i) {
+      cur_mv[i] = frame_mv[refs[i]];
+      // Clip "next_nearest" so that it does not extend to far out of image
+      if (this_mode != NEWMV) clamp_mv2(&cur_mv[i].as_mv, xd);
+#if CONFIG_JNT_COMP
+      if (mv_check_bounds(&x->mv_limits, &cur_mv[i].as_mv)) {
+        early_terminate = INT64_MAX;
+        continue;
+      }
+#else
     if (mv_check_bounds(&x->mv_limits, &cur_mv[i].as_mv)) return INT64_MAX;
-    mbmi->mv[i].as_int = cur_mv[i].as_int;
-  }
+#endif  // CONFIG_JNT_COMP
+      mbmi->mv[i].as_int = cur_mv[i].as_int;
+    }
+
+    if (this_mode == NEAREST_NEARESTMV) {
+      if (mbmi_ext->ref_mv_count[ref_frame_type] > 0) {
+        cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv;
+        cur_mv[1] = mbmi_ext->ref_mv_stack[ref_frame_type][0].comp_mv;
+
+        for (i = 0; i < 2; ++i) {
+          clamp_mv2(&cur_mv[i].as_mv, xd);
+#if CONFIG_JNT_COMP
+          if (mv_check_bounds(&x->mv_limits, &cur_mv[i].as_mv)) {
+            early_terminate = INT64_MAX;
+            continue;
+          }
+#else
+        if (mv_check_bounds(&x->mv_limits, &cur_mv[i].as_mv)) return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+          mbmi->mv[i].as_int = cur_mv[i].as_int;
+        }
+      }
+    }
+
+    if (mbmi_ext->ref_mv_count[ref_frame_type] > 0) {
+      if (this_mode == NEAREST_NEWMV) {
+        cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv;
+
+#if CONFIG_AMVR
+        lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv,
+                           cm->cur_frame_force_integer_mv);
+#else
+      lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv);
+#endif
+        clamp_mv2(&cur_mv[0].as_mv, xd);
+#if CONFIG_JNT_COMP
+        if (mv_check_bounds(&x->mv_limits, &cur_mv[0].as_mv)) {
+          early_terminate = INT64_MAX;
+          continue;
+        }
+#else
+      if (mv_check_bounds(&x->mv_limits, &cur_mv[0].as_mv)) return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+        mbmi->mv[0].as_int = cur_mv[0].as_int;
+      }
+
+      if (this_mode == NEW_NEARESTMV) {
+        cur_mv[1] = mbmi_ext->ref_mv_stack[ref_frame_type][0].comp_mv;
+
+#if CONFIG_AMVR
+        lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv,
+                           cm->cur_frame_force_integer_mv);
+#else
+      lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv);
+#endif
+        clamp_mv2(&cur_mv[1].as_mv, xd);
+#if CONFIG_JNT_COMP
+        if (mv_check_bounds(&x->mv_limits, &cur_mv[1].as_mv)) {
+          early_terminate = INT64_MAX;
+          continue;
+        }
+#else
+      if (mv_check_bounds(&x->mv_limits, &cur_mv[1].as_mv)) return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+        mbmi->mv[1].as_int = cur_mv[1].as_int;
+      }
+    }
+
+    if (mbmi_ext->ref_mv_count[ref_frame_type] > 1) {
+      int ref_mv_idx = mbmi->ref_mv_idx + 1;
+      if (this_mode == NEAR_NEWMV || this_mode == NEAR_NEARMV) {
+        cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx].this_mv;
+
+#if CONFIG_AMVR
+        lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv,
+                           cm->cur_frame_force_integer_mv);
+#else
+      lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv);
+#endif
+        clamp_mv2(&cur_mv[0].as_mv, xd);
+#if CONFIG_JNT_COMP
+        if (mv_check_bounds(&x->mv_limits, &cur_mv[0].as_mv)) {
+          early_terminate = INT64_MAX;
+          continue;
+        }
+#else
+      if (mv_check_bounds(&x->mv_limits, &cur_mv[0].as_mv)) return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+        mbmi->mv[0].as_int = cur_mv[0].as_int;
+      }
+
+      if (this_mode == NEW_NEARMV || this_mode == NEAR_NEARMV) {
+        cur_mv[1] = mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx].comp_mv;
+
+#if CONFIG_AMVR
+        lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv,
+                           cm->cur_frame_force_integer_mv);
+#else
+      lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv);
+#endif
+        clamp_mv2(&cur_mv[1].as_mv, xd);
+#if CONFIG_JNT_COMP
+        if (mv_check_bounds(&x->mv_limits, &cur_mv[1].as_mv)) {
+          early_terminate = INT64_MAX;
+          continue;
+        }
+#else
+      if (mv_check_bounds(&x->mv_limits, &cur_mv[1].as_mv)) return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+        mbmi->mv[1].as_int = cur_mv[1].as_int;
+      }
+    }
+
+    // do first prediction into the destination buffer. Do the next
+    // prediction into a temporary buffer. Then keep track of which one
+    // of these currently holds the best predictor, and use the other
+    // one for future predictions. In the end, copy from tmp_buf to
+    // dst if necessary.
+    for (i = 0; i < MAX_MB_PLANE; i++) {
+      tmp_dst.plane[i] = tmp_buf + i * MAX_SB_SQUARE;
+      tmp_dst.stride[i] = MAX_SB_SIZE;
+    }
+    for (i = 0; i < MAX_MB_PLANE; i++) {
+      orig_dst.plane[i] = xd->plane[i].dst.buf;
+      orig_dst.stride[i] = xd->plane[i].dst.stride;
+    }
+
+    // We don't include the cost of the second reference here, because there
+    // are only three options: Last/Golden, ARF/Last or Golden/ARF, or in other
+    // words if you present them in that order, the second one is always known
+    // if the first is known.
+    //
+    // Under some circumstances we discount the cost of new mv mode to encourage
+    // initiation of a motion field.
+    if (discount_newmv_test(cpi, this_mode, frame_mv[refs[0]], mode_mv,
+                            refs[0])) {
+      rd_stats->rate +=
+          AOMMIN(cost_mv_ref(x, this_mode, mode_ctx),
+                 cost_mv_ref(x, is_comp_pred ? NEAREST_NEARESTMV : NEARESTMV,
+                             mode_ctx));
+    } else {
+      rd_stats->rate += cost_mv_ref(x, this_mode, mode_ctx);
+    }
 
 #if CONFIG_JNT_COMP
-  if (is_comp_pred) {
-    if (mbmi->compound_idx == 0) {
-      int masked_compound_used = is_any_masked_compound_used(bsize);
-      masked_compound_used = masked_compound_used && cm->allow_masked_compound;
-
-      if (masked_compound_used) {
-        const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
-        rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
-      }
-
-      const int comp_index_ctx = get_comp_index_context(cm, xd);
-      rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
+    if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd &&
+        mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) {
+      early_terminate = INT64_MAX;
+      continue;
     }
-  }
-#endif  // CONFIG_JNT_COMP
-
-  if (this_mode == NEAREST_NEARESTMV) {
-    if (mbmi_ext->ref_mv_count[ref_frame_type] > 0) {
-      cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv;
-      cur_mv[1] = mbmi_ext->ref_mv_stack[ref_frame_type][0].comp_mv;
-
-      for (i = 0; i < 2; ++i) {
-        clamp_mv2(&cur_mv[i].as_mv, xd);
-        if (mv_check_bounds(&x->mv_limits, &cur_mv[i].as_mv)) return INT64_MAX;
-        mbmi->mv[i].as_int = cur_mv[i].as_int;
-      }
-    }
-  }
-
-  if (mbmi_ext->ref_mv_count[ref_frame_type] > 0) {
-    if (this_mode == NEAREST_NEWMV) {
-      cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv;
-
-#if CONFIG_AMVR
-      lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv,
-                         cm->cur_frame_force_integer_mv);
 #else
-      lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv);
-#endif
-      clamp_mv2(&cur_mv[0].as_mv, xd);
-      if (mv_check_bounds(&x->mv_limits, &cur_mv[0].as_mv)) return INT64_MAX;
-      mbmi->mv[0].as_int = cur_mv[0].as_int;
-    }
-
-    if (this_mode == NEW_NEARESTMV) {
-      cur_mv[1] = mbmi_ext->ref_mv_stack[ref_frame_type][0].comp_mv;
-
-#if CONFIG_AMVR
-      lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv,
-                         cm->cur_frame_force_integer_mv);
-#else
-      lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv);
-#endif
-      clamp_mv2(&cur_mv[1].as_mv, xd);
-      if (mv_check_bounds(&x->mv_limits, &cur_mv[1].as_mv)) return INT64_MAX;
-      mbmi->mv[1].as_int = cur_mv[1].as_int;
-    }
-  }
-
-  if (mbmi_ext->ref_mv_count[ref_frame_type] > 1) {
-    int ref_mv_idx = mbmi->ref_mv_idx + 1;
-    if (this_mode == NEAR_NEWMV || this_mode == NEAR_NEARMV) {
-      cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx].this_mv;
-
-#if CONFIG_AMVR
-      lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv,
-                         cm->cur_frame_force_integer_mv);
-#else
-      lower_mv_precision(&cur_mv[0].as_mv, cm->allow_high_precision_mv);
-#endif
-      clamp_mv2(&cur_mv[0].as_mv, xd);
-      if (mv_check_bounds(&x->mv_limits, &cur_mv[0].as_mv)) return INT64_MAX;
-      mbmi->mv[0].as_int = cur_mv[0].as_int;
-    }
-
-    if (this_mode == NEW_NEARMV || this_mode == NEAR_NEARMV) {
-      cur_mv[1] = mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx].comp_mv;
-
-#if CONFIG_AMVR
-      lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv,
-                         cm->cur_frame_force_integer_mv);
-#else
-      lower_mv_precision(&cur_mv[1].as_mv, cm->allow_high_precision_mv);
-#endif
-      clamp_mv2(&cur_mv[1].as_mv, xd);
-      if (mv_check_bounds(&x->mv_limits, &cur_mv[1].as_mv)) return INT64_MAX;
-      mbmi->mv[1].as_int = cur_mv[1].as_int;
-    }
-  }
-
-  // do first prediction into the destination buffer. Do the next
-  // prediction into a temporary buffer. Then keep track of which one
-  // of these currently holds the best predictor, and use the other
-  // one for future predictions. In the end, copy from tmp_buf to
-  // dst if necessary.
-  for (i = 0; i < MAX_MB_PLANE; i++) {
-    tmp_dst.plane[i] = tmp_buf + i * MAX_SB_SQUARE;
-    tmp_dst.stride[i] = MAX_SB_SIZE;
-  }
-  for (i = 0; i < MAX_MB_PLANE; i++) {
-    orig_dst.plane[i] = xd->plane[i].dst.buf;
-    orig_dst.stride[i] = xd->plane[i].dst.stride;
-  }
-
-  // We don't include the cost of the second reference here, because there
-  // are only three options: Last/Golden, ARF/Last or Golden/ARF, or in other
-  // words if you present them in that order, the second one is always known
-  // if the first is known.
-  //
-  // Under some circumstances we discount the cost of new mv mode to encourage
-  // initiation of a motion field.
-  if (discount_newmv_test(cpi, this_mode, frame_mv[refs[0]], mode_mv,
-                          refs[0])) {
-    rd_stats->rate += AOMMIN(
-        cost_mv_ref(x, this_mode, mode_ctx),
-        cost_mv_ref(x, is_comp_pred ? NEAREST_NEARESTMV : NEARESTMV, mode_ctx));
-  } else {
-    rd_stats->rate += cost_mv_ref(x, this_mode, mode_ctx);
-  }
-
   if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd &&
       mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV)
     return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
 
-  int64_t ret_val = interpolation_filter_search(
-      x, cpi, bsize, mi_row, mi_col, &tmp_dst, &orig_dst, args->single_filter,
-      &rd, &rs, &skip_txfm_sb, &skip_sse_sb);
+    ret_val = interpolation_filter_search(
+        x, cpi, bsize, mi_row, mi_col, &tmp_dst, &orig_dst, args->single_filter,
+        &rd, &rs, &skip_txfm_sb, &skip_sse_sb);
+#if CONFIG_JNT_COMP
+    if (ret_val != 0) {
+      early_terminate = INT64_MAX;
+      continue;
+    }
+#else
   if (ret_val != 0) return ret_val;
+#endif  // CONFIG_JNT_COMP
 
-  best_bmc_mbmi = *mbmi;
-  rate2_bmc_nocoeff = rd_stats->rate;
-  if (cm->interp_filter == SWITCHABLE) rate2_bmc_nocoeff += rs;
-  rate_mv_bmc = rate_mv;
+    best_bmc_mbmi = *mbmi;
+    rate2_bmc_nocoeff = rd_stats->rate;
+    if (cm->interp_filter == SWITCHABLE) rate2_bmc_nocoeff += rs;
+    rate_mv_bmc = rate_mv;
 
 #if CONFIG_JNT_COMP
-  if (is_comp_pred && mbmi->compound_idx)
+    if (is_comp_pred && comp_idx)
 #else
   if (is_comp_pred)
-#endif  // CONFIG_JNT_COMP
-  {
-    int rate_sum, rs2;
-    int64_t dist_sum;
-    int64_t best_rd_compound = INT64_MAX, best_rd_cur = INT64_MAX;
-    INTERINTER_COMPOUND_DATA best_compound_data;
-    int_mv best_mv[2];
-    int best_tmp_rate_mv = rate_mv;
-    int tmp_skip_txfm_sb;
-    int64_t tmp_skip_sse_sb;
-    DECLARE_ALIGNED(16, uint8_t, pred0[2 * MAX_SB_SQUARE]);
-    DECLARE_ALIGNED(16, uint8_t, pred1[2 * MAX_SB_SQUARE]);
-    uint8_t *preds0[1] = { pred0 };
-    uint8_t *preds1[1] = { pred1 };
-    int strides[1] = { bw };
-    int tmp_rate_mv;
-    int masked_compound_used = is_any_masked_compound_used(bsize);
-    masked_compound_used = masked_compound_used && cm->allow_masked_compound;
-    COMPOUND_TYPE cur_type;
-    int best_compmode_interinter_cost = 0;
+#endif
+    {
+      int rate_sum, rs2;
+      int64_t dist_sum;
+      int64_t best_rd_compound = INT64_MAX, best_rd_cur = INT64_MAX;
+      INTERINTER_COMPOUND_DATA best_compound_data;
+      int_mv best_mv[2];
+      int best_tmp_rate_mv = rate_mv;
+      int tmp_skip_txfm_sb;
+      int64_t tmp_skip_sse_sb;
+      DECLARE_ALIGNED(16, uint8_t, pred0[2 * MAX_SB_SQUARE]);
+      DECLARE_ALIGNED(16, uint8_t, pred1[2 * MAX_SB_SQUARE]);
+      uint8_t *preds0[1] = { pred0 };
+      uint8_t *preds1[1] = { pred1 };
+      int strides[1] = { bw };
+      int tmp_rate_mv;
+      COMPOUND_TYPE cur_type;
+      int best_compmode_interinter_cost = 0;
 
-    best_mv[0].as_int = cur_mv[0].as_int;
-    best_mv[1].as_int = cur_mv[1].as_int;
-    memset(&best_compound_data, 0, sizeof(best_compound_data));
-    uint8_t tmp_mask_buf[2 * MAX_SB_SQUARE];
-    best_compound_data.seg_mask = tmp_mask_buf;
+      best_mv[0].as_int = cur_mv[0].as_int;
+      best_mv[1].as_int = cur_mv[1].as_int;
+      memset(&best_compound_data, 0, sizeof(best_compound_data));
+      uint8_t tmp_mask_buf[2 * MAX_SB_SQUARE];
+      best_compound_data.seg_mask = tmp_mask_buf;
 
-    if (masked_compound_used) {
-      // get inter predictors to use for masked compound modes
-      av1_build_inter_predictors_for_planes_single_buf(
-          xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides);
-      av1_build_inter_predictors_for_planes_single_buf(
-          xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides);
-    }
-
-    for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
-      if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break;
-      if (!is_interinter_compound_used(cur_type, bsize)) continue;
-      tmp_rate_mv = rate_mv;
-      best_rd_cur = INT64_MAX;
-      mbmi->interinter_compound_type = cur_type;
-#if CONFIG_JNT_COMP
-      const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
-      int masked_type_cost = 0;
       if (masked_compound_used) {
-        if (cur_type == COMPOUND_AVERAGE) {
-          masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
-
-          const int comp_index_ctx = get_comp_index_context(cm, xd);
-          masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
-        } else {
-          masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
-
-          masked_type_cost +=
-              x->compound_type_cost[bsize][mbmi->interinter_compound_type - 1];
-        }
-      } else {
-        const int comp_index_ctx = get_comp_index_context(cm, xd);
-        masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
+        // get inter predictors to use for masked compound modes
+        av1_build_inter_predictors_for_planes_single_buf(
+            xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides);
+        av1_build_inter_predictors_for_planes_single_buf(
+            xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides);
       }
-      rs2 = av1_cost_literal(get_interinter_compound_type_bits(
-                bsize, mbmi->interinter_compound_type)) +
-            masked_type_cost;
+
+      for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
+        if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break;
+        if (!is_interinter_compound_used(cur_type, bsize)) continue;
+        tmp_rate_mv = rate_mv;
+        best_rd_cur = INT64_MAX;
+        mbmi->interinter_compound_type = cur_type;
+#if CONFIG_JNT_COMP
+        int masked_type_cost = 0;
+
+        const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+        const int comp_index_ctx = get_comp_index_context(cm, xd);
+        if (masked_compound_used) {
+          if (cur_type == COMPOUND_AVERAGE) {
+            mbmi->comp_group_idx = 0;
+            mbmi->compound_idx = 1;
+
+            masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+            masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
+          } else {
+            mbmi->comp_group_idx = 1;
+            mbmi->compound_idx = 1;
+
+            masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
+            masked_type_cost +=
+                x->compound_type_cost[bsize]
+                                     [mbmi->interinter_compound_type - 1];
+          }
+        } else {
+          mbmi->comp_group_idx = 0;
+          mbmi->compound_idx = 1;
+
+          masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
+        }
+
+        rs2 = av1_cost_literal(get_interinter_compound_type_bits(
+                  bsize, mbmi->interinter_compound_type)) +
+              masked_type_cost;
 #else
       int masked_type_cost = 0;
       if (masked_compound_used) {
@@ -8022,275 +8120,340 @@
             masked_type_cost;
 #endif  // CONFIG_JNT_COMP
 
-      switch (cur_type) {
-        case COMPOUND_AVERAGE:
-          av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst,
-                                         bsize);
+        switch (cur_type) {
+          case COMPOUND_AVERAGE:
+            av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst,
+                                           bsize);
+            av1_subtract_plane(x, bsize, 0);
+            rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
+                                     &tmp_skip_txfm_sb, &tmp_skip_sse_sb,
+                                     INT64_MAX);
+            if (rd != INT64_MAX)
+              best_rd_cur =
+                  RDCOST(x->rdmult, rs2 + rate_mv + rate_sum, dist_sum);
+            best_rd_compound = best_rd_cur;
+            break;
+          case COMPOUND_WEDGE:
+            if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
+                best_rd_compound / 3 < ref_best_rd) {
+              best_rd_cur = build_and_cost_compound_type(
+                  cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &orig_dst,
+                  &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
+            }
+            break;
+          case COMPOUND_SEG:
+            if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
+                best_rd_compound / 3 < ref_best_rd) {
+              best_rd_cur = build_and_cost_compound_type(
+                  cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &orig_dst,
+                  &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
+            }
+            break;
+          default: assert(0); return 0;
+        }
+
+        if (best_rd_cur < best_rd_compound) {
+          best_rd_compound = best_rd_cur;
+          best_compound_data.wedge_index = mbmi->wedge_index;
+          best_compound_data.wedge_sign = mbmi->wedge_sign;
+          best_compound_data.mask_type = mbmi->mask_type;
+          memcpy(best_compound_data.seg_mask, xd->seg_mask,
+                 2 * MAX_SB_SQUARE * sizeof(uint8_t));
+          best_compound_data.interinter_compound_type =
+              mbmi->interinter_compound_type;
+          best_compmode_interinter_cost = rs2;
+          if (have_newmv_in_inter_mode(this_mode)) {
+            if (use_masked_motion_search(cur_type)) {
+              best_tmp_rate_mv = tmp_rate_mv;
+              best_mv[0].as_int = mbmi->mv[0].as_int;
+              best_mv[1].as_int = mbmi->mv[1].as_int;
+            } else {
+              best_mv[0].as_int = cur_mv[0].as_int;
+              best_mv[1].as_int = cur_mv[1].as_int;
+            }
+          }
+        }
+        // reset to original mvs for next iteration
+        mbmi->mv[0].as_int = cur_mv[0].as_int;
+        mbmi->mv[1].as_int = cur_mv[1].as_int;
+      }
+      mbmi->wedge_index = best_compound_data.wedge_index;
+      mbmi->wedge_sign = best_compound_data.wedge_sign;
+      mbmi->mask_type = best_compound_data.mask_type;
+      memcpy(xd->seg_mask, best_compound_data.seg_mask,
+             2 * MAX_SB_SQUARE * sizeof(uint8_t));
+      mbmi->interinter_compound_type =
+          best_compound_data.interinter_compound_type;
+      if (have_newmv_in_inter_mode(this_mode)) {
+        mbmi->mv[0].as_int = best_mv[0].as_int;
+        mbmi->mv[1].as_int = best_mv[1].as_int;
+        if (use_masked_motion_search(mbmi->interinter_compound_type)) {
+          rd_stats->rate += best_tmp_rate_mv - rate_mv;
+          rate_mv = best_tmp_rate_mv;
+        }
+      }
+
+      if (ref_best_rd < INT64_MAX && best_rd_compound / 3 > ref_best_rd) {
+        restore_dst_buf(xd, orig_dst);
+#if CONFIG_JNT_COMP
+        early_terminate = INT64_MAX;
+        continue;
+#else
+      return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+      }
+
+      pred_exists = 0;
+
+      compmode_interinter_cost = best_compmode_interinter_cost;
+    }
+
+    if (is_comp_interintra_pred) {
+      INTERINTRA_MODE best_interintra_mode = II_DC_PRED;
+      int64_t best_interintra_rd = INT64_MAX;
+      int rmode, rate_sum;
+      int64_t dist_sum;
+      int j;
+      int tmp_rate_mv = 0;
+      int tmp_skip_txfm_sb;
+      int64_t tmp_skip_sse_sb;
+      DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_SB_SQUARE]);
+      uint8_t *intrapred;
+
+      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+        intrapred = CONVERT_TO_BYTEPTR(intrapred_);
+      else
+        intrapred = intrapred_;
+
+      mbmi->ref_frame[1] = NONE_FRAME;
+      for (j = 0; j < MAX_MB_PLANE; j++) {
+        xd->plane[j].dst.buf = tmp_buf + j * MAX_SB_SQUARE;
+        xd->plane[j].dst.stride = bw;
+      }
+      av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst, bsize);
+      restore_dst_buf(xd, orig_dst);
+      mbmi->ref_frame[1] = INTRA_FRAME;
+      mbmi->use_wedge_interintra = 0;
+
+      for (j = 0; j < INTERINTRA_MODES; ++j) {
+        mbmi->interintra_mode = (INTERINTRA_MODE)j;
+        rmode = interintra_mode_cost[mbmi->interintra_mode];
+        av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, &orig_dst,
+                                                  intrapred, bw);
+        av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
+        model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
+                        &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
+        rd = RDCOST(x->rdmult, tmp_rate_mv + rate_sum + rmode, dist_sum);
+        if (rd < best_interintra_rd) {
+          best_interintra_rd = rd;
+          best_interintra_mode = mbmi->interintra_mode;
+        }
+      }
+      mbmi->interintra_mode = best_interintra_mode;
+      rmode = interintra_mode_cost[mbmi->interintra_mode];
+      av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, &orig_dst,
+                                                intrapred, bw);
+      av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
+      av1_subtract_plane(x, bsize, 0);
+      rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
+                               &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
+      if (rd != INT64_MAX)
+        rd = RDCOST(x->rdmult, rate_mv + rmode + rate_sum, dist_sum);
+      best_interintra_rd = rd;
+
+      if (ref_best_rd < INT64_MAX && best_interintra_rd > 2 * ref_best_rd) {
+// Don't need to call restore_dst_buf here
+#if CONFIG_JNT_COMP
+        early_terminate = INT64_MAX;
+        continue;
+#else
+      return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+      }
+      if (is_interintra_wedge_used(bsize)) {
+        int64_t best_interintra_rd_nowedge = INT64_MAX;
+        int64_t best_interintra_rd_wedge = INT64_MAX;
+        int_mv tmp_mv;
+        int rwedge = x->wedge_interintra_cost[bsize][0];
+        if (rd != INT64_MAX)
+          rd = RDCOST(x->rdmult, rmode + rate_mv + rwedge + rate_sum, dist_sum);
+        best_interintra_rd_nowedge = best_interintra_rd;
+
+        // Disable wedge search if source variance is small
+        if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh) {
+          mbmi->use_wedge_interintra = 1;
+
+          rwedge = av1_cost_literal(get_interintra_wedge_bits(bsize)) +
+                   x->wedge_interintra_cost[bsize][1];
+
+          best_interintra_rd_wedge =
+              pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
+
+          best_interintra_rd_wedge +=
+              RDCOST(x->rdmult, rmode + rate_mv + rwedge, 0);
+          // Refine motion vector.
+          if (have_newmv_in_inter_mode(this_mode)) {
+            // get negative of mask
+            const uint8_t *mask = av1_get_contiguous_soft_mask(
+                mbmi->interintra_wedge_index, 1, bsize);
+            tmp_mv.as_int = x->mbmi_ext->ref_mvs[refs[0]][0].as_int;
+            compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, mi_row,
+                                          mi_col, intrapred, mask, bw,
+                                          &tmp_rate_mv, 0);
+            mbmi->mv[0].as_int = tmp_mv.as_int;
+            av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst,
+                                           bsize);
+            model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
+                            &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
+            rd = RDCOST(x->rdmult, rmode + tmp_rate_mv + rwedge + rate_sum,
+                        dist_sum);
+            if (rd >= best_interintra_rd_wedge) {
+              tmp_mv.as_int = cur_mv[0].as_int;
+              tmp_rate_mv = rate_mv;
+            }
+          } else {
+            tmp_mv.as_int = cur_mv[0].as_int;
+            tmp_rate_mv = rate_mv;
+            av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
+          }
+          // Evaluate closer to true rd
           av1_subtract_plane(x, bsize, 0);
           rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
                                    &tmp_skip_txfm_sb, &tmp_skip_sse_sb,
                                    INT64_MAX);
           if (rd != INT64_MAX)
-            best_rd_cur = RDCOST(x->rdmult, rs2 + rate_mv + rate_sum, dist_sum);
-          best_rd_compound = best_rd_cur;
-          break;
-        case COMPOUND_WEDGE:
-          if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
-              best_rd_compound / 3 < ref_best_rd) {
-            best_rd_cur = build_and_cost_compound_type(
-                cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &orig_dst,
-                &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
-          }
-          break;
-        case COMPOUND_SEG:
-          if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
-              best_rd_compound / 3 < ref_best_rd) {
-            best_rd_cur = build_and_cost_compound_type(
-                cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &orig_dst,
-                &tmp_rate_mv, preds0, preds1, strides, mi_row, mi_col);
-          }
-          break;
-        default: assert(0); return 0;
-      }
-
-      if (best_rd_cur < best_rd_compound) {
-        best_rd_compound = best_rd_cur;
-        best_compound_data.wedge_index = mbmi->wedge_index;
-        best_compound_data.wedge_sign = mbmi->wedge_sign;
-        best_compound_data.mask_type = mbmi->mask_type;
-        memcpy(best_compound_data.seg_mask, xd->seg_mask,
-               2 * MAX_SB_SQUARE * sizeof(uint8_t));
-        best_compound_data.interinter_compound_type =
-            mbmi->interinter_compound_type;
-        best_compmode_interinter_cost = rs2;
-        if (have_newmv_in_inter_mode(this_mode)) {
-          if (use_masked_motion_search(cur_type)) {
-            best_tmp_rate_mv = tmp_rate_mv;
-            best_mv[0].as_int = mbmi->mv[0].as_int;
-            best_mv[1].as_int = mbmi->mv[1].as_int;
+            rd = RDCOST(x->rdmult, rmode + tmp_rate_mv + rwedge + rate_sum,
+                        dist_sum);
+          best_interintra_rd_wedge = rd;
+          if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
+            mbmi->use_wedge_interintra = 1;
+            mbmi->mv[0].as_int = tmp_mv.as_int;
+            rd_stats->rate += tmp_rate_mv - rate_mv;
+            rate_mv = tmp_rate_mv;
           } else {
-            best_mv[0].as_int = cur_mv[0].as_int;
-            best_mv[1].as_int = cur_mv[1].as_int;
+            mbmi->use_wedge_interintra = 0;
+            mbmi->mv[0].as_int = cur_mv[0].as_int;
           }
-        }
-      }
-      // reset to original mvs for next iteration
-      mbmi->mv[0].as_int = cur_mv[0].as_int;
-      mbmi->mv[1].as_int = cur_mv[1].as_int;
-    }
-    mbmi->wedge_index = best_compound_data.wedge_index;
-    mbmi->wedge_sign = best_compound_data.wedge_sign;
-    mbmi->mask_type = best_compound_data.mask_type;
-    memcpy(xd->seg_mask, best_compound_data.seg_mask,
-           2 * MAX_SB_SQUARE * sizeof(uint8_t));
-    mbmi->interinter_compound_type =
-        best_compound_data.interinter_compound_type;
-    if (have_newmv_in_inter_mode(this_mode)) {
-      mbmi->mv[0].as_int = best_mv[0].as_int;
-      mbmi->mv[1].as_int = best_mv[1].as_int;
-      if (use_masked_motion_search(mbmi->interinter_compound_type)) {
-        rd_stats->rate += best_tmp_rate_mv - rate_mv;
-        rate_mv = best_tmp_rate_mv;
-      }
-    }
-
-    if (ref_best_rd < INT64_MAX && best_rd_compound / 3 > ref_best_rd) {
-      restore_dst_buf(xd, orig_dst);
-      return INT64_MAX;
-    }
-
-    pred_exists = 0;
-
-    compmode_interinter_cost = best_compmode_interinter_cost;
-  }
-
-  if (is_comp_interintra_pred) {
-    INTERINTRA_MODE best_interintra_mode = II_DC_PRED;
-    int64_t best_interintra_rd = INT64_MAX;
-    int rmode, rate_sum;
-    int64_t dist_sum;
-    int j;
-    int tmp_rate_mv = 0;
-    int tmp_skip_txfm_sb;
-    int64_t tmp_skip_sse_sb;
-    DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_SB_SQUARE]);
-    uint8_t *intrapred;
-
-    if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-      intrapred = CONVERT_TO_BYTEPTR(intrapred_);
-    else
-      intrapred = intrapred_;
-
-    mbmi->ref_frame[1] = NONE_FRAME;
-    for (j = 0; j < MAX_MB_PLANE; j++) {
-      xd->plane[j].dst.buf = tmp_buf + j * MAX_SB_SQUARE;
-      xd->plane[j].dst.stride = bw;
-    }
-    av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst, bsize);
-    restore_dst_buf(xd, orig_dst);
-    mbmi->ref_frame[1] = INTRA_FRAME;
-    mbmi->use_wedge_interintra = 0;
-
-    for (j = 0; j < INTERINTRA_MODES; ++j) {
-      mbmi->interintra_mode = (INTERINTRA_MODE)j;
-      rmode = interintra_mode_cost[mbmi->interintra_mode];
-      av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, &orig_dst,
-                                                intrapred, bw);
-      av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
-      model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
-                      &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
-      rd = RDCOST(x->rdmult, tmp_rate_mv + rate_sum + rmode, dist_sum);
-      if (rd < best_interintra_rd) {
-        best_interintra_rd = rd;
-        best_interintra_mode = mbmi->interintra_mode;
-      }
-    }
-    mbmi->interintra_mode = best_interintra_mode;
-    rmode = interintra_mode_cost[mbmi->interintra_mode];
-    av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, &orig_dst,
-                                              intrapred, bw);
-    av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
-    av1_subtract_plane(x, bsize, 0);
-    rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
-                             &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
-    if (rd != INT64_MAX)
-      rd = RDCOST(x->rdmult, rate_mv + rmode + rate_sum, dist_sum);
-    best_interintra_rd = rd;
-
-    if (ref_best_rd < INT64_MAX && best_interintra_rd > 2 * ref_best_rd) {
-      // Don't need to call restore_dst_buf here
-      return INT64_MAX;
-    }
-    if (is_interintra_wedge_used(bsize)) {
-      int64_t best_interintra_rd_nowedge = INT64_MAX;
-      int64_t best_interintra_rd_wedge = INT64_MAX;
-      int_mv tmp_mv;
-      int rwedge = x->wedge_interintra_cost[bsize][0];
-      if (rd != INT64_MAX)
-        rd = RDCOST(x->rdmult, rmode + rate_mv + rwedge + rate_sum, dist_sum);
-      best_interintra_rd_nowedge = best_interintra_rd;
-
-      // Disable wedge search if source variance is small
-      if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh) {
-        mbmi->use_wedge_interintra = 1;
-
-        rwedge = av1_cost_literal(get_interintra_wedge_bits(bsize)) +
-                 x->wedge_interintra_cost[bsize][1];
-
-        best_interintra_rd_wedge =
-            pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
-
-        best_interintra_rd_wedge +=
-            RDCOST(x->rdmult, rmode + rate_mv + rwedge, 0);
-        // Refine motion vector.
-        if (have_newmv_in_inter_mode(this_mode)) {
-          // get negative of mask
-          const uint8_t *mask = av1_get_contiguous_soft_mask(
-              mbmi->interintra_wedge_index, 1, bsize);
-          tmp_mv.as_int = x->mbmi_ext->ref_mvs[refs[0]][0].as_int;
-          compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, mi_row,
-                                        mi_col, intrapred, mask, bw,
-                                        &tmp_rate_mv, 0);
-          mbmi->mv[0].as_int = tmp_mv.as_int;
-          av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, &orig_dst,
-                                         bsize);
-          model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
-                          &tmp_skip_txfm_sb, &tmp_skip_sse_sb);
-          rd = RDCOST(x->rdmult, rmode + tmp_rate_mv + rwedge + rate_sum,
-                      dist_sum);
-          if (rd >= best_interintra_rd_wedge) {
-            tmp_mv.as_int = cur_mv[0].as_int;
-            tmp_rate_mv = rate_mv;
-          }
-        } else {
-          tmp_mv.as_int = cur_mv[0].as_int;
-          tmp_rate_mv = rate_mv;
-          av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
-        }
-        // Evaluate closer to true rd
-        av1_subtract_plane(x, bsize, 0);
-        rd =
-            estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
-                                &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
-        if (rd != INT64_MAX)
-          rd = RDCOST(x->rdmult, rmode + tmp_rate_mv + rwedge + rate_sum,
-                      dist_sum);
-        best_interintra_rd_wedge = rd;
-        if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
-          mbmi->use_wedge_interintra = 1;
-          mbmi->mv[0].as_int = tmp_mv.as_int;
-          rd_stats->rate += tmp_rate_mv - rate_mv;
-          rate_mv = tmp_rate_mv;
         } else {
           mbmi->use_wedge_interintra = 0;
-          mbmi->mv[0].as_int = cur_mv[0].as_int;
         }
-      } else {
-        mbmi->use_wedge_interintra = 0;
       }
-    }
 
-    pred_exists = 0;
-    compmode_interintra_cost = x->interintra_cost[size_group_lookup[bsize]][1] +
-                               interintra_mode_cost[mbmi->interintra_mode];
-    if (is_interintra_wedge_used(bsize)) {
-      compmode_interintra_cost +=
-          x->wedge_interintra_cost[bsize][mbmi->use_wedge_interintra];
-      if (mbmi->use_wedge_interintra) {
+      pred_exists = 0;
+      compmode_interintra_cost =
+          x->interintra_cost[size_group_lookup[bsize]][1] +
+          interintra_mode_cost[mbmi->interintra_mode];
+      if (is_interintra_wedge_used(bsize)) {
         compmode_interintra_cost +=
-            av1_cost_literal(get_interintra_wedge_bits(bsize));
+            x->wedge_interintra_cost[bsize][mbmi->use_wedge_interintra];
+        if (mbmi->use_wedge_interintra) {
+          compmode_interintra_cost +=
+              av1_cost_literal(get_interintra_wedge_bits(bsize));
+        }
       }
+    } else if (is_interintra_allowed(mbmi)) {
+      compmode_interintra_cost =
+          x->interintra_cost[size_group_lookup[bsize]][0];
     }
-  } else if (is_interintra_allowed(mbmi)) {
-    compmode_interintra_cost = x->interintra_cost[size_group_lookup[bsize]][0];
-  }
 
-  if (pred_exists == 0) {
-    int tmp_rate;
-    int64_t tmp_dist;
-    av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, &orig_dst, bsize);
-    model_rd_for_sb(cpi, bsize, x, xd, 0, MAX_MB_PLANE - 1, &tmp_rate,
-                    &tmp_dist, &skip_txfm_sb, &skip_sse_sb);
-    rd = RDCOST(x->rdmult, rs + tmp_rate, tmp_dist);
-  }
+    if (pred_exists == 0) {
+      int tmp_rate;
+      int64_t tmp_dist;
+      av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, &orig_dst, bsize);
+      model_rd_for_sb(cpi, bsize, x, xd, 0, MAX_MB_PLANE - 1, &tmp_rate,
+                      &tmp_dist, &skip_txfm_sb, &skip_sse_sb);
+      rd = RDCOST(x->rdmult, rs + tmp_rate, tmp_dist);
+    }
 
-  if (!is_comp_pred)
-    args->single_filter[this_mode][refs[0]] =
-        av1_extract_interp_filter(mbmi->interp_filters, 0);
+    if (!is_comp_pred)
+      args->single_filter[this_mode][refs[0]] =
+          av1_extract_interp_filter(mbmi->interp_filters, 0);
 
-  if (args->modelled_rd != NULL) {
-    if (is_comp_pred) {
-      const int mode0 = compound_ref0_mode(this_mode);
-      const int mode1 = compound_ref1_mode(this_mode);
-      const int64_t mrd = AOMMIN(args->modelled_rd[mode0][refs[0]],
-                                 args->modelled_rd[mode1][refs[1]]);
-      if (rd / 4 * 3 > mrd && ref_best_rd < INT64_MAX) {
-        restore_dst_buf(xd, orig_dst);
+    if (args->modelled_rd != NULL) {
+      if (is_comp_pred) {
+        const int mode0 = compound_ref0_mode(this_mode);
+        const int mode1 = compound_ref1_mode(this_mode);
+        const int64_t mrd = AOMMIN(args->modelled_rd[mode0][refs[0]],
+                                   args->modelled_rd[mode1][refs[1]]);
+        if (rd / 4 * 3 > mrd && ref_best_rd < INT64_MAX) {
+          restore_dst_buf(xd, orig_dst);
+#if CONFIG_JNT_COMP
+          early_terminate = INT64_MAX;
+          continue;
+#else
         return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+        }
+      } else if (!is_comp_interintra_pred) {
+        args->modelled_rd[this_mode][refs[0]] = rd;
       }
-    } else if (!is_comp_interintra_pred) {
-      args->modelled_rd[this_mode][refs[0]] = rd;
     }
-  }
 
-  if (cpi->sf.use_rd_breakout && ref_best_rd < INT64_MAX) {
-    // if current pred_error modeled rd is substantially more than the best
-    // so far, do not bother doing full rd
-    if (rd / 2 > ref_best_rd) {
-      restore_dst_buf(xd, orig_dst);
+    if (cpi->sf.use_rd_breakout && ref_best_rd < INT64_MAX) {
+      // if current pred_error modeled rd is substantially more than the best
+      // so far, do not bother doing full rd
+      if (rd / 2 > ref_best_rd) {
+        restore_dst_buf(xd, orig_dst);
+#if CONFIG_JNT_COMP
+        early_terminate = INT64_MAX;
+        continue;
+#else
       return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
+      }
+    }
+
+    rd_stats->rate += compmode_interintra_cost;
+    rate2_bmc_nocoeff += compmode_interintra_cost;
+    rd_stats->rate += compmode_interinter_cost;
+
+    ret_val = motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y, rd_stats_uv,
+                             disable_skip, mode_mv, mi_row, mi_col, args,
+                             ref_best_rd, refs, rate_mv, single_newmv,
+                             rate2_bmc_nocoeff, &best_bmc_mbmi, rate_mv_bmc, rs,
+                             &skip_txfm_sb, &skip_sse_sb, &orig_dst);
+#if CONFIG_JNT_COMP
+    if (is_comp_pred && ret_val != INT64_MAX) {
+      int64_t tmp_rd;
+      const int skip_ctx = av1_get_skip_context(xd);
+      if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) <
+          RDCOST(x->rdmult, 0, rd_stats->sse))
+        tmp_rd = RDCOST(x->rdmult, rd_stats->rate + x->skip_cost[skip_ctx][0],
+                        rd_stats->dist);
+      else
+        tmp_rd = RDCOST(x->rdmult,
+                        rd_stats->rate + x->skip_cost[skip_ctx][1] -
+                            rd_stats_y->rate - rd_stats_uv->rate,
+                        rd_stats->sse);
+
+      if (tmp_rd < best_rd) {
+        best_rd_stats = *rd_stats;
+        best_rd_stats_y = *rd_stats_y;
+        best_rd_stats_uv = *rd_stats_uv;
+        best_compound_idx = mbmi->compound_idx;
+        best_ret_val = ret_val;
+        best_rd = tmp_rd;
+        best_mbmi = *mbmi;
+        for (i = 0; i < MAX_MB_PLANE; ++i)
+          memcpy(best_blk_skip[i], x->blk_skip[i],
+                 sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
+      }
     }
   }
-
-  rd_stats->rate += compmode_interintra_cost;
-  rate2_bmc_nocoeff += compmode_interintra_cost;
-  rd_stats->rate += compmode_interinter_cost;
-
-  ret_val = motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y, rd_stats_uv,
-                           disable_skip, mode_mv, mi_row, mi_col, args,
-                           ref_best_rd, refs, rate_mv, single_newmv,
-                           rate2_bmc_nocoeff, &best_bmc_mbmi, rate_mv_bmc, rs,
-                           &skip_txfm_sb, &skip_sse_sb, &orig_dst);
+  // re-instate status of the best choice
+  if (is_comp_pred && best_ret_val != INT64_MAX) {
+    *rd_stats = best_rd_stats;
+    *rd_stats_y = best_rd_stats_y;
+    *rd_stats_uv = best_rd_stats_uv;
+    mbmi->compound_idx = best_compound_idx;
+    ret_val = best_ret_val;
+    *mbmi = best_mbmi;
+    for (i = 0; i < MAX_MB_PLANE; ++i)
+      memcpy(x->blk_skip[i], best_blk_skip[i],
+             sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
+  }
+  if (early_terminate == INT64_MAX) return INT64_MAX;
+#endif  // CONFIG_JNT_COMP
   if (ret_val != 0) return ret_val;
 
   return 0;  // The rate-distortion cost will be re-calculated by caller.
@@ -9643,86 +9806,6 @@
           }
         }
       }
-#if CONFIG_JNT_COMP
-      {
-        int cum_rate = rate2;
-        MB_MODE_INFO backup_mbmi = *mbmi;
-
-        int_mv backup_frame_mv[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
-        int_mv backup_single_newmv[TOTAL_REFS_PER_FRAME];
-        int backup_single_newmv_rate[TOTAL_REFS_PER_FRAME];
-        int64_t backup_modelled_rd[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
-
-        memcpy(backup_frame_mv, frame_mv, sizeof(frame_mv));
-        memcpy(backup_single_newmv, single_newmv, sizeof(single_newmv));
-        memcpy(backup_single_newmv_rate, single_newmv_rate,
-               sizeof(single_newmv_rate));
-        memcpy(backup_modelled_rd, modelled_rd, sizeof(modelled_rd));
-
-        InterpFilters backup_interp_filters = mbmi->interp_filters;
-
-        for (int comp_idx = 0; comp_idx < 1 + has_second_ref(mbmi);
-             ++comp_idx) {
-          RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
-          av1_init_rd_stats(&rd_stats);
-          av1_init_rd_stats(&rd_stats_y);
-          av1_init_rd_stats(&rd_stats_uv);
-          rd_stats.rate = cum_rate;
-
-          memcpy(frame_mv, backup_frame_mv, sizeof(frame_mv));
-          memcpy(single_newmv, backup_single_newmv, sizeof(single_newmv));
-          memcpy(single_newmv_rate, backup_single_newmv_rate,
-                 sizeof(single_newmv_rate));
-          memcpy(modelled_rd, backup_modelled_rd, sizeof(modelled_rd));
-
-          mbmi->interp_filters = backup_interp_filters;
-
-          int dummy_disable_skip = 0;
-
-          // Point to variables that are maintained between loop iterations
-          args.single_newmv = single_newmv;
-          args.single_newmv_rate = single_newmv_rate;
-          args.modelled_rd = modelled_rd;
-          mbmi->compound_idx = comp_idx;
-
-          int64_t tmp_rd = handle_inter_mode(
-              cpi, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
-              &dummy_disable_skip, frame_mv, mi_row, mi_col, &args, best_rd);
-
-          if (tmp_rd < INT64_MAX) {
-            if (RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist) <
-                RDCOST(x->rdmult, 0, rd_stats.sse))
-              tmp_rd =
-                  RDCOST(x->rdmult, rd_stats.rate + x->skip_cost[skip_ctx][0],
-                         rd_stats.dist);
-            else
-              tmp_rd = RDCOST(x->rdmult,
-                              rd_stats.rate + x->skip_cost[skip_ctx][1] -
-                                  rd_stats_y.rate - rd_stats_uv.rate,
-                              rd_stats.sse);
-          }
-
-          if (tmp_rd < this_rd) {
-            this_rd = tmp_rd;
-            rate2 = rd_stats.rate;
-            skippable = rd_stats.skip;
-            distortion2 = rd_stats.dist;
-            total_sse = rd_stats.sse;
-            rate_y = rd_stats_y.rate;
-            rate_uv = rd_stats_uv.rate;
-            disable_skip = dummy_disable_skip;
-            backup_mbmi = *mbmi;
-            for (i = 0; i < MAX_MB_PLANE; ++i)
-              memcpy(x->blk_skip_drl[i], x->blk_skip[i],
-                     sizeof(uint8_t) * ctx->num_4x4_blk);
-          }
-        }
-        *mbmi = backup_mbmi;
-        for (i = 0; i < MAX_MB_PLANE; ++i)
-          memcpy(x->blk_skip[i], x->blk_skip_drl[i],
-                 sizeof(uint8_t) * ctx->num_4x4_blk);
-      }
-#else   // CONFIG_JNT_COMP
       {
         RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
         av1_init_rd_stats(&rd_stats);
@@ -9742,7 +9825,6 @@
         rate_y = rd_stats_y.rate;
         rate_uv = rd_stats_uv.rate;
       }
-#endif  // CONFIG_JNT_COMP
 
       // TODO(jingning): This needs some refactoring to improve code quality
       // and reduce redundant steps.
@@ -9787,21 +9869,11 @@
           memcpy(x->blk_skip_drl[i], x->blk_skip[i],
                  sizeof(uint8_t) * ctx->num_4x4_blk);
 
-#if CONFIG_JNT_COMP
-        for (int sidx = 0; sidx < ref_set * (1 + has_second_ref(mbmi)); ++sidx)
-#else
-        for (ref_idx = 0; ref_idx < ref_set; ++ref_idx)
-#endif  // CONFIG_JNT_COMP
-        {
+        for (ref_idx = 0; ref_idx < ref_set; ++ref_idx) {
           int64_t tmp_alt_rd = INT64_MAX;
           int dummy_disable_skip = 0;
           int_mv cur_mv;
           RD_STATS tmp_rd_stats, tmp_rd_stats_y, tmp_rd_stats_uv;
-#if CONFIG_JNT_COMP
-          ref_idx = sidx;
-          if (has_second_ref(mbmi)) ref_idx /= 2;
-          mbmi->compound_idx = sidx % 2;
-#endif  // CONFIG_JNT_COMP
 
           av1_invalid_rd_stats(&tmp_rd_stats);
 
@@ -9929,9 +10001,6 @@
         for (i = 0; i < MAX_MB_PLANE; ++i)
           memcpy(x->blk_skip[i], x->blk_skip_drl[i],
                  sizeof(uint8_t) * ctx->num_4x4_blk);
-#if CONFIG_JNT_COMP
-        *mbmi = backup_mbmi;
-#endif  // CONFIG_JNT_COMP
       }
       mbmi_ext->ref_mvs[ref_frame][0] = backup_ref_mv[0];
       if (comp_pred) mbmi_ext->ref_mvs[second_ref_frame][0] = backup_ref_mv[1];