Enable rectangular transforms for Intra also.

These are under EXT_TX + RECT_TX experiment combo.

Results
=======

Derf Set:
--------
All Intra frames: 1.8% avg improvement (and 1.78% BD-rate improvement)
Video: 0.230% avg improvement (and 0.262% BD-rate improvement)

Objective-1-fast set
--------------------
Video: 0.52 PSNR improvement

Change-Id: I1893465929858e38419f327752dc61c19b96b997
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 9418017..5369e2b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1076,8 +1076,11 @@
       const PLANE_TYPE plane_type = plane == 0 ? PLANE_TYPE_Y : PLANE_TYPE_UV;
 
       INV_TXFM_PARAM inv_txfm_param;
+      const int block_raster_idx =
+          av1_block_index_to_raster_order(tx_size, block);
 
-      inv_txfm_param.tx_type = get_tx_type(plane_type, xd, block, tx_size);
+      inv_txfm_param.tx_type =
+          get_tx_type(plane_type, xd, block_raster_idx, tx_size);
       inv_txfm_param.tx_size = tx_size;
       inv_txfm_param.eob = eob;
       inv_txfm_param.lossless = xd->lossless[mbmi->segment_id];
@@ -1360,6 +1363,29 @@
 }
 #endif  // CONFIG_SUPERTX
 
+static int tx_size_cost(const AV1_COMP *const cpi, MACROBLOCK *x,
+                        BLOCK_SIZE bsize, TX_SIZE tx_size) {
+  const AV1_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+
+  const int tx_select =
+      cm->tx_mode == TX_MODE_SELECT && mbmi->sb_type >= BLOCK_8X8;
+
+  if (tx_select) {
+    const int is_inter = is_inter_block(mbmi);
+    const int tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
+                                     : intra_tx_size_cat_lookup[bsize];
+    const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
+    const int depth = tx_size_to_depth(coded_tx_size);
+    const int tx_size_ctx = get_tx_size_context(xd);
+    const int r_tx_size = cpi->tx_size_cost[tx_size_cat][tx_size_ctx][depth];
+    return r_tx_size;
+  } else {
+    return 0;
+  }
+}
+
 static int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
                         RD_STATS *rd_stats, int64_t ref_best_rd, BLOCK_SIZE bs,
                         TX_TYPE tx_type, int tx_size) {
@@ -1370,16 +1396,10 @@
   aom_prob skip_prob = av1_get_skip_prob(cm, xd);
   int s0, s1;
   const int is_inter = is_inter_block(mbmi);
-
-  const int tx_size_cat =
-      is_inter ? inter_tx_size_cat_lookup[bs] : intra_tx_size_cat_lookup[bs];
-  const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
-  const int depth = tx_size_to_depth(coded_tx_size);
   const int tx_select =
       cm->tx_mode == TX_MODE_SELECT && mbmi->sb_type >= BLOCK_8X8;
-  const int tx_size_ctx = tx_select ? get_tx_size_context(xd) : 0;
-  const int r_tx_size =
-      tx_select ? cpi->tx_size_cost[tx_size_cat][tx_size_ctx][depth] : 0;
+
+  const int r_tx_size = tx_size_cost(cpi, x, bs, tx_size);
 
   assert(skip_prob > 0);
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -1405,8 +1425,9 @@
                                     [mbmi->tx_type];
     } else {
       if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
-        rd_stats->rate += cpi->intra_tx_type_costs[ext_tx_set][mbmi->tx_size]
-                                                  [mbmi->mode][mbmi->tx_type];
+        rd_stats->rate +=
+            cpi->intra_tx_type_costs[ext_tx_set][txsize_sqr_map[mbmi->tx_size]]
+                                    [mbmi->mode][mbmi->tx_type];
     }
   }
 #else
@@ -1468,6 +1489,7 @@
 #endif  // CONFIG_RECT_TX
   int ext_tx_set;
 #endif  // CONFIG_EXT_TX
+  assert(bs >= BLOCK_8X8);
 
   if (tx_select) {
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -1494,8 +1516,9 @@
   if (evaluate_rect_tx) {
     const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
     RD_STATS this_rd_stats;
-    ext_tx_set = get_ext_tx_set(rect_tx_size, bs, 1);
-    if (ext_tx_used_inter[ext_tx_set][tx_type]) {
+    ext_tx_set = get_ext_tx_set(rect_tx_size, bs, is_inter);
+    if ((is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) ||
+        (!is_inter && ext_tx_used_intra[ext_tx_set][tx_type])) {
       rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type,
                     rect_tx_size);
       best_tx_size = rect_tx_size;
@@ -1651,13 +1674,15 @@
         if (is_inter) {
           if (ext_tx_set > 0)
             this_rd_stats.rate +=
-                cpi->inter_tx_type_costs[ext_tx_set][mbmi->tx_size]
+                cpi->inter_tx_type_costs[ext_tx_set]
+                                        [txsize_sqr_map[mbmi->tx_size]]
                                         [mbmi->tx_type];
         } else {
           if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
             this_rd_stats.rate +=
-                cpi->intra_tx_type_costs[ext_tx_set][mbmi->tx_size][mbmi->mode]
-                                        [mbmi->tx_type];
+                cpi->intra_tx_type_costs[ext_tx_set]
+                                        [txsize_sqr_map[mbmi->tx_size]]
+                                        [mbmi->mode][mbmi->tx_type];
         }
       }
 
@@ -1977,10 +2002,7 @@
       }
       this_rd = RDCOST(x->rdmult, x->rddiv, this_rate, tokenonly_rd_stats.dist);
       if (!xd->lossless[mbmi->segment_id] && mbmi->sb_type >= BLOCK_8X8) {
-        tokenonly_rd_stats.rate -=
-            cpi->tx_size_cost[max_txsize_lookup[bsize] - TX_8X8]
-                             [get_tx_size_context(xd)]
-                             [tx_size_to_depth(mbmi->tx_size)];
+        tokenonly_rd_stats.rate -= tx_size_cost(cpi, x, bsize, mbmi->tx_size);
       }
       if (this_rd < *best_rd) {
         *best_rd = this_rd;
@@ -2005,11 +2027,48 @@
 }
 #endif  // CONFIG_PALETTE
 
-static int64_t rd_pick_intra4x4block(
+// Wrappers to make function pointers usable.
+static void inv_txfm_add_4x8_wrapper(const tran_low_t *input, uint8_t *dest,
+                                     int stride, int eob, TX_TYPE tx_type,
+                                     int lossless) {
+  (void)lossless;
+  av1_inv_txfm_add_4x8(input, dest, stride, eob, tx_type);
+}
+
+static void inv_txfm_add_8x4_wrapper(const tran_low_t *input, uint8_t *dest,
+                                     int stride, int eob, TX_TYPE tx_type,
+                                     int lossless) {
+  (void)lossless;
+  av1_inv_txfm_add_8x4(input, dest, stride, eob, tx_type);
+}
+
+typedef void (*inv_txfm_func_ptr)(const tran_low_t *, uint8_t *, int, int,
+                                  TX_TYPE, int);
+#if CONFIG_AOM_HIGHBITDEPTH
+
+void highbd_inv_txfm_add_4x8_wrapper(const tran_low_t *input, uint8_t *dest,
+                                     int stride, int eob, int bd,
+                                     TX_TYPE tx_type, int is_lossless) {
+  (void)is_lossless;
+  av1_highbd_inv_txfm_add_4x8(input, dest, stride, eob, bd, tx_type);
+}
+
+void highbd_inv_txfm_add_8x4_wrapper(const tran_low_t *input, uint8_t *dest,
+                                     int stride, int eob, int bd,
+                                     TX_TYPE tx_type, int is_lossless) {
+  (void)is_lossless;
+  av1_highbd_inv_txfm_add_8x4(input, dest, stride, eob, bd, tx_type);
+}
+
+typedef void (*highbd_inv_txfm_func_ptr)(const tran_low_t *, uint8_t *, int,
+                                         int, int, TX_TYPE, int);
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+
+static int64_t rd_pick_intra_sub_8x8_y_subblock_mode(
     const AV1_COMP *const cpi, MACROBLOCK *x, int row, int col,
     PREDICTION_MODE *best_mode, const int *bmode_costs, ENTROPY_CONTEXT *a,
     ENTROPY_CONTEXT *l, int *bestrate, int *bestratey, int64_t *bestdistortion,
-    BLOCK_SIZE bsize, int *y_skip, int64_t rd_thresh) {
+    BLOCK_SIZE bsize, TX_SIZE tx_size, int *y_skip, int64_t rd_thresh) {
   const AV1_COMMON *const cm = &cpi->common;
   PREDICTION_MODE mode;
   MACROBLOCKD *const xd = &x->e_mbd;
@@ -2029,14 +2088,38 @@
   ENTROPY_CONTEXT ta[2], tempa[2];
   ENTROPY_CONTEXT tl[2], templ[2];
 #endif
-  const int num_4x4_blocks_wide = num_4x4_blocks_wide_lookup[bsize];
-  const int num_4x4_blocks_high = num_4x4_blocks_high_lookup[bsize];
+
+  const int pred_width_in_4x4_blocks = num_4x4_blocks_wide_lookup[bsize];
+  const int pred_height_in_4x4_blocks = num_4x4_blocks_high_lookup[bsize];
+  const int tx_width_unit = tx_size_wide_unit[tx_size];
+  const int tx_height_unit = tx_size_high_unit[tx_size];
+  const int pred_block_width = block_size_wide[bsize];
+  const int pred_block_height = block_size_high[bsize];
+  const int tx_width = tx_size_wide[tx_size];
+  const int tx_height = tx_size_high[tx_size];
+  const int pred_width_in_transform_blocks = pred_block_width / tx_width;
+  const int pred_height_in_transform_blocks = pred_block_height / tx_height;
   int idx, idy;
   int best_can_skip = 0;
   uint8_t best_dst[8 * 8];
+  inv_txfm_func_ptr inv_txfm_func =
+      (tx_size == TX_4X4) ? av1_inv_txfm_add_4x4
+                          : (tx_size == TX_4X8) ? inv_txfm_add_4x8_wrapper
+                                                : inv_txfm_add_8x4_wrapper;
 #if CONFIG_AOM_HIGHBITDEPTH
   uint16_t best_dst16[8 * 8];
+  highbd_inv_txfm_func_ptr highbd_inv_txfm_func =
+      (tx_size == TX_4X4)
+          ? av1_highbd_inv_txfm_add_4x4
+          : (tx_size == TX_4X8) ? highbd_inv_txfm_add_4x8_wrapper
+                                : highbd_inv_txfm_add_8x4_wrapper;
 #endif
+  const int is_lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+  const int sub_bsize = bsize;
+#else
+  const int sub_bsize = BLOCK_4X4;
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
 
 #if CONFIG_PVQ
   od_rollback_buffer pre_buf, post_buf;
@@ -2044,9 +2127,19 @@
   od_encode_checkpoint(&x->daala_enc, &post_buf);
 #endif
 
-  memcpy(ta, a, num_4x4_blocks_wide * sizeof(a[0]));
-  memcpy(tl, l, num_4x4_blocks_high * sizeof(l[0]));
-  xd->mi[0]->mbmi.tx_size = TX_4X4;
+  assert(bsize < BLOCK_8X8);
+  assert(tx_width < 8 || tx_height < 8);
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+  assert(tx_width == pred_block_width && tx_height == pred_block_height);
+#else
+  assert(tx_width == 4 && tx_height == 4);
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
+
+  memcpy(ta, a, pred_width_in_transform_blocks * sizeof(a[0]));
+  memcpy(tl, l, pred_height_in_transform_blocks * sizeof(l[0]));
+
+  xd->mi[0]->mbmi.tx_size = tx_size;
+
 #if CONFIG_PALETTE
   xd->mi[0]->mbmi.palette_mode_info.palette_size[0] = 0;
 #endif  // CONFIG_PALETTE
@@ -2060,7 +2153,9 @@
       int rate = bmode_costs[mode];
       int can_skip = 1;
 
-      if (!(cpi->sf.intra_y_mode_mask[TX_4X4] & (1 << mode))) continue;
+      if (!(cpi->sf.intra_y_mode_mask[txsize_sqr_up_map[tx_size]] &
+            (1 << mode)))
+        continue;
 
       // Only do the oblique modes if the best so far is
       // one of the neighboring directional modes
@@ -2068,70 +2163,97 @@
         if (conditional_skipintra(mode, *best_mode)) continue;
       }
 
-      memcpy(tempa, ta, num_4x4_blocks_wide * sizeof(ta[0]));
-      memcpy(templ, tl, num_4x4_blocks_high * sizeof(tl[0]));
+      memcpy(tempa, ta, pred_width_in_transform_blocks * sizeof(ta[0]));
+      memcpy(templ, tl, pred_height_in_transform_blocks * sizeof(tl[0]));
 
-      for (idy = 0; idy < num_4x4_blocks_high; ++idy) {
-        for (idx = 0; idx < num_4x4_blocks_wide; ++idx) {
-          const int block = (row + idy) * 2 + (col + idx);
+      for (idy = 0; idy < pred_height_in_transform_blocks; ++idy) {
+        for (idx = 0; idx < pred_width_in_transform_blocks; ++idx) {
+          const int block_raster_idx = (row + idy) * 2 + (col + idx);
+          const int block =
+              av1_raster_order_to_block_index(tx_size, block_raster_idx);
           const uint8_t *const src = &src_init[idx * 4 + idy * 4 * src_stride];
           uint8_t *const dst = &dst_init[idx * 4 + idy * 4 * dst_stride];
-          int16_t *const src_diff =
-              av1_raster_block_offset_int16(BLOCK_8X8, block, p->src_diff);
-          xd->mi[0]->bmi[block].as_mode = mode;
-          av1_predict_intra_block(xd, pd->width, pd->height, TX_4X4, mode, dst,
+          int16_t *const src_diff = av1_raster_block_offset_int16(
+              BLOCK_8X8, block_raster_idx, p->src_diff);
+          int skip;
+          assert(block < 4);
+          assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                         idx == 0 && idy == 0));
+          assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                         block == 0 || block == 2));
+          xd->mi[0]->bmi[block_raster_idx].as_mode = mode;
+          av1_predict_intra_block(xd, pd->width, pd->height, tx_size, mode, dst,
                                   dst_stride, dst, dst_stride, col + idx,
                                   row + idy, 0);
-          aom_highbd_subtract_block(4, 4, src_diff, 8, src, src_stride, dst,
-                                    dst_stride, xd->bd);
-          if (xd->lossless[xd->mi[0]->mbmi.segment_id]) {
-            TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, TX_4X4);
-            const SCAN_ORDER *scan_order = get_scan(cm, TX_4X4, tx_type, 0);
+          aom_highbd_subtract_block(tx_height, tx_width, src_diff, 8, src,
+                                    src_stride, dst, dst_stride, xd->bd);
+          if (is_lossless) {
+            TX_TYPE tx_type =
+                get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
+            const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
             const int coeff_ctx =
-                combine_entropy_contexts(*(tempa + idx), *(templ + idy));
+                combine_entropy_contexts(tempa[idx], templ[idy]);
 #if CONFIG_NEW_QUANT
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                            TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
 #else
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                            TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP);
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
 #endif  // CONFIG_NEW_QUANT
-            ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, TX_4X4,
+            ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, tx_size,
                                      scan_order->scan, scan_order->neighbors,
                                      cpi->sf.use_fast_coef_costing);
-            *(tempa + idx) = !(p->eobs[block] == 0);
-            *(templ + idy) = !(p->eobs[block] == 0);
-            can_skip &= (p->eobs[block] == 0);
+            skip = (p->eobs[block] == 0);
+            can_skip &= skip;
+            tempa[idx] = !skip;
+            templ[idy] = !skip;
+#if CONFIG_EXT_TX
+            if (tx_size == TX_8X4) {
+              tempa[idx + 1] = tempa[idx];
+            } else if (tx_size == TX_4X8) {
+              templ[idy + 1] = templ[idy];
+            }
+#endif  // CONFIG_EXT_TX
+
             if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
               goto next_highbd;
-            av1_highbd_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                        dst_stride, p->eobs[block], xd->bd,
-                                        DCT_DCT, 1);
+            highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
+                                 dst_stride, p->eobs[block], xd->bd, DCT_DCT,
+                                 1);
           } else {
             int64_t dist;
             unsigned int tmp;
-            TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, TX_4X4);
-            const SCAN_ORDER *scan_order = get_scan(cm, TX_4X4, tx_type, 0);
+            TX_TYPE tx_type =
+                get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
+            const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
             const int coeff_ctx =
-                combine_entropy_contexts(*(tempa + idx), *(templ + idy));
+                combine_entropy_contexts(tempa[idx], templ[idy]);
 #if CONFIG_NEW_QUANT
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                            TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
 #else
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                            TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP);
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
 #endif  // CONFIG_NEW_QUANT
-            av1_optimize_b(cm, x, 0, block, TX_4X4, coeff_ctx);
-            ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, TX_4X4,
+            av1_optimize_b(cm, x, 0, block, tx_size, coeff_ctx);
+            ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, tx_size,
                                      scan_order->scan, scan_order->neighbors,
                                      cpi->sf.use_fast_coef_costing);
-            *(tempa + idx) = !(p->eobs[block] == 0);
-            *(templ + idy) = !(p->eobs[block] == 0);
-            can_skip &= (p->eobs[block] == 0);
-            av1_highbd_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                        dst_stride, p->eobs[block], xd->bd,
-                                        tx_type, 0);
-            cpi->fn_ptr[BLOCK_4X4].vf(src, src_stride, dst, dst_stride, &tmp);
+            skip = (p->eobs[block] == 0);
+            can_skip &= skip;
+            tempa[idx] = !skip;
+            templ[idy] = !skip;
+#if CONFIG_EXT_TX
+            if (tx_size == TX_8X4) {
+              tempa[idx + 1] = tempa[idx];
+            } else if (tx_size == TX_4X8) {
+              templ[idy + 1] = templ[idy];
+            }
+#endif  // CONFIG_EXT_TX
+            highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
+                                 dst_stride, p->eobs[block], xd->bd, tx_type,
+                                 0);
+            cpi->fn_ptr[sub_bsize].vf(src, src_stride, dst, dst_stride, &tmp);
             dist = (int64_t)tmp << 4;
             distortion += dist;
             if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
@@ -2150,12 +2272,12 @@
         best_rd = this_rd;
         best_can_skip = can_skip;
         *best_mode = mode;
-        memcpy(a, tempa, num_4x4_blocks_wide * sizeof(tempa[0]));
-        memcpy(l, templ, num_4x4_blocks_high * sizeof(templ[0]));
-        for (idy = 0; idy < num_4x4_blocks_high * 4; ++idy) {
+        memcpy(a, tempa, pred_width_in_transform_blocks * sizeof(tempa[0]));
+        memcpy(l, templ, pred_height_in_transform_blocks * sizeof(templ[0]));
+        for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy) {
           memcpy(best_dst16 + idy * 8,
                  CONVERT_TO_SHORTPTR(dst_init + idy * dst_stride),
-                 num_4x4_blocks_wide * 4 * sizeof(uint16_t));
+                 pred_width_in_transform_blocks * 4 * sizeof(uint16_t));
         }
       }
     next_highbd : {}
@@ -2165,9 +2287,10 @@
 
     if (y_skip) *y_skip &= best_can_skip;
 
-    for (idy = 0; idy < num_4x4_blocks_high * 4; ++idy) {
+    for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy) {
       memcpy(CONVERT_TO_SHORTPTR(dst_init + idy * dst_stride),
-             best_dst16 + idy * 8, num_4x4_blocks_wide * 4 * sizeof(uint16_t));
+             best_dst16 + idy * 8,
+             pred_width_in_transform_blocks * 4 * sizeof(uint16_t));
     }
 
     return best_rd;
@@ -2185,7 +2308,10 @@
     int rate = bmode_costs[mode];
     int can_skip = 1;
 
-    if (!(cpi->sf.intra_y_mode_mask[TX_4X4] & (1 << mode))) continue;
+    if (!(cpi->sf.intra_y_mode_mask[txsize_sqr_up_map[tx_size]] &
+          (1 << mode))) {
+      continue;
+    }
 
     // Only do the oblique modes if the best so far is
     // one of the neighboring directional modes
@@ -2193,25 +2319,29 @@
       if (conditional_skipintra(mode, *best_mode)) continue;
     }
 
-    memcpy(tempa, ta, num_4x4_blocks_wide * sizeof(ta[0]));
-    memcpy(templ, tl, num_4x4_blocks_high * sizeof(tl[0]));
+    memcpy(tempa, ta, pred_width_in_transform_blocks * sizeof(ta[0]));
+    memcpy(templ, tl, pred_height_in_transform_blocks * sizeof(tl[0]));
 
-    for (idy = 0; idy < num_4x4_blocks_high; ++idy) {
-      for (idx = 0; idx < num_4x4_blocks_wide; ++idx) {
-        int block = (row + idy) * 2 + (col + idx);
+    for (idy = 0; idy < pred_height_in_4x4_blocks; idy += tx_height_unit) {
+      for (idx = 0; idx < pred_width_in_4x4_blocks; idx += tx_width_unit) {
+        const int block_raster_idx = (row + idy) * 2 + (col + idx);
+        int block = av1_raster_order_to_block_index(tx_size, block_raster_idx);
         const uint8_t *const src = &src_init[idx * 4 + idy * 4 * src_stride];
         uint8_t *const dst = &dst_init[idx * 4 + idy * 4 * dst_stride];
 #if !CONFIG_PVQ
-        int16_t *const src_diff =
-            av1_raster_block_offset_int16(BLOCK_8X8, block, p->src_diff);
+        int16_t *const src_diff = av1_raster_block_offset_int16(
+            BLOCK_8X8, block_raster_idx, p->src_diff);
 #else
-        int i, j, tx_blk_size;
-        int skip;
-
-        tx_blk_size = 4;
+        int i, j;
 #endif
-        xd->mi[0]->bmi[block].as_mode = mode;
-        av1_predict_intra_block(xd, pd->width, pd->height, TX_4X4, mode, dst,
+        int skip;
+        assert(block < 4);
+        assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                       idx == 0 && idy == 0));
+        assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                       block == 0 || block == 2));
+        xd->mi[0]->bmi[block_raster_idx].as_mode = mode;
+        av1_predict_intra_block(xd, pd->width, pd->height, tx_size, mode, dst,
                                 dst_stride, dst, dst_stride,
 #if CONFIG_CB4X4
                                 2 * (col + idx), 2 * (row + idy),
@@ -2220,21 +2350,23 @@
 #endif
                                 0);
 #if !CONFIG_PVQ
-        aom_subtract_block(4, 4, src_diff, 8, src, src_stride, dst, dst_stride);
+        aom_subtract_block(tx_height, tx_width, src_diff, 8, src, src_stride,
+                           dst, dst_stride);
 #endif
 
-        if (xd->lossless[xd->mi[0]->mbmi.segment_id]) {
-          TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, TX_4X4);
-          const SCAN_ORDER *scan_order = get_scan(cm, TX_4X4, tx_type, 0);
+        if (is_lossless) {
+          TX_TYPE tx_type =
+              get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
+          const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
           const int coeff_ctx =
-              combine_entropy_contexts(*(tempa + idx), *(templ + idy));
+              combine_entropy_contexts(tempa[idx], templ[idy]);
 #if CONFIG_CB4X4
           block = 4 * block;
 #endif
 #if !CONFIG_PVQ
 #if CONFIG_NEW_QUANT
           av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                          TX_4X4, coeff_ctx, AV1_XFORM_QUANT_B_NUQ);
+                          tx_size, coeff_ctx, AV1_XFORM_QUANT_B_NUQ);
 #else
           av1_xform_quant(cm, x, 0, block,
 #if CONFIG_CB4X4
@@ -2242,14 +2374,22 @@
 #else
                           row + idy, col + idx,
 #endif
-                          BLOCK_8X8, TX_4X4, coeff_ctx, AV1_XFORM_QUANT_B);
+                          BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_B);
 #endif  // CONFIG_NEW_QUANT
-          ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, TX_4X4,
+          ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, tx_size,
                                    scan_order->scan, scan_order->neighbors,
                                    cpi->sf.use_fast_coef_costing);
-          *(tempa + idx) = !(p->eobs[block] == 0);
-          *(templ + idy) = !(p->eobs[block] == 0);
-          can_skip &= (p->eobs[block] == 0);
+          skip = (p->eobs[block] == 0);
+          can_skip &= skip;
+          tempa[idx] = !skip;
+          templ[idy] = !skip;
+#if CONFIG_EXT_TX
+          if (tx_size == TX_8X4) {
+            tempa[idx + 1] = tempa[idx];
+          } else if (tx_size == TX_4X8) {
+            templ[idy + 1] = templ[idy];
+          }
+#endif  // CONFIG_EXT_TX
 #else
           (void)scan_order;
 
@@ -2259,40 +2399,41 @@
 #else
                           row + idy, col + idx,
 #endif
-                          BLOCK_8X8, TX_4X4, coeff_ctx, AV1_XFORM_QUANT_B);
+                          BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_B);
 
           ratey += x->rate;
           skip = x->pvq_skip[0];
-          *(tempa + idx) = !skip;
-          *(templ + idy) = !skip;
+          tempa[idx] = !skip;
+          templ[idy] = !skip;
           can_skip &= skip;
 #endif
           if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
             goto next;
 #if CONFIG_PVQ
           if (!skip) {
-            for (j = 0; j < tx_blk_size; j++)
-              for (i = 0; i < tx_blk_size; i++) dst[j * dst_stride + i] = 0;
+            for (j = 0; j < tx_height; j++)
+              for (i = 0; i < tx_width; i++) dst[j * dst_stride + i] = 0;
 #endif
-            av1_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                 dst_stride, p->eobs[block], DCT_DCT, 1);
+            inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst, dst_stride,
+                          p->eobs[block], DCT_DCT, 1);
 #if CONFIG_PVQ
           }
 #endif
         } else {
           int64_t dist;
           unsigned int tmp;
-          TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, TX_4X4);
-          const SCAN_ORDER *scan_order = get_scan(cm, TX_4X4, tx_type, 0);
+          TX_TYPE tx_type =
+              get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
+          const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
           const int coeff_ctx =
-              combine_entropy_contexts(*(tempa + idx), *(templ + idy));
+              combine_entropy_contexts(tempa[idx], templ[idy]);
 #if CONFIG_CB4X4
           block = 4 * block;
 #endif
 #if !CONFIG_PVQ
 #if CONFIG_NEW_QUANT
           av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                          TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
+                          tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
 #else
           av1_xform_quant(cm, x, 0, block,
 #if CONFIG_CB4X4
@@ -2300,15 +2441,23 @@
 #else
                           row + idy, col + idx,
 #endif
-                          BLOCK_8X8, TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP);
+                          BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
 #endif  // CONFIG_NEW_QUANT
-          av1_optimize_b(cm, x, 0, block, TX_4X4, coeff_ctx);
-          ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, TX_4X4,
+          av1_optimize_b(cm, x, 0, block, tx_size, coeff_ctx);
+          ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, tx_size,
                                    scan_order->scan, scan_order->neighbors,
                                    cpi->sf.use_fast_coef_costing);
-          *(tempa + idx) = !(p->eobs[block] == 0);
-          *(templ + idy) = !(p->eobs[block] == 0);
-          can_skip &= (p->eobs[block] == 0);
+          skip = (p->eobs[block] == 0);
+          can_skip &= skip;
+          tempa[idx] = !skip;
+          templ[idy] = !skip;
+#if CONFIG_EXT_TX
+          if (tx_size == TX_8X4) {
+            tempa[idx + 1] = tempa[idx];
+          } else if (tx_size == TX_4X8) {
+            templ[idy + 1] = templ[idy];
+          }
+#endif  // CONFIG_EXT_TX
 #else
           (void)scan_order;
 
@@ -2318,25 +2467,25 @@
 #else
                           row + idy, col + idx,
 #endif
-                          BLOCK_8X8, TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP);
+                          BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
           ratey += x->rate;
           skip = x->pvq_skip[0];
-          *(tempa + idx) = !skip;
-          *(templ + idy) = !skip;
+          tempa[idx] = !skip;
+          templ[idy] = !skip;
           can_skip &= skip;
 #endif
 #if CONFIG_PVQ
           if (!skip) {
-            for (j = 0; j < tx_blk_size; j++)
-              for (i = 0; i < tx_blk_size; i++) dst[j * dst_stride + i] = 0;
+            for (j = 0; j < tx_height; j++)
+              for (i = 0; i < tx_width; i++) dst[j * dst_stride + i] = 0;
 #endif
-            av1_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                 dst_stride, p->eobs[block], tx_type, 0);
+            inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst, dst_stride,
+                          p->eobs[block], tx_type, 0);
 #if CONFIG_PVQ
           }
 #endif
           // No need for av1_block_error2_c because the ssz is unused
-          cpi->fn_ptr[BLOCK_4X4].vf(src, src_stride, dst, dst_stride, &tmp);
+          cpi->fn_ptr[sub_bsize].vf(src, src_stride, dst, dst_stride, &tmp);
           dist = (int64_t)tmp << 4;
           distortion += dist;
           // To use the pixel domain distortion, the step below needs to be
@@ -2358,14 +2507,14 @@
       best_rd = this_rd;
       best_can_skip = can_skip;
       *best_mode = mode;
-      memcpy(a, tempa, num_4x4_blocks_wide * sizeof(tempa[0]));
-      memcpy(l, templ, num_4x4_blocks_high * sizeof(templ[0]));
+      memcpy(a, tempa, pred_width_in_transform_blocks * sizeof(tempa[0]));
+      memcpy(l, templ, pred_height_in_transform_blocks * sizeof(templ[0]));
 #if CONFIG_PVQ
       od_encode_checkpoint(&x->daala_enc, &post_buf);
 #endif
-      for (idy = 0; idy < num_4x4_blocks_high * 4; ++idy)
+      for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy)
         memcpy(best_dst + idy * 8, dst_init + idy * dst_stride,
-               num_4x4_blocks_wide * 4);
+               pred_width_in_transform_blocks * 4);
     }
   next : {}
 #if CONFIG_PVQ
@@ -2381,9 +2530,9 @@
 
   if (y_skip) *y_skip &= best_can_skip;
 
-  for (idy = 0; idy < num_4x4_blocks_high * 4; ++idy)
+  for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy)
     memcpy(dst_init + idy * dst_stride, best_dst + idy * 8,
-           num_4x4_blocks_wide * 4);
+           pred_width_in_transform_blocks * 4);
 
   return best_rd;
 }
@@ -2392,55 +2541,65 @@
                                             MACROBLOCK *mb, int *rate,
                                             int *rate_y, int64_t *distortion,
                                             int *y_skip, int64_t best_rd) {
-  int i, j;
   const MACROBLOCKD *const xd = &mb->e_mbd;
   MODE_INFO *const mic = xd->mi[0];
   const MODE_INFO *above_mi = xd->above_mi;
   const MODE_INFO *left_mi = xd->left_mi;
-  const BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
-  const int num_4x4_blocks_wide = num_4x4_blocks_wide_lookup[bsize];
-  const int num_4x4_blocks_high = num_4x4_blocks_high_lookup[bsize];
+  MB_MODE_INFO *const mbmi = &mic->mbmi;
+  const BLOCK_SIZE bsize = mbmi->sb_type;
+  const int pred_width_in_4x4_blocks = num_4x4_blocks_wide_lookup[bsize];
+  const int pred_height_in_4x4_blocks = num_4x4_blocks_high_lookup[bsize];
   int idx, idy;
   int cost = 0;
   int64_t total_distortion = 0;
   int tot_rate_y = 0;
   int64_t total_rd = 0;
   const int *bmode_costs = cpi->mbmode_cost[0];
+  const int is_lossless = xd->lossless[mbmi->segment_id];
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+  const TX_SIZE tx_size = is_lossless ? TX_4X4 : max_txsize_rect_lookup[bsize];
+#else
+  const TX_SIZE tx_size = TX_4X4;
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
 
 #if CONFIG_EXT_INTRA
 #if CONFIG_INTRA_INTERP
-  mic->mbmi.intra_filter = INTRA_FILTER_LINEAR;
+  mbmi->intra_filter = INTRA_FILTER_LINEAR;
 #endif  // CONFIG_INTRA_INTERP
 #endif  // CONFIG_EXT_INTRA
 #if CONFIG_FILTER_INTRA
-  mic->mbmi.filter_intra_mode_info.use_filter_intra_mode[0] = 0;
+  mbmi->filter_intra_mode_info.use_filter_intra_mode[0] = 0;
 #endif  // CONFIG_FILTER_INTRA
 
   // TODO(any): Add search of the tx_type to improve rd performance at the
   // expense of speed.
-  mic->mbmi.tx_type = DCT_DCT;
-  mic->mbmi.tx_size = TX_4X4;
+  mbmi->tx_type = DCT_DCT;
+  mbmi->tx_size = tx_size;
 
   if (y_skip) *y_skip = 1;
 
-  // Pick modes for each sub-block (of size 4x4, 4x8, or 8x4) in an 8x8 block.
-  for (idy = 0; idy < 2; idy += num_4x4_blocks_high) {
-    for (idx = 0; idx < 2; idx += num_4x4_blocks_wide) {
+  // Pick modes for each prediction sub-block (of size 4x4, 4x8, or 8x4) in this
+  // 8x8 coding block.
+  for (idy = 0; idy < 2; idy += pred_height_in_4x4_blocks) {
+    for (idx = 0; idx < 2; idx += pred_width_in_4x4_blocks) {
       PREDICTION_MODE best_mode = DC_PRED;
       int r = INT_MAX, ry = INT_MAX;
       int64_t d = INT64_MAX, this_rd = INT64_MAX;
-      i = idy * 2 + idx;
+      int j;
+      const int pred_block_idx = idy * 2 + idx;
       if (cpi->common.frame_type == KEY_FRAME) {
-        const PREDICTION_MODE A = av1_above_block_mode(mic, above_mi, i);
-        const PREDICTION_MODE L = av1_left_block_mode(mic, left_mi, i);
+        const PREDICTION_MODE A =
+            av1_above_block_mode(mic, above_mi, pred_block_idx);
+        const PREDICTION_MODE L =
+            av1_left_block_mode(mic, left_mi, pred_block_idx);
 
         bmode_costs = cpi->y_mode_costs[A][L];
       }
 
-      this_rd = rd_pick_intra4x4block(
+      this_rd = rd_pick_intra_sub_8x8_y_subblock_mode(
           cpi, mb, idy, idx, &best_mode, bmode_costs,
           xd->plane[0].above_context + idx, xd->plane[0].left_context + idy, &r,
-          &ry, &d, bsize, y_skip, best_rd - total_rd);
+          &ry, &d, bsize, tx_size, y_skip, best_rd - total_rd);
       if (this_rd >= best_rd - total_rd) return INT64_MAX;
 
       total_rd += this_rd;
@@ -2448,33 +2607,33 @@
       total_distortion += d;
       tot_rate_y += ry;
 
-      mic->bmi[i].as_mode = best_mode;
-      for (j = 1; j < num_4x4_blocks_high; ++j)
-        mic->bmi[i + j * 2].as_mode = best_mode;
-      for (j = 1; j < num_4x4_blocks_wide; ++j)
-        mic->bmi[i + j].as_mode = best_mode;
+      mic->bmi[pred_block_idx].as_mode = best_mode;
+      for (j = 1; j < pred_height_in_4x4_blocks; ++j)
+        mic->bmi[pred_block_idx + j * 2].as_mode = best_mode;
+      for (j = 1; j < pred_width_in_4x4_blocks; ++j)
+        mic->bmi[pred_block_idx + j].as_mode = best_mode;
 
       if (total_rd >= best_rd) return INT64_MAX;
     }
   }
-  mic->mbmi.mode = mic->bmi[3].as_mode;
+  mbmi->mode = mic->bmi[3].as_mode;
 
   // Add in the cost of the transform type
-  if (!xd->lossless[mic->mbmi.segment_id]) {
+  if (!is_lossless) {
     int rate_tx_type = 0;
 #if CONFIG_EXT_TX
-    if (get_ext_tx_types(TX_4X4, bsize, 0) > 1) {
-      const int eset = get_ext_tx_set(TX_4X4, bsize, 0);
-      rate_tx_type = cpi->intra_tx_type_costs[eset][TX_4X4][mic->mbmi.mode]
-                                             [mic->mbmi.tx_type];
+    if (get_ext_tx_types(tx_size, bsize, 0) > 1) {
+      const int eset = get_ext_tx_set(tx_size, bsize, 0);
+      rate_tx_type = cpi->intra_tx_type_costs[eset][txsize_sqr_map[tx_size]]
+                                             [mbmi->mode][mbmi->tx_type];
     }
 #else
     rate_tx_type =
-        cpi->intra_tx_type_costs[TX_4X4]
-                                [intra_mode_to_tx_type_context[mic->mbmi.mode]]
-                                [mic->mbmi.tx_type];
+        cpi->intra_tx_type_costs[txsize_sqr_map[tx_size]]
+                                [intra_mode_to_tx_type_context[mbmi->mode]]
+                                [mbmi->tx_type];
 #endif
-    assert(mic->mbmi.tx_size == TX_4X4);
+    assert(mbmi->tx_size == tx_size);
     cost += rate_tx_type;
     tot_rate_y += rate_tx_type;
   }
@@ -2884,7 +3043,6 @@
   const PREDICTION_MODE A = av1_above_block_mode(mic, above_mi, 0);
   const PREDICTION_MODE L = av1_left_block_mode(mic, left_mi, 0);
   const PREDICTION_MODE FINAL_MODE_SEARCH = TM_PRED + 1;
-  const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
 #if CONFIG_PVQ
   od_rollback_buffer pre_buf, post_buf;
 
@@ -2962,9 +3120,7 @@
       // tokenonly rate, but for intra blocks, tx_size is always coded
       // (prediction granularity), so we account for it in the full rate,
       // not the tokenonly rate.
-      this_rate_tokenonly -=
-          cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)]
-                           [tx_size_to_depth(mbmi->tx_size)];
+      this_rate_tokenonly -= tx_size_cost(cpi, x, bsize, mbmi->tx_size);
     }
 #if CONFIG_PALETTE
     if (cpi->common.allow_screen_content_tools && mbmi->mode == DC_PRED)
@@ -4073,7 +4229,9 @@
   pmi->palette_size[1] = 0;
 #endif  // CONFIG_PALETTE
   for (mode = DC_PRED; mode <= TM_PRED; ++mode) {
-    if (!(cpi->sf.intra_uv_mode_mask[max_tx_size] & (1 << mode))) continue;
+    if (!(cpi->sf.intra_uv_mode_mask[txsize_sqr_up_map[max_tx_size]] &
+          (1 << mode)))
+      continue;
 
     mbmi->uv_mode = mode;
 #if CONFIG_EXT_INTRA
@@ -4189,6 +4347,8 @@
   pmi->palette_size[1] = palette_mode_info.palette_size[1];
 #endif  // CONFIG_PALETTE
 
+  // Make sure we actually chose a mode
+  assert(best_rd < INT64_MAX);
   return best_rd;
 }
 
@@ -4550,16 +4710,11 @@
   for (idy = 0; idy < txb_height; idy += num_4x4_h) {
     for (idx = 0; idx < txb_width; idx += num_4x4_w) {
       int64_t dist, ssz, rd, rd1, rd2;
-      int block;
       int coeff_ctx;
-      int k;
-
-      k = i + (idy * 2 + idx);
-      if (tx_size == TX_4X4)
-        block = k;
-      else
-        block = (i ? 2 : 0);
-
+      const int k = i + (idy * 2 + idx);
+      const int block = av1_raster_order_to_block_index(tx_size, k);
+      assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                     idx == 0 && idy == 0));
       coeff_ctx = combine_entropy_contexts(*(ta + (k & 1)), *(tl + (k >> 1)));
 #if !CONFIG_PVQ
 #if CONFIG_NEW_QUANT
@@ -8414,7 +8569,6 @@
 #if CONFIG_PALETTE
   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
 #endif  // CONFIG_PALETTE
-  const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
   int rate2 = 0, rate_y = INT_MAX, skippable = 0, rate_uv, rate_dummy, i;
   int dc_mode_index;
   const int *const intra_mode_cost = cpi->mbmode_cost[size_group_lookup[bsize]];
@@ -8491,8 +8645,7 @@
     // tokenonly rate, but for intra blocks, tx_size is always coded
     // (prediction granularity), so we account for it in the full rate,
     // not the tokenonly rate.
-    rate_y -= cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)]
-                               [tx_size_to_depth(mbmi->tx_size)];
+    rate_y -= tx_size_cost(cpi, x, bsize, mbmi->tx_size);
   }
 
   rate2 += av1_cost_bit(cm->fc->filter_intra_probs[0],
@@ -8642,21 +8795,21 @@
   int64_t best_intra_rd = INT64_MAX;
   unsigned int best_pred_sse = UINT_MAX;
   PREDICTION_MODE best_intra_mode = DC_PRED;
-  int rate_uv_intra[TX_SIZES], rate_uv_tokenonly[TX_SIZES];
-  int64_t dist_uvs[TX_SIZES];
-  int skip_uvs[TX_SIZES];
-  PREDICTION_MODE mode_uv[TX_SIZES];
+  int rate_uv_intra[TX_SIZES_ALL], rate_uv_tokenonly[TX_SIZES_ALL];
+  int64_t dist_uvs[TX_SIZES_ALL];
+  int skip_uvs[TX_SIZES_ALL];
+  PREDICTION_MODE mode_uv[TX_SIZES_ALL];
 #if CONFIG_PALETTE
-  PALETTE_MODE_INFO pmi_uv[TX_SIZES];
+  PALETTE_MODE_INFO pmi_uv[TX_SIZES_ALL];
 #endif  // CONFIG_PALETTE
 #if CONFIG_EXT_INTRA
-  int8_t uv_angle_delta[TX_SIZES];
+  int8_t uv_angle_delta[TX_SIZES_ALL];
   int is_directional_mode, angle_stats_ready = 0;
   uint8_t directional_mode_skip_mask[INTRA_MODES];
 #endif  // CONFIG_EXT_INTRA
 #if CONFIG_FILTER_INTRA
   int8_t dc_skipped = 1;
-  FILTER_INTRA_MODE_INFO filter_intra_mode_info_uv[TX_SIZES];
+  FILTER_INTRA_MODE_INFO filter_intra_mode_info_uv[TX_SIZES_ALL];
 #endif  // CONFIG_FILTER_INTRA
   const int intra_cost_penalty = av1_get_intra_cost_penalty(
       cm->base_qindex, cm->y_dc_delta_q, cm->bit_depth);
@@ -8676,7 +8829,6 @@
   int64_t mode_threshold[MAX_MODES];
   int *mode_map = tile_data->mode_map[bsize];
   const int mode_search_skip_flags = sf->mode_search_skip_flags;
-  const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
 #if CONFIG_PVQ
   od_rollback_buffer pre_buf;
 #endif
@@ -8751,7 +8903,7 @@
                            &comp_mode_p);
 
   for (i = 0; i < REFERENCE_MODES; ++i) best_pred_rd[i] = INT64_MAX;
-  for (i = 0; i < TX_SIZES; i++) rate_uv_intra[i] = INT_MAX;
+  for (i = 0; i < TX_SIZES_ALL; i++) rate_uv_intra[i] = INT_MAX;
   for (i = 0; i < TOTAL_REFS_PER_FRAME; ++i) x->pred_sse[i] = INT_MAX;
   for (i = 0; i < MB_MODE_COUNT; ++i) {
     for (k = 0; k < TOTAL_REFS_PER_FRAME; ++k) {
@@ -9281,9 +9433,7 @@
         // tokenonly rate, but for intra blocks, tx_size is always coded
         // (prediction granularity), so we account for it in the full rate,
         // not the tokenonly rate.
-        rate_y -=
-            cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)]
-                             [tx_size_to_depth(mbmi->tx_size)];
+        rate_y -= tx_size_cost(cpi, x, bsize, mbmi->tx_size);
       }
 #if CONFIG_EXT_INTRA
       if (is_directional_mode) {