Reduce memory usage of inter_tx_size[] in MB_MODE_INFO

Reduce the length of inter_tx_size[] from 1024 to 16.

On a cif test sequence,
encoder memory consumption decreases by 18% (380MB -> 312MB);
decoder memory consumption decreases by 56% (21.4MB -> 9.4MB).

Change-Id: I42928eb9312748f96f4393c8d8040791f38f98b6
diff --git a/av1/common/av1_loopfilter.c b/av1/common/av1_loopfilter.c
index 51100590..deab086 100644
--- a/av1/common/av1_loopfilter.c
+++ b/av1/common/av1_loopfilter.c
@@ -1419,11 +1419,8 @@
     const int col_mask = 1 << c_step;
 
     if (is_inter_block(mbmi) && !mbmi->skip) {
-      const int tx_row_idx =
-          (blk_row * mi_size_high[BLOCK_8X8] << TX_UNIT_HIGH_LOG2) >> 1;
-      const int tx_col_idx =
-          (blk_col * mi_size_wide[BLOCK_8X8] << TX_UNIT_WIDE_LOG2) >> 1;
-      const TX_SIZE mb_tx_size = mbmi->inter_tx_size[tx_row_idx][tx_col_idx];
+      const TX_SIZE mb_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
+          sb_type, blk_row, blk_col)];
       tx_size = (plane->plane_type == PLANE_TYPE_UV)
                     ? av1_get_uv_tx_size(mbmi, ss_x, ss_y)
                     : mb_tx_size;
@@ -1990,29 +1987,13 @@
                         : av1_get_uv_tx_size(mbmi, plane_ptr->subsampling_x,
                                              plane_ptr->subsampling_y);
   assert(tx_size < TX_SIZES_ALL);
-
-  // mi_row and mi_col is the absolute position of the MI block.
-  // idx_c and idx_r is the relative offset of the MI within the super block
-  // c and r is the relative offset of the 8x8 block within the supert block
-  // blk_row and block_col is the relative offset of the current 8x8 block
-  // within the current partition.
-  const int idx_c = mi_col & MAX_MIB_MASK;
-  const int idx_r = mi_row & MAX_MIB_MASK;
-  const int c = idx_c >> mi_width_log2_lookup[BLOCK_8X8];
-  const int r = idx_r >> mi_height_log2_lookup[BLOCK_8X8];
-  const BLOCK_SIZE sb_type = mi->mbmi.sb_type;
-  const int blk_row = r & (num_8x8_blocks_high_lookup[sb_type] - 1);
-  const int blk_col = c & (num_8x8_blocks_wide_lookup[sb_type] - 1);
-
   if (is_inter_block(mbmi) && !mbmi->skip) {
-    const int tx_row_idx =
-        (blk_row * mi_size_high[BLOCK_8X8] << TX_UNIT_HIGH_LOG2) >> 1;
-    const int tx_col_idx =
-        (blk_col * mi_size_wide[BLOCK_8X8] << TX_UNIT_WIDE_LOG2) >> 1;
-    const TX_SIZE mb_tx_size = mbmi->inter_tx_size[tx_row_idx][tx_col_idx];
-
+    const BLOCK_SIZE sb_type = mi->mbmi.sb_type;
+    const int blk_row = mi_row & (mi_size_high[sb_type] - 1);
+    const int blk_col = mi_col & (mi_size_wide[sb_type] - 1);
+    const TX_SIZE mb_tx_size =
+        mbmi->inter_tx_size[av1_get_txb_size_index(sb_type, blk_row, blk_col)];
     assert(mb_tx_size < TX_SIZES_ALL);
-
     tx_size = (plane == AOM_PLANE_Y)
                   ? mb_tx_size
                   : av1_get_uv_tx_size(mbmi, plane_ptr->subsampling_x,
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 3368eb5..9f957d3 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -231,16 +231,19 @@
   COMPOUND_TYPE interinter_compound_type;
 } INTERINTER_COMPOUND_DATA;
 
-// This structure now relates to 8x8 block regions.
+#if CONFIG_TX64X64
+#define INTER_TX_SIZE_BUF_LEN 16
+#else
+#define INTER_TX_SIZE_BUF_LEN 256
+#endif
+// This structure now relates to 4x4 block regions.
 typedef struct MB_MODE_INFO {
   // Common for both INTER and INTRA blocks
   BLOCK_SIZE sb_type;
   PREDICTION_MODE mode;
   TX_SIZE tx_size;
-  // TODO(jingning): This effectively assigned a separate entry for each
-  // 8x8 block. Apparently it takes much more space than needed.
-  TX_SIZE inter_tx_size[MAX_MIB_SIZE][MAX_MIB_SIZE];
   TX_SIZE min_tx_size;
+  uint8_t inter_tx_size[INTER_TX_SIZE_BUF_LEN];
   int8_t skip;
 #if CONFIG_EXT_SKIP
   int8_t skip_mode;
@@ -889,6 +892,20 @@
   return ss_size_lookup[bsize][pd->subsampling_x][pd->subsampling_y];
 }
 
+static INLINE int av1_get_txb_size_index(BLOCK_SIZE bsize, int blk_row,
+                                         int blk_col) {
+  TX_SIZE txs = max_txsize_rect_lookup[1][bsize];
+  for (int level = 0; level < MAX_VARTX_DEPTH - 1; ++level)
+    txs = sub_tx_size_map[1][txs];
+  const int tx_w = tx_size_wide_unit[txs];
+  const int tx_h = tx_size_high_unit[txs];
+  const int bw_uint = mi_size_wide[bsize];
+  const int stride = bw_uint / tx_w;
+  const int index = (blk_row / tx_h) * stride + (blk_col / tx_w);
+  assert(index < INTER_TX_SIZE_BUF_LEN);
+  return index;
+}
+
 static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
                                       const MACROBLOCKD *xd, int blk_row,
                                       int blk_col, TX_SIZE tx_size,
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 212d953..aaeebb0 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -176,11 +176,10 @@
   (void)mi_row;
   (void)mi_col;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
-  const int tx_row = blk_row >> (1 - pd->subsampling_y);
-  const int tx_col = blk_col >> (1 - pd->subsampling_x);
   const TX_SIZE plane_tx_size =
       plane ? av1_get_uv_tx_size(mbmi, pd->subsampling_x, pd->subsampling_y)
-            : mbmi->inter_tx_size[tx_row][tx_col];
+            : mbmi->inter_tx_size[av1_get_txb_size_index(plane_bsize, blk_row,
+                                                         blk_col)];
   // Scale to match transform block unit.
   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 0661451..1ef9021 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -353,22 +353,20 @@
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
   (void)cm;
   int is_split = 0;
-  const int tx_row = blk_row >> 1;
-  const int tx_col = blk_col >> 1;
-  const int max_blocks_high = max_block_high(xd, mbmi->sb_type, 0);
-  const int max_blocks_wide = max_block_wide(xd, mbmi->sb_type, 0);
-  TX_SIZE(*const inter_tx_size)
-  [MAX_MIB_SIZE] =
-      (TX_SIZE(*)[MAX_MIB_SIZE]) & mbmi->inter_tx_size[tx_row][tx_col];
+  const BLOCK_SIZE bsize = mbmi->sb_type;
+  const int max_blocks_high = max_block_high(xd, bsize, 0);
+  const int max_blocks_wide = max_block_wide(xd, bsize, 0);
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
   assert(tx_size > TX_4X4);
 
   if (depth == MAX_VARTX_DEPTH) {
-    int idx, idy;
-    inter_tx_size[0][0] = tx_size;
-    for (idy = 0; idy < AOMMAX(1, tx_size_high_unit[tx_size] / 2); ++idy)
-      for (idx = 0; idx < AOMMAX(1, tx_size_wide_unit[tx_size] / 2); ++idx)
-        inter_tx_size[idy][idx] = tx_size;
+    for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
+      for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
+        const int index =
+            av1_get_txb_size_index(bsize, blk_row + idy, blk_col + idx);
+        mbmi->inter_tx_size[index] = tx_size;
+      }
+    }
     mbmi->tx_size = tx_size;
     mbmi->min_tx_size = TXSIZEMIN(mbmi->min_tx_size, tx_size);
     txfm_partition_update(xd->above_txfm_context + blk_col,
@@ -376,9 +374,9 @@
     return;
   }
 
-  int ctx = txfm_partition_context(xd->above_txfm_context + blk_col,
-                                   xd->left_txfm_context + blk_row,
-                                   mbmi->sb_type, tx_size);
+  const int ctx = txfm_partition_context(xd->above_txfm_context + blk_col,
+                                         xd->left_txfm_context + blk_row,
+                                         mbmi->sb_type, tx_size);
   is_split = aom_read_symbol(r, ec_ctx->txfm_partition_cdf[ctx], 2, ACCT_STR);
 
   if (is_split) {
@@ -387,11 +385,13 @@
     const int bsh = tx_size_high_unit[sub_txs];
 
     if (sub_txs == TX_4X4) {
-      int idx, idy;
-      inter_tx_size[0][0] = sub_txs;
-      for (idy = 0; idy < AOMMAX(1, tx_size_high_unit[tx_size] / 2); ++idy)
-        for (idx = 0; idx < AOMMAX(1, tx_size_wide_unit[tx_size] / 2); ++idx)
-          inter_tx_size[idy][idx] = inter_tx_size[0][0];
+      for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
+        for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
+          const int index =
+              av1_get_txb_size_index(bsize, blk_row + idy, blk_col + idx);
+          mbmi->inter_tx_size[index] = sub_txs;
+        }
+      }
       mbmi->tx_size = sub_txs;
       mbmi->min_tx_size = mbmi->tx_size;
       txfm_partition_update(xd->above_txfm_context + blk_col,
@@ -409,11 +409,13 @@
       }
     }
   } else {
-    int idx, idy;
-    inter_tx_size[0][0] = tx_size;
-    for (idy = 0; idy < AOMMAX(1, tx_size_high_unit[tx_size] / 2); ++idy)
-      for (idx = 0; idx < AOMMAX(1, tx_size_wide_unit[tx_size] / 2); ++idx)
-        inter_tx_size[idy][idx] = tx_size;
+    for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
+      for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
+        const int index =
+            av1_get_txb_size_index(bsize, blk_row + idy, blk_col + idx);
+        mbmi->inter_tx_size[index] = tx_size;
+      }
+    }
     mbmi->tx_size = tx_size;
     mbmi->min_tx_size = TXSIZEMIN(mbmi->min_tx_size, tx_size);
     txfm_partition_update(xd->above_txfm_context + blk_col,
@@ -933,9 +935,7 @@
       }
     } else {
       mbmi->tx_size = read_tx_size(cm, xd, 1, !mbmi->skip, r);
-      for (int idy = 0; idy < height; ++idy)
-        for (int idx = 0; idx < width; ++idx)
-          mbmi->inter_tx_size[idy >> 1][idx >> 1] = mbmi->tx_size;
+      memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
       mbmi->min_tx_size = mbmi->tx_size;
       set_txfm_ctxs(mbmi->tx_size, xd->n8_w, xd->n8_h, mbmi->skip, xd);
     }
@@ -2188,14 +2188,8 @@
         read_tx_size_vartx(cm, xd, mbmi, max_tx_size, 0, idy, idx, r);
   } else {
     mbmi->tx_size = read_tx_size(cm, xd, inter_block, !mbmi->skip, r);
-
-    if (inter_block) {
-      const int width = block_size_wide[bsize] >> tx_size_wide_log2[0];
-      const int height = block_size_high[bsize] >> tx_size_high_log2[0];
-      for (int idy = 0; idy < height; ++idy)
-        for (int idx = 0; idx < width; ++idx)
-          mbmi->inter_tx_size[idy >> 1][idx >> 1] = mbmi->tx_size;
-    }
+    if (inter_block)
+      memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
     mbmi->min_tx_size = mbmi->tx_size;
     set_txfm_ctxs(mbmi->tx_size, xd->n8_w, xd->n8_h, mbmi->skip, xd);
   }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index e5c4d6a..ff1d8dc 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -177,8 +177,6 @@
                                 aom_writer *w) {
   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
   (void)cm;
-  const int tx_row = blk_row >> 1;
-  const int tx_col = blk_col >> 1;
   const int max_blocks_high = max_block_high(xd, mbmi->sb_type, 0);
   const int max_blocks_wide = max_block_wide(xd, mbmi->sb_type, 0);
 
@@ -190,12 +188,13 @@
     return;
   }
 
-  int ctx = txfm_partition_context(xd->above_txfm_context + blk_col,
-                                   xd->left_txfm_context + blk_row,
-                                   mbmi->sb_type, tx_size);
-
+  const int ctx = txfm_partition_context(xd->above_txfm_context + blk_col,
+                                         xd->left_txfm_context + blk_row,
+                                         mbmi->sb_type, tx_size);
+  const int txb_size_index =
+      av1_get_txb_size_index(mbmi->sb_type, blk_row, blk_col);
   const int write_txfm_partition =
-      tx_size == mbmi->inter_tx_size[tx_row][tx_col];
+      tx_size == mbmi->inter_tx_size[txb_size_index];
   if (write_txfm_partition) {
     aom_write_symbol(w, 0, ec_ctx->txfm_partition_cdf[ctx], 2);
 
@@ -464,8 +463,6 @@
                             int block, int blk_row, int blk_col,
                             TX_SIZE tx_size, TOKEN_STATS *token_stats) {
   const struct macroblockd_plane *const pd = &xd->plane[plane];
-  const int tx_row = blk_row >> (1 - pd->subsampling_y);
-  const int tx_col = blk_col >> (1 - pd->subsampling_x);
   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
 
@@ -473,7 +470,8 @@
 
   const TX_SIZE plane_tx_size =
       plane ? av1_get_uv_tx_size(mbmi, pd->subsampling_x, pd->subsampling_y)
-            : mbmi->inter_tx_size[tx_row][tx_col];
+            : mbmi->inter_tx_size[av1_get_txb_size_index(plane_bsize, blk_row,
+                                                         blk_col)];
 
   if (tx_size == plane_tx_size || plane) {
     TOKEN_STATS tmp_token_stats;
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index a3e8616..3751c19 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -113,7 +113,7 @@
   TX_TYPE tx_type;
   TX_SIZE tx_size;
   TX_SIZE min_tx_size;
-  TX_SIZE inter_tx_size[MAX_MIB_SIZE][MAX_MIB_SIZE];
+  TX_SIZE inter_tx_size[INTER_TX_SIZE_BUF_LEN];
   uint8_t blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
 #if CONFIG_TXK_SEL
   TX_TYPE txk_type[MAX_SB_SQUARE / (TX_SIZE_W_MIN * TX_SIZE_H_MIN)];
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a060574..57f5e33 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -390,10 +390,7 @@
     mbmi->tx_size = (TX_SIZE)TXSIZEMAX(mbmi->tx_size, min_tx_size);
   }
   if (is_inter_block(mbmi)) {
-    for (int idy = 0; idy < xd->n8_h; ++idy) {
-      for (int idx = 0; idx < xd->n8_w; ++idx)
-        mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
-    }
+    memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
     mbmi->min_tx_size = mbmi->tx_size;
   }
 }
@@ -4691,14 +4688,14 @@
                               int blk_row, int blk_col,
                               uint8_t allow_update_cdf) {
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
-  const int tx_row = blk_row >> 1;
-  const int tx_col = blk_col >> 1;
-  const int max_blocks_high = max_block_high(xd, mbmi->sb_type, 0);
-  const int max_blocks_wide = max_block_wide(xd, mbmi->sb_type, 0);
+  const BLOCK_SIZE bsize = mbmi->sb_type;
+  const int max_blocks_high = max_block_high(xd, bsize, 0);
+  const int max_blocks_wide = max_block_wide(xd, bsize, 0);
   int ctx = txfm_partition_context(xd->above_txfm_context + blk_col,
                                    xd->left_txfm_context + blk_row,
                                    mbmi->sb_type, tx_size);
-  const TX_SIZE plane_tx_size = mbmi->inter_tx_size[tx_row][tx_col];
+  const int txb_size_index = av1_get_txb_size_index(bsize, blk_row, blk_col);
+  const TX_SIZE plane_tx_size = mbmi->inter_tx_size[txb_size_index];
 
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
   assert(tx_size > TX_4X4);
@@ -4729,7 +4726,7 @@
     ++x->txb_split_count;
 
     if (sub_txs == TX_4X4) {
-      mbmi->inter_tx_size[tx_row][tx_col] = TX_4X4;
+      mbmi->inter_tx_size[txb_size_index] = TX_4X4;
       mbmi->tx_size = TX_4X4;
       txfm_partition_update(xd->above_txfm_context + blk_col,
                             xd->left_txfm_context + blk_row, TX_4X4, tx_size);
@@ -4774,11 +4771,11 @@
 static void set_txfm_context(MACROBLOCKD *xd, TX_SIZE tx_size, int blk_row,
                              int blk_col) {
   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
-  const int tx_row = blk_row >> 1;
-  const int tx_col = blk_col >> 1;
-  const int max_blocks_high = max_block_high(xd, mbmi->sb_type, 0);
-  const int max_blocks_wide = max_block_wide(xd, mbmi->sb_type, 0);
-  const TX_SIZE plane_tx_size = mbmi->inter_tx_size[tx_row][tx_col];
+  const BLOCK_SIZE bsize = mbmi->sb_type;
+  const int max_blocks_high = max_block_high(xd, bsize, 0);
+  const int max_blocks_wide = max_block_wide(xd, bsize, 0);
+  const int txb_size_index = av1_get_txb_size_index(bsize, blk_row, blk_col);
+  const TX_SIZE plane_tx_size = mbmi->inter_tx_size[txb_size_index];
 
   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
 
@@ -4789,7 +4786,7 @@
 
   } else {
     if (tx_size == TX_8X8) {
-      mbmi->inter_tx_size[tx_row][tx_col] = TX_4X4;
+      mbmi->inter_tx_size[txb_size_index] = TX_4X4;
       mbmi->tx_size = TX_4X4;
       txfm_partition_update(xd->above_txfm_context + blk_col,
                             xd->left_txfm_context + blk_row, TX_4X4, tx_size);
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index 967f1fb..429bd2b 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -615,8 +615,6 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
-  const int tx_row = blk_row >> (1 - pd->subsampling_y);
-  const int tx_col = blk_col >> (1 - pd->subsampling_x);
   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
 
@@ -624,7 +622,8 @@
 
   const TX_SIZE plane_tx_size =
       plane ? av1_get_uv_tx_size(mbmi, pd->subsampling_x, pd->subsampling_y)
-            : mbmi->inter_tx_size[tx_row][tx_col];
+            : mbmi->inter_tx_size[av1_get_txb_size_index(plane_bsize, blk_row,
+                                                         blk_col)];
 
   if (tx_size == plane_tx_size || plane) {
     encode_block(plane, block, blk_row, blk_col, plane_bsize, tx_size, arg,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 7f2abc2..2ef8a4b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -3856,11 +3856,6 @@
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   struct macroblock_plane *const p = &x->plane[plane];
   struct macroblockd_plane *const pd = &xd->plane[plane];
-  const int tx_row = blk_row >> (1 - pd->subsampling_y);
-  const int tx_col = blk_col >> (1 - pd->subsampling_x);
-  TX_SIZE(*const inter_tx_size)
-  [MAX_MIB_SIZE] =
-      (TX_SIZE(*)[MAX_MIB_SIZE]) & mbmi->inter_tx_size[tx_row][tx_col];
   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
   const int bw = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
@@ -3908,7 +3903,8 @@
   rd_stats->ref_rdcost = ref_best_rd;
   rd_stats->zero_rate = zero_blk_rate;
   if (cpi->common.tx_mode == TX_MODE_SELECT || tx_size == TX_4X4) {
-    inter_tx_size[0][0] = tx_size;
+    const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
+    mbmi->inter_tx_size[index] = tx_size;
     av1_tx_block_rd_b(
         cpi, x, tx_size, blk_row, blk_col, plane, block, plane_bsize, pta, ptl,
         rd_stats, fast,
@@ -4104,8 +4100,7 @@
   }
 
   if (this_rd < sum_rd) {
-    int idx, idy;
-    TX_SIZE tx_size_selected = tx_size;
+    const TX_SIZE tx_size_selected = tx_size;
 
 #if CONFIG_LV_MAP
     p->txb_entropy_ctx[block] = tmp_eob;
@@ -4117,10 +4112,14 @@
 
     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
                           tx_size);
-    inter_tx_size[0][0] = tx_size_selected;
-    for (idy = 0; idy < AOMMAX(1, tx_size_high_unit[tx_size] / 2); ++idy)
-      for (idx = 0; idx < AOMMAX(1, tx_size_wide_unit[tx_size] / 2); ++idx)
-        inter_tx_size[idy][idx] = tx_size_selected;
+    for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
+      for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
+        const int index =
+            av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
+        mbmi->inter_tx_size[index] = tx_size_selected;
+      }
+    }
+
     mbmi->tx_size = tx_size_selected;
 #if CONFIG_TXK_SEL
     mbmi->txk_type[txk_idx] = best_tx_type;
@@ -4218,7 +4217,6 @@
   int s0 = x->skip_cost[skip_ctx][0];
   int s1 = x->skip_cost[skip_ctx][1];
   int64_t rd;
-  int row, col;
   const int max_blocks_high = max_block_high(xd, bsize, 0);
   const int max_blocks_wide = max_block_wide(xd, bsize, 0);
 
@@ -4237,11 +4235,14 @@
                          rd_info_tree);
   if (rd_stats->rate == INT_MAX) return INT64_MAX;
 
-  mbmi->min_tx_size = mbmi->inter_tx_size[0][0];
-  for (row = 0; row < max_blocks_high / 2; ++row)
-    for (col = 0; col < max_blocks_wide / 2; ++col)
+  mbmi->min_tx_size = mbmi->inter_tx_size[0];
+  for (int row = 0; row < max_blocks_high; ++row) {
+    for (int col = 0; col < max_blocks_wide; ++col) {
+      const int index = av1_get_txb_size_index(bsize, row, col);
       mbmi->min_tx_size =
-          TXSIZEMIN(mbmi->min_tx_size, mbmi->inter_tx_size[row][col]);
+          TXSIZEMIN(mbmi->min_tx_size, mbmi->inter_tx_size[index]);
+    }
+  }
 
   if (fast) {
     // Do a better (non-fast) search with tx sizes already decided.
@@ -4287,8 +4288,6 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   struct macroblockd_plane *const pd = &xd->plane[plane];
-  const int tx_row = blk_row >> (1 - pd->subsampling_y);
-  const int tx_col = blk_col >> (1 - pd->subsampling_x);
   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
 
@@ -4298,7 +4297,8 @@
 
   const TX_SIZE plane_tx_size =
       plane ? av1_get_uv_tx_size(mbmi, pd->subsampling_x, pd->subsampling_y)
-            : mbmi->inter_tx_size[tx_row][tx_col];
+            : mbmi->inter_tx_size[av1_get_txb_size_index(plane_bsize, blk_row,
+                                                         blk_col)];
 
   int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
                                    mbmi->sb_type, tx_size);
@@ -4485,9 +4485,8 @@
   tx_rd_info->min_tx_size = mbmi->min_tx_size;
   memcpy(tx_rd_info->blk_skip, x->blk_skip[0],
          sizeof(tx_rd_info->blk_skip[0]) * n4);
-  for (int idy = 0; idy < xd->n8_h; ++idy)
-    for (int idx = 0; idx < xd->n8_w; ++idx)
-      tx_rd_info->inter_tx_size[idy][idx] = mbmi->inter_tx_size[idy][idx];
+  av1_copy(tx_rd_info->inter_tx_size, mbmi->inter_tx_size);
+
 #if CONFIG_TXK_SEL
   av1_copy(tx_rd_info->txk_type, mbmi->txk_type);
 #endif  // CONFIG_TXK_SEL
@@ -4503,9 +4502,8 @@
   mbmi->min_tx_size = tx_rd_info->min_tx_size;
   memcpy(x->blk_skip[0], tx_rd_info->blk_skip,
          sizeof(tx_rd_info->blk_skip[0]) * n4);
-  for (int idy = 0; idy < xd->n8_h; ++idy)
-    for (int idx = 0; idx < xd->n8_w; ++idx)
-      mbmi->inter_tx_size[idy][idx] = tx_rd_info->inter_tx_size[idy][idx];
+  av1_copy(mbmi->inter_tx_size, tx_rd_info->inter_tx_size);
+
 #if CONFIG_TXK_SEL
   av1_copy(mbmi->txk_type, tx_rd_info->txk_type);
 #endif  // CONFIG_TXK_SEL
@@ -4763,9 +4761,7 @@
          sizeof(mbmi->txk_type[0]) *
              (MAX_SB_SQUARE / (TX_SIZE_W_MIN * TX_SIZE_H_MIN)));
 #endif
-  for (int idy = 0; idy < xd->n8_h; ++idy)
-    for (int idx = 0; idx < xd->n8_w; ++idx)
-      mbmi->inter_tx_size[idy][idx] = tx_size;
+  memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
   mbmi->tx_size = tx_size;
   mbmi->min_tx_size = tx_size;
   memset(x->blk_skip[0], 1, sizeof(uint8_t) * n4);
@@ -4822,7 +4818,7 @@
   int64_t best_rd = INT64_MAX;
   TX_TYPE tx_type, best_tx_type = DCT_DCT;
   const int is_inter = is_inter_block(mbmi);
-  TX_SIZE best_tx_size[MAX_MIB_SIZE][MAX_MIB_SIZE] = { { 0 } };
+  TX_SIZE best_tx_size[INTER_TX_SIZE_BUF_LEN] = { 0 };
   TX_SIZE best_tx = max_txsize_rect_lookup[1][bsize];
   TX_SIZE best_min_tx_size = TX_SIZES_ALL;
   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE * 8];
@@ -4833,7 +4829,6 @@
   TX_TYPE txk_end = TX_TYPES;
 #endif
   const int n4 = bsize_to_num_blk(bsize);
-  int idx, idy;
   // Get the tx_size 1 level down
   const TX_SIZE min_tx_size =
       sub_tx_size_map[1][max_txsize_rect_lookup[1][bsize]];
@@ -4943,9 +4938,7 @@
       best_min_tx_size = mbmi->min_tx_size;
       memcpy(best_blk_skip, x->blk_skip[0], sizeof(best_blk_skip[0]) * n4);
       found = 1;
-      for (idy = 0; idy < xd->n8_h; ++idy)
-        for (idx = 0; idx < xd->n8_w; ++idx)
-          best_tx_size[idy][idx] = mbmi->inter_tx_size[idy][idx];
+      av1_copy(best_tx_size, mbmi->inter_tx_size);
     }
 
 #if !CONFIG_TXK_SEL
@@ -4968,9 +4961,7 @@
   // We found a candidate transform to use. Copy our results from the "best"
   // array into mbmi.
   mbmi->tx_type = best_tx_type;
-  for (idy = 0; idy < xd->n8_h; ++idy)
-    for (idx = 0; idx < xd->n8_w; ++idx)
-      mbmi->inter_tx_size[idy][idx] = best_tx_size[idy][idx];
+  av1_copy(mbmi->inter_tx_size, best_tx_size);
   mbmi->tx_size = best_tx;
   mbmi->min_tx_size = best_min_tx_size;
   memcpy(x->blk_skip[0], best_blk_skip, sizeof(best_blk_skip[0]) * n4);
@@ -4987,8 +4978,6 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   struct macroblockd_plane *const pd = &xd->plane[plane];
-  const int tx_row = blk_row >> (1 - pd->subsampling_y);
-  const int tx_col = blk_col >> (1 - pd->subsampling_x);
   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
 
@@ -4998,7 +4987,8 @@
 
   const TX_SIZE plane_tx_size =
       plane ? av1_get_uv_tx_size(mbmi, pd->subsampling_x, pd->subsampling_y)
-            : mbmi->inter_tx_size[tx_row][tx_col];
+            : mbmi->inter_tx_size[av1_get_txb_size_index(plane_bsize, blk_row,
+                                                         blk_col)];
 
   if (tx_size == plane_tx_size || plane) {
     ENTROPY_CONTEXT *ta = above_ctx + blk_col;
@@ -7902,11 +7892,8 @@
         select_tx_type_yrd(cpi, x, rd_stats_y, bsize, mi_row, mi_col,
                            ref_best_rd);
       } else {
-        int idx, idy;
         super_block_yrd(cpi, x, rd_stats_y, bsize, ref_best_rd);
-        for (idy = 0; idy < xd->n8_h; ++idy)
-          for (idx = 0; idx < xd->n8_w; ++idx)
-            mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
+        memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
         memset(x->blk_skip[0], rd_stats_y->skip,
                sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
       }
@@ -8800,11 +8787,8 @@
       // Intrabc
       select_tx_type_yrd(cpi, x, &rd_stats, bsize, mi_row, mi_col, INT64_MAX);
     } else {
-      int idx, idy;
       super_block_yrd(cpi, x, &rd_stats, bsize, INT64_MAX);
-      for (idy = 0; idy < xd->n8_h; ++idy)
-        for (idx = 0; idx < xd->n8_w; ++idx)
-          mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
+      memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
       memset(x->blk_skip[0], rd_stats.skip,
              sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
     }
@@ -10386,11 +10370,8 @@
                            INT64_MAX);
         assert(rd_stats_y.rate != INT_MAX);
       } else {
-        int idx, idy;
         super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
-        for (idy = 0; idy < xd->n8_h; ++idy)
-          for (idx = 0; idx < xd->n8_w; ++idx)
-            mbmi->inter_tx_size[idy][idx] = mbmi->tx_size;
+        memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
         memset(x->blk_skip[0], rd_stats_y.skip,
                sizeof(uint8_t) * xd->n8_h * xd->n8_w * 4);
       }
@@ -10417,17 +10398,12 @@
     if (RDCOST(x->rdmult, best_rate_y + best_rate_uv, rd_cost->dist) >
         RDCOST(x->rdmult, rd_stats_y.rate + rd_stats_uv.rate,
                (rd_stats_y.dist + rd_stats_uv.dist))) {
-      int idx, idy;
       best_mbmode.tx_type = mbmi->tx_type;
       best_mbmode.tx_size = mbmi->tx_size;
-      for (idy = 0; idy < xd->n8_h; ++idy)
-        for (idx = 0; idx < xd->n8_w; ++idx)
-          best_mbmode.inter_tx_size[idy][idx] = mbmi->inter_tx_size[idy][idx];
-
+      av1_copy(best_mbmode.inter_tx_size, mbmi->inter_tx_size);
       for (i = 0; i < num_planes; ++i)
         memcpy(ctx->blk_skip[i], x->blk_skip[i],
                sizeof(uint8_t) * ctx->num_4x4_blk);
-
       best_mbmode.min_tx_size = mbmi->min_tx_size;
 #if CONFIG_TXK_SEL
       av1_copy(best_mbmode.txk_type, mbmi->txk_type);
@@ -10562,13 +10538,8 @@
       best_mbmode.tx_size = block_signals_txsize(bsize)
                                 ? tx_size_from_tx_mode(bsize, cm->tx_mode, 1)
                                 : max_txsize_rect_lookup[1][bsize];
-      {
-        const int width = block_size_wide[bsize] >> tx_size_wide_log2[0];
-        const int height = block_size_high[bsize] >> tx_size_high_log2[0];
-        for (int idy = 0; idy < height; ++idy)
-          for (int idx = 0; idx < width; ++idx)
-            best_mbmode.inter_tx_size[idy >> 1][idx >> 1] = best_mbmode.tx_size;
-      }
+      memset(best_mbmode.inter_tx_size, best_mbmode.tx_size,
+             sizeof(best_mbmode.inter_tx_size));
       best_mbmode.min_tx_size = best_mbmode.tx_size;
       set_txfm_ctxs(best_mbmode.tx_size, xd->n8_w, xd->n8_h, best_mbmode.skip,
                     xd);
diff --git a/av1/encoder/tokenize.c b/av1/encoder/tokenize.c
index ad0eddb..c823a84 100644
--- a/av1/encoder/tokenize.c
+++ b/av1/encoder/tokenize.c
@@ -486,8 +486,6 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const struct macroblockd_plane *const pd = &xd->plane[plane];
-  const int tx_row = blk_row >> (1 - pd->subsampling_y);
-  const int tx_col = blk_col >> (1 - pd->subsampling_x);
   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
 
@@ -495,7 +493,8 @@
 
   const TX_SIZE plane_tx_size =
       plane ? av1_get_uv_tx_size(mbmi, pd->subsampling_x, pd->subsampling_y)
-            : mbmi->inter_tx_size[tx_row][tx_col];
+            : mbmi->inter_tx_size[av1_get_txb_size_index(plane_bsize, blk_row,
+                                                         blk_col)];
 
   if (tx_size == plane_tx_size || plane) {
     plane_bsize = get_plane_block_size(mbmi->sb_type, pd);