Simplify partition_search/pruning criterion

This is a refactoring commit that simplifies the criterion for
performing partition search and partition pruning. Functional
signatures are changed to use only PartitionSearchState when possible.
Some helper functions for PartitionSearchState and PartitionBlkParams
are added to improve the readability of the code.

Furthermore, this commit also removes redundant checks in search
conditions. For example, to check whether partition_4 is allowed, the
codebase used to check for bsize >= BLOCK_8X8, do_rect_split, etc. But
those checks were already done during the assignment of
part4_search_allowed. In the refactored version, we simply check for
terminate_partition_search and part4_search_allowed. Similar
simplifications are done for other partition types as well. The
exception of normal rectangular split, as it would result in a
stats_changed due to interactions with is_active_edge.

No stats_changed is expected, and the encoder is sped up by 0.02%.

Change-Id: I0c8f8a139eada79ab632088798a92e19fae61001
diff --git a/av1/encoder/encodeframe_utils.h b/av1/encoder/encodeframe_utils.h
index 03cc870..07fd6c3 100644
--- a/av1/encoder/encodeframe_utils.h
+++ b/av1/encoder/encodeframe_utils.h
@@ -205,12 +205,27 @@
   int is_split_ctx_is_ready[2];
   int is_rect_ctx_is_ready[NUM_RECT_PARTS];
 
-  // Flags to prune/skip particular partition size evaluation.
+  // If true, skips the rest of partition evaluation at the current bsize level.
   int terminate_partition_search;
+
+  // If false, skips rdopt on PARTITION_NONE.
   int partition_none_allowed;
+
+  // If partition_rect_allowed[HORZ] is false, skips searching PARTITION_HORZ,
+  // PARTITION_HORZ_A, PARTITIO_HORZ_B, PARTITION_HORZ_4. Same holds for VERT.
   int partition_rect_allowed[NUM_RECT_PARTS];
+
+  // If false, skips searching rectangular partition unless some logic related
+  // to edge detection holds.
   int do_rectangular_split;
+
+  // If false, skips searching PARTITION_SPLIT.
   int do_square_split;
+
+  // If true, prunes the corresponding PARTITION_HORZ/PARTITION_VERT. Note that
+  // this does not directly affect the extended partitions, so this can be used
+  // to prune out PARTITION_HORZ/PARTITION_VERT while still allowing rdopt of
+  // PARTITION_HORZ_AB4, etc.
   int prune_rect_part[NUM_RECT_PARTS];
 
   // Chroma subsampling in x and y directions.
@@ -228,6 +243,48 @@
 #endif  // CONFIG_COLLECT_PARTITION_STATS
 } PartitionSearchState;
 
+static AOM_INLINE void av1_disable_square_split_partition(
+    PartitionSearchState *part_state) {
+  part_state->do_square_split = 0;
+}
+
+// Disables all possible rectangular splits. This includes PARTITION_AB4 as they
+// depend on the corresponding partition_rect_allowed.
+static AOM_INLINE void av1_disable_rect_partitions(
+    PartitionSearchState *part_state) {
+  part_state->do_rectangular_split = 0;
+  part_state->partition_rect_allowed[HORZ] = 0;
+  part_state->partition_rect_allowed[VERT] = 0;
+}
+
+// Disables all possible splits so that only PARTITION_NONE *might* be allowed.
+static AOM_INLINE void av1_disable_all_splits(
+    PartitionSearchState *part_state) {
+  av1_disable_square_split_partition(part_state);
+  av1_disable_rect_partitions(part_state);
+}
+
+static AOM_INLINE void av1_set_square_split_only(
+    PartitionSearchState *part_state) {
+  part_state->partition_none_allowed = 0;
+  part_state->do_square_split = 1;
+  av1_disable_rect_partitions(part_state);
+}
+
+static AOM_INLINE bool av1_blk_has_rows_and_cols(
+    const PartitionBlkParams *blk_params) {
+  return blk_params->has_rows && blk_params->has_cols;
+}
+
+static AOM_INLINE bool av1_is_whole_blk_in_frame(
+    const PartitionBlkParams *blk_params,
+    const CommonModeInfoParams *mi_params) {
+  const int mi_row = blk_params->mi_row, mi_col = blk_params->mi_col;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+  return mi_row + mi_size_high[bsize] <= mi_params->mi_rows &&
+         mi_col + mi_size_wide[bsize] <= mi_params->mi_cols;
+}
+
 static AOM_INLINE void update_filter_type_cdf(const MACROBLOCKD *xd,
                                               const MB_MODE_INFO *mbmi,
                                               int dual_filter) {
@@ -251,7 +308,7 @@
                              segment_qindex + cm->quant_params.y_dc_delta_q);
 }
 
-static AOM_INLINE int do_slipt_check(BLOCK_SIZE bsize) {
+static AOM_INLINE int do_split_check(BLOCK_SIZE bsize) {
   return (bsize == BLOCK_16X16 || bsize == BLOCK_32X32);
 }
 
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index e96c804..e184bc1 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -2251,7 +2251,7 @@
   switch (partition) {
     case PARTITION_NONE:
       pc_tree->none = av1_alloc_pmc(cpi, bsize, &td->shared_coeff_buf);
-      if (cpi->sf.rt_sf.nonrd_check_partition_split && do_slipt_check(bsize) &&
+      if (cpi->sf.rt_sf.nonrd_check_partition_split && do_split_check(bsize) &&
           !frame_is_intra_only(cm)) {
         RD_STATS split_rdc, none_rdc, block_rdc;
         RD_SEARCH_MACROBLOCK_CONTEXT x_ctx;
@@ -2710,21 +2710,20 @@
   part_search_state->terminate_partition_search = 0;
   part_search_state->do_square_split = blk_params->bsize_at_least_8x8;
   part_search_state->do_rectangular_split =
-      cpi->oxcf.part_cfg.enable_rect_partitions;
+      cpi->oxcf.part_cfg.enable_rect_partitions &&
+      blk_params->bsize_at_least_8x8;
   av1_zero(part_search_state->prune_rect_part);
 
   // Initialize allowed partition types for the partition block.
   part_search_state->partition_none_allowed =
-      blk_params->has_rows && blk_params->has_cols;
+      av1_blk_has_rows_and_cols(blk_params);
   part_search_state->partition_rect_allowed[HORZ] =
-      blk_params->has_cols && blk_params->bsize_at_least_8x8 &&
-      cpi->oxcf.part_cfg.enable_rect_partitions &&
+      part_search_state->do_rectangular_split && blk_params->has_cols &&
       get_plane_block_size(get_partition_subsize(bsize, PARTITION_HORZ),
                            part_search_state->ss_x,
                            part_search_state->ss_y) != BLOCK_INVALID;
   part_search_state->partition_rect_allowed[VERT] =
-      blk_params->has_rows && blk_params->bsize_at_least_8x8 &&
-      cpi->oxcf.part_cfg.enable_rect_partitions &&
+      part_search_state->do_rectangular_split && blk_params->has_rows &&
       get_plane_block_size(get_partition_subsize(bsize, PARTITION_VERT),
                            part_search_state->ss_x,
                            part_search_state->ss_y) != BLOCK_INVALID;
@@ -2783,7 +2782,7 @@
       blk_params.bsize_at_least_8x8 &&
       (blk_params.width > blk_params.min_partition_size_1d);
   part_search_state->partition_none_allowed =
-      blk_params.has_rows && blk_params.has_cols &&
+      av1_blk_has_rows_and_cols(&blk_params) &&
       (blk_params.width >= blk_params.min_partition_size_1d);
   part_search_state->partition_rect_allowed[HORZ] =
       blk_params.has_cols && is_rect_part_allowed &&
@@ -2835,15 +2834,16 @@
 
 // Checks if HORZ / VERT partition search is allowed.
 static AOM_INLINE int is_rect_part_allowed(
-    const AV1_COMP *cpi, PartitionSearchState *part_search_state,
-    active_edge_info *active_edge, RECT_PART_TYPE rect_part, const int mi_pos) {
-  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+    const AV1_COMP *cpi, const PartitionSearchState *part_search_state,
+    const active_edge_info *active_edge, RECT_PART_TYPE rect_part,
+    const int mi_pos) {
+  const PartitionBlkParams *blk_params = &part_search_state->part_blk_params;
   const int is_part_allowed =
       (!part_search_state->terminate_partition_search &&
        part_search_state->partition_rect_allowed[rect_part] &&
        !part_search_state->prune_rect_part[rect_part] &&
        (part_search_state->do_rectangular_split ||
-        active_edge[rect_part](cpi, mi_pos, blk_params.mi_step)));
+        active_edge[rect_part](cpi, mi_pos, blk_params->mi_step)));
   return is_part_allowed;
 }
 
@@ -3018,18 +3018,6 @@
   av1_restore_context(x, x_ctx, mi_row, mi_col, bsize, av1_num_planes(cm));
 }
 
-// Check if AB partitions search is allowed.
-static AOM_INLINE int is_ab_part_allowed(
-    PartitionSearchState *part_search_state,
-    const int ab_partitions_allowed[NUM_AB_PARTS], const int ab_part_type) {
-  const int is_horz_ab = (ab_part_type >> 1);
-  const int is_part_allowed =
-      (!part_search_state->terminate_partition_search &&
-       part_search_state->partition_rect_allowed[is_horz_ab] &&
-       ab_partitions_allowed[ab_part_type]);
-  return is_part_allowed;
-}
-
 // Set mode search context.
 static AOM_INLINE void set_mode_search_ctx(
     PC_TREE *pc_tree, const int is_ctx_ready[NUM_AB_PARTS][2],
@@ -3110,16 +3098,15 @@
   const int mi_col = blk_params.mi_col;
   const int bsize = blk_params.bsize;
 
-  int ab_partitions_allowed[NUM_AB_PARTS] = { 1, 1, 1, 1 };
+  if (part_search_state->terminate_partition_search) {
+    return;
+  }
+
+  int ab_partitions_allowed[NUM_AB_PARTS];
   // Prune AB partitions
-  av1_prune_ab_partitions(
-      cpi, x, pc_tree, bsize, mi_row, mi_col, pb_source_variance,
-      best_rdc->rdcost, part_search_state->rect_part_rd,
-      part_search_state->split_rd, rect_part_win_info, ext_partition_allowed,
-      part_search_state->partition_rect_allowed[HORZ],
-      part_search_state->partition_rect_allowed[VERT],
-      &ab_partitions_allowed[HORZ_A], &ab_partitions_allowed[HORZ_B],
-      &ab_partitions_allowed[VERT_A], &ab_partitions_allowed[VERT_B]);
+  av1_prune_ab_partitions(cpi, x, pc_tree, pb_source_variance, best_rdc->rdcost,
+                          rect_part_win_info, ext_partition_allowed,
+                          part_search_state, ab_partitions_allowed);
 
   // Flags to indicate whether the mode search is done.
   const int is_ctx_ready[NUM_AB_PARTS][2] = {
@@ -3176,9 +3163,9 @@
     const PARTITION_TYPE part_type = ab_part_type + PARTITION_HORZ_A;
 
     // Check if the AB partition search is to be performed.
-    if (!is_ab_part_allowed(part_search_state, ab_partitions_allowed,
-                            ab_part_type))
+    if (!ab_partitions_allowed[ab_part_type]) {
       continue;
+    }
 
     blk_params.subsize = get_partition_subsize(bsize, part_type);
     if (cpi->sf.part_sf.reuse_prev_rd_results_for_part_ab) {
@@ -3354,8 +3341,6 @@
     return;
   }
 
-  const int mi_row = blk_params.mi_row;
-  const int mi_col = blk_params.mi_col;
   const int bsize = blk_params.bsize;
   PARTITION_TYPE cur_part[NUM_PART4_TYPES] = { PARTITION_HORZ_4,
                                                PARTITION_VERT_4 };
@@ -3394,11 +3379,9 @@
   if (cpi->sf.part_sf.ml_prune_partition && partition4_allowed &&
       part_search_state->partition_rect_allowed[HORZ] &&
       part_search_state->partition_rect_allowed[VERT]) {
-    av1_ml_prune_4_partition(
-        cpi, x, bsize, pc_tree->partitioning, best_rdc->rdcost,
-        part_search_state->rect_part_rd, part_search_state->split_rd,
-        &part4_search_allowed[HORZ4], &part4_search_allowed[VERT4],
-        pb_source_variance, mi_row, mi_col);
+    av1_ml_prune_4_partition(cpi, x, pc_tree->partitioning, best_rdc->rdcost,
+                             part_search_state, part4_search_allowed,
+                             pb_source_variance);
   }
 
   // Pruning: pruning out 4-way partitions based on the number of horz/vert wins
@@ -3445,8 +3428,7 @@
                                         unsigned int *pb_source_variance) {
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
-  PartitionBlkParams blk_params = part_search_state->part_blk_params;
-  const CommonModeInfoParams *const mi_params = &cm->mi_params;
+  const PartitionBlkParams blk_params = part_search_state->part_blk_params;
   RD_STATS *this_rdc = &part_search_state->this_rdc;
   const BLOCK_SIZE bsize = blk_params.bsize;
   assert(bsize < BLOCK_SIZES_ALL);
@@ -3459,10 +3441,8 @@
         bsize <= cpi->sf.part_sf.use_square_partition_only_threshold &&
         bsize > BLOCK_4X4 && cpi->sf.part_sf.ml_predict_breakout_level >= 1;
     if (use_ml_based_breakout) {
-      av1_ml_predict_breakout(cpi, bsize, x, this_rdc, blk_params,
-                              *pb_source_variance, xd->bd,
-                              &part_search_state->do_square_split,
-                              &part_search_state->do_rectangular_split);
+      av1_ml_predict_breakout(cpi, x, this_rdc, *pb_source_variance, xd->bd,
+                              part_search_state);
     }
 
     // Adjust dist breakout threshold according to the partition size.
@@ -3490,15 +3470,13 @@
   // decision on early terminating at PARTITION_NONE.
   if (cpi->sf.part_sf.simple_motion_search_early_term_none && cm->show_frame &&
       !frame_is_intra_only(cm) && bsize >= BLOCK_16X16 &&
-      blk_params.mi_row_edge < mi_params->mi_rows &&
-      blk_params.mi_col_edge < mi_params->mi_cols &&
-      this_rdc->rdcost < INT64_MAX && this_rdc->rdcost >= 0 &&
-      this_rdc->rate < INT_MAX && this_rdc->rate >= 0 &&
+      av1_blk_has_rows_and_cols(&blk_params) && this_rdc->rdcost < INT64_MAX &&
+      this_rdc->rdcost >= 0 && this_rdc->rate < INT_MAX &&
+      this_rdc->rate >= 0 &&
       (part_search_state->do_square_split ||
        part_search_state->do_rectangular_split)) {
-    av1_simple_motion_search_early_term_none(
-        cpi, x, sms_tree, blk_params.mi_row, blk_params.mi_col, bsize, this_rdc,
-        &part_search_state->terminate_partition_search);
+    av1_simple_motion_search_early_term_none(cpi, x, sms_tree, this_rdc,
+                                             part_search_state);
   }
 }
 
@@ -3524,9 +3502,8 @@
       (part_search_state->partition_rect_allowed[HORZ] ||
        part_search_state->partition_rect_allowed[VERT])) {
     av1_ml_early_term_after_split(
-        cpi, x, sms_tree, bsize, best_rdc->rdcost, part_none_rd, part_split_rd,
-        part_search_state->split_rd, mi_row, mi_col,
-        &part_search_state->terminate_partition_search);
+        cpi, x, sms_tree, best_rdc->rdcost, part_none_rd, part_split_rd,
+        part_search_state->split_rd, part_search_state);
   }
 
   // Use the rd costs of PARTITION_NONE and subblocks from PARTITION_SPLIT
@@ -3540,11 +3517,9 @@
       !part_search_state->terminate_partition_search) {
     av1_setup_src_planes(x, cpi->source, mi_row, mi_col, av1_num_planes(cm),
                          bsize);
-    av1_ml_prune_rect_partition(cpi, x, bsize, mi_row, mi_col, best_rdc->rdcost,
+    av1_ml_prune_rect_partition(cpi, x, best_rdc->rdcost,
                                 part_search_state->none_rd,
-                                part_search_state->split_rd,
-                                &part_search_state->prune_rect_part[HORZ],
-                                &part_search_state->prune_rect_part[VERT]);
+                                part_search_state->split_rd, part_search_state);
   }
 }
 
@@ -3968,7 +3943,7 @@
   // Override partition costs at the edges of the frame in the same
   // way as in read_partition (see decodeframe.c).
   PartitionBlkParams blk_params = part_search_state.part_blk_params;
-  if (!(blk_params.has_rows && blk_params.has_cols))
+  if (!av1_blk_has_rows_and_cols(&blk_params))
     set_partition_cost_for_edge_blk(cm, &part_search_state);
 
   av1_set_offsets(cpi, tile_info, x, mi_row, mi_col, bsize);
@@ -4298,7 +4273,7 @@
 
   // Override partition costs at the edges of the frame in the same
   // way as in read_partition (see decodeframe.c).
-  if (!(blk_params.has_rows && blk_params.has_cols))
+  if (!av1_blk_has_rows_and_cols(&blk_params))
     set_partition_cost_for_edge_blk(cm, &part_search_state);
 
   // Disable rectangular partitions for inner blocks when the current block is
@@ -4355,24 +4330,13 @@
 #if CONFIG_COLLECT_COMPONENT_TIMING
   start_timing(cpi, av1_prune_partitions_time);
 #endif
-  int *partition_horz_allowed = &part_search_state.partition_rect_allowed[HORZ];
-  int *partition_vert_allowed = &part_search_state.partition_rect_allowed[VERT];
-  int *prune_horz = &part_search_state.prune_rect_part[HORZ];
-  int *prune_vert = &part_search_state.prune_rect_part[VERT];
   // Pruning: before searching any partition type, using source and simple
   // motion search results to prune out unlikely partitions.
-  av1_prune_partitions_before_search(
-      cpi, x, mi_row, mi_col, bsize, sms_tree,
-      &part_search_state.partition_none_allowed, partition_horz_allowed,
-      partition_vert_allowed, &part_search_state.do_rectangular_split,
-      &part_search_state.do_square_split, prune_horz, prune_vert);
+  av1_prune_partitions_before_search(cpi, x, sms_tree, &part_search_state);
 
   // Pruning: eliminating partition types leading to coding block sizes outside
   // the min and max bsize limitations set from the encoder.
-  av1_prune_partitions_by_max_min_bsize(
-      &x->sb_enc, bsize, blk_params.has_rows && blk_params.has_cols,
-      &part_search_state.partition_none_allowed, partition_horz_allowed,
-      partition_vert_allowed, &part_search_state.do_square_split);
+  av1_prune_partitions_by_max_min_bsize(&x->sb_enc, &part_search_state);
 #if CONFIG_COLLECT_COMPONENT_TIMING
   end_timing(cpi, av1_prune_partitions_time);
 #endif
@@ -4480,7 +4444,7 @@
   const int ext_partition_allowed =
       part_search_state.do_rectangular_split &&
       bsize > cpi->sf.part_sf.ext_partition_eval_thresh &&
-      blk_params.has_rows && blk_params.has_cols;
+      av1_blk_has_rows_and_cols(&blk_params);
 #if CONFIG_COLLECT_COMPONENT_TIMING
   start_timing(cpi, ab_partitions_search_time);
 #endif
@@ -4507,9 +4471,7 @@
   assert(IMPLIES(!cpi->oxcf.part_cfg.enable_rect_partitions,
                  !part4_search_allowed[HORZ4]));
   if (!part_search_state.terminate_partition_search &&
-      part4_search_allowed[HORZ4] && blk_params.has_rows &&
-      (part_search_state.do_rectangular_split ||
-       av1_active_h_edge(cpi, mi_row, blk_params.mi_step))) {
+      part4_search_allowed[HORZ4]) {
     const int inc_step[NUM_PART4_TYPES] = { mi_size_high[blk_params.bsize] / 4,
                                             0 };
     // Evaluation of Horz4 partition type.
@@ -4522,9 +4484,7 @@
   assert(IMPLIES(!cpi->oxcf.part_cfg.enable_rect_partitions,
                  !part4_search_allowed[VERT4]));
   if (!part_search_state.terminate_partition_search &&
-      part4_search_allowed[VERT4] && blk_params.has_cols &&
-      (part_search_state.do_rectangular_split ||
-       av1_active_v_edge(cpi, mi_row, blk_params.mi_step))) {
+      part4_search_allowed[VERT4] && blk_params.has_cols) {
     const int inc_step[NUM_PART4_TYPES] = { 0, mi_size_wide[blk_params.bsize] /
                                                    4 };
     // Evaluation of Vert4 partition type.
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index facd088..5d01d6f 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -11,6 +11,7 @@
 
 #include <float.h>
 
+#include "av1/encoder/encodeframe_utils.h"
 #include "config/aom_dsp_rtcd.h"
 
 #include "av1/common/enums.h"
@@ -132,14 +133,13 @@
 //   -- use reconstructed pixels instead of source pixels for padding
 //   -- use chroma pixels in addition to luma pixels
 void av1_intra_mode_cnn_partition(const AV1_COMMON *const cm, MACROBLOCK *x,
-                                  int bsize, int quad_tree_idx,
-                                  int *partition_none_allowed,
-                                  int *partition_horz_allowed,
-                                  int *partition_vert_allowed,
-                                  int *do_rectangular_split,
-                                  int *do_square_split) {
+                                  int quad_tree_idx,
+                                  PartitionSearchState *part_state) {
   assert(cm->seq_params->sb_size >= BLOCK_64X64 &&
          "Invalid sb_size for intra_cnn!");
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+
   const int bsize_idx = convert_bsize_to_idx(bsize);
 
   if (bsize == BLOCK_128X128) {
@@ -315,23 +315,22 @@
   }
 
   if (logits[0] > split_only_thresh) {
-    *partition_none_allowed = 0;
-    *partition_horz_allowed = 0;
-    *partition_vert_allowed = 0;
-    *do_rectangular_split = 0;
+    av1_set_square_split_only(part_state);
   }
 
   if (logits[0] < no_split_thresh) {
-    *do_square_split = 0;
+    av1_disable_square_split_partition(part_state);
   }
 }
 
-void av1_simple_motion_search_based_split(
-    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
-    int mi_row, int mi_col, BLOCK_SIZE bsize, int *partition_none_allowed,
-    int *partition_horz_allowed, int *partition_vert_allowed,
-    int *do_rectangular_split, int *do_square_split) {
+void av1_simple_motion_search_based_split(AV1_COMP *const cpi, MACROBLOCK *x,
+                                          SIMPLE_MOTION_DATA_TREE *sms_tree,
+                                          PartitionSearchState *part_state) {
   const AV1_COMMON *const cm = &cpi->common;
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const int mi_row = blk_params->mi_row, mi_col = blk_params->mi_col;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+
   const int bsize_idx = convert_bsize_to_idx(bsize);
   const int is_720p_or_larger = AOMMIN(cm->width, cm->height) >= 720;
   const int is_480p_or_larger = AOMMIN(cm->width, cm->height) >= 480;
@@ -369,8 +368,10 @@
   // Note: it is intended to not normalize the features here, to keep it
   // consistent for all features collected and passed to the external model.
   if (ext_ml_model_decision_before_none(
-          cpi, features, partition_none_allowed, partition_horz_allowed,
-          partition_vert_allowed, do_rectangular_split, do_square_split)) {
+          cpi, features, &part_state->partition_none_allowed,
+          &part_state->partition_rect_allowed[HORZ],
+          &part_state->partition_rect_allowed[VERT],
+          &part_state->do_rectangular_split, &part_state->do_square_split)) {
     return;
   }
 
@@ -383,15 +384,12 @@
   av1_nn_predict(features, nn_config, 1, &score);
 
   if (score > split_only_thresh) {
-    *partition_none_allowed = 0;
-    *partition_horz_allowed = 0;
-    *partition_vert_allowed = 0;
-    *do_rectangular_split = 0;
+    av1_set_square_split_only(part_state);
   }
 
   if (cpi->sf.part_sf.simple_motion_search_split >= 2 &&
       score < no_split_thresh) {
-    *do_square_split = 0;
+    av1_disable_square_split_partition(part_state);
   }
 
   // If the score is very low, prune rectangular split since it is unlikely to
@@ -402,7 +400,9 @@
         scale * av1_simple_motion_search_no_split_thresh
                     [cpi->sf.part_sf.simple_motion_search_rect_split][res_idx]
                     [bsize_idx];
-    if (score < rect_split_thresh) *do_rectangular_split = 0;
+    if (score < rect_split_thresh) {
+      part_state->do_rectangular_split = 0;
+    }
   }
 }
 
@@ -491,6 +491,7 @@
   const int w_mi = mi_size_wide[bsize];
   const int h_mi = mi_size_high[bsize];
   assert(mi_size_wide[bsize] == mi_size_high[bsize]);
+  assert(bsize >= BLOCK_8X8);
   assert(cpi->ref_frame_flags & av1_ref_frame_flag_list[LAST_FRAME] ||
          cpi->ref_frame_flags & av1_ref_frame_flag_list[ALTREF_FRAME]);
 
@@ -598,11 +599,14 @@
   features[f_idx++] = (float)mi_size_high_log2[left_bsize];
 }
 
-void av1_simple_motion_search_prune_rect(
-    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
-    int mi_row, int mi_col, BLOCK_SIZE bsize, int partition_horz_allowed,
-    int partition_vert_allowed, int *prune_horz, int *prune_vert) {
+void av1_simple_motion_search_prune_rect(AV1_COMP *const cpi, MACROBLOCK *x,
+                                         SIMPLE_MOTION_DATA_TREE *sms_tree,
+                                         PartitionSearchState *part_state) {
   const AV1_COMMON *const cm = &cpi->common;
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const int mi_row = blk_params->mi_row, mi_col = blk_params->mi_col;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+
   const int bsize_idx = convert_bsize_to_idx(bsize);
   const int is_720p_or_larger = AOMMIN(cm->width, cm->height) >= 720;
   const int is_480p_or_larger = AOMMIN(cm->width, cm->height) >= 480;
@@ -639,15 +643,17 @@
   // consistent for all features collected and passed to the external model.
   if (cpi->sf.part_sf.simple_motion_search_prune_rect &&
       !frame_is_intra_only(cm) &&
-      (partition_horz_allowed || partition_vert_allowed) &&
+      (part_state->partition_rect_allowed[HORZ] ||
+       part_state->partition_rect_allowed[VERT]) &&
       bsize >= BLOCK_8X8 && !av1_superres_scaled(cm)) {
     // Write features to file
     write_features_to_file(
         cpi->oxcf.partition_info_path, cpi->ext_part_controller.test_mode,
         features, FEATURE_SIZE_SMS_PRUNE_PART, 1, bsize, mi_row, mi_col);
 
-    if (ext_ml_model_decision_before_none_part2(cpi, features, prune_horz,
-                                                prune_vert)) {
+    if (ext_ml_model_decision_before_none_part2(
+            cpi, features, &part_state->prune_rect_part[HORZ],
+            &part_state->prune_rect_part[VERT])) {
       return;
     }
   }
@@ -668,12 +674,11 @@
   av1_nn_softmax(scores, probs, num_classes);
 
   // Determine if we should prune rectangular partitions.
-  if (cpi->sf.part_sf.simple_motion_search_prune_rect &&
-      !frame_is_intra_only(cm) &&
-      (partition_horz_allowed || partition_vert_allowed) &&
-      bsize >= BLOCK_8X8 && !av1_superres_scaled(cm)) {
-    *prune_horz = probs[PARTITION_HORZ] <= prune_thresh;
-    *prune_vert = probs[PARTITION_VERT] <= prune_thresh;
+  if (probs[PARTITION_HORZ] <= prune_thresh) {
+    part_state->prune_rect_part[HORZ] = 1;
+  }
+  if (probs[PARTITION_VERT] <= prune_thresh) {
+    part_state->prune_rect_part[VERT] = 1;
   }
 }
 
@@ -685,10 +690,11 @@
 //  - blk_row + blk_height/2 < total_rows and blk_col + blk_width/2 < total_cols
 void av1_simple_motion_search_early_term_none(
     AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
-    int mi_row, int mi_col, BLOCK_SIZE bsize, const RD_STATS *none_rdc,
-    int *early_terminate) {
-  // TODO(chiyotsai@google.com): There are other features we can extract from
-  // PARTITION_NONE. Play with this later.
+    const RD_STATS *none_rdc, PartitionSearchState *part_state) {
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const int mi_row = blk_params->mi_row, mi_col = blk_params->mi_col;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+
   float features[FEATURE_SIZE_SMS_TERM_NONE] = { 0.0f };
   simple_motion_search_prune_part_features(cpi, x, sms_tree, mi_row, mi_col,
                                            bsize, features,
@@ -730,7 +736,8 @@
                          cpi->ext_part_controller.test_mode, features,
                          FEATURE_SIZE_SMS_TERM_NONE, 3, bsize, mi_row, mi_col);
 
-  if (ext_ml_model_decision_after_none_part2(cpi, features, early_terminate)) {
+  if (ext_ml_model_decision_after_none_part2(
+          cpi, features, &part_state->terminate_partition_search)) {
     return;
   }
 
@@ -743,7 +750,7 @@
     score += ml_model[FEATURE_SIZE_SMS_TERM_NONE];
 
     if (score >= 0.0f) {
-      *early_terminate = 1;
+      part_state->terminate_partition_search = 1;
     }
   }
 }
@@ -962,12 +969,16 @@
 #define FEATURES 31
 void av1_ml_early_term_after_split(AV1_COMP *const cpi, MACROBLOCK *const x,
                                    SIMPLE_MOTION_DATA_TREE *const sms_tree,
-                                   BLOCK_SIZE bsize, int64_t best_rd,
-                                   int64_t part_none_rd, int64_t part_split_rd,
-                                   int64_t *split_block_rd, int mi_row,
-                                   int mi_col,
-                                   int *const terminate_partition_search) {
-  if (best_rd <= 0 || best_rd == INT64_MAX || *terminate_partition_search)
+                                   int64_t best_rd, int64_t part_none_rd,
+                                   int64_t part_split_rd,
+                                   int64_t *split_block_rd,
+                                   PartitionSearchState *part_state) {
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const int mi_row = blk_params->mi_row, mi_col = blk_params->mi_col;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+
+  if (best_rd <= 0 || best_rd == INT64_MAX ||
+      part_state->terminate_partition_search)
     return;
 
   const AV1_COMMON *const cm = &cpi->common;
@@ -1046,24 +1057,28 @@
                          cpi->ext_part_controller.test_mode, features, FEATURES,
                          4, bsize, mi_row, mi_col);
 
-  if (ext_ml_model_decision_after_split(cpi, features,
-                                        terminate_partition_search)) {
+  if (ext_ml_model_decision_after_split(
+          cpi, features, &part_state->terminate_partition_search)) {
     return;
   }
 
   float score = 0.0f;
   av1_nn_predict(features, nn_config, 1, &score);
   // Score is indicator of confidence that we should NOT terminate.
-  if (score < thresh) *terminate_partition_search = 1;
+  if (score < thresh) {
+    part_state->terminate_partition_search = 1;
+  }
 }
 #undef FEATURES
 
 void av1_ml_prune_rect_partition(AV1_COMP *const cpi, const MACROBLOCK *const x,
-                                 BLOCK_SIZE bsize, const int mi_row,
-                                 const int mi_col, int64_t best_rd,
-                                 int64_t none_rd, int64_t *split_rd,
-                                 int *const dst_prune_horz,
-                                 int *const dst_prune_vert) {
+                                 int64_t best_rd, int64_t none_rd,
+                                 const int64_t *split_rd,
+                                 PartitionSearchState *part_state) {
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const int mi_row = blk_params->mi_row, mi_col = blk_params->mi_col;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+
   if (bsize < BLOCK_8X8 || best_rd >= 1000000000) return;
   best_rd = AOMMAX(best_rd, 1);
   const NN_CONFIG *nn_config = NULL;
@@ -1145,7 +1160,8 @@
 
   if (ext_ml_model_decision_after_split_part2(
           &cpi->ext_part_controller, frame_is_intra_only(&cpi->common),
-          features, dst_prune_horz, dst_prune_vert)) {
+          features, &part_state->prune_rect_part[HORZ],
+          &part_state->prune_rect_part[VERT])) {
     return;
   }
 
@@ -1157,19 +1173,21 @@
 
   // probs[0] is the probability of the fact that both rectangular partitions
   // are worse than current best_rd
-  if (probs[1] <= cur_thresh) (*dst_prune_horz) = 1;
-  if (probs[2] <= cur_thresh) (*dst_prune_vert) = 1;
+  if (probs[1] <= cur_thresh) part_state->prune_rect_part[HORZ] = 1;
+  if (probs[2] <= cur_thresh) part_state->prune_rect_part[VERT] = 1;
 }
 
 // Use a ML model to predict if horz_a, horz_b, vert_a, and vert_b should be
 // considered.
-void av1_ml_prune_ab_partition(
-    AV1_COMP *const cpi, BLOCK_SIZE bsize, const int mi_row, const int mi_col,
-    int part_ctx, int var_ctx, int64_t best_rd,
-    int64_t horz_rd[SUB_PARTITIONS_RECT], int64_t vert_rd[SUB_PARTITIONS_RECT],
-    int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const horza_partition_allowed,
-    int *const horzb_partition_allowed, int *const verta_partition_allowed,
-    int *const vertb_partition_allowed) {
+void av1_ml_prune_ab_partition(AV1_COMP *const cpi, int part_ctx, int var_ctx,
+                               int64_t best_rd,
+                               PartitionSearchState *part_state,
+                               int *ab_partitions_allowed) {
+  const PartitionBlkParams blk_params = part_state->part_blk_params;
+  const int mi_row = blk_params.mi_row;
+  const int mi_col = blk_params.mi_col;
+  const int bsize = blk_params.bsize;
+
   if (bsize < BLOCK_8X8 || best_rd >= 1000000000) return;
   const NN_CONFIG *nn_config = NULL;
   switch (bsize) {
@@ -1191,16 +1209,19 @@
   int sub_block_rdcost[8] = { 0 };
   int rd_index = 0;
   for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+    const int64_t *horz_rd = part_state->rect_part_rd[HORZ];
     if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
       sub_block_rdcost[rd_index] = (int)horz_rd[i];
     ++rd_index;
   }
   for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+    const int64_t *vert_rd = part_state->rect_part_rd[VERT];
     if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
       sub_block_rdcost[rd_index] = (int)vert_rd[i];
     ++rd_index;
   }
   for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+    const int64_t *split_rd = part_state->split_rd;
     if (split_rd[i] > 0 && split_rd[i] < 1000000000)
       sub_block_rdcost[rd_index] = (int)split_rd[i];
     ++rd_index;
@@ -1223,8 +1244,9 @@
 
   if (ext_ml_model_decision_after_rect(
           &cpi->ext_part_controller, frame_is_intra_only(&cpi->common),
-          features, horza_partition_allowed, horzb_partition_allowed,
-          verta_partition_allowed, vertb_partition_allowed)) {
+          features, &ab_partitions_allowed[HORZ_A],
+          &ab_partitions_allowed[HORZ_B], &ab_partitions_allowed[VERT_A],
+          &ab_partitions_allowed[VERT_B])) {
     return;
   }
 
@@ -1245,16 +1267,13 @@
     case BLOCK_32X32: thresh -= 100; break;
     default: break;
   }
-  *horza_partition_allowed = 0;
-  *horzb_partition_allowed = 0;
-  *verta_partition_allowed = 0;
-  *vertb_partition_allowed = 0;
+  av1_zero_array(ab_partitions_allowed, NUM_AB_PARTS);
   for (int i = 0; i < 16; ++i) {
     if (int_score[i] >= thresh) {
-      if ((i >> 0) & 1) *horza_partition_allowed = 1;
-      if ((i >> 1) & 1) *horzb_partition_allowed = 1;
-      if ((i >> 2) & 1) *verta_partition_allowed = 1;
-      if ((i >> 3) & 1) *vertb_partition_allowed = 1;
+      if ((i >> 0) & 1) ab_partitions_allowed[HORZ_A] = 1;
+      if ((i >> 1) & 1) ab_partitions_allowed[HORZ_B] = 1;
+      if ((i >> 2) & 1) ab_partitions_allowed[VERT_A] = 1;
+      if ((i >> 3) & 1) ab_partitions_allowed[VERT_B] = 1;
     }
   }
 }
@@ -1262,21 +1281,27 @@
 #define FEATURES 18
 #define LABELS 4
 // Use a ML model to predict if horz4 and vert4 should be considered.
-void av1_ml_prune_4_partition(
-    AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
-    int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
-    int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
-    int *const partition_vert4_allowed, unsigned int pb_source_variance,
-    int mi_row, int mi_col) {
+void av1_ml_prune_4_partition(AV1_COMP *const cpi, MACROBLOCK *const x,
+                              int part_ctx, int64_t best_rd,
+                              PartitionSearchState *part_state,
+                              int *part4_allowed,
+                              unsigned int pb_source_variance) {
+  const PartitionBlkParams blk_params = part_state->part_blk_params;
+  const int mi_row = blk_params.mi_row;
+  const int mi_col = blk_params.mi_col;
+  const int bsize = blk_params.bsize;
+
+  int64_t(*rect_part_rd)[SUB_PARTITIONS_RECT] = part_state->rect_part_rd;
+  int64_t *split_rd = part_state->split_rd;
   if (ext_ml_model_decision_after_part_ab(
           cpi, x, bsize, part_ctx, best_rd, rect_part_rd, split_rd,
-          partition_horz4_allowed, partition_vert4_allowed, pb_source_variance,
+          &part4_allowed[HORZ4], &part4_allowed[VERT4], pb_source_variance,
           mi_row, mi_col))
     return;
 
   if (best_rd >= 1000000000) return;
-  int64_t *horz_rd = rect_part_rd[HORZ];
-  int64_t *vert_rd = rect_part_rd[VERT];
+  int64_t *horz_rd = rect_part_rd[HORZ4];
+  int64_t *vert_rd = rect_part_rd[VERT4];
   const NN_CONFIG *nn_config = NULL;
   switch (bsize) {
     case BLOCK_16X16: nn_config = &av1_4_partition_nnconfig_16; break;
@@ -1396,12 +1421,11 @@
     case BLOCK_64X64: thresh -= 200; break;
     default: break;
   }
-  *partition_horz4_allowed = 0;
-  *partition_vert4_allowed = 0;
+  av1_zero_array(part4_allowed, NUM_PART4_TYPES);
   for (int i = 0; i < LABELS; ++i) {
     if (int_score[i] >= thresh) {
-      if ((i >> 0) & 1) *partition_horz4_allowed = 1;
-      if ((i >> 1) & 1) *partition_vert4_allowed = 1;
+      if ((i >> 0) & 1) part4_allowed[HORZ4] = 1;
+      if ((i >> 1) & 1) part4_allowed[VERT4] = 1;
     }
   }
 }
@@ -1409,12 +1433,14 @@
 #undef LABELS
 
 #define FEATURES 4
-void av1_ml_predict_breakout(AV1_COMP *const cpi, BLOCK_SIZE bsize,
-                             const MACROBLOCK *const x,
+void av1_ml_predict_breakout(AV1_COMP *const cpi, const MACROBLOCK *const x,
                              const RD_STATS *const rd_stats,
-                             const PartitionBlkParams blk_params,
                              unsigned int pb_source_variance, int bit_depth,
-                             int *do_square_split, int *do_rectangular_split) {
+                             PartitionSearchState *part_state) {
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const int mi_row = blk_params->mi_row, mi_col = blk_params->mi_col;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+
   const NN_CONFIG *nn_config = NULL;
   int thresh = 0;
   switch (bsize) {
@@ -1470,12 +1496,12 @@
   // Write features to file
   write_features_to_file(cpi->oxcf.partition_info_path,
                          cpi->ext_part_controller.test_mode, features, FEATURES,
-                         2, blk_params.bsize, blk_params.mi_row,
-                         blk_params.mi_col);
+                         2, bsize, mi_row, mi_col);
 
-  if (ext_ml_model_decision_after_none(
-          &cpi->ext_part_controller, frame_is_intra_only(&cpi->common),
-          features, do_square_split, do_rectangular_split)) {
+  if (ext_ml_model_decision_after_none(&cpi->ext_part_controller,
+                                       frame_is_intra_only(&cpi->common),
+                                       features, &part_state->do_square_split,
+                                       &part_state->do_rectangular_split)) {
     return;
   }
 
@@ -1485,21 +1511,22 @@
 
   // Make decision.
   if ((int)(score * 100) >= thresh) {
-    *do_square_split = 0;
-    *do_rectangular_split = 0;
+    part_state->do_square_split = 0;
+    part_state->do_rectangular_split = 0;
   }
 }
 #undef FEATURES
 
-void av1_prune_partitions_before_search(
-    AV1_COMP *const cpi, MACROBLOCK *const x, int mi_row, int mi_col,
-    BLOCK_SIZE bsize, SIMPLE_MOTION_DATA_TREE *const sms_tree,
-    int *partition_none_allowed, int *partition_horz_allowed,
-    int *partition_vert_allowed, int *do_rectangular_split,
-    int *do_square_split, int *prune_horz, int *prune_vert) {
+void av1_prune_partitions_before_search(AV1_COMP *const cpi,
+                                        MACROBLOCK *const x,
+                                        SIMPLE_MOTION_DATA_TREE *const sms_tree,
+                                        PartitionSearchState *part_state) {
   const AV1_COMMON *const cm = &cpi->common;
   const CommonModeInfoParams *const mi_params = &cm->mi_params;
 
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const BLOCK_SIZE bsize = blk_params->bsize;
+
   // Prune rectangular, AB and 4-way partition based on q index and block size
   if (cpi->sf.part_sf.prune_rectangular_split_based_on_qidx) {
     // Enumeration difference between two square partitions
@@ -1515,9 +1542,7 @@
     // qidx 86 to 170: prune bsize below BLOCK_16X16
     // qidx 171 to 255: prune bsize below BLOCK_8X8
     if (bsize < max_prune_bsize) {
-      *do_rectangular_split = 0;
-      *partition_horz_allowed = 0;
-      *partition_vert_allowed = 0;
+      av1_disable_rect_partitions(part_state);
     }
   }
 
@@ -1536,9 +1561,7 @@
       }
     }
     if (prune_sub_8x8) {
-      *partition_horz_allowed = 0;
-      *partition_vert_allowed = 0;
-      *do_square_split = 0;
+      av1_disable_all_splits(part_state);
     }
   }
 
@@ -1548,49 +1571,49 @@
       !cpi->use_screen_content_tools && frame_is_intra_only(cm) &&
       cpi->sf.part_sf.intra_cnn_split &&
       cm->seq_params->sb_size >= BLOCK_64X64 && bsize <= BLOCK_64X64 &&
-      bsize >= BLOCK_8X8 &&
-      mi_row + mi_size_high[bsize] <= mi_params->mi_rows &&
-      mi_col + mi_size_wide[bsize] <= mi_params->mi_cols;
+      blk_params->bsize_at_least_8x8 &&
+      av1_is_whole_blk_in_frame(blk_params, mi_params);
 
   if (try_intra_cnn_split) {
-    av1_intra_mode_cnn_partition(
-        &cpi->common, x, bsize, x->part_search_info.quad_tree_idx,
-        partition_none_allowed, partition_horz_allowed, partition_vert_allowed,
-        do_rectangular_split, do_square_split);
+    av1_intra_mode_cnn_partition(&cpi->common, x,
+                                 x->part_search_info.quad_tree_idx, part_state);
   }
 
   // Use simple motion search to prune out split or non-split partitions. This
   // must be done prior to PARTITION_SPLIT to propagate the initial mvs to a
   // smaller blocksize.
   const int try_split_only =
-      cpi->sf.part_sf.simple_motion_search_split && *do_square_split &&
-      bsize >= BLOCK_8X8 &&
-      mi_row + mi_size_high[bsize] <= mi_params->mi_rows &&
-      mi_col + mi_size_wide[bsize] <= mi_params->mi_cols &&
+      cpi->sf.part_sf.simple_motion_search_split &&
+      part_state->do_square_split && blk_params->bsize_at_least_8x8 &&
+      av1_is_whole_blk_in_frame(blk_params, mi_params) &&
       !frame_is_intra_only(cm) && !av1_superres_scaled(cm);
 
   if (try_split_only) {
-    av1_simple_motion_search_based_split(
-        cpi, x, sms_tree, mi_row, mi_col, bsize, partition_none_allowed,
-        partition_horz_allowed, partition_vert_allowed, do_rectangular_split,
-        do_square_split);
+    av1_simple_motion_search_based_split(cpi, x, sms_tree, part_state);
   }
 
   // Use simple motion search to prune out rectangular partition in some
   // direction. The results are stored in prune_horz and prune_vert in order to
   // bypass future related pruning checks if a pruning decision has been made.
-  const int try_prune_rect =
-      cpi->sf.part_sf.simple_motion_search_prune_rect &&
-      !frame_is_intra_only(cm) && *do_rectangular_split &&
-      (*do_square_split || *partition_none_allowed ||
-       (*prune_horz && *prune_vert)) &&
-      (*partition_horz_allowed || *partition_vert_allowed) &&
-      bsize >= BLOCK_8X8;
+
+  // We want to search at least one partition mode, so don't prune if NONE and
+  // SPLIT are disabled.
+  const int non_rect_part_allowed =
+      part_state->do_square_split || part_state->partition_none_allowed;
+  // Only run the model if the partitions are not already pruned.
+  const int rect_part_allowed = part_state->do_rectangular_split &&
+                                ((part_state->partition_rect_allowed[HORZ] &&
+                                  !part_state->prune_rect_part[HORZ]) ||
+                                 (part_state->partition_rect_allowed[VERT] &&
+                                  !part_state->prune_rect_part[VERT]));
+
+  const int try_prune_rect = cpi->sf.part_sf.simple_motion_search_prune_rect &&
+                             !frame_is_intra_only(cm) &&
+                             non_rect_part_allowed && rect_part_allowed &&
+                             !av1_superres_scaled(cm);
 
   if (try_prune_rect) {
-    av1_simple_motion_search_prune_rect(
-        cpi, x, sms_tree, mi_row, mi_col, bsize, *partition_horz_allowed,
-        *partition_vert_allowed, prune_horz, prune_vert);
+    av1_simple_motion_search_prune_rect(cpi, x, sms_tree, part_state);
   }
 }
 
@@ -1600,13 +1623,13 @@
 }
 #endif  // NDEBUG
 
-void av1_prune_partitions_by_max_min_bsize(
-    SuperBlockEnc *sb_enc, BLOCK_SIZE bsize, int is_not_edge_block,
-    int *partition_none_allowed, int *partition_horz_allowed,
-    int *partition_vert_allowed, int *do_square_split) {
+void av1_prune_partitions_by_max_min_bsize(SuperBlockEnc *sb_enc,
+                                           PartitionSearchState *part_state) {
   assert(is_bsize_square(sb_enc->max_partition_size));
   assert(is_bsize_square(sb_enc->min_partition_size));
   assert(sb_enc->min_partition_size <= sb_enc->max_partition_size);
+  const PartitionBlkParams *blk_params = &part_state->part_blk_params;
+  const BLOCK_SIZE bsize = blk_params->bsize;
   assert(is_bsize_square(bsize));
   const int max_partition_size_1d = block_size_wide[sb_enc->max_partition_size];
   const int min_partition_size_1d = block_size_wide[sb_enc->min_partition_size];
@@ -1616,19 +1639,18 @@
   const int is_gt_max_sq_part = bsize_1d > max_partition_size_1d;
   if (is_gt_max_sq_part) {
     // If current block size is larger than max, only allow split.
-    *partition_none_allowed = 0;
-    *partition_horz_allowed = 0;
-    *partition_vert_allowed = 0;
-    *do_square_split = 1;
+    av1_set_square_split_only(part_state);
   } else if (is_le_min_sq_part) {
     // If current block size is less or equal to min, only allow none if valid
     // block large enough; only allow split otherwise.
-    *partition_horz_allowed = 0;
-    *partition_vert_allowed = 0;
+    av1_disable_rect_partitions(part_state);
+
     // only disable square split when current block is not at the picture
     // boundary. otherwise, inherit the square split flag from previous logic
-    if (is_not_edge_block) *do_square_split = 0;
-    *partition_none_allowed = !(*do_square_split);
+    if (!av1_blk_has_rows_and_cols(blk_params)) {
+      part_state->do_square_split = 0;
+    }
+    part_state->partition_none_allowed = !(part_state->do_square_split);
   }
 }
 
@@ -1666,25 +1688,25 @@
   return 1;
 }
 
-void av1_prune_ab_partitions(
-    AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
-    BLOCK_SIZE bsize, const int mi_row, const int mi_col,
-    int pb_source_variance, int64_t best_rdcost,
-    int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
-    int64_t split_rd[SUB_PARTITIONS_SPLIT],
-    const RD_RECT_PART_WIN_INFO *rect_part_win_info, int ext_partition_allowed,
-    int partition_horz_allowed, int partition_vert_allowed,
-    int *horza_partition_allowed, int *horzb_partition_allowed,
-    int *verta_partition_allowed, int *vertb_partition_allowed) {
-  int64_t *horz_rd = rect_part_rd[HORZ];
-  int64_t *vert_rd = rect_part_rd[VERT];
+void av1_prune_ab_partitions(AV1_COMP *cpi, const MACROBLOCK *x,
+                             const PC_TREE *pc_tree, int pb_source_variance,
+                             int64_t best_rdcost,
+                             const RD_RECT_PART_WIN_INFO *rect_part_win_info,
+                             bool ext_partition_allowed,
+                             PartitionSearchState *part_state,
+                             int *ab_partitions_allowed) {
+  int64_t *horz_rd = part_state->rect_part_rd[HORZ];
+  int64_t *vert_rd = part_state->rect_part_rd[VERT];
+  int64_t *split_rd = part_state->split_rd;
   const PartitionCfg *const part_cfg = &cpi->oxcf.part_cfg;
   // The standard AB partitions are allowed initially if ext-partition-types are
   // allowed.
-  int horzab_partition_allowed =
-      ext_partition_allowed & part_cfg->enable_ab_partitions;
-  int vertab_partition_allowed =
-      ext_partition_allowed & part_cfg->enable_ab_partitions;
+  int horzab_partition_allowed = ext_partition_allowed &&
+                                 part_cfg->enable_ab_partitions &&
+                                 part_state->partition_rect_allowed[HORZ];
+  int vertab_partition_allowed = ext_partition_allowed &&
+                                 part_cfg->enable_ab_partitions &&
+                                 part_state->partition_rect_allowed[VERT];
 
   // Pruning: pruning out AB partitions on one main direction based on the
   // current best partition and source variance.
@@ -1719,20 +1741,20 @@
   // Pruning: pruning out horz_a or horz_b if the combined rdcost of its
   // subblocks estimated from previous partitions is much higher than the best
   // rd so far.
-  *horza_partition_allowed = horzab_partition_allowed;
-  *horzb_partition_allowed = horzab_partition_allowed;
+  ab_partitions_allowed[HORZ_A] = horzab_partition_allowed;
+  ab_partitions_allowed[HORZ_B] = horzab_partition_allowed;
   if (cpi->sf.part_sf.prune_ext_partition_types_search_level) {
     const int64_t horz_a_rd = horz_rd[1] + split_rd[0] + split_rd[1];
     const int64_t horz_b_rd = horz_rd[0] + split_rd[2] + split_rd[3];
     switch (cpi->sf.part_sf.prune_ext_partition_types_search_level) {
       case 1:
-        *horza_partition_allowed &= (horz_a_rd / 16 * 14 < best_rdcost);
-        *horzb_partition_allowed &= (horz_b_rd / 16 * 14 < best_rdcost);
+        ab_partitions_allowed[HORZ_A] &= (horz_a_rd / 16 * 14 < best_rdcost);
+        ab_partitions_allowed[HORZ_B] &= (horz_b_rd / 16 * 14 < best_rdcost);
         break;
       case 2:
       default:
-        *horza_partition_allowed &= (horz_a_rd / 16 * 15 < best_rdcost);
-        *horzb_partition_allowed &= (horz_b_rd / 16 * 15 < best_rdcost);
+        ab_partitions_allowed[HORZ_A] &= (horz_a_rd / 16 * 15 < best_rdcost);
+        ab_partitions_allowed[HORZ_B] &= (horz_b_rd / 16 * 15 < best_rdcost);
         break;
     }
   }
@@ -1740,20 +1762,20 @@
   // Pruning: pruning out vert_a or vert_b if the combined rdcost of its
   // subblocks estimated from previous partitions is much higher than the best
   // rd so far.
-  *verta_partition_allowed = vertab_partition_allowed;
-  *vertb_partition_allowed = vertab_partition_allowed;
+  ab_partitions_allowed[VERT_A] = vertab_partition_allowed;
+  ab_partitions_allowed[VERT_B] = vertab_partition_allowed;
   if (cpi->sf.part_sf.prune_ext_partition_types_search_level) {
     const int64_t vert_a_rd = vert_rd[1] + split_rd[0] + split_rd[2];
     const int64_t vert_b_rd = vert_rd[0] + split_rd[1] + split_rd[3];
     switch (cpi->sf.part_sf.prune_ext_partition_types_search_level) {
       case 1:
-        *verta_partition_allowed &= (vert_a_rd / 16 * 14 < best_rdcost);
-        *vertb_partition_allowed &= (vert_b_rd / 16 * 14 < best_rdcost);
+        ab_partitions_allowed[VERT_A] &= (vert_a_rd / 16 * 14 < best_rdcost);
+        ab_partitions_allowed[VERT_B] &= (vert_b_rd / 16 * 14 < best_rdcost);
         break;
       case 2:
       default:
-        *verta_partition_allowed &= (vert_a_rd / 16 * 15 < best_rdcost);
-        *vertb_partition_allowed &= (vert_b_rd / 16 * 15 < best_rdcost);
+        ab_partitions_allowed[VERT_A] &= (vert_a_rd / 16 * 15 < best_rdcost);
+        ab_partitions_allowed[VERT_B] &= (vert_b_rd / 16 * 15 < best_rdcost);
         break;
     }
   }
@@ -1761,43 +1783,36 @@
   // Pruning: pruning out some ab partitions using a DNN taking rd costs of
   // sub-blocks from previous basic partition types.
   if (cpi->sf.part_sf.ml_prune_partition && ext_partition_allowed &&
-      partition_horz_allowed && partition_vert_allowed) {
+      part_state->partition_rect_allowed[HORZ] &&
+      part_state->partition_rect_allowed[VERT]) {
     // TODO(huisu@google.com): x->source_variance may not be the current
     // block's variance. The correct one to use is pb_source_variance. Need to
     // re-train the model to fix it.
-    av1_ml_prune_ab_partition(cpi, bsize, mi_row, mi_col, pc_tree->partitioning,
+    av1_ml_prune_ab_partition(cpi, pc_tree->partitioning,
                               get_unsigned_bits(x->source_variance),
-                              best_rdcost, horz_rd, vert_rd, split_rd,
-                              horza_partition_allowed, horzb_partition_allowed,
-                              verta_partition_allowed, vertb_partition_allowed);
+                              best_rdcost, part_state, ab_partitions_allowed);
   }
 
-  // Disable ab partitions if they are disabled by the encoder parameter.
-  *horza_partition_allowed &= part_cfg->enable_ab_partitions;
-  *horzb_partition_allowed &= part_cfg->enable_ab_partitions;
-  *verta_partition_allowed &= part_cfg->enable_ab_partitions;
-  *vertb_partition_allowed &= part_cfg->enable_ab_partitions;
-
   // Pruning: pruning AB partitions based on the number of horz/vert wins
   // in the current block and sub-blocks in PARTITION_SPLIT.
   if (cpi->sf.part_sf.prune_ext_part_using_split_info >= 2 &&
-      *horza_partition_allowed) {
-    *horza_partition_allowed &= evaluate_ab_partition_based_on_split(
+      ab_partitions_allowed[HORZ_A]) {
+    ab_partitions_allowed[HORZ_A] &= evaluate_ab_partition_based_on_split(
         pc_tree, PARTITION_HORZ, rect_part_win_info, x->qindex, 0, 1);
   }
   if (cpi->sf.part_sf.prune_ext_part_using_split_info >= 2 &&
-      *horzb_partition_allowed) {
-    *horzb_partition_allowed &= evaluate_ab_partition_based_on_split(
+      ab_partitions_allowed[HORZ_B]) {
+    ab_partitions_allowed[HORZ_B] &= evaluate_ab_partition_based_on_split(
         pc_tree, PARTITION_HORZ, rect_part_win_info, x->qindex, 2, 3);
   }
   if (cpi->sf.part_sf.prune_ext_part_using_split_info >= 2 &&
-      *verta_partition_allowed) {
-    *verta_partition_allowed &= evaluate_ab_partition_based_on_split(
+      ab_partitions_allowed[VERT_A]) {
+    ab_partitions_allowed[VERT_A] &= evaluate_ab_partition_based_on_split(
         pc_tree, PARTITION_VERT, rect_part_win_info, x->qindex, 0, 2);
   }
   if (cpi->sf.part_sf.prune_ext_part_using_split_info >= 2 &&
-      *vertb_partition_allowed) {
-    *vertb_partition_allowed &= evaluate_ab_partition_based_on_split(
+      ab_partitions_allowed[VERT_B]) {
+    ab_partitions_allowed[VERT_B] &= evaluate_ab_partition_based_on_split(
         pc_tree, PARTITION_VERT, rect_part_win_info, x->qindex, 1, 3);
   }
 }
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index 23a94d7..8030f69 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -18,29 +18,22 @@
 #include "av1/encoder/encoder.h"
 
 void av1_intra_mode_cnn_partition(const AV1_COMMON *const cm, MACROBLOCK *x,
-                                  int bsize, int label_idx,
-                                  int *partition_none_allowed,
-                                  int *partition_horz_allowed,
-                                  int *partition_vert_allowed,
-                                  int *do_rectangular_split,
-                                  int *do_square_split);
+                                  int label_idx,
+                                  PartitionSearchState *part_state);
 
 // Performs a simple_motion_search with a single reference frame and extract
 // the variance of residues. Then use the features to determine whether we want
 // to go straight to splitting without trying PARTITION_NONE
-void av1_simple_motion_search_based_split(
-    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
-    int mi_row, int mi_col, BLOCK_SIZE bsize, int *partition_none_allowed,
-    int *partition_horz_allowed, int *partition_vert_allowed,
-    int *do_rectangular_split, int *do_square_split);
+void av1_simple_motion_search_based_split(AV1_COMP *const cpi, MACROBLOCK *x,
+                                          SIMPLE_MOTION_DATA_TREE *sms_tree,
+                                          PartitionSearchState *part_state);
 
 // Performs a simple_motion_search with two reference frames and extract
 // the variance of residues. Then use the features to determine whether we want
 // to prune some partitions.
-void av1_simple_motion_search_prune_rect(
-    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
-    int mi_row, int mi_col, BLOCK_SIZE bsize, int partition_horz_allowed,
-    int partition_vert_allowed, int *prune_horz, int *prune_vert);
+void av1_simple_motion_search_prune_rect(AV1_COMP *const cpi, MACROBLOCK *x,
+                                         SIMPLE_MOTION_DATA_TREE *sms_tree,
+                                         PartitionSearchState *part_state);
 
 #if !CONFIG_REALTIME_ONLY
 // Early terminates PARTITION_NONE using simple_motion_search features and the
@@ -49,10 +42,11 @@
 //  - The frame is not intra only
 //  - The current bsize is > BLOCK_8X8
 //  - blk_row + blk_height/2 < total_rows and blk_col + blk_width/2 < total_cols
-void av1_simple_motion_search_early_term_none(
-    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
-    int mi_row, int mi_col, BLOCK_SIZE bsize, const RD_STATS *none_rdc,
-    int *early_terminate);
+void av1_simple_motion_search_early_term_none(AV1_COMP *const cpi,
+                                              MACROBLOCK *x,
+                                              SIMPLE_MOTION_DATA_TREE *sms_tree,
+                                              const RD_STATS *none_rdc,
+                                              PartitionSearchState *part_state);
 
 // Get the features for selecting the max and min partition size. Currently this
 // performs simple_motion_search on 16X16 subblocks of the current superblock,
@@ -69,11 +63,10 @@
 // Attempts an early termination after PARTITION_SPLIT.
 void av1_ml_early_term_after_split(AV1_COMP *const cpi, MACROBLOCK *const x,
                                    SIMPLE_MOTION_DATA_TREE *const sms_tree,
-                                   BLOCK_SIZE bsize, int64_t best_rd,
-                                   int64_t part_none_rd, int64_t part_split_rd,
-                                   int64_t *split_block_rd, int mi_row,
-                                   int mi_col,
-                                   int *const terminate_partition_search);
+                                   int64_t best_rd, int64_t part_none_rd,
+                                   int64_t part_split_rd,
+                                   int64_t *split_block_rd,
+                                   PartitionSearchState *part_state);
 
 // Use the rdcost ratio and source var ratio to prune PARTITION_HORZ and
 // PARTITION_VERT.
@@ -82,47 +75,37 @@
 // that we can get better performance by adding in q_index and rectangular
 // sse/var from SMS. We should retrain and tune this model later.
 void av1_ml_prune_rect_partition(AV1_COMP *const cpi, const MACROBLOCK *const x,
-                                 BLOCK_SIZE bsize, const int mi_row,
-                                 const int mi_col, int64_t best_rd,
-                                 int64_t none_rd, int64_t *split_rd,
-                                 int *const dst_prune_horz,
-                                 int *const dst_prune_vert);
+                                 int64_t best_rd, int64_t none_rd,
+                                 const int64_t *split_rd,
+                                 PartitionSearchState *part_state);
 
 // Use a ML model to predict if horz_a, horz_b, vert_a, and vert_b should be
 // considered.
-void av1_ml_prune_ab_partition(
-    AV1_COMP *const cpi, BLOCK_SIZE bsize, const int mi_row, const int mi_col,
-    int part_ctx, int var_ctx, int64_t best_rd,
-    int64_t horz_rd[SUB_PARTITIONS_RECT], int64_t vert_rd[SUB_PARTITIONS_RECT],
-    int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const horza_partition_allowed,
-    int *const horzb_partition_allowed, int *const verta_partition_allowed,
-    int *const vertb_partition_allowed);
+void av1_ml_prune_ab_partition(AV1_COMP *const cpi, int part_ctx, int var_ctx,
+                               int64_t best_rd,
+                               PartitionSearchState *part_state,
+                               int *ab_partitions_allowed);
 
 // Use a ML model to predict if horz4 and vert4 should be considered.
-void av1_ml_prune_4_partition(
-    AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
-    int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
-    int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
-    int *const partition_vert4_allowed, unsigned int pb_source_variance,
-    int mi_row, int mi_col);
+void av1_ml_prune_4_partition(AV1_COMP *const cpi, MACROBLOCK *const x,
+                              int part_ctx, int64_t best_rd,
+                              PartitionSearchState *part_state,
+                              int *part4_allowed,
+                              unsigned int pb_source_variance);
 
 // ML-based partition search breakout after PARTITION_NONE.
-void av1_ml_predict_breakout(AV1_COMP *const cpi, BLOCK_SIZE bsize,
-                             const MACROBLOCK *const x,
+void av1_ml_predict_breakout(AV1_COMP *const cpi, const MACROBLOCK *const x,
                              const RD_STATS *const rd_stats,
-                             const PartitionBlkParams blk_params,
                              unsigned int pb_source_variance, int bit_depth,
-                             int *do_square_split, int *do_rectangular_split);
+                             PartitionSearchState *part_state);
 
 // The first round of partition pruning determined before any partition
 // has been tested. The decisions will be updated and passed back
 // to the partition search function.
-void av1_prune_partitions_before_search(
-    AV1_COMP *const cpi, MACROBLOCK *const x, int mi_row, int mi_col,
-    BLOCK_SIZE bsize, SIMPLE_MOTION_DATA_TREE *const sms_tree,
-    int *partition_none_allowed, int *partition_horz_allowed,
-    int *partition_vert_allowed, int *do_rectangular_split,
-    int *do_square_split, int *prune_horz, int *prune_vert);
+void av1_prune_partitions_before_search(AV1_COMP *const cpi,
+                                        MACROBLOCK *const x,
+                                        SIMPLE_MOTION_DATA_TREE *const sms_tree,
+                                        PartitionSearchState *part_state);
 
 // Prune out partitions that lead to coding block sizes outside the min and max
 // bsizes set by the encoder. Max and min square partition levels are defined as
@@ -130,23 +113,18 @@
 // reach. To implement this: only PARTITION_NONE is allowed if the current node
 // equals max_partition_size, only PARTITION_SPLIT is allowed if the current
 // node exceeds max_partition_size.
-void av1_prune_partitions_by_max_min_bsize(
-    SuperBlockEnc *sb_enc, BLOCK_SIZE bsize, int is_not_edge_block,
-    int *partition_none_allowed, int *partition_horz_allowed,
-    int *partition_vert_allowed, int *do_square_split);
+void av1_prune_partitions_by_max_min_bsize(SuperBlockEnc *sb_enc,
+                                           PartitionSearchState *part_state);
 
 // Prune out AB partitions based on rd decisions made from testing the
 // basic partitions.
-void av1_prune_ab_partitions(
-    AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
-    BLOCK_SIZE bsize, const int mi_row, const int mi_col,
-    int pb_source_variance, int64_t best_rdcost,
-    int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
-    int64_t split_rd[SUB_PARTITIONS_SPLIT],
-    const RD_RECT_PART_WIN_INFO *rect_part_win_info, int ext_partition_allowed,
-    int partition_horz_allowed, int partition_vert_allowed,
-    int *horza_partition_allowed, int *horzb_partition_allowed,
-    int *verta_partition_allowed, int *vertb_partition_allowed);
+void av1_prune_ab_partitions(AV1_COMP *cpi, const MACROBLOCK *x,
+                             const PC_TREE *pc_tree, int pb_source_variance,
+                             int64_t best_rdcost,
+                             const RD_RECT_PART_WIN_INFO *rect_part_win_info,
+                             bool ext_partition_allowed,
+                             PartitionSearchState *part_state,
+                             int *ab_partitions_allowed);
 
 void av1_collect_motion_search_features_sb(AV1_COMP *const cpi, ThreadData *td,
                                            const int mi_row, const int mi_col,