Refactor the ext-tx experiment

Use common structure for inter and intra tx type information when
possible.

Change-Id: I1fd3bc86033871ffbcc2b496a31dca00b7d64b31
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index c381e3a..a9a1f6a 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -933,75 +933,142 @@
 #if CONFIG_EXT_TX
 #define ALLOW_INTRA_EXT_TX 1
 
-typedef enum {
-  // DCT only
-  EXT_TX_SET_DCTONLY = 0,
-  // DCT + Identity only
-  EXT_TX_SET_DCT_IDTX,
+// Number of transform types in each set type
+static const int av1_num_ext_tx_set[EXT_TX_SET_TYPES] = {
+  1, 2,
 #if CONFIG_MRC_TX
-  // DCT + MRC_DCT
+  2, 3,
+#endif  // CONFIG_MRC_TX
+  5, 7, 12, 16,
+};
+
+// Maps intra set index to the set type
+static const int av1_ext_tx_set_type_intra[EXT_TX_SETS_INTRA] = {
+  EXT_TX_SET_DCTONLY, EXT_TX_SET_DTT4_IDTX_1DDCT, EXT_TX_SET_DTT4_IDTX,
+#if CONFIG_MRC_TX
   EXT_TX_SET_MRC_DCT,
-  // DCT + MRC_DCT + IDTX
+#endif  // CONFIG_MRC_TX
+};
+
+// Maps inter set index to the set type
+static const int av1_ext_tx_set_type_inter[EXT_TX_SETS_INTER] = {
+  EXT_TX_SET_DCTONLY,         EXT_TX_SET_ALL16,
+  EXT_TX_SET_DTT9_IDTX_1DDCT, EXT_TX_SET_DCT_IDTX,
+#if CONFIG_MRC_TX
   EXT_TX_SET_MRC_DCT_IDTX,
 #endif  // CONFIG_MRC_TX
-  // Discrete Trig transforms w/o flip (4) + Identity (1)
-  EXT_TX_SET_DTT4_IDTX,
-  // Discrete Trig transforms w/o flip (4) + Identity (1) + 1D Hor/vert DCT (2)
-  EXT_TX_SET_DTT4_IDTX_1DDCT,
-  // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver DCT (2)
-  EXT_TX_SET_DTT9_IDTX_1DDCT,
-  // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver (6)
-  EXT_TX_SET_ALL16,
-  EXT_TX_SET_TYPES
-} TxSetType;
+};
 
+// Maps set types above to the indices used for intra
+static const int ext_tx_set_index_intra[EXT_TX_SET_TYPES] = {
+  0, -1,
 #if CONFIG_MRC_TX
-// Number of transform types in each set type
-static const int num_ext_tx_set[EXT_TX_SET_TYPES] = {
-  1, 2, 2, 3, 5, 7, 12, 16
+  3, -1,
+#endif  // CONFIG_MRC_TX
+  2, 1,  -1, -1,
 };
 
-// Maps intra set index to the set type
-static const int ext_tx_set_type_intra[EXT_TX_SETS_INTRA] = {
-  EXT_TX_SET_DCTONLY, EXT_TX_SET_DTT4_IDTX_1DDCT, EXT_TX_SET_DTT4_IDTX,
-  EXT_TX_SET_MRC_DCT
-};
-
-// Maps inter set index to the set type
-static const int ext_tx_set_type_inter[EXT_TX_SETS_INTER] = {
-  EXT_TX_SET_DCTONLY, EXT_TX_SET_ALL16, EXT_TX_SET_DTT9_IDTX_1DDCT,
-  EXT_TX_SET_DCT_IDTX, EXT_TX_SET_MRC_DCT_IDTX
-};
-
-// Maps set types above to the indices used for intra
-static const int ext_tx_set_index_intra[EXT_TX_SET_TYPES] = { 0, -1, 3,  -1,
-                                                              2, 1,  -1, -1 };
-
-// Maps set types above to the indices used for inter
-static const int ext_tx_set_index_inter[EXT_TX_SET_TYPES] = { 0,  3,  -1, 4,
-                                                              -1, -1, 2,  1 };
-#else   // CONFIG_MRC_TX
-// Number of transform types in each set type
-static const int num_ext_tx_set[EXT_TX_SET_TYPES] = { 1, 2, 5, 7, 12, 16 };
-
-// Maps intra set index to the set type
-static const int ext_tx_set_type_intra[EXT_TX_SETS_INTRA] = {
-  EXT_TX_SET_DCTONLY, EXT_TX_SET_DTT4_IDTX_1DDCT, EXT_TX_SET_DTT4_IDTX
-};
-
-// Maps inter set index to the set type
-static const int ext_tx_set_type_inter[EXT_TX_SETS_INTER] = {
-  EXT_TX_SET_DCTONLY, EXT_TX_SET_ALL16, EXT_TX_SET_DTT9_IDTX_1DDCT,
-  EXT_TX_SET_DCT_IDTX
-};
-
-// Maps set types above to the indices used for intra
-static const int ext_tx_set_index_intra[EXT_TX_SET_TYPES] = { 0, -1, 2,
-                                                              1, -1, -1 };
-
 // Maps set types above to the indices used for inter
 static const int ext_tx_set_index_inter[EXT_TX_SET_TYPES] = {
-  0, 3, -1, -1, 2, 1
+  0,  3,
+#if CONFIG_MRC_TX
+  -1, 4,
+#endif  // CONFIG_MRC_TX
+  -1, -1, 2, 1,
+};
+
+static const int use_intra_ext_tx_for_txsize[EXT_TX_SETS_INTRA][EXT_TX_SIZES] =
+    {
+#if CONFIG_CHROMA_2X2
+      { 1, 1, 1, 1, 1 },  // unused
+      { 0, 1, 1, 0, 0 },
+      { 0, 0, 0, 1, 0 },
+#if CONFIG_MRC_TX
+      { 0, 0, 0, 0, 1 },
+#endif  // CONFIG_MRC_TX
+#else   // CONFIG_CHROMA_2X2
+      { 1, 1, 1, 1 },  // unused
+      { 1, 1, 0, 0 },
+      { 0, 0, 1, 0 },
+#if CONFIG_MRC_TX
+      { 0, 0, 0, 1 },
+#endif  // CONFIG_MRC_TX
+#endif  // CONFIG_CHROMA_2X2
+    };
+
+static const int use_inter_ext_tx_for_txsize[EXT_TX_SETS_INTER][EXT_TX_SIZES] =
+    {
+#if CONFIG_CHROMA_2X2
+      { 1, 1, 1, 1, 1 },  // unused
+      { 0, 1, 1, 0, 0 }, { 0, 0, 0, 1, 0 }, { 0, 0, 0, 0, 1 },
+#if CONFIG_MRC_TX
+      { 0, 0, 0, 0, 1 },
+#endif  // CONFIG_MRC_TX
+#else   // CONFIG_CHROMA_2X2
+      { 1, 1, 1, 1 },  // unused
+      { 1, 1, 0, 0 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 },
+#if CONFIG_MRC_TX
+      { 0, 0, 0, 1 },
+#endif  // CONFIG_MRC_TX
+#endif  // CONFIG_CHROMA_2X2
+    };
+
+// 1D Transforms used in inter set, this needs to be changed if
+// ext_tx_used_inter is changed
+static const int ext_tx_used_inter_1D[EXT_TX_SETS_INTER][TX_TYPES_1D] = {
+  { 1, 0, 0, 0 }, { 1, 1, 1, 1 }, { 1, 1, 1, 1 }, { 1, 0, 0, 1 },
+#if CONFIG_MRC_TX
+  { 1, 0, 0, 1 },
+#endif  // CONFIG_MRC_TX
+};
+
+#if CONFIG_MRC_TX
+static const int av1_ext_tx_used[EXT_TX_SET_TYPES][TX_TYPES] = {
+  {
+      1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  },
+  {
+      1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
+  },
+  {
+      1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+  },
+  {
+      1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,
+  },
+  {
+      1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
+  },
+  {
+      1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
+  },
+  {
+      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
+  },
+  {
+      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
+  },
+};
+#else   // CONFIG_MRC_TX
+static const int av1_ext_tx_used[EXT_TX_SET_TYPES][TX_TYPES] = {
+  {
+      1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+  },
+  {
+      1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
+  },
+  {
+      1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
+  },
+  {
+      1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
+  },
+  {
+      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
+  },
+  {
+      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+  },
 };
 #endif  // CONFIG_MRC_TX
 
@@ -1043,124 +1110,11 @@
                   : ext_tx_set_index_intra[set_type];
 }
 
-#if CONFIG_MRC_TX
-static const int use_intra_ext_tx_for_txsize[EXT_TX_SETS_INTRA][EXT_TX_SIZES] =
-    {
-#if CONFIG_CHROMA_2X2
-      { 1, 1, 1, 1, 1 },  // unused
-      { 0, 1, 1, 0, 0 },
-      { 0, 0, 0, 1, 0 },
-      { 0, 0, 0, 0, 1 },
-#else
-      { 1, 1, 1, 1 },  // unused
-      { 1, 1, 0, 0 },
-      { 0, 0, 1, 0 },
-      { 0, 0, 0, 1 },
-#endif  // CONFIG_CHROMA_2X2
-    };
-
-static const int use_inter_ext_tx_for_txsize[EXT_TX_SETS_INTER][EXT_TX_SIZES] =
-    {
-#if CONFIG_CHROMA_2X2
-      { 1, 1, 1, 1, 1 },  // unused
-      { 0, 1, 1, 0, 0 }, { 0, 0, 0, 1, 0 },
-      { 0, 0, 0, 0, 1 }, { 0, 0, 0, 0, 1 },
-#else
-      { 1, 1, 1, 1 },  // unused
-      { 1, 1, 0, 0 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { 0, 0, 0, 1 },
-#endif  // CONFIG_CHROMA_2X2
-    };
-
-// Transform types used in each intra set
-static const int ext_tx_used_intra[EXT_TX_SETS_INTRA][TX_TYPES] = {
-  { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
-  { 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0 },
-  { 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
-  { 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1 },
-};
-
-// Numbers of transform types used in each intra set
-static const int ext_tx_cnt_intra[EXT_TX_SETS_INTRA] = { 1, 7, 5, 2 };
-
-// Transform types used in each inter set
-static const int ext_tx_used_inter[EXT_TX_SETS_INTER][TX_TYPES] = {
-  { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
-  { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0 },
-  { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0 },
-  { 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
-  { 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1 },
-};
-
-// Numbers of transform types used in each inter set
-static const int ext_tx_cnt_inter[EXT_TX_SETS_INTER] = { 1, 16, 12, 2, 3 };
-
-// 1D Transforms used in inter set, this needs to be changed if
-// ext_tx_used_inter is changed
-static const int ext_tx_used_inter_1D[EXT_TX_SETS_INTER][TX_TYPES_1D] = {
-  { 1, 0, 0, 0 }, { 1, 1, 1, 1 }, { 1, 1, 1, 1 }, { 1, 0, 0, 1 }, { 1, 0, 0, 1 }
-};
-#else  // CONFIG_MRC_TX
-static const int use_intra_ext_tx_for_txsize[EXT_TX_SETS_INTRA][EXT_TX_SIZES] =
-    {
-#if CONFIG_CHROMA_2X2
-      { 1, 1, 1, 1, 1 },  // unused
-      { 0, 1, 1, 0, 0 },
-      { 0, 0, 0, 1, 0 },
-#else
-      { 1, 1, 1, 1 },  // unused
-      { 1, 1, 0, 0 },
-      { 0, 0, 1, 0 },
-#endif  // CONFIG_CHROMA_2X2
-    };
-
-static const int use_inter_ext_tx_for_txsize[EXT_TX_SETS_INTER][EXT_TX_SIZES] =
-    {
-#if CONFIG_CHROMA_2X2
-      { 1, 1, 1, 1, 1 },  // unused
-      { 0, 1, 1, 0, 0 },
-      { 0, 0, 0, 1, 0 },
-      { 0, 0, 0, 0, 1 },
-#else
-      { 1, 1, 1, 1 },  // unused
-      { 1, 1, 0, 0 },
-      { 0, 0, 1, 0 },
-      { 0, 0, 0, 1 },
-#endif  // CONFIG_CHROMA_2X2
-    };
-
-// Transform types used in each intra set
-static const int ext_tx_used_intra[EXT_TX_SETS_INTRA][TX_TYPES] = {
-  { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
-  { 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0 },
-  { 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 },
-};
-
-// Numbers of transform types used in each intra set
-static const int ext_tx_cnt_intra[EXT_TX_SETS_INTRA] = { 1, 7, 5 };
-
-// Transform types used in each inter set
-static const int ext_tx_used_inter[EXT_TX_SETS_INTER][TX_TYPES] = {
-  { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
-  { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
-  { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0 },
-  { 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 },
-};
-
-// Numbers of transform types used in each inter set
-static const int ext_tx_cnt_inter[EXT_TX_SETS_INTER] = { 1, 16, 12, 2 };
-
-// 1D Transforms used in inter set, this needs to be changed if
-// ext_tx_used_inter is changed
-static const int ext_tx_used_inter_1D[EXT_TX_SETS_INTER][TX_TYPES_1D] = {
-  { 1, 0, 0, 0 }, { 1, 1, 1, 1 }, { 1, 1, 1, 1 }, { 1, 0, 0, 1 },
-};
-#endif  // CONFIG_MRC_TX
-
 static INLINE int get_ext_tx_types(TX_SIZE tx_size, BLOCK_SIZE bs, int is_inter,
                                    int use_reduced_set) {
   const int set_type =
       get_ext_tx_set_type(tx_size, bs, is_inter, use_reduced_set);
-  return num_ext_tx_set[set_type];
+  return av1_num_ext_tx_set[set_type];
 }
 
 #if CONFIG_RECT_TX
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 05bf08c..f30f3fd 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -1409,10 +1409,8 @@
 #endif
 #endif
 #if CONFIG_EXT_TX
-int av1_ext_tx_intra_ind[EXT_TX_SETS_INTRA][TX_TYPES];
-int av1_ext_tx_intra_inv[EXT_TX_SETS_INTRA][TX_TYPES];
-int av1_ext_tx_inter_ind[EXT_TX_SETS_INTER][TX_TYPES];
-int av1_ext_tx_inter_inv[EXT_TX_SETS_INTER][TX_TYPES];
+int av1_ext_tx_ind[EXT_TX_SET_TYPES][TX_TYPES];
+int av1_ext_tx_inv[EXT_TX_SET_TYPES][TX_TYPES];
 #endif
 
 #if CONFIG_SMOOTH_HV
@@ -2485,70 +2483,59 @@
 
 #if CONFIG_EXT_TX
 /* clang-format off */
-const aom_tree_index av1_ext_tx_inter_tree[EXT_TX_SETS_INTER]
-                                           [TREE_SIZE(TX_TYPES)] = {
-  { // ToDo(yaowu): remove used entry 0.
-    0
-  }, {
-    -IDTX, 2,
-    4, 14,
-    6, 8,
-    -V_DCT, -H_DCT,
-    10, 12,
-    -V_ADST, -H_ADST,
-    -V_FLIPADST, -H_FLIPADST,
-    -DCT_DCT, 16,
-    18, 24,
-    20, 22,
-    -ADST_DCT, -DCT_ADST,
-    -FLIPADST_DCT, -DCT_FLIPADST,
-    26, 28,
-    -ADST_ADST, -FLIPADST_FLIPADST,
-    -ADST_FLIPADST, -FLIPADST_ADST
-  }, {
-    -IDTX, 2,
-    4, 6,
-    -V_DCT, -H_DCT,
-    -DCT_DCT, 8,
-    10, 16,
-    12, 14,
-    -ADST_DCT, -DCT_ADST,
-    -FLIPADST_DCT, -DCT_FLIPADST,
-    18, 20,
-    -ADST_ADST, -FLIPADST_FLIPADST,
-    -ADST_FLIPADST, -FLIPADST_ADST
-  }, {
-    -IDTX, -DCT_DCT,
-  },
+const aom_tree_index av1_ext_tx_tree[EXT_TX_SET_TYPES][TREE_SIZE(TX_TYPES)] = {
+    // TODO(yaowu@google.com): remove used entry 0.
+    { 0 },
+    { -IDTX, -DCT_DCT, },
 #if CONFIG_MRC_TX
-  {
-    -IDTX, 2, -DCT_DCT, -MRC_DCT,
-  }
+    { -DCT_DCT, -MRC_DCT, },
+    {   -IDTX, 2,
+        -DCT_DCT, -MRC_DCT, },
 #endif  // CONFIG_MRC_TX
-};
-
-const aom_tree_index av1_ext_tx_intra_tree[EXT_TX_SETS_INTRA]
-                                           [TREE_SIZE(TX_TYPES)] = {
-  {  // ToDo(yaowu): remove unused entry 0.
-    0
-  }, {
-    -IDTX, 2,
-    -DCT_DCT, 4,
-    6, 8,
-    -V_DCT, -H_DCT,
-    -ADST_ADST, 10,
-    -ADST_DCT, -DCT_ADST,
-  }, {
-    -IDTX, 2,
-    -DCT_DCT, 4,
-    -ADST_ADST, 6,
-    -ADST_DCT, -DCT_ADST,
-  },
-#if CONFIG_MRC_TX
-  {
-    -DCT_DCT, -MRC_DCT,
-  }
-#endif  // CONFIG_MRC_TX
+    {
+        -IDTX, 2,
+        -DCT_DCT, 4,
+        -ADST_ADST, 6,
+        -ADST_DCT, -DCT_ADST,
+    },
+    {
+        -IDTX, 2,
+        -DCT_DCT, 4,
+        6, 8,
+        -V_DCT, -H_DCT,
+        -ADST_ADST, 10,
+        -ADST_DCT, -DCT_ADST,
+    },
+    {
+        -IDTX, 2,
+        4, 6,
+        -V_DCT, -H_DCT,
+        -DCT_DCT, 8,
+        10, 16,
+        12, 14,
+        -ADST_DCT, -DCT_ADST,
+        -FLIPADST_DCT, -DCT_FLIPADST,
+        18, 20,
+        -ADST_ADST, -FLIPADST_FLIPADST,
+        -ADST_FLIPADST, -FLIPADST_ADST,
+    },
+    {
+        -IDTX, 2,
+        4, 14,
+        6, 8,
+        -V_DCT, -H_DCT,
+        10, 12,
+        -V_ADST, -H_ADST,
+        -V_FLIPADST, -H_FLIPADST,
+        -DCT_DCT, 16,
+        18, 24,
+        20, 22,
+        -ADST_DCT, -DCT_ADST,
+        -FLIPADST_DCT, -DCT_FLIPADST,
+        26, 28,
+        -ADST_ADST, -FLIPADST_FLIPADST,
+        -ADST_FLIPADST, -FLIPADST_ADST,
+    },
 };
 /* clang-format on */
 
@@ -5531,17 +5518,20 @@
     int s;
     for (s = 1; s < EXT_TX_SETS_INTER; ++s) {
       if (use_inter_ext_tx_for_txsize[s][i]) {
-        aom_tree_merge_probs(
-            av1_ext_tx_inter_tree[s], pre_fc->inter_ext_tx_prob[s][i],
-            counts->inter_ext_tx[s][i], fc->inter_ext_tx_prob[s][i]);
+        aom_tree_merge_probs(av1_ext_tx_tree[av1_ext_tx_set_type_inter[s]],
+                             pre_fc->inter_ext_tx_prob[s][i],
+                             counts->inter_ext_tx[s][i],
+                             fc->inter_ext_tx_prob[s][i]);
       }
     }
     for (s = 1; s < EXT_TX_SETS_INTRA; ++s) {
       if (use_intra_ext_tx_for_txsize[s][i]) {
-        for (j = 0; j < INTRA_MODES; ++j)
-          aom_tree_merge_probs(
-              av1_ext_tx_intra_tree[s], pre_fc->intra_ext_tx_prob[s][i][j],
-              counts->intra_ext_tx[s][i][j], fc->intra_ext_tx_prob[s][i][j]);
+        for (j = 0; j < INTRA_MODES; ++j) {
+          aom_tree_merge_probs(av1_ext_tx_tree[av1_ext_tx_set_type_intra[s]],
+                               pre_fc->intra_ext_tx_prob[s][i][j],
+                               counts->intra_ext_tx[s][i][j],
+                               fc->intra_ext_tx_prob[s][i][j]);
+        }
       }
     }
   }
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 05440e8..eb0dfa9 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -551,10 +551,8 @@
 extern const int av1_intra_mode_ind[INTRA_MODES];
 extern const int av1_intra_mode_inv[INTRA_MODES];
 #if CONFIG_EXT_TX
-extern int av1_ext_tx_intra_ind[EXT_TX_SETS_INTRA][TX_TYPES];
-extern int av1_ext_tx_intra_inv[EXT_TX_SETS_INTRA][TX_TYPES];
-extern int av1_ext_tx_inter_ind[EXT_TX_SETS_INTER][TX_TYPES];
-extern int av1_ext_tx_inter_inv[EXT_TX_SETS_INTER][TX_TYPES];
+extern int av1_ext_tx_ind[EXT_TX_SET_TYPES][TX_TYPES];
+extern int av1_ext_tx_inv[EXT_TX_SET_TYPES][TX_TYPES];
 #endif
 
 #if CONFIG_EXT_INTER
@@ -588,6 +586,8 @@
                                                  [TREE_SIZE(TX_TYPES)];
 extern const aom_tree_index av1_ext_tx_intra_tree[EXT_TX_SETS_INTRA]
                                                  [TREE_SIZE(TX_TYPES)];
+extern const aom_tree_index av1_ext_tx_tree[EXT_TX_SET_TYPES]
+                                           [TREE_SIZE(TX_TYPES)];
 #else
 extern const aom_tree_index av1_ext_tx_tree[TREE_SIZE(TX_TYPES)];
 #endif  // CONFIG_EXT_TX
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 4465735..6857818 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -292,6 +292,28 @@
 } TX_TYPE;
 
 #if CONFIG_EXT_TX
+typedef enum {
+  // DCT only
+  EXT_TX_SET_DCTONLY = 0,
+  // DCT + Identity only
+  EXT_TX_SET_DCT_IDTX,
+#if CONFIG_MRC_TX
+  // DCT + MRC_DCT
+  EXT_TX_SET_MRC_DCT,
+  // DCT + MRC_DCT + IDTX
+  EXT_TX_SET_MRC_DCT_IDTX,
+#endif  // CONFIG_MRC_TX
+  // Discrete Trig transforms w/o flip (4) + Identity (1)
+  EXT_TX_SET_DTT4_IDTX,
+  // Discrete Trig transforms w/o flip (4) + Identity (1) + 1D Hor/vert DCT (2)
+  EXT_TX_SET_DTT4_IDTX_1DDCT,
+  // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver DCT (2)
+  EXT_TX_SET_DTT9_IDTX_1DDCT,
+  // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver (6)
+  EXT_TX_SET_ALL16,
+  EXT_TX_SET_TYPES
+} TxSetType;
+
 #define IS_2D_TRANSFORM(tx_type) (tx_type < IDTX)
 #else
 #define IS_2D_TRANSFORM(tx_type) 1
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 00501eb..b19332a 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2468,10 +2468,12 @@
         const int eset =
             get_ext_tx_set(supertx_size, bsize, 1, cm->reduced_tx_set_used);
         if (eset > 0) {
+          const TxSetType tx_set_type = get_ext_tx_set_type(
+              supertx_size, bsize, 1, cm->reduced_tx_set_used);
           const int packed_sym =
               aom_read_symbol(r, ec_ctx->inter_ext_tx_cdf[eset][supertx_size],
-                              ext_tx_cnt_inter[eset], ACCT_STR);
-          txfm = av1_ext_tx_inter_inv[eset][packed_sym];
+                              av1_num_ext_tx_set[tx_set_type], ACCT_STR);
+          txfm = av1_ext_tx_inv[tx_set_type][packed_sym];
           if (xd->counts) ++xd->counts->inter_ext_tx[eset][supertx_size][txfm];
         }
       }
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index dc8eff7..f293f81 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -985,6 +985,8 @@
         !supertx_enabled &&
 #endif  // CONFIG_SUPERTX
         !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+      const TxSetType tx_set_type = get_ext_tx_set_type(
+          tx_size, mbmi->sb_type, inter_block, cm->reduced_tx_set_used);
       const int eset = get_ext_tx_set(tx_size, mbmi->sb_type, inter_block,
                                       cm->reduced_tx_set_used);
       // eset == 0 should correspond to a set with only DCT_DCT and
@@ -993,14 +995,14 @@
       FRAME_COUNTS *counts = xd->counts;
 
       if (inter_block) {
-        *tx_type = av1_ext_tx_inter_inv[eset][aom_read_symbol(
+        *tx_type = av1_ext_tx_inv[tx_set_type][aom_read_symbol(
             r, ec_ctx->inter_ext_tx_cdf[eset][square_tx_size],
-            ext_tx_cnt_inter[eset], ACCT_STR)];
+            av1_num_ext_tx_set[tx_set_type], ACCT_STR)];
         if (counts) ++counts->inter_ext_tx[eset][square_tx_size][*tx_type];
       } else if (ALLOW_INTRA_EXT_TX) {
-        *tx_type = av1_ext_tx_intra_inv[eset][aom_read_symbol(
+        *tx_type = av1_ext_tx_inv[tx_set_type][aom_read_symbol(
             r, ec_ctx->intra_ext_tx_cdf[eset][square_tx_size][mbmi->mode],
-            ext_tx_cnt_intra[eset], ACCT_STR)];
+            av1_num_ext_tx_set[tx_set_type], ACCT_STR)];
         if (counts)
           ++counts->intra_ext_tx[eset][square_tx_size][mbmi->mode][*tx_type];
       }
diff --git a/av1/decoder/decoder.c b/av1/decoder/decoder.c
index d983a0f..ded933e 100644
--- a/av1/decoder/decoder.c
+++ b/av1/decoder/decoder.c
@@ -55,13 +55,10 @@
     av1_indices_from_tree(av1_switchable_interp_ind, av1_switchable_interp_inv,
                           av1_switchable_interp_tree);
 #if CONFIG_EXT_TX
-    int s;
-    for (s = 1; s < EXT_TX_SETS_INTRA; ++s)
-      av1_indices_from_tree(av1_ext_tx_intra_ind[s], av1_ext_tx_intra_inv[s],
-                            av1_ext_tx_intra_tree[s]);
-    for (s = 1; s < EXT_TX_SETS_INTER; ++s)
-      av1_indices_from_tree(av1_ext_tx_inter_ind[s], av1_ext_tx_inter_inv[s],
-                            av1_ext_tx_inter_tree[s]);
+    for (int s = 1; s < EXT_TX_SET_TYPES; ++s) {
+      av1_indices_from_tree(av1_ext_tx_ind[s], av1_ext_tx_inv[s],
+                            av1_ext_tx_tree[s]);
+    }
 #else
     av1_indices_from_tree(av1_ext_tx_ind, av1_ext_tx_inv, av1_ext_tx_tree);
 #endif
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 88a31bb..0dcf735 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -121,16 +121,24 @@
 void av1_encode_token_init(void) {
 #if CONFIG_EXT_TX
   int s;
-#endif  // CONFIG_EXT_TX
-#if CONFIG_EXT_TX
   for (s = 1; s < EXT_TX_SETS_INTER; ++s) {
-    av1_tokens_from_tree(ext_tx_inter_encodings[s], av1_ext_tx_inter_tree[s]);
+    av1_tokens_from_tree(ext_tx_inter_encodings[s],
+                         av1_ext_tx_tree[av1_ext_tx_set_type_inter[s]]);
   }
   for (s = 1; s < EXT_TX_SETS_INTRA; ++s) {
-    av1_tokens_from_tree(ext_tx_intra_encodings[s], av1_ext_tx_intra_tree[s]);
+    av1_tokens_from_tree(ext_tx_intra_encodings[s],
+                         av1_ext_tx_tree[av1_ext_tx_set_type_intra[s]]);
+  }
+  for (s = 1; s < EXT_TX_SET_TYPES; ++s) {
+    av1_indices_from_tree(av1_ext_tx_ind[s], av1_ext_tx_inv[s],
+                          av1_ext_tx_tree[s]);
   }
 #else
   av1_tokens_from_tree(ext_tx_encodings, av1_ext_tx_tree);
+  /* This hack is necessary because the four TX_TYPES are not consecutive,
+      e.g., 0, 1, 2, 3, when doing an in-order traversal of the av1_ext_tx_tree
+      structure. */
+  av1_indices_from_tree(av1_ext_tx_ind, av1_ext_tx_inv, av1_ext_tx_tree);
 #endif  // CONFIG_EXT_TX
 
 #if CONFIG_EXT_INTRA && CONFIG_INTRA_INTERP
@@ -158,19 +166,6 @@
       an in-order traversal of the av1_switchable_interp_tree structure. */
   av1_indices_from_tree(av1_switchable_interp_ind, av1_switchable_interp_inv,
                         av1_switchable_interp_tree);
-/* This hack is necessary because the four TX_TYPES are not consecutive,
-    e.g., 0, 1, 2, 3, when doing an in-order traversal of the av1_ext_tx_tree
-    structure. */
-#if CONFIG_EXT_TX
-  for (s = 1; s < EXT_TX_SETS_INTRA; ++s)
-    av1_indices_from_tree(av1_ext_tx_intra_ind[s], av1_ext_tx_intra_inv[s],
-                          av1_ext_tx_intra_tree[s]);
-  for (s = 1; s < EXT_TX_SETS_INTER; ++s)
-    av1_indices_from_tree(av1_ext_tx_inter_ind[s], av1_ext_tx_inter_inv[s],
-                          av1_ext_tx_inter_tree[s]);
-#else
-  av1_indices_from_tree(av1_ext_tx_ind, av1_ext_tx_inv, av1_ext_tx_tree);
-#endif
 }
 
 static void write_intra_mode_kf(const AV1_COMMON *cm, FRAME_CONTEXT *frame_ctx,
@@ -1607,23 +1602,23 @@
       if (tx_type == MRC_DCT)
         assert(mbmi->valid_mrc_mask && "Invalid MRC mask");
 #endif  // CONFIG_MRC_TX
-
+      const TxSetType tx_set_type = get_ext_tx_set_type(
+          tx_size, bsize, is_inter, cm->reduced_tx_set_used);
       const int eset =
           get_ext_tx_set(tx_size, bsize, is_inter, cm->reduced_tx_set_used);
       // eset == 0 should correspond to a set with only DCT_DCT and there
       // is no need to send the tx_type
       assert(eset > 0);
+      assert(av1_ext_tx_used[tx_set_type][tx_type]);
       if (is_inter) {
-        assert(ext_tx_used_inter[eset][tx_type]);
-        aom_write_symbol(w, av1_ext_tx_inter_ind[eset][tx_type],
+        aom_write_symbol(w, av1_ext_tx_ind[tx_set_type][tx_type],
                          ec_ctx->inter_ext_tx_cdf[eset][square_tx_size],
-                         ext_tx_cnt_inter[eset]);
+                         av1_num_ext_tx_set[tx_set_type]);
       } else if (ALLOW_INTRA_EXT_TX) {
-        assert(ext_tx_used_intra[eset][tx_type]);
         aom_write_symbol(
-            w, av1_ext_tx_intra_ind[eset][tx_type],
+            w, av1_ext_tx_ind[tx_set_type][tx_type],
             ec_ctx->intra_ext_tx_cdf[eset][square_tx_size][mbmi->mode],
-            ext_tx_cnt_intra[eset]);
+            av1_num_ext_tx_set[tx_set_type]);
       }
     }
 #else
@@ -3082,10 +3077,12 @@
         !skip) {
       const int eset =
           get_ext_tx_set(supertx_size, bsize, 1, cm->reduced_tx_set_used);
+      const int tx_set_type =
+          get_ext_tx_set_type(supertx_size, bsize, 1, cm->reduced_tx_set_used);
       if (eset > 0) {
-        aom_write_symbol(w, av1_ext_tx_inter_ind[eset][mbmi->tx_type],
+        aom_write_symbol(w, av1_ext_tx_ind[tx_set_type][mbmi->tx_type],
                          ec_ctx->inter_ext_tx_cdf[eset][supertx_size],
-                         ext_tx_cnt_inter[eset]);
+                         av1_num_ext_tx_set[tx_set_type]);
       }
     }
 #else
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 201f858..40f0064 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -6073,16 +6073,20 @@
     const int eset =
         get_ext_tx_set(tx_size, bsize, is_inter, cm->reduced_tx_set_used);
     if (eset > 0) {
+      const TxSetType tx_set_type = get_ext_tx_set_type(
+          tx_size, bsize, is_inter, cm->reduced_tx_set_used);
       if (is_inter) {
         update_cdf(fc->inter_ext_tx_cdf[eset][txsize_sqr_map[tx_size]],
-                   av1_ext_tx_inter_ind[eset][tx_type], ext_tx_cnt_inter[eset]);
+                   av1_ext_tx_ind[tx_set_type][tx_type],
+                   av1_num_ext_tx_set[tx_set_type]);
         ++counts->inter_ext_tx[eset][txsize_sqr_map[tx_size]][tx_type];
       } else {
         ++counts->intra_ext_tx[eset][txsize_sqr_map[tx_size]][mbmi->mode]
                               [tx_type];
         update_cdf(
             fc->intra_ext_tx_cdf[eset][txsize_sqr_map[tx_size]][mbmi->mode],
-            av1_ext_tx_intra_ind[eset][tx_type], ext_tx_cnt_intra[eset]);
+            av1_ext_tx_ind[tx_set_type][tx_type],
+            av1_num_ext_tx_set[tx_set_type]);
       }
     }
   }
@@ -7327,9 +7331,6 @@
   TX_SIZE tx_size;
   MB_MODE_INFO *mbmi;
   TX_TYPE tx_type, best_tx_nostx;
-#if CONFIG_EXT_TX
-  int ext_tx_set;
-#endif  // CONFIG_EXT_TX
   int tmp_rate_tx = 0, skip_tx = 0;
   int64_t tmp_dist_tx = 0, rd_tx, bestrd_tx = INT64_MAX;
 
@@ -7399,7 +7400,9 @@
   tx_size = max_txsize_lookup[bsize];
   av1_subtract_plane(x, bsize, 0);
 #if CONFIG_EXT_TX
-  ext_tx_set = get_ext_tx_set(tx_size, bsize, 1, cm->reduced_tx_set_used);
+  int ext_tx_set = get_ext_tx_set(tx_size, bsize, 1, cm->reduced_tx_set_used);
+  const TxSetType tx_set_type =
+      get_ext_tx_set_type(tx_size, bsize, 1, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
   for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
 #if CONFIG_VAR_TX
@@ -7410,7 +7413,7 @@
 #endif  // CONFIG_VAR_TX
 
 #if CONFIG_EXT_TX
-    if (!ext_tx_used_inter[ext_tx_set][tx_type]) continue;
+    if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
 #else
     if (tx_size >= TX_32X32 && tx_type != DCT_DCT) continue;
 #endif  // CONFIG_EXT_TX
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 1300578..b06c43d 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -2044,12 +2044,11 @@
     }
 
 #if CONFIG_EXT_TX
-    int is_inter = is_inter_block(mbmi);
-    int ext_tx_set = get_ext_tx_set(get_min_tx_size(tx_size), mbmi->sb_type,
-                                    is_inter, cm->reduced_tx_set_used);
-    if (!(is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) &&
-        !(!is_inter && ext_tx_used_intra[ext_tx_set][tx_type]))
-      continue;
+    const int is_inter = is_inter_block(mbmi);
+    const TxSetType tx_set_type =
+        get_ext_tx_set_type(get_min_tx_size(tx_size), mbmi->sb_type, is_inter,
+                            cm->reduced_tx_set_used);
+    if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
 #endif  // CONFIG_EXT_TX
 
     RD_STATS this_rd_stats;
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 8a6ce0b..fd8b96e 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -179,15 +179,16 @@
       if (use_inter_ext_tx_for_txsize[s][i]) {
         av1_cost_tokens_from_cdf(x->inter_tx_type_costs[s][i],
                                  fc->inter_ext_tx_cdf[s][i],
-                                 av1_ext_tx_inter_inv[s]);
+                                 av1_ext_tx_inv[av1_ext_tx_set_type_inter[s]]);
       }
     }
     for (s = 1; s < EXT_TX_SETS_INTRA; ++s) {
       if (use_intra_ext_tx_for_txsize[s][i]) {
-        for (j = 0; j < INTRA_MODES; ++j)
-          av1_cost_tokens_from_cdf(x->intra_tx_type_costs[s][i][j],
-                                   fc->intra_ext_tx_cdf[s][i][j],
-                                   av1_ext_tx_intra_inv[s]);
+        for (j = 0; j < INTRA_MODES; ++j) {
+          av1_cost_tokens_from_cdf(
+              x->intra_tx_type_costs[s][i][j], fc->intra_ext_tx_cdf[s][i][j],
+              av1_ext_tx_inv[av1_ext_tx_set_type_intra[s]]);
+        }
       }
     }
   }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 019e1f5..cabfd7c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2311,7 +2311,7 @@
   }
 }
 
-// #TODO(angiebird): use this function whenever it's possible
+// TODO(angiebird): use this function whenever it's possible
 int av1_tx_type_cost(const AV1_COMMON *cm, const MACROBLOCK *x,
                      const MACROBLOCKD *xd, BLOCK_SIZE bsize, int plane,
                      TX_SIZE tx_size, TX_TYPE tx_type) {
@@ -2442,10 +2442,10 @@
   if (max_tx_size >= TX_32X32 && tx_size == TX_4X4) return 1;
 #if CONFIG_EXT_TX
   const AV1_COMMON *const cm = &cpi->common;
-  int ext_tx_set =
-      get_ext_tx_set(tx_size, bs, is_inter, cm->reduced_tx_set_used);
+  const TxSetType tx_set_type =
+      get_ext_tx_set_type(tx_size, bs, is_inter, cm->reduced_tx_set_used);
+  if (!av1_ext_tx_used[tx_set_type][tx_type]) return 1;
   if (is_inter) {
-    if (!ext_tx_used_inter[ext_tx_set][tx_type]) return 1;
     if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
       if (!do_tx_type_search(tx_type, prune)) return 1;
     }
@@ -2453,7 +2453,6 @@
     if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
       if (tx_type != intra_mode_to_tx_type_context[mbmi->mode]) return 1;
     }
-    if (!ext_tx_used_intra[ext_tx_set][tx_type]) return 1;
   }
 #else   // CONFIG_EXT_TX
   if (tx_size >= TX_32X32 && tx_type != DCT_DCT) return 1;
@@ -2494,9 +2493,6 @@
   const int is_inter = is_inter_block(mbmi);
   int prune = 0;
   const int plane = 0;
-#if CONFIG_EXT_TX
-  int ext_tx_set;
-#endif  // CONFIG_EXT_TX
   av1_invalid_rd_stats(rd_stats);
 
   mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
@@ -2504,8 +2500,10 @@
   mbmi->min_tx_size = get_min_tx_size(mbmi->tx_size);
 #endif  // CONFIG_VAR_TX
 #if CONFIG_EXT_TX
-  ext_tx_set =
+  int ext_tx_set =
       get_ext_tx_set(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used);
+  const TxSetType tx_set_type =
+      get_ext_tx_set_type(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
 
   if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
@@ -2526,12 +2524,12 @@
 #endif  // CONFIG_PVQ
 
     for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
+      if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
       RD_STATS this_rd_stats;
       if (is_inter) {
         if (x->use_default_inter_tx_type &&
             tx_type != get_default_tx_type(0, xd, 0, mbmi->tx_size))
           continue;
-        if (!ext_tx_used_inter[ext_tx_set][tx_type]) continue;
         if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
           if (!do_tx_type_search(tx_type, prune)) continue;
         }
@@ -2542,7 +2540,6 @@
         if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
           if (tx_type != intra_mode_to_tx_type_context[mbmi->mode]) continue;
         }
-        if (!ext_tx_used_intra[ext_tx_set][tx_type]) continue;
       }
 
       mbmi->tx_type = tx_type;
@@ -2696,10 +2693,9 @@
       if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
       const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
       RD_STATS this_rd_stats;
-      int ext_tx_set =
-          get_ext_tx_set(rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
-      if ((is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) ||
-          (!is_inter && ext_tx_used_intra[ext_tx_set][tx_type])) {
+      const TxSetType tx_set_type = get_ext_tx_set_type(
+          rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
+      if (av1_ext_tx_used[tx_set_type][tx_type]) {
         rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type,
                       rect_tx_size);
         ref_best_rd = AOMMIN(rd, ref_best_rd);
@@ -2745,10 +2741,9 @@
       if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
       const TX_SIZE tx_size = quarter_txsize_lookup[bs];
       RD_STATS this_rd_stats;
-      int ext_tx_set =
-          get_ext_tx_set(tx_size, bs, is_inter, cm->reduced_tx_set_used);
-      if ((is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) ||
-          (!is_inter && ext_tx_used_intra[ext_tx_set][tx_type])) {
+      const TxSetType tx_set_type =
+          get_ext_tx_set_type(tx_size, bs, is_inter, cm->reduced_tx_set_used);
+      if (av1_ext_tx_used[tx_set_type][tx_type]) {
         rd =
             txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, tx_size);
         if (rd < best_rd) {
@@ -5271,6 +5266,8 @@
   RD_STATS rd_stats_stack[4];
 #endif  // CONFIG_EXT_PARTITION
 #if CONFIG_EXT_TX
+  const TxSetType tx_set_type = get_ext_tx_set_type(
+      max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
   const int ext_tx_set =
       get_ext_tx_set(max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
@@ -5315,8 +5312,8 @@
       continue;
 #endif  // CONFIG_MRC_TX
 #if CONFIG_EXT_TX
+    if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
     if (is_inter) {
-      if (!ext_tx_used_inter[ext_tx_set][tx_type]) continue;
       if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
         if (!do_tx_type_search(tx_type, prune)) continue;
       }
@@ -5324,7 +5321,6 @@
       if (!ALLOW_INTRA_EXT_TX && bsize >= BLOCK_8X8) {
         if (tx_type != intra_mode_to_tx_type_context[mbmi->mode]) continue;
       }
-      if (!ext_tx_used_intra[ext_tx_set][tx_type]) continue;
     }
 #else   // CONFIG_EXT_TX
     if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&