Use tx_size 1 level down for transform type search

This addresses an inconsistency between the set used
to decode the tx_type in the bitstream and the set used
for the tx_type search. Previously, the set used to
read/write the tx_type was based on the smallest tx_size
in the vartx partitioning, but the search uses a set
based on the largest possible tx_size. This patch
changes the tx_type search to use the transform type
set associated with the tx_size 1 recursive level down from
the max square tx_size to make the search more consistent
with the bitstream syntax. If a tx_size is selected for an
invalid tx_type, DCT_DCT is used for that partition instead.

This patch also adds assertions to all exposed transform
functions to ensure that no illegal transform type/size
combinations occur.

This currently gets a 0.1% drop in performance on lowres.
The drop is due to the reduction of the tx_types available
for 32x16 and 16x32 transform sizes. Before this patch,
32x16 and 16x32 transforms were getting assigned a
set of 12 tx_types, some of which we did not intend to
support for these sizes.

Change-Id: I44aca4876b261c345623cd04ad6235bca4532701
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 933b7d0..4908137 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1976,6 +1976,9 @@
 #if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
                                     mrc_mask,
 #endif  // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
+#if CONFIG_EXT_TX
+                                    plane,
+#endif  // CONFIG_EXT_TX
                                     tx_type, tx_size, recon, MAX_TX_SIZE, eob);
 
 #if CONFIG_DIST_8X8
@@ -4006,6 +4009,9 @@
 #if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
                               mrc_mask,
 #endif  // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
+#if CONFIG_EXT_TX
+                              plane,
+#endif  // CONFIG_EXT_TX
                               tx_type, tx_size, rec_buffer, MAX_TX_SIZE, eob);
   if (eob > 0) {
 #if CONFIG_DIST_8X8
@@ -4897,6 +4903,17 @@
   param.tx_size = max_txsize_rect_lookup[bsize];
   param.bd = 8;
   param.lossless = 0;
+#if CONFIG_EXT_TX
+  const MACROBLOCKD *xd = &x->e_mbd;
+  const struct macroblockd_plane *const pd = &xd->plane[0];
+  const BLOCK_SIZE plane_bsize =
+      get_plane_block_size(xd->mi[0]->mbmi.sb_type, pd);
+  // TODO(sarahparker) This assumes reduced_tx_set_used == 0. I will do a
+  // follow up refactor to make the actual value of reduced_tx_set_used
+  // within this function.
+  param.tx_set_type = get_ext_tx_set_type(param.tx_size, plane_bsize,
+                                          is_inter_block(&xd->mi[0]->mbmi), 0);
+#endif  // CONFIG_EXT_TX
 
 #if CONFIG_TXMG
   av1_highbd_fwd_txfm(p->src_diff, DCT_coefs, bw, &param);
@@ -5001,8 +5018,12 @@
   int idx, idy;
   int prune = 0;
 #if CONFIG_EXT_TX
+  const TX_SIZE sqr_up_tx_size =
+      txsize_sqr_up_map[max_txsize_rect_lookup[bsize]];
+  // Get the tx_size 1 level down
+  TX_SIZE min_tx_size = sub_tx_size_map[sqr_up_tx_size];
   const TxSetType tx_set_type = get_ext_tx_set_type(
-      max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
+      min_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
   int within_border = (mi_row + mi_size_high[bsize] <= cm->mi_rows) &&
                       (mi_col + mi_size_wide[bsize] <= cm->mi_cols);
@@ -5070,6 +5091,11 @@
 #endif  // CONFIG_EXT_TX && CONFIG_MRC_TX
 #if CONFIG_EXT_TX
     if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
+    (void)prune;
+// TODO(sarahparker) This speed feature has been temporarily disabled
+// with ext-tx because it is not compatible with the current
+// search method. It will be fixed in a followup.
+/*
     if (is_inter) {
       if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
         if (!do_tx_type_search(tx_type, prune,
@@ -5081,6 +5107,7 @@
         if (tx_type != intra_mode_to_tx_type_context[mbmi->mode]) continue;
       }
     }
+*/
 #else   // CONFIG_EXT_TX
     if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
         !do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
@@ -5095,6 +5122,15 @@
 
     rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, mi_row, mi_col,
                                  ref_best_rd, tx_type);
+#if CONFIG_EXT_TX
+    // If the current tx_type is not included in the tx_set for the smallest
+    // tx size found, then all vartx partitions were actually transformed with
+    // DCT_DCT and we should avoid picking it.
+    const TxSetType min_tx_set_type = get_ext_tx_set_type(
+        mbmi->min_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
+    if (!av1_ext_tx_used[min_tx_set_type][tx_type]) continue;
+#endif  // CONFIG_EXT_TX
+
     ref_best_rd = AOMMIN(rd, ref_best_rd);
     if (rd < best_rd) {
       best_rd = rd;