CWG-D191 : Enable IST for inter blocks

STATS_CHANGED
diff --git a/aom/aom_encoder.h b/aom/aom_encoder.h
index 80385ba..7c517fc 100644
--- a/aom/aom_encoder.h
+++ b/aom/aom_encoder.h
@@ -364,6 +364,12 @@
    *
    */
   unsigned int enable_ist;
+#if CONFIG_INTER_IST
+  /*!\brief enable Inter secondary transform
+   *
+   */
+  unsigned int enable_inter_ist;
+#endif  // CONFIG_INTER_IST
   /*!\brief enable cross-chroma component transform
    *
    */
diff --git a/aom_dsp/txfm_common.h b/aom_dsp/txfm_common.h
index 4b5dff8..aec2d17 100644
--- a/aom_dsp/txfm_common.h
+++ b/aom_dsp/txfm_common.h
@@ -37,6 +37,9 @@
   TX_TYPE sec_tx_type;
   // intra prediction mode used for the current tx block
   PREDICTION_MODE intra_mode;
+#if CONFIG_INTER_IST
+  int is_inter;
+#endif  // CONFIG_INTER_IST
   CctxType cctx_type;
   TX_SIZE tx_size;
   int lossless;
diff --git a/apps/aomenc.c b/apps/aomenc.c
index a9472a5..feb7a46 100644
--- a/apps/aomenc.c
+++ b/apps/aomenc.c
@@ -452,6 +452,9 @@
   &g_av1_codec_arg_defs.enable_idif,
 #endif  // CONFIG_IDIF
   &g_av1_codec_arg_defs.enable_ist,
+#if CONFIG_INTER_IST
+  &g_av1_codec_arg_defs.enable_inter_ist,
+#endif  // CONFIG_INTER_IST
   &g_av1_codec_arg_defs.enable_cctx,
   &g_av1_codec_arg_defs.enable_ibp,
   &g_av1_codec_arg_defs.explicit_ref_frame_map,
@@ -664,6 +667,9 @@
   config->enable_idif = 1;
 #endif  // CONFIG_IDIF
   config->enable_ist = 1;
+#if CONFIG_INTER_IST
+  config->enable_inter_ist = 1;
+#endif  // CONFIG_INTER_IST
   config->enable_cctx = 1;
   config->enable_ibp = 1;
   config->enable_adaptive_mvd = 1;
@@ -1522,6 +1528,10 @@
           encoder_cfg->enable_sdp);
   fprintf(stdout, "                               : IST (%d)\n",
           encoder_cfg->enable_ist);
+#if CONFIG_INTER_IST
+  fprintf(stdout, "                               : Inter IST (%d)\n",
+          encoder_cfg->enable_inter_ist);
+#endif  // CONFIG_INTER_IST
   fprintf(stdout,
           "Tool setting (Intra)           : SmoothIntra (%d), CfL (%d), "
           "FilterIntra (%d)\n",
diff --git a/av1/arg_defs.c b/av1/arg_defs.c
index 60d8cba..a149f7e 100644
--- a/av1/arg_defs.c
+++ b/av1/arg_defs.c
@@ -437,6 +437,11 @@
   .enable_ist = ARG_DEF(NULL, "enable-ist", 1,
                         "Enable intra secondary transform"
                         "(0: false, 1: true (default))"),
+#if CONFIG_INTER_IST
+  .enable_inter_ist = ARG_DEF(NULL, "enable-inter-ist", 1,
+                              "Enable inter secondary transform"
+                              "(0: false, 1: true (default))"),
+#endif  // CONFIG_INTER_IST
   .enable_cctx = ARG_DEF(NULL, "enable-cctx", 1,
                          "Enable cross-chroma component transform "
                          "(0: false, 1: true(default))"),
diff --git a/av1/arg_defs.h b/av1/arg_defs.h
index 1a93327..6ca7e17 100644
--- a/av1/arg_defs.h
+++ b/av1/arg_defs.h
@@ -170,6 +170,9 @@
   arg_def_t enable_idif;
 #endif  // CONFIG_IDIF
   arg_def_t enable_ist;
+#if CONFIG_INTER_IST
+  arg_def_t enable_inter_ist;
+#endif  // CONFIG_INTER_IST
   arg_def_t enable_cctx;
   arg_def_t enable_ibp;
   arg_def_t enable_adaptive_mvd;
diff --git a/av1/av1_cx_iface.c b/av1/av1_cx_iface.c
index 522ff92..1c2be59 100644
--- a/av1/av1_cx_iface.c
+++ b/av1/av1_cx_iface.c
@@ -142,9 +142,12 @@
   int enable_fsc;   // enable forward skip coding
   int enable_orip;  // enable ORIP
 #if CONFIG_IDIF
-  int enable_idif;          // enable IDIF
-#endif                      // CONFIG_IDIF
-  int enable_ist;           // enable intra secondary transform
+  int enable_idif;  // enable IDIF
+#endif              // CONFIG_IDIF
+  int enable_ist;   // enable intra secondary transform
+#if CONFIG_INTER_IST
+  int enable_inter_ist;     // enable inter secondary transform
+#endif                      // CONFIG_INTER_IST
   int enable_cctx;          // enable cross-chroma component transform
   int enable_ibp;           // enable intra bi-prediction
   int enable_adaptive_mvd;  // enable adaptive MVD resolution
@@ -485,6 +488,9 @@
   1,    // enable IDIF
 #endif  // CONFIG_IDIF
   1,    // enable intra secondary transform
+#if CONFIG_INTER_IST
+  1,    // enable inter secondary transform
+#endif  // CONFIG_INTER_IST
   1,    // enable cross-chroma component transform
   1,    // enable intra bi-prediction
   1,    // enable adaptive mvd resolution
@@ -1012,6 +1018,9 @@
   cfg->enable_idif = extra_cfg->enable_idif;
 #endif  // CONFIG_IDIF
   cfg->enable_ist = extra_cfg->enable_ist;
+#if CONFIG_INTER_IST
+  cfg->enable_inter_ist = extra_cfg->enable_inter_ist;
+#endif  // CONFIG_INTER_IST
   cfg->enable_cctx = extra_cfg->enable_cctx;
   cfg->enable_ibp = extra_cfg->enable_ibp;
   cfg->enable_adaptive_mvd = extra_cfg->enable_adaptive_mvd;
@@ -1138,6 +1147,9 @@
   extra_cfg->enable_idif = cfg->enable_idif;
 #endif  // CONFIG_IDIF
   extra_cfg->enable_ist = cfg->enable_ist;
+#if CONFIG_INTER_IST
+  extra_cfg->enable_inter_ist = cfg->enable_inter_ist;
+#endif  // CONFIG_INTER_IST
   extra_cfg->enable_cctx = cfg->enable_cctx;
   extra_cfg->enable_ibp = cfg->enable_ibp;
   extra_cfg->enable_adaptive_mvd = cfg->enable_adaptive_mvd;
@@ -1747,6 +1759,10 @@
   txfm_cfg->disable_ml_transform_speed_features =
       extra_cfg->disable_ml_transform_speed_features;
   txfm_cfg->enable_ist = extra_cfg->enable_ist && !extra_cfg->lossless;
+#if CONFIG_INTER_IST
+  txfm_cfg->enable_inter_ist =
+      extra_cfg->enable_inter_ist && !extra_cfg->lossless;
+#endif  // CONFIG_INTER_IST
   txfm_cfg->enable_cctx =
       tool_cfg->enable_monochrome ? 0 : extra_cfg->enable_cctx;
 
@@ -3946,6 +3962,11 @@
   } else if (arg_match_helper(&arg, &g_av1_codec_arg_defs.enable_ist, argv,
                               err_string)) {
     extra_cfg.enable_ist = arg_parse_int_helper(&arg, err_string);
+#if CONFIG_INTER_IST
+  } else if (arg_match_helper(&arg, &g_av1_codec_arg_defs.enable_inter_ist,
+                              argv, err_string)) {
+    extra_cfg.enable_inter_ist = arg_parse_int_helper(&arg, err_string);
+#endif  // CONFIG_INTER_IST
   } else if (arg_match_helper(&arg, &g_av1_codec_arg_defs.enable_cctx, argv,
                               err_string)) {
     extra_cfg.enable_cctx = arg_parse_int_helper(&arg, err_string);
@@ -4454,6 +4475,9 @@
         1,
 #endif      // CONFIG_IDIF
         1,  // IST
+#if CONFIG_INTER_IST
+        1,  // inter IST
+#endif      // CONFIG_INTER_IST
         1,  // enable_cctx
         1, 1,   1,
 #if CONFIG_IMPROVED_CFL
diff --git a/av1/common/av1_common_int.h b/av1/common/av1_common_int.h
index 29491a4..5183035 100644
--- a/av1/common/av1_common_int.h
+++ b/av1/common/av1_common_int.h
@@ -479,7 +479,10 @@
   uint8_t
       enable_idif;  // enables/disables Intra Directional Interpolation Filter
 #endif              // CONFIG_IDIF
-  uint8_t enable_ist;   // enables/disables intra secondary transform
+  uint8_t enable_ist;  // enables/disables intra secondary transform
+#if CONFIG_INTER_IST
+  uint8_t enable_inter_ist;  // enables/disables inter secondary transform
+#endif                       // CONFIG_INTER_IST
   uint8_t enable_cctx;  // enables/disables cross-chroma component transform
   uint8_t enable_ibp;   // enables/disables intra bi-prediction(IBP)
   uint8_t enable_adaptive_mvd;  // enables/disables adaptive MVD resolution
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 4315db8..204d7c4 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -3470,7 +3470,11 @@
  *
  */
 static INLINE void disable_secondary_tx_type(TX_TYPE *tx_type) {
+#if CONFIG_IST_SET_FLAG
+  *tx_type &= 0x000f;
+#else
   *tx_type &= 0x0f;
+#endif
 }
 /*
  * This function masks primary transform type used by the transform block
@@ -3486,7 +3490,11 @@
  * This function returns primary transform type used by the transform block
  */
 static INLINE TX_TYPE get_primary_tx_type(TX_TYPE tx_type) {
+#if CONFIG_IST_SET_FLAG
+  return tx_type & 0x000f;
+#else
   return tx_type & 0x0f;
+#endif
 }
 /*
  * This function returns secondary transform type used by the transform block
@@ -3529,14 +3537,20 @@
 static INLINE int block_signals_sec_tx_type(const MACROBLOCKD *xd,
                                             TX_SIZE tx_size, TX_TYPE tx_type,
                                             int eob) {
+#if CONFIG_INTER_IST
+  int should_return =
+      (is_inter_block(xd->mi[0], xd->tree_type) ? (eob <= 3) : (eob <= 1));
+  if (should_return) return 0;
+#else
   if (eob <= 1) return 0;
+#endif  // CONFIG_INTER_IST
   const MB_MODE_INFO *mbmi = xd->mi[0];
   PREDICTION_MODE intra_dir;
   if (mbmi->filter_intra_mode_info.use_filter_intra) {
     intra_dir =
         fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode];
   } else {
-    intra_dir = mbmi->mode;
+    intra_dir = get_intra_mode(mbmi, AOM_PLANE_Y);
   }
   const BLOCK_SIZE bs = mbmi->sb_type[PLANE_TYPE_Y];
   const TX_TYPE primary_tx_type = get_primary_tx_type(tx_type);
@@ -3550,10 +3564,22 @@
     ist_eob = 0;
   }
   const int is_depth0 = tx_size_is_depth0(tx_size, bs);
+#if CONFIG_INTER_IST
+  bool condition = (primary_tx_type == DCT_DCT && width >= 16 && height >= 16);
+  bool mode_dependent_condition =
+      (is_inter_block(mbmi, xd->tree_type)
+           ? condition
+           : (intra_dir < PAETH_PRED &&
+              !(mbmi->filter_intra_mode_info.use_filter_intra)));
+  const int code_stx =
+      (primary_tx_type == DCT_DCT || primary_tx_type == ADST_ADST) &&
+      mode_dependent_condition && is_depth0 && ist_eob;
+#else
   const int code_stx =
       (primary_tx_type == DCT_DCT || primary_tx_type == ADST_ADST) &&
       (intra_dir < PAETH_PRED) &&
       !(mbmi->filter_intra_mode_info.use_filter_intra) && is_depth0 && ist_eob;
+#endif  // CONFIG_INTER_IST
   return code_stx;
 }
 
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index b740017..99629fb 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -6790,11 +6790,26 @@
 
 #if CONFIG_ENTROPY_PARA
 #if CONFIG_IST_ANY_SET
+#if CONFIG_INTER_IST
+static const aom_cdf_prob default_stx_cdf[2][TX_SIZES][CDF_SIZE(STX_TYPES)] = {
+  { { AOM_CDF4(293, 11683, 25053), 0 },
+    { AOM_CDF4(2952, 9945, 16750), 0 },
+    { AOM_CDF4(2684, 9484, 16065), 0 },
+    { AOM_CDF4(3552, 10398, 15130), 0 },
+    { AOM_CDF4(10685, 14127, 17177), 1 } },
+  { { AOM_CDF4(293, 11683, 25053), 0 },
+    { AOM_CDF4(2952, 9945, 16750), 0 },
+    { AOM_CDF4(2684, 9484, 16065), 0 },
+    { AOM_CDF4(3552, 10398, 15130), 0 },
+    { AOM_CDF4(10685, 14127, 17177), 1 } }
+};
+#else
 static const aom_cdf_prob default_stx_cdf[TX_SIZES][CDF_SIZE(STX_TYPES)] = {
   { AOM_CDF4(303, 12789, 26360), 75 }, { AOM_CDF4(1671, 11400, 19958), 30 },
   { AOM_CDF4(2286, 9675, 16955), 5 },  { AOM_CDF4(3524, 9155, 13661), 0 },
   { AOM_CDF4(8277, 13215, 16769), 6 },
 };
+#endif  // CONFIG_INTER_IST
 #else
 static const aom_cdf_prob default_stx_cdf[TX_SIZES][CDF_SIZE(STX_TYPES)] = {
   { AOM_CDF4(1542, 11565, 24287), 0 },  { AOM_CDF4(4776, 13664, 21624), 0 },
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index f1d79cf..f8c15e4 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -588,7 +588,11 @@
                                [CDF_SIZE(TX_TYPES)];
   aom_cdf_prob cfl_sign_cdf[CDF_SIZE(CFL_JOINT_SIGNS)];
   aom_cdf_prob cfl_alpha_cdf[CFL_ALPHA_CONTEXTS][CDF_SIZE(CFL_ALPHABET_SIZE)];
+#if CONFIG_INTER_IST
+  aom_cdf_prob stx_cdf[2][TX_SIZES][CDF_SIZE(STX_TYPES)];
+#else
   aom_cdf_prob stx_cdf[TX_SIZES][CDF_SIZE(STX_TYPES)];
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
 #if CONFIG_INTRA_TX_IST_PARSE
   aom_cdf_prob most_probable_stx_set_cdf[CDF_SIZE(IST_DIR_SIZE)];
diff --git a/av1/common/idct.c b/av1/common/idct.c
index fb3a44b..24ffc9f 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -326,6 +326,23 @@
 #endif  // CONFIG_IST_SET_FLAG
   txfm_param->sec_tx_type = 0;
   txfm_param->intra_mode = get_intra_mode(mbmi, plane);
+#if CONFIG_INTER_IST
+  txfm_param->is_inter = is_inter_block(xd->mi[0], xd->tree_type);
+  const int width = tx_size_wide[tx_size];
+  const int height = tx_size_high[tx_size];
+  bool mode_dependent_condition =
+      (txfm_param->is_inter
+           ? (txfm_param->tx_type == DCT_DCT && width >= 16 && height >= 16)
+           : (txfm_param->intra_mode < PAETH_PRED &&
+              !(mbmi->filter_intra_mode_info.use_filter_intra)));
+  if (mode_dependent_condition && !xd->lossless[mbmi->segment_id]) {
+    // updated EOB condition
+    txfm_param->sec_tx_type = get_secondary_tx_type(tx_type);
+#if CONFIG_IST_SET_FLAG
+    txfm_param->sec_tx_set = get_secondary_tx_set(tx_type);
+#endif  // CONFIG_IST_SET_FLAG
+  }
+#else
   if ((txfm_param->intra_mode < PAETH_PRED) &&
       !xd->lossless[mbmi->segment_id] &&
       !(mbmi->filter_intra_mode_info.use_filter_intra)) {
@@ -335,6 +352,7 @@
     txfm_param->sec_tx_set = get_secondary_tx_set(tx_type);
 #endif  // CONFIG_IST_SET_FLAG
   }
+#endif  // CONFIG_INTER_IST
   txfm_param->tx_size = tx_size;
   // EOB needs to adjusted after inverse IST
   if (txfm_param->sec_tx_type) {
@@ -507,7 +525,13 @@
   MB_MODE_INFO *const mbmi = xd->mi[0];
   PREDICTION_MODE intra_mode = get_intra_mode(mbmi, plane);
   const int filter = mbmi->filter_intra_mode_info.use_filter_intra;
+#if CONFIG_INTER_IST
+  if (!is_inter_block(xd->mi[0], xd->tree_type))
+    assert(((intra_mode >= PAETH_PRED || filter) && txfm_param.sec_tx_type) ==
+           0);
+#else
   assert(((intra_mode >= PAETH_PRED || filter) && txfm_param.sec_tx_type) == 0);
+#endif  // CONFIG_INTER_IST
   (void)intra_mode;
   (void)filter;
 
@@ -582,7 +606,12 @@
                          : 32;
 
   if ((width >= 4 && height >= 4) && stx_type) {
+#if CONFIG_INTER_IST
+    const PREDICTION_MODE intra_mode =
+        (txfm_param->is_inter ? DC_PRED : txfm_param->intra_mode);
+#else
     const PREDICTION_MODE intra_mode = txfm_param->intra_mode;
+#endif  // CONFIG_INTER_IST
     PREDICTION_MODE mode = 0, mode_t = 0;
     const int log2width = tx_size_wide_log2[txfm_param->tx_size];
 
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 1813519..9662d7d 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -6700,6 +6700,9 @@
 #endif  // CONFIG_ALLOW_SAME_REF_COMPOUND
   seq_params->enable_sdp = aom_rb_read_bit(rb);
   seq_params->enable_ist = aom_rb_read_bit(rb);
+#if CONFIG_INTER_IST
+  seq_params->enable_inter_ist = aom_rb_read_bit(rb);
+#endif  // CONFIG_INTER_IST
   seq_params->enable_cctx = seq_params->monochrome ? 0 : aom_rb_read_bit(rb);
   seq_params->enable_mrls = aom_rb_read_bit(rb);
   seq_params->enable_tip = aom_rb_read_literal(rb, 2);
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index a1a304e..62d2ae1 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -1405,6 +1405,32 @@
 }
 
 // This function reads a 'secondary tx set' from the bitstream
+#if CONFIG_INTER_IST
+static void read_secondary_tx_set(MACROBLOCKD *xd, FRAME_CONTEXT *ec_ctx,
+                                  aom_reader *r, MB_MODE_INFO *mbmi,
+                                  TX_TYPE *tx_type) {
+  const int inter_block = is_inter_block(mbmi, xd->tree_type);
+  TX_TYPE stx_set_flag = DC_PRED;
+  if (!inter_block) {
+    uint8_t intra_mode = get_intra_mode(mbmi, AOM_PLANE_Y);
+#if CONFIG_INTRA_TX_IST_PARSE
+    const TX_TYPE reordered_stx_set_flag =
+        aom_read_symbol(r, ec_ctx->most_probable_stx_set_cdf, IST_DIR_SIZE,
+                        ACCT_INFO("stx_set_flag"));
+    stx_set_flag =
+        inv_most_probable_stx_mapping[intra_mode][reordered_stx_set_flag];
+#else
+    uint8_t stx_set_ctx = stx_transpose_mapping[intra_mode];
+    assert(stx_set_ctx < IST_DIR_SIZE);
+    stx_set_flag = aom_read_symbol(r, ec_ctx->stx_set_cdf[stx_set_ctx],
+                                   IST_DIR_SIZE, ACCT_INFO("stx_set_flag"));
+#endif  // CONFIG_INTRA_TX_IST_PARSE
+    assert(stx_set_flag < IST_DIR_SIZE);
+  }
+  if (get_primary_tx_type(*tx_type) == ADST_ADST) stx_set_flag += IST_DIR_SIZE;
+  set_secondary_tx_set(tx_type, stx_set_flag);
+}
+#else
 static void read_secondary_tx_set(FRAME_CONTEXT *ec_ctx, aom_reader *r,
                                   MB_MODE_INFO *mbmi, TX_TYPE *tx_type) {
   uint8_t intra_mode = get_intra_mode(mbmi, PLANE_TYPE_Y);
@@ -1425,6 +1451,7 @@
   if (get_primary_tx_type(*tx_type) == ADST_ADST) stx_set_flag += IST_DIR_SIZE;
   set_secondary_tx_set(tx_type, stx_set_flag);
 }
+#endif  // CONFIG_INTER_IST
 
 void av1_read_sec_tx_type(const AV1_COMMON *const cm, MACROBLOCKD *xd,
                           int blk_row, int blk_col, TX_SIZE tx_size,
@@ -1446,26 +1473,53 @@
       1) {
     FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
     const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
+#if !CONFIG_INTER_IST
     if (!inter_block) {
+#endif  // CONFIG_INTER_IST
       if (block_signals_sec_tx_type(xd, tx_size, *tx_type, *eob)) {
+#if CONFIG_INTER_IST
         const uint8_t stx_flag =
-            aom_read_symbol(r, ec_ctx->stx_cdf[square_tx_size], STX_TYPES,
-                            ACCT_INFO("stx_flag"));
+            aom_read_symbol(r, ec_ctx->stx_cdf[inter_block][square_tx_size],
+                            STX_TYPES, ACCT_INFO("stx_flag"));
+#else
+      const uint8_t stx_flag = aom_read_symbol(
+          r, ec_ctx->stx_cdf[square_tx_size], STX_TYPES, ACCT_INFO("stx_flag"));
+#endif  // CONFIG_INTER_IST
         *tx_type |= (stx_flag << 4);
 #if CONFIG_IST_SET_FLAG
+#if CONFIG_INTER_IST
+        if (stx_flag > 0) read_secondary_tx_set(xd, ec_ctx, r, mbmi, tx_type);
+#else
         if (stx_flag > 0) read_secondary_tx_set(ec_ctx, r, mbmi, tx_type);
+#endif  // CONFIG_INTER_IST
 #endif  // CONFIG_IST_SET_FLAG
       }
+#if !CONFIG_INTER_IST
     }
+#endif  // CONFIG_INTER_IST
+#if CONFIG_INTER_IST
+  } else {
+#else
   } else if (!inter_block) {
+#endif  // CONFIG_INTER_IST
     FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
     const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
     if (block_signals_sec_tx_type(xd, tx_size, *tx_type, *eob)) {
+#if CONFIG_INTER_IST
+      const uint8_t stx_flag =
+          aom_read_symbol(r, ec_ctx->stx_cdf[inter_block][square_tx_size],
+                          STX_TYPES, ACCT_INFO("stx_flag"));
+#else
       const uint8_t stx_flag = aom_read_symbol(
           r, ec_ctx->stx_cdf[square_tx_size], STX_TYPES, ACCT_INFO("stx_flag"));
+#endif  // // CONFIG_INTER_IST
       *tx_type |= (stx_flag << 4);
 #if CONFIG_IST_SET_FLAG
+#if CONFIG_INTER_IST
+      if (stx_flag > 0) read_secondary_tx_set(xd, ec_ctx, r, mbmi, tx_type);
+#else
       if (stx_flag > 0) read_secondary_tx_set(ec_ctx, r, mbmi, tx_type);
+#endif  // CONFIG_INTER_IST
 #endif  // CONFIG_IST_SET_FLAG
     }
   }
diff --git a/av1/decoder/decodetxb.c b/av1/decoder/decodetxb.c
index 85edaae..dd6373f 100644
--- a/av1/decoder/decodetxb.c
+++ b/av1/decoder/decodetxb.c
@@ -809,7 +809,14 @@
 
   // read  sec_tx_type here
   // Only y plane's sec_tx_type is transmitted
+#if CONFIG_INTER_IST
+  if ((plane == AOM_PLANE_Y) &&
+      (is_inter_block(mbmi, xd->tree_type)
+           ? (*eob > 3 && cm->seq_params.enable_inter_ist)
+           : (*eob != 1 && cm->seq_params.enable_ist))) {
+#else
   if ((plane == AOM_PLANE_Y) && (cm->seq_params.enable_ist) && (*eob != 1)) {
+#endif  // // CONFIG_INTER_IST
     av1_read_sec_tx_type(cm, xd, blk_row, blk_col, tx_size, eob, r);
   }
   //
@@ -1212,6 +1219,10 @@
   av1_set_entropy_contexts(xd, pd, plane, plane_bsize, tx_size, cul_level, col,
                            row);
   if (is_inter_block(mbmi, xd->tree_type)) {
+#if CONFIG_INTER_IST
+    const TX_TYPE tx_type1 = av1_get_tx_type(xd, plane_type, row, col, tx_size,
+                                             cm->features.reduced_tx_set_used);
+#endif  // CONFIG_INTER_IST
     if (plane == 0) {
       const int txw = tx_size_wide_unit[tx_size];
       const int txh = tx_size_high_unit[tx_size];
@@ -1225,7 +1236,11 @@
         const int stride = xd->tx_type_map_stride;
         for (int idy = 0; idy < txh; idy += tx_unit) {
           for (int idx = 0; idx < txw; idx += tx_unit) {
+#if CONFIG_INTER_IST
+            xd->tx_type_map[(row + idy) * stride + col + idx] = tx_type1;
+#else
             xd->tx_type_map[(row + idy) * stride + col + idx] = tx_type;
+#endif  // CONFIG_INTER_IST
           }
         }
       }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index e5c076f..4e82142 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1610,10 +1610,17 @@
     }
     if (is_inter) {
       const int eob_tx_ctx = get_lp2tx_ctx(tx_size, get_txb_bwl(tx_size), eob);
+#if CONFIG_INTER_IST
+      aom_write_symbol(
+          w, av1_ext_tx_ind[tx_set_type][get_primary_tx_type(tx_type)],
+          ec_ctx->inter_ext_tx_cdf[eset][eob_tx_ctx][square_tx_size],
+          av1_num_ext_tx_set[tx_set_type]);
+#else
       aom_write_symbol(
           w, av1_ext_tx_ind[tx_set_type][tx_type],
           ec_ctx->inter_ext_tx_cdf[eset][eob_tx_ctx][square_tx_size],
           av1_num_ext_tx_set[tx_set_type]);
+#endif  // CONFIG_INTER_IST
     } else {
       if (mbmi->fsc_mode[xd->tree_type == CHROMA_PART]) {
         return;
@@ -1691,26 +1698,52 @@
       !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
     FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
     const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
+#if !CONFIG_INTER_IST
     if (!is_inter) {
+#endif  // CONFIG_INTER_IST
       const TX_TYPE stx_flag = get_secondary_tx_type(tx_type);
       assert(stx_flag <= STX_TYPES - 1);
       if (block_signals_sec_tx_type(xd, tx_size, tx_type, eob)) {
-        aom_write_symbol(w, stx_flag, ec_ctx->stx_cdf[square_tx_size],
+#if CONFIG_INTER_IST
+        aom_write_symbol(w, stx_flag, ec_ctx->stx_cdf[is_inter][square_tx_size],
                          STX_TYPES);
+#else
+      aom_write_symbol(w, stx_flag, ec_ctx->stx_cdf[square_tx_size], STX_TYPES);
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
+#if CONFIG_INTER_IST
+        if (stx_flag > 0 && !is_inter)
+          write_sec_tx_set(ec_ctx, w, mbmi, tx_type);
+#else
         if (stx_flag > 0) write_sec_tx_set(ec_ctx, w, mbmi, tx_type);
+#endif  // CONFIG_INTER_IST
 #endif  // CONFIG_IST_SET_FLAG
       }
+#if !CONFIG_INTER_IST
     }
+#endif  // CONFIG_INTER_IST
+#if CONFIG_INTER_IST
+  } else if (!xd->lossless[mbmi->segment_id]) {
+#else
   } else if (!is_inter && !xd->lossless[mbmi->segment_id]) {
+#endif  // CONFIG_INTER_IST
     FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
     const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
     TX_TYPE stx_flag = get_secondary_tx_type(tx_type);
     assert(stx_flag <= STX_TYPES - 1);
     if (block_signals_sec_tx_type(xd, tx_size, tx_type, eob)) {
+#if CONFIG_INTER_IST
+      aom_write_symbol(w, stx_flag, ec_ctx->stx_cdf[is_inter][square_tx_size],
+                       STX_TYPES);
+#else
       aom_write_symbol(w, stx_flag, ec_ctx->stx_cdf[square_tx_size], STX_TYPES);
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
+#if CONFIG_INTER_IST
+      if (stx_flag > 0 && !is_inter) write_sec_tx_set(ec_ctx, w, mbmi, tx_type);
+#else
       if (stx_flag > 0) write_sec_tx_set(ec_ctx, w, mbmi, tx_type);
+#endif  // CONFIG_INTER_IST
 #endif  // CONFIG_IST_SET_FLAG
     }
   }
@@ -5415,6 +5448,9 @@
 #endif  // CONFIG_ALLOW_SAME_REF_COMPOUND
   aom_wb_write_bit(wb, seq_params->enable_sdp);
   aom_wb_write_bit(wb, seq_params->enable_ist);
+#if CONFIG_INTER_IST
+  aom_wb_write_bit(wb, seq_params->enable_inter_ist);
+#endif  // CONFIG_INTER_IST
   if (!seq_params->monochrome) aom_wb_write_bit(wb, seq_params->enable_cctx);
   aom_wb_write_bit(wb, seq_params->enable_mrls);
   aom_wb_write_literal(wb, seq_params->enable_tip, 2);
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index de77f4c..5ff0559 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -931,7 +931,11 @@
 #endif  // CONFIG_AIMC
 
   //! Cost of signaling secondary transform index
+#if CONFIG_INTER_IST
+  int stx_flag_cost[2][TX_SIZES][STX_TYPES];
+#else
   int stx_flag_cost[TX_SIZES][STX_TYPES];
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
   //! Cost of signaling secondary transform set index
 #if CONFIG_INTRA_TX_IST_PARSE
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 5b543ca..1ee70ad 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -546,8 +546,14 @@
   const PREDICTION_MODE intra_mode = get_intra_mode(mbmi, plane);
   const int filter = mbmi->filter_intra_mode_info.use_filter_intra;
   const int is_depth0 = tx_size_is_depth0(txfm_param->tx_size, plane_bsize);
+#if CONFIG_INTER_IST
+  if (!is_inter_block(mbmi, xd->tree_type))
+    assert(((intra_mode >= PAETH_PRED || filter || !is_depth0) &&
+            txfm_param->sec_tx_type) == 0);
+#else
   assert(((intra_mode >= PAETH_PRED || filter || !is_depth0) &&
           txfm_param->sec_tx_type) == 0);
+#endif  // CONFIG_INTER_IST
   (void)intra_mode;
   (void)filter;
   (void)is_depth0;
@@ -651,6 +657,25 @@
 #endif  // CONFIG_IST_SET_FLAG
   txfm_param->sec_tx_type = 0;
   txfm_param->intra_mode = get_intra_mode(mbmi, plane);
+#if CONFIG_INTER_IST
+  txfm_param->is_inter = is_inter_block(xd->mi[0], xd->tree_type);
+  const int width = tx_size_wide[tx_size];
+  const int height = tx_size_high[tx_size];
+  bool mode_dependent_condition =
+      (txfm_param->is_inter
+           ? (txfm_param->tx_type == DCT_DCT && width >= 16 && height >= 16 &&
+              cm->seq_params.enable_inter_ist)
+           : (txfm_param->intra_mode < PAETH_PRED &&
+              !(mbmi->filter_intra_mode_info.use_filter_intra) &&
+              cm->seq_params.enable_ist));
+  if (mode_dependent_condition && !xd->lossless[mbmi->segment_id] &&
+      !(mbmi->fsc_mode[xd->tree_type == CHROMA_PART])) {
+#if CONFIG_IST_SET_FLAG
+    txfm_param->sec_tx_set = get_secondary_tx_set(tx_type);
+#endif  // CONFIG_IST_SET_FLAG
+    txfm_param->sec_tx_type = get_secondary_tx_type(tx_type);
+  }
+#else
   if ((txfm_param->intra_mode < PAETH_PRED) &&
       !xd->lossless[mbmi->segment_id] &&
       !(mbmi->filter_intra_mode_info.use_filter_intra) &&
@@ -661,6 +686,7 @@
 #endif  // CONFIG_IST_SET_FLAG
     txfm_param->sec_tx_type = get_secondary_tx_type(tx_type);
   }
+#endif  // CONFIG_INTER_IST
   txfm_param->cctx_type = cctx_type;
   txfm_param->tx_size = tx_size;
   txfm_param->lossless = xd->lossless[mbmi->segment_id];
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index cb37ae7..5fb06dd 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -460,6 +460,9 @@
   seq->enable_idif = oxcf->intra_mode_cfg.enable_idif;
 #endif  // CONFIG_IDIF
   seq->enable_ist = oxcf->txfm_cfg.enable_ist;
+#if CONFIG_INTER_IST
+  seq->enable_inter_ist = oxcf->txfm_cfg.enable_inter_ist;
+#endif  // CONFIG_INTER_IST
   seq->enable_cctx = oxcf->txfm_cfg.enable_cctx;
   seq->enable_ibp = oxcf->intra_mode_cfg.enable_ibp;
   seq->enable_adaptive_mvd = tool_cfg->enable_adaptive_mvd;
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index a3d766d..aae2e02 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -338,6 +338,12 @@
    * Flag to indicate if intra secondary transform should be enabled.
    */
   bool enable_ist;
+#if CONFIG_INTER_IST
+  /*!
+   * Flag to indicate if inter secondary transform should be enabled.
+   */
+  bool enable_inter_ist;
+#endif  // CONFIG_INTER_IST
   /*!
    * Flag to indicate if cross chroma component transform is enabled.
    */
@@ -1279,7 +1285,11 @@
   unsigned int delta_lf_multi_cnts[FRAME_LF_COUNT][CDF_SIZE(DELTA_LF_PROBS +
                                                             1)];  // placeholder
   unsigned int delta_lf_cnts[CDF_SIZE(DELTA_LF_PROBS + 1)];       // placeholder
-  unsigned int stx_cnts[TX_SIZES][CDF_SIZE(STX_TYPES)];           // placeholder
+#if CONFIG_INTER_IST
+  unsigned int stx_cnts[2][TX_SIZES][CDF_SIZE(STX_TYPES)];  // placeholder
+#else
+  unsigned int stx_cnts[TX_SIZES][CDF_SIZE(STX_TYPES)];      // placeholder
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
 #if CONFIG_INTRA_TX_IST_PARSE
   unsigned int stx_set_cnts[CDF_SIZE(IST_DIR_SIZE)];  // placeholder
diff --git a/av1/encoder/encodetxb.c b/av1/encoder/encodetxb.c
index 2dcfbb7..42e9931 100644
--- a/av1/encoder/encodetxb.c
+++ b/av1/encoder/encodetxb.c
@@ -891,7 +891,14 @@
 
   // write sec_tx_type here
   // Only y plane's sec_tx_type is transmitted
+#if CONFIG_INTER_IST
+  if ((plane == AOM_PLANE_Y) &&
+      (is_inter_block(xd->mi[0], xd->tree_type)
+           ? (eob > 3 && cm->seq_params.enable_inter_ist)
+           : (eob != 1 && cm->seq_params.enable_ist))) {
+#else
   if ((plane == AOM_PLANE_Y) && (cm->seq_params.enable_ist) && eob != 1) {
+#endif  // CONFIG_INTER_IST
     av1_write_sec_tx_type(cm, xd, tx_type, tx_size, eob, w);
   }
 
@@ -1658,8 +1665,23 @@
         const int esc_eob = is_fsc ? bob_code : eob;
         const int eob_tx_ctx =
             get_lp2tx_ctx(tx_size, get_txb_bwl(tx_size), esc_eob);
+#if CONFIG_INTER_IST
+        int tx_type_cost = 0;
+        tx_type_cost =
+            x->mode_costs
+                .inter_tx_type_costs[ext_tx_set][eob_tx_ctx][square_tx_size]
+                                    [get_primary_tx_type(tx_type)];
+        if (block_signals_sec_tx_type(xd, tx_size, tx_type, eob) &&
+            xd->enable_ist) {
+          tx_type_cost +=
+              x->mode_costs.stx_flag_cost[is_inter][square_tx_size]
+                                         [get_secondary_tx_type(tx_type)];
+        }
+        return tx_type_cost;
+#else
         return x->mode_costs.inter_tx_type_costs[ext_tx_set][eob_tx_ctx]
                                                 [square_tx_size][tx_type];
+#endif  // CONFIG_INTER_IST
       }
     } else {
       if (ext_tx_set > 0) {
@@ -1691,9 +1713,15 @@
         }
         if (block_signals_sec_tx_type(xd, tx_size, tx_type, eob) &&
             xd->enable_ist) {
+#if CONFIG_INTER_IST
+          tx_type_cost +=
+              x->mode_costs.stx_flag_cost[is_inter][square_tx_size]
+                                         [get_secondary_tx_type(tx_type)];
+#else
           tx_type_cost +=
               x->mode_costs.stx_flag_cost[square_tx_size]
                                          [get_secondary_tx_type(tx_type)];
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
           if (get_secondary_tx_type(tx_type) > 0)
             tx_type_cost += get_sec_tx_set_cost(x, mbmi, tx_type);
@@ -1702,15 +1730,30 @@
         return tx_type_cost;
       }
     }
+#if CONFIG_INTER_IST
+  } else if (!xd->lossless[xd->mi[0]->segment_id]) {
+#else
   } else if (!is_inter && !xd->lossless[xd->mi[0]->segment_id]) {
+#endif  // CONFIG_INTER_IST
     if (block_signals_sec_tx_type(xd, tx_size, tx_type, eob) &&
         xd->enable_ist) {
+#if CONFIG_INTER_IST
+      int tx_type_cost =
+          x->mode_costs.stx_flag_cost[is_inter][square_tx_size]
+                                     [get_secondary_tx_type(tx_type)];
+#else
       int tx_type_cost =
           x->mode_costs
               .stx_flag_cost[square_tx_size][get_secondary_tx_type(tx_type)];
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
+#if CONFIG_INTER_IST
+      if (get_secondary_tx_type(tx_type) > 0 && !is_inter)
+        tx_type_cost += get_sec_tx_set_cost(x, mbmi, tx_type);
+#else
       if (get_secondary_tx_type(tx_type) > 0)
         tx_type_cost += get_sec_tx_set_cost(x, mbmi, tx_type);
+#endif  // CONFIG_INTER_IST
 #endif  // CONFIG_IST_SET_FLAG
       return tx_type_cost;
     }
@@ -4575,10 +4618,23 @@
         const int eob_tx_ctx =
             get_lp2tx_ctx(tx_size, get_txb_bwl(tx_size), esc_eob);
         if (allow_update_cdf) {
+#if CONFIG_INTER_IST
+          update_cdf(
+              fc->inter_ext_tx_cdf[eset][eob_tx_ctx][txsize_sqr_map[tx_size]],
+              av1_ext_tx_ind[tx_set_type][get_primary_tx_type(tx_type)],
+              av1_num_ext_tx_set[tx_set_type]);
+          // Modified condition for CDF update
+          if (cm->seq_params.enable_inter_ist &&
+              block_signals_sec_tx_type(xd, tx_size, tx_type, eob)) {
+            update_cdf(fc->stx_cdf[is_inter][txsize_sqr_map[tx_size]],
+                       (int8_t)get_secondary_tx_type(tx_type), STX_TYPES);
+          }
+#else
           update_cdf(
               fc->inter_ext_tx_cdf[eset][eob_tx_ctx][txsize_sqr_map[tx_size]],
               av1_ext_tx_ind[tx_set_type][tx_type],
               av1_num_ext_tx_set[tx_set_type]);
+#endif  // CONFIG_INTER_IST
         }
 #if CONFIG_ENTROPY_STATS
         ++counts->inter_ext_tx[eset][eob_tx_ctx][txsize_sqr_map[tx_size]]
@@ -4635,8 +4691,13 @@
           // Modified condition for CDF update
           if (cm->seq_params.enable_ist &&
               block_signals_sec_tx_type(xd, tx_size, tx_type, eob)) {
+#if CONFIG_INTER_IST
+            update_cdf(fc->stx_cdf[is_inter][txsize_sqr_map[tx_size]],
+                       (int8_t)get_secondary_tx_type(tx_type), STX_TYPES);
+#else
             update_cdf(fc->stx_cdf[txsize_sqr_map[tx_size]],
                        (int8_t)get_secondary_tx_type(tx_type), STX_TYPES);
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
             if (get_secondary_tx_type(tx_type) > 0)
               update_sec_tx_set_cdf(fc, mbmi, tx_type);
@@ -4647,18 +4708,38 @@
     }
   }
   // CDF update for txsize_sqr_up_map[tx_size] >= TX_32X32
+#if CONFIG_INTER_IST
+  else if (cm->quant_params.base_qindex > 0 &&
+           !mbmi->skip_txfm[xd->tree_type == CHROMA_PART] &&
+           !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP) &&
+           (is_inter ? cm->seq_params.enable_inter_ist
+                     : cm->seq_params.enable_ist) &&
+           block_signals_sec_tx_type(xd, tx_size, tx_type, eob)) {
+    if (eob == 1 && !is_inter && allow_update_cdf) return;
+#else
   else if (!is_inter && cm->quant_params.base_qindex > 0 &&
            !mbmi->skip_txfm[xd->tree_type == CHROMA_PART] &&
            !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP) &&
            cm->seq_params.enable_ist &&
            block_signals_sec_tx_type(xd, tx_size, tx_type, eob)) {
     if (eob == 1 && allow_update_cdf) return;
+#endif  // CONFIG_INTER_IST
     if (allow_update_cdf) {
+#if CONFIG_INTER_IST
+      update_cdf(fc->stx_cdf[is_inter][txsize_sqr_map[tx_size]],
+                 (int8_t)get_secondary_tx_type(tx_type), STX_TYPES);
+#else
       update_cdf(fc->stx_cdf[txsize_sqr_map[tx_size]],
                  (int8_t)get_secondary_tx_type(tx_type), STX_TYPES);
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IST_SET_FLAG
+#if CONFIG_INTER_IST
+      if (get_secondary_tx_type(tx_type) > 0 && !is_inter)
+        update_sec_tx_set_cdf(fc, mbmi, tx_type);
+#else
       if (get_secondary_tx_type(tx_type) > 0)
         update_sec_tx_set_cdf(fc, mbmi, tx_type);
+#endif  // CONFIG_INTER_IST
 #endif  // CONFIG_IST_SET_FLAG
     }
   }
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index 952ade2..ecd61de 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -419,7 +419,12 @@
                          : 32;
 
   if ((width >= 4 && height >= 4) && stx_type) {
+#if CONFIG_INTER_IST
+    const PREDICTION_MODE intra_mode =
+        (txfm_param->is_inter ? DC_PRED : txfm_param->intra_mode);
+#else
     const PREDICTION_MODE intra_mode = txfm_param->intra_mode;
+#endif  // CONFIG_INTER_IST
     PREDICTION_MODE mode = 0, mode_t = 0;
     const int log2width = tx_size_wide_log2[txfm_param->tx_size];
     const int sb_size = (width >= 8 && height >= 8) ? 8 : 4;
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 157faef..d32653b 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -555,10 +555,19 @@
   }
 #endif  // CONFIG_MORPH_PRED
 
+#if CONFIG_INTER_IST
+  for (j = 0; j < 2; ++j) {
+    for (i = 0; i < TX_SIZES; ++i) {
+      av1_cost_tokens_from_cdf(mode_costs->stx_flag_cost[j][i],
+                               fc->stx_cdf[j][i], NULL);
+    }
+  }
+#else
   for (i = 0; i < TX_SIZES; ++i) {
     av1_cost_tokens_from_cdf(mode_costs->stx_flag_cost[i], fc->stx_cdf[i],
                              NULL);
   }
+#endif  // CONFIG_INTER_IST
 
 #if CONFIG_IST_SET_FLAG
 #if CONFIG_INTRA_TX_IST_PARSE
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index d8a492f..b6705dc 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -2620,9 +2620,13 @@
       // Therefore transform domain distortion is not valid for these
       // transform sizes.
       (txsize_sqr_up_map[tx_size] != TX_64X64) &&
-      // Use pixel domain distortion for IST
-      // TODO(any): Make IST compatible with tx domain distortion
+  // Use pixel domain distortion for IST
+  // TODO(any): Make IST compatible with tx domain distortion
+#if CONFIG_INTER_IST
+      !(cm->seq_params.enable_ist || cm->seq_params.enable_inter_ist) &&
+#else
       !cm->seq_params.enable_ist &&
+#endif  // CONFIG_INTER_IST
       // Use pixel domain distortion for DC only blocks
       !dc_only_blk;
   // Flag to indicate if an extra calculation of distortion in the pixel domain
@@ -2678,43 +2682,60 @@
       continue;
     }
     bool skip_idx = false;
+#if CONFIG_INTER_IST
+    xd->enable_ist =
+        (is_inter_block(mbmi, xd->tree_type) ? cm->seq_params.enable_inter_ist
+                                             : cm->seq_params.enable_ist) &&
+#else
     xd->enable_ist = cm->seq_params.enable_ist &&
-                     !cpi->sf.tx_sf.tx_type_search.skip_stx_search &&
-                     !mbmi->fsc_mode[xd->tree_type == CHROMA_PART] &&
-                     !xd->lossless[mbmi->segment_id];
+#endif  // CONFIG_INTER_IST
+        !cpi->sf.tx_sf.tx_type_search.skip_stx_search &&
+        !mbmi->fsc_mode[xd->tree_type == CHROMA_PART] &&
+        !xd->lossless[mbmi->segment_id];
 
     const PREDICTION_MODE intra_mode = get_intra_mode(mbmi, plane);
     const int filter = mbmi->filter_intra_mode_info.use_filter_intra;
     const int is_depth0 = tx_size_is_depth0(tx_size, plane_bsize);
-
+#if CONFIG_INTER_IST
+    bool skip_stx =
+        ((primary_tx_type != DCT_DCT && primary_tx_type != ADST_ADST) ||
+         plane != 0 ||
+         (is_inter_block(mbmi, xd->tree_type)
+              ? (primary_tx_type == ADST_ADST || txw < 16 || txh < 16)
+              : (intra_mode >= PAETH_PRED || filter)) ||
+         dc_only_blk || !is_depth0 || (eob_found) || !xd->enable_ist);
+#else
     bool skip_stx =
         ((primary_tx_type != DCT_DCT && primary_tx_type != ADST_ADST) ||
          plane != 0 || is_inter_block(mbmi, xd->tree_type) || dc_only_blk ||
          intra_mode >= PAETH_PRED || filter || !is_depth0 || (eob_found) ||
          !xd->enable_ist);
+#endif  // CONFIG_INTER_IST
 
 #if CONFIG_IST_ANY_SET
+#if CONFIG_INTER_IST
+    int init_set_id = 0;
+    int max_set_id =
+        (skip_stx || is_inter_block(mbmi, xd->tree_type)) ? 1 : IST_DIR_SIZE;
+#else
     int max_set_id = skip_stx ? 1 : IST_DIR_SIZE;
+#endif  // CONFIG_INTER_IST
 
     // Iterate through all possible secondary tx sets for given primary tx type
+#if CONFIG_INTER_IST
+    for (int set_id = init_set_id; set_id < max_set_id; ++set_id) {
+#else
     for (int set_id = 0; set_id < max_set_id; ++set_id) {
-      TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
-      uint16_t stx_set =
-          (tx_type == ADST_ADST) ? set_id + IST_DIR_SIZE : set_id;
-      if (skip_stx) stx_set = 0;
-      assert(stx_set < IST_SET_SIZE);
-      set_secondary_tx_set(&tx_type, stx_set);
-      assert(tx_type < (1 << (PRIMARY_TX_BITS + SECONDARY_TX_BITS +
-                              SECONDARY_TX_SET_BITS)));
-      txfm_param.sec_tx_set = stx_set;
+#endif  // CONFIG_INTER_IST
 #endif  // CONFIG_IST_ANY_SET
 
       const int max_stx = xd->enable_ist && !(eob_found) ? 4 : 1;
 
       for (int stx = 0; stx < max_stx; ++stx) {
 #if CONFIG_IST_ANY_SET
-        tx_type = (TX_TYPE)txk_map[idx];
+        TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
         if (eob_found) skip_stx = true;
+        uint16_t stx_set = 0;
 #else   // CONFIG_IST_ANY_SET
       TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
       skip_stx |= eob_found;
@@ -2741,7 +2762,11 @@
         txfm_param.sec_tx_set = stx_set;
 #endif  // !CONFIG_IST_ANY_SET && CONFIG_IST_SET_FLAG
 #if CONFIG_IST_ANY_SET
+        stx_set = (primary_tx_type == ADST_ADST && stx) ? set_id + IST_DIR_SIZE
+                                                        : set_id;
         set_secondary_tx_set(&tx_type, stx_set);
+        txfm_param.sec_tx_set = stx_set;
+        assert(stx_set < IST_SET_SIZE);
         assert(tx_type < (1 << (PRIMARY_TX_BITS + SECONDARY_TX_BITS +
                                 SECONDARY_TX_SET_BITS)));
 #endif  // CONFIG_IST_ANY_SET
@@ -2783,6 +2808,7 @@
 
         // pre-skip DC only case to make things faster
         uint16_t *const eob = &p->eobs[block];
+#if CONFIG_INTER_IST
         if (*eob == 1 && plane == PLANE_TYPE_Y && !is_inter) {
           if (tx_type1 == DCT_DCT) eob_found = 1;
           if (tx_type1 != DCT_DCT || (stx && primary_tx_type)) {
@@ -2790,7 +2816,19 @@
             continue;
           }
         }
-
+        if (*eob <= 3 && plane == PLANE_TYPE_Y && is_inter && stx) {
+          update_txk_array(xd, blk_row, blk_col, tx_size, primary_tx_type);
+          continue;
+        }
+#else
+      if (*eob == 1 && plane == PLANE_TYPE_Y && !is_inter) {
+        if (tx_type1 == DCT_DCT) eob_found = 1;
+        if (tx_type1 != DCT_DCT || (stx && primary_tx_type)) {
+          update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
+          continue;
+        }
+      }
+#endif  // CONFIG_INTER_IST
 #if CONFIG_IMPROVEIDTX_RDPH
         if (fsc_mode_in && quant_param.use_optimize_b) {
           av1_optimize_fsc(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
@@ -2818,19 +2856,33 @@
               cost_coeffs(cm, x, plane, block, tx_size, tx_type, CCTX_NONE,
                           txb_ctx, cm->features.reduced_tx_set_used);
         }
-
+#if CONFIG_INTER_IST
         if (*eob == 1 && plane == PLANE_TYPE_Y && !is_inter) {
-          // post quant-skip DC only case
           if (tx_type1 == DCT_DCT) eob_found = 1;
           if (tx_type1 != DCT_DCT || (stx && primary_tx_type)) {
-            if (plane == 0)
-              update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
+            update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
             continue;
           }
           if (get_secondary_tx_type(tx_type) > 0) continue;
           if (txfm_param.sec_tx_type > 0) continue;
         }
-
+        if (*eob <= 3 && plane == PLANE_TYPE_Y && is_inter && stx) {
+          update_txk_array(xd, blk_row, blk_col, tx_size, primary_tx_type);
+          continue;
+        }
+#else
+      if (*eob == 1 && plane == PLANE_TYPE_Y && !is_inter) {
+        // post quant-skip DC only case
+        if (tx_type1 == DCT_DCT) eob_found = 1;
+        if (tx_type1 != DCT_DCT || (stx && primary_tx_type)) {
+          if (plane == 0)
+            update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
+          continue;
+        }
+        if (get_secondary_tx_type(tx_type) > 0) continue;
+        if (txfm_param.sec_tx_type > 0) continue;
+      }
+#endif  // CONFIG_INTER_IST
         // If rd cost based on coeff rate alone is already more than best_rd,
         // terminate early.
         if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
@@ -2961,6 +3013,9 @@
         }
       }  // for (int stx = 0;
 #if CONFIG_IST_ANY_SET
+#if CONFIG_INTER_IST
+      if (skip_idx) break;
+#endif  // CONFIG_INTER_IST
     }   // for (int stx_set = 0;
 #endif  // CONFIG_IST_ANY_SET
     if (skip_idx) break;
@@ -3496,7 +3551,11 @@
   int64_t best_rd = INT64_MAX;
   TX_PARTITION_TYPE best_tx_partition = TX_PARTITION_INVALID;
   uint8_t best_partition_entropy_ctxs[MAX_TX_PARTITIONS] = { 0 };
+#if CONFIG_INTER_IST
+  TX_TYPE best_partition_tx_types[MAX_TX_PARTITIONS] = { 0 };
+#else
   TX_PARTITION_TYPE best_partition_tx_types[MAX_TX_PARTITIONS] = { 0 };
+#endif  // CONFIG_INTER_IST
   uint8_t full_blk_skip[MAX_TX_PARTITIONS] = { 0 };
 
   const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
@@ -3585,7 +3644,11 @@
     get_tx_partition_sizes(type, max_tx_size, &txb_pos, sub_txs);
     uint8_t this_blk_skip[MAX_TX_PARTITIONS] = { 0 };
     uint8_t partition_entropy_ctxs[MAX_TX_PARTITIONS] = { 0 };
+#if CONFIG_INTER_IST
+    TX_TYPE partition_tx_types[MAX_TX_PARTITIONS] = { 0 };
+#else
     TX_PARTITION_TYPE partition_tx_types[MAX_TX_PARTITIONS] = { 0 };
+#endif  // CONFIG_INTER_IST
     int cur_block = block;
 
     // Compute cost of each tx size in this partition
@@ -3642,7 +3705,11 @@
     int blk_idx = 0;
     uint8_t this_blk_skip[MAX_TX_PARTITIONS] = { 0 };
     uint8_t partition_entropy_ctxs[MAX_TX_PARTITIONS] = { 0 };
+#if CONFIG_INTER_IST
+    TX_TYPE partition_tx_types[MAX_TX_PARTITIONS] = { 0 };
+#else
     TX_PARTITION_TYPE partition_tx_types[MAX_TX_PARTITIONS] = { 0 };
+#endif  // CONFIG_INTER_IST
     int cur_block = block;
 
     // Compute cost of each tx size in this partition
diff --git a/build/cmake/aom_config_defaults.cmake b/build/cmake/aom_config_defaults.cmake
index 91ea097..6625c9b 100644
--- a/build/cmake/aom_config_defaults.cmake
+++ b/build/cmake/aom_config_defaults.cmake
@@ -406,6 +406,7 @@
 set_aom_config_var(
   CONFIG_INTRA_TX_IST_PARSE 1
   "Parsing dependency removal for intra tx type and IST set signaling.")
+set_aom_config_var(CONFIG_INTER_IST 1 "Enable IST for inter blocks.")
 #
 # Variables in this section control optional features of the build system.
 #
diff --git a/common/args.c b/common/args.c
index fc28b62..01ce286 100644
--- a/common/args.c
+++ b/common/args.c
@@ -104,6 +104,9 @@
     GET_PARAMS(enable_idif);
 #endif  // CONFIG_IDIF
     GET_PARAMS(enable_ist);
+#if CONFIG_INTER_IST
+    GET_PARAMS(enable_inter_ist);
+#endif  // CONFIG_INTER_IST
     GET_PARAMS(enable_cctx);
     GET_PARAMS(enable_ibp);
     GET_PARAMS(enable_adaptive_mvd);
diff --git a/common/av1_config.c b/common/av1_config.c
index 9174341..f44141a 100644
--- a/common/av1_config.c
+++ b/common/av1_config.c
@@ -252,6 +252,9 @@
 #endif  // CONFIG_OUTPUT_FRAME_BASED_ON_ORDER_HINT
   AV1C_READ_BIT_OR_RETURN_ERROR(enable_sdp);
   AV1C_READ_BIT_OR_RETURN_ERROR(enable_ist);
+#if CONFIG_INTER_IST
+  AV1C_READ_BIT_OR_RETURN_ERROR(enable_inter_ist);
+#endif  // CONFIG_INTER_IST
   AV1C_READ_BIT_OR_RETURN_ERROR(enable_cctx);
   AV1C_READ_BIT_OR_RETURN_ERROR(enable_mrls);
   AV1C_READ_BIT_OR_RETURN_ERROR(enable_tip);
diff --git a/tools/aom_entropy_optimizer.c b/tools/aom_entropy_optimizer.c
index e834c09..f88324a 100644
--- a/tools/aom_entropy_optimizer.c
+++ b/tools/aom_entropy_optimizer.c
@@ -718,12 +718,22 @@
                      "[CDF_SIZE(DELTA_Q_PROBS + 1)]",
                      0, &total_count, 0, mem_wanted, "Filters");
 
+#if CONFIG_INTER_IST
+  cts_each_dim[0] = 2;
+  cts_each_dim[1] = TX_SIZES;
+  cts_each_dim[2] = STX_TYPES;
+  optimize_cdf_table(&fc.stx_cnts[0][0][0], probsfile, 3, cts_each_dim,
+                     "static aom_cdf_prob default_stx_cdf"
+                     "[2][TX_SIZES][CDF_SIZE(STX_TYPES)]",
+                     0, &total_count, 0, mem_wanted, "Transforms");
+#else
   cts_each_dim[0] = TX_SIZES;
   cts_each_dim[1] = STX_TYPES;
   optimize_cdf_table(&fc.stx_cnts[0][0], probsfile, 2, cts_each_dim,
                      "static aom_cdf_prob default_stx_cdf"
                      "[TX_SIZES][CDF_SIZE(STX_TYPES)]",
                      0, &total_count, 0, mem_wanted, "Transforms");
+#endif  // CONFIG_INTER_IST
 
 #if CONFIG_IST_ANY_SET
 #if CONFIG_INTRA_TX_IST_PARSE