Add separable model based tx pruning.

Add a function that prunes the transform types based on estimated
RD cost. It checks horizontal and vertical transform types separately
to save speed overhead.

Change-Id: Iec2bcdd6beae9f285fdff095b17d14f5ea573402
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 4352294..9b55fa4 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2439,6 +2439,137 @@
   return prune;
 }
 
+int16_t prune_txk_type_separ(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+                             int block, TX_SIZE tx_size, int blk_row,
+                             int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
+                             int16_t allowed_tx_mask, int prune_factor,
+                             const TXB_CTX *const txb_ctx,
+                             int reduced_tx_set_used, int64_t ref_best_rd,
+                             int num_sel) {
+  const AV1_COMMON *cm = &cpi->common;
+
+  int idx;
+
+  int64_t rds_v[4];
+  int64_t rds_h[4];
+  int idx_v[4] = { 0, 1, 2, 3 };
+  int idx_h[4] = { 0, 1, 2, 3 };
+  int skip_v[4] = { 0 };
+  int skip_h[4] = { 0 };
+  const int idx_map[16] = {
+    DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
+    ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
+    FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
+    H_DCT,        H_ADST,        H_FLIPADST,        IDTX
+  };
+
+  const int sel_pattern_v[16] = {
+    0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
+  };
+  const int sel_pattern_h[16] = {
+    0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
+  };
+
+  QUANT_PARAM quant_param;
+  TxfmParam txfm_param;
+  av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
+  av1_setup_quant(cm, tx_size, 1, AV1_XFORM_QUANT_B, &quant_param);
+  int tx_type;
+  // to ensure we can try ones even outside of ext_tx_set of current block
+  // this function should only be called for size < 16
+  assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
+  txfm_param.tx_set_type = EXT_TX_SET_ALL16;
+
+  int rate_cost = 0;
+  int64_t dist = 0, sse = 0;
+  // evaluate horizontal with vertical DCT
+  for (idx = 0; idx < 4; ++idx) {
+    tx_type = idx_map[idx];
+    txfm_param.tx_type = tx_type;
+
+    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+                    &quant_param);
+
+    dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
+
+    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
+                                              txb_ctx, reduced_tx_set_used, 0);
+
+    rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
+
+    if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
+      skip_h[idx] = 1;
+    }
+  }
+  sort_rd(rds_h, idx_h, 4);
+  for (idx = 1; idx < 4; idx++) {
+    if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
+  }
+
+  if (skip_h[idx_h[0]]) return 0xFFFF;
+
+  // evaluate vertical with the best horizontal chosen
+  rds_v[0] = rds_h[0];
+  int start_v = 1, end_v = 4;
+  const int *idx_map_v = idx_map + idx_h[0];
+
+  for (idx = start_v; idx < end_v; ++idx) {
+    tx_type = idx_map_v[idx_v[idx] * 4];
+    txfm_param.tx_type = tx_type;
+
+    av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+                    &quant_param);
+
+    dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
+
+    rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
+                                              txb_ctx, reduced_tx_set_used, 0);
+
+    rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
+
+    if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
+      skip_v[idx] = 1;
+    }
+  }
+  sort_rd(rds_v, idx_v, 4);
+  for (idx = 1; idx < 4; idx++) {
+    if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
+  }
+
+  // combine rd_h and rd_v to prune tx candidates
+  int i_v, i_h;
+  int64_t rds[16];
+  int num_cand = 0, last = TX_TYPES - 1;
+
+  for (int i = 0; i < 16; i++) {
+    i_v = sel_pattern_v[i];
+    i_h = sel_pattern_h[i];
+    tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
+    if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
+        skip_v[idx_v[i_v]]) {
+      txk_map[last] = tx_type;
+      last--;
+    } else {
+      txk_map[num_cand] = tx_type;
+      rds[num_cand] = rds_v[i_v] + rds_h[i_h];
+      num_cand++;
+    }
+  }
+  sort_rd(rds, txk_map, num_cand);
+
+  uint16_t prune = ~(1 << txk_map[0]);
+  num_sel = AOMMIN(num_sel, num_cand);
+
+  for (int i = 1; i < num_sel; i++) {
+    int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
+    if (factor < (int64_t)prune_factor)
+      prune &= ~(1 << txk_map[i]);
+    else
+      break;
+  }
+  return prune;
+}
+
 static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
                                int block, int blk_row, int blk_col,
                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,