Improving the model for pruning the TX type search

Introduces two new TX type pruning modes that provide better
speed-quality trade-off compared to the existing ones. A shallow
neural network with one hidden layer trained separately for each
block size is used as a prediction model. The new modes differ in
thresholds applied to the output of the neural net, so that they
prune different number of TX types on average.

Owing to relatively low quality loss PRUNE_2D_ACCURATE is used
by default, regardless of speed settings. Starting with speed
setting of 3 we switch to PRUNE_2D_FAST mode to get better
speed-up.

Evaluation results:
----------------------------------------------------------
Prune mode | Avg. speed-up | Quality loss | Quality loss
           |(high bitrates)|   (lowres)   |   (midres)
----------------------------------------------------------
PRUNE_ONE  |     18.7%     |    0.396%    |    0.308%
----------------------------------------------------------
PRUNE_TWO  |     27.2%     |    0.439%    |    0.389%
----------------------------------------------------------
PRUNE_2D_  |     18.8%     |    0.032%    |    0.063%
ACCURATE   |               |              |
----------------------------------------------------------
PRUNE_2D_  |     33.3%     |    0.504%    |     ---
FAST       |               |              |

Change-Id: Ibd59f52eef493a499e529d824edad267daa65f9d
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index ff4d204..6698244 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -62,6 +62,9 @@
 #include "av1/encoder/rd.h"
 #include "av1/encoder/rdopt.h"
 #include "av1/encoder/tokenize.h"
+#if CONFIG_EXT_TX
+#include "av1/encoder/tx_prune_model_weights.h"
+#endif  // CONFIG_EXT_TX
 #if CONFIG_PVQ
 #include "av1/encoder/pvq_encoder.h"
 #include "av1/common/pvq.h"
@@ -1146,50 +1149,315 @@
   { 1, 0, 0, 1 },
 #endif  // CONFIG_MRC_TX
 };
+
+static void get_energy_distribution_finer(const int16_t *diff, int stride,
+                                          int bw, int bh, float *hordist,
+                                          float *verdist) {
+  // First compute downscaled block energy values (esq); downscale factors
+  // are defined by w_shift and h_shift.
+  unsigned int esq[256];
+  const int w_shift = bw <= 8 ? 0 : 1;
+  const int h_shift = bh <= 8 ? 0 : 1;
+  const int esq_w = bw <= 8 ? bw : bw / 2;
+  const int esq_h = bh <= 8 ? bh : bh / 2;
+  const int esq_sz = esq_w * esq_h;
+  int i, j;
+  memset(esq, 0, esq_sz * sizeof(esq[0]));
+  for (i = 0; i < bh; i++) {
+    unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
+    const int16_t *cur_diff_row = diff + i * stride;
+    for (j = 0; j < bw; j++) {
+      cur_esq_row[j >> w_shift] += cur_diff_row[j] * cur_diff_row[j];
+    }
+  }
+
+  uint64_t total = 0;
+  for (i = 0; i < esq_sz; i++) total += esq[i];
+
+  // Output hordist and verdist arrays are normalized 1D projections of esq
+  if (total == 0) {
+    float hor_val = 1.0f / esq_w;
+    for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
+    float ver_val = 1.0f / esq_h;
+    for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
+    return;
+  }
+
+  const float e_recip = 1.0f / (float)total;
+  memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
+  memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
+  const unsigned int *cur_esq_row;
+  for (i = 0; i < esq_h - 1; i++) {
+    cur_esq_row = esq + i * esq_w;
+    for (j = 0; j < esq_w - 1; j++) {
+      hordist[j] += (float)cur_esq_row[j];
+      verdist[i] += (float)cur_esq_row[j];
+    }
+    verdist[i] += (float)cur_esq_row[j];
+  }
+  cur_esq_row = esq + i * esq_w;
+  for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
+
+  for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
+  for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
+}
+
+// Similar to get_horver_correlation, but also takes into account first
+// row/column, when computing horizontal/vertical correlation.
+static void get_horver_correlation_full(const int16_t *diff, int stride, int w,
+                                        int h, float *hcorr, float *vcorr) {
+  const float num_hor = h * (w - 1);
+  const float num_ver = (h - 1) * w;
+  int i, j;
+
+  // The following notation is used:
+  // x - current pixel
+  // y - left neighbor pixel
+  // z - top neighbor pixel
+  int64_t xy_sum = 0, xz_sum = 0;
+  int64_t xhor_sum = 0, xver_sum = 0, y_sum = 0, z_sum = 0;
+  int64_t x2hor_sum = 0, x2ver_sum = 0, y2_sum = 0, z2_sum = 0;
+
+  int16_t x, y, z;
+  for (j = 1; j < w; ++j) {
+    x = diff[j];
+    y = diff[j - 1];
+    xy_sum += x * y;
+    xhor_sum += x;
+    y_sum += y;
+    x2hor_sum += x * x;
+    y2_sum += y * y;
+  }
+  for (i = 1; i < h; ++i) {
+    x = diff[i * stride];
+    z = diff[(i - 1) * stride];
+    xz_sum += x * z;
+    xver_sum += x;
+    z_sum += z;
+    x2ver_sum += x * x;
+    z2_sum += z * z;
+    for (j = 1; j < w; ++j) {
+      x = diff[i * stride + j];
+      y = diff[i * stride + j - 1];
+      z = diff[(i - 1) * stride + j];
+      xy_sum += x * y;
+      xz_sum += x * z;
+      xhor_sum += x;
+      xver_sum += x;
+      y_sum += y;
+      z_sum += z;
+      x2hor_sum += x * x;
+      x2ver_sum += x * x;
+      y2_sum += y * y;
+      z2_sum += z * z;
+    }
+  }
+  const float xhor_var_n = x2hor_sum - (xhor_sum * xhor_sum) / num_hor;
+  const float y_var_n = y2_sum - (y_sum * y_sum) / num_hor;
+  const float xy_var_n = xy_sum - (xhor_sum * y_sum) / num_hor;
+  const float xver_var_n = x2ver_sum - (xver_sum * xver_sum) / num_ver;
+  const float z_var_n = z2_sum - (z_sum * z_sum) / num_ver;
+  const float xz_var_n = xz_sum - (xver_sum * z_sum) / num_ver;
+
+  *hcorr = *vcorr = 1;
+  if (xhor_var_n > 0 && y_var_n > 0) {
+    *hcorr = xy_var_n / sqrtf(xhor_var_n * y_var_n);
+    *hcorr = *hcorr < 0 ? 0 : *hcorr;
+  }
+  if (xver_var_n > 0 && z_var_n > 0) {
+    *vcorr = xz_var_n / sqrtf(xver_var_n * z_var_n);
+    *vcorr = *vcorr < 0 ? 0 : *vcorr;
+  }
+}
+
+// Performs a forward pass through a neural network with 2 fully-connected
+// layers, assuming ReLU as activation function. Number of output neurons
+// is always equal to 4.
+// fc1, fc2 - weight matrices of the respective layers.
+// b1, b2 - bias vectors of the respective layers.
+static void compute_1D_scores(float *features, int num_features,
+                              const float *fc1, const float *b1,
+                              const float *fc2, const float *b2,
+                              int num_hidden_units, float *dst_scores) {
+  assert(num_hidden_units <= 32);
+  float hidden_layer[32];
+  for (int i = 0; i < num_hidden_units; i++) {
+    const float *cur_coef = fc1 + i * num_features;
+    hidden_layer[i] = 0.0f;
+    for (int j = 0; j < num_features; j++)
+      hidden_layer[i] += cur_coef[j] * features[j];
+    hidden_layer[i] = AOMMAX(hidden_layer[i] + b1[i], 0.0f);
+  }
+  for (int i = 0; i < 4; i++) {
+    const float *cur_coef = fc2 + i * num_hidden_units;
+    dst_scores[i] = 0.0f;
+    for (int j = 0; j < num_hidden_units; j++)
+      dst_scores[i] += cur_coef[j] * hidden_layer[j];
+    dst_scores[i] += b2[i];
+  }
+}
+
+// Transforms raw scores into a probability distribution across 16 TX types
+static void score_2D_transform_pow8(float *scores_2D, float shift) {
+  float sum = 0.0f;
+  int i;
+
+  for (i = 0; i < 16; i++) {
+    float v, v2, v4;
+    v = AOMMAX(scores_2D[i] + shift, 0.0f);
+    v2 = v * v;
+    v4 = v2 * v2;
+    scores_2D[i] = v4 * v4;
+    sum += scores_2D[i];
+  }
+  for (i = 0; i < 16; i++) scores_2D[i] /= sum;
+}
+
+static int prune_tx_types_2D(BLOCK_SIZE bsize, const MACROBLOCK *x,
+                             int tx_set_type, int pruning_aggressiveness) {
+  if (bsize >= BLOCK_32X32) return 0;
+  const struct macroblock_plane *const p = &x->plane[0];
+  const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
+  const float score_thresh =
+      av1_prune_2D_adaptive_thresholds[bidx][pruning_aggressiveness - 1];
+
+  float hfeatures[16], vfeatures[16];
+  float hscores[4], vscores[4];
+  float scores_2D[16];
+  int tx_type_table_2D[16] = {
+    DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
+    ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
+    FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
+    H_DCT,        H_ADST,        H_FLIPADST,        IDTX
+  };
+  const int bw = block_size_wide[bsize], bh = block_size_high[bsize];
+  const int hfeatures_num = bw <= 8 ? bw : bw / 2;
+  const int vfeatures_num = bh <= 8 ? bh : bh / 2;
+  assert(hfeatures_num <= 16);
+  assert(vfeatures_num <= 16);
+
+  get_energy_distribution_finer(p->src_diff, bw, bw, bh, hfeatures, vfeatures);
+  get_horver_correlation_full(p->src_diff, bw, bw, bh,
+                              &hfeatures[hfeatures_num - 1],
+                              &vfeatures[vfeatures_num - 1]);
+
+  const float *fc1_hor = av1_prune_2D_learned_weights_hor[bidx];
+  const float *b1_hor =
+      fc1_hor + av1_prune_2D_num_hidden_units_hor[bidx] * hfeatures_num;
+  const float *fc2_hor = b1_hor + av1_prune_2D_num_hidden_units_hor[bidx];
+  const float *b2_hor = fc2_hor + av1_prune_2D_num_hidden_units_hor[bidx] * 4;
+  compute_1D_scores(hfeatures, hfeatures_num, fc1_hor, b1_hor, fc2_hor, b2_hor,
+                    av1_prune_2D_num_hidden_units_hor[bidx], hscores);
+
+  const float *fc1_ver = av1_prune_2D_learned_weights_ver[bidx];
+  const float *b1_ver =
+      fc1_ver + av1_prune_2D_num_hidden_units_ver[bidx] * vfeatures_num;
+  const float *fc2_ver = b1_ver + av1_prune_2D_num_hidden_units_ver[bidx];
+  const float *b2_ver = fc2_ver + av1_prune_2D_num_hidden_units_ver[bidx] * 4;
+  compute_1D_scores(vfeatures, vfeatures_num, fc1_ver, b1_ver, fc2_ver, b2_ver,
+                    av1_prune_2D_num_hidden_units_ver[bidx], vscores);
+
+  float score_2D_average = 0.0f;
+  for (int i = 0; i < 4; i++) {
+    float *cur_scores_2D = scores_2D + i * 4;
+    cur_scores_2D[0] = vscores[i] * hscores[0];
+    cur_scores_2D[1] = vscores[i] * hscores[1];
+    cur_scores_2D[2] = vscores[i] * hscores[2];
+    cur_scores_2D[3] = vscores[i] * hscores[3];
+    score_2D_average += cur_scores_2D[0] + cur_scores_2D[1] + cur_scores_2D[2] +
+                        cur_scores_2D[3];
+  }
+  score_2D_average /= 16;
+  score_2D_transform_pow8(scores_2D, (20 - score_2D_average));
+
+  // Always keep the TX type with the highest score, prune all others with
+  // score below score_thresh.
+  int max_score_i = 0;
+  float max_score = 0.0f;
+  for (int i = 0; i < 16; i++) {
+    if (scores_2D[i] > max_score &&
+        av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
+      max_score = scores_2D[i];
+      max_score_i = i;
+    }
+  }
+
+  int prune_bitmask = 0;
+  for (int i = 0; i < 16; i++) {
+    if (scores_2D[i] < score_thresh && i != max_score_i)
+      prune_bitmask |= (1 << tx_type_table_2D[i]);
+  }
+
+  return prune_bitmask;
+}
 #endif  // CONFIG_EXT_TX
 
 static int prune_tx_types(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
-                          const MACROBLOCKD *const xd, int tx_set) {
+                          const MACROBLOCKD *const xd, int tx_set_type) {
 #if CONFIG_EXT_TX
-  const int *tx_set_1D = tx_set >= 0 ? ext_tx_used_inter_1D[tx_set] : NULL;
+  int tx_set = ext_tx_set_index[1][tx_set_type];
+  assert(tx_set >= 0);
+  const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
 #else
   const int tx_set_1D[TX_TYPES_1D] = { 0 };
+  (void)tx_set_type;
 #endif  // CONFIG_EXT_TX
 
   switch (cpi->sf.tx_type_search.prune_mode) {
     case NO_PRUNE: return 0; break;
     case PRUNE_ONE:
-      if ((tx_set >= 0) && !(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D]))
-        return 0;
+      if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return 0;
       return prune_one_for_sby(cpi, bsize, x, xd);
       break;
 #if CONFIG_EXT_TX
     case PRUNE_TWO:
-      if ((tx_set >= 0) && !(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
+      if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
         if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return 0;
         return prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
       }
-      if ((tx_set >= 0) && !(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
+      if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
         return prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
       return prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
       break;
+    case PRUNE_2D_ACCURATE:
+      if (tx_set_type == EXT_TX_SET_ALL16)
+        return prune_tx_types_2D(bsize, x, tx_set_type, 6);
+      else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+        return prune_tx_types_2D(bsize, x, tx_set_type, 4);
+      else
+        return 0;
+      break;
+    case PRUNE_2D_FAST:
+      if (tx_set_type == EXT_TX_SET_ALL16)
+        return prune_tx_types_2D(bsize, x, tx_set_type, 10);
+      else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+        return prune_tx_types_2D(bsize, x, tx_set_type, 7);
+      else
+        return 0;
+      break;
 #endif  // CONFIG_EXT_TX
   }
   assert(0);
   return 0;
 }
 
-static int do_tx_type_search(TX_TYPE tx_type, int prune) {
+static int do_tx_type_search(TX_TYPE tx_type, int prune,
+                             TX_TYPE_PRUNE_MODE mode) {
 // TODO(sarahparker) implement for non ext tx
 #if CONFIG_EXT_TX
-  return !(((prune >> vtx_tab[tx_type]) & 1) |
-           ((prune >> (htx_tab[tx_type] + 8)) & 1));
+  if (mode >= PRUNE_2D_ACCURATE) {
+    return !((prune >> tx_type) & 1);
+  } else {
+    return !(((prune >> vtx_tab[tx_type]) & 1) |
+             ((prune >> (htx_tab[tx_type] + 8)) & 1));
+  }
 #else
   // temporary to avoid compiler warnings
   (void)vtx_tab;
   (void)htx_tab;
   (void)tx_type;
   (void)prune;
+  (void)mode;
   return 1;
 #endif  // CONFIG_EXT_TX
 }
@@ -2290,16 +2558,11 @@
 }
 
 static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
-                            TX_TYPE tx_type, TX_SIZE tx_size) {
+                            TX_TYPE tx_type, TX_SIZE tx_size, int prune) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   const MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const TX_SIZE max_tx_size = max_txsize_lookup[bs];
   const int is_inter = is_inter_block(mbmi);
-  int prune = 0;
-  if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
-    // passing -1 in for tx_type indicates that all 1D
-    // transforms should be considered for pruning
-    prune = prune_tx_types(cpi, bs, x, xd, -1);
 
 #if CONFIG_MRC_TX
   // MRC_DCT only implemented for TX_32X32 so only include this tx in
@@ -2329,7 +2592,8 @@
   if (!av1_ext_tx_used[tx_set_type][tx_type]) return 1;
   if (is_inter) {
     if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
-      if (!do_tx_type_search(tx_type, prune)) return 1;
+      if (!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
+        return 1;
     }
   } else {
     if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
@@ -2339,7 +2603,7 @@
 #else   // CONFIG_EXT_TX
   if (tx_size >= TX_32X32 && tx_type != DCT_DCT) return 1;
   if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
-      !do_tx_type_search(tx_type, prune))
+      !do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
     return 1;
 #endif  // CONFIG_EXT_TX
   return 0;
@@ -2389,18 +2653,18 @@
   mbmi->min_tx_size = get_min_tx_size(mbmi->tx_size);
 #endif  // CONFIG_VAR_TX
 #if CONFIG_EXT_TX
-  int ext_tx_set =
-      get_ext_tx_set(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used);
   const TxSetType tx_set_type =
       get_ext_tx_set_type(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
 
-  if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
+  if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
+      !x->use_default_inter_tx_type) {
 #if CONFIG_EXT_TX
-    prune = prune_tx_types(cpi, bs, x, xd, ext_tx_set);
+    prune = prune_tx_types(cpi, bs, x, xd, tx_set_type);
 #else
     prune = prune_tx_types(cpi, bs, x, xd, 0);
 #endif  // CONFIG_EXT_TX
+  }
 #if CONFIG_EXT_TX
   if (get_ext_tx_types(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used) >
           1 &&
@@ -2420,7 +2684,9 @@
             tx_type != get_default_tx_type(0, xd, 0, mbmi->tx_size))
           continue;
         if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
-          if (!do_tx_type_search(tx_type, prune)) continue;
+          if (!do_tx_type_search(tx_type, prune,
+                                 cpi->sf.tx_type_search.prune_mode))
+            continue;
         }
       } else {
         if (x->use_default_intra_tx_type &&
@@ -2512,7 +2778,8 @@
       av1_tx_type_cost(cm, x, xd, bs, plane, mbmi->tx_size, tx_type);
       if (is_inter) {
         if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
-            !do_tx_type_search(tx_type, prune))
+            !do_tx_type_search(tx_type, prune,
+                               cpi->sf.tx_type_search.prune_mode))
           continue;
       }
       if (this_rd_stats.skip)
@@ -2733,6 +3000,16 @@
     end_tx = chosen_tx_size;
   }
 
+  int prune = 0;
+  if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
+      !x->use_default_inter_tx_type) {
+#if CONFIG_EXT_TX
+    prune = prune_tx_types(cpi, bs, x, xd, EXT_TX_SET_ALL16);
+#else
+    prune = prune_tx_types(cpi, bs, x, xd, 0);
+#endif  // CONFIG_EXT_TX
+  }
+
   last_rd = INT64_MAX;
   for (n = start_tx; n >= end_tx; --n) {
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -2748,7 +3025,7 @@
     TX_TYPE tx_type;
     for (tx_type = tx_start; tx_type < tx_end; ++tx_type) {
       RD_STATS this_rd_stats;
-      if (skip_txfm_search(cpi, x, bs, tx_type, n)) continue;
+      if (skip_txfm_search(cpi, x, bs, tx_type, n, prune)) continue;
       rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, n);
 #if CONFIG_PVQ
       od_encode_rollback(&x->daala_enc, &buf);
@@ -2785,8 +3062,8 @@
     }
 #if CONFIG_LGT_FROM_PRED
     mbmi->use_lgt = 1;
-    if (is_lgt_allowed(mbmi->mode, n) && !skip_txfm_search(cpi, x, bs, 0, n) &&
-        !breakout) {
+    if (is_lgt_allowed(mbmi->mode, n) &&
+        !skip_txfm_search(cpi, x, bs, 0, n, prune) && !breakout) {
       RD_STATS this_rd_stats;
       rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, 0, n);
       if (rd < best_rd) {
@@ -5310,8 +5587,6 @@
 #if CONFIG_EXT_TX
   const TxSetType tx_set_type = get_ext_tx_set_type(
       max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
-  const int ext_tx_set =
-      get_ext_tx_set(max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
 #endif  // CONFIG_EXT_TX
 
   av1_invalid_rd_stats(rd_stats);
@@ -5353,12 +5628,14 @@
     }
   }
 
-  if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
+  if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
+      !x->use_default_inter_tx_type && !xd->lossless[mbmi->segment_id]) {
 #if CONFIG_EXT_TX
-    prune = prune_tx_types(cpi, bsize, x, xd, ext_tx_set);
+    prune = prune_tx_types(cpi, bsize, x, xd, tx_set_type);
 #else
     prune = prune_tx_types(cpi, bsize, x, xd, 0);
 #endif  // CONFIG_EXT_TX
+  }
 
   int found = 0;
 
@@ -5377,7 +5654,9 @@
     if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
     if (is_inter) {
       if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
-        if (!do_tx_type_search(tx_type, prune)) continue;
+        if (!do_tx_type_search(tx_type, prune,
+                               cpi->sf.tx_type_search.prune_mode))
+          continue;
       }
     } else {
       if (!ALLOW_INTRA_EXT_TX && bsize >= BLOCK_8X8) {
@@ -5386,7 +5665,7 @@
     }
 #else   // CONFIG_EXT_TX
     if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
-        !do_tx_type_search(tx_type, prune))
+        !do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
       continue;
 #endif  // CONFIG_EXT_TX
     if (is_inter && x->use_default_inter_tx_type &&