Introduce skip flag prediction in super_block_yrd()

Early prediction of skip flag based on residual statistics is introduced
in super_block_yrd() for inter blocks. For speed = 3, 4 presets, encode
time reduction of 0.3%, 0.9% is seen (averaged across multiple test cases)
with a bd rate impact of -0.01%, 0.01%

Change-Id: I27743f026f6afc7d99eb294934b91f2009324a23
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index b1bc873..b19a169 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3765,6 +3765,164 @@
   x->tx_split_prune_flag = 0;
 }
 
+// origin_threshold * 128 / 100
+static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
+  {
+      64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
+      68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
+  },
+  {
+      88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
+      68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
+  },
+  {
+      90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
+      74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
+  },
+};
+
+// lookup table for predict_skip_flag
+// int max_tx_size = max_txsize_rect_lookup[bsize];
+// if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
+//   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
+static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
+  TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
+  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
+  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
+  TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
+};
+
+// Uses simple features on top of DCT coefficients to quickly predict
+// whether optimal RD decision is to skip encoding the residual.
+// The sse value is stored in dist.
+static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
+                             int reduced_tx_set) {
+  const int bw = block_size_wide[bsize];
+  const int bh = block_size_high[bsize];
+  const MACROBLOCKD *xd = &x->e_mbd;
+  const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
+
+  *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
+
+  const int64_t mse = *dist / bw / bh;
+  // Normalized quantizer takes the transform upscaling factor (8 for tx size
+  // smaller than 32) into account.
+  const int16_t normalized_dc_q = dc_q >> 3;
+  const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
+  // Predict not to skip when mse is larger than threshold.
+  if (mse > mse_thresh) return 0;
+
+  const int max_tx_size = max_predict_sf_tx_size[bsize];
+  const int tx_h = tx_size_high[max_tx_size];
+  const int tx_w = tx_size_wide[max_tx_size];
+  DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
+  TxfmParam param;
+  param.tx_type = DCT_DCT;
+  param.tx_size = max_tx_size;
+  param.bd = xd->bd;
+  param.is_hbd = is_cur_buf_hbd(xd);
+  param.lossless = 0;
+  param.tx_set_type = av1_get_ext_tx_set_type(
+      param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
+  const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
+  const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
+  const int16_t *src_diff = x->plane[0].src_diff;
+  const int n_coeff = tx_w * tx_h;
+  const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
+  const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
+  const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
+  for (int row = 0; row < bh; row += tx_h) {
+    for (int col = 0; col < bw; col += tx_w) {
+      av1_fwd_txfm(src_diff + col, coefs, bw, &param);
+      // Operating on TX domain, not pixels; we want the QTX quantizers
+      const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
+      if (dc_coef >= dc_thresh) return 0;
+      for (int i = 1; i < n_coeff; ++i) {
+        const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
+        if (ac_coef >= ac_thresh) return 0;
+      }
+    }
+    src_diff += tx_h * bw;
+  }
+  return 1;
+}
+
+#if CONFIG_ONE_PASS_SVM
+static void calc_regional_sse(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t dist,
+                              RD_STATS *rd_stats) {
+  // TODO(chiyotsai@google.com): Don't need regional sse's unless we are doing
+  // none.
+  const int bw = block_size_wide[bsize];
+  const int bw_mi = bw >> tx_size_wide_log2[0];
+  const int bh_mi = bw >> tx_size_high_log2[0];
+  const BLOCK_SIZE split_size = get_partition_subsize(bsize, PARTITION_SPLIT);
+  int64_t dist_0, dist_1, dist_2, dist_3;
+  MACROBLOCKD *xd = &x->e_mbd;
+  dist_0 = pixel_diff_dist(x, AOM_PLANE_Y, 0, 0, bsize, split_size, NULL);
+  dist_1 =
+      pixel_diff_dist(x, AOM_PLANE_Y, 0, bw_mi / 2, bsize, split_size, NULL);
+  dist_2 =
+      pixel_diff_dist(x, AOM_PLANE_Y, bh_mi / 2, 0, bsize, split_size, NULL);
+  dist_3 = pixel_diff_dist(x, AOM_PLANE_Y, bh_mi / 2, bw_mi / 2, bsize,
+                           split_size, NULL);
+
+  if (is_cur_buf_hbd(xd)) {
+    dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
+    dist_0 = ROUND_POWER_OF_TWO(dist_0, (xd->bd - 8) * 2);
+    dist_1 = ROUND_POWER_OF_TWO(dist_1, (xd->bd - 8) * 2);
+    dist_2 = ROUND_POWER_OF_TWO(dist_2, (xd->bd - 8) * 2);
+    dist_3 = ROUND_POWER_OF_TWO(dist_3, (xd->bd - 8) * 2);
+  }
+  const int scaling_factor = MAX_MIB_SIZE * MAX_MIB_SIZE;
+  rd_stats->y_sse = (dist << 4);
+  rd_stats->sse_0 = (dist_0 << 4) * scaling_factor;
+  rd_stats->sse_1 = (dist_1 << 4) * scaling_factor;
+  rd_stats->sse_2 = (dist_2 << 4) * scaling_factor;
+  rd_stats->sse_3 = (dist_3 << 4) * scaling_factor;
+  av1_reg_stat_skipmode_update(rd_stats, x->rdmult);
+}
+#endif
+
+// Used to set proper context for early termination with skip = 1.
+static void set_skip_flag(MACROBLOCK *x, RD_STATS *rd_stats, int bsize,
+                          int64_t dist) {
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = xd->mi[0];
+  const int n4 = bsize_to_num_blk(bsize);
+  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
+  memset(mbmi->txk_type, DCT_DCT, sizeof(mbmi->txk_type[0]) * TXK_TYPE_BUF_LEN);
+  memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
+  mbmi->tx_size = tx_size;
+  for (int i = 0; i < n4; ++i) set_blk_skip(x, 0, i, 1);
+  rd_stats->skip = 1;
+  if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
+  rd_stats->dist = rd_stats->sse = (dist << 4);
+  // Though decision is to make the block as skip based on luma stats,
+  // it is possible that block becomes non skip after chroma rd. In addition
+  // intermediate non skip costs calculated by caller function will be
+  // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
+  // accounted). Hence intermediate rate is populated to code the luma tx blks
+  // as skip, the caller function based on final rd decision (i.e., skip vs
+  // non-skip) sets the final rate accordingly. Here the rate populated
+  // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
+  // size possible) in the current block. Eg: For 128*128 block, rate would be
+  // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
+  // block as 'all zeros'
+  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
+  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
+  av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
+  ENTROPY_CONTEXT *ta = ctxa;
+  ENTROPY_CONTEXT *tl = ctxl;
+  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
+  TXB_CTX txb_ctx;
+  get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
+  const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
+                                .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
+  rd_stats->rate = zero_blk_rate *
+                   (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
+                   (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
+}
+
 static void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
                             RD_STATS *rd_stats, BLOCK_SIZE bs,
                             int64_t ref_best_rd) {
@@ -3773,6 +3931,27 @@
 
   assert(bs == xd->mi[0]->sb_type);
 
+  // If we predict that skip is the optimal RD decision - set the respective
+  // context and terminate early.
+  int64_t dist;
+  int is_inter = is_inter_block(xd->mi[0]);
+  if (cpi->sf.tx_type_search.use_skip_flag_prediction && is_inter &&
+      (!xd->lossless[xd->mi[0]->segment_id]) &&
+      predict_skip_flag(x, bs, &dist, cpi->common.reduced_tx_set_used)) {
+    // Populate rdstats as per skip decision
+    set_skip_flag(x, rd_stats, bs, dist);
+#if CONFIG_ONE_PASS_SVM
+    if (bs >= BLOCK_8X8 && mi_size_wide[bs] == mi_size_high[bs] &&
+        xd->mi[0]->partition == PARTITION_NONE) {
+      calc_regional_sse(x, bs, dist, rd_stats);
+    }
+#endif
+    // Reset the pruning flags.
+    av1_zero(x->tx_search_prune);
+    x->tx_split_prune_flag = 0;
+    return;
+  }
+
   if (xd->lossless[xd->mi[0]->segment_id]) {
     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
   } else if (cpi->sf.tx_size_search_method == USE_LARGESTALL) {
@@ -5695,164 +5874,6 @@
   return 1;
 }
 
-// origin_threshold * 128 / 100
-static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
-  {
-      64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
-      68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
-  },
-  {
-      88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
-      68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
-  },
-  {
-      90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
-      74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
-  },
-};
-
-// lookup table for predict_skip_flag
-// int max_tx_size = max_txsize_rect_lookup[bsize];
-// if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
-//   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
-static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
-  TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
-  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
-  TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
-  TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
-};
-
-// Uses simple features on top of DCT coefficients to quickly predict
-// whether optimal RD decision is to skip encoding the residual.
-// The sse value is stored in dist.
-static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
-                             int reduced_tx_set) {
-  const int bw = block_size_wide[bsize];
-  const int bh = block_size_high[bsize];
-  const MACROBLOCKD *xd = &x->e_mbd;
-  const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
-
-  *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
-
-  const int64_t mse = *dist / bw / bh;
-  // Normalized quantizer takes the transform upscaling factor (8 for tx size
-  // smaller than 32) into account.
-  const int16_t normalized_dc_q = dc_q >> 3;
-  const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
-  // Predict not to skip when mse is larger than threshold.
-  if (mse > mse_thresh) return 0;
-
-  const int max_tx_size = max_predict_sf_tx_size[bsize];
-  const int tx_h = tx_size_high[max_tx_size];
-  const int tx_w = tx_size_wide[max_tx_size];
-  DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
-  TxfmParam param;
-  param.tx_type = DCT_DCT;
-  param.tx_size = max_tx_size;
-  param.bd = xd->bd;
-  param.is_hbd = is_cur_buf_hbd(xd);
-  param.lossless = 0;
-  param.tx_set_type = av1_get_ext_tx_set_type(
-      param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
-  const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
-  const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
-  const int16_t *src_diff = x->plane[0].src_diff;
-  const int n_coeff = tx_w * tx_h;
-  const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
-  const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
-  const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
-  for (int row = 0; row < bh; row += tx_h) {
-    for (int col = 0; col < bw; col += tx_w) {
-      av1_fwd_txfm(src_diff + col, coefs, bw, &param);
-      // Operating on TX domain, not pixels; we want the QTX quantizers
-      const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
-      if (dc_coef >= dc_thresh) return 0;
-      for (int i = 1; i < n_coeff; ++i) {
-        const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
-        if (ac_coef >= ac_thresh) return 0;
-      }
-    }
-    src_diff += tx_h * bw;
-  }
-  return 1;
-}
-
-#if CONFIG_ONE_PASS_SVM
-static void calc_regional_sse(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t dist,
-                              RD_STATS *rd_stats) {
-  // TODO(chiyotsai@google.com): Don't need regional sse's unless we are doing
-  // none.
-  const int bw = block_size_wide[bsize];
-  const int bw_mi = bw >> tx_size_wide_log2[0];
-  const int bh_mi = bw >> tx_size_high_log2[0];
-  const BLOCK_SIZE split_size = get_partition_subsize(bsize, PARTITION_SPLIT);
-  int64_t dist_0, dist_1, dist_2, dist_3;
-  MACROBLOCKD *xd = &x->e_mbd;
-  dist_0 = pixel_diff_dist(x, AOM_PLANE_Y, 0, 0, bsize, split_size, NULL);
-  dist_1 =
-      pixel_diff_dist(x, AOM_PLANE_Y, 0, bw_mi / 2, bsize, split_size, NULL);
-  dist_2 =
-      pixel_diff_dist(x, AOM_PLANE_Y, bh_mi / 2, 0, bsize, split_size, NULL);
-  dist_3 = pixel_diff_dist(x, AOM_PLANE_Y, bh_mi / 2, bw_mi / 2, bsize,
-                           split_size, NULL);
-
-  if (is_cur_buf_hbd(xd)) {
-    dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
-    dist_0 = ROUND_POWER_OF_TWO(dist_0, (xd->bd - 8) * 2);
-    dist_1 = ROUND_POWER_OF_TWO(dist_1, (xd->bd - 8) * 2);
-    dist_2 = ROUND_POWER_OF_TWO(dist_2, (xd->bd - 8) * 2);
-    dist_3 = ROUND_POWER_OF_TWO(dist_3, (xd->bd - 8) * 2);
-  }
-  const int scaling_factor = MAX_MIB_SIZE * MAX_MIB_SIZE;
-  rd_stats->y_sse = (dist << 4);
-  rd_stats->sse_0 = (dist_0 << 4) * scaling_factor;
-  rd_stats->sse_1 = (dist_1 << 4) * scaling_factor;
-  rd_stats->sse_2 = (dist_2 << 4) * scaling_factor;
-  rd_stats->sse_3 = (dist_3 << 4) * scaling_factor;
-  av1_reg_stat_skipmode_update(rd_stats, x->rdmult);
-}
-#endif
-
-// Used to set proper context for early termination with skip = 1.
-static void set_skip_flag(MACROBLOCK *x, RD_STATS *rd_stats, int bsize,
-                          int64_t dist) {
-  MACROBLOCKD *const xd = &x->e_mbd;
-  MB_MODE_INFO *const mbmi = xd->mi[0];
-  const int n4 = bsize_to_num_blk(bsize);
-  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
-  memset(mbmi->txk_type, DCT_DCT, sizeof(mbmi->txk_type[0]) * TXK_TYPE_BUF_LEN);
-  memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
-  mbmi->tx_size = tx_size;
-  for (int i = 0; i < n4; ++i) set_blk_skip(x, 0, i, 1);
-  rd_stats->skip = 1;
-  if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
-  rd_stats->dist = rd_stats->sse = (dist << 4);
-  // Though decision is to make the block as skip based on luma stats,
-  // it is possible that block becomes non skip after chroma rd. In addition
-  // intermediate non skip costs calculated by caller function will be
-  // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
-  // accounted). Hence intermediate rate is populated to code the luma tx blks
-  // as skip, the caller function based on final rd decision (i.e., skip vs
-  // non-skip) sets the final rate accordingly. Here the rate populated
-  // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
-  // size possible) in the current block. Eg: For 128*128 block, rate would be
-  // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
-  // block as 'all zeros'
-  ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
-  ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
-  av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
-  ENTROPY_CONTEXT *ta = ctxa;
-  ENTROPY_CONTEXT *tl = ctxl;
-  const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
-  TXB_CTX txb_ctx;
-  get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
-  const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
-                                .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
-  rd_stats->rate = zero_blk_rate *
-                   (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
-                   (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
-}
-
 // Search for best transform size and type for luma inter blocks.
 static void pick_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
                                   RD_STATS *rd_stats, BLOCK_SIZE bsize,