Refactor compute_valid_comp_types

Refactored compute_valid_comp_types() to include
all valid compound types to be evaluated so that
compound_type_rd is simplified

Change-Id: I22979ddaf10303be2a40ff62670c7c2d63d84fac
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 526483a..91379f1 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -9804,22 +9804,31 @@
   }
 }
 
-// Computes the valid compound_types to be evaluated in core loop
+// Computes the valid compound_types to be evaluated
 static INLINE int compute_valid_comp_types(
-    MACROBLOCK *x, const AV1_COMP *const cpi, BLOCK_SIZE bsize,
-    int try_average_comp, int try_distwtd_comp,
-    int try_average_and_distwtd_comp, int masked_compound_used,
+    MACROBLOCK *x, const AV1_COMP *const cpi, int *try_average_and_distwtd_comp,
+    int32_t *comp_rate, BLOCK_SIZE bsize, int masked_compound_used,
     int mode_search_mask, COMPOUND_TYPE *valid_comp_types) {
+  const AV1_COMMON *cm = &cpi->common;
   int valid_type_count = 0;
   int comp_type, valid_check;
   int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
 
+  const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
+  const int try_distwtd_comp =
+      ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
+       cm->seq_params.order_hint_info.enable_dist_wtd_comp == 1 &&
+       cpi->sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
+  *try_average_and_distwtd_comp = try_average_comp && try_distwtd_comp &&
+                                  comp_rate[COMPOUND_AVERAGE] == INT_MAX &&
+                                  comp_rate[COMPOUND_DISTWTD] == INT_MAX;
+
   // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
   for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
        comp_type++) {
     valid_check =
         (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
-    if (!try_average_and_distwtd_comp && valid_check &&
+    if (!*try_average_and_distwtd_comp && valid_check &&
         is_interinter_compound_used(comp_type, bsize))
       valid_comp_types[valid_type_count++] = comp_type;
   }
@@ -9886,19 +9895,20 @@
   int rate_sum, tmp_skip_txfm_sb;
   int64_t dist_sum, tmp_skip_sse_sb;
 
-  // Special handling if both compound_average and compound_distwtd
-  // are to be searched. In this case, first estimate between the two
-  // modes and then call estimate_yrd_for_sb() only for the better of
-  // the two.
-  const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
-  const int try_distwtd_comp =
-      ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
-       cm->seq_params.order_hint_info.enable_dist_wtd_comp == 1 &&
-       cpi->sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
-  const int try_average_and_distwtd_comp =
-      try_average_comp && try_distwtd_comp &&
-      comp_rate[COMPOUND_AVERAGE] == INT_MAX &&
-      comp_rate[COMPOUND_DISTWTD] == INT_MAX;
+  // Local array to store the valid compound types to be evaluated in the core
+  // loop
+  COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
+    COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
+  };
+  int valid_type_count = 0;
+  int try_average_and_distwtd_comp = 0;
+  // compute_valid_comp_types() returns the number of valid compound types to be
+  // evaluated and populates the same in the local array valid_comp_types[].
+  // It also sets the flag 'try_average_and_distwtd_comp'
+  valid_type_count = compute_valid_comp_types(
+      x, cpi, &try_average_and_distwtd_comp, comp_rate, bsize,
+      masked_compound_used, mode_search_mask, valid_comp_types);
+
   // The following context indices are independent of compound type
   const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
   const int comp_index_ctx = get_comp_index_context(cm, xd);
@@ -9910,7 +9920,10 @@
   int64_t comp_model_rd_cur = INT64_MAX;
   int64_t best_rd_cur = INT64_MAX;
 
-  // Special case of COMPOUND_AVERAGE and COMPOUND_DISTWTD search
+  // Special handling if both compound_average and compound_distwtd
+  // are to be searched. In this case, first estimate between the two
+  // modes and then call estimate_yrd_for_sb() only for the better of
+  // the two.
   if (try_average_and_distwtd_comp) {
     int est_rate[2];
     int64_t est_dist[2], est_rd[2];
@@ -9963,16 +9976,7 @@
     }
   }
 
-  COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
-    COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
-  };
-  int valid_type_count = 0;
-  // compute_valid_comp_types() returns the number of valid compound types to be
-  // evaluated and populates the same in the local array valid_comp_types[]
-  valid_type_count = compute_valid_comp_types(
-      x, cpi, bsize, try_average_comp, try_distwtd_comp,
-      try_average_and_distwtd_comp, masked_compound_used, mode_search_mask,
-      valid_comp_types);
+  // If COMPOUND_AVERAGE is not valid, use the spare buffer
   if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
 
   for (int i = 0; i < valid_type_count; i++) {