Predict skip flag to speed up the TX type search

Average speed-up (lowres):
low bitrates: 6.6%
mid bitrates: 2.5%
high bitrates: 0.0%

Average PSNR loss:
lowres: 0.010%
midres: 0.005%

Change-Id: Id34fb247e5e31f04ca324c58142e4b5ac4edacda
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 401e50f..843cad2 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5169,6 +5169,101 @@
   *rd_stats = tx_rd_info->rd_stats;
 }
 
+// Uses simple features on top of DCT coefficients to quickly predict
+// whether optimal RD decision is to skip encoding the residual.
+static int predict_skip_flag_8bit(const MACROBLOCK *x, BLOCK_SIZE bsize) {
+  if (bsize > BLOCK_16X16) return 0;
+  // Tuned for target false-positive rate of 5% for all block sizes:
+  const uint32_t threshold_table[] = { 50, 50, 50, 55, 47, 47, 53, 22, 22, 37 };
+  const struct macroblock_plane *const p = &x->plane[0];
+  const int bw = block_size_wide[bsize];
+  const int bh = block_size_high[bsize];
+  tran_low_t DCT_coefs[32 * 32];
+  TxfmParam param;
+  param.tx_type = DCT_DCT;
+#if CONFIG_RECT_TX && (CONFIG_EXT_TX || CONFIG_VAR_TX)
+  param.tx_size = max_txsize_rect_lookup[bsize];
+#else
+  param.tx_size = max_txsize_lookup[bsize];
+#endif
+  param.bd = 8;
+  param.lossless = 0;
+  av1_fwd_txfm(p->src_diff, DCT_coefs, bw, &param);
+
+  uint32_t dc = (uint32_t)av1_dc_quant(x->qindex, 0, AOM_BITS_8);
+  uint32_t ac = (uint32_t)av1_ac_quant(x->qindex, 0, AOM_BITS_8);
+  uint32_t max_quantized_coef = (100 * (uint32_t)abs(DCT_coefs[0])) / dc;
+  for (int i = 1; i < bw * bh; i++) {
+    uint32_t cur_quantized_coef = (100 * (uint32_t)abs(DCT_coefs[i])) / ac;
+    if (cur_quantized_coef > max_quantized_coef)
+      max_quantized_coef = cur_quantized_coef;
+  }
+
+  return max_quantized_coef < threshold_table[AOMMAX(bsize - BLOCK_4X4, 0)];
+}
+
+// Used to set proper context for early termination with skip = 1.
+static void set_skip_flag(const AV1_COMP *cpi, MACROBLOCK *x,
+                          RD_STATS *rd_stats, int bsize) {
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  const int n4 = bsize_to_num_blk(bsize);
+#if CONFIG_RECT_TX && (CONFIG_EXT_TX || CONFIG_VAR_TX)
+  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
+#else
+  const TX_SIZE tx_size = max_txsize_lookup[bsize];
+#endif
+  mbmi->tx_type = DCT_DCT;
+  for (int idy = 0; idy < xd->n8_h; ++idy)
+    for (int idx = 0; idx < xd->n8_w; ++idx)
+      mbmi->inter_tx_size[idy][idx] = tx_size;
+  mbmi->tx_size = tx_size;
+  mbmi->min_tx_size = get_min_tx_size(tx_size);
+  memset(x->blk_skip[0], 1, sizeof(uint8_t) * n4);
+  rd_stats->skip = 1;
+
+  // Rate.
+  const int tx_size_ctx = txsize_sqr_map[tx_size];
+  ENTROPY_CONTEXT ctxa[2 * MAX_MIB_SIZE];
+  ENTROPY_CONTEXT ctxl[2 * MAX_MIB_SIZE];
+  av1_get_entropy_contexts(bsize, 0, &xd->plane[0], ctxa, ctxl);
+  int coeff_ctx = get_entropy_context(tx_size, ctxa, ctxl);
+  int rate = x->token_head_costs[tx_size_ctx][PLANE_TYPE_Y][1][0][coeff_ctx][0];
+  if (tx_size > TX_4X4) {
+    int ctx = txfm_partition_context(
+        xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
+    rate += av1_cost_bit(cpi->common.fc->txfm_partition_prob[ctx], 0);
+  }
+#if !CONFIG_TXK_SEL
+#if CONFIG_EXT_TX
+  const AV1_COMMON *cm = &cpi->common;
+  const int ext_tx_set = get_ext_tx_set(max_txsize_lookup[bsize], bsize, 1,
+                                        cm->reduced_tx_set_used);
+  if (get_ext_tx_types(mbmi->min_tx_size, bsize, 1, cm->reduced_tx_set_used) >
+          1 &&
+      !xd->lossless[xd->mi[0]->mbmi.segment_id]) {
+    if (ext_tx_set > 0)
+      rate +=
+          x->inter_tx_type_costs[ext_tx_set][txsize_sqr_map[mbmi->min_tx_size]]
+                                [mbmi->tx_type];
+  }
+#else
+  if (mbmi->min_tx_size < TX_32X32 && !xd->lossless[xd->mi[0]->mbmi.segment_id])
+    rd_stats->rate += x->inter_tx_type_costs[mbmi->min_tx_size][mbmi->tx_type];
+#endif  // CONFIG_EXT_TX
+#endif  // CONFIG_TXK_SEL
+  rd_stats->rate = rate;
+
+  // Distortion.
+  int64_t tmp = pixel_diff_dist(x, 0, x->plane[0].src_diff,
+                                block_size_wide[bsize], 0, 0, bsize, bsize);
+#if CONFIG_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+    tmp = ROUND_POWER_OF_TWO(tmp, (xd->bd - 8) * 2);
+#endif  // CONFIG_HIGHBITDEPTH
+  rd_stats->dist = rd_stats->sse = (tmp << 4);
+}
+
 static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
                                RD_STATS *rd_stats, BLOCK_SIZE bsize,
                                int64_t ref_best_rd) {
@@ -5207,13 +5302,6 @@
       get_ext_tx_set(max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
 
-  if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
-#if CONFIG_EXT_TX
-    prune = prune_tx_types(cpi, bsize, x, xd, ext_tx_set);
-#else
-    prune = prune_tx_types(cpi, bsize, x, xd, 0);
-#endif  // CONFIG_EXT_TX
-
   av1_invalid_rd_stats(rd_stats);
 
   for (idx = 0; idx < count32; ++idx)
@@ -5236,6 +5324,26 @@
     }
   }
 
+// If we predict that skip is the optimal RD decision - set the respective
+// context and terminate early.
+#if CONFIG_HIGHBITDEPTH
+  if (!(xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH))
+#endif  // CONFIG_HIGHBITDEPTH
+  {
+    if (is_inter && cpi->sf.tx_type_search.use_skip_flag_prediction &&
+        predict_skip_flag_8bit(x, bsize)) {
+      set_skip_flag(cpi, x, rd_stats, bsize);
+      return;
+    }
+  }
+
+  if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
+#if CONFIG_EXT_TX
+    prune = prune_tx_types(cpi, bsize, x, xd, ext_tx_set);
+#else
+    prune = prune_tx_types(cpi, bsize, x, xd, 0);
+#endif  // CONFIG_EXT_TX
+
   for (tx_type = txk_start; tx_type < txk_end; ++tx_type) {
     RD_STATS this_rd_stats;
     av1_init_rd_stats(&this_rd_stats);