Refactor core loop of compound_type_rd

Simplified the core loop structure and gatings in compound_type_rd

Change-Id: Ib761bdae6ef9453810d61260e12662feff289661
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 24a55ef..88731c3 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7758,7 +7758,7 @@
   }
 }
 
-static int64_t build_and_cost_compound_type(
+static int64_t masked_compound_type_rd(
     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
     const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
     int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
@@ -9804,6 +9804,42 @@
   }
 }
 
+// Computes the valid compound_types to be evaluated in core loop
+static INLINE int compute_valid_comp_types(
+    MACROBLOCK *x, const AV1_COMP *const cpi, BLOCK_SIZE bsize,
+    int try_average_comp, int try_distwtd_comp,
+    int try_average_and_distwtd_comp, int masked_compound_used,
+    int mode_search_mask, COMPOUND_TYPE *valid_comp_types) {
+  int valid_type_count = 0;
+  int comp_type, valid_check;
+  int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
+
+  // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
+  for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
+       comp_type++) {
+    valid_check =
+        (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
+    if (!try_average_and_distwtd_comp && valid_check &&
+        is_interinter_compound_used(comp_type, bsize))
+      valid_comp_types[valid_type_count++] = comp_type;
+  }
+  // Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
+  if (masked_compound_used) {
+    // enable_masked_type[0] corresponds to COMPOUND_WEDGE
+    // enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
+    enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
+    enable_masked_type[1] = cpi->oxcf.enable_diff_wtd_comp;
+    for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
+         comp_type++) {
+      if ((mode_search_mask & (1 << comp_type)) &&
+          is_interinter_compound_used(comp_type, bsize) &&
+          enable_masked_type[comp_type - COMPOUND_WEDGE])
+        valid_comp_types[valid_type_count++] = comp_type;
+    }
+  }
+  return valid_type_count;
+}
+
 static int compound_type_rd(
     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_col,
     int mi_row, int_mv *cur_mv, int mode_search_mask, int masked_compound_used,
@@ -9910,6 +9946,7 @@
     const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
     const int64_t est_rd_ =
         estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+
     if (est_rd_ != INT64_MAX) {
       best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
                            est_rd_stats.dist);
@@ -9926,23 +9963,25 @@
     }
   }
 
-  for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
-    if (((1 << cur_type) & mode_search_mask) == 0) {
-      if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
-      continue;
-    }
-    if (!is_interinter_compound_used(cur_type, bsize)) continue;
-    if (cur_type >= COMPOUND_WEDGE && !masked_compound_used) break;
-    if (cur_type == COMPOUND_DISTWTD &&
-        (!try_distwtd_comp || try_average_and_distwtd_comp))
-      continue;
-    if (cur_type == COMPOUND_AVERAGE && try_average_and_distwtd_comp) continue;
+  COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
+    COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
+  };
+  int valid_type_count = 0;
+  // compute_valid_comp_types() returns the number of valid compound types to be
+  // evaluated and populates the same in the local array valid_comp_types[]
+  valid_type_count = compute_valid_comp_types(
+      x, cpi, bsize, try_average_comp, try_distwtd_comp,
+      try_average_and_distwtd_comp, masked_compound_used, mode_search_mask,
+      valid_comp_types);
+  if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
 
+  for (int i = 0; i < valid_type_count; i++) {
+    cur_type = valid_comp_types[i];
     comp_model_rd_cur = INT64_MAX;
     tmp_rate_mv = *rate_mv;
     best_rd_cur = INT64_MAX;
 
-    if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
+    if (cur_type < COMPOUND_WEDGE) {
       update_mbmi_for_compound_type(mbmi, cur_type);
       rs2 = masked_type_cost[cur_type];
       const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
@@ -9987,20 +10026,13 @@
                            cpi->max_comp_type_rd_threshold_mul);
 
       if (approx_rd < ref_best_rd) {
-        int8_t enable_wedge = ((cur_type == COMPOUND_WEDGE) &&
-                               enable_wedge_interinter_search(x, cpi));
-        int8_t enable_diffwtd =
-            ((cur_type == COMPOUND_DIFFWTD) && cpi->oxcf.enable_diff_wtd_comp);
-
-        if (enable_wedge || enable_diffwtd) {
-          const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
-          best_rd_cur = build_and_cost_compound_type(
-              cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
-              &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
-              strides, mi_row, mi_col, rd_stats->rate, tmp_rd_thresh,
-              &calc_pred_masked_compound, comp_rate, comp_dist, comp_model_rd,
-              best_type_stats.comp_best_model_rd, &comp_model_rd_cur);
-        }
+        const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
+        best_rd_cur = masked_compound_type_rd(
+            cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
+            &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
+            strides, mi_row, mi_col, rd_stats->rate, tmp_rd_thresh,
+            &calc_pred_masked_compound, comp_rate, comp_dist, comp_model_rd,
+            best_type_stats.comp_best_model_rd, &comp_model_rd_cur);
       }
     }
     if (best_rd_cur < *rd) {