Refine compound_type_rd based on model_rd for speed >= 4

Gate compound_type wedge and segment estimate_yrd_for_sb evaluation using model_rd value.

For speed 4 preset, BD-rate impact is seen as 0.00% (as per AWCY runs),
with encode time reduction of 3.00% (averaged across multiple test cases).

STATS_CHANGED
Change-Id: I6877f6d07707e4209ab882bbc2dd4390d7862fd3
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 46947ec..7c4cd00 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -194,6 +194,7 @@
 typedef struct {
   int32_t rate[COMPOUND_TYPES];
   int64_t dist[COMPOUND_TYPES];
+  int64_t comp_model_rd[COMPOUND_TYPES];
   int_mv mv[2];
   MV_REFERENCE_FRAME ref_frames[2];
   PREDICTION_MODE mode;
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index b9f6c82..10c75b1 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7872,13 +7872,17 @@
     int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
     uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
     int mi_row, int mi_col, int mode_rate, int64_t ref_best_rd,
-    int *calc_pred_masked_compound, int32_t *comp_rate, int64_t *comp_dist) {
+    int *calc_pred_masked_compound, int32_t *comp_rate, int64_t *comp_dist,
+    int64_t *const comp_model_rd, const int64_t comp_best_model_rd,
+    int64_t *const comp_model_rd_cur) {
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
   int64_t best_rd_cur = INT64_MAX;
   int64_t rd = INT64_MAX;
   const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
+  int rate_sum, tmp_skip_txfm_sb;
+  int64_t dist_sum, tmp_skip_sse_sb;
 
   // TODO(any): Save pred and mask calculation as well into records. However
   // this may increase memory requirements as compound segment mask needs to be
@@ -7898,8 +7902,10 @@
     const unsigned int mse =
         ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
     // If two predictors are very similar, skip wedge compound mode search
-    if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64))
+    if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
+      *comp_model_rd_cur = INT64_MAX;
       return INT64_MAX;
+    }
   }
 
   best_rd_cur =
@@ -7911,7 +7917,10 @@
   // is unlikely to be the best mode considering the transform rd cost and other
   // mode overhead cost
   int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
-  if (mode_rd > ref_best_rd) return INT64_MAX;
+  if (mode_rd > ref_best_rd) {
+    *comp_model_rd_cur = INT64_MAX;
+    return INT64_MAX;
+  }
 
   // Reuse data if matching record is found
   if (comp_rate[compound_type] == INT_MAX) {
@@ -7922,26 +7931,39 @@
           cpi, x, cur_mv, bsize, this_mode, mi_row, mi_col);
       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
                                     AOM_PLANE_Y, AOM_PLANE_Y);
-      int rate_sum, tmp_skip_txfm_sb;
-      int64_t dist_sum, tmp_skip_sse_sb;
+
       model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
           cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
           &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
       rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
+      *comp_model_rd_cur = rd;
       if (rd >= best_rd_cur) {
         mbmi->mv[0].as_int = cur_mv[0].as_int;
         mbmi->mv[1].as_int = cur_mv[1].as_int;
         *out_rate_mv = rate_mv;
         av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
                                                  strides, preds1, strides);
+        *comp_model_rd_cur = best_rd_cur;
       }
     } else {
       *out_rate_mv = rate_mv;
       av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
                                                preds1, strides);
+      model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
+          cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
+          &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
+      *comp_model_rd_cur =
+          RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
     }
 
     RD_STATS rd_stats;
+
+    if (cpi->sf.prune_comp_type_by_model_rd &&
+        (*comp_model_rd_cur > comp_best_model_rd) &&
+        comp_best_model_rd != INT64_MAX) {
+      *comp_model_rd_cur = INT64_MAX;
+      return INT64_MAX;
+    }
     rd = estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &rd_stats);
     if (rd != INT64_MAX) {
       rd =
@@ -7949,6 +7971,7 @@
       // Backup rate and distortion for future reuse
       comp_rate[compound_type] = rd_stats.rate;
       comp_dist[compound_type] = rd_stats.dist;
+      comp_model_rd[compound_type] = *comp_model_rd_cur;
     }
   } else {
     assert(comp_dist[compound_type] != INT64_MAX);
@@ -7962,6 +7985,7 @@
     // Calculate RD cost based on stored stats
     rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
                 comp_dist[compound_type]);
+    *comp_model_rd_cur = comp_model_rd[compound_type];
   }
   return rd;
 }
@@ -8415,7 +8439,8 @@
                                    const MACROBLOCK *const x,
                                    const COMP_RD_STATS *st,
                                    const MB_MODE_INFO *const mi,
-                                   int32_t *comp_rate, int64_t *comp_dist) {
+                                   int32_t *comp_rate, int64_t *comp_dist,
+                                   int64_t *comp_model_rd) {
   // TODO(ranjit): Ensure that compound type search use regular filter always
   // and check if following check can be removed
   // Check if interp filter matches with previous case
@@ -8435,8 +8460,10 @@
   // Store the stats for compound average
   comp_rate[COMPOUND_AVERAGE] = st->rate[COMPOUND_AVERAGE];
   comp_dist[COMPOUND_AVERAGE] = st->dist[COMPOUND_AVERAGE];
+  comp_model_rd[COMPOUND_AVERAGE] = st->comp_model_rd[COMPOUND_AVERAGE];
   comp_rate[COMPOUND_DISTWTD] = st->rate[COMPOUND_DISTWTD];
   comp_dist[COMPOUND_DISTWTD] = st->dist[COMPOUND_DISTWTD];
+  comp_model_rd[COMPOUND_DISTWTD] = st->comp_model_rd[COMPOUND_DISTWTD];
 
   // For compound wedge/segment, reuse data only if NEWMV is not present in
   // either of the directions
@@ -8447,6 +8474,8 @@
            sizeof(comp_rate[COMPOUND_WEDGE]) * 2);
     memcpy(&comp_dist[COMPOUND_WEDGE], &st->dist[COMPOUND_WEDGE],
            sizeof(comp_dist[COMPOUND_WEDGE]) * 2);
+    memcpy(&comp_model_rd[COMPOUND_WEDGE], &st->comp_model_rd[COMPOUND_WEDGE],
+           sizeof(comp_model_rd[COMPOUND_WEDGE]) * 2);
   }
   return 1;
 }
@@ -8469,11 +8498,11 @@
 static INLINE int find_comp_rd_in_stats(const AV1_COMP *const cpi,
                                         const MACROBLOCK *x,
                                         const MB_MODE_INFO *const mbmi,
-                                        int32_t *comp_rate,
-                                        int64_t *comp_dist) {
+                                        int32_t *comp_rate, int64_t *comp_dist,
+                                        int64_t *comp_model_rd) {
   for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
     if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
-                         comp_dist)) {
+                         comp_dist, comp_model_rd)) {
       return 1;
     }
   }
@@ -8506,12 +8535,15 @@
                                             const MB_MODE_INFO *const mbmi,
                                             const int32_t *comp_rate,
                                             const int64_t *comp_dist,
+                                            const int64_t *comp_model_rd,
                                             const int_mv *cur_mv) {
   const int offset = x->comp_rd_stats_idx;
   if (offset < MAX_COMP_RD_STATS) {
     COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
     memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
     memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
+    memcpy(rd_stats->comp_model_rd, comp_model_rd,
+           sizeof(rd_stats->comp_model_rd));
     memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
     memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
     rd_stats->mode = mbmi->mode;
@@ -9782,13 +9814,19 @@
   int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
                                         INT64_MAX };
   int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
+  int64_t comp_model_rd[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
+                                            INT64_MAX };
   // TODO(debargha): Remove the code related to comp_rd_stats since it is
   // not used.
   const int match_found =
-      find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist);
+      find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rd);
+
   best_mv[0].as_int = cur_mv[0].as_int;
   best_mv[1].as_int = cur_mv[1].as_int;
   *rd = INT64_MAX;
+  int rate_sum, tmp_skip_txfm_sb;
+  int64_t dist_sum, tmp_skip_sse_sb;
+  int64_t comp_best_model_rd = INT64_MAX;
   for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
     if (cur_type >= COMPOUND_WEDGE && !masked_compound_used) break;
     if (!is_interinter_compound_used(cur_type, bsize)) continue;
@@ -9798,6 +9836,7 @@
          mbmi->mode == GLOBAL_GLOBALMV))
       continue;
     if (((1 << cur_type) & mode_search_mask) == 0) continue;
+    int64_t comp_model_rd_cur = INT64_MAX;
     tmp_rate_mv = *rate_mv;
     int64_t best_rd_cur = INT64_MAX;
     mbmi->interinter_comp.type = cur_type;
@@ -9831,15 +9870,23 @@
           if (est_rd != INT64_MAX) {
             best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
                                  est_rd_stats.dist);
+            model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
+                cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
+                &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
+            comp_model_rd_cur =
+                RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
+
             // Backup rate and distortion for future reuse
             comp_rate[cur_type] = est_rd_stats.rate;
             comp_dist[cur_type] = est_rd_stats.dist;
+            comp_model_rd[cur_type] = comp_model_rd_cur;
           }
         } else {
           // Calculate RD cost based on stored stats
           assert(comp_dist[cur_type] != INT64_MAX);
           best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
                                comp_dist[cur_type]);
+          comp_model_rd_cur = comp_model_rd[cur_type];
         }
       }
       // use spare buffer for following compound type try
@@ -9865,11 +9912,13 @@
               cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
               &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
               strides, mi_row, mi_col, rd_stats->rate, ref_best_rd,
-              &calc_pred_masked_compound, comp_rate, comp_dist);
+              &calc_pred_masked_compound, comp_rate, comp_dist, comp_model_rd,
+              comp_best_model_rd, &comp_model_rd_cur);
       }
     }
     if (best_rd_cur < *rd) {
       *rd = best_rd_cur;
+      comp_best_model_rd = comp_model_rd_cur;
       best_compound_data = mbmi->interinter_comp;
       if (masked_compound_used && cur_type != COMPOUND_TYPES - 1) {
         memcpy(buffers->tmp_best_mask_buf, xd->seg_mask, mask_len);
@@ -9906,7 +9955,8 @@
   }
   restore_dst_buf(xd, *orig_dst, 1);
   if (!match_found)
-    save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, cur_mv);
+    save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rd,
+                             cur_mv);
   return best_compmode_interinter_cost;
 }
 
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index d4b0228..c66dd6e 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -350,6 +350,7 @@
     sf->alt_ref_search_fp = 1;
     sf->skip_sharp_interp_filter_search = 1;
     sf->perform_coeff_opt = is_boosted_arf2_bwd_type ? 2 : 4;
+    sf->prune_comp_type_by_model_rd = boosted ? 0 : 1;
     sf->adaptive_txb_search_level = boosted ? 2 : 3;
   }
 
@@ -785,6 +786,7 @@
   sf->prune_warp_using_wmtype = 0;
   sf->disable_wedge_interintra_search = 0;
   sf->perform_coeff_opt = 0;
+  sf->prune_comp_type_by_model_rd = 0;
   sf->disable_smooth_intra = 0;
 
   if (oxcf->mode == GOOD)
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 7306bd2..cfa5831 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -670,6 +670,10 @@
   // This flag controls the use of non-RD mode decision.
   int use_nonrd_pick_mode;
 
+  // prune wedge and compound segment approximate rd evaluation based on
+  // compound average modeled rd
+  int prune_comp_type_by_model_rd;
+
   // Enable/disable smooth intra modes.
   int disable_smooth_intra;
 } SPEED_FEATURES;