diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 10afca9..e1840bd 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -2390,28 +2390,29 @@
   pd->dqcoeff = orig_dqcoeff;
 }
 
-// Pick transform type for a transform block of tx_size.
+// Pick transform type for a luma transform block of tx_size. Note this function
+// is used only for inter-predicted blocks.
 static AOM_INLINE void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
                                   TX_SIZE tx_size, int blk_row, int blk_col,
-                                  int plane, int block, int plane_bsize,
-                                  TXB_CTX *txb_ctx, RD_STATS *rd_stats,
+                                  int block, int plane_bsize, TXB_CTX *txb_ctx,
+                                  RD_STATS *rd_stats,
                                   FAST_TX_SEARCH_MODE ftxs_mode,
                                   int64_t ref_rdcost,
                                   TXB_RD_INFO *rd_info_array) {
-  const struct macroblock_plane *const p = &x->plane[plane];
+  const struct macroblock_plane *const p = &x->plane[0];
   const uint16_t cur_joint_ctx =
       (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
   MACROBLOCKD *xd = &x->e_mbd;
-  const int tx_type_map_idx =
-      plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
+  assert(is_inter_block(xd->mi[0]));
+  const int tx_type_map_idx = blk_row * xd->tx_type_map_stride + blk_col;
   // Look up RD and terminate early in case when we've already processed exactly
-  // the same residual with exactly the same entropy context.
+  // the same residue with exactly the same entropy context.
   if (rd_info_array != NULL && rd_info_array->valid &&
       rd_info_array->entropy_context == cur_joint_ctx) {
-    if (plane == 0) xd->tx_type_map[tx_type_map_idx] = rd_info_array->tx_type;
+    xd->tx_type_map[tx_type_map_idx] = rd_info_array->tx_type;
     const TX_TYPE ref_tx_type =
-        av1_get_tx_type(&x->e_mbd, get_plane_type(plane), blk_row, blk_col,
-                        tx_size, cpi->common.features.reduced_tx_set_used);
+        av1_get_tx_type(&x->e_mbd, get_plane_type(0), blk_row, blk_col, tx_size,
+                        cpi->common.features.reduced_tx_set_used);
     if (ref_tx_type == rd_info_array->tx_type) {
       rd_stats->rate += rd_info_array->rate;
       rd_stats->dist += rd_info_array->dist;
@@ -2424,9 +2425,8 @@
   }
 
   RD_STATS this_rd_stats;
-
   const int skip_trellis = 0;
-  search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+  search_txk_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
                   txb_ctx, ftxs_mode, 0, skip_trellis, ref_rdcost,
                   &this_rd_stats);
 
@@ -2441,7 +2441,7 @@
     rd_info_array->sse = this_rd_stats.sse;
     rd_info_array->eob = p->eobs[block];
     rd_info_array->txb_entropy_ctx = p->txb_entropy_ctx[block];
-    if (plane == 0) rd_info_array->tx_type = xd->tx_type_map[tx_type_map_idx];
+    rd_info_array->tx_type = xd->tx_type_map[tx_type_map_idx];
   }
 }
 
@@ -2466,7 +2466,7 @@
   rd_stats->zero_rate = zero_blk_rate;
   const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
   mbmi->inter_tx_size[index] = tx_size;
-  tx_type_rd(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize, &txb_ctx,
+  tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
              rd_stats, ftxs_mode, ref_best_rd,
              rd_info_node != NULL ? rd_info_node->rd_info_array : NULL);
   assert(rd_stats->rate < INT_MAX);
@@ -2910,27 +2910,28 @@
   return rd;
 }
 
-// Finds rd cost for a y block, given the transform size partitions
+// Search for the best transform type for a luma inter-predicted block, given
+// the transform block partitions.
+// This function is used only when some speed features are enabled.
 static AOM_INLINE void tx_block_yrd(
     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
     TX_SIZE tx_size, BLOCK_SIZE plane_bsize, int depth,
     ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
     TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, int64_t ref_best_rd,
     RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode) {
+  assert(tx_size < TX_SIZES_ALL);
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
+  assert(is_inter_block(mbmi));
   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
 
-  assert(tx_size < TX_SIZES_ALL);
-
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
 
   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
       plane_bsize, blk_row, blk_col)];
-
-  int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
-                                   mbmi->sb_type, tx_size);
+  const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
+                                         mbmi->sb_type, tx_size);
 
   av1_init_rd_stats(rd_stats);
   if (tx_size == plane_tx_size) {
@@ -2943,8 +2944,8 @@
     const int zero_blk_rate = x->coeff_costs[txs_ctx][get_plane_type(0)]
                                   .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
     rd_stats->zero_rate = zero_blk_rate;
-    tx_type_rd(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize,
-               &txb_ctx, rd_stats, ftxs_mode, ref_best_rd, NULL);
+    tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
+               rd_stats, ftxs_mode, ref_best_rd, NULL);
     const int mi_width = mi_size_wide[plane_bsize];
     if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
             RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
@@ -2967,18 +2968,17 @@
                           tx_size);
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsw = tx_size_wide_unit[sub_txs];
-    const int bsh = tx_size_high_unit[sub_txs];
-    const int step = bsh * bsw;
+    const int txb_width = tx_size_wide_unit[sub_txs];
+    const int txb_height = tx_size_high_unit[sub_txs];
+    const int step = txb_height * txb_width;
     RD_STATS pn_rd_stats;
     int64_t this_rd = 0;
-    assert(bsw > 0 && bsh > 0);
+    assert(txb_width > 0 && txb_height > 0);
 
-    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
-      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += txb_height) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += txb_width) {
         const int offsetr = blk_row + row;
         const int offsetc = blk_col + col;
-
         if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
         av1_init_rd_stats(&pn_rd_stats);
