Enable tx size search optimization for winner mode

For speed >= 3, tx mode is set conservatively during mode evaluation
and FULL_RD tx mode is set for final mode winner.
This change is not applicable for key-frames.

            Encode Time
cpu-used     Reduction      Quality Loss
   3           5.04%           +0.02%
   4           3.51%           +0.07%

STATS_CHANGED

Change-Id: Id69ade04e82122b6fdd9d10e1127d294205c62db
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 199074c..83c4960 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -443,6 +443,11 @@
   // Strong color activity detection. Used in REALTIME coding mode to enhance
   // the visual quality at the boundary of moving color objects.
   uint8_t color_sensitivity[2];
+
+  // Used to control the tx size search evaluation for mode processing
+  // (normal/winner mode)
+  int tx_size_search_method;
+  TX_MODE tx_mode;
 };
 
 static INLINE int is_rect_tx_allowed_bsize(BLOCK_SIZE bsize) {
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 5f93c21..93d8de7 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -762,6 +762,7 @@
   // coefficients for mode decision
   x->coeff_opt_dist_threshold =
       get_rd_opt_coeff_thresh(cpi->coeff_opt_dist_threshold, 0, 0);
+  set_tx_size_search_method(cpi, x, 0, 1);
 
   // Save rdmult before it might be changed, so it can be restored later.
   const int orig_rdmult = x->rdmult;
@@ -4200,17 +4201,6 @@
                          cm->seq_params.subsampling_y, num_planes);
 }
 
-static TX_MODE select_tx_mode(const AV1_COMP *cpi) {
-  if (cpi->common.coded_lossless) return ONLY_4X4;
-  if (cpi->sf.tx_size_search_method == USE_LARGESTALL)
-    return TX_MODE_LARGEST;
-  else if (cpi->sf.tx_size_search_method == USE_FULL_RD ||
-           cpi->sf.tx_size_search_method == USE_FAST_RD)
-    return TX_MODE_SELECT;
-  else
-    return cpi->common.tx_mode;
-}
-
 void av1_alloc_tile_data(AV1_COMP *cpi) {
   AV1_COMMON *const cm = &cpi->common;
   const int tile_cols = cm->tile_cols;
@@ -4695,7 +4685,7 @@
   cm->coded_lossless = is_coded_lossless(cm, xd);
   cm->all_lossless = cm->coded_lossless && !av1_superres_scaled(cm);
 
-  cm->tx_mode = select_tx_mode(cpi);
+  cm->tx_mode = select_tx_mode(cpi, cpi->sf.tx_size_search_method);
 
   // Fix delta q resolution for the moment
   cm->delta_q_info.delta_q_res = 0;
diff --git a/av1/encoder/nonrd_pickmode.c b/av1/encoder/nonrd_pickmode.c
index 2c72345..e3542da 100644
--- a/av1/encoder/nonrd_pickmode.c
+++ b/av1/encoder/nonrd_pickmode.c
@@ -499,13 +499,14 @@
 }
 
 static TX_SIZE calculate_tx_size(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
-                                 MACROBLOCKD *const xd, unsigned int var,
+                                 MACROBLOCK *const x, unsigned int var,
                                  unsigned int sse) {
+  MACROBLOCKD *const xd = &x->e_mbd;
   TX_SIZE tx_size;
-  if (cpi->common.tx_mode == TX_MODE_SELECT) {
+  if (x->tx_mode == TX_MODE_SELECT) {
     if (sse > (var << 2))
       tx_size = AOMMIN(max_txsize_lookup[bsize],
-                       tx_mode_to_biggest_tx_size[cpi->common.tx_mode]);
+                       tx_mode_to_biggest_tx_size[x->tx_mode]);
     else
       tx_size = TX_8X8;
 
@@ -516,7 +517,7 @@
       tx_size = TX_16X16;
   } else {
     tx_size = AOMMIN(max_txsize_lookup[bsize],
-                     tx_mode_to_biggest_tx_size[cpi->common.tx_mode]);
+                     tx_mode_to_biggest_tx_size[x->tx_mode]);
   }
   if (bsize > BLOCK_32X32) tx_size = TX_16X16;
   return AOMMIN(tx_size, TX_16X16);
@@ -623,7 +624,7 @@
   ac_thr *= ac_thr_factor(cpi->oxcf.speed, cpi->common.width,
                           cpi->common.height, abs(sum) >> (bw + bh));
 
-  tx_size = calculate_tx_size(cpi, bsize, xd, var, sse);
+  tx_size = calculate_tx_size(cpi, bsize, x, var, sse);
   // The code below for setting skip flag assumes tranform size of at least 8x8,
   // so force this lower limit on transform.
   if (tx_size < TX_8X8) tx_size = TX_8X8;
@@ -703,7 +704,7 @@
 
   unsigned int var = cpi->fn_ptr[bsize].vf(p->src.buf, p->src.stride,
                                            pd->dst.buf, pd->dst.stride, &sse);
-  xd->mi[0]->tx_size = calculate_tx_size(cpi, bsize, xd, var, sse);
+  xd->mi[0]->tx_size = calculate_tx_size(cpi, bsize, x, var, sse);
 
   if (cpi->sf.use_modeled_non_rd_cost) {
     const int bwide = block_size_wide[bsize];
@@ -1322,7 +1323,7 @@
     init_mbmi(mi, this_mode, ref_frame, NONE_FRAME, cm);
 
     mi->tx_size = AOMMIN(AOMMIN(max_txsize_lookup[bsize],
-                                tx_mode_to_biggest_tx_size[cm->tx_mode]),
+                                tx_mode_to_biggest_tx_size[x->tx_mode]),
                          TX_16X16);
     memset(mi->inter_tx_size, mi->tx_size, sizeof(mi->inter_tx_size));
     memset(mi->txk_type, DCT_DCT, sizeof(mi->txk_type[0]) * TXK_TYPE_BUF_LEN);
@@ -1624,7 +1625,7 @@
     PRED_BUFFER *const best_pred = best_pickmode.best_pred;
     TX_SIZE intra_tx_size =
         AOMMIN(AOMMIN(max_txsize_lookup[bsize],
-                      tx_mode_to_biggest_tx_size[cpi->common.tx_mode]),
+                      tx_mode_to_biggest_tx_size[x->tx_mode]),
                TX_16X16);
 
     if (reuse_inter_pred && best_pred != NULL) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 3476a37..ff1a3c8 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3500,10 +3500,10 @@
   }
 }
 
-static int tx_size_cost(const AV1_COMMON *const cm, const MACROBLOCK *const x,
-                        BLOCK_SIZE bsize, TX_SIZE tx_size) {
+static int tx_size_cost(const MACROBLOCK *const x, BLOCK_SIZE bsize,
+                        TX_SIZE tx_size) {
   assert(bsize == x->e_mbd.mi[0]->sb_type);
-  if (cm->tx_mode != TX_MODE_SELECT || !block_signals_txsize(bsize)) return 0;
+  if (x->tx_mode != TX_MODE_SELECT || !block_signals_txsize(bsize)) return 0;
 
   const int32_t tx_size_cat = bsize_to_tx_size_cat(bsize);
   const int depth = tx_size_to_depth(tx_size, bsize);
@@ -3516,7 +3516,6 @@
                         RD_STATS *rd_stats, int64_t ref_best_rd, BLOCK_SIZE bs,
                         TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
                         int skip_trellis) {
-  const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
   int64_t rd = INT64_MAX;
@@ -3524,11 +3523,11 @@
   int s0, s1;
   const int is_inter = is_inter_block(mbmi);
   const int tx_select =
-      cm->tx_mode == TX_MODE_SELECT && block_signals_txsize(mbmi->sb_type);
+      x->tx_mode == TX_MODE_SELECT && block_signals_txsize(mbmi->sb_type);
   int ctx = txfm_partition_context(
       xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
-  const int r_tx_size = is_inter ? x->txfm_partition_cost[ctx][0]
-                                 : tx_size_cost(cm, x, bs, tx_size);
+  const int r_tx_size =
+      is_inter ? x->txfm_partition_cost[ctx][0] : tx_size_cost(x, bs, tx_size);
 
   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
 
@@ -3610,10 +3609,9 @@
 static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
                                    RD_STATS *rd_stats, int64_t ref_best_rd,
                                    BLOCK_SIZE bs) {
-  const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
-  mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode);
+  mbmi->tx_size = tx_size_from_tx_mode(bs, x->tx_mode);
   const int skip_ctx = av1_get_skip_context(xd);
   int s0, s1;
 
@@ -3646,8 +3644,9 @@
 }
 
 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
-                                 const SPEED_FEATURES *sf) {
-  if (sf->tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
+                                 const SPEED_FEATURES *sf,
+                                 int tx_size_search_method) {
+  if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
 
   if (sf->tx_size_search_lgr_block) {
     if (mi_width > mi_size_wide[BLOCK_64X64] ||
@@ -3669,20 +3668,20 @@
                                         int64_t ref_best_rd, BLOCK_SIZE bs) {
   av1_invalid_rd_stats(rd_stats);
 
-  const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
-  const int tx_select = cm->tx_mode == TX_MODE_SELECT;
+  const int tx_select = x->tx_mode == TX_MODE_SELECT;
   int start_tx;
   int depth, init_depth;
 
   if (tx_select) {
     start_tx = max_rect_tx_size;
     init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
-                                       is_inter_block(mbmi), &cpi->sf);
+                                       is_inter_block(mbmi), &cpi->sf,
+                                       x->tx_size_search_method);
   } else {
-    const TX_SIZE chosen_tx_size = tx_size_from_tx_mode(bs, cm->tx_mode);
+    const TX_SIZE chosen_tx_size = tx_size_from_tx_mode(bs, x->tx_mode);
     start_tx = chosen_tx_size;
     init_depth = MAX_TX_DEPTH;
   }
@@ -3968,7 +3967,7 @@
 
   if (xd->lossless[xd->mi[0]->segment_id]) {
     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
-  } else if (cpi->sf.tx_size_search_method == USE_LARGESTALL) {
+  } else if (x->tx_size_search_method == USE_LARGESTALL) {
     choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
   } else {
     choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
@@ -4108,7 +4107,7 @@
   RD_STATS this_rd_stats;
   int row, col;
   int64_t temp_sse, this_rd;
-  TX_SIZE tx_size = tx_size_from_tx_mode(bsize, cm->tx_mode);
+  TX_SIZE tx_size = tx_size_from_tx_mode(bsize, x->tx_mode);
   const int stepr = tx_size_high_unit[tx_size];
   const int stepc = tx_size_wide_unit[tx_size];
   const int max_blocks_wide = max_block_wide(xd, bsize, 0);
@@ -4236,8 +4235,7 @@
   int this_rate = tokenonly_rd_stats.rate + palette_mode_cost;
   int64_t this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->sb_type)) {
-    tokenonly_rd_stats.rate -=
-        tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
+    tokenonly_rd_stats.rate -= tx_size_cost(x, bsize, mbmi->tx_size);
   }
   if (this_rd < *best_rd) {
     *best_rd = this_rd;
@@ -4689,7 +4687,7 @@
     // tokenonly rate, but for intra blocks, tx_size is always coded
     // (prediction granularity), so we account for it in the full rate,
     // not the tokenonly rate.
-    this_rate_tokenonly -= tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
+    this_rate_tokenonly -= tx_size_cost(x, bsize, mbmi->tx_size);
   }
   const int this_rate =
       rd_stats.rate +
@@ -4757,6 +4755,9 @@
   x->coeff_opt_dist_threshold =
       get_rd_opt_coeff_thresh(cpi->coeff_opt_dist_threshold,
                               cpi->sf.enable_winner_mode_for_coeff_opt, 0);
+  // Set the transform size search method for mode evaluation
+  set_tx_size_search_method(cpi, x, cpi->sf.enable_winner_mode_for_tx_size_srch,
+                            0);
 
   MB_MODE_INFO best_mbmi = *mbmi;
   /* Y Search for intra prediction mode */
@@ -4800,8 +4801,7 @@
       // tokenonly rate, but for intra blocks, tx_size is always coded
       // (prediction granularity), so we account for it in the full rate,
       // not the tokenonly rate.
-      this_rate_tokenonly -=
-          tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
+      this_rate_tokenonly -= tx_size_cost(x, bsize, mbmi->tx_size);
     }
     this_rate =
         this_rd_stats.rate +
@@ -4838,15 +4838,20 @@
   // If previous searches use only the default tx type/no R-D optimization of
   // quantized coeffs, do an extra search for the best tx type/better R-D
   // optimization of quantized coeffs
+  // TODO(any) : Refactor the winner mode evaluation check control code
   if ((cpi->sf.tx_type_search.fast_intra_tx_type_search &&
        !cpi->oxcf.use_intra_default_tx_only) ||
       (cpi->sf.enable_winner_mode_for_coeff_opt &&
        (cpi->optimize_seg_arr[mbmi->segment_id] != NO_TRELLIS_OPT &&
-        cpi->optimize_seg_arr[mbmi->segment_id] != FINAL_PASS_TRELLIS_OPT))) {
+        cpi->optimize_seg_arr[mbmi->segment_id] != FINAL_PASS_TRELLIS_OPT)) ||
+      cpi->sf.enable_winner_mode_for_tx_size_srch) {
     // Get the threshold for R-D optimization of coefficients for winner mode
     x->coeff_opt_dist_threshold =
         get_rd_opt_coeff_thresh(cpi->coeff_opt_dist_threshold,
                                 cpi->sf.enable_winner_mode_for_coeff_opt, 1);
+    // Set the transform size search method for winner mode processing
+    set_tx_size_search_method(cpi, x,
+                              cpi->sf.enable_winner_mode_for_tx_size_srch, 1);
     *mbmi = best_mbmi;
     x->use_default_intra_tx_type = 0;
     intra_block_yrd(cpi, x, bsize, bmode_costs, &best_rd, rate, rate_tokenonly,
@@ -5310,7 +5315,7 @@
   // will use more complex search given that the transform partitions have
   // already been decided.
 
-  const int fast_tx_search = cpi->sf.tx_size_search_method > USE_FULL_RD;
+  const int fast_tx_search = x->tx_size_search_method > USE_FULL_RD;
   int64_t rd_thresh = ref_best_rd;
   if (fast_tx_search && rd_thresh < INT64_MAX) {
     if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
@@ -5336,8 +5341,8 @@
   const int skip_ctx = av1_get_skip_context(xd);
   const int s0 = x->skip_cost[skip_ctx][0];
   const int s1 = x->skip_cost[skip_ctx][1];
-  const int init_depth =
-      get_search_init_depth(mi_width, mi_height, 1, &cpi->sf);
+  const int init_depth = get_search_init_depth(mi_width, mi_height, 1, &cpi->sf,
+                                               x->tx_size_search_method);
   const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
   const int bh = tx_size_high_unit[max_tx_size];
   const int bw = tx_size_wide_unit[max_tx_size];
@@ -5512,8 +5517,8 @@
     const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, plane_bsize, 0);
     const int bh = tx_size_high_unit[max_tx_size];
     const int bw = tx_size_wide_unit[max_tx_size];
-    const int init_depth =
-        get_search_init_depth(mi_width, mi_height, 1, &cpi->sf);
+    const int init_depth = get_search_init_depth(
+        mi_width, mi_height, 1, &cpi->sf, x->tx_size_search_method);
     int idx, idy;
     int block = 0;
     int step = tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
@@ -8933,7 +8938,7 @@
 
   // cost and distortion
   av1_subtract_plane(x, bsize, 0);
-  if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
+  if (x->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
     pick_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, mi_row, mi_col, rd_thresh);
 #if CONFIG_COLLECT_RD_STATS == 2
     PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
@@ -11002,6 +11007,8 @@
   // decision
   x->coeff_opt_dist_threshold =
       get_rd_opt_coeff_thresh(cpi->coeff_opt_dist_threshold, 0, 0);
+  // Set the transform size search method for mode evaluation
+  set_tx_size_search_method(cpi, x, 0, 0);
 
   if (intra_yrd < best_rd) {
     // Only store reconstructed luma when there's chroma RDO. When there's no
@@ -11219,7 +11226,7 @@
 
     // Set up tx_size related variables for skip-specific loop filtering.
     search_state->best_mbmode.tx_size =
-        block_signals_txsize(bsize) ? tx_size_from_tx_mode(bsize, cm->tx_mode)
+        block_signals_txsize(bsize) ? tx_size_from_tx_mode(bsize, x->tx_mode)
                                     : max_txsize_rect_lookup[bsize];
     memset(search_state->best_mbmode.inter_tx_size,
            search_state->best_mbmode.tx_size,
@@ -11276,6 +11283,7 @@
   MB_MODE_INFO *const mbmi = xd->mi[0];
   const int num_planes = av1_num_planes(cm);
 
+  // TODO(any) : Refactor the winner mode evaluation check control code
   if (xd->lossless[mbmi->segment_id] == 0 && best_mode_index >= 0 &&
       ((sf->tx_type_search.fast_inter_tx_type_search &&
         !cpi->oxcf.use_inter_dct_only && is_inter_mode(best_mbmode->mode)) ||
@@ -11284,7 +11292,8 @@
         !is_inter_mode(best_mbmode->mode)) ||
        (cpi->sf.enable_winner_mode_for_coeff_opt &&
         (cpi->optimize_seg_arr[mbmi->segment_id] != NO_TRELLIS_OPT &&
-         cpi->optimize_seg_arr[mbmi->segment_id] != FINAL_PASS_TRELLIS_OPT)))) {
+         cpi->optimize_seg_arr[mbmi->segment_id] != FINAL_PASS_TRELLIS_OPT)) ||
+       cpi->sf.enable_winner_mode_for_tx_size_srch)) {
     int skip_blk = 0;
     RD_STATS rd_stats_y, rd_stats_uv;
     const int skip_ctx = av1_get_skip_context(xd);
@@ -11296,6 +11305,9 @@
     x->coeff_opt_dist_threshold =
         get_rd_opt_coeff_thresh(cpi->coeff_opt_dist_threshold,
                                 cpi->sf.enable_winner_mode_for_coeff_opt, 1);
+    // Set the transform size search method for winner mode processing
+    set_tx_size_search_method(cpi, x,
+                              cpi->sf.enable_winner_mode_for_tx_size_srch, 1);
 
     *mbmi = *best_mbmode;
 
@@ -11315,7 +11327,7 @@
         av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
 
       av1_subtract_plane(x, bsize, 0);
-      if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
+      if (x->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
         pick_tx_size_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col,
                               INT64_MAX);
         assert(rd_stats_y.rate != INT_MAX);
@@ -11700,6 +11712,9 @@
   x->coeff_opt_dist_threshold =
       get_rd_opt_coeff_thresh(cpi->coeff_opt_dist_threshold,
                               cpi->sf.enable_winner_mode_for_coeff_opt, 0);
+  // Set the transform size search method for mode evaluation
+  set_tx_size_search_method(cpi, x, cpi->sf.enable_winner_mode_for_tx_size_srch,
+                            0);
 
   if (cpi->sf.skip_repeat_interpolation_filter_search) {
     x->interp_filter_stats_idx[0] = 0;
@@ -12310,7 +12325,7 @@
     // tokenonly rate, but for intra blocks, tx_size is always coded
     // (prediction granularity), so we account for it in the full rate,
     // not the tokenonly rate.
-    rd_stats_y->rate -= tx_size_cost(cm, x, bsize, mbmi->tx_size);
+    rd_stats_y->rate -= tx_size_cost(x, bsize, mbmi->tx_size);
   }
   if (num_planes > 1 && !x->skip_chroma_rd) {
     const int uv_mode_cost =
@@ -13158,6 +13173,8 @@
   // Get the threshold for R-D optimization of coefficients for mode evaluation
   x->coeff_opt_dist_threshold =
       get_rd_opt_coeff_thresh(cpi->coeff_opt_dist_threshold, 0, 0);
+  // Set the transform size search method for winner mode processing
+  set_tx_size_search_method(cpi, x, 0, 0);
 
   // Only try palette mode when the best mode so far is an intra mode.
   const int try_palette =
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 9fb7f5d..9441caf 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -217,6 +217,29 @@
          USABLE_REF_MV_STACK_SIZE * sizeof(xd->ref_mv_stack[0][0]));
 }
 
+static TX_MODE select_tx_mode(
+    const AV1_COMP *cpi, const TX_SIZE_SEARCH_METHOD tx_size_search_method) {
+  if (cpi->common.coded_lossless) return ONLY_4X4;
+  if (tx_size_search_method == USE_LARGESTALL)
+    return TX_MODE_LARGEST;
+  else if (tx_size_search_method == USE_FULL_RD ||
+           tx_size_search_method == USE_FAST_RD)
+    return TX_MODE_SELECT;
+  else
+    return cpi->common.tx_mode;
+}
+
+static INLINE void set_tx_size_search_method(
+    const struct AV1_COMP *cpi, MACROBLOCK *x,
+    int enable_winner_mode_for_tx_size_srch, int is_winner_mode) {
+  // Populate transform size search method/transform mode appropriately
+  if (enable_winner_mode_for_tx_size_srch && !is_winner_mode) {
+    x->tx_size_search_method = USE_LARGESTALL;
+  } else {
+    x->tx_size_search_method = cpi->sf.tx_size_search_method;
+  }
+  x->tx_mode = select_tx_mode(cpi, x->tx_size_search_method);
+}
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index f0c3ce4..d60b2cb 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -342,7 +342,6 @@
 
   if (speed >= 3) {
     sf->reduce_inter_modes = boosted ? 1 : 2;
-    sf->tx_size_search_method = boosted ? USE_FULL_RD : USE_LARGESTALL;
     sf->less_rectangular_check_level = 2;
     // adaptive_motion_search breaks encoder multi-thread tests.
     // The values in x->pred_mv[] differ for single and multi-thread cases.
@@ -368,7 +367,12 @@
         !frame_is_intra_only(&cpi->common) || (cpi->rc.frames_to_key != 1);
     // TODO(any): Experiment on the dependency of this speed feature with
     // use_intra_txb_hash, use_inter_txb_hash and use_mb_rd_hash speed features
+    // TODO(any): Refactor the code related to following winner mode speed
+    // features
     sf->enable_winner_mode_for_coeff_opt = 1;
+    // TODO(any): Experiment with this speed feature by enabling for key frames
+    sf->enable_winner_mode_for_tx_size_srch =
+        frame_is_intra_only(&cpi->common) ? 0 : 1;
     sf->reduce_wiener_window_size = is_boosted_arf2_bwd_type ? 0 : 1;
     sf->mv.subpel_search_method = SUBPEL_TREE_PRUNED;
   }
@@ -859,6 +863,7 @@
   sf->disable_wedge_interintra_search = 0;
   sf->perform_coeff_opt = 0;
   sf->enable_winner_mode_for_coeff_opt = 0;
+  sf->enable_winner_mode_for_tx_size_srch = 0;
   sf->prune_comp_type_by_model_rd = 0;
   sf->disable_smooth_intra = 0;
   sf->perform_best_rd_based_gating_for_chroma = 0;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 0cf9d3d..f4ea4f6 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -673,6 +673,10 @@
   // of quantized coeffs
   int enable_winner_mode_for_coeff_opt;
 
+  // Flag used to control the winner mode processing for transform size
+  // search method
+  int enable_winner_mode_for_tx_size_srch;
+
   // Flag used to control the speed of the eob selection in trellis.
   int trellis_eob_fast;