Add the option of using 1:4/4:1 tx_size+sb_type

Change-Id: I96e5ff72caee8935efb7535afa3a534175bc425c
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index b8d8533..56eb28c 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1047,10 +1047,17 @@
 }
 #endif
 
+static INLINE TX_SIZE get_max_rect_tx_size(BLOCK_SIZE bsize, int is_inter) {
+  if (is_inter)
+    return max_txsize_rect_lookup[bsize];
+  else
+    return max_txsize_rect_intra_lookup[bsize];
+}
+
 static INLINE TX_SIZE tx_size_from_tx_mode(BLOCK_SIZE bsize, TX_MODE tx_mode,
                                            int is_inter) {
   const TX_SIZE largest_tx_size = tx_mode_to_biggest_tx_size[tx_mode];
-  const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bsize];
+  const TX_SIZE max_rect_tx_size = get_max_rect_tx_size(bsize, is_inter);
   (void)is_inter;
   if (bsize == BLOCK_4X4)
     return AOMMIN(max_txsize_lookup[bsize], largest_tx_size);
@@ -1293,7 +1300,7 @@
 static INLINE int get_vartx_max_txsize(const MB_MODE_INFO *const mbmi,
                                        BLOCK_SIZE bsize, int subsampled) {
   (void)mbmi;
-  TX_SIZE max_txsize = max_txsize_rect_lookup[bsize];
+  TX_SIZE max_txsize = get_max_rect_tx_size(bsize, is_inter_block(mbmi));
 
 #if CONFIG_EXT_PARTITION && CONFIG_TX64X64
   // The decoder is designed so that it can process 64x64 luma pixels at a
diff --git a/av1/common/common_data.h b/av1/common/common_data.h
index ab3fc4e..5edd740 100644
--- a/av1/common/common_data.h
+++ b/av1/common/common_data.h
@@ -647,6 +647,46 @@
 #endif  // CONFIG_EXT_PARTITION
 };
 
+static const TX_SIZE max_txsize_rect_intra_lookup[BLOCK_SIZES_ALL] = {
+  // 2X2,    2X4,      4X2,
+  TX_4X4,    TX_4X4,   TX_4X4,
+  //                   4X4
+                       TX_4X4,
+  // 4X8,    8X4,      8X8
+  TX_4X8,    TX_8X4,   TX_8X8,
+  // 8X16,   16X8,     16X16
+  TX_8X16,   TX_16X8,  TX_16X16,
+  // 16X32,  32X16,    32X32
+  TX_16X32,  TX_32X16, TX_32X32,
+#if CONFIG_TX64X64
+  // 32X64,  64X32,
+  TX_32X64,  TX_64X32,
+  // 64X64
+  TX_64X64,
+#if CONFIG_EXT_PARTITION
+  // 64x128, 128x64,   128x128
+  TX_64X64,  TX_64X64, TX_64X64,
+#endif  // CONFIG_EXT_PARTITION
+#else
+  // 32X64,  64X32,
+  TX_32X32,  TX_32X32,
+  // 64X64
+  TX_32X32,
+#if CONFIG_EXT_PARTITION
+  // 64x128, 128x64,   128x128
+  TX_32X32,  TX_32X32, TX_32X32,
+#endif  // CONFIG_EXT_PARTITION
+#endif  // CONFIG_TX64X64
+  // 4x16,   16x4,     8x32
+  TX_4X8,    TX_8X4,   TX_8X16,
+  // 32x8    16x64,    64x16
+  TX_16X8,   TX_16X32, TX_32X16,
+#if CONFIG_EXT_PARTITION
+  // 32x128  128x32
+  TX_32X32,  TX_32X32
+#endif  // CONFIG_EXT_PARTITION
+};
+
 static const TX_SIZE max_txsize_rect_lookup[BLOCK_SIZES_ALL] = {
   // 2X2,    2X4,      4X2,
   TX_4X4,    TX_4X4,   TX_4X4,
@@ -677,11 +717,11 @@
   TX_32X32,  TX_32X32, TX_32X32,
 #endif  // CONFIG_EXT_PARTITION
 #endif  // CONFIG_TX64X64
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT
   // 4x16,   16x4,     8x32
-  TX_4X16,   TX_16X4,  TX_8X32,
+  TX_4X16,   TX_16X4,  TX_8X16,
   // 32x8
-  TX_32X8,
+  TX_16X8,
 #else
   // 4x16,   16x4,     8x32
   TX_4X8,    TX_8X4,   TX_8X16,
diff --git a/av1/common/entropy.h b/av1/common/entropy.h
index 478bc32..39c3db3 100644
--- a/av1/common/entropy.h
+++ b/av1/common/entropy.h
@@ -330,7 +330,7 @@
       left_ec = !!*(const uint64_t *)l;
       break;
 #endif  // CONFIG_TX64X64
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
     case TX_4X16:
       above_ec = a[0] != 0;
       left_ec = !!*(const uint32_t *)l;
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 8b6d615..3ab61e3 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -37,6 +37,10 @@
 #define MAX_SB_SIZE (1 << MAX_SB_SIZE_LOG2)
 #define MAX_SB_SQUARE (MAX_SB_SIZE * MAX_SB_SIZE)
 
+#if CONFIG_EXT_PARTITION_TYPES
+#define USE_RECT_TX_EXT 0
+#endif
+
 // Min superblock size
 #define MIN_SB_SIZE_LOG2 6
 
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 7a96237..bfc348b 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -1866,7 +1866,7 @@
 }
 
 // These will be used by the masked-tx experiment in the future.
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
 static void inv_txfm_add_4x16(const tran_low_t *input, uint8_t *dest,
                               int stride, const TxfmParam *txfm_param) {
   av1_iht4x16_64_add(input, dest, stride, txfm_param);
@@ -2375,7 +2375,7 @@
       // case.
       inv_txfm_add_4x4(input, dest, stride, txfm_param);
       break;
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
     case TX_32X8: inv_txfm_add_32x8(input, dest, stride, txfm_param); break;
     case TX_8X32: inv_txfm_add_8x32(input, dest, stride, txfm_param); break;
     case TX_16X4: inv_txfm_add_16x4(input, dest, stride, txfm_param); break;
diff --git a/av1/common/reconintra.c b/av1/common/reconintra.c
index 76af119..94ce012 100644
--- a/av1/common/reconintra.c
+++ b/av1/common/reconintra.c
@@ -2429,7 +2429,7 @@
   const int block_width = block_size_wide[bsize];
   const int block_height = block_size_high[bsize];
 #if INTRA_USES_RECT_TRANSFORMS
-  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
+  const TX_SIZE tx_size = get_max_rect_tx_size(bsize, 0);
   assert(tx_size < TX_SIZES_ALL);
 #else
   const TX_SIZE tx_size = max_txsize_lookup[bsize];
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 2edf667..e22e1aa 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -350,31 +350,23 @@
     assert(IMPLIES(tx_size <= TX_4X4, sub_txs == tx_size));
     assert(IMPLIES(tx_size > TX_4X4, sub_txs < tx_size));
 #endif
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int sub_step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    const int sub_step = bsw * bsh;
 
-    assert(bsl > 0);
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-#if CONFIG_RECT_TX_EXT
-      int is_wide_tx = tx_size_wide_unit[sub_txs] > tx_size_high_unit[sub_txs];
-      const int offsetr =
-          is_qttx ? (is_wide_tx ? i * tx_size_high_unit[sub_txs] : 0)
-                  : blk_row + ((i >> 1) * bsl);
-      const int offsetc =
-          is_qttx ? (is_wide_tx ? 0 : i * tx_size_wide_unit[sub_txs])
-                  : blk_col + (i & 0x01) * bsl;
-#else
-      const int offsetr = blk_row + (i >> 1) * bsl;
-      const int offsetc = blk_col + (i & 0x01) * bsl;
-#endif
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        const int offsetr = blk_row + row;
+        const int offsetc = blk_col + col;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      decode_reconstruct_tx(cm, xd, r, mbmi, plane, plane_bsize, offsetr,
-                            offsetc, block, sub_txs, eob_total);
-      block += sub_step;
+        decode_reconstruct_tx(cm, xd, r, mbmi, plane, plane_bsize, offsetr,
+                              offsetc, block, sub_txs, eob_total);
+        block += sub_step;
+      }
     }
   }
 }
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index d885269..83376f6 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -469,8 +469,8 @@
 
   if (is_split) {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
     if (counts) ++counts->txfm_partition[ctx][1];
 
@@ -487,12 +487,14 @@
       return;
     }
 
-    assert(bsl > 0);
-    for (i = 0; i < 4; ++i) {
-      int offsetr = blk_row + (i >> 1) * bsl;
-      int offsetc = blk_col + (i & 0x01) * bsl;
-      read_tx_size_vartx(cm, xd, mbmi, counts, sub_txs, depth + 1, offsetr,
-                         offsetc, r);
+    assert(bsw > 0 && bsh > 0);
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        int offsetr = blk_row + row;
+        int offsetc = blk_col + col;
+        read_tx_size_vartx(cm, xd, mbmi, counts, sub_txs, depth + 1, offsetr,
+                           offsetc, r);
+      }
     }
   } else {
     int idx, idy;
@@ -554,11 +556,11 @@
             quarter_tx = 1;
           }
           return quarter_tx ? quarter_txsize_lookup[bsize]
-                            : max_txsize_rect_lookup[bsize];
+                            : get_max_rect_tx_size(bsize, is_inter);
         }
 #endif  // CONFIG_RECT_TX_EXT
 
-        return max_txsize_rect_lookup[bsize];
+        return get_max_rect_tx_size(bsize, is_inter);
       }
       return coded_tx_size;
     } else {
@@ -566,7 +568,7 @@
     }
   } else {
     assert(IMPLIES(tx_mode == ONLY_4X4, bsize == BLOCK_4X4));
-    return max_txsize_rect_lookup[bsize];
+    return get_max_rect_tx_size(bsize, is_inter);
   }
 }
 
@@ -1086,7 +1088,7 @@
     int idx, idy;
     if ((cm->tx_mode == TX_MODE_SELECT && block_signals_txsize(bsize) &&
          !xd->lossless[mbmi->segment_id] && !mbmi->skip)) {
-      const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
+      const TX_SIZE max_tx_size = get_max_rect_tx_size(bsize, 0);
       const int bh = tx_size_high_unit[max_tx_size];
       const int bw = tx_size_wide_unit[max_tx_size];
       mbmi->min_tx_size = TX_SIZES_ALL;
@@ -2690,7 +2692,7 @@
 
   if (cm->tx_mode == TX_MODE_SELECT && block_signals_txsize(bsize) &&
       !mbmi->skip && inter_block && !xd->lossless[mbmi->segment_id]) {
-    const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
+    const TX_SIZE max_tx_size = get_max_rect_tx_size(bsize, inter_block);
     const int bh = tx_size_high_unit[max_tx_size];
     const int bw = tx_size_wide_unit[max_tx_size];
     const int width = block_size_wide[bsize] >> tx_size_wide_log2[0];
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 550afde..8656b28 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -279,8 +279,8 @@
     // TODO(yuec): set correct txfm partition update for qttx
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
 #if CONFIG_NEW_MULTISYMBOL
     aom_write_symbol(w, 1, ec_ctx->txfm_partition_cdf[ctx], 2);
@@ -294,13 +294,14 @@
       return;
     }
 
-    assert(bsl > 0);
-    for (i = 0; i < 4; ++i) {
-      int offsetr = blk_row + (i >> 1) * bsl;
-      int offsetc = blk_col + (i & 0x01) * bsl;
-      write_tx_size_vartx(cm, xd, mbmi, sub_txs, depth + 1, offsetr, offsetc,
-                          w);
-    }
+    assert(bsw > 0 && bsh > 0);
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh)
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        int offsetr = blk_row + row;
+        int offsetc = blk_col + col;
+        write_tx_size_vartx(cm, xd, mbmi, sub_txs, depth + 1, offsetr, offsetc,
+                            w);
+      }
   }
 }
 
@@ -654,21 +655,24 @@
 #endif
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
-    assert(bsl > 0);
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-      const int offsetr = blk_row + (i >> 1) * bsl;
-      const int offsetc = blk_col + (i & 0x01) * bsl;
-      const int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+    for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
+      for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw) {
+        const int offsetr = blk_row + r;
+        const int offsetc = blk_col + c;
+        const int step = bsh * bsw;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      pack_txb_tokens(w, cm, x, tp, tok_end, xd, mbmi, plane, plane_bsize,
-                      bit_depth, block, offsetr, offsetc, sub_txs, token_stats);
-      block += step;
+        pack_txb_tokens(w, cm, x, tp, tok_end, xd, mbmi, plane, plane_bsize,
+                        bit_depth, block, offsetr, offsetc, sub_txs,
+                        token_stats);
+        block += step;
+      }
     }
   }
 }
@@ -716,31 +720,23 @@
 #else
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
 #endif
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
-    assert(bsl > 0);
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-#if CONFIG_RECT_TX_EXT
-      int is_wide_tx = tx_size_wide_unit[sub_txs] > tx_size_high_unit[sub_txs];
-      const int offsetr =
-          is_qttx ? (is_wide_tx ? i * tx_size_high_unit[sub_txs] : 0)
-                  : blk_row + (i >> 1) * bsl;
-      const int offsetc =
-          is_qttx ? (is_wide_tx ? 0 : i * tx_size_wide_unit[sub_txs])
-                  : blk_col + (i & 0x01) * bsl;
-#else
-      const int offsetr = blk_row + (i >> 1) * bsl;
-      const int offsetc = blk_col + (i & 0x01) * bsl;
-#endif
-      const int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+    for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
+      for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw) {
+        const int offsetr = blk_row + r;
+        const int offsetc = blk_col + c;
+        const int step = bsh * bsw;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      pack_txb_tokens(w, tp, tok_end, xd, mbmi, plane, plane_bsize, bit_depth,
-                      block, offsetr, offsetc, sub_txs, token_stats);
-      block += step;
+        pack_txb_tokens(w, tp, tok_end, xd, mbmi, plane, plane_bsize, bit_depth,
+                        block, offsetr, offsetc, sub_txs, token_stats);
+        block += step;
+      }
     }
   }
 }
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 8e6aac2..f0619b9 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -4626,8 +4626,8 @@
                           xd->left_txfm_context + blk_row, tx_size, tx_size);
   } else {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bs = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
 
     ++counts->txfm_partition[ctx][1];
 #if CONFIG_NEW_MULTISYMBOL
@@ -4644,11 +4644,14 @@
       return;
     }
 
-    for (i = 0; i < 4; ++i) {
-      int offsetr = (i >> 1) * bs;
-      int offsetc = (i & 0x01) * bs;
-      update_txfm_count(x, xd, counts, sub_txs, depth + 1, blk_row + offsetr,
-                        blk_col + offsetc, allow_update_cdf);
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        int offsetr = row;
+        int offsetc = col;
+
+        update_txfm_count(x, xd, counts, sub_txs, depth + 1, blk_row + offsetr,
+                          blk_col + offsetc, allow_update_cdf);
+      }
     }
   }
 }
@@ -4936,14 +4939,15 @@
         tx_partition_count_update(cm, x, bsize, mi_row, mi_col, td->counts,
                                   tile_data->allow_update_cdf);
       } else {
-        if (tx_size != max_txsize_rect_lookup[bsize]) ++x->txb_split_count;
+        if (tx_size != get_max_rect_tx_size(bsize, 0)) ++x->txb_split_count;
       }
 
 #if CONFIG_RECT_TX_EXT
       if (is_quarter_tx_allowed(xd, mbmi, is_inter) &&
-          quarter_txsize_lookup[bsize] != max_txsize_rect_lookup[bsize] &&
+          quarter_txsize_lookup[bsize] !=
+              get_max_rect_tx_size(bsize, is_inter) &&
           (mbmi->tx_size == quarter_txsize_lookup[bsize] ||
-           mbmi->tx_size == max_txsize_rect_lookup[bsize])) {
+           mbmi->tx_size == get_max_rect_tx_size(bsize, is_inter))) {
         const int use_qttx = mbmi->tx_size == quarter_txsize_lookup[bsize];
         ++td->counts->quarter_tx_size[use_qttx];
 #if CONFIG_NEW_MULTISYMBOL
@@ -4975,7 +4979,8 @@
             mi_8x8[mis * j + i]->mbmi.tx_size = intra_tx_size;
 
       mbmi->min_tx_size = get_min_tx_size(intra_tx_size);
-      if (intra_tx_size != max_txsize_rect_lookup[bsize]) ++x->txb_split_count;
+      if (intra_tx_size != get_max_rect_tx_size(bsize, is_inter))
+        ++x->txb_split_count;
     }
 
 #if !CONFIG_TXK_SEL
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 1c9f3f4..a95fa6f 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -670,30 +670,22 @@
     assert(IMPLIES(tx_size > TX_4X4, sub_txs < tx_size));
 #endif
     // This is the square transform block partition entry point.
-    int bsl = tx_size_wide_unit[sub_txs];
-    int i;
-    assert(bsl > 0);
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    const int step = bsh * bsw;
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-#if CONFIG_RECT_TX_EXT
-      int is_wide_tx = tx_size_wide_unit[sub_txs] > tx_size_high_unit[sub_txs];
-      const int offsetr =
-          is_qttx ? (is_wide_tx ? i * tx_size_high_unit[sub_txs] : 0)
-                  : blk_row + ((i >> 1) * bsl);
-      const int offsetc =
-          is_qttx ? (is_wide_tx ? 0 : i * tx_size_wide_unit[sub_txs])
-                  : blk_col + ((i & 0x01) * bsl);
-#else
-      const int offsetr = blk_row + ((i >> 1) * bsl);
-      const int offsetc = blk_col + ((i & 0x01) * bsl);
-#endif
-      int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        const int offsetr = blk_row + row;
+        const int offsetc = blk_col + col;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      encode_block_inter(plane, block, offsetr, offsetc, plane_bsize, sub_txs,
-                         arg);
-      block += step;
+        encode_block_inter(plane, block, offsetr, offsetc, plane_bsize, sub_txs,
+                           arg);
+        block += step;
+      }
     }
   }
 }
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index 76b417c..fa9a41b 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -145,7 +145,7 @@
 }
 #endif  // CONFIG_TX64X64
 
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
 static void fwd_txfm_16x4(const int16_t *src_diff, tran_low_t *coeff,
                           int diff_stride, TxfmParam *txfm_param) {
   av1_fht16x4(src_diff, coeff, diff_stride, txfm_param);
@@ -503,7 +503,7 @@
       fwd_txfm_32x16(src_diff, coeff, diff_stride, txfm_param);
       break;
     case TX_4X4: fwd_txfm_4x4(src_diff, coeff, diff_stride, txfm_param); break;
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
     case TX_4X16:
       fwd_txfm_4x16(src_diff, coeff, diff_stride, txfm_param);
       break;
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index c001277..f74fac6 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -897,7 +897,7 @@
       for (i = 0; i < num_4x4_h; i += 4)
         t_left[i] = !!*(const uint32_t *)&left[i];
       break;
-#if CONFIG_RECT_TX_EXT
+#if CONFIG_RECT_TX_EXT || (CONFIG_EXT_PARTITION_TYPES && USE_RECT_TX_EXT)
     case TX_4X16:
       memcpy(t_above, above, sizeof(ENTROPY_CONTEXT) * num_4x4_w);
       for (i = 0; i < num_4x4_h; i += 4)
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index abfa631..4954c8f 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2663,7 +2663,7 @@
     TX_TYPE tx_type;
     for (tx_type = tx_start; tx_type < tx_end; ++tx_type) {
       if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) continue;
-      const TX_SIZE rect_tx_size = max_txsize_rect_lookup[bs];
+      const TX_SIZE rect_tx_size = get_max_rect_tx_size(bs, is_inter);
       RD_STATS this_rd_stats;
       const TxSetType tx_set_type = get_ext_tx_set_type(
           rect_tx_size, bs, is_inter, cm->reduced_tx_set_used);
@@ -3890,7 +3890,6 @@
   int64_t this_rd = INT64_MAX;
   ENTROPY_CONTEXT *pta = ta + blk_col;
   ENTROPY_CONTEXT *ptl = tl + blk_row;
-  int i;
   int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
                                    mbmi->sb_type, tx_size);
   int64_t sum_rd = INT64_MAX;
@@ -3904,7 +3903,7 @@
 #if CONFIG_RECT_TX_EXT
   TX_SIZE quarter_txsize = quarter_txsize_lookup[mbmi->sb_type];
   int check_qttx = is_quarter_tx_allowed(xd, mbmi, is_inter_block(mbmi)) &&
-                   tx_size == max_txsize_rect_lookup[mbmi->sb_type] &&
+                   tx_size == get_max_rect_tx_size(mbmi->sb_type, 1) &&
                    quarter_txsize != tx_size;
   int is_qttx_picked = 0;
   int eobs_qttx[2] = { 0, 0 };
@@ -4102,13 +4101,14 @@
 #endif  // CONFIG_MRC_TX
       ) {
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int sub_step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    int sub_step = bsw * bsh;
     RD_STATS this_rd_stats;
     int this_cost_valid = 1;
     int64_t tmp_rd = 0;
 #if CONFIG_DIST_8X8
-    int sub8x8_eob[4];
+    int sub8x8_eob[4] = { 0, 0, 0, 0 };
 #endif
     sum_rd_stats.rate = x->txfm_partition_cost[ctx][1];
 
@@ -4116,29 +4116,35 @@
 
     ref_best_rd = AOMMIN(this_rd, ref_best_rd);
 
-    for (i = 0; i < 4 && this_cost_valid; ++i) {
-      int offsetr = blk_row + (i >> 1) * bsl;
-      int offsetc = blk_col + (i & 0x01) * bsl;
+    for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
+      for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw) {
+        int offsetr = blk_row + r;
+        int offsetc = blk_col + c;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      select_tx_block(cpi, x, offsetr, offsetc, plane, block, sub_txs,
-                      depth + 1, plane_bsize, ta, tl, tx_above, tx_left,
-                      &this_rd_stats, ref_best_rd - tmp_rd, &this_cost_valid,
-                      fast, 0);
+        select_tx_block(cpi, x, offsetr, offsetc, plane, block, sub_txs,
+                        depth + 1, plane_bsize, ta, tl, tx_above, tx_left,
+                        &this_rd_stats, ref_best_rd - tmp_rd, &this_cost_valid,
+                        fast, 0);
 #if CONFIG_DIST_8X8
-      if (x->using_dist_8x8 && plane == 0 && tx_size == TX_8X8) {
-        sub8x8_eob[i] = p->eobs[block];
-      }
-#endif  // CONFIG_DIST_8X8
-      av1_merge_rd_stats(&sum_rd_stats, &this_rd_stats);
-
-      tmp_rd = RDCOST(x->rdmult, sum_rd_stats.rate, sum_rd_stats.dist);
-#if CONFIG_DIST_8X8
-      if (!x->using_dist_8x8)
+        if (!x->using_dist_8x8)
 #endif
-        if (this_rd < tmp_rd) break;
-      block += sub_step;
+          if (!this_cost_valid) break;
+#if CONFIG_DIST_8X8
+        if (x->using_dist_8x8 && plane == 0 && tx_size == TX_8X8) {
+          sub8x8_eob[2 * (r / bsh) + (c / bsw)] = p->eobs[block];
+        }
+#endif  // CONFIG_DIST_8X8
+        av1_merge_rd_stats(&sum_rd_stats, &this_rd_stats);
+
+        tmp_rd = RDCOST(x->rdmult, sum_rd_stats.rate, sum_rd_stats.dist);
+#if CONFIG_DIST_8X8
+        if (!x->using_dist_8x8)
+#endif
+          if (this_rd < tmp_rd) break;
+        block += sub_step;
+      }
     }
 #if CONFIG_DIST_8X8
     if (x->using_dist_8x8 && this_cost_valid && plane == 0 &&
@@ -4158,7 +4164,7 @@
       const int pred_idx = (blk_row * pred_stride + blk_col)
                            << tx_size_wide_log2[0];
       int16_t *pred = &pd->pred[pred_idx];
-      int j;
+      int i, j;
       int row, col;
 
 #if CONFIG_HIGHBITDEPTH
@@ -4325,7 +4331,7 @@
     const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, pd);
     const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
     const int mi_height = block_size_high[plane_bsize] >> tx_size_high_log2[0];
-    const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
+    const TX_SIZE max_tx_size = get_max_rect_tx_size(plane_bsize, 1);
     const int bh = tx_size_high_unit[max_tx_size];
     const int bw = tx_size_wide_unit[max_tx_size];
     int idx, idy;
@@ -4698,7 +4704,8 @@
   DECLARE_ALIGNED(32, tran_low_t, DCT_coefs[32 * 32]);
   TxfmParam param;
   param.tx_type = DCT_DCT;
-  param.tx_size = max_txsize_rect_lookup[bsize];
+  param.tx_size =
+      get_max_rect_tx_size(bsize, is_inter_block(&x->e_mbd.mi[0]->mbmi));
   param.bd = xd->bd;
   param.is_hbd = get_bitdepth_data_path_index(xd);
   param.lossless = 0;
@@ -4740,7 +4747,7 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const int n4 = bsize_to_num_blk(bsize);
-  const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
+  const TX_SIZE tx_size = get_max_rect_tx_size(bsize, is_inter_block(mbmi));
   mbmi->tx_type = DCT_DCT;
   for (int idy = 0; idy < xd->n8_h; ++idy)
     for (int idx = 0; idx < xd->n8_w; ++idx)
@@ -5042,7 +5049,8 @@
       const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
       const int mi_height =
           block_size_high[plane_bsize] >> tx_size_high_log2[0];
-      const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
+      const TX_SIZE max_tx_size =
+          get_max_rect_tx_size(plane_bsize, is_inter_block(mbmi));
       const int bh = tx_size_high_unit[max_tx_size];
       const int bw = tx_size_wide_unit[max_tx_size];
       int idx, idy;
@@ -5052,7 +5060,6 @@
       ENTROPY_CONTEXT tl[2 * MAX_MIB_SIZE];
       RD_STATS pn_rd_stats;
       av1_init_rd_stats(&pn_rd_stats);
-
       av1_get_entropy_contexts(bsize, 0, pd, ta, tl);
 
       for (idy = 0; idy < mi_height; idy += bh) {
diff --git a/av1/encoder/tokenize.c b/av1/encoder/tokenize.c
index ef18298..4bc36c9 100644
--- a/av1/encoder/tokenize.c
+++ b/av1/encoder/tokenize.c
@@ -612,32 +612,23 @@
     // Half the block size in transform block unit.
     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
 #endif
-    const int bsl = tx_size_wide_unit[sub_txs];
-    int i;
+    const int bsw = tx_size_wide_unit[sub_txs];
+    const int bsh = tx_size_high_unit[sub_txs];
+    const int step = bsw * bsh;
 
-    assert(bsl > 0);
+    assert(bsw > 0 && bsh > 0);
 
-    for (i = 0; i < 4; ++i) {
-#if CONFIG_RECT_TX_EXT
-      int is_wide_tx = tx_size_wide_unit[sub_txs] > tx_size_high_unit[sub_txs];
-      const int offsetr =
-          is_qttx ? (is_wide_tx ? i * tx_size_high_unit[sub_txs] : 0)
-                  : blk_row + ((i >> 1) * bsl);
-      const int offsetc =
-          is_qttx ? (is_wide_tx ? 0 : i * tx_size_wide_unit[sub_txs])
-                  : blk_col + ((i & 0x01) * bsl);
-#else
-      const int offsetr = blk_row + ((i >> 1) * bsl);
-      const int offsetc = blk_col + ((i & 0x01) * bsl);
-#endif
+    for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+      for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+        const int offsetr = blk_row + row;
+        const int offsetc = blk_col + col;
 
-      int step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
+        if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
 
-      if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
-
-      tokenize_vartx(td, t, dry_run, sub_txs, plane_bsize, offsetr, offsetc,
-                     block, plane, arg);
-      block += step;
+        tokenize_vartx(td, t, dry_run, sub_txs, plane_bsize, offsetr, offsetc,
+                       block, plane, arg);
+        block += step;
+      }
     }
   }
 }