Add skip rd based gating for tx in compound type rd

Skip rd based gating is introduced for estimate_yrd_for_sb()
calls from compound type rd.

            Encode Time             Quality loss
cpu-used     Reduction     avg.psnr    ovr.psnr    ssim
   3           0.507%      0.0023%     0.0027%    0.0159%
   4           1.617%      0.0422%     0.0561%    0.0194%
   5           2.961%      0.2044%     0.2238%    0.1039%

STATS_CHANGED

Change-Id: Iaefebd7f88ceeeb35f28a9249751957d7661780b
diff --git a/av1/encoder/compound_type.c b/av1/encoder/compound_type.c
index 14d30b0..83ecdee 100644
--- a/av1/encoder/compound_type.c
+++ b/av1/encoder/compound_type.c
@@ -20,7 +20,8 @@
 typedef int64_t (*pick_interinter_mask_type)(
     const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
     const uint8_t *const p0, const uint8_t *const p1,
-    const int16_t *const residual1, const int16_t *const diff10);
+    const int16_t *const residual1, const int16_t *const diff10,
+    uint64_t *best_sse);
 
 // Checks if characteristics of search match
 static INLINE int is_comp_rd_match(const AV1_COMP *const cpi,
@@ -184,7 +185,7 @@
                           const int16_t *const residual1,
                           const int16_t *const diff10,
                           int8_t *const best_wedge_sign,
-                          int8_t *const best_wedge_index) {
+                          int8_t *const best_wedge_index, uint64_t *best_sse) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   const struct buf_2d *const src = &x->plane[0].src;
   const int bw = block_size_wide[bsize];
@@ -247,6 +248,7 @@
       *best_wedge_index = wedge_index;
       *best_wedge_sign = wedge_sign;
       best_rd = rd;
+      *best_sse = sse;
     }
   }
 
@@ -255,13 +257,11 @@
 }
 
 // Choose the best wedge index the specified sign
-static int64_t pick_wedge_fixed_sign(const AV1_COMP *const cpi,
-                                     const MACROBLOCK *const x,
-                                     const BLOCK_SIZE bsize,
-                                     const int16_t *const residual1,
-                                     const int16_t *const diff10,
-                                     const int8_t wedge_sign,
-                                     int8_t *const best_wedge_index) {
+static int64_t pick_wedge_fixed_sign(
+    const AV1_COMP *const cpi, const MACROBLOCK *const x,
+    const BLOCK_SIZE bsize, const int16_t *const residual1,
+    const int16_t *const diff10, const int8_t wedge_sign,
+    int8_t *const best_wedge_index, uint64_t *best_sse) {
   const MACROBLOCKD *const xd = &x->e_mbd;
 
   const int bw = block_size_wide[bsize];
@@ -290,6 +290,7 @@
     if (rd < best_rd) {
       *best_wedge_index = wedge_index;
       best_rd = rd;
+      *best_sse = sse;
     }
   }
   return best_rd -
@@ -299,7 +300,8 @@
 static int64_t pick_interinter_wedge(
     const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
     const uint8_t *const p0, const uint8_t *const p1,
-    const int16_t *const residual1, const int16_t *const diff10) {
+    const int16_t *const residual1, const int16_t *const diff10,
+    uint64_t *best_sse) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
   const int bw = block_size_wide[bsize];
@@ -314,10 +316,10 @@
   if (cpi->sf.inter_sf.fast_wedge_sign_estimate) {
     wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
     rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
-                               &wedge_index);
+                               &wedge_index, best_sse);
   } else {
     rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
-                    &wedge_index);
+                    &wedge_index, best_sse);
   }
 
   mbmi->interinter_comp.wedge_sign = wedge_sign;
@@ -330,7 +332,8 @@
                                    const uint8_t *const p0,
                                    const uint8_t *const p1,
                                    const int16_t *const residual1,
-                                   const int16_t *const diff10) {
+                                   const int16_t *const diff10,
+                                   uint64_t *best_sse) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
   const int bw = block_size_wide[bsize];
@@ -368,6 +371,7 @@
     if (rd0 < best_rd) {
       best_mask_type = cur_mask_type;
       best_rd = rd0;
+      *best_sse = sse;
     }
   }
   mbmi->interinter_comp.mask_type = best_mask_type;
@@ -407,8 +411,9 @@
   aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
 #endif
   int8_t wedge_index = -1;
-  int64_t rd =
-      pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0, &wedge_index);
+  uint64_t sse;
+  int64_t rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0,
+                                     &wedge_index, &sse);
 
   mbmi->interintra_wedge_index = wedge_index;
   return rd;
@@ -989,7 +994,7 @@
     int mode_rate, int64_t rd_thresh, int *calc_pred_masked_compound,
     int32_t *comp_rate, int64_t *comp_dist, int32_t *comp_model_rate,
     int64_t *comp_model_dist, const int64_t comp_best_model_rd,
-    int64_t *const comp_model_rd_cur, int *comp_rs2) {
+    int64_t *const comp_model_rd_cur, int *comp_rs2, int64_t ref_skip_rd) {
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
@@ -1030,10 +1035,13 @@
   // Function pointer to pick the appropriate mask
   // compound_type == COMPOUND_WEDGE, calls pick_interinter_wedge()
   // compound_type == COMPOUND_DIFFWTD, calls pick_interinter_seg()
+  uint64_t cur_sse = UINT64_MAX;
   best_rd_cur = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
-      cpi, x, bsize, *preds0, *preds1, residual1, diff10);
+      cpi, x, bsize, *preds0, *preds1, residual1, diff10, &cur_sse);
   *rs2 += get_interinter_compound_mask_rate(x, mbmi);
   best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
+  assert(cur_sse != UINT64_MAX);
+  int64_t skip_rd_cur = RDCOST(x->rdmult, *rs2 + rate_mv, cur_sse);
 
   // Although the true rate_mv might be different after motion search, but it
   // is unlikely to be the best mode considering the transform rd cost and other
@@ -1044,6 +1052,18 @@
     return INT64_MAX;
   }
 
+  // Check if the mode is good enough based on skip rd
+  // TODO(nithya): Handle wedge_newmv_search if extending for lower speed
+  // setting
+  if (cpi->sf.inter_sf.txfm_rd_gate_level) {
+    int eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd_cur,
+                                    cpi->sf.inter_sf.txfm_rd_gate_level, 1);
+    if (!eval_txfm) {
+      *comp_model_rd_cur = INT64_MAX;
+      return INT64_MAX;
+    }
+  }
+
   // Compute cost if matching record not found, else, reuse data
   if (comp_rate[compound_type] == INT_MAX) {
     // Check whether new MV search for wedge is to be done
@@ -1081,7 +1101,7 @@
 
       tmp_rd = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
           cpi, x, bsize, *tmp_preds0, *tmp_preds1, tmp_buf.residual1,
-          tmp_buf.diff10);
+          tmp_buf.diff10, &cur_sse);
       // we can reuse rs2 here
       tmp_rd += RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
 
@@ -1166,7 +1186,8 @@
                          const BUFFER_SET *tmp_dst,
                          const CompoundTypeRdBuffers *buffers, int *rate_mv,
                          int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
-                         int *is_luma_interp_done, int64_t rd_thresh) {
+                         int64_t ref_skip_rd, int *is_luma_interp_done,
+                         int64_t rd_thresh) {
   const AV1_COMMON *cm = &cpi->common;
   MACROBLOCKD *xd = &x->e_mbd;
   MB_MODE_INFO *mbmi = xd->mi[0];
@@ -1275,6 +1296,7 @@
       }
       restore_dst_buf(xd, *tmp_dst, 1);
     } else {
+      int64_t sse_y[COMPOUND_DISTWTD + 1];
       // Calculate model_rd for COMPOUND_AVERAGE and COMPOUND_DISTWTD
       for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
            comp_type++) {
@@ -1287,6 +1309,7 @@
         est_rate[comp_type] += masked_type_cost[comp_type];
         comp_model_rate[comp_type] = est_rate[comp_type];
         comp_model_dist[comp_type] = est_dist[comp_type];
+        sse_y[comp_type] = x->pred_sse[xd->mi[0]->ref_frame[0]];
         if (comp_type == COMPOUND_AVERAGE) {
           *is_luma_interp_done = 1;
           restore_dst_buf(xd, *tmp_dst, 1);
@@ -1302,8 +1325,19 @@
       RD_STATS est_rd_stats;
       const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
       const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
-      const int64_t est_rd_ =
-          estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+      int64_t est_rd_ = INT64_MAX;
+      int eval_txfm = 1;
+      // Check if the mode is good enough based on skip rd
+      if (cpi->sf.inter_sf.txfm_rd_gate_level) {
+        int64_t skip_rd = RDCOST(x->rdmult, rs2 + *rate_mv, sse_y[best_type]);
+        eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd,
+                                    cpi->sf.inter_sf.txfm_rd_gate_level, 1);
+      }
+      // Evaluate further if skip rd is low enough
+      if (eval_txfm) {
+        est_rd_ =
+            estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+      }
 
       if (est_rd_ != INT64_MAX) {
         best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
@@ -1348,8 +1382,21 @@
           // Compute RD cost for the current type
           RD_STATS est_rd_stats;
           const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
-          const int64_t est_rd =
-              estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+          int64_t est_rd = INT64_MAX;
+          int eval_txfm = 1;
+          // Check if the mode is good enough based on skip rd
+          if (cpi->sf.inter_sf.txfm_rd_gate_level) {
+            int64_t sse_y = compute_sse_plane(x, xd, PLANE_TYPE_Y, bsize);
+            int64_t skip_rd = RDCOST(x->rdmult, rs2 + *rate_mv, sse_y);
+            eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd,
+                                        cpi->sf.inter_sf.txfm_rd_gate_level, 1);
+          }
+          // Evaluate further if skip rd is low enough
+          if (eval_txfm) {
+            est_rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh,
+                                         &est_rd_stats);
+          }
+
           if (est_rd != INT64_MAX) {
             best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
                                  est_rd_stats.dist);
@@ -1393,7 +1440,8 @@
             &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
             strides, rd_stats->rate, tmp_rd_thresh, &calc_pred_masked_compound,
             comp_rate, comp_dist, comp_model_rate, comp_model_dist,
-            best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2);
+            best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2,
+            ref_skip_rd);
       }
     }
     // Update stats for best compound type
diff --git a/av1/encoder/compound_type.h b/av1/encoder/compound_type.h
index c1ef298..f2bd857 100644
--- a/av1/encoder/compound_type.h
+++ b/av1/encoder/compound_type.h
@@ -38,7 +38,8 @@
                          const BUFFER_SET *tmp_dst,
                          const CompoundTypeRdBuffers *buffers, int *rate_mv,
                          int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
-                         int *is_luma_interp_done, int64_t rd_thresh);
+                         int64_t ref_skip_rd, int *is_luma_interp_done,
+                         int64_t rd_thresh);
 
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/av1/encoder/model_rd.h b/av1/encoder/model_rd.h
index 1ca636d..c353c8f 100644
--- a/av1/encoder/model_rd.h
+++ b/av1/encoder/model_rd.h
@@ -70,6 +70,21 @@
   return sse;
 }
 
+static AOM_INLINE int64_t compute_sse_plane(MACROBLOCK *x, MACROBLOCKD *xd,
+                                            int plane, const BLOCK_SIZE bsize) {
+  struct macroblockd_plane *const pd = &xd->plane[plane];
+  const BLOCK_SIZE plane_bsize =
+      get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
+  int bw, bh;
+  const struct macroblock_plane *const p = &x->plane[plane];
+  get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
+                     &bh);
+
+  int64_t sse = calculate_sse(xd, p, pd, bw, bh);
+
+  return sse;
+}
+
 static AOM_INLINE void model_rd_from_sse(const AV1_COMP *const cpi,
                                          const MACROBLOCK *const x,
                                          BLOCK_SIZE plane_bsize, int plane,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 1bf7b81..4e19e24 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1502,11 +1502,9 @@
       rd_stats->rdcost = est_rd;
       if (rd_stats->rdcost < *best_est_rd) {
         *best_est_rd = rd_stats->rdcost;
-        int64_t skip_rdy = INT64_MAX;
-        if (cpi->sf.inter_sf.txfm_rd_gate_level) {
-          skip_rdy = RDCOST(x->rdmult, mode_rate, sse_y);
-        }
-        ref_skip_rd[1] = skip_rdy;
+        ref_skip_rd[1] = cpi->sf.inter_sf.txfm_rd_gate_level
+                             ? RDCOST(x->rdmult, mode_rate, sse_y)
+                             : INT64_MAX;
       }
       if (cm->current_frame.reference_mode == SINGLE_REFERENCE) {
         if (!is_comp_pred) {
@@ -1535,7 +1533,7 @@
         skip_rd = RDCOST(x->rdmult, rd_stats->rate, curr_sse);
         skip_rdy = RDCOST(x->rdmult, rd_stats->rate, sse_y);
         int eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd[0], skip_rd,
-                                        cpi->sf.inter_sf.txfm_rd_gate_level);
+                                        cpi->sf.inter_sf.txfm_rd_gate_level, 0);
         if (!eval_txfm) continue;
       }
 
@@ -2491,7 +2489,7 @@
       compmode_interinter_cost = av1_compound_type_rd(
           cpi, x, bsize, cur_mv, mode_search_mask, masked_compound_used,
           &orig_dst, &tmp_dst, rd_buffers, &rate_mv, &best_rd_compound,
-          rd_stats, ref_best_rd, &is_luma_interp_done, rd_thresh);
+          rd_stats, ref_best_rd, skip_rd[1], &is_luma_interp_done, rd_thresh);
       if (ref_best_rd < INT64_MAX &&
           (best_rd_compound >> comp_type_rd_shift) * comp_type_rd_scale >
               ref_best_rd) {
@@ -4633,7 +4631,7 @@
         skip_rd = RDCOST(x->rdmult, mode_rate, curr_sse);
         int eval_txfm =
             check_txfm_eval(x, bsize, search_state.best_skip_rd[0], skip_rd,
-                            cpi->sf.inter_sf.txfm_rd_gate_level);
+                            cpi->sf.inter_sf.txfm_rd_gate_level, 0);
         if (!eval_txfm) continue;
       }
 
diff --git a/av1/encoder/rdopt_utils.h b/av1/encoder/rdopt_utils.h
index 4a4344e..4d0b05e 100644
--- a/av1/encoder/rdopt_utils.h
+++ b/av1/encoder/rdopt_utils.h
@@ -329,21 +329,29 @@
 
 static INLINE int check_txfm_eval(MACROBLOCK *const x, BLOCK_SIZE bsize,
                                   int64_t best_skip_rd, int64_t skip_rd,
-                                  int level) {
+                                  int level, int is_luma_only) {
   int eval_txfm = 1;
   // Derive aggressiveness factor for gating the transform search
-  // Lower value indicates more aggresiveness. Be more conservative (high value)
-  // for (i) low quantizers (ii) regions where prediction is poor
-  const int scale[3] = { INT_MAX, 3, 2 };
+  // Lower value indicates more aggressiveness. Be more conservative (high
+  // value) for (i) low quantizers (ii) regions where prediction is poor
+  const int scale[4] = { INT_MAX, 3, 3, 2 };
+
   int aggr_factor =
       AOMMAX(1, ((MAXQ - x->qindex) * 2 + QINDEX_RANGE / 2) >> QINDEX_BITS);
   if (best_skip_rd >
       (x->source_variance << (num_pels_log2_lookup[bsize] + RDDIV_BITS)))
     aggr_factor *= scale[level];
 
-  int64_t rd_thresh = (best_skip_rd == INT64_MAX)
-                          ? best_skip_rd
-                          : (int64_t)(best_skip_rd * aggr_factor);
+  // Be more conservative for luma only cases (called from compound type rd)
+  // since best_skip_rd is computed after and skip_rd is computed (with 8-bit
+  // prediction signals blended for WEDGE/DIFFWTD rather than 16-bit) before
+  // interpolation filter search
+  const int luma_mul[4] = { INT_MAX, 16, 15, 11 };
+  int mul_factor = is_luma_only ? luma_mul[level] : 8;
+  int64_t rd_thresh =
+      (best_skip_rd == INT64_MAX)
+          ? best_skip_rd
+          : (int64_t)(best_skip_rd * aggr_factor * mul_factor >> 3);
   if (skip_rd > rd_thresh) eval_txfm = 0;
   return eval_txfm;
 }
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 24639d8..0024c66 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -510,7 +510,7 @@
     sf->inter_sf.alt_ref_search_fp = 1;
     sf->inter_sf.prune_ref_mv_idx_search = 1;
     sf->inter_sf.txfm_rd_gate_level =
-        (boosted || cm->allow_screen_content_tools) ? 0 : 1;
+        (boosted || cm->allow_screen_content_tools) ? 0 : 2;
 
     sf->inter_sf.disable_smooth_interintra = 1;
 
@@ -577,7 +577,7 @@
     sf->inter_sf.disable_obmc = 1;
     sf->inter_sf.disable_onesided_comp = 1;
     sf->inter_sf.txfm_rd_gate_level =
-        (boosted || cm->allow_screen_content_tools) ? 0 : 2;
+        (boosted || cm->allow_screen_content_tools) ? 0 : 3;
     sf->inter_sf.prune_inter_modes_if_skippable = 1;
 
     sf->lpf_sf.lpf_pick = LPF_PICK_FROM_FULL_IMAGE_NON_DUAL;