Refactor the ext-tx experiment

Use common structure for inter and intra tx type information when
possible.

Change-Id: I1fd3bc86033871ffbcc2b496a31dca00b7d64b31
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 019e1f5..cabfd7c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2311,7 +2311,7 @@
   }
 }
 
-// #TODO(angiebird): use this function whenever it's possible
+// TODO(angiebird): use this function whenever it's possible
 int av1_tx_type_cost(const AV1_COMMON *cm, const MACROBLOCK *x,
                      const MACROBLOCKD *xd, BLOCK_SIZE bsize, int plane,
                      TX_SIZE tx_size, TX_TYPE tx_type) {
@@ -2442,10 +2442,10 @@
   if (max_tx_size >= TX_32X32 && tx_size == TX_4X4) return 1;
 #if CONFIG_EXT_TX
   const AV1_COMMON *const cm = &cpi->common;
-  int ext_tx_set =
-      get_ext_tx_set(tx_size, bs, is_inter, cm->reduced_tx_set_used);
+  const TxSetType tx_set_type =
+      get_ext_tx_set_type(tx_size, bs, is_inter, cm->reduced_tx_set_used);
+  if (!av1_ext_tx_used[tx_set_type][tx_type]) return 1;
   if (is_inter) {
-    if (!ext_tx_used_inter[ext_tx_set][tx_type]) return 1;
     if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
       if (!do_tx_type_search(tx_type, prune)) return 1;
     }
@@ -2453,7 +2453,6 @@
     if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
       if (tx_type != intra_mode_to_tx_type_context[mbmi->mode]) return 1;
     }
-    if (!ext_tx_used_intra[ext_tx_set][tx_type]) return 1;
   }
 #else   // CONFIG_EXT_TX
   if (tx_size >= TX_32X32 && tx_type != DCT_DCT) return 1;
@@ -2494,9 +2493,6 @@
   const int is_inter = is_inter_block(mbmi);
   int prune = 0;
   const int plane = 0;
-#if CONFIG_EXT_TX
-  int ext_tx_set;
-#endif  // CONFIG_EXT_TX
   av1_invalid_rd_stats(rd_stats);
 
   mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
@@ -2504,8 +2500,10 @@
   mbmi->min_tx_size = get_min_tx_size(mbmi->tx_size);
 #endif  // CONFIG_VAR_TX
 #if CONFIG_EXT_TX
-  ext_tx_set =
+  int ext_tx_set =
       get_ext_tx_set(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used);
+  const TxSetType tx_set_type =
+      get_ext_tx_set_type(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
 
   if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
@@ -2526,12 +2524,12 @@
 #endif  // CONFIG_PVQ
 
     for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
+      if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
       RD_STATS this_rd_stats;
       if (is_inter) {
         if (x->use_default_inter_tx_type &&
             tx_type != get_default_tx_type(0, xd, 0, mbmi->tx_size))
           continue;
-        if (!ext_tx_used_inter[ext_tx_set][tx_type]) continue;
         if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
           if (!do_tx_type_search(tx_type, prune)) continue;
         }
@@ -2542,7 +2540,6 @@
         if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
           if (tx_type != intra_mode_to_tx_type_context[mbmi->mode]) continue;
         }
-        if (!ext_tx_used_intra[ext_tx_set][tx_type]) continue;
       }
 
       mbmi->tx_type = tx_type;
@@ -2696,10 +2693,9 @@
       if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
       const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
       RD_STATS this_rd_stats;
-      int ext_tx_set =
-          get_ext_tx_set(rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
-      if ((is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) ||
-          (!is_inter && ext_tx_used_intra[ext_tx_set][tx_type])) {
+      const TxSetType tx_set_type = get_ext_tx_set_type(
+          rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
+      if (av1_ext_tx_used[tx_set_type][tx_type]) {
         rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type,
                       rect_tx_size);
         ref_best_rd = AOMMIN(rd, ref_best_rd);
@@ -2745,10 +2741,9 @@
       if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
       const TX_SIZE tx_size = quarter_txsize_lookup[bs];
       RD_STATS this_rd_stats;
-      int ext_tx_set =
-          get_ext_tx_set(tx_size, bs, is_inter, cm->reduced_tx_set_used);
-      if ((is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) ||
-          (!is_inter && ext_tx_used_intra[ext_tx_set][tx_type])) {
+      const TxSetType tx_set_type =
+          get_ext_tx_set_type(tx_size, bs, is_inter, cm->reduced_tx_set_used);
+      if (av1_ext_tx_used[tx_set_type][tx_type]) {
         rd =
             txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, tx_size);
         if (rd < best_rd) {
@@ -5271,6 +5266,8 @@
   RD_STATS rd_stats_stack[4];
 #endif  // CONFIG_EXT_PARTITION
 #if CONFIG_EXT_TX
+  const TxSetType tx_set_type = get_ext_tx_set_type(
+      max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
   const int ext_tx_set =
       get_ext_tx_set(max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
@@ -5315,8 +5312,8 @@
       continue;
 #endif  // CONFIG_MRC_TX
 #if CONFIG_EXT_TX
+    if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
     if (is_inter) {
-      if (!ext_tx_used_inter[ext_tx_set][tx_type]) continue;
       if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
         if (!do_tx_type_search(tx_type, prune)) continue;
       }
@@ -5324,7 +5321,6 @@
       if (!ALLOW_INTRA_EXT_TX && bsize >= BLOCK_8X8) {
         if (tx_type != intra_mode_to_tx_type_context[mbmi->mode]) continue;
       }
-      if (!ext_tx_used_intra[ext_tx_set][tx_type]) continue;
     }
 #else   // CONFIG_EXT_TX
     if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&