Allow first pass to encode with other square block sizes

Traditionally, block size of 16x16 is hard coded for the first pass
encoding.

Here we make a few changes to remove the hard coded constraint and
allow other square block sizes: 4x4, 8x8, 16x16.

Changes include:
* Remove macro of 16x16 block size
* Add a variable of first pass block size to cpi, so that multi thread
  is aware of unit size.
* Add auxiliary functions to get size info, and let block size be
  input variable so that various block sizes can be supported.
* Other miscellaneous changes to make this work.

Change-Id: I302aa805736d3c6c26273ac064ef94247b16c364
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index b4bc58d..7b85348 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -2620,6 +2620,11 @@
    * Number of frames left to be encoded, is 0 if limit is not set.
    */
   int frames_left;
+
+  /*!
+   * Block size of first pass encoding
+   */
+  BLOCK_SIZE fp_block_size;
 } AV1_COMP;
 
 /*!
diff --git a/av1/encoder/ethread.c b/av1/encoder/ethread.c
index 730eedb..43776f9 100644
--- a/av1/encoder/ethread.c
+++ b/av1/encoder/ethread.c
@@ -299,7 +299,8 @@
 
 static AOM_INLINE void switch_tile_and_get_next_job(
     AV1_COMMON *const cm, TileDataEnc *const tile_data, int *cur_tile_id,
-    int *current_mi_row, int *end_of_frame, int is_firstpass) {
+    int *current_mi_row, int *end_of_frame, int is_firstpass,
+    const BLOCK_SIZE fp_block_size) {
   const int tile_cols = cm->tiles.cols;
   const int tile_rows = cm->tiles.rows;
 
@@ -320,11 +321,13 @@
           av1_get_sb_cols_in_tile(cm, this_tile->tile_info);
 #else
       int num_b_rows_in_tile =
-          is_firstpass ? av1_get_mb_rows_in_tile(this_tile->tile_info)
-                       : av1_get_sb_rows_in_tile(cm, this_tile->tile_info);
+          is_firstpass
+              ? av1_get_unit_rows_in_tile(this_tile->tile_info, fp_block_size)
+              : av1_get_sb_rows_in_tile(cm, this_tile->tile_info);
       int num_b_cols_in_tile =
-          is_firstpass ? av1_get_mb_cols_in_tile(this_tile->tile_info)
-                       : av1_get_sb_cols_in_tile(cm, this_tile->tile_info);
+          is_firstpass
+              ? av1_get_unit_cols_in_tile(this_tile->tile_info, fp_block_size)
+              : av1_get_sb_cols_in_tile(cm, this_tile->tile_info);
 #endif
       int theoretical_limit_on_threads =
           AOMMIN((num_b_cols_in_tile + 1) >> 1, num_b_rows_in_tile);
@@ -361,8 +364,9 @@
     // Update the current tile id to the tile id that will be processed next,
     // which will be the least processed tile.
     *cur_tile_id = tile_id;
+    const int unit_height = mi_size_high[fp_block_size];
     get_next_job(&tile_data[tile_id], current_mi_row,
-                 is_firstpass ? FP_MIB_SIZE : cm->seq_params.mib_size);
+                 is_firstpass ? unit_height : cm->seq_params.mib_size);
   }
 }
 
@@ -381,6 +385,8 @@
 
   assert(cur_tile_id != -1);
 
+  const BLOCK_SIZE fp_block_size = cpi->fp_block_size;
+  const int unit_height = mi_size_high[fp_block_size];
   int end_of_frame = 0;
   while (1) {
     int current_mi_row = -1;
@@ -388,11 +394,12 @@
     pthread_mutex_lock(enc_row_mt_mutex_);
 #endif
     if (!get_next_job(&cpi->tile_data[cur_tile_id], &current_mi_row,
-                      FP_MIB_SIZE)) {
+                      unit_height)) {
       // No jobs are available for the current tile. Query for the status of
       // other tiles and get the next job if available
       switch_tile_and_get_next_job(cm, cpi->tile_data, &cur_tile_id,
-                                   &current_mi_row, &end_of_frame, 1);
+                                   &current_mi_row, &end_of_frame, 1,
+                                   fp_block_size);
     }
 #if CONFIG_MULTITHREAD
     pthread_mutex_unlock(enc_row_mt_mutex_);
@@ -406,7 +413,9 @@
     assert(current_mi_row != -1 &&
            current_mi_row <= this_tile->tile_info.mi_row_end);
 
-    av1_first_pass_row(cpi, td, this_tile, current_mi_row >> FP_MIB_SIZE_LOG2);
+    const int unit_height_log2 = mi_size_high_log2[fp_block_size];
+    av1_first_pass_row(cpi, td, this_tile, current_mi_row >> unit_height_log2,
+                       fp_block_size);
 #if CONFIG_MULTITHREAD
     pthread_mutex_lock(enc_row_mt_mutex_);
 #endif
@@ -434,6 +443,7 @@
 
   assert(cur_tile_id != -1);
 
+  const BLOCK_SIZE fp_block_size = cpi->fp_block_size;
   int end_of_frame = 0;
   while (1) {
     int current_mi_row = -1;
@@ -445,7 +455,8 @@
       // No jobs are available for the current tile. Query for the status of
       // other tiles and get the next job if available
       switch_tile_and_get_next_job(cm, cpi->tile_data, &cur_tile_id,
-                                   &current_mi_row, &end_of_frame, 0);
+                                   &current_mi_row, &end_of_frame, 0,
+                                   fp_block_size);
     }
 #if CONFIG_MULTITHREAD
     pthread_mutex_unlock(enc_row_mt_mutex_);
@@ -945,8 +956,10 @@
   for (int row = 0; row < tile_rows; row++) {
     for (int col = 0; col < tile_cols; col++) {
       av1_tile_init(&tile_info, cm, row, col);
-      const int num_mb_rows_in_tile = av1_get_mb_rows_in_tile(tile_info);
-      const int num_mb_cols_in_tile = av1_get_mb_cols_in_tile(tile_info);
+      const int num_mb_rows_in_tile =
+          av1_get_unit_rows_in_tile(tile_info, cpi->fp_block_size);
+      const int num_mb_cols_in_tile =
+          av1_get_unit_cols_in_tile(tile_info, cpi->fp_block_size);
       total_num_threads_row_mt +=
           AOMMIN((num_mb_cols_in_tile + 1) >> 1, num_mb_rows_in_tile);
     }
@@ -956,8 +969,9 @@
 
 // Computes the maximum number of mb_rows for row multi-threading of firstpass
 // stage
-static AOM_INLINE int fp_compute_max_mb_rows(
-    const AV1_COMMON *const cm, const TileDataEnc *const tile_data) {
+static AOM_INLINE int fp_compute_max_mb_rows(const AV1_COMMON *const cm,
+                                             const TileDataEnc *const tile_data,
+                                             const BLOCK_SIZE fp_block_size) {
   const int tile_cols = cm->tiles.cols;
   const int tile_rows = cm->tiles.rows;
   int max_mb_rows = 0;
@@ -965,7 +979,8 @@
     for (int col = 0; col < tile_cols; col++) {
       const int tile_index = row * cm->tiles.cols + col;
       TileInfo tile_info = tile_data[tile_index].tile_info;
-      const int num_mb_rows_in_tile = av1_get_mb_rows_in_tile(tile_info);
+      const int num_mb_rows_in_tile =
+          av1_get_unit_rows_in_tile(tile_info, fp_block_size);
       max_mb_rows = AOMMAX(max_mb_rows, num_mb_rows_in_tile);
     }
   }
@@ -1066,7 +1081,8 @@
 
   av1_init_tile_data(cpi);
 
-  max_mb_rows = fp_compute_max_mb_rows(cm, cpi->tile_data);
+  const BLOCK_SIZE fp_block_size = cpi->fp_block_size;
+  max_mb_rows = fp_compute_max_mb_rows(cm, cpi->tile_data, fp_block_size);
 
   // TODO(ravi.chaudhary@ittiam.com): Currently the percentage of
   // post-processing stages in encoder is quiet low, so limiting the number of
diff --git a/av1/encoder/firstpass.c b/av1/encoder/firstpass.c
index 24cb6c4..8a827ea 100644
--- a/av1/encoder/firstpass.c
+++ b/av1/encoder/firstpass.c
@@ -137,6 +137,45 @@
   section->duration += frame->duration;
 }
 
+static int get_unit_rows(const BLOCK_SIZE fp_block_size, const int mb_rows) {
+  const int height_mi_log2 = mi_size_high_log2[fp_block_size];
+  const int mb_height_mi_log2 = mi_size_high_log2[BLOCK_16X16];
+  if (height_mi_log2 > mb_height_mi_log2) {
+    return mb_rows >> (height_mi_log2 - mb_height_mi_log2);
+  }
+
+  return mb_rows << (mb_height_mi_log2 - height_mi_log2);
+}
+
+static int get_unit_cols(const BLOCK_SIZE fp_block_size, const int mb_cols) {
+  const int width_mi_log2 = mi_size_wide_log2[fp_block_size];
+  const int mb_width_mi_log2 = mi_size_wide_log2[BLOCK_16X16];
+  if (width_mi_log2 > mb_width_mi_log2) {
+    return mb_cols >> (width_mi_log2 - mb_width_mi_log2);
+  }
+
+  return mb_cols << (mb_width_mi_log2 - width_mi_log2);
+}
+
+// TODO(chengchen): can we simplify it even if resize has to be considered?
+static int get_num_mbs(const BLOCK_SIZE fp_block_size,
+                       const int num_mbs_16X16) {
+  const int width_mi_log2 = mi_size_wide_log2[fp_block_size];
+  const int height_mi_log2 = mi_size_high_log2[fp_block_size];
+  const int mb_width_mi_log2 = mi_size_wide_log2[BLOCK_16X16];
+  const int mb_height_mi_log2 = mi_size_high_log2[BLOCK_16X16];
+  // TODO(chengchen): Now this function assumes a square block is used.
+  // It does not support rectangular block sizes.
+  assert(width_mi_log2 == height_mi_log2);
+  if (width_mi_log2 > mb_width_mi_log2) {
+    return num_mbs_16X16 >> ((width_mi_log2 - mb_width_mi_log2) +
+                             (height_mi_log2 - mb_height_mi_log2));
+  }
+
+  return num_mbs_16X16 << ((mb_width_mi_log2 - width_mi_log2) +
+                           (mb_height_mi_log2 - height_mi_log2));
+}
+
 void av1_end_first_pass(AV1_COMP *cpi) {
   if (cpi->twopass.stats_buf_ctx->total_stats)
     output_stats(cpi->twopass.stats_buf_ctx->total_stats, cpi->output_pkt_list);
@@ -256,18 +295,35 @@
 }
 
 static BLOCK_SIZE get_bsize(const CommonModeInfoParams *const mi_params,
-                            int mb_row, int mb_col) {
-  if (mi_size_wide[BLOCK_16X16] * mb_col + mi_size_wide[BLOCK_8X8] <
-      mi_params->mi_cols) {
-    return mi_size_wide[BLOCK_16X16] * mb_row + mi_size_wide[BLOCK_8X8] <
-                   mi_params->mi_rows
-               ? BLOCK_16X16
-               : BLOCK_16X8;
+                            const BLOCK_SIZE fp_block_size, const int unit_row,
+                            const int unit_col) {
+  const int unit_width = mi_size_wide[fp_block_size];
+  const int unit_height = mi_size_high[fp_block_size];
+  const int is_half_width =
+      unit_width * unit_col + unit_width / 2 >= mi_params->mi_cols;
+  const int is_half_height =
+      unit_height * unit_row + unit_height / 2 >= mi_params->mi_rows;
+  const int max_dimension =
+      AOMMAX(block_size_wide[fp_block_size], block_size_high[fp_block_size]);
+  int square_block_size = 0;
+  // 4X4, 8X8, 16X16, 32X32, 64X64, 128X128
+  switch (max_dimension) {
+    case 4: square_block_size = 0; break;
+    case 8: square_block_size = 1; break;
+    case 16: square_block_size = 2; break;
+    case 32: square_block_size = 3; break;
+    case 64: square_block_size = 4; break;
+    case 128: square_block_size = 5; break;
+    default: assert(0 && "First pass block size is not supported!"); break;
+  }
+  if (is_half_width && is_half_height) {
+    return subsize_lookup[PARTITION_SPLIT][square_block_size];
+  } else if (is_half_width) {
+    return subsize_lookup[PARTITION_VERT][square_block_size];
+  } else if (is_half_height) {
+    return subsize_lookup[PARTITION_HORZ][square_block_size];
   } else {
-    return mi_size_wide[BLOCK_16X16] * mb_row + mi_size_wide[BLOCK_8X8] <
-                   mi_params->mi_rows
-               ? BLOCK_8X16
-               : BLOCK_8X8;
+    return fp_block_size;
   }
 }
 
@@ -307,8 +363,8 @@
 //   cpi: the encoder setting. Only a few params in it will be used.
 //   this_frame: the current frame buffer.
 //   tile: tile information (not used in first pass, already init to zero)
-//   mb_row: row index in the unit of first pass block size.
-//   mb_col: column index in the unit of first pass block size.
+//   unit_row: row index in the unit of first pass block size.
+//   unit_col: column index in the unit of first pass block size.
 //   y_offset: the offset of y frame buffer, indicating the starting point of
 //             the current block.
 //   uv_offset: the offset of u and v frame buffer, indicating the starting
@@ -327,7 +383,7 @@
 //   this_intra_error.
 static int firstpass_intra_prediction(
     AV1_COMP *cpi, ThreadData *td, YV12_BUFFER_CONFIG *const this_frame,
-    const TileInfo *const tile, const int mb_row, const int mb_col,
+    const TileInfo *const tile, const int unit_row, const int unit_col,
     const int y_offset, const int uv_offset, const BLOCK_SIZE fp_block_size,
     const int qindex, FRAME_STATS *const stats) {
   const AV1_COMMON *const cm = &cpi->common;
@@ -335,28 +391,28 @@
   const SequenceHeader *const seq_params = &cm->seq_params;
   MACROBLOCK *const x = &td->mb;
   MACROBLOCKD *const xd = &x->e_mbd;
-  const int mb_scale = mi_size_wide[fp_block_size];
-  const int use_dc_pred = (mb_col || mb_row) && (!mb_col || !mb_row);
+  const int unit_scale = mi_size_wide[fp_block_size];
+  const int use_dc_pred = (unit_col || unit_row) && (!unit_col || !unit_row);
   const int num_planes = av1_num_planes(cm);
-  const BLOCK_SIZE bsize = get_bsize(mi_params, mb_row, mb_col);
+  const BLOCK_SIZE bsize =
+      get_bsize(mi_params, fp_block_size, unit_row, unit_col);
 
   aom_clear_system_state();
-  set_mi_offsets(mi_params, xd, mb_row * mb_scale, mb_col * mb_scale);
+  set_mi_offsets(mi_params, xd, unit_row * unit_scale, unit_col * unit_scale);
   xd->plane[0].dst.buf = this_frame->y_buffer + y_offset;
   xd->plane[1].dst.buf = this_frame->u_buffer + uv_offset;
   xd->plane[2].dst.buf = this_frame->v_buffer + uv_offset;
-  xd->left_available = (mb_col != 0);
+  xd->left_available = (unit_col != 0);
   xd->mi[0]->bsize = bsize;
   xd->mi[0]->ref_frame[0] = INTRA_FRAME;
-  set_mi_row_col(xd, tile, mb_row * mb_scale, mi_size_high[bsize],
-                 mb_col * mb_scale, mi_size_wide[bsize], mi_params->mi_rows,
+  set_mi_row_col(xd, tile, unit_row * unit_scale, mi_size_high[bsize],
+                 unit_col * unit_scale, mi_size_wide[bsize], mi_params->mi_rows,
                  mi_params->mi_cols);
   set_plane_n4(xd, mi_size_wide[bsize], mi_size_high[bsize], num_planes);
   xd->mi[0]->segment_id = 0;
   xd->lossless[xd->mi[0]->segment_id] = (qindex == 0);
   xd->mi[0]->mode = DC_PRED;
-  xd->mi[0]->tx_size =
-      use_dc_pred ? (bsize >= fp_block_size ? TX_16X16 : TX_8X8) : TX_4X4;
+  xd->mi[0]->tx_size = use_dc_pred ? max_txsize_lookup[bsize] : TX_4X4;
 
   av1_encode_intra_block_plane(cpi, x, bsize, 0, DRY_RUN_NORMAL, 0);
   int this_intra_error = aom_get_mb_ss(x->plane[0].src_diff);
@@ -375,8 +431,8 @@
 
   if (this_intra_error < UL_INTRA_THRESH) {
     ++stats->intra_skip_count;
-  } else if ((mb_col > 0) && (stats->image_data_start_row == INVALID_ROW)) {
-    stats->image_data_start_row = mb_row;
+  } else if ((unit_col > 0) && (stats->image_data_start_row == INVALID_ROW)) {
+    stats->image_data_start_row = unit_row;
   }
 
   aom_clear_system_state();
@@ -506,8 +562,8 @@
 //   last_frame: the frame buffer of the last frame.
 //   golden_frame: the frame buffer of the golden frame.
 //   alt_ref_frame: the frame buffer of the alt ref frame.
-//   mb_row: row index in the unit of first pass block size.
-//   mb_col: column index in the unit of first pass block size.
+//   unit_row: row index in the unit of first pass block size.
+//   unit_col: column index in the unit of first pass block size.
 //   recon_yoffset: the y offset of the reconstructed  frame buffer,
 //                  indicating the starting point of the current block.
 //   recont_uvoffset: the u/v offset of the reconstructed frame buffer,
@@ -531,8 +587,8 @@
 static int firstpass_inter_prediction(
     AV1_COMP *cpi, ThreadData *td, const YV12_BUFFER_CONFIG *const last_frame,
     const YV12_BUFFER_CONFIG *const golden_frame,
-    const YV12_BUFFER_CONFIG *const alt_ref_frame, const int mb_row,
-    const int mb_col, const int recon_yoffset, const int recon_uvoffset,
+    const YV12_BUFFER_CONFIG *const alt_ref_frame, const int unit_row,
+    const int unit_col, const int recon_yoffset, const int recon_uvoffset,
     const int src_yoffset, const int alt_ref_frame_yoffset,
     const BLOCK_SIZE fp_block_size, const int this_intra_error,
     const int raw_motion_err_counts, int *raw_motion_err_list, MV *best_ref_mv,
@@ -545,17 +601,21 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   const int is_high_bitdepth = is_cur_buf_hbd(xd);
   const int bitdepth = xd->bd;
-  const int mb_scale = mi_size_wide[fp_block_size];
-  const BLOCK_SIZE bsize = get_bsize(mi_params, mb_row, mb_col);
+  const int unit_scale = mi_size_wide[fp_block_size];
+  const BLOCK_SIZE bsize =
+      get_bsize(mi_params, fp_block_size, unit_row, unit_col);
   const int fp_block_size_height = block_size_wide[fp_block_size];
+  const int unit_width = mi_size_wide[fp_block_size];
+  const int unit_rows = get_unit_rows(fp_block_size, mi_params->mb_rows);
+  const int unit_cols = get_unit_cols(fp_block_size, mi_params->mb_cols);
   // Assume 0,0 motion with no mv overhead.
   FULLPEL_MV mv = kZeroFullMv;
   FULLPEL_MV tmp_mv = kZeroFullMv;
   xd->plane[0].pre[0].buf = last_frame->y_buffer + recon_yoffset;
   // Set up limit values for motion vectors to prevent them extending
   // outside the UMV borders.
-  av1_set_mv_col_limits(mi_params, &x->mv_limits, (mb_col << FP_MIB_SIZE_LOG2),
-                        (fp_block_size_height >> MI_SIZE_LOG2),
+  av1_set_mv_col_limits(mi_params, &x->mv_limits, unit_col * unit_width,
+                        fp_block_size_height >> MI_SIZE_LOG2,
                         cpi->oxcf.border_in_pixels);
 
   int motion_error =
@@ -681,8 +741,9 @@
     xd->mi[0]->tx_size = TX_4X4;
     xd->mi[0]->ref_frame[0] = LAST_FRAME;
     xd->mi[0]->ref_frame[1] = NONE_FRAME;
-    av1_enc_build_inter_predictor(cm, xd, mb_row * mb_scale, mb_col * mb_scale,
-                                  NULL, bsize, AOM_PLANE_Y, AOM_PLANE_Y);
+    av1_enc_build_inter_predictor(cm, xd, unit_row * unit_scale,
+                                  unit_col * unit_scale, NULL, bsize,
+                                  AOM_PLANE_Y, AOM_PLANE_Y);
     av1_encode_sby_pass1(cpi, x, bsize);
     stats->sum_mvr += best_mv.row;
     stats->sum_mvr_abs += abs(best_mv.row);
@@ -693,8 +754,8 @@
     ++stats->inter_count;
 
     *best_ref_mv = best_mv;
-    accumulate_mv_stats(best_mv, mv, mb_row, mb_col, mi_params->mb_rows,
-                        mi_params->mb_cols, last_mv, stats);
+    accumulate_mv_stats(best_mv, mv, unit_row, unit_col, unit_rows, unit_cols,
+                        last_mv, stats);
   }
 
   return this_inter_error;
@@ -718,7 +779,8 @@
                                    const FRAME_STATS *const stats,
                                    const double raw_err_stdev,
                                    const int frame_number,
-                                   const int64_t ts_duration) {
+                                   const int64_t ts_duration,
+                                   const BLOCK_SIZE fp_block_size) {
   TWO_PASS *twopass = &cpi->twopass;
   AV1_COMMON *const cm = &cpi->common;
   const CommonModeInfoParams *const mi_params = &cm->mi_params;
@@ -729,9 +791,12 @@
   // where the typical "real" energy per MB also falls.
   // Initial estimate here uses sqrt(mbs) to define the min_err, where the
   // number of mbs is proportional to the image area.
-  const int num_mbs = (cpi->oxcf.resize_cfg.resize_mode != RESIZE_NONE)
-                          ? cpi->initial_mbs
-                          : mi_params->MBs;
+  const int num_mbs_16X16 = (cpi->oxcf.resize_cfg.resize_mode != RESIZE_NONE)
+                                ? cpi->initial_mbs
+                                : mi_params->MBs;
+  // Number of actual units used in the first pass, it can be other square
+  // block sizes than 16X16.
+  const int num_mbs = get_num_mbs(fp_block_size, num_mbs_16X16);
   const double min_err = 200 * sqrt(num_mbs);
 
   fps.weight = stats->intra_factor * stats->brightness_factor;
@@ -858,18 +923,18 @@
 }
 
 static void setup_firstpass_data(AV1_COMMON *const cm,
-                                 FirstPassData *firstpass_data) {
-  const CommonModeInfoParams *const mi_params = &cm->mi_params;
+                                 FirstPassData *firstpass_data,
+                                 const int unit_rows, const int unit_cols) {
   CHECK_MEM_ERROR(cm, firstpass_data->raw_motion_err_list,
-                  aom_calloc(mi_params->mb_rows * mi_params->mb_cols,
+                  aom_calloc(unit_rows * unit_cols,
                              sizeof(*firstpass_data->raw_motion_err_list)));
-  CHECK_MEM_ERROR(cm, firstpass_data->mb_stats,
-                  aom_calloc(mi_params->mb_rows * mi_params->mb_cols,
-                             sizeof(*firstpass_data->mb_stats)));
-  for (int j = 0; j < mi_params->mb_rows; j++) {
-    for (int i = 0; i < mi_params->mb_cols; i++) {
-      firstpass_data->mb_stats[j * mi_params->mb_cols + i]
-          .image_data_start_row = INVALID_ROW;
+  CHECK_MEM_ERROR(
+      cm, firstpass_data->mb_stats,
+      aom_calloc(unit_rows * unit_cols, sizeof(*firstpass_data->mb_stats)));
+  for (int j = 0; j < unit_rows; j++) {
+    for (int i = 0; i < unit_cols; i++) {
+      firstpass_data->mb_stats[j * unit_cols + i].image_data_start_row =
+          INVALID_ROW;
     }
   }
 }
@@ -879,33 +944,39 @@
   aom_free(firstpass_data->mb_stats);
 }
 
-int av1_get_mb_rows_in_tile(TileInfo tile) {
-  int mi_rows_aligned_to_mb =
-      ALIGN_POWER_OF_TWO(tile.mi_row_end - tile.mi_row_start, FP_MIB_SIZE_LOG2);
-  int mb_rows = mi_rows_aligned_to_mb >> FP_MIB_SIZE_LOG2;
+int av1_get_unit_rows_in_tile(TileInfo tile, const BLOCK_SIZE fp_block_size) {
+  const int unit_height_log2 = mi_size_high_log2[fp_block_size];
+  const int mi_rows_aligned_to_unit =
+      ALIGN_POWER_OF_TWO(tile.mi_row_end - tile.mi_row_start, unit_height_log2);
+  const int unit_rows = mi_rows_aligned_to_unit >> unit_height_log2;
 
-  return mb_rows;
+  return unit_rows;
 }
 
-int av1_get_mb_cols_in_tile(TileInfo tile) {
-  int mi_cols_aligned_to_mb =
-      ALIGN_POWER_OF_TWO(tile.mi_col_end - tile.mi_col_start, FP_MIB_SIZE_LOG2);
-  int mb_cols = mi_cols_aligned_to_mb >> FP_MIB_SIZE_LOG2;
+int av1_get_unit_cols_in_tile(TileInfo tile, const BLOCK_SIZE fp_block_size) {
+  const int unit_width_log2 = mi_size_wide_log2[fp_block_size];
+  const int mi_cols_aligned_to_unit =
+      ALIGN_POWER_OF_TWO(tile.mi_col_end - tile.mi_col_start, unit_width_log2);
+  const int unit_cols = mi_cols_aligned_to_unit >> unit_width_log2;
 
-  return mb_cols;
+  return unit_cols;
 }
 
 #define FIRST_PASS_ALT_REF_DISTANCE 16
 static void first_pass_tile(AV1_COMP *cpi, ThreadData *td,
-                            TileDataEnc *tile_data) {
+                            TileDataEnc *tile_data,
+                            const BLOCK_SIZE fp_block_size) {
   TileInfo *tile = &tile_data->tile_info;
+  const int unit_height = mi_size_high[fp_block_size];
+  const int unit_height_log2 = mi_size_high_log2[fp_block_size];
   for (int mi_row = tile->mi_row_start; mi_row < tile->mi_row_end;
-       mi_row += FP_MIB_SIZE) {
-    av1_first_pass_row(cpi, td, tile_data, mi_row >> FP_MIB_SIZE_LOG2);
+       mi_row += unit_height) {
+    av1_first_pass_row(cpi, td, tile_data, mi_row >> unit_height_log2,
+                       fp_block_size);
   }
 }
 
-static void first_pass_tiles(AV1_COMP *cpi) {
+static void first_pass_tiles(AV1_COMP *cpi, const BLOCK_SIZE fp_block_size) {
   AV1_COMMON *const cm = &cpi->common;
   const int tile_cols = cm->tiles.cols;
   const int tile_rows = cm->tiles.rows;
@@ -913,13 +984,13 @@
     for (int tile_col = 0; tile_col < tile_cols; ++tile_col) {
       TileDataEnc *const tile_data =
           &cpi->tile_data[tile_row * tile_cols + tile_col];
-      first_pass_tile(cpi, &cpi->td, tile_data);
+      first_pass_tile(cpi, &cpi->td, tile_data, fp_block_size);
     }
   }
 }
 
 void av1_first_pass_row(AV1_COMP *cpi, ThreadData *td, TileDataEnc *tile_data,
-                        int mb_row) {
+                        const int unit_row, const BLOCK_SIZE fp_block_size) {
   MACROBLOCK *const x = &td->mb;
   AV1_COMMON *const cm = &cpi->common;
   const CommonModeInfoParams *const mi_params = &cm->mi_params;
@@ -929,14 +1000,16 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   TileInfo *tile = &tile_data->tile_info;
   const int qindex = find_fp_qindex(seq_params->bit_depth);
-  // First pass coding proceeds in raster scan order with unit size of 16x16.
-  const BLOCK_SIZE fp_block_size = BLOCK_16X16;
   const int fp_block_size_width = block_size_high[fp_block_size];
   const int fp_block_size_height = block_size_wide[fp_block_size];
+  const int unit_width = mi_size_wide[fp_block_size];
+  const int unit_width_log2 = mi_size_wide_log2[fp_block_size];
+  const int unit_height_log2 = mi_size_high_log2[fp_block_size];
+  const int unit_cols = mi_params->mb_cols * 4 / unit_width;
   int raw_motion_err_counts = 0;
-  int mb_row_in_tile = mb_row - (tile->mi_row_start >> FP_MIB_SIZE_LOG2);
-  int mb_col_start = tile->mi_col_start >> FP_MIB_SIZE_LOG2;
-  int mb_cols_in_tile = av1_get_mb_cols_in_tile(*tile);
+  int unit_row_in_tile = unit_row - (tile->mi_row_start >> unit_height_log2);
+  int unit_col_start = tile->mi_col_start >> unit_width_log2;
+  int unit_cols_in_tile = av1_get_unit_cols_in_tile(*tile, fp_block_size);
   MultiThreadInfo *const mt_info = &cpi->mt_info;
   AV1EncRowMultiThreadInfo *const enc_row_mt = &mt_info->enc_row_mt;
   AV1EncRowMultiThreadSync *const row_mt_sync = &tile_data->row_mt_sync;
@@ -961,9 +1034,9 @@
 
   PICK_MODE_CONTEXT *ctx = td->firstpass_ctx;
   FRAME_STATS *mb_stats =
-      cpi->firstpass_data.mb_stats + mb_row * mi_params->mb_cols + mb_col_start;
+      cpi->firstpass_data.mb_stats + unit_row * unit_cols + unit_col_start;
   int *raw_motion_err_list = cpi->firstpass_data.raw_motion_err_list +
-                             mb_row * mi_params->mb_cols + mb_col_start;
+                             unit_row * unit_cols + unit_col_start;
   MV *first_top_mv = &tile_data->firstpass_top_mv;
 
   for (int i = 0; i < num_planes; ++i) {
@@ -984,26 +1057,26 @@
   MV last_mv;
 
   // Reset above block coeffs.
-  xd->up_available = (mb_row_in_tile != 0);
-  int recon_yoffset = (mb_row * recon_y_stride * fp_block_size_height) +
-                      (mb_col_start * fp_block_size_width);
-  int src_yoffset = (mb_row * src_y_stride * fp_block_size_height) +
-                    (mb_col_start * fp_block_size_width);
-  int recon_uvoffset =
-      (mb_row * recon_uv_stride * uv_mb_height) + (mb_col_start * uv_mb_height);
+  xd->up_available = (unit_row_in_tile != 0);
+  int recon_yoffset = (unit_row * recon_y_stride * fp_block_size_height) +
+                      (unit_col_start * fp_block_size_width);
+  int src_yoffset = (unit_row * src_y_stride * fp_block_size_height) +
+                    (unit_col_start * fp_block_size_width);
+  int recon_uvoffset = (unit_row * recon_uv_stride * uv_mb_height) +
+                       (unit_col_start * uv_mb_height);
   int alt_ref_frame_yoffset =
       (alt_ref_frame != NULL)
-          ? (mb_row * alt_ref_frame->y_stride * fp_block_size_height) +
-                (mb_col_start * fp_block_size_width)
+          ? (unit_row * alt_ref_frame->y_stride * fp_block_size_height) +
+                (unit_col_start * fp_block_size_width)
           : -1;
 
   // Set up limit values for motion vectors to prevent them extending
   // outside the UMV borders.
-  av1_set_mv_row_limits(mi_params, &x->mv_limits, (mb_row << FP_MIB_SIZE_LOG2),
-                        (fp_block_size_height >> MI_SIZE_LOG2),
-                        cpi->oxcf.border_in_pixels);
+  av1_set_mv_row_limits(
+      mi_params, &x->mv_limits, (unit_row << unit_height_log2),
+      (fp_block_size_height >> MI_SIZE_LOG2), cpi->oxcf.border_in_pixels);
 
-  av1_setup_src_planes(x, cpi->source, mb_row << FP_MIB_SIZE_LOG2,
+  av1_setup_src_planes(x, cpi->source, unit_row << unit_height_log2,
                        tile->mi_col_start, num_planes, fp_block_size);
 
   // Fix - zero the 16x16 block first. This ensures correct this_intra_error for
@@ -1011,26 +1084,27 @@
   av1_zero_array(x->plane[0].src_diff, 256);
 
   for (int mi_col = tile->mi_col_start; mi_col < tile->mi_col_end;
-       mi_col += FP_MIB_SIZE) {
-    int mb_col = mi_col >> FP_MIB_SIZE_LOG2;
-    int mb_col_in_tile = mb_col - mb_col_start;
+       mi_col += unit_width) {
+    const int unit_col = mi_col >> unit_width_log2;
+    const int unit_col_in_tile = unit_col - unit_col_start;
 
-    (*(enc_row_mt->sync_read_ptr))(row_mt_sync, mb_row_in_tile, mb_col_in_tile);
+    (*(enc_row_mt->sync_read_ptr))(row_mt_sync, unit_row_in_tile,
+                                   unit_col_in_tile);
 
-    if (mb_col_in_tile == 0) {
+    if (unit_col_in_tile == 0) {
       last_mv = *first_top_mv;
     }
     int this_intra_error = firstpass_intra_prediction(
-        cpi, td, this_frame, tile, mb_row, mb_col, recon_yoffset,
+        cpi, td, this_frame, tile, unit_row, unit_col, recon_yoffset,
         recon_uvoffset, fp_block_size, qindex, mb_stats);
 
     if (!frame_is_intra_only(cm)) {
       const int this_inter_error = firstpass_inter_prediction(
-          cpi, td, last_frame, golden_frame, alt_ref_frame, mb_row, mb_col,
+          cpi, td, last_frame, golden_frame, alt_ref_frame, unit_row, unit_col,
           recon_yoffset, recon_uvoffset, src_yoffset, alt_ref_frame_yoffset,
           fp_block_size, this_intra_error, raw_motion_err_counts,
           raw_motion_err_list, &best_ref_mv, &last_mv, mb_stats);
-      if (mb_col_in_tile == 0) {
+      if (unit_col_in_tile == 0) {
         *first_top_mv = last_mv;
       }
       mb_stats->coded_error += this_inter_error;
@@ -1052,8 +1126,8 @@
     alt_ref_frame_yoffset += fp_block_size_width;
     mb_stats++;
 
-    (*(enc_row_mt->sync_write_ptr))(row_mt_sync, mb_row_in_tile, mb_col_in_tile,
-                                    mb_cols_in_tile);
+    (*(enc_row_mt->sync_write_ptr))(row_mt_sync, unit_row_in_tile,
+                                    unit_col_in_tile, unit_cols_in_tile);
   }
 }
 
@@ -1071,10 +1145,17 @@
     FeatureFlags *const features = &cm->features;
     av1_set_screen_content_options(cpi, features);
   }
-  // First pass coding proceeds in raster scan order with unit size of 16x16.
+  // Unit size for the first pass encoding.
   const BLOCK_SIZE fp_block_size = BLOCK_16X16;
+  // Number of rows in the unit size.
+  // Note mi_params->mb_rows and mi_params->mb_cols are in the unit of 16x16.
+  const int unit_rows = get_unit_rows(fp_block_size, mi_params->mb_rows);
+  const int unit_cols = get_unit_cols(fp_block_size, mi_params->mb_cols);
 
-  setup_firstpass_data(cm, &cpi->firstpass_data);
+  // Set fp_block_size, for the convenience of multi-thread usage.
+  cpi->fp_block_size = fp_block_size;
+
+  setup_firstpass_data(cm, &cpi->firstpass_data, unit_rows, unit_cols);
   int *raw_motion_err_list = cpi->firstpass_data.raw_motion_err_list;
   FRAME_STATS *mb_stats = cpi->firstpass_data.mb_stats;
 
@@ -1141,39 +1222,42 @@
     enc_row_mt->sync_write_ptr = av1_row_mt_sync_write;
     av1_fp_encode_tiles_row_mt(cpi);
   } else {
-    first_pass_tiles(cpi);
+    first_pass_tiles(cpi, fp_block_size);
   }
 
-  FRAME_STATS stats =
-      accumulate_frame_stats(mb_stats, mi_params->mb_rows, mi_params->mb_cols);
+  FRAME_STATS stats = accumulate_frame_stats(mb_stats, unit_rows, unit_cols);
   int total_raw_motion_err_count =
-      frame_is_intra_only(cm) ? 0 : mi_params->mb_rows * mi_params->mb_cols;
+      frame_is_intra_only(cm) ? 0 : unit_rows * unit_cols;
   const double raw_err_stdev =
       raw_motion_error_stdev(raw_motion_err_list, total_raw_motion_err_count);
   free_firstpass_data(&cpi->firstpass_data);
 
   // Clamp the image start to rows/2. This number of rows is discarded top
   // and bottom as dead data so rows / 2 means the frame is blank.
-  if ((stats.image_data_start_row > mi_params->mb_rows / 2) ||
+  if ((stats.image_data_start_row > unit_rows / 2) ||
       (stats.image_data_start_row == INVALID_ROW)) {
-    stats.image_data_start_row = mi_params->mb_rows / 2;
+    stats.image_data_start_row = unit_rows / 2;
   }
   // Exclude any image dead zone
   if (stats.image_data_start_row > 0) {
     stats.intra_skip_count =
         AOMMAX(0, stats.intra_skip_count -
-                      (stats.image_data_start_row * mi_params->mb_cols * 2));
+                      (stats.image_data_start_row * unit_cols * 2));
   }
 
   TWO_PASS *twopass = &cpi->twopass;
-  const int num_mbs = (cpi->oxcf.resize_cfg.resize_mode != RESIZE_NONE)
-                          ? cpi->initial_mbs
-                          : mi_params->MBs;
+  const int num_mbs_16X16 = (cpi->oxcf.resize_cfg.resize_mode != RESIZE_NONE)
+                                ? cpi->initial_mbs
+                                : mi_params->MBs;
+  // Number of actual units used in the first pass, it can be other square
+  // block sizes than 16X16.
+  const int num_mbs = get_num_mbs(fp_block_size, num_mbs_16X16);
   stats.intra_factor = stats.intra_factor / (double)num_mbs;
   stats.brightness_factor = stats.brightness_factor / (double)num_mbs;
   FIRSTPASS_STATS *this_frame_stats = twopass->stats_buf_ctx->stats_in_end;
   update_firstpass_stats(cpi, &stats, raw_err_stdev,
-                         current_frame->frame_number, ts_duration);
+                         current_frame->frame_number, ts_duration,
+                         fp_block_size);
 
   // Copy the previous Last Frame back into gf buffer if the prediction is good
   // enough... but also don't allow it to lag too far.
diff --git a/av1/encoder/firstpass.h b/av1/encoder/firstpass.h
index 8764e77..22969e8 100644
--- a/av1/encoder/firstpass.h
+++ b/av1/encoder/firstpass.h
@@ -30,10 +30,6 @@
 
 #define VLOW_MOTION_THRESHOLD 950
 
-// size of firstpass macroblocks in terms of MIs.
-#define FP_MIB_SIZE 4
-#define FP_MIB_SIZE_LOG2 2
-
 /*!
  * \brief The stucture of acummulated frame stats in the first pass.
  */
@@ -331,12 +327,13 @@
 struct AV1EncoderConfig;
 struct TileDataEnc;
 
-int av1_get_mb_rows_in_tile(TileInfo tile);
-int av1_get_mb_cols_in_tile(TileInfo tile);
+int av1_get_unit_rows_in_tile(TileInfo tile, const BLOCK_SIZE fp_block_size);
+int av1_get_unit_cols_in_tile(TileInfo tile, const BLOCK_SIZE fp_block_size);
 
 void av1_rc_get_first_pass_params(struct AV1_COMP *cpi);
 void av1_first_pass_row(struct AV1_COMP *cpi, struct ThreadData *td,
-                        struct TileDataEnc *tile_data, int mb_row);
+                        struct TileDataEnc *tile_data, const int mb_row,
+                        const BLOCK_SIZE fp_block_size);
 void av1_end_first_pass(struct AV1_COMP *cpi);
 
 void av1_twopass_zero_stats(FIRSTPASS_STATS *section);