CONFIG_IST_REDUCTION

Search only 4 IST sets at encoder
diff --git a/aom_dsp/txfm_common.h b/aom_dsp/txfm_common.h
index 5731e92..6a51664 100644
--- a/aom_dsp/txfm_common.h
+++ b/aom_dsp/txfm_common.h
@@ -28,6 +28,9 @@
   // Primary transform set used for the current tx block.
   TX_TYPE tx_type;
 #if CONFIG_IST_SET_FLAG
+#if CONFIG_IST_REDUCTION
+  int sec_tx_set_idx;
+#endif // CONFIG_IST_REDUCTION
   // for both forward and inverse secondary transforms
   // Secondary transform set used for the current tx block.
   TX_TYPE sec_tx_set;
diff --git a/av1/common/common_data.h b/av1/common/common_data.h
index 27e8476..047bca4 100644
--- a/av1/common/common_data.h
+++ b/av1/common/common_data.h
@@ -1,4 +1,4 @@
-/*
+/*
  * Copyright (c) 2021, Alliance for Open Media. All rights reserved
  *
  * This source code is subject to the terms of the BSD 3-Clause Clear License
@@ -1014,6 +1014,28 @@
          block_size_high[bsize1] > block_size_high[bsize2];
 }
 
+#if CONFIG_IST_REDUCTION
+static const uint8_t ist_intra_stx_mapping[IST_DIR_SIZE][IST_DIR_SIZE] = {
+  { 6, 1, 0, 5, 4, 3, 2 },  // DC_PRED
+  { 1, 6, 0, 4, 2, 5, 3 },  // V_PRED, H_PRED, SMOOTH_V_PRED, SMOOTH_H_PRED
+  { 2, 6, 0, 5, 1, 4, 3 },  // D45_PRED
+  { 3, 4, 6, 1, 0, 2, 5 },  // D135_PRED
+  { 4, 1, 3, 6, 0, 5, 2 },  // D113_PRED, D157_PRED
+  { 5, 0, 6, 2, 1, 4, 3 },  // D203_PRED, D67_PRED
+  { 6, 1, 0, 5, 4, 3, 2 },  // SMOOTH_PRED
+};
+
+static const uint8_t inv_ist_intra_stx_mapping[IST_DIR_SIZE][IST_DIR_SIZE] = {
+  { 2, 1, 6, 5, 4, 3, 0 },  // DC_PRED
+  { 2, 0, 4, 6, 3, 5, 1 },  // V_PRED, H_PRED, SMOOTH_V_PRED, SMOOTH_H_PRED
+  { 2, 4, 0, 6, 5, 3, 1 },  // D45_PRED
+  { 4, 3, 5, 0, 1, 6, 2 },  // D135_PRED
+  { 4, 1, 6, 2, 0, 5, 3 },  // D113_PRED, D157_PRED
+  { 1, 4, 3, 6, 5, 0, 2 },  // D203_PRED, D67_PRED
+  { 2, 1, 6, 5, 4, 3, 0 },  // SMOOTH_PRED
+};
+#endif  // CONFIG_IST_REDUCTION
+
 #if CONFIG_INTRA_TX_IST_PARSE
 // Mapping of IST kernel set to an index based on intra mode.
 // The index will be signaled in the bitstream
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 701a618..60eb51c 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -28,6 +28,11 @@
 /*!\cond */
 
 #undef MAX_SB_SIZE
+
+#if CONFIG_IST_REDUCTION
+#define IST_REDUCE_SET_SIZE 4
+#endif
+
 #define BAWP_BUGFIX 1
 #define ADJUST_SUPER_RES_Q 1
 
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 72ff447..c55a804 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -431,6 +431,9 @@
   txfm_param->tx_type = get_primary_tx_type(tx_type);
 #if CONFIG_IST_SET_FLAG
   txfm_param->sec_tx_set = 0;
+#if CONFIG_IST_REDUCTION
+  txfm_param->sec_tx_set_idx = 0;
+#endif // CONFIG_IST_REDUCTION
 #endif  // CONFIG_IST_SET_FLAG
   txfm_param->sec_tx_type = 0;
   txfm_param->intra_mode = get_intra_mode(mbmi, plane);
@@ -447,6 +450,17 @@
     txfm_param->sec_tx_type = get_secondary_tx_type(tx_type);
 #if CONFIG_IST_SET_FLAG
     txfm_param->sec_tx_set = get_secondary_tx_set(tx_type);
+#if CONFIG_IST_REDUCTION
+    uint8_t intra_stx_mode = stx_transpose_mapping[txfm_param->intra_mode];
+    uint8_t stx_id;
+    if (txfm_param->tx_type == ADST_ADST) {
+      stx_id = txfm_param->sec_tx_set - IST_DIR_SIZE;
+    } else {
+      stx_id = txfm_param->sec_tx_set;
+    }
+    uint8_t stx_idx = inv_ist_intra_stx_mapping[intra_stx_mode][stx_id];
+    txfm_param->sec_tx_set_idx = stx_idx;
+#endif // CONFIG_IST_REDUCTION
 #endif  // CONFIG_IST_SET_FLAG
   }
   txfm_param->tx_size = tx_size;
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index e45ed98..f683a92 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -527,7 +527,12 @@
     // per set) Set 0 ~ IST_DIR_SIZE-1 for DCT_DCT, and Set IST_DIR_SIZE ~
     // IST_SET_SIZE-1 for ADST_ADST
     if (txfm_param->sec_tx_type == 0 &&
-        (txfm_param->sec_tx_set == 0 || txfm_param->sec_tx_set == IST_DIR_SIZE))
+#if CONFIG_IST_REDUCTION
+        txfm_param->sec_tx_set_idx == 0
+#else
+        (txfm_param->sec_tx_set == 0 || txfm_param->sec_tx_set == IST_DIR_SIZE)
+#endif  // CONFIG_IST_REDUCTION
+        )
 #else
     if (txfm_param->sec_tx_type == 0)
 #endif  // CONFIG_IST_ANY_SET
@@ -649,6 +654,9 @@
   txfm_param->tx_type = get_primary_tx_type(tx_type);
 #if CONFIG_IST_SET_FLAG
   txfm_param->sec_tx_set = 0;
+#if CONFIG_IST_REDUCTION
+  txfm_param->sec_tx_set_idx = 0;
+#endif // CONFIG_IST_REDUCTION
 #endif  // CONFIG_IST_SET_FLAG
   txfm_param->sec_tx_type = 0;
   txfm_param->intra_mode = get_intra_mode(mbmi, plane);
@@ -666,6 +674,17 @@
       !(mbmi->fsc_mode[xd->tree_type == CHROMA_PART])) {
 #if CONFIG_IST_SET_FLAG
     txfm_param->sec_tx_set = get_secondary_tx_set(tx_type);
+#if CONFIG_IST_REDUCTION
+    uint8_t intra_stx_mode = stx_transpose_mapping[txfm_param->intra_mode];
+    uint8_t stx_id;
+    if (txfm_param->tx_type == ADST_ADST) {
+      stx_id = txfm_param->sec_tx_set - IST_DIR_SIZE;
+    } else {
+      stx_id = txfm_param->sec_tx_set;
+    }
+    uint8_t stx_idx = inv_ist_intra_stx_mapping[intra_stx_mode][stx_id];
+    txfm_param->sec_tx_set_idx = stx_idx;
+#endif // CONFIG_IST_REDUCTION
 #endif  // CONFIG_IST_SET_FLAG
     txfm_param->sec_tx_type = get_secondary_tx_type(tx_type);
   }
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 260d98a..b4ad627 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -2843,15 +2843,37 @@
     int max_set_id =
         (skip_stx || is_inter_block(mbmi, xd->tree_type)) ? 1 : IST_DIR_SIZE;
 #endif
+#if CONFIG_IST_REDUCTION
+    if (max_set_id == IST_DIR_SIZE) {
+    max_set_id = IST_REDUCE_SET_SIZE;
+    // if (txw <= 4 || txh <= 4){
+    //   max_set_id = 1;
+    // }
+    }
+    for (int set_idx = init_set_id; set_idx < max_set_id; ++set_idx) {
+      txfm_param.sec_tx_set_idx = set_idx;
+      uint8_t set_id = set_idx;
+      if (!is_inter_block(mbmi, xd->tree_type)) {
+      const PREDICTION_MODE mode = AOMMIN(intra_mode, SMOOTH_H_PRED);
+      int intra_stx_mode = stx_transpose_mapping[mode];
+      assert(set_idx < IST_REDUCE_SET_SIZE);
+      set_id = ist_intra_stx_mapping[intra_stx_mode][set_idx];
+      }
+#else
     // Iterate through all possible secondary tx sets for given primary tx type
     for (int set_id = init_set_id; set_id < max_set_id; ++set_id) {
+#endif  // CONFIG_IST_REDUCTION
 #endif  // CONFIG_IST_ANY_SET
 
       const int max_stx = xd->enable_ist && !(eob_found) ? 4 : 1;
 
       for (int stx = 0; stx < max_stx; ++stx) {
         // Skip repeated evaluation of no secondary transform.
+#if CONFIG_IST_REDUCTION
+        if (set_idx && !stx) continue;
+#else
         if (set_id && !stx) continue;
+#endif // CONFIG_IST_REDUCTION
 
 #if CONFIG_IST_ANY_SET
         TX_TYPE tx_type = primary_tx_type;
diff --git a/build/cmake/aom_config_defaults.cmake b/build/cmake/aom_config_defaults.cmake
index bc3d3de..62f8c11 100644
--- a/build/cmake/aom_config_defaults.cmake
+++ b/build/cmake/aom_config_defaults.cmake
@@ -159,6 +159,9 @@
 #IST multiset for inter
 set_aom_config_var(CONFIG_IST_INTER_MULTISET 1
                    "Enable multiset IST for inter TU.")
+#IST set reduction (non-normative)
+set_aom_config_var(CONFIG_IST_REDUCTION 1
+                   "Use 4 sets for IST encoder search.")
 
 # AV2 experiment flags.
 set_aom_config_var(CONFIG_IMPROVEIDTX 1