Predict skip flag to speed up the TX type search
Average speed-up (lowres):
low bitrates: 6.6%
mid bitrates: 2.5%
high bitrates: 0.0%
Average PSNR loss:
lowres: 0.010%
midres: 0.005%
Change-Id: Id34fb247e5e31f04ca324c58142e4b5ac4edacda
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 401e50f..843cad2 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5169,6 +5169,101 @@
*rd_stats = tx_rd_info->rd_stats;
}
+// Uses simple features on top of DCT coefficients to quickly predict
+// whether optimal RD decision is to skip encoding the residual.
+static int predict_skip_flag_8bit(const MACROBLOCK *x, BLOCK_SIZE bsize) {
+ if (bsize > BLOCK_16X16) return 0;
+ // Tuned for target false-positive rate of 5% for all block sizes:
+ const uint32_t threshold_table[] = { 50, 50, 50, 55, 47, 47, 53, 22, 22, 37 };
+ const struct macroblock_plane *const p = &x->plane[0];
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ tran_low_t DCT_coefs[32 * 32];
+ TxfmParam param;
+ param.tx_type = DCT_DCT;
+#if CONFIG_RECT_TX && (CONFIG_EXT_TX || CONFIG_VAR_TX)
+ param.tx_size = max_txsize_rect_lookup[bsize];
+#else
+ param.tx_size = max_txsize_lookup[bsize];
+#endif
+ param.bd = 8;
+ param.lossless = 0;
+ av1_fwd_txfm(p->src_diff, DCT_coefs, bw, ¶m);
+
+ uint32_t dc = (uint32_t)av1_dc_quant(x->qindex, 0, AOM_BITS_8);
+ uint32_t ac = (uint32_t)av1_ac_quant(x->qindex, 0, AOM_BITS_8);
+ uint32_t max_quantized_coef = (100 * (uint32_t)abs(DCT_coefs[0])) / dc;
+ for (int i = 1; i < bw * bh; i++) {
+ uint32_t cur_quantized_coef = (100 * (uint32_t)abs(DCT_coefs[i])) / ac;
+ if (cur_quantized_coef > max_quantized_coef)
+ max_quantized_coef = cur_quantized_coef;
+ }
+
+ return max_quantized_coef < threshold_table[AOMMAX(bsize - BLOCK_4X4, 0)];
+}
+
+// Used to set proper context for early termination with skip = 1.
+static void set_skip_flag(const AV1_COMP *cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, int bsize) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+ const int n4 = bsize_to_num_blk(bsize);
+#if CONFIG_RECT_TX && (CONFIG_EXT_TX || CONFIG_VAR_TX)
+ const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
+#else
+ const TX_SIZE tx_size = max_txsize_lookup[bsize];
+#endif
+ mbmi->tx_type = DCT_DCT;
+ for (int idy = 0; idy < xd->n8_h; ++idy)
+ for (int idx = 0; idx < xd->n8_w; ++idx)
+ mbmi->inter_tx_size[idy][idx] = tx_size;
+ mbmi->tx_size = tx_size;
+ mbmi->min_tx_size = get_min_tx_size(tx_size);
+ memset(x->blk_skip[0], 1, sizeof(uint8_t) * n4);
+ rd_stats->skip = 1;
+
+ // Rate.
+ const int tx_size_ctx = txsize_sqr_map[tx_size];
+ ENTROPY_CONTEXT ctxa[2 * MAX_MIB_SIZE];
+ ENTROPY_CONTEXT ctxl[2 * MAX_MIB_SIZE];
+ av1_get_entropy_contexts(bsize, 0, &xd->plane[0], ctxa, ctxl);
+ int coeff_ctx = get_entropy_context(tx_size, ctxa, ctxl);
+ int rate = x->token_head_costs[tx_size_ctx][PLANE_TYPE_Y][1][0][coeff_ctx][0];
+ if (tx_size > TX_4X4) {
+ int ctx = txfm_partition_context(
+ xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
+ rate += av1_cost_bit(cpi->common.fc->txfm_partition_prob[ctx], 0);
+ }
+#if !CONFIG_TXK_SEL
+#if CONFIG_EXT_TX
+ const AV1_COMMON *cm = &cpi->common;
+ const int ext_tx_set = get_ext_tx_set(max_txsize_lookup[bsize], bsize, 1,
+ cm->reduced_tx_set_used);
+ if (get_ext_tx_types(mbmi->min_tx_size, bsize, 1, cm->reduced_tx_set_used) >
+ 1 &&
+ !xd->lossless[xd->mi[0]->mbmi.segment_id]) {
+ if (ext_tx_set > 0)
+ rate +=
+ x->inter_tx_type_costs[ext_tx_set][txsize_sqr_map[mbmi->min_tx_size]]
+ [mbmi->tx_type];
+ }
+#else
+ if (mbmi->min_tx_size < TX_32X32 && !xd->lossless[xd->mi[0]->mbmi.segment_id])
+ rd_stats->rate += x->inter_tx_type_costs[mbmi->min_tx_size][mbmi->tx_type];
+#endif // CONFIG_EXT_TX
+#endif // CONFIG_TXK_SEL
+ rd_stats->rate = rate;
+
+ // Distortion.
+ int64_t tmp = pixel_diff_dist(x, 0, x->plane[0].src_diff,
+ block_size_wide[bsize], 0, 0, bsize, bsize);
+#if CONFIG_HIGHBITDEPTH
+ if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+ tmp = ROUND_POWER_OF_TWO(tmp, (xd->bd - 8) * 2);
+#endif // CONFIG_HIGHBITDEPTH
+ rd_stats->dist = rd_stats->sse = (tmp << 4);
+}
+
static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
RD_STATS *rd_stats, BLOCK_SIZE bsize,
int64_t ref_best_rd) {
@@ -5207,13 +5302,6 @@
get_ext_tx_set(max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
#endif // CONFIG_EXT_TX
- if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
-#if CONFIG_EXT_TX
- prune = prune_tx_types(cpi, bsize, x, xd, ext_tx_set);
-#else
- prune = prune_tx_types(cpi, bsize, x, xd, 0);
-#endif // CONFIG_EXT_TX
-
av1_invalid_rd_stats(rd_stats);
for (idx = 0; idx < count32; ++idx)
@@ -5236,6 +5324,26 @@
}
}
+// If we predict that skip is the optimal RD decision - set the respective
+// context and terminate early.
+#if CONFIG_HIGHBITDEPTH
+ if (!(xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH))
+#endif // CONFIG_HIGHBITDEPTH
+ {
+ if (is_inter && cpi->sf.tx_type_search.use_skip_flag_prediction &&
+ predict_skip_flag_8bit(x, bsize)) {
+ set_skip_flag(cpi, x, rd_stats, bsize);
+ return;
+ }
+ }
+
+ if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
+#if CONFIG_EXT_TX
+ prune = prune_tx_types(cpi, bsize, x, xd, ext_tx_set);
+#else
+ prune = prune_tx_types(cpi, bsize, x, xd, 0);
+#endif // CONFIG_EXT_TX
+
for (tx_type = txk_start; tx_type < txk_end; ++tx_type) {
RD_STATS this_rd_stats;
av1_init_rd_stats(&this_rd_stats);