Add separable model based tx pruning.
Add a function that prunes the transform types based on estimated
RD cost. It checks horizontal and vertical transform types separately
to save speed overhead.
Change-Id: Iec2bcdd6beae9f285fdff095b17d14f5ea573402
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 4352294..9b55fa4 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2439,6 +2439,137 @@
return prune;
}
+int16_t prune_txk_type_separ(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+ int block, TX_SIZE tx_size, int blk_row,
+ int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
+ int16_t allowed_tx_mask, int prune_factor,
+ const TXB_CTX *const txb_ctx,
+ int reduced_tx_set_used, int64_t ref_best_rd,
+ int num_sel) {
+ const AV1_COMMON *cm = &cpi->common;
+
+ int idx;
+
+ int64_t rds_v[4];
+ int64_t rds_h[4];
+ int idx_v[4] = { 0, 1, 2, 3 };
+ int idx_h[4] = { 0, 1, 2, 3 };
+ int skip_v[4] = { 0 };
+ int skip_h[4] = { 0 };
+ const int idx_map[16] = {
+ DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
+ ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
+ FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
+ H_DCT, H_ADST, H_FLIPADST, IDTX
+ };
+
+ const int sel_pattern_v[16] = {
+ 0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
+ };
+ const int sel_pattern_h[16] = {
+ 0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
+ };
+
+ QUANT_PARAM quant_param;
+ TxfmParam txfm_param;
+ av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
+ av1_setup_quant(cm, tx_size, 1, AV1_XFORM_QUANT_B, &quant_param);
+ int tx_type;
+ // to ensure we can try ones even outside of ext_tx_set of current block
+ // this function should only be called for size < 16
+ assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
+ txfm_param.tx_set_type = EXT_TX_SET_ALL16;
+
+ int rate_cost = 0;
+ int64_t dist = 0, sse = 0;
+ // evaluate horizontal with vertical DCT
+ for (idx = 0; idx < 4; ++idx) {
+ tx_type = idx_map[idx];
+ txfm_param.tx_type = tx_type;
+
+ av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+ &quant_param);
+
+ dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
+
+ rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
+ txb_ctx, reduced_tx_set_used, 0);
+
+ rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
+
+ if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
+ skip_h[idx] = 1;
+ }
+ }
+ sort_rd(rds_h, idx_h, 4);
+ for (idx = 1; idx < 4; idx++) {
+ if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
+ }
+
+ if (skip_h[idx_h[0]]) return 0xFFFF;
+
+ // evaluate vertical with the best horizontal chosen
+ rds_v[0] = rds_h[0];
+ int start_v = 1, end_v = 4;
+ const int *idx_map_v = idx_map + idx_h[0];
+
+ for (idx = start_v; idx < end_v; ++idx) {
+ tx_type = idx_map_v[idx_v[idx] * 4];
+ txfm_param.tx_type = tx_type;
+
+ av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+ &quant_param);
+
+ dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
+
+ rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
+ txb_ctx, reduced_tx_set_used, 0);
+
+ rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
+
+ if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
+ skip_v[idx] = 1;
+ }
+ }
+ sort_rd(rds_v, idx_v, 4);
+ for (idx = 1; idx < 4; idx++) {
+ if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
+ }
+
+ // combine rd_h and rd_v to prune tx candidates
+ int i_v, i_h;
+ int64_t rds[16];
+ int num_cand = 0, last = TX_TYPES - 1;
+
+ for (int i = 0; i < 16; i++) {
+ i_v = sel_pattern_v[i];
+ i_h = sel_pattern_h[i];
+ tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
+ if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
+ skip_v[idx_v[i_v]]) {
+ txk_map[last] = tx_type;
+ last--;
+ } else {
+ txk_map[num_cand] = tx_type;
+ rds[num_cand] = rds_v[i_v] + rds_h[i_h];
+ num_cand++;
+ }
+ }
+ sort_rd(rds, txk_map, num_cand);
+
+ uint16_t prune = ~(1 << txk_map[0]);
+ num_sel = AOMMIN(num_sel, num_cand);
+
+ for (int i = 1; i < num_sel; i++) {
+ int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
+ if (factor < (int64_t)prune_factor)
+ prune &= ~(1 << txk_map[i]);
+ else
+ break;
+ }
+ return prune;
+}
+
static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
int block, int blk_row, int blk_col,
BLOCK_SIZE plane_bsize, TX_SIZE tx_size,