Allocate mbmi on BLOCK_8X8 level for 4K+ videos

This commit change the allocation of cm->mi to one for each BLOCK_8X8
when the resolution is 4K or above.

Since most of the codebase assumes that the allocation done for each
BLOCK_4X4, we do not change how cm->mi_grid_base is allocated. i.e.
mi_grid_base is still allocated for each BLOCK_4X4, and some of them
are aliased to the same mi.

Total Memory Reduction:
2.6 GB => 2.5 GB, or ~4% reduction.

BUG=aomedia:2453

Change-Id: Ib3aa59761afffd063e1cff57becaf073ff7daff1
diff --git a/av1/common/alloccommon.c b/av1/common/alloccommon.c
index f0412fc..0823db1 100644
--- a/av1/common/alloccommon.c
+++ b/av1/common/alloccommon.c
@@ -31,60 +31,6 @@
   return mb_rows * mb_cols;
 }
 
-#if CONFIG_LPF_MASK
-static int alloc_loop_filter_mask(AV1_COMMON *cm) {
-  aom_free(cm->lf.lfm);
-  cm->lf.lfm = NULL;
-
-  // Each lfm holds bit masks for all the 4x4 blocks in a max
-  // 64x64 (128x128 for ext_partitions) region.  The stride
-  // and rows are rounded up / truncated to a multiple of 16
-  // (32 for ext_partition).
-  cm->lf.lfm_stride = (cm->mi_cols + (MI_SIZE_64X64 - 1)) >> MIN_MIB_SIZE_LOG2;
-  cm->lf.lfm_num = ((cm->mi_rows + (MI_SIZE_64X64 - 1)) >> MIN_MIB_SIZE_LOG2) *
-                   cm->lf.lfm_stride;
-  cm->lf.lfm =
-      (LoopFilterMask *)aom_calloc(cm->lf.lfm_num, sizeof(*cm->lf.lfm));
-  if (!cm->lf.lfm) return 1;
-
-  unsigned int i;
-  for (i = 0; i < cm->lf.lfm_num; ++i) av1_zero(cm->lf.lfm[i]);
-
-  return 0;
-}
-
-static void free_loop_filter_mask(AV1_COMMON *cm) {
-  if (cm->lf.lfm == NULL) return;
-
-  aom_free(cm->lf.lfm);
-  cm->lf.lfm = NULL;
-  cm->lf.lfm_num = 0;
-  cm->lf.lfm_stride = 0;
-}
-#endif
-
-void av1_set_mb_mi(AV1_COMMON *cm, int width, int height) {
-  // Ensure that the decoded width and height are both multiples of
-  // 8 luma pixels (note: this may only be a multiple of 4 chroma pixels if
-  // subsampling is used).
-  // This simplifies the implementation of various experiments,
-  // eg. cdef, which operates on units of 8x8 luma pixels.
-  const int aligned_width = ALIGN_POWER_OF_TWO(width, 3);
-  const int aligned_height = ALIGN_POWER_OF_TWO(height, 3);
-
-  cm->mi_cols = aligned_width >> MI_SIZE_LOG2;
-  cm->mi_rows = aligned_height >> MI_SIZE_LOG2;
-  cm->mi_stride = calc_mi_size(cm->mi_cols);
-
-  cm->mb_cols = (cm->mi_cols + 2) >> 2;
-  cm->mb_rows = (cm->mi_rows + 2) >> 2;
-  cm->MBs = cm->mb_rows * cm->mb_cols;
-
-#if CONFIG_LPF_MASK
-  alloc_loop_filter_mask(cm);
-#endif
-}
-
 void av1_free_ref_frame_buffers(BufferPool *pool) {
   int i;
 
@@ -272,20 +218,15 @@
 }
 
 int av1_alloc_context_buffers(AV1_COMMON *cm, int width, int height) {
-  int new_mi_size;
+  cm->set_mb_mi(cm, width, height);
 
-  av1_set_mb_mi(cm, width, height);
-  new_mi_size = cm->mi_stride * calc_mi_size(cm->mi_rows);
-  if (cm->mi_alloc_size < new_mi_size) {
-    cm->free_mi(cm);
-    if (cm->alloc_mi(cm, new_mi_size)) goto fail;
-  }
+  if (cm->alloc_mi(cm)) goto fail;
 
   return 0;
 
 fail:
   // clear the mi_* values to force a realloc on resync
-  av1_set_mb_mi(cm, 0, 0);
+  cm->set_mb_mi(cm, 0, 0);
   av1_free_context_buffers(cm);
   return 1;
 }
diff --git a/av1/common/alloccommon.h b/av1/common/alloccommon.h
index 8e58969..6f487ad 100644
--- a/av1/common/alloccommon.h
+++ b/av1/common/alloccommon.h
@@ -14,6 +14,8 @@
 
 #define INVALID_IDX -1  // Invalid buffer index.
 
+#include "config/aom_config.h"
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -38,9 +40,40 @@
 int av1_alloc_state_buffers(struct AV1Common *cm, int width, int height);
 void av1_free_state_buffers(struct AV1Common *cm);
 
-void av1_set_mb_mi(struct AV1Common *cm, int width, int height);
 int av1_get_MBs(int width, int height);
 
+#if CONFIG_LPF_MASK
+static INLINE int alloc_loop_filter_mask(AV1_COMMON *cm) {
+  aom_free(cm->lf.lfm);
+  cm->lf.lfm = NULL;
+
+  // Each lfm holds bit masks for all the 4x4 blocks in a max
+  // 64x64 (128x128 for ext_partitions) region.  The stride
+  // and rows are rounded up / truncated to a multiple of 16
+  // (32 for ext_partition).
+  cm->lf.lfm_stride = (cm->mi_cols + (MI_SIZE_64X64 - 1)) >> MIN_MIB_SIZE_LOG2;
+  cm->lf.lfm_num = ((cm->mi_rows + (MI_SIZE_64X64 - 1)) >> MIN_MIB_SIZE_LOG2) *
+                   cm->lf.lfm_stride;
+  cm->lf.lfm =
+      (LoopFilterMask *)aom_calloc(cm->lf.lfm_num, sizeof(*cm->lf.lfm));
+  if (!cm->lf.lfm) return 1;
+
+  unsigned int i;
+  for (i = 0; i < cm->lf.lfm_num; ++i) av1_zero(cm->lf.lfm[i]);
+
+  return 0;
+}
+
+static INLINE void free_loop_filter_mask(AV1_COMMON *cm) {
+  if (cm->lf.lfm == NULL) return;
+
+  aom_free(cm->lf.lfm);
+  cm->lf.lfm = NULL;
+  cm->lf.lfm_num = 0;
+  cm->lf.lfm_stride = 0;
+}
+#endif
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 9347e0c..38b76cd 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -427,13 +427,21 @@
 
   /* We allocate a MB_MODE_INFO struct for each macroblock, together with
      an extra row on top and column on the left to simplify prediction. */
-  int mi_alloc_size;
+  int mi_alloc_size, mi_grid_size;
   MB_MODE_INFO *mi;  /* Corresponds to upper left visible macroblock */
 
+  // The minimum size each allocated mi can correspond to.
+  // For decoder, this is always BLOCK_4X4.
+  // For encoder, this is currently set to BLOCK_4X4 for resolution below 4k,
+  // and BLOCK_8X8 for resolution above 4k
+  BLOCK_SIZE mi_alloc_bsize;
+  int mi_alloc_rows, mi_alloc_cols, mi_alloc_stride;
+
   // Separate mi functions between encoder and decoder.
-  int (*alloc_mi)(struct AV1Common *cm, int mi_size);
+  int (*alloc_mi)(struct AV1Common *cm);
   void (*free_mi)(struct AV1Common *cm);
   void (*setup_mi)(struct AV1Common *cm);
+  void (*set_mb_mi)(struct AV1Common *cm, int height, int width);
 
   // Grid of pointers to 4x4 MB_MODE_INFO structs. Any 4x4 not in the visible
   // area will be NULL.
@@ -1186,6 +1194,29 @@
   set_txfm_ctx(xd->left_txfm_context, bh, n4_h);
 }
 
+static INLINE int get_mi_grid_idx(const AV1_COMMON *cm, int mi_row,
+                                  int mi_col) {
+  return mi_row * cm->mi_stride + mi_col;
+}
+
+static INLINE int get_alloc_mi_idx(const AV1_COMMON *cm, int mi_row,
+                                   int mi_col) {
+  const int mi_alloc_size_1d = mi_size_wide[cm->mi_alloc_bsize];
+  const int mi_alloc_row = mi_row / mi_alloc_size_1d;
+  const int mi_alloc_col = mi_col / mi_alloc_size_1d;
+
+  return mi_alloc_row * cm->mi_alloc_stride + mi_alloc_col;
+}
+
+static INLINE int get_mi_ext_idx(const AV1_COMMON *cm, int mi_row,
+                                 int mi_col) {
+  const int mi_alloc_size_1d = mi_size_wide[cm->mi_alloc_bsize];
+  const int mi_alloc_row = mi_row / mi_alloc_size_1d;
+  const int mi_alloc_col = mi_col / mi_alloc_size_1d;
+
+  return mi_alloc_row * cm->mi_alloc_cols + mi_alloc_col;
+}
+
 static INLINE void txfm_partition_update(TXFM_CONTEXT *above_ctx,
                                          TXFM_CONTEXT *left_ctx,
                                          TX_SIZE tx_size, TX_SIZE txb_size) {
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index ac304c5..7584679 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -340,11 +340,10 @@
                         int bh, int x_mis, int y_mis) {
   const int num_planes = av1_num_planes(cm);
 
-  const int offset = mi_row * cm->mi_stride + mi_col;
   const TileInfo *const tile = &xd->tile;
 
-  xd->mi = cm->mi_grid_base + offset;
-  xd->mi[0] = &cm->mi[offset];
+  xd->mi = cm->mi_grid_base + get_mi_grid_idx(cm, mi_row, mi_col);
+  xd->mi[0] = &cm->mi[get_alloc_mi_idx(cm, mi_row, mi_col)];
   // TODO(slavarnway): Generate sb_type based on bwl and bhl, instead of
   // passing bsize from decode_partition().
   xd->mi[0]->sb_type = bsize;
@@ -2153,7 +2152,7 @@
                            "Failed to allocate context buffers");
       }
     } else {
-      av1_set_mb_mi(cm, width, height);
+      cm->set_mb_mi(cm, width, height);
     }
     av1_init_context_buffers(cm);
     cm->width = width;
diff --git a/av1/decoder/decoder.c b/av1/decoder/decoder.c
index 50eace3..1c63ec4 100644
--- a/av1/decoder/decoder.c
+++ b/av1/decoder/decoder.c
@@ -45,19 +45,56 @@
   av1_init_wedge_masks();
 }
 
-static void dec_setup_mi(AV1_COMMON *cm) {
-  cm->mi_grid_base = cm->mi_grid_base;
-  memset(cm->mi_grid_base, 0,
-         cm->mi_stride * cm->mi_rows * sizeof(*cm->mi_grid_base));
+static void dec_set_mb_mi(AV1_COMMON *cm, int width, int height) {
+  // Ensure that the decoded width and height are both multiples of
+  // 8 luma pixels (note: this may only be a multiple of 4 chroma pixels if
+  // subsampling is used).
+  // This simplifies the implementation of various experiments,
+  // eg. cdef, which operates on units of 8x8 luma pixels.
+  const int aligned_width = ALIGN_POWER_OF_TWO(width, 3);
+  const int aligned_height = ALIGN_POWER_OF_TWO(height, 3);
+
+  cm->mi_cols = aligned_width >> MI_SIZE_LOG2;
+  cm->mi_rows = aligned_height >> MI_SIZE_LOG2;
+  cm->mi_stride = calc_mi_size(cm->mi_cols);
+
+  cm->mb_cols = (cm->mi_cols + 2) >> 2;
+  cm->mb_rows = (cm->mi_rows + 2) >> 2;
+  cm->MBs = cm->mb_rows * cm->mb_cols;
+
+  cm->mi_alloc_bsize = BLOCK_4X4;
+  cm->mi_alloc_rows = cm->mi_rows;
+  cm->mi_alloc_cols = cm->mi_cols;
+  cm->mi_alloc_stride = cm->mi_stride;
+
+  assert(mi_size_wide[cm->mi_alloc_bsize] == mi_size_high[cm->mi_alloc_bsize]);
+
+#if CONFIG_LPF_MASK
+  alloc_loop_filter_mask(cm);
+#endif
 }
 
-static int dec_alloc_mi(AV1_COMMON *cm, int mi_size) {
-  cm->mi = aom_calloc(mi_size, sizeof(*cm->mi));
-  if (!cm->mi) return 1;
-  cm->mi_alloc_size = mi_size;
-  cm->mi_grid_base =
-      (MB_MODE_INFO **)aom_calloc(mi_size, sizeof(MB_MODE_INFO *));
-  if (!cm->mi_grid_base) return 1;
+static void dec_setup_mi(AV1_COMMON *cm) {
+  const int mi_grid_size = cm->mi_stride * calc_mi_size(cm->mi_rows);
+  memset(cm->mi_grid_base, 0, mi_grid_size * sizeof(*cm->mi_grid_base));
+}
+
+static int dec_alloc_mi(AV1_COMMON *cm) {
+  const int mi_grid_size = cm->mi_stride * calc_mi_size(cm->mi_rows);
+
+  if (cm->mi_alloc_size < mi_grid_size || cm->mi_grid_size < mi_grid_size) {
+    cm->free_mi(cm);
+
+    cm->mi = aom_calloc(mi_grid_size, sizeof(*cm->mi));
+    if (!cm->mi) return 1;
+    cm->mi_alloc_size = mi_grid_size;
+
+    cm->mi_grid_base =
+        (MB_MODE_INFO **)aom_calloc(mi_grid_size, sizeof(MB_MODE_INFO *));
+    if (!cm->mi_grid_base) return 1;
+    cm->mi_grid_size = mi_grid_size;
+  }
+
   return 0;
 }
 
@@ -113,6 +150,7 @@
   cm->alloc_mi = dec_alloc_mi;
   cm->free_mi = dec_free_mi;
   cm->setup_mi = dec_setup_mi;
+  cm->set_mb_mi = dec_set_mb_mi;
 
   av1_loop_filter_init(cm);
 
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 818c4d6..b13d16f 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1321,11 +1321,8 @@
   AV1_COMMON *const cm = &cpi->common;
   const MB_MODE_INFO *const *mbmi =
       *(cm->mi_grid_base + (mi_row * cm->mi_stride + mi_col));
-  const int mi_alloc_size_1d = cpi->mi_alloc_size_1d;
-  const int mi_alloc_row = (mi_row + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
-  const int mi_alloc_col = (mi_col + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
   const MB_MODE_INFO_EXT *const *mbmi_ext =
-      cpi->mbmi_ext_base + (mi_alloc_row * cpi->mi_alloc_cols + mi_alloc_col);
+      cpi->mbmi_ext_base + get_mi_ext_idx(cm, mi_row, mi_col);
   if (is_inter_block(mbmi)) {
 #define FRAME_TO_CHECK 11
     if (cm->current_frame.frame_number == FRAME_TO_CHECK &&
@@ -1511,12 +1508,7 @@
   const AV1_COMMON *cm = &cpi->common;
   MACROBLOCKD *xd = &cpi->td.mb.e_mbd;
   xd->mi = cm->mi_grid_base + (mi_row * cm->mi_stride + mi_col);
-
-  const int mi_alloc_size_1d = cpi->mi_alloc_size_1d;
-  const int mi_alloc_row = (mi_row + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
-  const int mi_alloc_col = (mi_col + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
-  cpi->td.mb.mbmi_ext =
-      cpi->mbmi_ext_base + (mi_alloc_row * cpi->mi_alloc_cols + mi_alloc_col);
+  cpi->td.mb.mbmi_ext = cpi->mbmi_ext_base + get_mi_ext_idx(cm, mi_row, mi_col);
 
   const MB_MODE_INFO *mbmi = xd->mi[0];
   const BLOCK_SIZE bsize = mbmi->sb_type;
diff --git a/av1/encoder/encode_strategy.c b/av1/encoder/encode_strategy.c
index 06e575f..6f686a4 100644
--- a/av1/encoder/encode_strategy.c
+++ b/av1/encoder/encode_strategy.c
@@ -1046,8 +1046,8 @@
     av1_init_context_buffers(cm);
     setup_mi(cpi, frame_input->source);
     av1_init_macroblockd(cm, xd, NULL);
-    const int alloc_mi_size = cpi->mi_alloc_rows * cpi->mi_alloc_cols;
-    memset(cpi->mbmi_ext_base, 0, alloc_mi_size * sizeof(*cpi->mbmi_ext_base));
+    const int ext_mi_size = cm->mi_alloc_rows * cm->mi_alloc_cols;
+    memset(cpi->mbmi_ext_base, 0, ext_mi_size * sizeof(*cpi->mbmi_ext_base));
 
     av1_set_speed_features_framesize_independent(cpi, oxcf->speed);
     av1_set_speed_features_framesize_dependent(cpi, oxcf->speed);
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index dfe5d77..630ab96 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1685,10 +1685,13 @@
   const int mi_rows_remaining = tile->mi_row_end - mi_row;
   const int mi_cols_remaining = tile->mi_col_end - mi_col;
   int block_row, block_col;
-  MB_MODE_INFO *const mi_upper_left = cm->mi + mi_row * cm->mi_stride + mi_col;
+  MB_MODE_INFO *const mi_upper_left =
+      cm->mi + get_alloc_mi_idx(cm, mi_row, mi_col);
   int bh = mi_size_high[bsize];
   int bw = mi_size_wide[bsize];
 
+  assert(bsize >= cm->mi_alloc_bsize &&
+         "Attempted to use bsize < cm->mi_alloc_bsize");
   assert((mi_rows_remaining > 0) && (mi_cols_remaining > 0));
 
   // Apply the requested partition size to the SB if it is all "in image"
@@ -1697,9 +1700,10 @@
     for (block_row = 0; block_row < cm->seq_params.mib_size; block_row += bh) {
       for (block_col = 0; block_col < cm->seq_params.mib_size;
            block_col += bw) {
-        int index = block_row * cm->mi_stride + block_col;
-        mib[index] = mi_upper_left + index;
-        mib[index]->sb_type = bsize;
+        const int grid_index = get_mi_grid_idx(cm, block_row, block_col);
+        const int mi_index = get_alloc_mi_idx(cm, block_row, block_col);
+        mib[grid_index] = mi_upper_left + mi_index;
+        mib[grid_index]->sb_type = bsize;
       }
     }
   } else {
@@ -3670,19 +3674,17 @@
     const int frame_lf_count =
         av1_num_planes(cm) > 1 ? FRAME_LF_COUNT : FRAME_LF_COUNT - 2;
     const int mib_size = cm->seq_params.mib_size;
-    const int mi_stide = cm->mi_stride;
-    int mi_index_base = mi_row * mi_stide + mi_col;
 
     // pre-set the delta lf for loop filter. Note that this value is set
     // before mi is assigned for each block in current superblock
     for (int j = 0; j < AOMMIN(mib_size, cm->mi_rows - mi_row); j++) {
       for (int k = 0; k < AOMMIN(mib_size, cm->mi_cols - mi_col); k++) {
-        cm->mi[mi_index_base + k].delta_lf_from_base = delta_lf;
+        const int mi_idx = get_alloc_mi_idx(cm, mi_row + j, mi_col + k);
+        cm->mi[mi_idx].delta_lf_from_base = delta_lf;
         for (int lf_id = 0; lf_id < frame_lf_count; ++lf_id) {
-          cm->mi[mi_index_base + k].delta_lf[lf_id] = delta_lf;
+          cm->mi[mi_idx].delta_lf[lf_id] = delta_lf;
         }
       }
-      mi_index_base += mi_stide;
     }
   }
 }
@@ -4010,8 +4012,7 @@
 
     td->mb.cb_coef_buff = av1_get_cb_coeff_buffer(cpi, mi_row, mi_col);
 
-    const int idx_str = cm->mi_stride * mi_row + mi_col;
-    MB_MODE_INFO **mi = cm->mi_grid_base + idx_str;
+    MB_MODE_INFO **mi = cm->mi_grid_base + get_mi_grid_idx(cm, mi_row, mi_col);
     x->source_variance = UINT_MAX;
     x->simple_motion_pred_sse = UINT_MAX;
     const struct segmentation *const seg = &cm->seg;
@@ -4569,7 +4570,7 @@
     av1_generate_block_2x2_hash_value(cpi->source, block_hash_values[0],
                                       is_block_same[0], &cpi->td.mb);
     const int max_size = 128, min_size = 4;
-    const int min_alloc_size = block_size_wide[cpi->mi_alloc_bsize];
+    const int min_alloc_size = block_size_wide[cm->mi_alloc_bsize];
     int src_idx = 0;
     for (int size = min_size; size <= max_size; size *= 2, src_idx = !src_idx) {
       const int dst_idx = !src_idx;
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index f59bd34..8449756 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -13,7 +13,6 @@
 #include <math.h>
 #include <stdio.h>
 
-#include "av1/common/enums.h"
 #include "config/aom_config.h"
 #include "config/aom_dsp_rtcd.h"
 #include "config/aom_scale_rtcd.h"
@@ -378,22 +377,64 @@
   cpi->vaq_refresh = 0;
 }
 
-static void enc_setup_mi(AV1_COMMON *cm) {
-  int mi_rows_sb_aligned = calc_mi_size(cm->mi_rows);
-  memset(cm->mi, 0, cm->mi_stride * mi_rows_sb_aligned * sizeof(*cm->mi));
+static void enc_set_mb_mi(AV1_COMMON *cm, int width, int height) {
+  // Ensure that the decoded width and height are both multiples of
+  // 8 luma pixels (note: this may only be a multiple of 4 chroma pixels if
+  // subsampling is used).
+  // This simplifies the implementation of various experiments,
+  // eg. cdef, which operates on units of 8x8 luma pixels.
+  const int aligned_width = ALIGN_POWER_OF_TWO(width, 3);
+  const int aligned_height = ALIGN_POWER_OF_TWO(height, 3);
 
-  memset(cm->mi_grid_base, 0,
-         cm->mi_stride * mi_rows_sb_aligned * sizeof(*cm->mi_grid_base));
+  cm->mi_cols = aligned_width >> MI_SIZE_LOG2;
+  cm->mi_rows = aligned_height >> MI_SIZE_LOG2;
+  cm->mi_stride = calc_mi_size(cm->mi_cols);
+
+  cm->mb_cols = (cm->mi_cols + 2) >> 2;
+  cm->mb_rows = (cm->mi_rows + 2) >> 2;
+  cm->MBs = cm->mb_rows * cm->mb_cols;
+
+  const int is_4k_or_larger = AOMMIN(width, height) >= 2160;
+
+  cm->mi_alloc_bsize = is_4k_or_larger ? BLOCK_8X8 : BLOCK_4X4;
+  const int mi_alloc_size_1d = mi_size_wide[cm->mi_alloc_bsize];
+  cm->mi_alloc_rows = (cm->mi_rows + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
+  cm->mi_alloc_cols = (cm->mi_cols + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
+  cm->mi_alloc_stride =
+      (cm->mi_stride + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
+
+  assert(mi_size_wide[cm->mi_alloc_bsize] == mi_size_high[cm->mi_alloc_bsize]);
+
+#if CONFIG_LPF_MASK
+  alloc_loop_filter_mask(cm);
+#endif
 }
 
-static int enc_alloc_mi(AV1_COMMON *cm, int mi_size) {
-  cm->mi = aom_calloc(mi_size, sizeof(*cm->mi));
-  if (!cm->mi) return 1;
-  cm->mi_alloc_size = mi_size;
+static void enc_setup_mi(AV1_COMMON *cm) {
+  const int mi_grid_size = cm->mi_stride * calc_mi_size(cm->mi_rows);
+  memset(cm->mi, 0, cm->mi_alloc_size * sizeof(*cm->mi));
 
-  cm->mi_grid_base =
-      (MB_MODE_INFO **)aom_calloc(mi_size, sizeof(MB_MODE_INFO *));
-  if (!cm->mi_grid_base) return 1;
+  memset(cm->mi_grid_base, 0, mi_grid_size * sizeof(*cm->mi_grid_base));
+}
+
+static int enc_alloc_mi(AV1_COMMON *cm) {
+  const int mi_grid_size = cm->mi_stride * calc_mi_size(cm->mi_rows);
+  const int alloc_size_1d = mi_size_wide[cm->mi_alloc_bsize];
+  const int alloc_mi_size =
+      cm->mi_alloc_stride * (calc_mi_size(cm->mi_rows) / alloc_size_1d);
+
+  if (cm->mi_alloc_size < alloc_mi_size || cm->mi_grid_size < mi_grid_size) {
+    cm->free_mi(cm);
+
+    cm->mi = aom_calloc(alloc_mi_size, sizeof(*cm->mi));
+    if (!cm->mi) return 1;
+    cm->mi_alloc_size = alloc_mi_size;
+
+    cm->mi_grid_base =
+        (MB_MODE_INFO **)aom_calloc(mi_grid_size, sizeof(MB_MODE_INFO *));
+    if (!cm->mi_grid_base) return 1;
+    cm->mi_grid_size = mi_grid_size;
+  }
 
   return 0;
 }
@@ -425,23 +466,14 @@
 
 static void alloc_context_buffers_ext(AV1_COMP *cpi) {
   AV1_COMMON *cm = &cpi->common;
-  const int is_4k_or_larger = AOMMIN(cm->width, cm->height) >= 2160;
+  const int new_ext_mi_size = cm->mi_alloc_rows * cm->mi_alloc_cols;
 
-  cpi->mi_alloc_bsize = is_4k_or_larger ? BLOCK_8X8 : BLOCK_4X4;
-  cpi->mi_alloc_size_1d = mi_size_wide[cpi->mi_alloc_bsize];
-  cpi->mi_alloc_rows =
-      (cm->mi_rows + cpi->mi_alloc_size_1d - 1) / cpi->mi_alloc_size_1d;
-  cpi->mi_alloc_cols =
-      (cm->mi_cols + cpi->mi_alloc_size_1d - 1) / cpi->mi_alloc_size_1d;
-
-  assert(mi_size_wide[cpi->mi_alloc_bsize] ==
-         mi_size_high[cpi->mi_alloc_bsize]);
-
-  const int alloc_mi_size = cpi->mi_alloc_rows * cpi->mi_alloc_cols;
-
-  dealloc_context_buffers_ext(cpi);
-  CHECK_MEM_ERROR(cm, cpi->mbmi_ext_base,
-                  aom_calloc(alloc_mi_size, sizeof(*cpi->mbmi_ext_base)));
+  if (new_ext_mi_size > cpi->mi_ext_alloc_size) {
+    dealloc_context_buffers_ext(cpi);
+    CHECK_MEM_ERROR(cm, cpi->mbmi_ext_base,
+                    aom_calloc(new_ext_mi_size, sizeof(*cpi->mbmi_ext_base)));
+    cpi->mi_ext_alloc_size = new_ext_mi_size;
+  }
 }
 
 static void reset_film_grain_chroma_params(aom_film_grain_t *pars) {
@@ -843,7 +875,10 @@
   AV1_COMMON *cm = &cpi->common;
   const int num_planes = av1_num_planes(cm);
 
-  av1_alloc_context_buffers(cm, cm->width, cm->height);
+  if (av1_alloc_context_buffers(cm, cm->width, cm->height)) {
+    aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
+                       "Failed to allocate context buffers");
+  }
 
   int mi_rows_aligned_to_sb =
       ALIGN_POWER_OF_TWO(cm->mi_rows, cm->seq_params.mib_size_log2);
@@ -942,12 +977,18 @@
   AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &cpi->td.mb.e_mbd;
 
-  av1_set_mb_mi(cm, cm->width, cm->height);
+  // We need to reallocate the context buffers here in case we need more mis.
+  if (av1_alloc_context_buffers(cm, cm->width, cm->height)) {
+    aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
+                       "Failed to allocate context buffers");
+  }
   av1_init_context_buffers(cm);
+
   av1_init_macroblockd(cm, xd, NULL);
 
-  const int alloc_mi_size = cpi->mi_alloc_rows * cpi->mi_alloc_cols;
-  memset(cpi->mbmi_ext_base, 0, alloc_mi_size * sizeof(*cpi->mbmi_ext_base));
+  const int ext_mi_size = cm->mi_alloc_rows * cm->mi_alloc_cols;
+  alloc_context_buffers_ext(cpi);
+  memset(cpi->mbmi_ext_base, 0, ext_mi_size * sizeof(*cpi->mbmi_ext_base));
   set_tile_info(cpi);
 }
 
@@ -2637,6 +2678,9 @@
   cm->alloc_mi = enc_alloc_mi;
   cm->free_mi = enc_free_mi;
   cm->setup_mi = enc_setup_mi;
+  cm->set_mb_mi = enc_set_mb_mi;
+
+  cm->mi_alloc_bsize = BLOCK_4X4;
 
   CHECK_MEM_ERROR(cm, cm->fc,
                   (FRAME_CONTEXT *)aom_memalign(32, sizeof(*cm->fc)));
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index de153e2..87f70de 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -753,13 +753,8 @@
   struct lookahead_entry *alt_ref_source;
   int no_show_kf;
 
-  // The minimum size each allocateed mi_ext can correspond to. Currently set to
-  // BLOCK_4X4 for resolution below 4k, and BLOCK_8X8 for resolution above 4k
-  BLOCK_SIZE mi_alloc_bsize;
-  int mi_alloc_size_1d;  // Number of 4x4 blocks in an allocated mi_ext
-  int mi_alloc_rows, mi_alloc_cols;
-
   int optimize_seg_arr[MAX_SEGMENTS];
+  int mi_ext_alloc_size;
 
   YV12_BUFFER_CONFIG *source;
   YV12_BUFFER_CONFIG *last_source;  // NULL for first frame and alt_ref frames
@@ -1363,15 +1358,14 @@
                                          MACROBLOCKD *const xd, int mi_row,
                                          int mi_col) {
   const AV1_COMMON *const cm = &cpi->common;
-  const int idx_str = xd->mi_stride * mi_row + mi_col;
-  xd->mi = cm->mi_grid_base + idx_str;
-  xd->mi[0] = cm->mi + idx_str;
+  const int grid_idx = get_mi_grid_idx(cm, mi_row, mi_col);
+  const int mi_idx = get_alloc_mi_idx(cm, mi_row, mi_col);
+  const int ext_idx = get_mi_ext_idx(cm, mi_row, mi_col);
 
-  const int mi_alloc_size_1d = cpi->mi_alloc_size_1d;
-  const int mi_alloc_row = (mi_row + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
-  const int mi_alloc_col = (mi_col + mi_alloc_size_1d - 1) / mi_alloc_size_1d;
-  x->mbmi_ext =
-      cpi->mbmi_ext_base + (mi_alloc_row * cpi->mi_alloc_cols + mi_alloc_col);
+  xd->mi = cm->mi_grid_base + grid_idx;
+  xd->mi[0] = cm->mi + mi_idx;
+
+  x->mbmi_ext = cpi->mbmi_ext_base + ext_idx;
 }
 
 // Check to see if the given partition size is allowed for a specified number
diff --git a/av1/encoder/firstpass.c b/av1/encoder/firstpass.c
index 9eda3ee..3d3810f 100644
--- a/av1/encoder/firstpass.c
+++ b/av1/encoder/firstpass.c
@@ -463,9 +463,12 @@
 
       aom_clear_system_state();
 
-      const int idx_str = xd->mi_stride * mb_row * mb_scale + mb_col * mb_scale;
-      xd->mi = cm->mi_grid_base + idx_str;
-      xd->mi[0] = cm->mi + idx_str;
+      const int grid_idx =
+          get_mi_grid_idx(cm, mb_row * mb_scale, mb_col * mb_scale);
+      const int mi_idx =
+          get_alloc_mi_idx(cm, mb_row * mb_scale, mb_col * mb_scale);
+      xd->mi = cm->mi_grid_base + grid_idx;
+      xd->mi[0] = cm->mi + mi_idx;
       xd->plane[0].dst.buf = new_yv12->y_buffer + recon_yoffset;
       xd->plane[1].dst.buf = new_yv12->u_buffer + recon_uvoffset;
       xd->plane[2].dst.buf = new_yv12->v_buffer + recon_uvoffset;