Refactor try_average_and_distwtd_comp case

Moved try_average_and_distwtd_comp case out of
the core loop in compound_type_rd and refactored
it to remove redundant code

Change-Id: I822a6c70e908e0c717eedaab25d37c8032cd8c06
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index d700a27..24a55ef 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -9870,6 +9870,62 @@
   // Populates masked_type_cost local array for the 4 compound types
   calc_masked_type_cost(x, bsize, comp_group_idx_ctx, comp_index_ctx,
                         masked_compound_used, masked_type_cost);
+
+  int64_t comp_model_rd_cur = INT64_MAX;
+  int64_t best_rd_cur = INT64_MAX;
+
+  // Special case of COMPOUND_AVERAGE and COMPOUND_DISTWTD search
+  if (try_average_and_distwtd_comp) {
+    int est_rate[2];
+    int64_t est_dist[2], est_rd[2];
+    COMPOUND_TYPE best_type;
+
+    // Calculate model_rd for COMPOUND_AVERAGE and COMPOUND_DISTWTD
+    for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
+         comp_type++) {
+      update_mbmi_for_compound_type(mbmi, comp_type);
+      av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+                                    AOM_PLANE_Y, AOM_PLANE_Y);
+      model_rd_sb_fn[MODELRD_CURVFIT](
+          cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &est_rate[comp_type],
+          &est_dist[comp_type], NULL, NULL, NULL, NULL, NULL);
+      est_rate[comp_type] += masked_type_cost[comp_type];
+      est_rd[comp_type] = RDCOST(x->rdmult, est_rate[comp_type] + *rate_mv,
+                                 est_dist[comp_type]);
+      if (comp_type == COMPOUND_AVERAGE) {
+        *is_luma_interp_done = 1;
+        restore_dst_buf(xd, *tmp_dst, 1);
+      }
+    }
+    // Choose the better of the two based on modeled cost and call
+    // estimate_yrd_for_sb() for that one.
+    best_type = (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD])
+                    ? COMPOUND_AVERAGE
+                    : COMPOUND_DISTWTD;
+    update_mbmi_for_compound_type(mbmi, best_type);
+    if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *orig_dst, 1);
+    rs2 = masked_type_cost[best_type];
+    RD_STATS est_rd_stats;
+    const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
+    const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+    const int64_t est_rd_ =
+        estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+    if (est_rd_ != INT64_MAX) {
+      best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
+                           est_rd_stats.dist);
+      // Backup rate and distortion for future reuse
+      comp_rate[best_type] = est_rd_stats.rate;
+      comp_dist[best_type] = est_rd_stats.dist;
+      comp_model_rd[best_type] = est_rd[best_type];
+      comp_model_rd_cur = est_rd[best_type];
+    }
+    if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
+    if (best_rd_cur < *rd) {
+      update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
+                       comp_model_rd_cur, rs2);
+    }
+  }
+
   for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
     if (((1 << cur_type) & mode_search_mask) == 0) {
       if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
@@ -9877,144 +9933,73 @@
     }
     if (!is_interinter_compound_used(cur_type, bsize)) continue;
     if (cur_type >= COMPOUND_WEDGE && !masked_compound_used) break;
-    if (cur_type == COMPOUND_DISTWTD && !try_distwtd_comp) continue;
+    if (cur_type == COMPOUND_DISTWTD &&
+        (!try_distwtd_comp || try_average_and_distwtd_comp))
+      continue;
     if (cur_type == COMPOUND_AVERAGE && try_average_and_distwtd_comp) continue;
 
-    int64_t comp_model_rd_cur = INT64_MAX;
+    comp_model_rd_cur = INT64_MAX;
     tmp_rate_mv = *rate_mv;
-    int64_t best_rd_cur = INT64_MAX;
+    best_rd_cur = INT64_MAX;
 
-    if (cur_type == COMPOUND_DISTWTD && try_average_and_distwtd_comp) {
-      int est_rate[2];
-      int64_t est_dist[2], est_rd[2];
+    if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
+      update_mbmi_for_compound_type(mbmi, cur_type);
+      rs2 = masked_type_cost[cur_type];
+      const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
+      if (mode_rd < ref_best_rd) {
+        // Reuse data if matching record is found
+        if (comp_rate[cur_type] == INT_MAX) {
+          av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+                                        AOM_PLANE_Y, AOM_PLANE_Y);
+          if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
+          RD_STATS est_rd_stats;
+          const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+          const int64_t est_rd =
+              estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+          if (est_rd != INT64_MAX) {
+            best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
+                                 est_rd_stats.dist);
+            model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
+                cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
+                &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
+            comp_model_rd_cur =
+                RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
 
-      // First find the modeled rd cost for COMPOUND_AVERAGE
-      update_mbmi_for_compound_type(mbmi, COMPOUND_AVERAGE);
-      av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
-                                    AOM_PLANE_Y, AOM_PLANE_Y);
-      *is_luma_interp_done = 1;
-      model_rd_sb_fn[MODELRD_CURVFIT](
-          cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &est_rate[COMPOUND_AVERAGE],
-          &est_dist[COMPOUND_AVERAGE], NULL, NULL, NULL, NULL, NULL);
-      est_rate[COMPOUND_AVERAGE] += masked_type_cost[COMPOUND_AVERAGE];
-      est_rd[COMPOUND_AVERAGE] =
-          RDCOST(x->rdmult, est_rate[COMPOUND_AVERAGE] + *rate_mv,
-                 est_dist[COMPOUND_AVERAGE]);
-      restore_dst_buf(xd, *tmp_dst, 1);
-
-      // Next find the modeled rd cost for COMPOUND_DISTWTD
-      update_mbmi_for_compound_type(mbmi, COMPOUND_DISTWTD);
-      av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
-                                    AOM_PLANE_Y, AOM_PLANE_Y);
-      model_rd_sb_fn[MODELRD_CURVFIT](
-          cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &est_rate[COMPOUND_DISTWTD],
-          &est_dist[COMPOUND_DISTWTD], NULL, NULL, NULL, NULL, NULL);
-      est_rate[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_DISTWTD];
-      est_rd[COMPOUND_DISTWTD] =
-          RDCOST(x->rdmult, est_rate[COMPOUND_DISTWTD] + *rate_mv,
-                 est_dist[COMPOUND_DISTWTD]);
-
-      // Choose the better of the two based on modeled cost and call
-      // estimate_yrd_for_sb() for that one.
-      if (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD]) {
-        update_mbmi_for_compound_type(mbmi, COMPOUND_AVERAGE);
-        restore_dst_buf(xd, *orig_dst, 1);
-        rs2 = masked_type_cost[COMPOUND_AVERAGE];
-        RD_STATS est_rd_stats;
-        const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
-        const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
-        const int64_t est_rd_ =
-            estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
-        if (est_rd_ != INT64_MAX) {
-          best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
-                               est_rd_stats.dist);
-          restore_dst_buf(xd, *tmp_dst, 1);
-          comp_rate[COMPOUND_AVERAGE] = est_rd_stats.rate;
-          comp_dist[COMPOUND_AVERAGE] = est_rd_stats.dist;
-          comp_model_rd[COMPOUND_AVERAGE] = est_rd[COMPOUND_AVERAGE];
-          comp_model_rd_cur = est_rd[COMPOUND_AVERAGE];
-        }
-        restore_dst_buf(xd, *tmp_dst, 1);
-      } else {
-        rs2 = masked_type_cost[COMPOUND_DISTWTD];
-        RD_STATS est_rd_stats;
-        const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
-        const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
-        const int64_t est_rd_ =
-            estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
-
-        if (est_rd_ != INT64_MAX) {
-          best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
-                               est_rd_stats.dist);
-          comp_rate[COMPOUND_DISTWTD] = est_rd_stats.rate;
-          comp_dist[COMPOUND_DISTWTD] = est_rd_stats.dist;
-          comp_model_rd[COMPOUND_DISTWTD] = est_rd[COMPOUND_DISTWTD];
-          comp_model_rd_cur = est_rd[COMPOUND_DISTWTD];
+            // Backup rate and distortion for future reuse
+            comp_rate[cur_type] = est_rd_stats.rate;
+            comp_dist[cur_type] = est_rd_stats.dist;
+            comp_model_rd[cur_type] = comp_model_rd_cur;
+          }
+        } else {
+          // Calculate RD cost based on stored stats
+          assert(comp_dist[cur_type] != INT64_MAX);
+          best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
+                               comp_dist[cur_type]);
+          comp_model_rd_cur = comp_model_rd[cur_type];
         }
       }
+      // use spare buffer for following compound type try
+      if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
     } else {
-      if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
-        update_mbmi_for_compound_type(mbmi, cur_type);
-        rs2 = masked_type_cost[cur_type];
-        const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
-        if (mode_rd < ref_best_rd) {
-          // Reuse data if matching record is found
-          if (comp_rate[cur_type] == INT_MAX) {
-            av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst,
-                                          bsize, AOM_PLANE_Y, AOM_PLANE_Y);
-            if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
-            RD_STATS est_rd_stats;
-            const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
-            const int64_t est_rd = estimate_yrd_for_sb(
-                cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
-            if (est_rd != INT64_MAX) {
-              best_rd_cur =
-                  RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
-                         est_rd_stats.dist);
-              model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
-                  cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
-                  &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
-              comp_model_rd_cur =
-                  RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
+      update_mbmi_for_compound_type(mbmi, cur_type);
+      rs2 = masked_type_cost[cur_type];
+      int64_t approx_rd = ((*rd / cpi->max_comp_type_rd_threshold_div) *
+                           cpi->max_comp_type_rd_threshold_mul);
 
-              // Backup rate and distortion for future reuse
-              comp_rate[cur_type] = est_rd_stats.rate;
-              comp_dist[cur_type] = est_rd_stats.dist;
-              comp_model_rd[cur_type] = comp_model_rd_cur;
-            }
-          } else {
-            // Calculate RD cost based on stored stats
-            assert(comp_dist[cur_type] != INT64_MAX);
-            best_rd_cur =
-                RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
-                       comp_dist[cur_type]);
-            comp_model_rd_cur = comp_model_rd[cur_type];
-          }
-        }
-        // use spare buffer for following compound type try
-        if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
-      } else {
-        update_mbmi_for_compound_type(mbmi, cur_type);
-        rs2 = masked_type_cost[cur_type];
-        int64_t approx_rd = ((*rd / cpi->max_comp_type_rd_threshold_div) *
-                             cpi->max_comp_type_rd_threshold_mul);
+      if (approx_rd < ref_best_rd) {
+        int8_t enable_wedge = ((cur_type == COMPOUND_WEDGE) &&
+                               enable_wedge_interinter_search(x, cpi));
+        int8_t enable_diffwtd =
+            ((cur_type == COMPOUND_DIFFWTD) && cpi->oxcf.enable_diff_wtd_comp);
 
-        if (approx_rd < ref_best_rd) {
-          int8_t enable_wedge = ((cur_type == COMPOUND_WEDGE) &&
-                                 enable_wedge_interinter_search(x, cpi));
-          int8_t enable_diffwtd = ((cur_type == COMPOUND_DIFFWTD) &&
-                                   cpi->oxcf.enable_diff_wtd_comp);
-
-          if (enable_wedge || enable_diffwtd) {
-            const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
-            best_rd_cur = build_and_cost_compound_type(
-                cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
-                &tmp_rate_mv, preds0, preds1, buffers->residual1,
-                buffers->diff10, strides, mi_row, mi_col, rd_stats->rate,
-                tmp_rd_thresh, &calc_pred_masked_compound, comp_rate, comp_dist,
-                comp_model_rd, best_type_stats.comp_best_model_rd,
-                &comp_model_rd_cur);
-          }
+        if (enable_wedge || enable_diffwtd) {
+          const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
+          best_rd_cur = build_and_cost_compound_type(
+              cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
+              &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
+              strides, mi_row, mi_col, rd_stats->rate, tmp_rd_thresh,
+              &calc_pred_masked_compound, comp_rate, comp_dist, comp_model_rd,
+              best_type_stats.comp_best_model_rd, &comp_model_rd_cur);
         }
       }
     }