Rework recursive transform block partition search

Support transform block level kernel selection in the recursive
transform block partitioning search.

Change-Id: I511c39705ee636b0c9fabbe4720fe5a9764b964a
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 2e29db5..0101b16 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -1870,7 +1870,11 @@
   TX_TYPE best_tx_type = txk_start;
   int64_t best_rd = INT64_MAX;
   const int coeff_ctx = combine_entropy_contexts(*a, *l);
+  RD_STATS best_rd_stats;
   TX_TYPE tx_type;
+
+  av1_invalid_rd_stats(&best_rd_stats);
+
   for (tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
     if (plane == 0) mbmi->txk_type[block] = tx_type;
     TX_TYPE ref_tx_type =
@@ -1894,10 +1898,13 @@
     int rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
     if (rd < best_rd) {
       best_rd = rd;
-      *rd_stats = this_rd_stats;
+      best_rd_stats = this_rd_stats;
       best_tx_type = tx_type;
     }
   }
+
+  av1_merge_rd_stats(rd_stats, &best_rd_stats);
+
   if (plane == 0) mbmi->txk_type[block] = best_tx_type;
   // TODO(angiebird): Instead of re-call av1_xform_quant and av1_optimize_b,
   // copy the best result in the above tx_type search for loop
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index be12985..991cde2 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4182,6 +4182,12 @@
 
   int coeff_ctx = get_entropy_context(tx_size, a, l);
 
+#if CONFIG_TXK_SEL
+  av1_search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
+                      tx_size, a, l, 0, rd_stats);
+  return;
+#endif
+
   av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
                   coeff_ctx, AV1_XFORM_QUANT_FP);
 
@@ -4288,6 +4294,10 @@
   int zero_blk_rate;
   RD_STATS sum_rd_stats;
   const int tx_size_ctx = txsize_sqr_map[tx_size];
+#if CONFIG_TXK_SEL
+  TX_TYPE best_tx_type = TX_TYPES;
+  int txk_idx = block;
+#endif
 
   av1_init_rd_stats(&sum_rd_stats);
 
@@ -4346,6 +4356,9 @@
           av1_cost_bit(cpi->common.fc->txfm_partition_prob[ctx], 0);
     this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
     tmp_eob = p->eobs[block];
+#if CONFIG_TXK_SEL
+    best_tx_type = mbmi->txk_type[txk_idx];
+#endif
   }
 
   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH) {
@@ -4496,6 +4509,9 @@
       for (idx = 0; idx < tx_size_wide_unit[tx_size] / 2; ++idx)
         inter_tx_size[idy][idx] = tx_size;
     mbmi->tx_size = tx_size;
+#if CONFIG_TXK_SEL
+    mbmi->txk_type[txk_idx] = best_tx_type;
+#endif
     if (this_rd == INT64_MAX) *is_cost_valid = 0;
     x->blk_skip[plane][blk_row * bw + blk_col] = rd_stats->skip;
   } else {
@@ -4643,6 +4659,12 @@
   TX_SIZE best_tx = max_txsize_lookup[bsize];
   TX_SIZE best_min_tx_size = TX_SIZES_ALL;
   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
+  TX_TYPE txk_start = DCT_DCT;
+#if CONFIG_TXK_SEL
+  TX_TYPE txk_end = DCT_DCT + 1;
+#else
+  TX_TYPE txk_end = TX_TYPES;
+#endif
   const int n4 = bsize_to_num_blk(bsize);
   int idx, idy;
   int prune = 0;
@@ -4670,7 +4692,7 @@
   for (idx = 0; idx < count32; ++idx)
     av1_invalid_rd_stats(&rd_stats_stack[idx]);
 
-  for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
+  for (tx_type = txk_start; tx_type < txk_end; ++tx_type) {
     RD_STATS this_rd_stats;
     av1_init_rd_stats(&this_rd_stats);
 #if CONFIG_EXT_TX