Enable rectangular transforms for Intra also.

These are under EXT_TX + RECT_TX experiment combo.

Results
=======

Derf Set:
--------
All Intra frames: 1.8% avg improvement (and 1.78% BD-rate improvement)
Video: 0.230% avg improvement (and 0.262% BD-rate improvement)

Objective-1-fast set
--------------------
Video: 0.52 PSNR improvement

Change-Id: I1893465929858e38419f327752dc61c19b96b997
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 9f735bc..285adfa 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -688,7 +688,7 @@
 
 static INLINE int is_rect_tx_allowed(const MACROBLOCKD *xd,
                                      const MB_MODE_INFO *mbmi) {
-  return is_inter_block(mbmi) && is_rect_tx_allowed_bsize(mbmi->sb_type) &&
+  return is_rect_tx_allowed_bsize(mbmi->sb_type) &&
          !xd->lossless[mbmi->segment_id];
 }
 
@@ -699,40 +699,33 @@
 static INLINE TX_SIZE tx_size_from_tx_mode(BLOCK_SIZE bsize, TX_MODE tx_mode,
                                            int is_inter) {
   const TX_SIZE largest_tx_size = tx_mode_to_biggest_tx_size[tx_mode];
-#if CONFIG_VAR_TX
-  const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
-
-#if CONFIG_CB4X4
-  if (!is_inter || bsize == BLOCK_4X4)
-    return AOMMIN(max_txsize_lookup[bsize], largest_tx_size);
-#else
-  if (!is_inter || bsize < BLOCK_8X8)
-    return AOMMIN(max_txsize_lookup[bsize], largest_tx_size);
-#endif
-
-  if (txsize_sqr_map[max_tx_size] <= largest_tx_size)
-    return max_tx_size;
-  else
-    return largest_tx_size;
+#if CONFIG_VAR_TX || (CONFIG_EXT_TX && CONFIG_RECT_TX)
+  const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bsize];
 #else
   const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
+#endif  // CONFIG_VAR_TX || (CONFIG_EXT_TX && CONFIG_RECT_TX)
+  (void)is_inter;
+#if CONFIG_VAR_TX
+#if CONFIG_CB4X4
+  if (bsize == BLOCK_4X4)
+    return AOMMIN(max_txsize_lookup[bsize], largest_tx_size);
+#else
+  if (bsize < BLOCK_8X8)
+    return AOMMIN(max_txsize_lookup[bsize], largest_tx_size);
 #endif
-
-#if CONFIG_EXT_TX && CONFIG_RECT_TX
-  if (!is_inter) {
-    return AOMMIN(max_tx_size, largest_tx_size);
+  if (txsize_sqr_map[max_rect_tx_size] <= largest_tx_size)
+    return max_rect_tx_size;
+  else
+    return largest_tx_size;
+#elif CONFIG_EXT_TX && CONFIG_RECT_TX
+  if (txsize_sqr_up_map[max_rect_tx_size] <= largest_tx_size) {
+    return max_rect_tx_size;
   } else {
-    const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bsize];
-    if (txsize_sqr_up_map[max_rect_tx_size] <= largest_tx_size) {
-      return max_rect_tx_size;
-    } else {
-      return largest_tx_size;
-    }
+    return largest_tx_size;
   }
 #else
-  (void)is_inter;
   return AOMMIN(max_tx_size, largest_tx_size);
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
+#endif  // CONFIG_VAR_TX
 }
 
 #if CONFIG_FILTER_INTRA
diff --git a/av1/common/common_data.h b/av1/common/common_data.h
index e1e1dd1..bcfdf1a 100644
--- a/av1/common/common_data.h
+++ b/av1/common/common_data.h
@@ -487,6 +487,42 @@
 #endif  // CONFIG_TX64X64
 };
 
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+// Same as "max_txsize_lookup[bsize] - TX_8X8", except for rectangular
+// block which may use a rectangular transform, in which  case it is
+// "(max_txsize_lookup[bsize] + 1) - TX_8X8", invalid for bsize < 8X8
+static const int32_t intra_tx_size_cat_lookup[BLOCK_SIZES] = {
+#if CONFIG_CB4X4
+  // 2X2,             2X4,                4X2,
+  INT32_MIN,          INT32_MIN,          INT32_MIN,
+#endif
+  //                                      4X4
+                                          INT32_MIN,
+  // 4X8,             8X4,                8X8
+  INT32_MIN,          INT32_MIN,          TX_8X8 - TX_8X8,
+  // 8X16,            16X8,               16X16
+  TX_16X16 - TX_8X8,  TX_16X16 - TX_8X8,  TX_16X16 - TX_8X8,
+  // 16X32,           32X16,              32X32
+  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,
+  // 32X64,           64X32,
+  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,
+#if CONFIG_TX64X64
+  // 64X64
+  TX_64X64 - TX_8X8,
+#if CONFIG_EXT_PARTITION
+  // 64x128,          128x64,             128x128
+  TX_64X64 - TX_8X8,  TX_64X64 - TX_8X8,  TX_64X64 - TX_8X8,
+#endif  // CONFIG_EXT_PARTITION
+#else
+  // 64X64
+  TX_32X32 - TX_8X8,
+#if CONFIG_EXT_PARTITION
+  // 64x128,          128x64,             128x128
+  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,
+#endif  // CONFIG_EXT_PARTITION
+#endif  // CONFIG_TX64X64
+};
+#else
 // Same as "max_txsize_lookup[bsize] - TX_8X8", invalid for bsize < 8X8
 static const int32_t intra_tx_size_cat_lookup[BLOCK_SIZES] = {
 #if CONFIG_CB4X4
@@ -519,46 +555,10 @@
 #endif  // CONFIG_EXT_PARTITION
 #endif  // CONFIG_TX64X64
 };
-
-#if CONFIG_EXT_TX && CONFIG_RECT_TX
-// Same as "max_txsize_lookup[bsize] - TX_8X8", except for rectangular
-// block which may use a rectangular transform, in which  case it is
-// "(max_txsize_lookup[bsize] + 1) - TX_8X8", invalid for bsize < 8X8
-static const int32_t inter_tx_size_cat_lookup[BLOCK_SIZES] = {
-#if CONFIG_CB4X4
-  // 2X2,             2X4,                4X2,
-  INT32_MIN,          INT32_MIN,          INT32_MIN,
-#endif
-  //                                      4X4
-                                          INT32_MIN,
-  // 4X8,             8X4,                8X8
-  INT32_MIN,          INT32_MIN,           TX_8X8 - TX_8X8,
-  // 8X16,            16X8,               16X16
-  TX_16X16 - TX_8X8,  TX_16X16 - TX_8X8,  TX_16X16 - TX_8X8,
-  // 16X32,           32X16,              32X32
-  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,
-  // 32X64,           64X32,
-  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,
-#if CONFIG_TX64X64
-  // 64X64
-  TX_64X64 - TX_8X8,
-#if CONFIG_EXT_PARTITION
-  // 64x128,          128x64,             128x128
-  TX_64X64 - TX_8X8,  TX_64X64 - TX_8X8,  TX_64X64 - TX_8X8,
-#endif  // CONFIG_EXT_PARTITION
-#else
-  // 64X64
-  TX_32X32 - TX_8X8,
-#if CONFIG_EXT_PARTITION
-  // 64x128,          128x64,             128x128
-  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,  TX_32X32 - TX_8X8,
-#endif  // CONFIG_EXT_PARTITION
-#endif  // CONFIG_TX64X64
-};
-#else
-#define inter_tx_size_cat_lookup intra_tx_size_cat_lookup
 #endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
 
+#define inter_tx_size_cat_lookup intra_tx_size_cat_lookup
+
 /* clang-format on */
 
 static const TX_SIZE sub_tx_size_map[TX_SIZES_ALL] = {
diff --git a/av1/common/entropy.c b/av1/common/entropy.c
index 7f6cd0d..686b535 100644
--- a/av1/common/entropy.c
+++ b/av1/common/entropy.c
@@ -3534,6 +3534,9 @@
       (const unsigned int(*)[REF_TYPES][COEF_BANDS]
                             [COEFF_CONTEXTS])cm->counts.eob_branch[tx_size];
   int i, j, k, l, m;
+#if CONFIG_RECT_TX
+  assert(!is_rect_tx(tx_size));
+#endif  // CONFIG_RECT_TX
 
   for (i = 0; i < PLANE_TYPES; ++i)
     for (j = 0; j < REF_TYPES; ++j)
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index ba2ed88..4df6079 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -2526,6 +2526,7 @@
 }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
+// TODO(urvang/davidbarker): Refactor with av1_predict_intra_block().
 static void build_intra_predictors_for_interintra(MACROBLOCKD *xd, uint8_t *ref,
                                                   int ref_stride, uint8_t *dst,
                                                   int dst_stride,
diff --git a/av1/common/reconintra.c b/av1/common/reconintra.c
index ff93ada..bd6efcb 100644
--- a/av1/common/reconintra.c
+++ b/av1/common/reconintra.c
@@ -1387,6 +1387,8 @@
       filter_intra_mode_info->filter_intra_mode[plane != 0];
 #endif  // CONFIG_FILTER_INTRA
   int base = 128 << (xd->bd - 8);
+  assert(tx_size_wide[tx_size] == tx_size_high[tx_size]);
+
 // 127 127 127 .. 127 127 127 127 127 127
 // 129  A   B  ..  Y   Z
 // 129  C   D  ..  W   X
@@ -1552,6 +1554,7 @@
   const FILTER_INTRA_MODE filter_intra_mode =
       filter_intra_mode_info->filter_intra_mode[plane != 0];
 #endif  // CONFIG_FILTER_INTRA
+  assert(tx_size_wide[tx_size] == tx_size_high[tx_size]);
 
 // 127 127 127 .. 127 127 127 127 127 127
 // 129  A   B  ..  Y   Z
@@ -1687,11 +1690,11 @@
   }
 }
 
-void av1_predict_intra_block(const MACROBLOCKD *xd, int wpx, int hpx,
-                             TX_SIZE tx_size, PREDICTION_MODE mode,
-                             const uint8_t *ref, int ref_stride, uint8_t *dst,
-                             int dst_stride, int col_off, int row_off,
-                             int plane) {
+static void predict_square_intra_block(const MACROBLOCKD *xd, int wpx, int hpx,
+                                       TX_SIZE tx_size, PREDICTION_MODE mode,
+                                       const uint8_t *ref, int ref_stride,
+                                       uint8_t *dst, int dst_stride,
+                                       int col_off, int row_off, int plane) {
   const BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
   const int txw = tx_size_wide_unit[tx_size];
@@ -1725,6 +1728,7 @@
                     tx_size, row_off, col_off, pd->subsampling_x);
   const int have_bottom = av1_has_bottom(bsize, mi_row, mi_col, yd > 0, tx_size,
                                          row_off, col_off, pd->subsampling_y);
+  assert(txwpx == txhpx);
 
 #if CONFIG_PALETTE
   if (xd->mi[0]->mbmi.palette_mode_info.palette_size[plane != 0] > 0) {
@@ -1782,6 +1786,142 @@
                          plane);
 }
 
+void av1_predict_intra_block(const MACROBLOCKD *xd, int wpx, int hpx,
+                             TX_SIZE tx_size, PREDICTION_MODE mode,
+                             const uint8_t *ref, int ref_stride, uint8_t *dst,
+                             int dst_stride, int col_off, int row_off,
+                             int plane) {
+  const int tx_width = tx_size_wide[tx_size];
+  const int tx_height = tx_size_high[tx_size];
+  if (tx_width == tx_height) {
+    predict_square_intra_block(xd, wpx, hpx, tx_size, mode, ref, ref_stride,
+                               dst, dst_stride, col_off, row_off, plane);
+  } else {
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+#if CONFIG_AOM_HIGHBITDEPTH
+    uint16_t tmp16[MAX_SB_SIZE];
+#endif
+    uint8_t tmp[MAX_SB_SIZE];
+    const TX_SIZE sub_tx_size = txsize_sqr_map[tx_size];
+    assert(sub_tx_size < TX_SIZES);
+    assert((tx_width == wpx && tx_height == hpx) ||
+           (tx_width == (wpx >> 1) && tx_height == hpx) ||
+           (tx_width == wpx && tx_height == (hpx >> 1)));
+
+    if (tx_width < tx_height) {
+      assert(tx_height == (tx_width << 1));
+      // Predict the top square sub-block.
+      predict_square_intra_block(xd, wpx, hpx, sub_tx_size, mode, ref,
+                                 ref_stride, dst, dst_stride, col_off, row_off,
+                                 plane);
+      {
+        const int half_tx_height = tx_height >> 1;
+        const int half_txh_unit = tx_size_high_unit[tx_size] >> 1;
+        // Cast away const to modify 'ref' temporarily; will be restored later.
+        uint8_t *src_2 = (uint8_t *)ref + half_tx_height * ref_stride;
+        uint8_t *dst_2 = dst + half_tx_height * dst_stride;
+        const int row_off_2 = row_off + half_txh_unit;
+        // Save the last row of top square sub-block as 'above' row for bottom
+        // square sub-block.
+        if (src_2 != dst_2 || ref_stride != dst_stride) {
+#if CONFIG_AOM_HIGHBITDEPTH
+          if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+            uint16_t *src_2_16 = CONVERT_TO_SHORTPTR(src_2);
+            uint16_t *dst_2_16 = CONVERT_TO_SHORTPTR(dst_2);
+            memcpy(tmp16, src_2_16 - ref_stride, tx_width * sizeof(*src_2_16));
+            memcpy(src_2_16 - ref_stride, dst_2_16 - dst_stride,
+                   tx_width * sizeof(*src_2_16));
+          } else {
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+            memcpy(tmp, src_2 - ref_stride, tx_width * sizeof(*src_2));
+            memcpy(src_2 - ref_stride, dst_2 - dst_stride,
+                   tx_width * sizeof(*src_2));
+#if CONFIG_AOM_HIGHBITDEPTH
+          }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+        }
+        // Predict the bottom square sub-block.
+        predict_square_intra_block(xd, wpx, hpx, sub_tx_size, mode, src_2,
+                                   ref_stride, dst_2, dst_stride, col_off,
+                                   row_off_2, plane);
+        // Restore the last row of top square sub-block.
+        if (src_2 != dst_2 || ref_stride != dst_stride) {
+#if CONFIG_AOM_HIGHBITDEPTH
+          if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+            uint16_t *src_2_16 = CONVERT_TO_SHORTPTR(src_2);
+            memcpy(src_2_16 - ref_stride, tmp16, tx_width * sizeof(*src_2_16));
+          } else {
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+            memcpy(src_2 - ref_stride, tmp, tx_width * sizeof(*src_2));
+#if CONFIG_AOM_HIGHBITDEPTH
+          }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+        }
+      }
+    } else {  // tx_width > tx_height
+      assert(tx_width == (tx_height << 1));
+      // Predict the left square sub-block
+      predict_square_intra_block(xd, wpx, hpx, sub_tx_size, mode, ref,
+                                 ref_stride, dst, dst_stride, col_off, row_off,
+                                 plane);
+      {
+        int i;
+        const int half_tx_width = tx_width >> 1;
+        const int half_txw_unit = tx_size_wide_unit[tx_size] >> 1;
+        // Cast away const to modify 'ref' temporarily; will be restored later.
+        uint8_t *src_2 = (uint8_t *)ref + half_tx_width;
+        uint8_t *dst_2 = dst + half_tx_width;
+        const int col_off_2 = col_off + half_txw_unit;
+        // Save the last column of left square sub-block as 'left' column for
+        // right square sub-block.
+        if (src_2 != dst_2 || ref_stride != dst_stride) {
+#if CONFIG_AOM_HIGHBITDEPTH
+          if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+            uint16_t *src_2_16 = CONVERT_TO_SHORTPTR(src_2);
+            uint16_t *dst_2_16 = CONVERT_TO_SHORTPTR(dst_2);
+            for (i = 0; i < tx_height; ++i) {
+              tmp16[i] = src_2_16[i * ref_stride - 1];
+              src_2_16[i * ref_stride - 1] = dst_2_16[i * dst_stride - 1];
+            }
+          } else {
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+            for (i = 0; i < tx_height; ++i) {
+              tmp[i] = src_2[i * ref_stride - 1];
+              src_2[i * ref_stride - 1] = dst_2[i * dst_stride - 1];
+            }
+#if CONFIG_AOM_HIGHBITDEPTH
+          }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+        }
+        // Predict the right square sub-block.
+        predict_square_intra_block(xd, wpx, hpx, sub_tx_size, mode, src_2,
+                                   ref_stride, dst_2, dst_stride, col_off_2,
+                                   row_off, plane);
+        // Restore the last column of left square sub-block.
+        if (src_2 != dst_2 || ref_stride != dst_stride) {
+#if CONFIG_AOM_HIGHBITDEPTH
+          if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+            uint16_t *src_2_16 = CONVERT_TO_SHORTPTR(src_2);
+            for (i = 0; i < tx_height; ++i) {
+              src_2_16[i * ref_stride - 1] = tmp16[i];
+            }
+          } else {
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+            for (i = 0; i < tx_height; ++i) {
+              src_2[i * ref_stride - 1] = tmp[i];
+            }
+#if CONFIG_AOM_HIGHBITDEPTH
+          }
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+        }
+      }
+    }
+#else
+    assert(0);
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
+  }
+}
+
 void av1_init_intra_predictors(void) {
   once(av1_init_intra_predictors_internal);
 }
diff --git a/av1/common/scan.c b/av1/common/scan.c
index b5cde7c..7522a72 100644
--- a/av1/common/scan.c
+++ b/av1/common/scan.c
@@ -47,7 +47,6 @@
   17, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, mcol_scan_4x8[32]) = {
   0, 4, 8,  12, 16, 20, 24, 28, 1, 5, 9,  13, 17, 21, 25, 29,
   2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31,
@@ -57,14 +56,12 @@
   0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
   16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
 };
-#endif  // CONFIG_EXT_TX
 
 DECLARE_ALIGNED(16, static const int16_t, default_scan_8x4[32]) = {
   0,  1,  8,  9, 2,  16, 10, 17, 18, 3,  24, 11, 25, 19, 26, 4,
   12, 27, 20, 5, 28, 13, 21, 29, 6,  14, 22, 30, 7,  15, 23, 31,
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, mcol_scan_8x4[32]) = {
   0, 8,  16, 24, 1, 9,  17, 25, 2, 10, 18, 26, 3, 11, 19, 27,
   4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31,
@@ -74,7 +71,6 @@
   0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
   16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
 };
-#endif
 
 DECLARE_ALIGNED(16, static const int16_t, default_scan_4x16[64]) = {
   0,  1,  4,  2,  5,  8,  3,  6,  9,  12, 7,  10, 13, 16, 11, 14,
@@ -306,7 +302,6 @@
   122, 63, 78,  93,  108, 123, 79, 94, 109, 124, 95,  110, 125, 111, 126, 127,
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, mcol_scan_8x16[128]) = {
   0, 8,  16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96,  104, 112, 120,
   1, 9,  17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97,  105, 113, 121,
@@ -352,7 +347,6 @@
   105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
   120, 121, 122, 123, 124, 125, 126, 127,
 };
-#endif
 
 DECLARE_ALIGNED(16, static const int16_t, default_scan_16x32[512]) = {
   0,   1,   16,  2,   17,  32,  3,   18,  33,  48,  4,   19,  34,  49,  64,
@@ -430,7 +424,6 @@
   510, 511,
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, mcol_scan_16x32[512]) = {
   0,   16,  32,  48,  64,  80,  96,  112, 128, 144, 160, 176, 192, 208, 224,
   240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464,
@@ -579,7 +572,6 @@
   495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509,
   510, 511,
 };
-#endif  // CONFIG_EXT_TX
 
 DECLARE_ALIGNED(16, static const int16_t, default_scan_16x16[256]) = {
   0,   16,  1,   32,  17,  2,   48,  33,  18,  3,   64,  34,  49,  19,  65,
@@ -1548,7 +1540,6 @@
   24, 22, 25, 23, 26, 24, 24, 25, 28, 26, 29, 27, 30, 0,  0
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t,
                 mcol_scan_4x8_neighbors[33 * MAX_NEIGHBORS]) = {
   0, 0, 0,  0,  4,  4,  8,  8,  12, 12, 16, 16, 20, 20, 24, 24, 0,
@@ -1564,7 +1555,6 @@
   13, 16, 14, 17, 15, 18, 16, 16, 17, 20, 18, 21, 19, 22, 20, 20, 21,
   24, 22, 25, 23, 26, 24, 24, 25, 28, 26, 29, 27, 30, 0,  0
 };
-#endif
 
 DECLARE_ALIGNED(16, static const int16_t,
                 default_scan_8x4_neighbors[33 * MAX_NEIGHBORS]) = {
@@ -1574,7 +1564,6 @@
   13, 14, 21, 22, 29, 6, 6,  7,  14, 15, 22, 23, 30, 0,  0
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t,
                 mcol_scan_8x4_neighbors[33 * MAX_NEIGHBORS]) = {
   0,  0,  0,  0,  8,  8,  16, 16, 0,  0,  1,  8,  9,  16, 17, 24, 1,
@@ -1590,7 +1579,6 @@
   9,  16, 10, 17, 11, 18, 12, 19, 13, 20, 14, 21, 15, 22, 16, 16, 17,
   24, 18, 25, 19, 26, 20, 27, 21, 28, 22, 29, 23, 30, 0,  0
 };
-#endif  // CONFIG_EXT_TX
 
 DECLARE_ALIGNED(16, static const int16_t,
                 default_scan_4x16_neighbors[65 * MAX_NEIGHBORS]) = {
@@ -1995,7 +1983,6 @@
   126, 0,   0
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t,
                 mcol_scan_8x16_neighbors[129 * MAX_NEIGHBORS]) = {
   0,  0,  0,  0,  8,  8,  16, 16, 24, 24,  32,  32,  40,  40,  48,  48,
@@ -2081,7 +2068,6 @@
   104, 119, 105, 120, 106, 121, 107, 122, 108, 123, 109, 124, 110, 125, 111,
   126, 0,   0
 };
-#endif
 
 DECLARE_ALIGNED(16, static const int16_t,
                 default_scan_16x32_neighbors[513 * MAX_NEIGHBORS]) = {
@@ -2229,7 +2215,6 @@
   478, 509, 479, 510, 0,   0
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t,
                 mcol_scan_16x32_neighbors[513 * MAX_NEIGHBORS]) = {
   0,   0,   0,   0,   16,  16,  32,  32,  48,  48,  64,  64,  80,  80,  96,
@@ -2521,7 +2506,6 @@
   501, 471, 502, 472, 503, 473, 504, 474, 505, 475, 506, 476, 507, 477, 508,
   478, 509, 479, 510, 0,   0
 };
-#endif  // CONFIG_EXT_TX
 
 #if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t,
@@ -4297,7 +4281,6 @@
   15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, av1_mcol_iscan_4x8[32]) = {
   0, 8,  16, 24, 1, 9,  17, 25, 2, 10, 18, 26, 3, 11, 19, 27,
   4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31,
@@ -4307,14 +4290,12 @@
   0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
   16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
 };
-#endif
 
 DECLARE_ALIGNED(16, static const int16_t, av1_default_iscan_8x4[32]) = {
   0, 1, 4, 9,  15, 19, 24, 28, 2,  3,  6,  11, 16, 21, 25, 29,
   5, 7, 8, 13, 18, 22, 26, 30, 10, 12, 14, 17, 20, 23, 27, 31,
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, av1_mcol_iscan_8x4[32]) = {
   0, 4, 8,  12, 16, 20, 24, 28, 1, 5, 9,  13, 17, 21, 25, 29,
   2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31,
@@ -4324,7 +4305,6 @@
   0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
   16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
 };
-#endif  // CONFIG_EXT_TX
 
 DECLARE_ALIGNED(16, static const int16_t, av1_default_iscan_4x16[64]) = {
   0,  1,  3,  6,  2,  4,  7,  10, 5,  8,  11, 14, 9,  12, 15, 18,
@@ -4554,7 +4534,6 @@
   35, 43, 51, 59, 67, 75, 83, 91, 99, 106, 112, 117, 121, 124, 126, 127,
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, av1_mcol_iscan_8x16[128]) = {
   0,  16, 32, 48, 64, 80, 96,  112, 1,  17, 33, 49, 65, 81, 97,  113,
   2,  18, 34, 50, 66, 82, 98,  114, 3,  19, 35, 51, 67, 83, 99,  115,
@@ -4600,7 +4579,6 @@
   105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
   120, 121, 122, 123, 124, 125, 126, 127,
 };
-#endif
 
 DECLARE_ALIGNED(16, static const int16_t, av1_default_iscan_16x32[512]) = {
   0,   1,   3,   6,   10,  15,  21,  28,  36,  45,  55,  66,  78,  91,  105,
@@ -4678,7 +4656,6 @@
   510, 511,
 };
 
-#if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, av1_mcol_iscan_16x32[512]) = {
   0,  32, 64, 96,  128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480,
   1,  33, 65, 97,  129, 161, 193, 225, 257, 289, 321, 353, 385, 417, 449, 481,
@@ -4828,8 +4805,6 @@
   510, 511,
 };
 
-#endif  // CONFIG_EXT_TX
-
 #if CONFIG_EXT_TX
 DECLARE_ALIGNED(16, static const int16_t, av1_mcol_iscan_16x16[256]) = {
   0,  16, 32, 48, 64, 80, 96,  112, 128, 144, 160, 176, 192, 208, 224, 240,
@@ -5756,7 +5731,7 @@
 #endif  // CONFIG_TX64X64
 };
 
-const SCAN_ORDER av1_intra_scan_orders[TX_SIZES][TX_TYPES] = {
+const SCAN_ORDER av1_intra_scan_orders[TX_SIZES_ALL][TX_TYPES] = {
 #if CONFIG_CB4X4
   {
       // TX_2X2
@@ -5909,8 +5884,162 @@
       { default_scan_64x64, av1_default_iscan_64x64,
         default_scan_64x64_neighbors },
 #endif  // CONFIG_EXT_TX
-  }
+  },
 #endif  // CONFIG_TX64X64
+  {
+      // TX_4X8
+      { default_scan_4x8, av1_default_iscan_4x8, default_scan_4x8_neighbors },
+      { mrow_scan_4x8, av1_mrow_iscan_4x8, mrow_scan_4x8_neighbors },
+      { mcol_scan_4x8, av1_mcol_iscan_4x8, mcol_scan_4x8_neighbors },
+      { default_scan_4x8, av1_default_iscan_4x8, default_scan_4x8_neighbors },
+#if CONFIG_EXT_TX
+      { default_scan_4x8, av1_default_iscan_4x8, default_scan_4x8_neighbors },
+      { default_scan_4x8, av1_default_iscan_4x8, default_scan_4x8_neighbors },
+      { default_scan_4x8, av1_default_iscan_4x8, default_scan_4x8_neighbors },
+      { default_scan_4x8, av1_default_iscan_4x8, default_scan_4x8_neighbors },
+      { default_scan_4x8, av1_default_iscan_4x8, default_scan_4x8_neighbors },
+      { mrow_scan_4x8, av1_mrow_iscan_4x8, mrow_scan_4x8_neighbors },
+      { mrow_scan_4x8, av1_mrow_iscan_4x8, mrow_scan_4x8_neighbors },
+      { mcol_scan_4x8, av1_mcol_iscan_4x8, mcol_scan_4x8_neighbors },
+      { mrow_scan_4x8, av1_mrow_iscan_4x8, mrow_scan_4x8_neighbors },
+      { mcol_scan_4x8, av1_mcol_iscan_4x8, mcol_scan_4x8_neighbors },
+      { mrow_scan_4x8, av1_mrow_iscan_4x8, mrow_scan_4x8_neighbors },
+      { mcol_scan_4x8, av1_mcol_iscan_4x8, mcol_scan_4x8_neighbors },
+#endif  // CONFIG_EXT_TX
+  },
+  {
+      // TX_8X4
+      { default_scan_8x4, av1_default_iscan_8x4, default_scan_8x4_neighbors },
+      { mrow_scan_8x4, av1_mrow_iscan_8x4, mrow_scan_8x4_neighbors },
+      { mcol_scan_8x4, av1_mcol_iscan_8x4, mcol_scan_8x4_neighbors },
+      { default_scan_8x4, av1_default_iscan_8x4, default_scan_8x4_neighbors },
+#if CONFIG_EXT_TX
+      { default_scan_8x4, av1_default_iscan_8x4, default_scan_8x4_neighbors },
+      { default_scan_8x4, av1_default_iscan_8x4, default_scan_8x4_neighbors },
+      { default_scan_8x4, av1_default_iscan_8x4, default_scan_8x4_neighbors },
+      { default_scan_8x4, av1_default_iscan_8x4, default_scan_8x4_neighbors },
+      { default_scan_8x4, av1_default_iscan_8x4, default_scan_8x4_neighbors },
+      { mrow_scan_8x4, av1_mrow_iscan_8x4, mrow_scan_8x4_neighbors },
+      { mrow_scan_8x4, av1_mrow_iscan_8x4, mrow_scan_8x4_neighbors },
+      { mcol_scan_8x4, av1_mcol_iscan_8x4, mcol_scan_8x4_neighbors },
+      { mrow_scan_8x4, av1_mrow_iscan_8x4, mrow_scan_8x4_neighbors },
+      { mcol_scan_8x4, av1_mcol_iscan_8x4, mcol_scan_8x4_neighbors },
+      { mrow_scan_8x4, av1_mrow_iscan_8x4, mrow_scan_8x4_neighbors },
+      { mcol_scan_8x4, av1_mcol_iscan_8x4, mcol_scan_8x4_neighbors },
+#endif  // CONFIG_EXT_TX
+  },
+  {
+      // TX_8X16
+      { default_scan_8x16, av1_default_iscan_8x16,
+        default_scan_8x16_neighbors },
+      { mrow_scan_8x16, av1_mrow_iscan_8x16, mrow_scan_8x16_neighbors },
+      { mcol_scan_8x16, av1_mcol_iscan_8x16, mcol_scan_8x16_neighbors },
+      { default_scan_8x16, av1_default_iscan_8x16,
+        default_scan_8x16_neighbors },
+#if CONFIG_EXT_TX
+      { default_scan_8x16, av1_default_iscan_8x16,
+        default_scan_8x16_neighbors },
+      { default_scan_8x16, av1_default_iscan_8x16,
+        default_scan_8x16_neighbors },
+      { default_scan_8x16, av1_default_iscan_8x16,
+        default_scan_8x16_neighbors },
+      { default_scan_8x16, av1_default_iscan_8x16,
+        default_scan_8x16_neighbors },
+      { default_scan_8x16, av1_default_iscan_8x16,
+        default_scan_8x16_neighbors },
+      { mrow_scan_8x16, av1_mrow_iscan_8x16, mrow_scan_8x16_neighbors },
+      { mrow_scan_8x16, av1_mrow_iscan_8x16, mrow_scan_8x16_neighbors },
+      { mcol_scan_8x16, av1_mcol_iscan_8x16, mcol_scan_8x16_neighbors },
+      { mrow_scan_8x16, av1_mrow_iscan_8x16, mrow_scan_8x16_neighbors },
+      { mcol_scan_8x16, av1_mcol_iscan_8x16, mcol_scan_8x16_neighbors },
+      { mrow_scan_8x16, av1_mrow_iscan_8x16, mrow_scan_8x16_neighbors },
+      { mcol_scan_8x16, av1_mcol_iscan_8x16, mcol_scan_8x16_neighbors },
+#endif  // CONFIG_EXT_TX
+  },
+  {
+      // TX_16X8
+      { default_scan_16x8, av1_default_iscan_16x8,
+        default_scan_16x8_neighbors },
+      { mrow_scan_16x8, av1_mrow_iscan_16x8, mrow_scan_16x8_neighbors },
+      { mcol_scan_16x8, av1_mcol_iscan_16x8, mcol_scan_16x8_neighbors },
+      { default_scan_16x8, av1_default_iscan_16x8,
+        default_scan_16x8_neighbors },
+#if CONFIG_EXT_TX
+      { default_scan_16x8, av1_default_iscan_16x8,
+        default_scan_16x8_neighbors },
+      { default_scan_16x8, av1_default_iscan_16x8,
+        default_scan_16x8_neighbors },
+      { default_scan_16x8, av1_default_iscan_16x8,
+        default_scan_16x8_neighbors },
+      { default_scan_16x8, av1_default_iscan_16x8,
+        default_scan_16x8_neighbors },
+      { default_scan_16x8, av1_default_iscan_16x8,
+        default_scan_16x8_neighbors },
+      { mrow_scan_16x8, av1_mrow_iscan_16x8, mrow_scan_16x8_neighbors },
+      { mrow_scan_16x8, av1_mrow_iscan_16x8, mrow_scan_16x8_neighbors },
+      { mcol_scan_16x8, av1_mcol_iscan_16x8, mcol_scan_16x8_neighbors },
+      { mrow_scan_16x8, av1_mrow_iscan_16x8, mrow_scan_16x8_neighbors },
+      { mcol_scan_16x8, av1_mcol_iscan_16x8, mcol_scan_16x8_neighbors },
+      { mrow_scan_16x8, av1_mrow_iscan_16x8, mrow_scan_16x8_neighbors },
+      { mcol_scan_16x8, av1_mcol_iscan_16x8, mcol_scan_16x8_neighbors },
+#endif  // CONFIG_EXT_TX
+  },
+  {
+      // TX_16X32
+      { default_scan_16x32, av1_default_iscan_16x32,
+        default_scan_16x32_neighbors },
+      { mrow_scan_16x32, av1_mrow_iscan_16x32, mrow_scan_16x32_neighbors },
+      { mcol_scan_16x32, av1_mcol_iscan_16x32, mcol_scan_16x32_neighbors },
+      { default_scan_16x32, av1_default_iscan_16x32,
+        default_scan_16x32_neighbors },
+#if CONFIG_EXT_TX
+      { default_scan_16x32, av1_default_iscan_16x32,
+        default_scan_16x32_neighbors },
+      { default_scan_16x32, av1_default_iscan_16x32,
+        default_scan_16x32_neighbors },
+      { default_scan_16x32, av1_default_iscan_16x32,
+        default_scan_16x32_neighbors },
+      { default_scan_16x32, av1_default_iscan_16x32,
+        default_scan_16x32_neighbors },
+      { default_scan_16x32, av1_default_iscan_16x32,
+        default_scan_16x32_neighbors },
+      { mrow_scan_16x32, av1_mrow_iscan_16x32, mrow_scan_16x32_neighbors },
+      { mrow_scan_16x32, av1_mrow_iscan_16x32, mrow_scan_16x32_neighbors },
+      { mcol_scan_16x32, av1_mcol_iscan_16x32, mcol_scan_16x32_neighbors },
+      { mrow_scan_16x32, av1_mrow_iscan_16x32, mrow_scan_16x32_neighbors },
+      { mcol_scan_16x32, av1_mcol_iscan_16x32, mcol_scan_16x32_neighbors },
+      { mrow_scan_16x32, av1_mrow_iscan_16x32, mrow_scan_16x32_neighbors },
+      { mcol_scan_16x32, av1_mcol_iscan_16x32, mcol_scan_16x32_neighbors },
+#endif  // CONFIG_EXT_TX
+  },
+  {
+      // TX_32X16
+      { default_scan_32x16, av1_default_iscan_32x16,
+        default_scan_32x16_neighbors },
+      { mrow_scan_32x16, av1_mrow_iscan_32x16, mrow_scan_32x16_neighbors },
+      { mcol_scan_32x16, av1_mcol_iscan_32x16, mcol_scan_32x16_neighbors },
+      { default_scan_32x16, av1_default_iscan_32x16,
+        default_scan_32x16_neighbors },
+#if CONFIG_EXT_TX
+      { default_scan_32x16, av1_default_iscan_32x16,
+        default_scan_32x16_neighbors },
+      { default_scan_32x16, av1_default_iscan_32x16,
+        default_scan_32x16_neighbors },
+      { default_scan_32x16, av1_default_iscan_32x16,
+        default_scan_32x16_neighbors },
+      { default_scan_32x16, av1_default_iscan_32x16,
+        default_scan_32x16_neighbors },
+      { default_scan_32x16, av1_default_iscan_32x16,
+        default_scan_32x16_neighbors },
+      { mrow_scan_32x16, av1_mrow_iscan_32x16, mrow_scan_32x16_neighbors },
+      { mrow_scan_32x16, av1_mrow_iscan_32x16, mrow_scan_32x16_neighbors },
+      { mcol_scan_32x16, av1_mcol_iscan_32x16, mcol_scan_32x16_neighbors },
+      { mrow_scan_32x16, av1_mrow_iscan_32x16, mrow_scan_32x16_neighbors },
+      { mcol_scan_32x16, av1_mcol_iscan_32x16, mcol_scan_32x16_neighbors },
+      { mrow_scan_32x16, av1_mrow_iscan_32x16, mrow_scan_32x16_neighbors },
+      { mcol_scan_32x16, av1_mcol_iscan_32x16, mcol_scan_32x16_neighbors },
+#endif  // CONFIG_EXT_TX
+  },
 };
 
 const SCAN_ORDER av1_inter_scan_orders[TX_SIZES_ALL][TX_TYPES] = {
diff --git a/av1/common/scan.h b/av1/common/scan.h
index 71868d0..9047359 100644
--- a/av1/common/scan.h
+++ b/av1/common/scan.h
@@ -26,7 +26,7 @@
 #define MAX_NEIGHBORS 2
 
 extern const SCAN_ORDER av1_default_scan_orders[TX_SIZES];
-extern const SCAN_ORDER av1_intra_scan_orders[TX_SIZES][TX_TYPES];
+extern const SCAN_ORDER av1_intra_scan_orders[TX_SIZES_ALL][TX_TYPES];
 extern const SCAN_ORDER av1_inter_scan_orders[TX_SIZES_ALL][TX_TYPES];
 
 #if CONFIG_ADAPT_SCAN
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 48c0f5f..7d822ae 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -466,7 +466,7 @@
   PREDICTION_MODE mode = (plane == 0) ? mbmi->mode : mbmi->uv_mode;
   PLANE_TYPE plane_type = (plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV;
   uint8_t *dst;
-  int block_idx = (row << 1) + col;
+  const int block_idx = (row << 1) + col;
 #if CONFIG_PVQ
   (void)cm;
   (void)r;
@@ -475,7 +475,7 @@
 
 #if !CONFIG_CB4X4
   if (mbmi->sb_type < BLOCK_8X8)
-    if (plane == 0) mode = xd->mi[0]->bmi[(row << 1) + col].as_mode;
+    if (plane == 0) mode = xd->mi[0]->bmi[block_idx].as_mode;
 #endif
 
   av1_predict_intra_block(xd, pd->width, pd->height, tx_size, mode, dst,
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 7c59417..84de6aa 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -384,45 +384,32 @@
                                      int tx_size_cat, aom_reader *r) {
   FRAME_COUNTS *counts = xd->counts;
   const int ctx = get_tx_size_context(xd);
-  int depth = aom_read_tree(r, av1_tx_size_tree[tx_size_cat],
-                            cm->fc->tx_size_probs[tx_size_cat][ctx], ACCT_STR);
-  TX_SIZE tx_size = depth_to_tx_size(depth);
+  const int depth =
+      aom_read_tree(r, av1_tx_size_tree[tx_size_cat],
+                    cm->fc->tx_size_probs[tx_size_cat][ctx], ACCT_STR);
+  const TX_SIZE tx_size = depth_to_tx_size(depth);
+#if CONFIG_RECT_TX
+  assert(!is_rect_tx(tx_size));
+#endif  // CONFIG_RECT_TX
   if (counts) ++counts->tx_size[tx_size_cat][ctx][depth];
   return tx_size;
 }
 
-static TX_SIZE read_tx_size_intra(AV1_COMMON *cm, MACROBLOCKD *xd,
-                                  aom_reader *r) {
-  TX_MODE tx_mode = cm->tx_mode;
-  BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
-  if (xd->lossless[xd->mi[0]->mbmi.segment_id]) return TX_4X4;
-  if (bsize >= BLOCK_8X8) {
-    if (tx_mode == TX_MODE_SELECT) {
-      const TX_SIZE tx_size =
-          read_selected_tx_size(cm, xd, intra_tx_size_cat_lookup[bsize], r);
-      assert(tx_size <= max_txsize_lookup[bsize]);
-      return tx_size;
-    } else {
-      return tx_size_from_tx_mode(bsize, cm->tx_mode, 0);
-    }
-  } else {
-    return TX_4X4;
-  }
-}
-
-static TX_SIZE read_tx_size_inter(AV1_COMMON *cm, MACROBLOCKD *xd,
-                                  int allow_select, aom_reader *r) {
-  TX_MODE tx_mode = cm->tx_mode;
-  BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
+static TX_SIZE read_tx_size(AV1_COMMON *cm, MACROBLOCKD *xd, int is_inter,
+                            int allow_select_inter, aom_reader *r) {
+  const TX_MODE tx_mode = cm->tx_mode;
+  const BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
   if (xd->lossless[xd->mi[0]->mbmi.segment_id]) return TX_4X4;
 #if CONFIG_CB4X4 && CONFIG_VAR_TX
   if (bsize > BLOCK_4X4) {
 #else
   if (bsize >= BLOCK_8X8) {
-#endif
-    if (allow_select && tx_mode == TX_MODE_SELECT) {
+#endif  // CONFIG_CB4X4 && CONFIG_VAR_TX
+    if ((!is_inter || allow_select_inter) && tx_mode == TX_MODE_SELECT) {
+      const int32_t tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
+                                           : intra_tx_size_cat_lookup[bsize];
       const TX_SIZE coded_tx_size =
-          read_selected_tx_size(cm, xd, inter_tx_size_cat_lookup[bsize], r);
+          read_selected_tx_size(cm, xd, tx_size_cat, r);
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
       if (coded_tx_size > max_txsize_lookup[bsize]) {
         assert(coded_tx_size == max_txsize_lookup[bsize] + 1);
@@ -433,7 +420,7 @@
 #endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
       return coded_tx_size;
     } else {
-      return tx_size_from_tx_mode(bsize, cm->tx_mode, 1);
+      return tx_size_from_tx_mode(bsize, tx_mode, is_inter);
     }
   } else {
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -441,7 +428,7 @@
     return max_txsize_rect_lookup[bsize];
 #else
     return TX_4X4;
-#endif
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
   }
 }
 
@@ -711,6 +698,7 @@
 #endif
   if (!FIXED_TX_TYPE) {
 #if CONFIG_EXT_TX
+    const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
     if (get_ext_tx_types(tx_size, mbmi->sb_type, inter_block) > 1 &&
         cm->base_qindex > 0 && !mbmi->skip &&
 #if CONFIG_SUPERTX
@@ -724,19 +712,19 @@
         if (eset > 0) {
           mbmi->tx_type = aom_read_tree(
               r, av1_ext_tx_inter_tree[eset],
-              cm->fc->inter_ext_tx_prob[eset][txsize_sqr_map[tx_size]],
-              ACCT_STR);
+              cm->fc->inter_ext_tx_prob[eset][square_tx_size], ACCT_STR);
           if (counts)
-            ++counts->inter_ext_tx[eset][txsize_sqr_map[tx_size]]
-                                  [mbmi->tx_type];
+            ++counts->inter_ext_tx[eset][square_tx_size][mbmi->tx_type];
         }
       } else if (ALLOW_INTRA_EXT_TX) {
         if (eset > 0) {
           mbmi->tx_type = aom_read_tree(
               r, av1_ext_tx_intra_tree[eset],
-              cm->fc->intra_ext_tx_prob[eset][tx_size][mbmi->mode], ACCT_STR);
+              cm->fc->intra_ext_tx_prob[eset][square_tx_size][mbmi->mode],
+              ACCT_STR);
           if (counts)
-            ++counts->intra_ext_tx[eset][tx_size][mbmi->mode][mbmi->tx_type];
+            ++counts->intra_ext_tx[eset][square_tx_size][mbmi->mode]
+                                  [mbmi->tx_type];
         }
       }
     } else {
@@ -807,7 +795,7 @@
   }
 #endif
 
-  mbmi->tx_size = read_tx_size_intra(cm, xd, r);
+  mbmi->tx_size = read_tx_size(cm, xd, 0, 1, r);
   mbmi->ref_frame[0] = INTRA_FRAME;
   mbmi->ref_frame[1] = NONE;
 
@@ -1967,10 +1955,7 @@
           read_tx_size_vartx(cm, xd, mbmi, xd->counts, max_tx_size,
                              height != width, idy, idx, r);
     } else {
-      if (inter_block)
-        mbmi->tx_size = read_tx_size_inter(cm, xd, !mbmi->skip, r);
-      else
-        mbmi->tx_size = read_tx_size_intra(cm, xd, r);
+      mbmi->tx_size = read_tx_size(cm, xd, inter_block, !mbmi->skip, r);
 
       if (inter_block) {
         const int width = block_size_wide[bsize] >> tx_size_wide_log2[0];
@@ -1984,10 +1969,7 @@
       set_txfm_ctxs(mbmi->tx_size, xd->n8_w, xd->n8_h, mbmi->skip, xd);
     }
 #else
-  if (inter_block)
-    mbmi->tx_size = read_tx_size_inter(cm, xd, !mbmi->skip, r);
-  else
-    mbmi->tx_size = read_tx_size_intra(cm, xd, r);
+  mbmi->tx_size = read_tx_size(cm, xd, inter_block, !mbmi->skip, r);
 #endif  // CONFIG_VAR_TX
 #if CONFIG_SUPERTX
   }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index d18c8ca..1545aae 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1169,9 +1169,10 @@
   const TX_SIZE tx_size = is_inter ? mbmi->min_tx_size : mbmi->tx_size;
 #else
   const TX_SIZE tx_size = mbmi->tx_size;
-#endif
+#endif  // CONFIG_VAR_TX
   if (!FIXED_TX_TYPE) {
 #if CONFIG_EXT_TX
+    const TX_SIZE square_tx_size = txsize_sqr_map[tx_size];
     const BLOCK_SIZE bsize = mbmi->sb_type;
     if (get_ext_tx_types(tx_size, bsize, is_inter) > 1 && cm->base_qindex > 0 &&
         !mbmi->skip &&
@@ -1182,16 +1183,19 @@
       int eset = get_ext_tx_set(tx_size, bsize, is_inter);
       if (is_inter) {
         assert(ext_tx_used_inter[eset][mbmi->tx_type]);
-        if (eset > 0)
-          av1_write_token(
-              w, av1_ext_tx_inter_tree[eset],
-              cm->fc->inter_ext_tx_prob[eset][txsize_sqr_map[tx_size]],
-              &ext_tx_inter_encodings[eset][mbmi->tx_type]);
+        if (eset > 0) {
+          av1_write_token(w, av1_ext_tx_inter_tree[eset],
+                          cm->fc->inter_ext_tx_prob[eset][square_tx_size],
+                          &ext_tx_inter_encodings[eset][mbmi->tx_type]);
+        }
       } else if (ALLOW_INTRA_EXT_TX) {
-        if (eset > 0)
-          av1_write_token(w, av1_ext_tx_intra_tree[eset],
-                          cm->fc->intra_ext_tx_prob[eset][tx_size][mbmi->mode],
-                          &ext_tx_intra_encodings[eset][mbmi->tx_type]);
+        assert(ext_tx_used_intra[eset][mbmi->tx_type]);
+        if (eset > 0) {
+          av1_write_token(
+              w, av1_ext_tx_intra_tree[eset],
+              cm->fc->intra_ext_tx_prob[eset][square_tx_size][mbmi->mode],
+              &ext_tx_intra_encodings[eset][mbmi->tx_type]);
+        }
       }
     }
 #else
@@ -2641,6 +2645,9 @@
   unsigned int(*eob_branch_ct)[REF_TYPES][COEF_BANDS][COEFF_CONTEXTS] =
       cpi->common.counts.eob_branch[tx_size];
   int i, j, k, l, m;
+#if CONFIG_RECT_TX
+  assert(!is_rect_tx(tx_size));
+#endif  // CONFIG_RECT_TX
 
   for (i = 0; i < PLANE_TYPES; ++i) {
     for (j = 0; j < REF_TYPES; ++j) {
@@ -2679,6 +2686,9 @@
 #else
   const int probwt = 1;
 #endif
+#if CONFIG_RECT_TX
+  assert(!is_rect_tx(tx_size));
+#endif  // CONFIG_RECT_TX
 
   switch (cpi->sf.use_fast_coef_updates) {
     case TWO_LOOP: {
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index b8e886b..5e29d9a 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -204,6 +204,27 @@
 #endif
 };
 
+// Converts block_index for given transform size to index of the block in raster
+// order.
+static inline int av1_block_index_to_raster_order(TX_SIZE tx_size,
+                                                  int block_idx) {
+  // For transform size 4x8, the possible block_idx values are 0 & 2, because
+  // block_idx values are incremented in steps of size 'tx_width_unit x
+  // tx_height_unit'. But, for this transform size, block_idx = 2 corresponds to
+  // block number 1 in raster order, inside an 8x8 MI block.
+  // For any other transform size, the two indices are equivalent.
+  return (tx_size == TX_4X8 && block_idx == 2) ? 1 : block_idx;
+}
+
+// Inverse of above function.
+// Note: only implemented for transform sizes 4x4, 4x8 and 8x4 right now.
+static inline int av1_raster_order_to_block_index(TX_SIZE tx_size,
+                                                  int raster_order) {
+  assert(tx_size == TX_4X4 || tx_size == TX_4X8 || tx_size == TX_8X4);
+  // We ensure that block indices are 0 & 2 if tx size is 4x8 or 8x4.
+  return (tx_size == TX_4X4) ? raster_order : (raster_order > 0) ? 2 : 0;
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 69b1a02..c5a9c41 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5581,13 +5581,17 @@
         } else {
           intra_tx_size = tx_size_from_tx_mode(bsize, cm->tx_mode, 1);
         }
-#if CONFIG_EXT_TX && CONFIG_RECT_TX
-        ++td->counts->tx_size_implied[max_txsize_lookup[bsize]]
-                                     [txsize_sqr_up_map[tx_size]];
-#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
       } else {
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+        intra_tx_size = tx_size;
+#else
         intra_tx_size = (bsize >= BLOCK_8X8) ? tx_size : TX_4X4;
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
       }
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+      ++td->counts->tx_size_implied[max_txsize_lookup[bsize]]
+                                   [txsize_sqr_up_map[tx_size]];
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
 
       for (j = 0; j < mi_height; j++)
         for (i = 0; i < mi_width; i++)
@@ -5613,7 +5617,8 @@
           ++td->counts->inter_ext_tx[eset][txsize_sqr_map[tx_size]]
                                     [mbmi->tx_type];
         } else {
-          ++td->counts->intra_ext_tx[eset][tx_size][mbmi->mode][mbmi->tx_type];
+          ++td->counts->intra_ext_tx[eset][txsize_sqr_map[tx_size]][mbmi->mode]
+                                    [mbmi->tx_type];
         }
       }
     }
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index a869f82..a031387 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -101,7 +101,7 @@
   av1_token_state tokens[MAX_TX_SQUARE + 1][2];
   unsigned best_index[MAX_TX_SQUARE + 1][2];
   uint8_t token_cache[MAX_TX_SQUARE];
-  const tran_low_t *const coeff = BLOCK_OFFSET(mb->plane[plane].coeff, block);
+  const tran_low_t *const coeff = BLOCK_OFFSET(p->coeff, block);
   tran_low_t *const qcoeff = BLOCK_OFFSET(p->qcoeff, block);
   tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   const int eob = p->eobs[block];
@@ -109,7 +109,8 @@
   const int default_eob = tx_size_2d[tx_size];
   const int16_t *const dequant_ptr = pd->dequant;
   const uint8_t *const band_translate = get_band_translate(tx_size);
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const int block_raster_idx = av1_block_index_to_raster_order(tx_size, block);
+  TX_TYPE tx_type = get_tx_type(plane_type, xd, block_raster_idx, tx_size);
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, is_inter_block(&xd->mi[0]->mbmi));
   const int16_t *const scan = scan_order->scan;
@@ -486,7 +487,8 @@
   struct macroblockd_plane *const pd = &xd->plane[plane];
 #endif
   PLANE_TYPE plane_type = (plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV;
-  TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const int block_raster_idx = av1_block_index_to_raster_order(tx_size, block);
+  TX_TYPE tx_type = get_tx_type(plane_type, xd, block_raster_idx, tx_size);
   const int is_inter = is_inter_block(&xd->mi[0]->mbmi);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, is_inter);
   tran_low_t *const coeff = BLOCK_OFFSET(p->coeff, block);
@@ -626,8 +628,9 @@
   uint8_t *dst;
   ENTROPY_CONTEXT *a, *l;
   INV_TXFM_PARAM inv_txfm_param;
+  const int block_raster_idx = av1_block_index_to_raster_order(tx_size, block);
 #if CONFIG_PVQ
-  int tx_blk_size;
+  int tx_width_pixels, tx_height_pixels;
   int i, j;
 #endif
 #if CONFIG_VAR_TX
@@ -690,18 +693,20 @@
   if (x->pvq_skip[plane]) return;
 
   // transform block size in pixels
-  tx_blk_size = tx_size_wide[tx_size];
+  tx_width_pixels = tx_size_wide[tx_size];
+  tx_height_pixels = tx_size_high[tx_size];
 
   // Since av1 does not have separate function which does inverse transform
   // but av1_inv_txfm_add_*x*() also does addition of predicted image to
   // inverse transformed image,
   // pass blank dummy image to av1_inv_txfm_add_*x*(), i.e. set dst as zeros
-  for (j = 0; j < tx_blk_size; j++)
-    for (i = 0; i < tx_blk_size; i++) dst[j * pd->dst.stride + i] = 0;
+  for (j = 0; j < tx_height_pixels; j++)
+    for (i = 0; i < tx_width_pixels; i++) dst[j * pd->dst.stride + i] = 0;
 #endif
 
   // inverse transform parameters
-  inv_txfm_param.tx_type = get_tx_type(pd->plane_type, xd, block, tx_size);
+  inv_txfm_param.tx_type =
+      get_tx_type(pd->plane_type, xd, block_raster_idx, tx_size);
   inv_txfm_param.tx_size = tx_size;
   inv_txfm_param.eob = p->eobs[block];
   inv_txfm_param.lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
@@ -927,7 +932,9 @@
   struct macroblockd_plane *const pd = &xd->plane[plane];
   tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
   PLANE_TYPE plane_type = (plane == 0) ? PLANE_TYPE_Y : PLANE_TYPE_UV;
-  const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size);
+  const int block_raster_idx = av1_block_index_to_raster_order(tx_size, block);
+  const TX_TYPE tx_type =
+      get_tx_type(plane_type, xd, block_raster_idx, tx_size);
   PREDICTION_MODE mode;
   const int diff_stride = block_size_wide[plane_bsize];
   uint8_t *src, *dst;
@@ -945,13 +952,11 @@
   int i, j;
 #endif
 
-  assert(tx1d_width == tx1d_height);
-
   dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
   src = &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
   src_diff =
       &p->src_diff[(blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]];
-  mode = plane == 0 ? get_y_mode(xd->mi[0], block) : mbmi->uv_mode;
+  mode = (plane == 0) ? get_y_mode(xd->mi[0], block_raster_idx) : mbmi->uv_mode;
   av1_predict_intra_block(xd, pd->width, pd->height, tx_size, mode, dst,
                           dst_stride, dst, dst_stride, blk_col, blk_row, plane);
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 9418017..5369e2b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1076,8 +1076,11 @@
       const PLANE_TYPE plane_type = plane == 0 ? PLANE_TYPE_Y : PLANE_TYPE_UV;
 
       INV_TXFM_PARAM inv_txfm_param;
+      const int block_raster_idx =
+          av1_block_index_to_raster_order(tx_size, block);
 
-      inv_txfm_param.tx_type = get_tx_type(plane_type, xd, block, tx_size);
+      inv_txfm_param.tx_type =
+          get_tx_type(plane_type, xd, block_raster_idx, tx_size);
       inv_txfm_param.tx_size = tx_size;
       inv_txfm_param.eob = eob;
       inv_txfm_param.lossless = xd->lossless[mbmi->segment_id];
@@ -1360,6 +1363,29 @@
 }
 #endif  // CONFIG_SUPERTX
 
+static int tx_size_cost(const AV1_COMP *const cpi, MACROBLOCK *x,
+                        BLOCK_SIZE bsize, TX_SIZE tx_size) {
+  const AV1_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+
+  const int tx_select =
+      cm->tx_mode == TX_MODE_SELECT && mbmi->sb_type >= BLOCK_8X8;
+
+  if (tx_select) {
+    const int is_inter = is_inter_block(mbmi);
+    const int tx_size_cat = is_inter ? inter_tx_size_cat_lookup[bsize]
+                                     : intra_tx_size_cat_lookup[bsize];
+    const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
+    const int depth = tx_size_to_depth(coded_tx_size);
+    const int tx_size_ctx = get_tx_size_context(xd);
+    const int r_tx_size = cpi->tx_size_cost[tx_size_cat][tx_size_ctx][depth];
+    return r_tx_size;
+  } else {
+    return 0;
+  }
+}
+
 static int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
                         RD_STATS *rd_stats, int64_t ref_best_rd, BLOCK_SIZE bs,
                         TX_TYPE tx_type, int tx_size) {
@@ -1370,16 +1396,10 @@
   aom_prob skip_prob = av1_get_skip_prob(cm, xd);
   int s0, s1;
   const int is_inter = is_inter_block(mbmi);
-
-  const int tx_size_cat =
-      is_inter ? inter_tx_size_cat_lookup[bs] : intra_tx_size_cat_lookup[bs];
-  const TX_SIZE coded_tx_size = txsize_sqr_up_map[tx_size];
-  const int depth = tx_size_to_depth(coded_tx_size);
   const int tx_select =
       cm->tx_mode == TX_MODE_SELECT && mbmi->sb_type >= BLOCK_8X8;
-  const int tx_size_ctx = tx_select ? get_tx_size_context(xd) : 0;
-  const int r_tx_size =
-      tx_select ? cpi->tx_size_cost[tx_size_cat][tx_size_ctx][depth] : 0;
+
+  const int r_tx_size = tx_size_cost(cpi, x, bs, tx_size);
 
   assert(skip_prob > 0);
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -1405,8 +1425,9 @@
                                     [mbmi->tx_type];
     } else {
       if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
-        rd_stats->rate += cpi->intra_tx_type_costs[ext_tx_set][mbmi->tx_size]
-                                                  [mbmi->mode][mbmi->tx_type];
+        rd_stats->rate +=
+            cpi->intra_tx_type_costs[ext_tx_set][txsize_sqr_map[mbmi->tx_size]]
+                                    [mbmi->mode][mbmi->tx_type];
     }
   }
 #else
@@ -1468,6 +1489,7 @@
 #endif  // CONFIG_RECT_TX
   int ext_tx_set;
 #endif  // CONFIG_EXT_TX
+  assert(bs >= BLOCK_8X8);
 
   if (tx_select) {
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
@@ -1494,8 +1516,9 @@
   if (evaluate_rect_tx) {
     const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
     RD_STATS this_rd_stats;
-    ext_tx_set = get_ext_tx_set(rect_tx_size, bs, 1);
-    if (ext_tx_used_inter[ext_tx_set][tx_type]) {
+    ext_tx_set = get_ext_tx_set(rect_tx_size, bs, is_inter);
+    if ((is_inter && ext_tx_used_inter[ext_tx_set][tx_type]) ||
+        (!is_inter && ext_tx_used_intra[ext_tx_set][tx_type])) {
       rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type,
                     rect_tx_size);
       best_tx_size = rect_tx_size;
@@ -1651,13 +1674,15 @@
         if (is_inter) {
           if (ext_tx_set > 0)
             this_rd_stats.rate +=
-                cpi->inter_tx_type_costs[ext_tx_set][mbmi->tx_size]
+                cpi->inter_tx_type_costs[ext_tx_set]
+                                        [txsize_sqr_map[mbmi->tx_size]]
                                         [mbmi->tx_type];
         } else {
           if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
             this_rd_stats.rate +=
-                cpi->intra_tx_type_costs[ext_tx_set][mbmi->tx_size][mbmi->mode]
-                                        [mbmi->tx_type];
+                cpi->intra_tx_type_costs[ext_tx_set]
+                                        [txsize_sqr_map[mbmi->tx_size]]
+                                        [mbmi->mode][mbmi->tx_type];
         }
       }
 
@@ -1977,10 +2002,7 @@
       }
       this_rd = RDCOST(x->rdmult, x->rddiv, this_rate, tokenonly_rd_stats.dist);
       if (!xd->lossless[mbmi->segment_id] && mbmi->sb_type >= BLOCK_8X8) {
-        tokenonly_rd_stats.rate -=
-            cpi->tx_size_cost[max_txsize_lookup[bsize] - TX_8X8]
-                             [get_tx_size_context(xd)]
-                             [tx_size_to_depth(mbmi->tx_size)];
+        tokenonly_rd_stats.rate -= tx_size_cost(cpi, x, bsize, mbmi->tx_size);
       }
       if (this_rd < *best_rd) {
         *best_rd = this_rd;
@@ -2005,11 +2027,48 @@
 }
 #endif  // CONFIG_PALETTE
 
-static int64_t rd_pick_intra4x4block(
+// Wrappers to make function pointers usable.
+static void inv_txfm_add_4x8_wrapper(const tran_low_t *input, uint8_t *dest,
+                                     int stride, int eob, TX_TYPE tx_type,
+                                     int lossless) {
+  (void)lossless;
+  av1_inv_txfm_add_4x8(input, dest, stride, eob, tx_type);
+}
+
+static void inv_txfm_add_8x4_wrapper(const tran_low_t *input, uint8_t *dest,
+                                     int stride, int eob, TX_TYPE tx_type,
+                                     int lossless) {
+  (void)lossless;
+  av1_inv_txfm_add_8x4(input, dest, stride, eob, tx_type);
+}
+
+typedef void (*inv_txfm_func_ptr)(const tran_low_t *, uint8_t *, int, int,
+                                  TX_TYPE, int);
+#if CONFIG_AOM_HIGHBITDEPTH
+
+void highbd_inv_txfm_add_4x8_wrapper(const tran_low_t *input, uint8_t *dest,
+                                     int stride, int eob, int bd,
+                                     TX_TYPE tx_type, int is_lossless) {
+  (void)is_lossless;
+  av1_highbd_inv_txfm_add_4x8(input, dest, stride, eob, bd, tx_type);
+}
+
+void highbd_inv_txfm_add_8x4_wrapper(const tran_low_t *input, uint8_t *dest,
+                                     int stride, int eob, int bd,
+                                     TX_TYPE tx_type, int is_lossless) {
+  (void)is_lossless;
+  av1_highbd_inv_txfm_add_8x4(input, dest, stride, eob, bd, tx_type);
+}
+
+typedef void (*highbd_inv_txfm_func_ptr)(const tran_low_t *, uint8_t *, int,
+                                         int, int, TX_TYPE, int);
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+
+static int64_t rd_pick_intra_sub_8x8_y_subblock_mode(
     const AV1_COMP *const cpi, MACROBLOCK *x, int row, int col,
     PREDICTION_MODE *best_mode, const int *bmode_costs, ENTROPY_CONTEXT *a,
     ENTROPY_CONTEXT *l, int *bestrate, int *bestratey, int64_t *bestdistortion,
-    BLOCK_SIZE bsize, int *y_skip, int64_t rd_thresh) {
+    BLOCK_SIZE bsize, TX_SIZE tx_size, int *y_skip, int64_t rd_thresh) {
   const AV1_COMMON *const cm = &cpi->common;
   PREDICTION_MODE mode;
   MACROBLOCKD *const xd = &x->e_mbd;
@@ -2029,14 +2088,38 @@
   ENTROPY_CONTEXT ta[2], tempa[2];
   ENTROPY_CONTEXT tl[2], templ[2];
 #endif
-  const int num_4x4_blocks_wide = num_4x4_blocks_wide_lookup[bsize];
-  const int num_4x4_blocks_high = num_4x4_blocks_high_lookup[bsize];
+
+  const int pred_width_in_4x4_blocks = num_4x4_blocks_wide_lookup[bsize];
+  const int pred_height_in_4x4_blocks = num_4x4_blocks_high_lookup[bsize];
+  const int tx_width_unit = tx_size_wide_unit[tx_size];
+  const int tx_height_unit = tx_size_high_unit[tx_size];
+  const int pred_block_width = block_size_wide[bsize];
+  const int pred_block_height = block_size_high[bsize];
+  const int tx_width = tx_size_wide[tx_size];
+  const int tx_height = tx_size_high[tx_size];
+  const int pred_width_in_transform_blocks = pred_block_width / tx_width;
+  const int pred_height_in_transform_blocks = pred_block_height / tx_height;
   int idx, idy;
   int best_can_skip = 0;
   uint8_t best_dst[8 * 8];
+  inv_txfm_func_ptr inv_txfm_func =
+      (tx_size == TX_4X4) ? av1_inv_txfm_add_4x4
+                          : (tx_size == TX_4X8) ? inv_txfm_add_4x8_wrapper
+                                                : inv_txfm_add_8x4_wrapper;
 #if CONFIG_AOM_HIGHBITDEPTH
   uint16_t best_dst16[8 * 8];
+  highbd_inv_txfm_func_ptr highbd_inv_txfm_func =
+      (tx_size == TX_4X4)
+          ? av1_highbd_inv_txfm_add_4x4
+          : (tx_size == TX_4X8) ? highbd_inv_txfm_add_4x8_wrapper
+                                : highbd_inv_txfm_add_8x4_wrapper;
 #endif
+  const int is_lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+  const int sub_bsize = bsize;
+#else
+  const int sub_bsize = BLOCK_4X4;
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
 
 #if CONFIG_PVQ
   od_rollback_buffer pre_buf, post_buf;
@@ -2044,9 +2127,19 @@
   od_encode_checkpoint(&x->daala_enc, &post_buf);
 #endif
 
-  memcpy(ta, a, num_4x4_blocks_wide * sizeof(a[0]));
-  memcpy(tl, l, num_4x4_blocks_high * sizeof(l[0]));
-  xd->mi[0]->mbmi.tx_size = TX_4X4;
+  assert(bsize < BLOCK_8X8);
+  assert(tx_width < 8 || tx_height < 8);
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+  assert(tx_width == pred_block_width && tx_height == pred_block_height);
+#else
+  assert(tx_width == 4 && tx_height == 4);
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
+
+  memcpy(ta, a, pred_width_in_transform_blocks * sizeof(a[0]));
+  memcpy(tl, l, pred_height_in_transform_blocks * sizeof(l[0]));
+
+  xd->mi[0]->mbmi.tx_size = tx_size;
+
 #if CONFIG_PALETTE
   xd->mi[0]->mbmi.palette_mode_info.palette_size[0] = 0;
 #endif  // CONFIG_PALETTE
@@ -2060,7 +2153,9 @@
       int rate = bmode_costs[mode];
       int can_skip = 1;
 
-      if (!(cpi->sf.intra_y_mode_mask[TX_4X4] & (1 << mode))) continue;
+      if (!(cpi->sf.intra_y_mode_mask[txsize_sqr_up_map[tx_size]] &
+            (1 << mode)))
+        continue;
 
       // Only do the oblique modes if the best so far is
       // one of the neighboring directional modes
@@ -2068,70 +2163,97 @@
         if (conditional_skipintra(mode, *best_mode)) continue;
       }
 
-      memcpy(tempa, ta, num_4x4_blocks_wide * sizeof(ta[0]));
-      memcpy(templ, tl, num_4x4_blocks_high * sizeof(tl[0]));
+      memcpy(tempa, ta, pred_width_in_transform_blocks * sizeof(ta[0]));
+      memcpy(templ, tl, pred_height_in_transform_blocks * sizeof(tl[0]));
 
-      for (idy = 0; idy < num_4x4_blocks_high; ++idy) {
-        for (idx = 0; idx < num_4x4_blocks_wide; ++idx) {
-          const int block = (row + idy) * 2 + (col + idx);
+      for (idy = 0; idy < pred_height_in_transform_blocks; ++idy) {
+        for (idx = 0; idx < pred_width_in_transform_blocks; ++idx) {
+          const int block_raster_idx = (row + idy) * 2 + (col + idx);
+          const int block =
+              av1_raster_order_to_block_index(tx_size, block_raster_idx);
           const uint8_t *const src = &src_init[idx * 4 + idy * 4 * src_stride];
           uint8_t *const dst = &dst_init[idx * 4 + idy * 4 * dst_stride];
-          int16_t *const src_diff =
-              av1_raster_block_offset_int16(BLOCK_8X8, block, p->src_diff);
-          xd->mi[0]->bmi[block].as_mode = mode;
-          av1_predict_intra_block(xd, pd->width, pd->height, TX_4X4, mode, dst,
+          int16_t *const src_diff = av1_raster_block_offset_int16(
+              BLOCK_8X8, block_raster_idx, p->src_diff);
+          int skip;
+          assert(block < 4);
+          assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                         idx == 0 && idy == 0));
+          assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                         block == 0 || block == 2));
+          xd->mi[0]->bmi[block_raster_idx].as_mode = mode;
+          av1_predict_intra_block(xd, pd->width, pd->height, tx_size, mode, dst,
                                   dst_stride, dst, dst_stride, col + idx,
                                   row + idy, 0);
-          aom_highbd_subtract_block(4, 4, src_diff, 8, src, src_stride, dst,
-                                    dst_stride, xd->bd);
-          if (xd->lossless[xd->mi[0]->mbmi.segment_id]) {
-            TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, TX_4X4);
-            const SCAN_ORDER *scan_order = get_scan(cm, TX_4X4, tx_type, 0);
+          aom_highbd_subtract_block(tx_height, tx_width, src_diff, 8, src,
+                                    src_stride, dst, dst_stride, xd->bd);
+          if (is_lossless) {
+            TX_TYPE tx_type =
+                get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
+            const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
             const int coeff_ctx =
-                combine_entropy_contexts(*(tempa + idx), *(templ + idy));
+                combine_entropy_contexts(tempa[idx], templ[idy]);
 #if CONFIG_NEW_QUANT
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                            TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
 #else
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                            TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP);
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
 #endif  // CONFIG_NEW_QUANT
-            ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, TX_4X4,
+            ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, tx_size,
                                      scan_order->scan, scan_order->neighbors,
                                      cpi->sf.use_fast_coef_costing);
-            *(tempa + idx) = !(p->eobs[block] == 0);
-            *(templ + idy) = !(p->eobs[block] == 0);
-            can_skip &= (p->eobs[block] == 0);
+            skip = (p->eobs[block] == 0);
+            can_skip &= skip;
+            tempa[idx] = !skip;
+            templ[idy] = !skip;
+#if CONFIG_EXT_TX
+            if (tx_size == TX_8X4) {
+              tempa[idx + 1] = tempa[idx];
+            } else if (tx_size == TX_4X8) {
+              templ[idy + 1] = templ[idy];
+            }
+#endif  // CONFIG_EXT_TX
+
             if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
               goto next_highbd;
-            av1_highbd_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                        dst_stride, p->eobs[block], xd->bd,
-                                        DCT_DCT, 1);
+            highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
+                                 dst_stride, p->eobs[block], xd->bd, DCT_DCT,
+                                 1);
           } else {
             int64_t dist;
             unsigned int tmp;
-            TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, TX_4X4);
-            const SCAN_ORDER *scan_order = get_scan(cm, TX_4X4, tx_type, 0);
+            TX_TYPE tx_type =
+                get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
+            const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
             const int coeff_ctx =
-                combine_entropy_contexts(*(tempa + idx), *(templ + idy));
+                combine_entropy_contexts(tempa[idx], templ[idy]);
 #if CONFIG_NEW_QUANT
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                            TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
 #else
             av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                            TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP);
+                            tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
 #endif  // CONFIG_NEW_QUANT
-            av1_optimize_b(cm, x, 0, block, TX_4X4, coeff_ctx);
-            ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, TX_4X4,
+            av1_optimize_b(cm, x, 0, block, tx_size, coeff_ctx);
+            ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, tx_size,
                                      scan_order->scan, scan_order->neighbors,
                                      cpi->sf.use_fast_coef_costing);
-            *(tempa + idx) = !(p->eobs[block] == 0);
-            *(templ + idy) = !(p->eobs[block] == 0);
-            can_skip &= (p->eobs[block] == 0);
-            av1_highbd_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                        dst_stride, p->eobs[block], xd->bd,
-                                        tx_type, 0);
-            cpi->fn_ptr[BLOCK_4X4].vf(src, src_stride, dst, dst_stride, &tmp);
+            skip = (p->eobs[block] == 0);
+            can_skip &= skip;
+            tempa[idx] = !skip;
+            templ[idy] = !skip;
+#if CONFIG_EXT_TX
+            if (tx_size == TX_8X4) {
+              tempa[idx + 1] = tempa[idx];
+            } else if (tx_size == TX_4X8) {
+              templ[idy + 1] = templ[idy];
+            }
+#endif  // CONFIG_EXT_TX
+            highbd_inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst,
+                                 dst_stride, p->eobs[block], xd->bd, tx_type,
+                                 0);
+            cpi->fn_ptr[sub_bsize].vf(src, src_stride, dst, dst_stride, &tmp);
             dist = (int64_t)tmp << 4;
             distortion += dist;
             if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
@@ -2150,12 +2272,12 @@
         best_rd = this_rd;
         best_can_skip = can_skip;
         *best_mode = mode;
-        memcpy(a, tempa, num_4x4_blocks_wide * sizeof(tempa[0]));
-        memcpy(l, templ, num_4x4_blocks_high * sizeof(templ[0]));
-        for (idy = 0; idy < num_4x4_blocks_high * 4; ++idy) {
+        memcpy(a, tempa, pred_width_in_transform_blocks * sizeof(tempa[0]));
+        memcpy(l, templ, pred_height_in_transform_blocks * sizeof(templ[0]));
+        for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy) {
           memcpy(best_dst16 + idy * 8,
                  CONVERT_TO_SHORTPTR(dst_init + idy * dst_stride),
-                 num_4x4_blocks_wide * 4 * sizeof(uint16_t));
+                 pred_width_in_transform_blocks * 4 * sizeof(uint16_t));
         }
       }
     next_highbd : {}
@@ -2165,9 +2287,10 @@
 
     if (y_skip) *y_skip &= best_can_skip;
 
-    for (idy = 0; idy < num_4x4_blocks_high * 4; ++idy) {
+    for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy) {
       memcpy(CONVERT_TO_SHORTPTR(dst_init + idy * dst_stride),
-             best_dst16 + idy * 8, num_4x4_blocks_wide * 4 * sizeof(uint16_t));
+             best_dst16 + idy * 8,
+             pred_width_in_transform_blocks * 4 * sizeof(uint16_t));
     }
 
     return best_rd;
@@ -2185,7 +2308,10 @@
     int rate = bmode_costs[mode];
     int can_skip = 1;
 
-    if (!(cpi->sf.intra_y_mode_mask[TX_4X4] & (1 << mode))) continue;
+    if (!(cpi->sf.intra_y_mode_mask[txsize_sqr_up_map[tx_size]] &
+          (1 << mode))) {
+      continue;
+    }
 
     // Only do the oblique modes if the best so far is
     // one of the neighboring directional modes
@@ -2193,25 +2319,29 @@
       if (conditional_skipintra(mode, *best_mode)) continue;
     }
 
-    memcpy(tempa, ta, num_4x4_blocks_wide * sizeof(ta[0]));
-    memcpy(templ, tl, num_4x4_blocks_high * sizeof(tl[0]));
+    memcpy(tempa, ta, pred_width_in_transform_blocks * sizeof(ta[0]));
+    memcpy(templ, tl, pred_height_in_transform_blocks * sizeof(tl[0]));
 
-    for (idy = 0; idy < num_4x4_blocks_high; ++idy) {
-      for (idx = 0; idx < num_4x4_blocks_wide; ++idx) {
-        int block = (row + idy) * 2 + (col + idx);
+    for (idy = 0; idy < pred_height_in_4x4_blocks; idy += tx_height_unit) {
+      for (idx = 0; idx < pred_width_in_4x4_blocks; idx += tx_width_unit) {
+        const int block_raster_idx = (row + idy) * 2 + (col + idx);
+        int block = av1_raster_order_to_block_index(tx_size, block_raster_idx);
         const uint8_t *const src = &src_init[idx * 4 + idy * 4 * src_stride];
         uint8_t *const dst = &dst_init[idx * 4 + idy * 4 * dst_stride];
 #if !CONFIG_PVQ
-        int16_t *const src_diff =
-            av1_raster_block_offset_int16(BLOCK_8X8, block, p->src_diff);
+        int16_t *const src_diff = av1_raster_block_offset_int16(
+            BLOCK_8X8, block_raster_idx, p->src_diff);
 #else
-        int i, j, tx_blk_size;
-        int skip;
-
-        tx_blk_size = 4;
+        int i, j;
 #endif
-        xd->mi[0]->bmi[block].as_mode = mode;
-        av1_predict_intra_block(xd, pd->width, pd->height, TX_4X4, mode, dst,
+        int skip;
+        assert(block < 4);
+        assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                       idx == 0 && idy == 0));
+        assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                       block == 0 || block == 2));
+        xd->mi[0]->bmi[block_raster_idx].as_mode = mode;
+        av1_predict_intra_block(xd, pd->width, pd->height, tx_size, mode, dst,
                                 dst_stride, dst, dst_stride,
 #if CONFIG_CB4X4
                                 2 * (col + idx), 2 * (row + idy),
@@ -2220,21 +2350,23 @@
 #endif
                                 0);
 #if !CONFIG_PVQ
-        aom_subtract_block(4, 4, src_diff, 8, src, src_stride, dst, dst_stride);
+        aom_subtract_block(tx_height, tx_width, src_diff, 8, src, src_stride,
+                           dst, dst_stride);
 #endif
 
-        if (xd->lossless[xd->mi[0]->mbmi.segment_id]) {
-          TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, TX_4X4);
-          const SCAN_ORDER *scan_order = get_scan(cm, TX_4X4, tx_type, 0);
+        if (is_lossless) {
+          TX_TYPE tx_type =
+              get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
+          const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
           const int coeff_ctx =
-              combine_entropy_contexts(*(tempa + idx), *(templ + idy));
+              combine_entropy_contexts(tempa[idx], templ[idy]);
 #if CONFIG_CB4X4
           block = 4 * block;
 #endif
 #if !CONFIG_PVQ
 #if CONFIG_NEW_QUANT
           av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                          TX_4X4, coeff_ctx, AV1_XFORM_QUANT_B_NUQ);
+                          tx_size, coeff_ctx, AV1_XFORM_QUANT_B_NUQ);
 #else
           av1_xform_quant(cm, x, 0, block,
 #if CONFIG_CB4X4
@@ -2242,14 +2374,22 @@
 #else
                           row + idy, col + idx,
 #endif
-                          BLOCK_8X8, TX_4X4, coeff_ctx, AV1_XFORM_QUANT_B);
+                          BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_B);
 #endif  // CONFIG_NEW_QUANT
-          ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, TX_4X4,
+          ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, tx_size,
                                    scan_order->scan, scan_order->neighbors,
                                    cpi->sf.use_fast_coef_costing);
-          *(tempa + idx) = !(p->eobs[block] == 0);
-          *(templ + idy) = !(p->eobs[block] == 0);
-          can_skip &= (p->eobs[block] == 0);
+          skip = (p->eobs[block] == 0);
+          can_skip &= skip;
+          tempa[idx] = !skip;
+          templ[idy] = !skip;
+#if CONFIG_EXT_TX
+          if (tx_size == TX_8X4) {
+            tempa[idx + 1] = tempa[idx];
+          } else if (tx_size == TX_4X8) {
+            templ[idy + 1] = templ[idy];
+          }
+#endif  // CONFIG_EXT_TX
 #else
           (void)scan_order;
 
@@ -2259,40 +2399,41 @@
 #else
                           row + idy, col + idx,
 #endif
-                          BLOCK_8X8, TX_4X4, coeff_ctx, AV1_XFORM_QUANT_B);
+                          BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_B);
 
           ratey += x->rate;
           skip = x->pvq_skip[0];
-          *(tempa + idx) = !skip;
-          *(templ + idy) = !skip;
+          tempa[idx] = !skip;
+          templ[idy] = !skip;
           can_skip &= skip;
 #endif
           if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
             goto next;
 #if CONFIG_PVQ
           if (!skip) {
-            for (j = 0; j < tx_blk_size; j++)
-              for (i = 0; i < tx_blk_size; i++) dst[j * dst_stride + i] = 0;
+            for (j = 0; j < tx_height; j++)
+              for (i = 0; i < tx_width; i++) dst[j * dst_stride + i] = 0;
 #endif
-            av1_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                 dst_stride, p->eobs[block], DCT_DCT, 1);
+            inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst, dst_stride,
+                          p->eobs[block], DCT_DCT, 1);
 #if CONFIG_PVQ
           }
 #endif
         } else {
           int64_t dist;
           unsigned int tmp;
-          TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block, TX_4X4);
-          const SCAN_ORDER *scan_order = get_scan(cm, TX_4X4, tx_type, 0);
+          TX_TYPE tx_type =
+              get_tx_type(PLANE_TYPE_Y, xd, block_raster_idx, tx_size);
+          const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, 0);
           const int coeff_ctx =
-              combine_entropy_contexts(*(tempa + idx), *(templ + idy));
+              combine_entropy_contexts(tempa[idx], templ[idy]);
 #if CONFIG_CB4X4
           block = 4 * block;
 #endif
 #if !CONFIG_PVQ
 #if CONFIG_NEW_QUANT
           av1_xform_quant(cm, x, 0, block, row + idy, col + idx, BLOCK_8X8,
-                          TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
+                          tx_size, coeff_ctx, AV1_XFORM_QUANT_FP_NUQ);
 #else
           av1_xform_quant(cm, x, 0, block,
 #if CONFIG_CB4X4
@@ -2300,15 +2441,23 @@
 #else
                           row + idy, col + idx,
 #endif
-                          BLOCK_8X8, TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP);
+                          BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
 #endif  // CONFIG_NEW_QUANT
-          av1_optimize_b(cm, x, 0, block, TX_4X4, coeff_ctx);
-          ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, TX_4X4,
+          av1_optimize_b(cm, x, 0, block, tx_size, coeff_ctx);
+          ratey += av1_cost_coeffs(cm, x, 0, block, coeff_ctx, tx_size,
                                    scan_order->scan, scan_order->neighbors,
                                    cpi->sf.use_fast_coef_costing);
-          *(tempa + idx) = !(p->eobs[block] == 0);
-          *(templ + idy) = !(p->eobs[block] == 0);
-          can_skip &= (p->eobs[block] == 0);
+          skip = (p->eobs[block] == 0);
+          can_skip &= skip;
+          tempa[idx] = !skip;
+          templ[idy] = !skip;
+#if CONFIG_EXT_TX
+          if (tx_size == TX_8X4) {
+            tempa[idx + 1] = tempa[idx];
+          } else if (tx_size == TX_4X8) {
+            templ[idy + 1] = templ[idy];
+          }
+#endif  // CONFIG_EXT_TX
 #else
           (void)scan_order;
 
@@ -2318,25 +2467,25 @@
 #else
                           row + idy, col + idx,
 #endif
-                          BLOCK_8X8, TX_4X4, coeff_ctx, AV1_XFORM_QUANT_FP);
+                          BLOCK_8X8, tx_size, coeff_ctx, AV1_XFORM_QUANT_FP);
           ratey += x->rate;
           skip = x->pvq_skip[0];
-          *(tempa + idx) = !skip;
-          *(templ + idy) = !skip;
+          tempa[idx] = !skip;
+          templ[idy] = !skip;
           can_skip &= skip;
 #endif
 #if CONFIG_PVQ
           if (!skip) {
-            for (j = 0; j < tx_blk_size; j++)
-              for (i = 0; i < tx_blk_size; i++) dst[j * dst_stride + i] = 0;
+            for (j = 0; j < tx_height; j++)
+              for (i = 0; i < tx_width; i++) dst[j * dst_stride + i] = 0;
 #endif
-            av1_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block), dst,
-                                 dst_stride, p->eobs[block], tx_type, 0);
+            inv_txfm_func(BLOCK_OFFSET(pd->dqcoeff, block), dst, dst_stride,
+                          p->eobs[block], tx_type, 0);
 #if CONFIG_PVQ
           }
 #endif
           // No need for av1_block_error2_c because the ssz is unused
-          cpi->fn_ptr[BLOCK_4X4].vf(src, src_stride, dst, dst_stride, &tmp);
+          cpi->fn_ptr[sub_bsize].vf(src, src_stride, dst, dst_stride, &tmp);
           dist = (int64_t)tmp << 4;
           distortion += dist;
           // To use the pixel domain distortion, the step below needs to be
@@ -2358,14 +2507,14 @@
       best_rd = this_rd;
       best_can_skip = can_skip;
       *best_mode = mode;
-      memcpy(a, tempa, num_4x4_blocks_wide * sizeof(tempa[0]));
-      memcpy(l, templ, num_4x4_blocks_high * sizeof(templ[0]));
+      memcpy(a, tempa, pred_width_in_transform_blocks * sizeof(tempa[0]));
+      memcpy(l, templ, pred_height_in_transform_blocks * sizeof(templ[0]));
 #if CONFIG_PVQ
       od_encode_checkpoint(&x->daala_enc, &post_buf);
 #endif
-      for (idy = 0; idy < num_4x4_blocks_high * 4; ++idy)
+      for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy)
         memcpy(best_dst + idy * 8, dst_init + idy * dst_stride,
-               num_4x4_blocks_wide * 4);
+               pred_width_in_transform_blocks * 4);
     }
   next : {}
 #if CONFIG_PVQ
@@ -2381,9 +2530,9 @@
 
   if (y_skip) *y_skip &= best_can_skip;
 
-  for (idy = 0; idy < num_4x4_blocks_high * 4; ++idy)
+  for (idy = 0; idy < pred_height_in_transform_blocks * 4; ++idy)
     memcpy(dst_init + idy * dst_stride, best_dst + idy * 8,
-           num_4x4_blocks_wide * 4);
+           pred_width_in_transform_blocks * 4);
 
   return best_rd;
 }
@@ -2392,55 +2541,65 @@
                                             MACROBLOCK *mb, int *rate,
                                             int *rate_y, int64_t *distortion,
                                             int *y_skip, int64_t best_rd) {
-  int i, j;
   const MACROBLOCKD *const xd = &mb->e_mbd;
   MODE_INFO *const mic = xd->mi[0];
   const MODE_INFO *above_mi = xd->above_mi;
   const MODE_INFO *left_mi = xd->left_mi;
-  const BLOCK_SIZE bsize = xd->mi[0]->mbmi.sb_type;
-  const int num_4x4_blocks_wide = num_4x4_blocks_wide_lookup[bsize];
-  const int num_4x4_blocks_high = num_4x4_blocks_high_lookup[bsize];
+  MB_MODE_INFO *const mbmi = &mic->mbmi;
+  const BLOCK_SIZE bsize = mbmi->sb_type;
+  const int pred_width_in_4x4_blocks = num_4x4_blocks_wide_lookup[bsize];
+  const int pred_height_in_4x4_blocks = num_4x4_blocks_high_lookup[bsize];
   int idx, idy;
   int cost = 0;
   int64_t total_distortion = 0;
   int tot_rate_y = 0;
   int64_t total_rd = 0;
   const int *bmode_costs = cpi->mbmode_cost[0];
+  const int is_lossless = xd->lossless[mbmi->segment_id];
+#if CONFIG_EXT_TX && CONFIG_RECT_TX
+  const TX_SIZE tx_size = is_lossless ? TX_4X4 : max_txsize_rect_lookup[bsize];
+#else
+  const TX_SIZE tx_size = TX_4X4;
+#endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
 
 #if CONFIG_EXT_INTRA
 #if CONFIG_INTRA_INTERP
-  mic->mbmi.intra_filter = INTRA_FILTER_LINEAR;
+  mbmi->intra_filter = INTRA_FILTER_LINEAR;
 #endif  // CONFIG_INTRA_INTERP
 #endif  // CONFIG_EXT_INTRA
 #if CONFIG_FILTER_INTRA
-  mic->mbmi.filter_intra_mode_info.use_filter_intra_mode[0] = 0;
+  mbmi->filter_intra_mode_info.use_filter_intra_mode[0] = 0;
 #endif  // CONFIG_FILTER_INTRA
 
   // TODO(any): Add search of the tx_type to improve rd performance at the
   // expense of speed.
-  mic->mbmi.tx_type = DCT_DCT;
-  mic->mbmi.tx_size = TX_4X4;
+  mbmi->tx_type = DCT_DCT;
+  mbmi->tx_size = tx_size;
 
   if (y_skip) *y_skip = 1;
 
-  // Pick modes for each sub-block (of size 4x4, 4x8, or 8x4) in an 8x8 block.
-  for (idy = 0; idy < 2; idy += num_4x4_blocks_high) {
-    for (idx = 0; idx < 2; idx += num_4x4_blocks_wide) {
+  // Pick modes for each prediction sub-block (of size 4x4, 4x8, or 8x4) in this
+  // 8x8 coding block.
+  for (idy = 0; idy < 2; idy += pred_height_in_4x4_blocks) {
+    for (idx = 0; idx < 2; idx += pred_width_in_4x4_blocks) {
       PREDICTION_MODE best_mode = DC_PRED;
       int r = INT_MAX, ry = INT_MAX;
       int64_t d = INT64_MAX, this_rd = INT64_MAX;
-      i = idy * 2 + idx;
+      int j;
+      const int pred_block_idx = idy * 2 + idx;
       if (cpi->common.frame_type == KEY_FRAME) {
-        const PREDICTION_MODE A = av1_above_block_mode(mic, above_mi, i);
-        const PREDICTION_MODE L = av1_left_block_mode(mic, left_mi, i);
+        const PREDICTION_MODE A =
+            av1_above_block_mode(mic, above_mi, pred_block_idx);
+        const PREDICTION_MODE L =
+            av1_left_block_mode(mic, left_mi, pred_block_idx);
 
         bmode_costs = cpi->y_mode_costs[A][L];
       }
 
-      this_rd = rd_pick_intra4x4block(
+      this_rd = rd_pick_intra_sub_8x8_y_subblock_mode(
           cpi, mb, idy, idx, &best_mode, bmode_costs,
           xd->plane[0].above_context + idx, xd->plane[0].left_context + idy, &r,
-          &ry, &d, bsize, y_skip, best_rd - total_rd);
+          &ry, &d, bsize, tx_size, y_skip, best_rd - total_rd);
       if (this_rd >= best_rd - total_rd) return INT64_MAX;
 
       total_rd += this_rd;
@@ -2448,33 +2607,33 @@
       total_distortion += d;
       tot_rate_y += ry;
 
-      mic->bmi[i].as_mode = best_mode;
-      for (j = 1; j < num_4x4_blocks_high; ++j)
-        mic->bmi[i + j * 2].as_mode = best_mode;
-      for (j = 1; j < num_4x4_blocks_wide; ++j)
-        mic->bmi[i + j].as_mode = best_mode;
+      mic->bmi[pred_block_idx].as_mode = best_mode;
+      for (j = 1; j < pred_height_in_4x4_blocks; ++j)
+        mic->bmi[pred_block_idx + j * 2].as_mode = best_mode;
+      for (j = 1; j < pred_width_in_4x4_blocks; ++j)
+        mic->bmi[pred_block_idx + j].as_mode = best_mode;
 
       if (total_rd >= best_rd) return INT64_MAX;
     }
   }
-  mic->mbmi.mode = mic->bmi[3].as_mode;
+  mbmi->mode = mic->bmi[3].as_mode;
 
   // Add in the cost of the transform type
-  if (!xd->lossless[mic->mbmi.segment_id]) {
+  if (!is_lossless) {
     int rate_tx_type = 0;
 #if CONFIG_EXT_TX
-    if (get_ext_tx_types(TX_4X4, bsize, 0) > 1) {
-      const int eset = get_ext_tx_set(TX_4X4, bsize, 0);
-      rate_tx_type = cpi->intra_tx_type_costs[eset][TX_4X4][mic->mbmi.mode]
-                                             [mic->mbmi.tx_type];
+    if (get_ext_tx_types(tx_size, bsize, 0) > 1) {
+      const int eset = get_ext_tx_set(tx_size, bsize, 0);
+      rate_tx_type = cpi->intra_tx_type_costs[eset][txsize_sqr_map[tx_size]]
+                                             [mbmi->mode][mbmi->tx_type];
     }
 #else
     rate_tx_type =
-        cpi->intra_tx_type_costs[TX_4X4]
-                                [intra_mode_to_tx_type_context[mic->mbmi.mode]]
-                                [mic->mbmi.tx_type];
+        cpi->intra_tx_type_costs[txsize_sqr_map[tx_size]]
+                                [intra_mode_to_tx_type_context[mbmi->mode]]
+                                [mbmi->tx_type];
 #endif
-    assert(mic->mbmi.tx_size == TX_4X4);
+    assert(mbmi->tx_size == tx_size);
     cost += rate_tx_type;
     tot_rate_y += rate_tx_type;
   }
@@ -2884,7 +3043,6 @@
   const PREDICTION_MODE A = av1_above_block_mode(mic, above_mi, 0);
   const PREDICTION_MODE L = av1_left_block_mode(mic, left_mi, 0);
   const PREDICTION_MODE FINAL_MODE_SEARCH = TM_PRED + 1;
-  const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
 #if CONFIG_PVQ
   od_rollback_buffer pre_buf, post_buf;
 
@@ -2962,9 +3120,7 @@
       // tokenonly rate, but for intra blocks, tx_size is always coded
       // (prediction granularity), so we account for it in the full rate,
       // not the tokenonly rate.
-      this_rate_tokenonly -=
-          cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)]
-                           [tx_size_to_depth(mbmi->tx_size)];
+      this_rate_tokenonly -= tx_size_cost(cpi, x, bsize, mbmi->tx_size);
     }
 #if CONFIG_PALETTE
     if (cpi->common.allow_screen_content_tools && mbmi->mode == DC_PRED)
@@ -4073,7 +4229,9 @@
   pmi->palette_size[1] = 0;
 #endif  // CONFIG_PALETTE
   for (mode = DC_PRED; mode <= TM_PRED; ++mode) {
-    if (!(cpi->sf.intra_uv_mode_mask[max_tx_size] & (1 << mode))) continue;
+    if (!(cpi->sf.intra_uv_mode_mask[txsize_sqr_up_map[max_tx_size]] &
+          (1 << mode)))
+      continue;
 
     mbmi->uv_mode = mode;
 #if CONFIG_EXT_INTRA
@@ -4189,6 +4347,8 @@
   pmi->palette_size[1] = palette_mode_info.palette_size[1];
 #endif  // CONFIG_PALETTE
 
+  // Make sure we actually chose a mode
+  assert(best_rd < INT64_MAX);
   return best_rd;
 }
 
@@ -4550,16 +4710,11 @@
   for (idy = 0; idy < txb_height; idy += num_4x4_h) {
     for (idx = 0; idx < txb_width; idx += num_4x4_w) {
       int64_t dist, ssz, rd, rd1, rd2;
-      int block;
       int coeff_ctx;
-      int k;
-
-      k = i + (idy * 2 + idx);
-      if (tx_size == TX_4X4)
-        block = k;
-      else
-        block = (i ? 2 : 0);
-
+      const int k = i + (idy * 2 + idx);
+      const int block = av1_raster_order_to_block_index(tx_size, k);
+      assert(IMPLIES(tx_size == TX_4X8 || tx_size == TX_8X4,
+                     idx == 0 && idy == 0));
       coeff_ctx = combine_entropy_contexts(*(ta + (k & 1)), *(tl + (k >> 1)));
 #if !CONFIG_PVQ
 #if CONFIG_NEW_QUANT
@@ -8414,7 +8569,6 @@
 #if CONFIG_PALETTE
   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
 #endif  // CONFIG_PALETTE
-  const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
   int rate2 = 0, rate_y = INT_MAX, skippable = 0, rate_uv, rate_dummy, i;
   int dc_mode_index;
   const int *const intra_mode_cost = cpi->mbmode_cost[size_group_lookup[bsize]];
@@ -8491,8 +8645,7 @@
     // tokenonly rate, but for intra blocks, tx_size is always coded
     // (prediction granularity), so we account for it in the full rate,
     // not the tokenonly rate.
-    rate_y -= cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)]
-                               [tx_size_to_depth(mbmi->tx_size)];
+    rate_y -= tx_size_cost(cpi, x, bsize, mbmi->tx_size);
   }
 
   rate2 += av1_cost_bit(cm->fc->filter_intra_probs[0],
@@ -8642,21 +8795,21 @@
   int64_t best_intra_rd = INT64_MAX;
   unsigned int best_pred_sse = UINT_MAX;
   PREDICTION_MODE best_intra_mode = DC_PRED;
-  int rate_uv_intra[TX_SIZES], rate_uv_tokenonly[TX_SIZES];
-  int64_t dist_uvs[TX_SIZES];
-  int skip_uvs[TX_SIZES];
-  PREDICTION_MODE mode_uv[TX_SIZES];
+  int rate_uv_intra[TX_SIZES_ALL], rate_uv_tokenonly[TX_SIZES_ALL];
+  int64_t dist_uvs[TX_SIZES_ALL];
+  int skip_uvs[TX_SIZES_ALL];
+  PREDICTION_MODE mode_uv[TX_SIZES_ALL];
 #if CONFIG_PALETTE
-  PALETTE_MODE_INFO pmi_uv[TX_SIZES];
+  PALETTE_MODE_INFO pmi_uv[TX_SIZES_ALL];
 #endif  // CONFIG_PALETTE
 #if CONFIG_EXT_INTRA
-  int8_t uv_angle_delta[TX_SIZES];
+  int8_t uv_angle_delta[TX_SIZES_ALL];
   int is_directional_mode, angle_stats_ready = 0;
   uint8_t directional_mode_skip_mask[INTRA_MODES];
 #endif  // CONFIG_EXT_INTRA
 #if CONFIG_FILTER_INTRA
   int8_t dc_skipped = 1;
-  FILTER_INTRA_MODE_INFO filter_intra_mode_info_uv[TX_SIZES];
+  FILTER_INTRA_MODE_INFO filter_intra_mode_info_uv[TX_SIZES_ALL];
 #endif  // CONFIG_FILTER_INTRA
   const int intra_cost_penalty = av1_get_intra_cost_penalty(
       cm->base_qindex, cm->y_dc_delta_q, cm->bit_depth);
@@ -8676,7 +8829,6 @@
   int64_t mode_threshold[MAX_MODES];
   int *mode_map = tile_data->mode_map[bsize];
   const int mode_search_skip_flags = sf->mode_search_skip_flags;
-  const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
 #if CONFIG_PVQ
   od_rollback_buffer pre_buf;
 #endif
@@ -8751,7 +8903,7 @@
                            &comp_mode_p);
 
   for (i = 0; i < REFERENCE_MODES; ++i) best_pred_rd[i] = INT64_MAX;
-  for (i = 0; i < TX_SIZES; i++) rate_uv_intra[i] = INT_MAX;
+  for (i = 0; i < TX_SIZES_ALL; i++) rate_uv_intra[i] = INT_MAX;
   for (i = 0; i < TOTAL_REFS_PER_FRAME; ++i) x->pred_sse[i] = INT_MAX;
   for (i = 0; i < MB_MODE_COUNT; ++i) {
     for (k = 0; k < TOTAL_REFS_PER_FRAME; ++k) {
@@ -9281,9 +9433,7 @@
         // tokenonly rate, but for intra blocks, tx_size is always coded
         // (prediction granularity), so we account for it in the full rate,
         // not the tokenonly rate.
-        rate_y -=
-            cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)]
-                             [tx_size_to_depth(mbmi->tx_size)];
+        rate_y -= tx_size_cost(cpi, x, bsize, mbmi->tx_size);
       }
 #if CONFIG_EXT_INTRA
       if (is_directional_mode) {
diff --git a/av1/encoder/tokenize.c b/av1/encoder/tokenize.c
index 5725154..222b8ba 100644
--- a/av1/encoder/tokenize.c
+++ b/av1/encoder/tokenize.c
@@ -333,7 +333,8 @@
   struct macroblockd_plane *pd = &xd->plane[plane];
   const PLANE_TYPE type = pd->plane_type;
   const int ref = is_inter_block(mbmi);
-  const TX_TYPE tx_type = get_tx_type(type, xd, block, tx_size);
+  const int block_raster_idx = av1_block_index_to_raster_order(tx_size, block);
+  const TX_TYPE tx_type = get_tx_type(type, xd, block_raster_idx, tx_size);
   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, ref);
   int pt = get_entropy_context(tx_size, pd->above_context + blk_col,
                                pd->left_context + blk_row);
@@ -438,7 +439,7 @@
   int pt; /* near block/prev token context index */
   int c;
   TOKENEXTRA *t = *tp; /* store tokens starting here */
-  int eob = p->eobs[block];
+  const int eob = p->eobs[block];
   const PLANE_TYPE type = pd->plane_type;
   const tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
 #if CONFIG_SUPERTX
@@ -447,7 +448,8 @@
   const int segment_id = mbmi->segment_id;
 #endif  // CONFIG_SUEPRTX
   const int16_t *scan, *nb;
-  const TX_TYPE tx_type = get_tx_type(type, xd, block, tx_size);
+  const int block_raster_idx = av1_block_index_to_raster_order(tx_size, block);
+  const TX_TYPE tx_type = get_tx_type(type, xd, block_raster_idx, tx_size);
   const SCAN_ORDER *const scan_order =
       get_scan(cm, tx_size, tx_type, is_inter_block(mbmi));
   const int ref = is_inter_block(mbmi);
@@ -497,6 +499,7 @@
     skip_eob = (token == ZERO_TOKEN);
   }
   if (c < seg_eob) {
+    assert(!skip_eob);  // The last token must be non-zero.
     add_token(&t, coef_probs[band[c]][pt],
 #if CONFIG_EC_MULTISYMBOL
               NULL,