Use tx_size 1 level down for transform type search

This addresses an inconsistency between the set used
to decode the tx_type in the bitstream and the set used
for the tx_type search. Previously, the set used to
read/write the tx_type was based on the smallest tx_size
in the vartx partitioning, but the search uses a set
based on the largest possible tx_size. This patch
changes the tx_type search to use the transform type
set associated with the tx_size 1 recursive level down from
the max square tx_size to make the search more consistent
with the bitstream syntax. If a tx_size is selected for an
invalid tx_type, DCT_DCT is used for that partition instead.

This patch also adds assertions to all exposed transform
functions to ensure that no illegal transform type/size
combinations occur.

This currently gets a 0.1% drop in performance on lowres.
The drop is due to the reduction of the tx_types available
for 32x16 and 16x32 transform sizes. Before this patch,
32x16 and 16x32 transforms were getting assigned a
set of 12 tx_types, some of which we did not intend to
support for these sizes.

Change-Id: I44aca4876b261c345623cd04ad6235bca4532701
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 7bd8302..de6e05b 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -256,6 +256,9 @@
 // apply. Otherwise they return 0
 int get_lgt4(const TxfmParam *txfm_param, int is_col,
              const tran_high_t **lgtmtx) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
                  vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
     lgtmtx[0] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
@@ -271,6 +274,9 @@
 
 int get_lgt8(const TxfmParam *txfm_param, int is_col,
              const tran_high_t **lgtmtx) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
                  vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
     lgtmtx[0] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
@@ -388,6 +394,9 @@
 
 void get_lgt4_from_pred(const TxfmParam *txfm_param, int is_col,
                         const tran_high_t **lgtmtx, int ntx) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   PREDICTION_MODE mode = txfm_param->mode;
   int stride = txfm_param->stride;
   uint8_t *dst = txfm_param->dst;
@@ -469,6 +478,9 @@
 
 void get_lgt8_from_pred(const TxfmParam *txfm_param, int is_col,
                         const tran_high_t **lgtmtx, int ntx) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   PREDICTION_MODE mode = txfm_param->mode;
   int stride = txfm_param->stride;
   uint8_t *dst = txfm_param->dst;
@@ -538,6 +550,9 @@
 // will just call DCT or ADST
 void get_lgt16up_from_pred(const TxfmParam *txfm_param, int is_col,
                            const tran_high_t **lgtmtx, int ntx) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   int tx_length = is_col ? tx_size_high[txfm_param->tx_size]
                          : tx_size_wide[txfm_param->tx_size];
   assert(tx_length == 16 || tx_length == 32);
@@ -2414,6 +2429,9 @@
 // idct
 void av1_idct4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
                      const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   const int eob = txfm_param->eob;
   if (eob > 1)
     av1_iht4x4_16_add(input, dest, stride, txfm_param);
@@ -2423,6 +2441,9 @@
 
 void av1_iwht4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
                      const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   const int eob = txfm_param->eob;
   if (eob > 1)
     aom_iwht4x4_16_add(input, dest, stride);
@@ -2897,6 +2918,9 @@
 
 void av1_highbd_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
                                  int stride, const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   int eob = txfm_param->eob;
   int bd = txfm_param->bd;
   int lossless = txfm_param->lossless;
@@ -2942,6 +2966,9 @@
 
 void av1_highbd_inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest,
                                  int stride, const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   const int32_t *src = cast_to_int32(input);
   av1_inv_txfm2d_add_4x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
                            txfm_param->tx_type, txfm_param->bd);
@@ -2949,6 +2976,9 @@
 
 void av1_highbd_inv_txfm_add_8x4(const tran_low_t *input, uint8_t *dest,
                                  int stride, const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   const int32_t *src = cast_to_int32(input);
   av1_inv_txfm2d_add_8x4_c(src, CONVERT_TO_SHORTPTR(dest), stride,
                            txfm_param->tx_type, txfm_param->bd);
@@ -3158,6 +3188,9 @@
 
 void av1_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
                       TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   const TX_SIZE tx_size = txfm_param->tx_size;
 #if CONFIG_LGT_FROM_PRED
   if (txfm_param->use_lgt) {
@@ -3199,13 +3232,27 @@
   }
 }
 
-static void init_txfm_param(const MACROBLOCKD *xd, TX_SIZE tx_size,
-                            TX_TYPE tx_type, int eob, TxfmParam *txfm_param) {
+static void init_txfm_param(const MACROBLOCKD *xd,
+#if CONFIG_EXT_TX
+                            int plane,
+#endif  // CONFIG_EXT_TX
+                            TX_SIZE tx_size, TX_TYPE tx_type, int eob,
+                            TxfmParam *txfm_param) {
   txfm_param->tx_type = tx_type;
   txfm_param->tx_size = tx_size;
   txfm_param->eob = eob;
   txfm_param->lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
   txfm_param->bd = xd->bd;
+#if CONFIG_EXT_TX
+  const struct macroblockd_plane *const pd = &xd->plane[plane];
+  const BLOCK_SIZE plane_bsize =
+      get_plane_block_size(xd->mi[0]->mbmi.sb_type, pd);
+  // TODO(sarahparker) This assumes reduced_tx_set_used == 0. I will do a
+  // follow up refactor to make the actual value of reduced_tx_set_used
+  // within this function.
+  txfm_param->tx_set_type = get_ext_tx_set_type(
+      txfm_param->tx_size, plane_bsize, is_inter_block(&xd->mi[0]->mbmi), 0);
+#endif  // CONFIG_EXT_TX
 #if CONFIG_LGT
   txfm_param->is_inter = is_inter_block(&xd->mi[0]->mbmi);
 #endif
@@ -3234,12 +3281,19 @@
 #if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
                                  uint8_t *mrc_mask,
 #endif  // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
+#if CONFIG_EXT_TX
+                                 int plane,
+#endif  // CONFIG_EXT_TX
                                  TX_TYPE tx_type, TX_SIZE tx_size, uint8_t *dst,
                                  int stride, int eob) {
   if (!eob) return;
 
   TxfmParam txfm_param;
-  init_txfm_param(xd, tx_size, tx_type, eob, &txfm_param);
+  init_txfm_param(xd,
+#if CONFIG_EXT_TX
+                  plane,
+#endif  // CONFIG_EXT_TX
+                  tx_size, tx_type, eob, &txfm_param);
 #if CONFIG_LGT || CONFIG_MRC_TX
   txfm_param.is_inter = is_inter_block(&xd->mi[0]->mbmi);
 #endif  // CONFIG_LGT || CONFIG_MRC_TX
@@ -3253,6 +3307,9 @@
   txfm_param.mode = mode;
 #endif  // CONFIG_LGT_FROM_PRED
 #endif  // CONFIG_LGT_FROM_PRED || CONFIG_MRC_TX
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param.tx_set_type][txfm_param.tx_type]);
+#endif  // CONFIG_EXT_TX
 
   const int is_hbd = get_bitdepth_data_path_index(xd);
 #if CONFIG_TXMG
@@ -3304,12 +3361,18 @@
 #if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
                               mrc_mask,
 #endif  // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
+#if CONFIG_EXT_TX
+                              plane,
+#endif  // CONFIG_EXT_TX
                               tx_type, tx_size, dst, dst_stride, eob);
 }
 
 void av1_highbd_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
                              TxfmParam *txfm_param) {
   const TX_SIZE tx_size = txfm_param->tx_size;
+#if CONFIG_EXT_TX
+  assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif  // CONFIG_EXT_TX
   switch (tx_size) {
 #if CONFIG_TX64X64
     case TX_64X64: