Refactor the direct partition merging code

Change-Id: I449e1e2eed2dd27d25ae20ebbc099fa4275920f1
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index 60365cf..e57e544 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -2282,6 +2282,197 @@
 #endif
 }
 
+// Evaluate if the sub-partitions can be merged directly into a large partition
+// without calculating the RD cost.
+static void direct_partition_merging(AV1_COMP *cpi, ThreadData *td,
+                                     TileDataEnc *tile_data, MB_MODE_INFO **mib,
+                                     int mi_row, int mi_col, BLOCK_SIZE bsize) {
+  AV1_COMMON *const cm = &cpi->common;
+  const CommonModeInfoParams *const mi_params = &cm->mi_params;
+  TileInfo *const tile_info = &tile_data->tile_info;
+  MACROBLOCK *const x = &td->mb;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  const int bs = mi_size_wide[bsize];
+  const int hbs = bs / 2;
+  const PARTITION_TYPE partition =
+      (bsize >= BLOCK_8X8) ? get_partition(cm, mi_row, mi_col, bsize)
+                           : PARTITION_NONE;
+  BLOCK_SIZE subsize = get_partition_subsize(bsize, partition);
+
+  MB_MODE_INFO **b0 = mib;
+  MB_MODE_INFO **b1 = mib + hbs;
+  MB_MODE_INFO **b2 = mib + hbs * mi_params->mi_stride;
+  MB_MODE_INFO **b3 = mib + hbs * mi_params->mi_stride + hbs;
+
+  // Check if the following conditions are met. This can be updated
+  // later with more support added.
+  const int further_split = b0[0]->bsize < subsize || b1[0]->bsize < subsize ||
+                            b2[0]->bsize < subsize || b3[0]->bsize < subsize;
+  if (further_split) return;
+
+  const int no_skip = !b0[0]->skip_txfm || !b1[0]->skip_txfm ||
+                      !b2[0]->skip_txfm || !b3[0]->skip_txfm;
+  if (no_skip) return;
+
+  const int compound = (b0[0]->ref_frame[1] != b1[0]->ref_frame[1] ||
+                        b0[0]->ref_frame[1] != b2[0]->ref_frame[1] ||
+                        b0[0]->ref_frame[1] != b3[0]->ref_frame[1] ||
+                        b0[0]->ref_frame[1] > NONE_FRAME);
+  if (compound) return;
+
+  // Intra modes aren't considered here.
+  const int different_ref = (b0[0]->ref_frame[0] != b1[0]->ref_frame[0] ||
+                             b0[0]->ref_frame[0] != b2[0]->ref_frame[0] ||
+                             b0[0]->ref_frame[0] != b3[0]->ref_frame[0] ||
+                             b0[0]->ref_frame[0] <= INTRA_FRAME);
+  if (different_ref) return;
+
+  const int different_mode =
+      (b0[0]->mode != b1[0]->mode || b0[0]->mode != b2[0]->mode ||
+       b0[0]->mode != b3[0]->mode);
+  if (different_mode) return;
+
+  const int unsupported_mode =
+      (b0[0]->mode != NEARESTMV && b0[0]->mode != GLOBALMV);
+  if (unsupported_mode) return;
+
+  const int different_mv = (b0[0]->mv[0].as_int != b1[0]->mv[0].as_int ||
+                            b0[0]->mv[0].as_int != b2[0]->mv[0].as_int ||
+                            b0[0]->mv[0].as_int != b3[0]->mv[0].as_int);
+  if (different_mv) return;
+
+  const int unsupported_motion_mode =
+      (b0[0]->motion_mode != b1[0]->motion_mode ||
+       b0[0]->motion_mode != b2[0]->motion_mode ||
+       b0[0]->motion_mode != b3[0]->motion_mode ||
+       b0[0]->motion_mode != SIMPLE_TRANSLATION);
+  if (unsupported_motion_mode) return;
+
+  const int diffent_filter =
+      (b0[0]->interp_filters.as_int != b1[0]->interp_filters.as_int ||
+       b0[0]->interp_filters.as_int != b2[0]->interp_filters.as_int ||
+       b0[0]->interp_filters.as_int != b3[0]->interp_filters.as_int);
+  if (diffent_filter) return;
+
+  const int different_seg = (b0[0]->segment_id != b1[0]->segment_id ||
+                             b0[0]->segment_id != b2[0]->segment_id ||
+                             b0[0]->segment_id != b3[0]->segment_id);
+  if (different_seg) return;
+
+  // Evaluate the ref_mv.
+  MB_MODE_INFO **this_mi = mib;
+  BLOCK_SIZE orig_bsize = this_mi[0]->bsize;
+  const PARTITION_TYPE orig_partition = this_mi[0]->partition;
+
+  this_mi[0]->bsize = bsize;
+  this_mi[0]->partition = PARTITION_NONE;
+  this_mi[0]->skip_txfm = 1;
+
+  // TODO(yunqing): functions called below can be optimized with
+  // removing unrelated operations.
+  av1_set_offsets_without_segment_id(cpi, &tile_data->tile_info, x, mi_row,
+                                     mi_col, bsize);
+
+  const MV_REFERENCE_FRAME ref_frame = this_mi[0]->ref_frame[0];
+  int_mv frame_mv[MB_MODE_COUNT][REF_FRAMES];
+  struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE];
+  int force_skip_low_temp_var = 0;
+  int skip_pred_mv = 0;
+
+  for (int i = 0; i < MB_MODE_COUNT; ++i) {
+    for (int j = 0; j < REF_FRAMES; ++j) {
+      frame_mv[i][j].as_int = INVALID_MV;
+    }
+  }
+  x->color_sensitivity[0] = x->color_sensitivity_sb[0];
+  x->color_sensitivity[1] = x->color_sensitivity_sb[1];
+  skip_pred_mv = (x->nonrd_prune_ref_frame_search > 2 &&
+                  x->color_sensitivity[0] != 2 && x->color_sensitivity[1] != 2);
+
+  find_predictors(cpi, x, ref_frame, frame_mv, tile_data, yv12_mb, bsize,
+                  force_skip_low_temp_var, skip_pred_mv);
+
+  int continue_merging = 1;
+  if (frame_mv[NEARESTMV][ref_frame].as_mv.row != b0[0]->mv[0].as_mv.row ||
+      frame_mv[NEARESTMV][ref_frame].as_mv.col != b0[0]->mv[0].as_mv.col)
+    continue_merging = 0;
+
+  if (!continue_merging) {
+    this_mi[0]->bsize = orig_bsize;
+    this_mi[0]->partition = orig_partition;
+
+    // TODO(yunqing): Store the results and restore here instead of
+    // calling find_predictors() again.
+    av1_set_offsets_without_segment_id(cpi, &tile_data->tile_info, x, mi_row,
+                                       mi_col, this_mi[0]->bsize);
+    find_predictors(cpi, x, ref_frame, frame_mv, tile_data, yv12_mb,
+                    this_mi[0]->bsize, force_skip_low_temp_var, skip_pred_mv);
+  } else {
+    struct scale_factors *sf = get_ref_scale_factors(cm, ref_frame);
+    const int is_scaled = av1_is_scaled(sf);
+    const int is_y_subpel_mv = (abs(this_mi[0]->mv[0].as_mv.row) % 8) ||
+                               (abs(this_mi[0]->mv[0].as_mv.col) % 8);
+    const int is_uv_subpel_mv = (abs(this_mi[0]->mv[0].as_mv.row) % 16) ||
+                                (abs(this_mi[0]->mv[0].as_mv.col) % 16);
+
+    if (cpi->ppi->use_svc || is_scaled || is_y_subpel_mv || is_uv_subpel_mv) {
+      const int num_planes = av1_num_planes(cm);
+      set_ref_ptrs(cm, xd, ref_frame, this_mi[0]->ref_frame[1]);
+      const YV12_BUFFER_CONFIG *cfg = get_ref_frame_yv12_buf(cm, ref_frame);
+      av1_setup_pre_planes(xd, 0, cfg, mi_row, mi_col,
+                           xd->block_ref_scale_factors[0], num_planes);
+
+      if (!cpi->ppi->use_svc && !is_scaled && !is_y_subpel_mv) {
+        assert(is_uv_subpel_mv == 1);
+        av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize, 1,
+                                      num_planes - 1);
+      } else {
+        av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize, 0,
+                                      num_planes - 1);
+      }
+    }
+
+    // Copy out mbmi_ext information.
+    MB_MODE_INFO_EXT *const mbmi_ext = &x->mbmi_ext;
+    MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame = x->mbmi_ext_frame;
+    av1_copy_mbmi_ext_to_mbmi_ext_frame(
+        mbmi_ext_frame, mbmi_ext, av1_ref_frame_type(this_mi[0]->ref_frame));
+
+    const BLOCK_SIZE this_subsize =
+        get_partition_subsize(bsize, this_mi[0]->partition);
+    // Update partition contexts.
+    update_ext_partition_context(xd, mi_row, mi_col, this_subsize, bsize,
+                                 this_mi[0]->partition);
+
+    const int num_planes = av1_num_planes(cm);
+    av1_reset_entropy_context(xd, bsize, num_planes);
+
+    // Note: use x->txfm_search_params.tx_mode_search_type instead of
+    // cm->features.tx_mode here.
+    TX_SIZE tx_size =
+        tx_size_from_tx_mode(bsize, x->txfm_search_params.tx_mode_search_type);
+    if (xd->lossless[this_mi[0]->segment_id]) tx_size = TX_4X4;
+    this_mi[0]->tx_size = tx_size;
+    memset(this_mi[0]->inter_tx_size, this_mi[0]->tx_size,
+           sizeof(this_mi[0]->inter_tx_size));
+
+    // Update txfm contexts.
+    xd->above_txfm_context =
+        cm->above_contexts.txfm[tile_info->tile_row] + mi_col;
+    xd->left_txfm_context =
+        xd->left_txfm_context_buffer + (mi_row & MAX_MIB_MASK);
+    set_txfm_ctxs(this_mi[0]->tx_size, xd->width, xd->height,
+                  this_mi[0]->skip_txfm && is_inter_block(this_mi[0]), xd);
+
+    // Update mi for this partition block.
+    for (int y = 0; y < bs; y++) {
+      for (int x_idx = 0; x_idx < bs; x_idx++) {
+        this_mi[x_idx + y * mi_params->mi_stride] = this_mi[0];
+      }
+    }
+  }
+}
+
 /*!\brief AV1 block partition application (minimal RD search).
 *
 * \ingroup partition_search
@@ -2601,191 +2792,8 @@
                 mode_costs->partition_cost[pl][PARTITION_SPLIT] &&
             (mi_row + bs <= mi_params->mi_rows) &&
             (mi_col + bs <= mi_params->mi_cols)) {
-          MB_MODE_INFO **b0 = mib;
-          MB_MODE_INFO **b1 = mib + hbs;
-          MB_MODE_INFO **b2 = mib + hbs * mi_params->mi_stride;
-          MB_MODE_INFO **b3 = mib + hbs * mi_params->mi_stride + hbs;
-
-          // Check if the following conditions are met. This can be updated
-          // later with more support added.
-          const int further_split =
-              b0[0]->bsize < subsize || b1[0]->bsize < subsize ||
-              b2[0]->bsize < subsize || b3[0]->bsize < subsize;
-          if (further_split) break;
-
-          const int no_skip = !b0[0]->skip_txfm || !b1[0]->skip_txfm ||
-                              !b2[0]->skip_txfm || !b3[0]->skip_txfm;
-          if (no_skip) break;
-
-          const int compound = (b0[0]->ref_frame[1] != b1[0]->ref_frame[1] ||
-                                b0[0]->ref_frame[1] != b2[0]->ref_frame[1] ||
-                                b0[0]->ref_frame[1] != b3[0]->ref_frame[1] ||
-                                b0[0]->ref_frame[1] > NONE_FRAME);
-          if (compound) break;
-
-          // Intra modes aren't considered here.
-          const int different_ref =
-              (b0[0]->ref_frame[0] != b1[0]->ref_frame[0] ||
-               b0[0]->ref_frame[0] != b2[0]->ref_frame[0] ||
-               b0[0]->ref_frame[0] != b3[0]->ref_frame[0] ||
-               b0[0]->ref_frame[0] <= INTRA_FRAME);
-          if (different_ref) break;
-
-          const int different_mode =
-              (b0[0]->mode != b1[0]->mode || b0[0]->mode != b2[0]->mode ||
-               b0[0]->mode != b3[0]->mode);
-          if (different_mode) break;
-
-          const int unsupported_mode =
-              (b0[0]->mode != NEARESTMV && b0[0]->mode != GLOBALMV);
-          if (unsupported_mode) break;
-
-          const int different_mv =
-              (b0[0]->mv[0].as_int != b1[0]->mv[0].as_int ||
-               b0[0]->mv[0].as_int != b2[0]->mv[0].as_int ||
-               b0[0]->mv[0].as_int != b3[0]->mv[0].as_int);
-          if (different_mv) break;
-
-          const int unsupported_motion_mode =
-              (b0[0]->motion_mode != b1[0]->motion_mode ||
-               b0[0]->motion_mode != b2[0]->motion_mode ||
-               b0[0]->motion_mode != b3[0]->motion_mode ||
-               b0[0]->motion_mode != SIMPLE_TRANSLATION);
-          if (unsupported_motion_mode) break;
-
-          const int diffent_filter =
-              (b0[0]->interp_filters.as_int != b1[0]->interp_filters.as_int ||
-               b0[0]->interp_filters.as_int != b2[0]->interp_filters.as_int ||
-               b0[0]->interp_filters.as_int != b3[0]->interp_filters.as_int);
-          if (diffent_filter) break;
-
-          const int different_seg = (b0[0]->segment_id != b1[0]->segment_id ||
-                                     b0[0]->segment_id != b2[0]->segment_id ||
-                                     b0[0]->segment_id != b3[0]->segment_id);
-          if (different_seg) break;
-
-          // Evaluate the ref_mv.
-          MB_MODE_INFO **this_mi = mib;
-          BLOCK_SIZE orig_bsize = this_mi[0]->bsize;
-          const PARTITION_TYPE orig_partition = this_mi[0]->partition;
-
-          this_mi[0]->bsize = bsize;
-          this_mi[0]->partition = PARTITION_NONE;
-          this_mi[0]->skip_txfm = 1;
-
-          // TODO(yunqing): functions called below can be optimized with
-          // removing unrelated operations.
-          av1_set_offsets_without_segment_id(cpi, &tile_data->tile_info, x,
-                                             mi_row, mi_col, bsize);
-
-          const MV_REFERENCE_FRAME ref_frame = this_mi[0]->ref_frame[0];
-          int_mv frame_mv[MB_MODE_COUNT][REF_FRAMES];
-          struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE];
-          int force_skip_low_temp_var = 0;
-          int skip_pred_mv = 0;
-
-          for (int i = 0; i < MB_MODE_COUNT; ++i) {
-            for (int j = 0; j < REF_FRAMES; ++j) {
-              frame_mv[i][j].as_int = INVALID_MV;
-            }
-          }
-          x->color_sensitivity[0] = x->color_sensitivity_sb[0];
-          x->color_sensitivity[1] = x->color_sensitivity_sb[1];
-          skip_pred_mv =
-              (x->nonrd_prune_ref_frame_search > 2 &&
-               x->color_sensitivity[0] != 2 && x->color_sensitivity[1] != 2);
-
-          find_predictors(cpi, x, ref_frame, frame_mv, tile_data, yv12_mb,
-                          bsize, force_skip_low_temp_var, skip_pred_mv);
-
-          int continue_merging = 1;
-          if (frame_mv[NEARESTMV][ref_frame].as_mv.row !=
-                  b0[0]->mv[0].as_mv.row ||
-              frame_mv[NEARESTMV][ref_frame].as_mv.col !=
-                  b0[0]->mv[0].as_mv.col)
-            continue_merging = 0;
-
-          if (!continue_merging) {
-            this_mi[0]->bsize = orig_bsize;
-            this_mi[0]->partition = orig_partition;
-
-            // TODO(yunqing): Store the results and restore here instead of
-            // calling find_predictors() again.
-            av1_set_offsets_without_segment_id(cpi, &tile_data->tile_info, x,
-                                               mi_row, mi_col,
-                                               this_mi[0]->bsize);
-            find_predictors(cpi, x, ref_frame, frame_mv, tile_data, yv12_mb,
-                            this_mi[0]->bsize, force_skip_low_temp_var,
-                            skip_pred_mv);
-          } else {
-            struct scale_factors *sf = get_ref_scale_factors(cm, ref_frame);
-            const int is_scaled = av1_is_scaled(sf);
-            const int is_y_subpel_mv = (abs(this_mi[0]->mv[0].as_mv.row) % 8) ||
-                                       (abs(this_mi[0]->mv[0].as_mv.col) % 8);
-            const int is_uv_subpel_mv =
-                (abs(this_mi[0]->mv[0].as_mv.row) % 16) ||
-                (abs(this_mi[0]->mv[0].as_mv.col) % 16);
-
-            if (cpi->ppi->use_svc || is_scaled || is_y_subpel_mv ||
-                is_uv_subpel_mv) {
-              const int num_planes = av1_num_planes(cm);
-              set_ref_ptrs(cm, xd, ref_frame, this_mi[0]->ref_frame[1]);
-              const YV12_BUFFER_CONFIG *cfg =
-                  get_ref_frame_yv12_buf(cm, ref_frame);
-              av1_setup_pre_planes(xd, 0, cfg, mi_row, mi_col,
-                                   xd->block_ref_scale_factors[0], num_planes);
-
-              if (!cpi->ppi->use_svc && !is_scaled && !is_y_subpel_mv) {
-                assert(is_uv_subpel_mv == 1);
-                av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL,
-                                              bsize, 1, num_planes - 1);
-              } else {
-                av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL,
-                                              bsize, 0, num_planes - 1);
-              }
-            }
-
-            // Copy out mbmi_ext information.
-            MB_MODE_INFO_EXT *const mbmi_ext = &x->mbmi_ext;
-            MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame = x->mbmi_ext_frame;
-            av1_copy_mbmi_ext_to_mbmi_ext_frame(
-                mbmi_ext_frame, mbmi_ext,
-                av1_ref_frame_type(this_mi[0]->ref_frame));
-
-            const BLOCK_SIZE this_subsize =
-                get_partition_subsize(bsize, this_mi[0]->partition);
-            // Update partition contexts.
-            update_ext_partition_context(xd, mi_row, mi_col, this_subsize,
-                                         bsize, this_mi[0]->partition);
-
-            const int num_planes = av1_num_planes(cm);
-            av1_reset_entropy_context(xd, bsize, num_planes);
-
-            // Note: use x->txfm_search_params.tx_mode_search_type instead of
-            // cm->features.tx_mode here.
-            TX_SIZE tx_size = tx_size_from_tx_mode(
-                bsize, x->txfm_search_params.tx_mode_search_type);
-            if (xd->lossless[this_mi[0]->segment_id]) tx_size = TX_4X4;
-            this_mi[0]->tx_size = tx_size;
-            memset(this_mi[0]->inter_tx_size, this_mi[0]->tx_size,
-                   sizeof(this_mi[0]->inter_tx_size));
-
-            // Update txfm contexts.
-            xd->above_txfm_context =
-                cm->above_contexts.txfm[tile_info->tile_row] + mi_col;
-            xd->left_txfm_context =
-                xd->left_txfm_context_buffer + ((mi_row)&MAX_MIB_MASK);
-            set_txfm_ctxs(this_mi[0]->tx_size, xd->width, xd->height,
-                          this_mi[0]->skip_txfm && is_inter_block(this_mi[0]),
-                          xd);
-
-            // Update mi for this partition block.
-            for (int y = 0; y < bs; y++) {
-              for (int x_idx = 0; x_idx < bs; x_idx++) {
-                this_mi[x_idx + y * mi_params->mi_stride] = this_mi[0];
-              }
-            }
-          }
+          direct_partition_merging(cpi, td, tile_data, mib, mi_row, mi_col,
+                                   bsize);
         }
       }
       break;