Correct rd thresholds to estimate_yrd_for_sb

          Encode Time      Quality Impact(%)
Preset    Reduction(%)        (AWCY)
  1         0.00              -0.01
  2         0.00              -0.02
  3         0.00              -0.01
  4         0.00              +0.01

STATS_CHANGED

Change-Id: Ide14ba346765416f71d430cc628d4ca42732c15f
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index c76c4f3..08ff00d 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -636,6 +636,18 @@
   },
 };
 /* clang-format on */
+// Calculate rd threshold based on ref best rd and relevant scaling factors
+static INLINE int64_t get_rd_thresh_from_best_rd(int64_t ref_best_rd,
+                                                 int mul_factor,
+                                                 int div_factor) {
+  int64_t rd_thresh = ref_best_rd;
+  if (div_factor != 0) {
+    rd_thresh = ref_best_rd < (div_factor * (INT64_MAX / mul_factor))
+                    ? ((ref_best_rd / div_factor) * mul_factor)
+                    : INT64_MAX;
+  }
+  return rd_thresh;
+}
 
 static int get_prediction_mode_idx(PREDICTION_MODE this_mode,
                                    MV_REFERENCE_FRAME ref_frame,
@@ -3544,6 +3556,7 @@
                                    MACROBLOCK *x, int64_t ref_best_rd,
                                    RD_STATS *rd_stats) {
   MACROBLOCKD *const xd = &x->e_mbd;
+  if (ref_best_rd < 0) return INT64_MAX;
   av1_subtract_plane(x, bs, 0);
   x->rd_model = LOW_TXFM_RD;
   int skip_trellis = cpi->optimize_seg_arr[xd->mi[0]->segment_id] ==
@@ -7758,7 +7771,7 @@
     const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
     int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
     uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
-    int mi_row, int mi_col, int mode_rate, int64_t ref_best_rd,
+    int mi_row, int mi_col, int mode_rate, int64_t rd_thresh,
     int *calc_pred_masked_compound, int32_t *comp_rate, int64_t *comp_dist,
     int64_t *const comp_model_rd, const int64_t comp_best_model_rd,
     int64_t *const comp_model_rd_cur) {
@@ -7804,7 +7817,7 @@
   // is unlikely to be the best mode considering the transform rd cost and other
   // mode overhead cost
   int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
-  if (mode_rd > ref_best_rd) {
+  if (mode_rd > rd_thresh) {
     *comp_model_rd_cur = INT64_MAX;
     return INT64_MAX;
   }
@@ -7851,7 +7864,9 @@
       *comp_model_rd_cur = INT64_MAX;
       return INT64_MAX;
     }
-    rd = estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &rd_stats);
+    const int64_t tmp_mode_rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
+    const int64_t tmp_rd_thresh = rd_thresh - tmp_mode_rd;
+    rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
     if (rd != INT64_MAX) {
       rd =
           RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
@@ -8972,14 +8987,21 @@
     }
 
     RD_STATS rd_stats;
-    rd = estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &rd_stats);
+    const int64_t rd_thresh = get_rd_thresh_from_best_rd(
+        ref_best_rd, (1 << INTER_INTRA_RD_THRESH_SHIFT),
+        INTER_INTRA_RD_THRESH_SCALE);
+    const int64_t mode_rd = RDCOST(x->rdmult, *rate_mv + rmode + rwedge, 0);
+    const int64_t tmp_rd_thresh = rd_thresh - mode_rd;
+    rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
     if (rd != INT64_MAX) {
       rd = RDCOST(x->rdmult, *rate_mv + rmode + rd_stats.rate + rwedge,
                   rd_stats.dist);
     }
     best_interintra_rd = rd;
     if (ref_best_rd < INT64_MAX &&
-        ((best_interintra_rd >> 4) * 9) > ref_best_rd) {
+        ((((best_interintra_rd >> INTER_INTRA_RD_THRESH_SHIFT) *
+           INTER_INTRA_RD_THRESH_SCALE) > ref_best_rd) ||
+         (best_interintra_rd == INT64_MAX))) {
       return -1;
     }
   }
@@ -9069,14 +9091,18 @@
       }
       // Evaluate closer to true rd
       RD_STATS rd_stats;
-      rd = estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &rd_stats);
+      const int64_t mode_rd =
+          RDCOST(x->rdmult, rmode + tmp_rate_mv + rwedge, 0);
+      const int64_t tmp_rd_thresh = best_interintra_rd_nowedge - mode_rd;
+      rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
       if (rd != INT64_MAX) {
         rd = RDCOST(x->rdmult, rmode + tmp_rate_mv + rwedge + rd_stats.rate,
                     rd_stats.dist);
       }
       best_interintra_rd_wedge = rd;
       if ((!cpi->oxcf.enable_smooth_interintra ||
-           cpi->sf.disable_smooth_interintra) &&
+           cpi->sf.disable_smooth_interintra ||
+           best_interintra_rd_nowedge == INT64_MAX) &&
           best_interintra_rd_wedge == INT64_MAX)
         return -1;
       if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
@@ -9687,7 +9713,8 @@
     int mi_row, int_mv *cur_mv, int mode_search_mask, int masked_compound_used,
     const BUFFER_SET *orig_dst, const BUFFER_SET *tmp_dst,
     CompoundTypeRdBuffers *buffers, int *rate_mv, int64_t *rd,
-    RD_STATS *rd_stats, int64_t ref_best_rd, int *is_luma_interp_done) {
+    RD_STATS *rd_stats, int64_t ref_best_rd, int *is_luma_interp_done,
+    int64_t rd_thresh) {
   const AV1_COMMON *cm = &cpi->common;
   MACROBLOCKD *xd = &x->e_mbd;
   MB_MODE_INFO *mbmi = xd->mi[0];
@@ -9803,10 +9830,12 @@
         mbmi->interinter_comp.type = COMPOUND_AVERAGE;
         mbmi->compound_idx = 1;
         restore_dst_buf(xd, *orig_dst, 1);
-        RD_STATS est_rd_stats;
-        const int64_t est_rd_ =
-            estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
         rs2 = masked_type_cost[COMPOUND_AVERAGE];
+        RD_STATS est_rd_stats;
+        const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
+        const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+        const int64_t est_rd_ =
+            estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
         if (est_rd_ != INT64_MAX) {
           best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
                                est_rd_stats.dist);
@@ -9818,10 +9847,13 @@
         }
         restore_dst_buf(xd, *tmp_dst, 1);
       } else {
-        RD_STATS est_rd_stats;
-        const int64_t est_rd_ =
-            estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
         rs2 = masked_type_cost[COMPOUND_DISTWTD];
+        RD_STATS est_rd_stats;
+        const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
+        const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+        const int64_t est_rd_ =
+            estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+
         if (est_rd_ != INT64_MAX) {
           best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
                                est_rd_stats.dist);
@@ -9852,8 +9884,9 @@
                                           bsize, AOM_PLANE_Y, AOM_PLANE_Y);
             if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
             RD_STATS est_rd_stats;
-            const int64_t est_rd =
-                estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
+            const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+            const int64_t est_rd = estimate_yrd_for_sb(
+                cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
             if (comp_rate[cur_type] != INT_MAX) {
               assert(comp_rate[cur_type] == est_rd_stats.rate);
               assert(comp_dist[cur_type] == est_rd_stats.dist);
@@ -9896,6 +9929,7 @@
         if (((*rd / cpi->max_comp_type_rd_threshold_div) *
              cpi->max_comp_type_rd_threshold_mul) < ref_best_rd) {
           const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
+          const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
 
           if (!((compound_type == COMPOUND_WEDGE &&
                  !enable_wedge_interinter_search(x, cpi)) ||
@@ -9905,7 +9939,7 @@
                 cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
                 &tmp_rate_mv, preds0, preds1, buffers->residual1,
                 buffers->diff10, strides, mi_row, mi_col, rd_stats->rate,
-                ref_best_rd, &calc_pred_masked_compound, comp_rate, comp_dist,
+                tmp_rd_thresh, &calc_pred_masked_compound, comp_rate, comp_dist,
                 comp_model_rd, comp_best_model_rd, &comp_model_rd_cur);
         }
       }
@@ -10465,13 +10499,19 @@
           }
 
           int64_t best_rd_compound;
+          int64_t rd_thresh;
+          const int comp_type_rd_shift = COMP_TYPE_RD_THRESH_SHIFT;
+          const int comp_type_rd_scale =
+              COMP_TYPE_RD_THRESH_SCALE + 2 * do_two_loop_comp_search;
+          rd_thresh = get_rd_thresh_from_best_rd(
+              ref_best_rd, (1 << comp_type_rd_shift), comp_type_rd_scale);
           compmode_interinter_cost = compound_type_rd(
               cpi, x, bsize, mi_col, mi_row, cur_mv,
               mode_search_mask[comp_loop_idx], masked_compound_used, &orig_dst,
               &tmp_dst, rd_buffers, &rate_mv, &best_rd_compound, rd_stats,
-              ref_best_rd, &is_luma_interp_done);
+              ref_best_rd, &is_luma_interp_done, rd_thresh);
           if (ref_best_rd < INT64_MAX &&
-              (best_rd_compound >> 4) * (11 + 2 * do_two_loop_comp_search) >
+              (best_rd_compound >> comp_type_rd_shift) * comp_type_rd_scale >
                   ref_best_rd) {
             restore_dst_buf(xd, orig_dst, num_planes);
             continue;
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 4a92ba9..338ccc6 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -31,6 +31,10 @@
 #define DEFAULT_CHROMA_INTERP_SKIP_FLAG 2
 #define DEFAULT_INTERP_SKIP_FLAG \
   (DEFAULT_LUMA_INTERP_SKIP_FLAG | DEFAULT_CHROMA_INTERP_SKIP_FLAG)
+#define INTER_INTRA_RD_THRESH_SCALE 9
+#define INTER_INTRA_RD_THRESH_SHIFT 4
+#define COMP_TYPE_RD_THRESH_SCALE 11
+#define COMP_TYPE_RD_THRESH_SHIFT 4
 
 struct TileInfo;
 struct macroblock;