tx skip prediction: add pre-screening on prediction error

Exit early on large prediction error. This reduces the overhead of the
tx skip prediction, especially for low quantizers.

BUG=aomedia:1106

Change-Id: Icad8e01cdeb2e8f4cf0befa7f5a89e088f3c17e5
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 8e6835f..8b14613 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4727,7 +4727,8 @@
 
 // Uses simple features on top of DCT coefficients to quickly predict
 // whether optimal RD decision is to skip encoding the residual.
-static int predict_skip_flag(const MACROBLOCK *x, BLOCK_SIZE bsize) {
+// The sse value is stored in dist.
+static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist) {
   const int max_tx_size =
       get_max_rect_tx_size(bsize, is_inter_block(&x->e_mbd.mi[0]->mbmi));
   const int tx_h = tx_size_high[max_tx_size];
@@ -4737,6 +4738,17 @@
   const int bw = block_size_wide[bsize];
   const int bh = block_size_high[bsize];
   const MACROBLOCKD *xd = &x->e_mbd;
+  const uint32_t dc_q = (uint32_t)av1_dc_quant_QTX(x->qindex, 0, xd->bd);
+
+  *dist = pixel_diff_dist(x, 0, x->plane[0].src_diff, bw, 0, 0, bsize, bsize);
+  const int64_t mse = *dist / bw / bh;
+  // Normalized quantizer takes the transform upscaling factor (8 for tx size
+  // smaller than 32) into account.
+  const uint32_t normalized_dc_q = dc_q >> 3;
+  const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
+  // Predict not to skip when mse is larger than threshold.
+  if (mse > mse_thresh) return 0;
+
   DECLARE_ALIGNED(32, tran_low_t, DCT_coefs[32 * 32]);
   TxfmParam param;
   param.tx_type = DCT_DCT;
@@ -4752,8 +4764,7 @@
   // within this function.
   param.tx_set_type = get_ext_tx_set_type(param.tx_size, plane_bsize,
                                           is_inter_block(&xd->mi[0]->mbmi), 0);
-  const uint32_t dc = (uint32_t)av1_dc_quant_QTX(x->qindex, 0, xd->bd);
-  const uint32_t ac = (uint32_t)av1_ac_quant_QTX(x->qindex, 0, xd->bd);
+  const uint32_t ac_q = (uint32_t)av1_ac_quant_QTX(x->qindex, 0, xd->bd);
   uint32_t max_quantized_coef = 0;
   const int16_t *src_diff = x->plane[0].src_diff;
   for (int row = 0; row < bh; row += tx_h) {
@@ -4770,7 +4781,7 @@
       // Operating on TX domain, not pixels; we want the QTX quantizers
       for (int i = 0; i < tx_w * tx_h; ++i) {
         uint32_t cur_quantized_coef =
-            (100 * (uint32_t)abs(DCT_coefs[i])) / (i ? ac : dc);
+            (100 * (uint32_t)abs(DCT_coefs[i])) / (i ? ac_q : dc_q);
         if (cur_quantized_coef > max_quantized_coef)
           max_quantized_coef = cur_quantized_coef;
       }
@@ -4783,7 +4794,7 @@
 
 // Used to set proper context for early termination with skip = 1.
 static void set_skip_flag(const AV1_COMP *cpi, MACROBLOCK *x,
-                          RD_STATS *rd_stats, int bsize) {
+                          RD_STATS *rd_stats, int bsize, int64_t dist) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const int n4 = bsize_to_num_blk(bsize);
@@ -4838,15 +4849,11 @@
   }
 #endif  // CONFIG_TXK_SEL
   rd_stats->rate = rate;
-
-  // Distortion.
-  int64_t tmp = pixel_diff_dist(x, 0, x->plane[0].src_diff,
-                                block_size_wide[bsize], 0, 0, bsize, bsize);
 #if CONFIG_HIGHBITDEPTH
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
-    tmp = ROUND_POWER_OF_TWO(tmp, (xd->bd - 8) * 2);
+    dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
 #endif  // CONFIG_HIGHBITDEPTH
-  rd_stats->dist = rd_stats->sse = (tmp << 4);
+  rd_stats->dist = rd_stats->sse = (dist << 4);
 }
 
 static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
@@ -4900,9 +4907,10 @@
 
   // If we predict that skip is the optimal RD decision - set the respective
   // context and terminate early.
+  int64_t dist;
   if (is_inter && cpi->sf.tx_type_search.use_skip_flag_prediction &&
-      predict_skip_flag(x, bsize)) {
-    set_skip_flag(cpi, x, rd_stats, bsize);
+      predict_skip_flag(x, bsize, &dist)) {
+    set_skip_flag(cpi, x, rd_stats, bsize, dist);
     // Save the RD search results into tx_rd_record.
     if (within_border) save_tx_rd_info(n4, hash, x, rd_stats, tx_rd_record);
     return;