Modularize NONE and SPLIT partition search

Abstracted the code related to PARTITION_NONE and PARTITION_SPLIT
into separate functions.

BUG=aomedia:2687

Change-Id: Id8ffb37db71aaaa1997e69bca3ec7b3dfe5d23bd
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 1cc3369..ccf55a9 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -2768,6 +2768,380 @@
                                      part4_search_allowed);
 }
 
+static bool rd_pick_partition(AV1_COMP *const cpi, ThreadData *td,
+                              TileDataEnc *tile_data, TokenExtra **tp,
+                              int mi_row, int mi_col, BLOCK_SIZE bsize,
+                              RD_STATS *rd_cost, RD_STATS best_rdc,
+                              PC_TREE *pc_tree,
+                              SIMPLE_MOTION_DATA_TREE *sms_tree,
+                              int64_t *none_rd,
+                              SB_MULTI_PASS_MODE multi_pass_mode,
+                              RD_RECT_PART_WIN_INFO *rect_part_win_info);
+
+// Set PARTITION_NONE allowed flag.
+static AOM_INLINE void set_part_none_allowed_flag(
+    AV1_COMP *const cpi, PartitionSearchState *part_search_state) {
+  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  if ((blk_params.width <= blk_params.min_partition_size_1d) &&
+      blk_params.has_rows && blk_params.has_cols)
+    part_search_state->partition_none_allowed = 1;
+  assert(part_search_state->terminate_partition_search == 0);
+
+  // Set PARTITION_NONE for screen content.
+  if (cpi->is_screen_content_type)
+    part_search_state->partition_none_allowed =
+        blk_params.has_rows && blk_params.has_cols;
+}
+
+// Set params needed for PARTITION_NONE search.
+static AOM_INLINE void set_none_partition_params(
+    const AV1_COMMON *const cm, ThreadData *td, MACROBLOCK *x, PC_TREE *pc_tree,
+    PartitionSearchState *part_search_state, RD_STATS *best_remain_rdcost,
+    RD_STATS *best_rdc, int *pt_cost) {
+  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  RD_STATS partition_rdcost;
+  // Set PARTITION_NONE context.
+  if (pc_tree->none == NULL)
+    pc_tree->none = av1_alloc_pmc(cm, blk_params.bsize, &td->shared_coeff_buf);
+
+  // Set PARTITION_NONE type cost.
+  if (part_search_state->partition_none_allowed) {
+    if (blk_params.bsize_at_least_8x8) {
+      *pt_cost = part_search_state->partition_cost[PARTITION_NONE] < INT_MAX
+                     ? part_search_state->partition_cost[PARTITION_NONE]
+                     : 0;
+    }
+
+    // Initialize the RD stats structure.
+    av1_init_rd_stats(&partition_rdcost);
+    partition_rdcost.rate = *pt_cost;
+    av1_rd_cost_update(x->rdmult, &partition_rdcost);
+    av1_rd_stats_subtraction(x->rdmult, best_rdc, &partition_rdcost,
+                             best_remain_rdcost);
+  }
+}
+
+// Skip other partitions based on PARTITION_NONE rd cost.
+static AOM_INLINE void prune_partitions_after_none(
+    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+    PICK_MODE_CONTEXT *ctx_none, PartitionSearchState *part_search_state,
+    RD_STATS *best_rdc, unsigned int *pb_source_variance) {
+  const AV1_COMMON *const cm = &cpi->common;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  const CommonModeInfoParams *const mi_params = &cm->mi_params;
+  RD_STATS *this_rdc = &part_search_state->this_rdc;
+  const BLOCK_SIZE bsize = blk_params.bsize;
+  assert(bsize < BLOCK_SIZES_ALL);
+
+  if (!frame_is_intra_only(cm) &&
+      (part_search_state->do_square_split ||
+       part_search_state->do_rectangular_split) &&
+      !x->e_mbd.lossless[xd->mi[0]->segment_id] && ctx_none->skippable) {
+    const int use_ml_based_breakout =
+        bsize <= cpi->sf.part_sf.use_square_partition_only_threshold &&
+        bsize > BLOCK_4X4 && xd->bd == 8;
+    if (use_ml_based_breakout) {
+      if (av1_ml_predict_breakout(cpi, bsize, x, this_rdc,
+                                  *pb_source_variance)) {
+        part_search_state->do_square_split = 0;
+        part_search_state->do_rectangular_split = 0;
+      }
+    }
+
+    // Adjust dist breakout threshold according to the partition size.
+    const int64_t dist_breakout_thr =
+        cpi->sf.part_sf.partition_search_breakout_dist_thr >>
+        ((2 * (MAX_SB_SIZE_LOG2 - 2)) -
+         (mi_size_wide_log2[bsize] + mi_size_high_log2[bsize]));
+    const int rate_breakout_thr =
+        cpi->sf.part_sf.partition_search_breakout_rate_thr *
+        num_pels_log2_lookup[bsize];
+    // If all y, u, v transform blocks in this partition are skippable,
+    // and the dist & rate are within the thresholds, the partition
+    // search is terminated for current branch of the partition search
+    // tree. The dist & rate thresholds are set to 0 at speed 0 to
+    // disable the early termination at that speed.
+    if (best_rdc->dist < dist_breakout_thr &&
+        best_rdc->rate < rate_breakout_thr) {
+      part_search_state->do_square_split = 0;
+      part_search_state->do_rectangular_split = 0;
+    }
+  }
+
+  // Early termination: using simple_motion_search features and the
+  // rate, distortion, and rdcost of PARTITION_NONE, a DNN will make a
+  // decision on early terminating at PARTITION_NONE.
+  if (cpi->sf.part_sf.simple_motion_search_early_term_none && cm->show_frame &&
+      !frame_is_intra_only(cm) && bsize >= BLOCK_16X16 &&
+      blk_params.mi_row_edge < mi_params->mi_rows &&
+      blk_params.mi_col_edge < mi_params->mi_cols &&
+      this_rdc->rdcost < INT64_MAX && this_rdc->rdcost >= 0 &&
+      this_rdc->rate < INT_MAX && this_rdc->rate >= 0 &&
+      (part_search_state->do_square_split ||
+       part_search_state->do_rectangular_split)) {
+    av1_simple_motion_search_early_term_none(
+        cpi, x, sms_tree, blk_params.mi_row, blk_params.mi_col, bsize, this_rdc,
+        &part_search_state->terminate_partition_search);
+  }
+}
+
+// Decide early termination and rectangular partition pruning
+// based on PARTITION_NONE and PARTITION_SPLIT costs.
+static AOM_INLINE void prune_partitions_after_split(
+    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+    PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+    int64_t part_none_rd, int64_t part_split_rd) {
+  const AV1_COMMON *const cm = &cpi->common;
+  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  const int mi_row = blk_params.mi_row;
+  const int mi_col = blk_params.mi_col;
+  const BLOCK_SIZE bsize = blk_params.bsize;
+  assert(bsize < BLOCK_SIZES_ALL);
+
+  // Early termination: using the rd costs of PARTITION_NONE and subblocks
+  // from PARTITION_SPLIT to determine an early breakout.
+  if (cpi->sf.part_sf.ml_early_term_after_part_split_level &&
+      !frame_is_intra_only(cm) &&
+      !part_search_state->terminate_partition_search &&
+      part_search_state->do_rectangular_split &&
+      (part_search_state->partition_rect_allowed[HORZ] ||
+       part_search_state->partition_rect_allowed[VERT])) {
+    av1_ml_early_term_after_split(
+        cpi, x, sms_tree, bsize, best_rdc->rdcost, part_none_rd, part_split_rd,
+        part_search_state->split_rd, mi_row, mi_col,
+        &part_search_state->terminate_partition_search);
+  }
+
+  // Use the rd costs of PARTITION_NONE and subblocks from PARTITION_SPLIT
+  // to prune out rectangular partitions in some directions.
+  if (!cpi->sf.part_sf.ml_early_term_after_part_split_level &&
+      cpi->sf.part_sf.ml_prune_rect_partition && !frame_is_intra_only(cm) &&
+      (part_search_state->partition_rect_allowed[HORZ] ||
+       part_search_state->partition_rect_allowed[VERT]) &&
+      !(part_search_state->prune_rect_part[HORZ] ||
+        part_search_state->prune_rect_part[VERT]) &&
+      !part_search_state->terminate_partition_search) {
+    av1_setup_src_planes(x, cpi->source, mi_row, mi_col, av1_num_planes(cm),
+                         bsize);
+    av1_ml_prune_rect_partition(
+        cpi, x, bsize, best_rdc->rdcost, part_search_state->none_rd,
+        part_search_state->split_rd, &part_search_state->prune_rect_part[HORZ],
+        &part_search_state->prune_rect_part[VERT]);
+  }
+}
+
+// PARTITION_NONE search.
+static AOM_INLINE void none_partition_search(
+    AV1_COMP *const cpi, ThreadData *td, TileDataEnc *tile_data, MACROBLOCK *x,
+    PC_TREE *pc_tree, SIMPLE_MOTION_DATA_TREE *sms_tree,
+    RD_SEARCH_MACROBLOCK_CONTEXT *x_ctx,
+    PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+    unsigned int *pb_source_variance, unsigned int *pb_simple_motion_pred_sse,
+    int64_t *none_rd, int64_t *part_none_rd) {
+  const AV1_COMMON *const cm = &cpi->common;
+  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  RD_STATS *this_rdc = &part_search_state->this_rdc;
+  const int mi_row = blk_params.mi_row;
+  const int mi_col = blk_params.mi_col;
+  const BLOCK_SIZE bsize = blk_params.bsize;
+  assert(bsize < BLOCK_SIZES_ALL);
+
+  // Set PARTITION_NONE allowed flag.
+  set_part_none_allowed_flag(cpi, part_search_state);
+  if (!part_search_state->partition_none_allowed) return;
+
+  int pt_cost = 0;
+  RD_STATS best_remain_rdcost;
+
+  // Set PARTITION_NONE context and cost.
+  set_none_partition_params(cm, td, x, pc_tree, part_search_state,
+                            &best_remain_rdcost, best_rdc, &pt_cost);
+
+#if CONFIG_COLLECT_PARTITION_STATS
+  // Timer start for partition None.
+  if (best_remain_rdcost >= 0) {
+    partition_attempts[PARTITION_NONE] += 1;
+    aom_usec_timer_start(&partition_timer);
+    partition_timer_on = 1;
+  }
+#endif
+  // PARTITION_NONE evaluation and cost update.
+  pick_sb_modes(cpi, tile_data, x, mi_row, mi_col, this_rdc, PARTITION_NONE,
+                bsize, pc_tree->none, best_remain_rdcost, PICK_MODE_RD);
+  av1_rd_cost_update(x->rdmult, this_rdc);
+
+#if CONFIG_COLLECT_PARTITION_STATS
+  // Timer end for partition None.
+  if (partition_timer_on) {
+    aom_usec_timer_mark(&partition_timer);
+    int64_t time = aom_usec_timer_elapsed(&partition_timer);
+    partition_times[PARTITION_NONE] += time;
+    partition_timer_on = 0;
+  }
+#endif
+  *pb_source_variance = x->source_variance;
+  *pb_simple_motion_pred_sse = x->simple_motion_pred_sse;
+  if (none_rd) *none_rd = this_rdc->rdcost;
+  part_search_state->none_rd = this_rdc->rdcost;
+  if (this_rdc->rate != INT_MAX) {
+    // Record picked ref frame to prune ref frames for other partition types.
+    if (cpi->sf.inter_sf.prune_ref_frame_for_rect_partitions) {
+      const int ref_type = av1_ref_frame_type(pc_tree->none->mic.ref_frame);
+      av1_update_picked_ref_frames_mask(
+          x, ref_type, bsize, cm->seq_params.mib_size, mi_row, mi_col);
+    }
+
+    // Calculate the total cost and update the best partition.
+    if (blk_params.bsize_at_least_8x8) {
+      this_rdc->rate += pt_cost;
+      this_rdc->rdcost = RDCOST(x->rdmult, this_rdc->rate, this_rdc->dist);
+    }
+    *part_none_rd = this_rdc->rdcost;
+    if (this_rdc->rdcost < best_rdc->rdcost) {
+      *best_rdc = *this_rdc;
+      part_search_state->found_best_partition = true;
+      if (blk_params.bsize_at_least_8x8) {
+        pc_tree->partitioning = PARTITION_NONE;
+      }
+
+      // Disable split and rectangular partition search
+      // based on PARTITION_NONE cost.
+      prune_partitions_after_none(cpi, x, sms_tree, pc_tree->none,
+                                  part_search_state, best_rdc,
+                                  pb_source_variance);
+    }
+  }
+  av1_restore_context(x, x_ctx, mi_row, mi_col, bsize, av1_num_planes(cm));
+}
+
+// PARTITION_SPLIT search.
+static AOM_INLINE void split_partition_search(
+    AV1_COMP *const cpi, ThreadData *td, TileDataEnc *tile_data,
+    TokenExtra **tp, MACROBLOCK *x, PC_TREE *pc_tree,
+    SIMPLE_MOTION_DATA_TREE *sms_tree, RD_SEARCH_MACROBLOCK_CONTEXT *x_ctx,
+    PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+    SB_MULTI_PASS_MODE multi_pass_mode, int64_t *part_split_rd) {
+  const AV1_COMMON *const cm = &cpi->common;
+  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  const CommonModeInfoParams *const mi_params = &cm->mi_params;
+  const int mi_row = blk_params.mi_row;
+  const int mi_col = blk_params.mi_col;
+  const int bsize = blk_params.bsize;
+  assert(bsize < BLOCK_SIZES_ALL);
+  RD_STATS sum_rdc = part_search_state->sum_rdc;
+  const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+
+  // Check if partition split is allowed.
+  if (part_search_state->terminate_partition_search ||
+      !part_search_state->do_square_split)
+    return;
+
+  for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+    if (pc_tree->split[i] == NULL)
+      pc_tree->split[i] = av1_alloc_pc_tree_node(subsize);
+    pc_tree->split[i]->index = i;
+  }
+
+  // Initialization of this partition RD stats.
+  av1_init_rd_stats(&sum_rdc);
+  sum_rdc.rate = part_search_state->partition_cost[PARTITION_SPLIT];
+  sum_rdc.rdcost = RDCOST(x->rdmult, sum_rdc.rate, 0);
+
+  int idx;
+#if CONFIG_COLLECT_PARTITION_STATS
+  if (best_rdc->rdcost - sum_rdc.rdcost >= 0) {
+    partition_attempts[PARTITION_SPLIT] += 1;
+    aom_usec_timer_start(&partition_timer);
+    partition_timer_on = 1;
+  }
+#endif
+  // Recursive partition search on 4 sub-blocks.
+  for (idx = 0; idx < SUB_PARTITIONS_SPLIT && sum_rdc.rdcost < best_rdc->rdcost;
+       ++idx) {
+    const int x_idx = (idx & 1) * blk_params.mi_step;
+    const int y_idx = (idx >> 1) * blk_params.mi_step;
+
+    if (mi_row + y_idx >= mi_params->mi_rows ||
+        mi_col + x_idx >= mi_params->mi_cols)
+      continue;
+
+    pc_tree->split[idx]->index = idx;
+    int64_t *p_split_rd = &part_search_state->split_rd[idx];
+    RD_STATS best_remain_rdcost;
+    av1_rd_stats_subtraction(x->rdmult, best_rdc, &sum_rdc,
+                             &best_remain_rdcost);
+
+    int curr_quad_tree_idx = 0;
+    if (frame_is_intra_only(cm) && bsize <= BLOCK_64X64) {
+      curr_quad_tree_idx = part_search_state->intra_part_info->quad_tree_idx;
+      part_search_state->intra_part_info->quad_tree_idx =
+          4 * curr_quad_tree_idx + idx + 1;
+    }
+    // Split partition evaluation of corresponding idx.
+    // If the RD cost exceeds the best cost then do not
+    // evaluate other split sub-partitions.
+    if (!rd_pick_partition(
+            cpi, td, tile_data, tp, mi_row + y_idx, mi_col + x_idx, subsize,
+            &part_search_state->this_rdc, best_remain_rdcost,
+            pc_tree->split[idx], sms_tree->split[idx], p_split_rd,
+            multi_pass_mode, &part_search_state->split_part_rect_win[idx])) {
+      av1_invalid_rd_stats(&sum_rdc);
+      break;
+    }
+    if (frame_is_intra_only(cm) && bsize <= BLOCK_64X64) {
+      part_search_state->intra_part_info->quad_tree_idx = curr_quad_tree_idx;
+    }
+
+    sum_rdc.rate += part_search_state->this_rdc.rate;
+    sum_rdc.dist += part_search_state->this_rdc.dist;
+    av1_rd_cost_update(x->rdmult, &sum_rdc);
+
+    // Set split ctx as ready for use.
+    if (idx <= 1 && (bsize <= BLOCK_8X8 ||
+                     pc_tree->split[idx]->partitioning == PARTITION_NONE)) {
+      const MB_MODE_INFO *const mbmi = &pc_tree->split[idx]->none->mic;
+      const PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
+      // Neither palette mode nor cfl predicted.
+      if (pmi->palette_size[0] == 0 && pmi->palette_size[1] == 0) {
+        if (mbmi->uv_mode != UV_CFL_PRED)
+          part_search_state->is_split_ctx_is_ready[idx] = 1;
+      }
+    }
+  }
+#if CONFIG_COLLECT_PARTITION_STATS
+  if (partition_timer_on) {
+    aom_usec_timer_mark(&partition_timer);
+    int64_t time = aom_usec_timer_elapsed(&partition_timer);
+    partition_times[PARTITION_SPLIT] += time;
+    partition_timer_on = 0;
+  }
+#endif
+  const int reached_last_index = (idx == SUB_PARTITIONS_SPLIT);
+
+  // Calculate the total cost and update the best partition.
+  *part_split_rd = sum_rdc.rdcost;
+  if (reached_last_index && sum_rdc.rdcost < best_rdc->rdcost) {
+    sum_rdc.rdcost = RDCOST(x->rdmult, sum_rdc.rate, sum_rdc.dist);
+    if (sum_rdc.rdcost < best_rdc->rdcost) {
+      *best_rdc = sum_rdc;
+      part_search_state->found_best_partition = true;
+      pc_tree->partitioning = PARTITION_SPLIT;
+    }
+  } else if (cpi->sf.part_sf.less_rectangular_check_level > 0) {
+    // Skip rectangular partition test when partition type none gives better
+    // rd than partition type split.
+    if (cpi->sf.part_sf.less_rectangular_check_level == 2 || idx <= 2) {
+      const int partition_none_valid = part_search_state->none_rd > 0;
+      const int partition_none_better =
+          part_search_state->none_rd < sum_rdc.rdcost;
+      part_search_state->do_rectangular_split &=
+          !(partition_none_valid && partition_none_better);
+    }
+  }
+  av1_restore_context(x, x_ctx, mi_row, mi_col, bsize, av1_num_planes(cm));
+}
+
 /*!\brief AV1 block partition search (full search).
  *
  * \ingroup partition_search
@@ -2817,7 +3191,6 @@
                               SB_MULTI_PASS_MODE multi_pass_mode,
                               RD_RECT_PART_WIN_INFO *rect_part_win_info) {
   const AV1_COMMON *const cm = &cpi->common;
-  const CommonModeInfoParams *const mi_params = &cm->mi_params;
   const int num_planes = av1_num_planes(cm);
   TileInfo *const tile_info = &tile_data->tile_info;
   MACROBLOCK *const x = &td->mb;
@@ -2931,281 +3304,21 @@
   unsigned int pb_simple_motion_pred_sse = UINT_MAX;
   (void)pb_simple_motion_pred_sse;
 
-  // PARTITION_NONE
-  if (pc_tree->none == NULL)
-    pc_tree->none = av1_alloc_pmc(cm, bsize, &td->shared_coeff_buf);
-  PICK_MODE_CONTEXT *ctx_none = pc_tree->none;
-  if ((blk_params.width <= blk_params.min_partition_size_1d) &&
-      blk_params.has_rows && blk_params.has_cols)
-    part_search_state.partition_none_allowed = 1;
-  assert(part_search_state.terminate_partition_search == 0);
+  // PARTITION_NONE search stage.
   int64_t part_none_rd = INT64_MAX;
-  if (cpi->is_screen_content_type)
-    part_search_state.partition_none_allowed =
-        blk_params.has_rows && blk_params.has_cols;
-  if (part_search_state.partition_none_allowed) {
-    int pt_cost = 0;
-    if (blk_params.bsize_at_least_8x8) {
-      pt_cost = part_search_state.partition_cost[PARTITION_NONE] < INT_MAX
-                    ? part_search_state.partition_cost[PARTITION_NONE]
-                    : 0;
-    }
-    RD_STATS partition_rdcost;
-    av1_init_rd_stats(&partition_rdcost);
-    partition_rdcost.rate = pt_cost;
-    av1_rd_cost_update(x->rdmult, &partition_rdcost);
-    RD_STATS best_remain_rdcost;
-    av1_rd_stats_subtraction(x->rdmult, &best_rdc, &partition_rdcost,
-                             &best_remain_rdcost);
-#if CONFIG_COLLECT_PARTITION_STATS
-    if (best_remain_rdcost >= 0) {
-      partition_attempts[PARTITION_NONE] += 1;
-      aom_usec_timer_start(&partition_timer);
-      partition_timer_on = 1;
-    }
-#endif
-    pick_sb_modes(cpi, tile_data, x, mi_row, mi_col,
-                  &part_search_state.this_rdc, PARTITION_NONE, bsize, ctx_none,
-                  best_remain_rdcost, PICK_MODE_RD);
-    av1_rd_cost_update(x->rdmult, &part_search_state.this_rdc);
-#if CONFIG_COLLECT_PARTITION_STATS
-    if (partition_timer_on) {
-      aom_usec_timer_mark(&partition_timer);
-      int64_t time = aom_usec_timer_elapsed(&partition_timer);
-      partition_times[PARTITION_NONE] += time;
-      partition_timer_on = 0;
-    }
-#endif
-    pb_source_variance = x->source_variance;
-    pb_simple_motion_pred_sse = x->simple_motion_pred_sse;
-    if (none_rd) *none_rd = part_search_state.this_rdc.rdcost;
-    part_search_state.none_rd = part_search_state.this_rdc.rdcost;
-    if (part_search_state.this_rdc.rate != INT_MAX) {
-      // Record picked ref frame to prune ref frames for other partition types.
-      if (cpi->sf.inter_sf.prune_ref_frame_for_rect_partitions) {
-        const int ref_type = av1_ref_frame_type(ctx_none->mic.ref_frame);
-        av1_update_picked_ref_frames_mask(
-            x, ref_type, bsize, cm->seq_params.mib_size, mi_row, mi_col);
-      }
+  none_partition_search(cpi, td, tile_data, x, pc_tree, sms_tree, &x_ctx,
+                        &part_search_state, &best_rdc, &pb_source_variance,
+                        &pb_simple_motion_pred_sse, none_rd, &part_none_rd);
 
-      // Calculate the total cost and update the best partition.
-      if (blk_params.bsize_at_least_8x8) {
-        part_search_state.this_rdc.rate += pt_cost;
-        part_search_state.this_rdc.rdcost =
-            RDCOST(x->rdmult, part_search_state.this_rdc.rate,
-                   part_search_state.this_rdc.dist);
-      }
-      part_none_rd = part_search_state.this_rdc.rdcost;
-      if (part_search_state.this_rdc.rdcost < best_rdc.rdcost) {
-        // Adjust dist breakout threshold according to the partition size.
-        const int64_t dist_breakout_thr =
-            cpi->sf.part_sf.partition_search_breakout_dist_thr >>
-            ((2 * (MAX_SB_SIZE_LOG2 - 2)) -
-             (mi_size_wide_log2[bsize] + mi_size_high_log2[bsize]));
-        const int rate_breakout_thr =
-            cpi->sf.part_sf.partition_search_breakout_rate_thr *
-            num_pels_log2_lookup[bsize];
-
-        best_rdc = part_search_state.this_rdc;
-        part_search_state.found_best_partition = true;
-        if (blk_params.bsize_at_least_8x8) {
-          pc_tree->partitioning = PARTITION_NONE;
-        }
-
-        // Early termination: if the rd cost is very low, early terminate at
-        // PARTITION_NONE and skip all other partitions.
-        if (!frame_is_intra_only(cm) &&
-            (part_search_state.do_square_split ||
-             part_search_state.do_rectangular_split) &&
-            !x->e_mbd.lossless[xd->mi[0]->segment_id] && ctx_none->skippable) {
-          const int use_ml_based_breakout =
-              bsize <= cpi->sf.part_sf.use_square_partition_only_threshold &&
-              bsize > BLOCK_4X4 && xd->bd == 8;
-          if (use_ml_based_breakout) {
-            if (av1_ml_predict_breakout(cpi, bsize, x,
-                                        &part_search_state.this_rdc,
-                                        pb_source_variance)) {
-              part_search_state.do_square_split = 0;
-              part_search_state.do_rectangular_split = 0;
-            }
-          }
-
-          // If all y, u, v transform blocks in this partition are skippable,
-          // and the dist & rate are within the thresholds, the partition
-          // search is terminated for current branch of the partition search
-          // tree. The dist & rate thresholds are set to 0 at speed 0 to
-          // disable the early termination at that speed.
-          if (best_rdc.dist < dist_breakout_thr &&
-              best_rdc.rate < rate_breakout_thr) {
-            part_search_state.do_square_split = 0;
-            part_search_state.do_rectangular_split = 0;
-          }
-        }
-
-        // Early termination: using simple_motion_search features and the
-        // rate, distortion, and rdcost of PARTITION_NONE, a DNN will make a
-        // decision on early terminating at PARTITION_NONE.
-        if (cpi->sf.part_sf.simple_motion_search_early_term_none &&
-            cm->show_frame && !frame_is_intra_only(cm) &&
-            bsize >= BLOCK_16X16 &&
-            blk_params.mi_row_edge < mi_params->mi_rows &&
-            blk_params.mi_col_edge < mi_params->mi_cols &&
-            part_search_state.this_rdc.rdcost < INT64_MAX &&
-            part_search_state.this_rdc.rdcost >= 0 &&
-            part_search_state.this_rdc.rate < INT_MAX &&
-            part_search_state.this_rdc.rate >= 0 &&
-            (part_search_state.do_square_split ||
-             part_search_state.do_rectangular_split)) {
-          av1_simple_motion_search_early_term_none(
-              cpi, x, sms_tree, mi_row, mi_col, bsize,
-              &part_search_state.this_rdc,
-              &part_search_state.terminate_partition_search);
-        }
-      }
-    }
-
-    av1_restore_context(x, &x_ctx, mi_row, mi_col, bsize, num_planes);
-  }
-
-  // PARTITION_SPLIT
+  // PARTITION_SPLIT search stage.
   int64_t part_split_rd = INT64_MAX;
-  blk_params.subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
-  if ((!part_search_state.terminate_partition_search &&
-       part_search_state.do_square_split)) {
-    for (int i = 0; i < 4; ++i) {
-      if (pc_tree->split[i] == NULL)
-        pc_tree->split[i] = av1_alloc_pc_tree_node(blk_params.subsize);
-      pc_tree->split[i]->index = i;
-    }
-    av1_init_rd_stats(&part_search_state.sum_rdc);
-    part_search_state.sum_rdc.rate =
-        part_search_state.partition_cost[PARTITION_SPLIT];
-    part_search_state.sum_rdc.rdcost =
-        RDCOST(x->rdmult, part_search_state.sum_rdc.rate, 0);
+  split_partition_search(cpi, td, tile_data, tp, x, pc_tree, sms_tree, &x_ctx,
+                         &part_search_state, &best_rdc, multi_pass_mode,
+                         &part_split_rd);
 
-    int idx;
-#if CONFIG_COLLECT_PARTITION_STATS
-    if (best_rdc.rdcost - part_search_state.sum_rdc.rdcost >= 0) {
-      partition_attempts[PARTITION_SPLIT] += 1;
-      aom_usec_timer_start(&partition_timer);
-      partition_timer_on = 1;
-    }
-#endif
-    // Recursive partition search on 4 sub-blocks.
-    for (idx = 0; idx < 4 && part_search_state.sum_rdc.rdcost < best_rdc.rdcost;
-         ++idx) {
-      const int x_idx = (idx & 1) * blk_params.mi_step;
-      const int y_idx = (idx >> 1) * blk_params.mi_step;
-
-      if (mi_row + y_idx >= mi_params->mi_rows ||
-          mi_col + x_idx >= mi_params->mi_cols)
-        continue;
-
-      pc_tree->split[idx]->index = idx;
-      int64_t *p_split_rd = &part_search_state.split_rd[idx];
-
-      RD_STATS best_remain_rdcost;
-      av1_rd_stats_subtraction(x->rdmult, &best_rdc, &part_search_state.sum_rdc,
-                               &best_remain_rdcost);
-
-      int curr_quad_tree_idx = 0;
-      if (frame_is_intra_only(cm) && bsize <= BLOCK_64X64) {
-        curr_quad_tree_idx = part_search_state.intra_part_info->quad_tree_idx;
-        part_search_state.intra_part_info->quad_tree_idx =
-            4 * curr_quad_tree_idx + idx + 1;
-      }
-      if (!rd_pick_partition(cpi, td, tile_data, tp, mi_row + y_idx,
-                             mi_col + x_idx, blk_params.subsize,
-                             &part_search_state.this_rdc, best_remain_rdcost,
-                             pc_tree->split[idx], sms_tree->split[idx],
-                             p_split_rd, multi_pass_mode,
-                             &part_search_state.split_part_rect_win[idx])) {
-        av1_invalid_rd_stats(&part_search_state.sum_rdc);
-        break;
-      }
-      if (frame_is_intra_only(cm) && bsize <= BLOCK_64X64) {
-        part_search_state.intra_part_info->quad_tree_idx = curr_quad_tree_idx;
-      }
-
-      part_search_state.sum_rdc.rate += part_search_state.this_rdc.rate;
-      part_search_state.sum_rdc.dist += part_search_state.this_rdc.dist;
-      av1_rd_cost_update(x->rdmult, &part_search_state.sum_rdc);
-      if (idx <= 1 && (bsize <= BLOCK_8X8 ||
-                       pc_tree->split[idx]->partitioning == PARTITION_NONE)) {
-        const MB_MODE_INFO *const mbmi = &pc_tree->split[idx]->none->mic;
-        const PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
-        // Neither palette mode nor cfl predicted.
-        if (pmi->palette_size[0] == 0 && pmi->palette_size[1] == 0) {
-          if (mbmi->uv_mode != UV_CFL_PRED)
-            part_search_state.is_split_ctx_is_ready[idx] = 1;
-        }
-      }
-    }
-#if CONFIG_COLLECT_PARTITION_STATS
-    if (partition_timer_on) {
-      aom_usec_timer_mark(&partition_timer);
-      int64_t time = aom_usec_timer_elapsed(&partition_timer);
-      partition_times[PARTITION_SPLIT] += time;
-      partition_timer_on = 0;
-    }
-#endif
-    const int reached_last_index = (idx == 4);
-
-    // Calculate the total cost and update the best partition.
-    part_split_rd = part_search_state.sum_rdc.rdcost;
-    if (reached_last_index &&
-        part_search_state.sum_rdc.rdcost < best_rdc.rdcost) {
-      part_search_state.sum_rdc.rdcost =
-          RDCOST(x->rdmult, part_search_state.sum_rdc.rate,
-                 part_search_state.sum_rdc.dist);
-      if (part_search_state.sum_rdc.rdcost < best_rdc.rdcost) {
-        best_rdc = part_search_state.sum_rdc;
-        part_search_state.found_best_partition = true;
-        pc_tree->partitioning = PARTITION_SPLIT;
-      }
-    } else if (cpi->sf.part_sf.less_rectangular_check_level > 0) {
-      // Skip rectangular partition test when partition type none gives better
-      // rd than partition type split.
-      if (cpi->sf.part_sf.less_rectangular_check_level == 2 || idx <= 2) {
-        const int partition_none_valid = part_search_state.none_rd > 0;
-        const int partition_none_better =
-            part_search_state.none_rd < part_search_state.sum_rdc.rdcost;
-        part_search_state.do_rectangular_split &=
-            !(partition_none_valid && partition_none_better);
-      }
-    }
-
-    av1_restore_context(x, &x_ctx, mi_row, mi_col, bsize, num_planes);
-  }  // if (do_split)
-
-  // Early termination: using the rd costs of PARTITION_NONE and subblocks
-  // from PARTITION_SPLIT to determine an early breakout.
-  if (cpi->sf.part_sf.ml_early_term_after_part_split_level &&
-      !frame_is_intra_only(cm) &&
-      !part_search_state.terminate_partition_search &&
-      part_search_state.do_rectangular_split &&
-      (part_search_state.partition_rect_allowed[HORZ] ||
-       part_search_state.partition_rect_allowed[VERT])) {
-    av1_ml_early_term_after_split(
-        cpi, x, sms_tree, bsize, best_rdc.rdcost, part_none_rd, part_split_rd,
-        part_search_state.split_rd, mi_row, mi_col,
-        &part_search_state.terminate_partition_search);
-  }
-
-  // Pruning: using the rd costs of PARTITION_NONE and subblocks from
-  // PARTITION_SPLIT to prune out rectangular partitions in some directions.
-  if (!cpi->sf.part_sf.ml_early_term_after_part_split_level &&
-      cpi->sf.part_sf.ml_prune_rect_partition && !frame_is_intra_only(cm) &&
-      (part_search_state.partition_rect_allowed[HORZ] ||
-       part_search_state.partition_rect_allowed[VERT]) &&
-      !(part_search_state.prune_rect_part[HORZ] ||
-        part_search_state.prune_rect_part[VERT]) &&
-      !part_search_state.terminate_partition_search) {
-    av1_setup_src_planes(x, cpi->source, mi_row, mi_col, num_planes, bsize);
-    av1_ml_prune_rect_partition(
-        cpi, x, bsize, best_rdc.rdcost, part_search_state.none_rd,
-        part_search_state.split_rd, prune_horz, prune_vert);
-  }
+  // Prune partitions based on PARTITION_NONE and PARTITION_SPLIT.
+  prune_partitions_after_split(cpi, x, sms_tree, &part_search_state, &best_rdc,
+                               part_none_rd, part_split_rd);
 
   // Rectangular partitions search stage.
   rectangular_partition_search(cpi, td, tile_data, tp, x, pc_tree, &x_ctx,
diff --git a/av1/encoder/encodeframe_utils.c b/av1/encoder/encodeframe_utils.c
index 8572304..216d748e 100644
--- a/av1/encoder/encodeframe_utils.c
+++ b/av1/encoder/encodeframe_utils.c
@@ -178,6 +178,7 @@
   const MB_MODE_INFO *const mi = &ctx->mic;
   MB_MODE_INFO *const mi_addr = xd->mi[0];
   const struct segmentation *const seg = &cm->seg;
+  assert(bsize < BLOCK_SIZES_ALL);
   const int bw = mi_size_wide[mi->sb_type];
   const int bh = mi_size_high[mi->sb_type];
   const int mis = mi_params->mi_stride;
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index 9c30fb4..a5597ae 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -33,6 +33,9 @@
 #define FEATURE_SMS_SPLIT_MODEL_FLAG \
   (FEATURE_SMS_NONE_FLAG | FEATURE_SMS_SPLIT_FLAG)
 
+// Number of sub-partitions in split partition type.
+#define SUB_PARTITIONS_SPLIT 4
+
 // Number of sub-partitions in AB partition types.
 #define SUB_PARTITIONS_AB 3