Move all partition models in encodeframe.c to partition_strategy

BUG=aomedia:2343

Change-Id: I6189fe7887b89d20551af37bdf6b8fe1d399627e
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 2bf8193..e629b27 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -68,10 +68,6 @@
                               ThreadData *td, TOKENEXTRA **t, RUN_TYPE dry_run,
                               int mi_row, int mi_col, BLOCK_SIZE bsize,
                               int *rate);
-static int ml_predict_breakout(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
-                               const MACROBLOCK *const x,
-                               const RD_STATS *const rd_stats,
-                               unsigned int pb_source_variance);
 
 // This is used as a reference when computing the source variance for the
 //  purposes of activity masking.
@@ -2363,8 +2359,8 @@
           // values as in rd_pick_partition. Retraining the model and tuning the
           // threshold values might be helpful to improve the speed.
           if (use_ml_based_breakout) {
-            if (ml_predict_breakout(cpi, bsize, x, &this_rdc,
-                                    x->source_variance)) {
+            if (av1_ml_predict_breakout(cpi, bsize, x, &this_rdc,
+                                        x->source_variance)) {
               do_square_split = 0;
             }
           }
@@ -2509,488 +2505,6 @@
   }
 }
 
-// split_score indicates confidence of picking split partition;
-// none_score indicates confidence of picking none partition;
-#define FEATURE_SIZE 19
-static int ml_prune_2pass_split_partition(const PC_TREE_STATS *pc_tree_stats,
-                                          BLOCK_SIZE bsize, int *split_score,
-                                          int *none_score) {
-  if (!pc_tree_stats->valid) return 0;
-  const float *split_weights = NULL;
-  const float *none_weights = NULL;
-  switch (bsize) {
-    case BLOCK_4X4: break;
-    case BLOCK_8X8:
-      split_weights = av1_2pass_split_partition_weights_8;
-      none_weights = av1_2pass_none_partition_weights_8;
-      break;
-    case BLOCK_16X16:
-      split_weights = av1_2pass_split_partition_weights_16;
-      none_weights = av1_2pass_none_partition_weights_16;
-      break;
-    case BLOCK_32X32:
-      split_weights = av1_2pass_split_partition_weights_32;
-      none_weights = av1_2pass_none_partition_weights_32;
-      break;
-    case BLOCK_64X64:
-      split_weights = av1_2pass_split_partition_weights_64;
-      none_weights = av1_2pass_none_partition_weights_64;
-      break;
-    case BLOCK_128X128:
-      split_weights = av1_2pass_split_partition_weights_128;
-      none_weights = av1_2pass_none_partition_weights_128;
-      break;
-    default: assert(0 && "Unexpected bsize.");
-  }
-  if (!split_weights || !none_weights) return 0;
-
-  aom_clear_system_state();
-
-  float features[FEATURE_SIZE];
-  int feature_index = 0;
-  features[feature_index++] = (float)pc_tree_stats->split;
-  features[feature_index++] = (float)pc_tree_stats->skip;
-  const int rdcost = (int)AOMMIN(INT_MAX, pc_tree_stats->rdcost);
-  const int rd_valid = rdcost > 0 && rdcost < 1000000000;
-  features[feature_index++] = (float)rd_valid;
-  for (int i = 0; i < 4; ++i) {
-    features[feature_index++] = (float)pc_tree_stats->sub_block_split[i];
-    features[feature_index++] = (float)pc_tree_stats->sub_block_skip[i];
-    const int sub_rdcost =
-        (int)AOMMIN(INT_MAX, pc_tree_stats->sub_block_rdcost[i]);
-    const int sub_rd_valid = sub_rdcost > 0 && sub_rdcost < 1000000000;
-    features[feature_index++] = (float)sub_rd_valid;
-    // Ratio between the sub-block RD and the whole-block RD.
-    float rd_ratio = 1.0f;
-    if (rd_valid && sub_rd_valid && sub_rdcost < rdcost)
-      rd_ratio = (float)sub_rdcost / (float)rdcost;
-    features[feature_index++] = rd_ratio;
-  }
-  assert(feature_index == FEATURE_SIZE);
-
-  float score_1 = split_weights[FEATURE_SIZE];
-  float score_2 = none_weights[FEATURE_SIZE];
-  for (int i = 0; i < FEATURE_SIZE; ++i) {
-    score_1 += features[i] * split_weights[i];
-    score_2 += features[i] * none_weights[i];
-  }
-  *split_score = (int)(score_1 * 100);
-  *none_score = (int)(score_2 * 100);
-  return 1;
-}
-#undef FEATURE_SIZE
-
-static void ml_prune_rect_partition(const AV1_COMP *const cpi,
-                                    const MACROBLOCK *const x, BLOCK_SIZE bsize,
-                                    int64_t best_rd, int64_t none_rd,
-                                    int64_t *split_rd,
-                                    int *const dst_prune_horz,
-                                    int *const dst_prune_vert) {
-  if (bsize < BLOCK_8X8 || best_rd >= 1000000000) return;
-  best_rd = AOMMAX(best_rd, 1);
-  const NN_CONFIG *nn_config = NULL;
-  const float prob_thresholds[5] = { 0.01f, 0.01f, 0.004f, 0.002f, 0.002f };
-  float cur_thresh = 0.0f;
-  switch (bsize) {
-    case BLOCK_8X8:
-      nn_config = &av1_rect_partition_nnconfig_8;
-      cur_thresh = prob_thresholds[0];
-      break;
-    case BLOCK_16X16:
-      nn_config = &av1_rect_partition_nnconfig_16;
-      cur_thresh = prob_thresholds[1];
-      break;
-    case BLOCK_32X32:
-      nn_config = &av1_rect_partition_nnconfig_32;
-      cur_thresh = prob_thresholds[2];
-      break;
-    case BLOCK_64X64:
-      nn_config = &av1_rect_partition_nnconfig_64;
-      cur_thresh = prob_thresholds[3];
-      break;
-    case BLOCK_128X128:
-      nn_config = &av1_rect_partition_nnconfig_128;
-      cur_thresh = prob_thresholds[4];
-      break;
-    default: assert(0 && "Unexpected bsize.");
-  }
-  if (!nn_config) return;
-  aom_clear_system_state();
-
-  // 1. Compute input features
-  float features[9];
-
-  // RD cost ratios
-  for (int i = 0; i < 5; i++) features[i] = 1.0f;
-  if (none_rd > 0 && none_rd < 1000000000)
-    features[0] = (float)none_rd / (float)best_rd;
-  for (int i = 0; i < 4; i++) {
-    if (split_rd[i] > 0 && split_rd[i] < 1000000000)
-      features[1 + i] = (float)split_rd[i] / (float)best_rd;
-  }
-
-  // Variance ratios
-  const MACROBLOCKD *const xd = &x->e_mbd;
-  int whole_block_variance;
-  if (is_cur_buf_hbd(xd)) {
-    whole_block_variance = av1_high_get_sby_perpixel_variance(
-        cpi, &x->plane[0].src, bsize, xd->bd);
-  } else {
-    whole_block_variance =
-        av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, bsize);
-  }
-  whole_block_variance = AOMMAX(whole_block_variance, 1);
-
-  int split_variance[4];
-  const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
-  struct buf_2d buf;
-  buf.stride = x->plane[0].src.stride;
-  const int bw = block_size_wide[bsize];
-  for (int i = 0; i < 4; ++i) {
-    const int x_idx = (i & 1) * bw / 2;
-    const int y_idx = (i >> 1) * bw / 2;
-    buf.buf = x->plane[0].src.buf + x_idx + y_idx * buf.stride;
-    if (is_cur_buf_hbd(xd)) {
-      split_variance[i] =
-          av1_high_get_sby_perpixel_variance(cpi, &buf, subsize, xd->bd);
-    } else {
-      split_variance[i] = av1_get_sby_perpixel_variance(cpi, &buf, subsize);
-    }
-  }
-
-  for (int i = 0; i < 4; i++)
-    features[5 + i] = (float)split_variance[i] / (float)whole_block_variance;
-
-  // 2. Do the prediction and prune 0-2 partitions based on their probabilities
-  float raw_scores[3] = { 0.0f };
-  av1_nn_predict(features, nn_config, raw_scores);
-  aom_clear_system_state();
-  float probs[3] = { 0.0f };
-  av1_nn_softmax(raw_scores, probs, 3);
-
-  // probs[0] is the probability of the fact that both rectangular partitions
-  // are worse than current best_rd
-  if (probs[1] <= cur_thresh) (*dst_prune_horz) = 1;
-  if (probs[2] <= cur_thresh) (*dst_prune_vert) = 1;
-}
-
-// Use a ML model to predict if horz_a, horz_b, vert_a, and vert_b should be
-// considered.
-static void ml_prune_ab_partition(BLOCK_SIZE bsize, int part_ctx, int var_ctx,
-                                  int64_t best_rd, int64_t horz_rd[2],
-                                  int64_t vert_rd[2], int64_t split_rd[4],
-                                  int *const horza_partition_allowed,
-                                  int *const horzb_partition_allowed,
-                                  int *const verta_partition_allowed,
-                                  int *const vertb_partition_allowed) {
-  if (bsize < BLOCK_8X8 || best_rd >= 1000000000) return;
-  const NN_CONFIG *nn_config = NULL;
-  switch (bsize) {
-    case BLOCK_8X8: nn_config = NULL; break;
-    case BLOCK_16X16: nn_config = &av1_ab_partition_nnconfig_16; break;
-    case BLOCK_32X32: nn_config = &av1_ab_partition_nnconfig_32; break;
-    case BLOCK_64X64: nn_config = &av1_ab_partition_nnconfig_64; break;
-    case BLOCK_128X128: nn_config = &av1_ab_partition_nnconfig_128; break;
-    default: assert(0 && "Unexpected bsize.");
-  }
-  if (!nn_config) return;
-
-  aom_clear_system_state();
-
-  // Generate features.
-  float features[10];
-  int feature_index = 0;
-  features[feature_index++] = (float)part_ctx;
-  features[feature_index++] = (float)var_ctx;
-  const int rdcost = (int)AOMMIN(INT_MAX, best_rd);
-  int sub_block_rdcost[8] = { 0 };
-  int rd_index = 0;
-  for (int i = 0; i < 2; ++i) {
-    if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
-      sub_block_rdcost[rd_index] = (int)horz_rd[i];
-    ++rd_index;
-  }
-  for (int i = 0; i < 2; ++i) {
-    if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
-      sub_block_rdcost[rd_index] = (int)vert_rd[i];
-    ++rd_index;
-  }
-  for (int i = 0; i < 4; ++i) {
-    if (split_rd[i] > 0 && split_rd[i] < 1000000000)
-      sub_block_rdcost[rd_index] = (int)split_rd[i];
-    ++rd_index;
-  }
-  for (int i = 0; i < 8; ++i) {
-    // Ratio between the sub-block RD and the whole-block RD.
-    float rd_ratio = 1.0f;
-    if (sub_block_rdcost[i] > 0 && sub_block_rdcost[i] < rdcost)
-      rd_ratio = (float)sub_block_rdcost[i] / (float)rdcost;
-    features[feature_index++] = rd_ratio;
-  }
-  assert(feature_index == 10);
-
-  // Calculate scores using the NN model.
-  float score[16] = { 0.0f };
-  av1_nn_predict(features, nn_config, score);
-  aom_clear_system_state();
-  int int_score[16];
-  int max_score = -1000;
-  for (int i = 0; i < 16; ++i) {
-    int_score[i] = (int)(100 * score[i]);
-    max_score = AOMMAX(int_score[i], max_score);
-  }
-
-  // Make decisions based on the model scores.
-  int thresh = max_score;
-  switch (bsize) {
-    case BLOCK_16X16: thresh -= 150; break;
-    case BLOCK_32X32: thresh -= 100; break;
-    default: break;
-  }
-  *horza_partition_allowed = 0;
-  *horzb_partition_allowed = 0;
-  *verta_partition_allowed = 0;
-  *vertb_partition_allowed = 0;
-  for (int i = 0; i < 16; ++i) {
-    if (int_score[i] >= thresh) {
-      if ((i >> 0) & 1) *horza_partition_allowed = 1;
-      if ((i >> 1) & 1) *horzb_partition_allowed = 1;
-      if ((i >> 2) & 1) *verta_partition_allowed = 1;
-      if ((i >> 3) & 1) *vertb_partition_allowed = 1;
-    }
-  }
-}
-
-#define FEATURES 18
-#define LABELS 4
-// Use a ML model to predict if horz4 and vert4 should be considered.
-static void ml_prune_4_partition(const AV1_COMP *const cpi, MACROBLOCK *const x,
-                                 BLOCK_SIZE bsize, int part_ctx,
-                                 int64_t best_rd, int64_t horz_rd[2],
-                                 int64_t vert_rd[2], int64_t split_rd[4],
-                                 int *const partition_horz4_allowed,
-                                 int *const partition_vert4_allowed,
-                                 unsigned int pb_source_variance, int mi_row,
-                                 int mi_col) {
-  if (best_rd >= 1000000000) return;
-  const NN_CONFIG *nn_config = NULL;
-  switch (bsize) {
-    case BLOCK_16X16: nn_config = &av1_4_partition_nnconfig_16; break;
-    case BLOCK_32X32: nn_config = &av1_4_partition_nnconfig_32; break;
-    case BLOCK_64X64: nn_config = &av1_4_partition_nnconfig_64; break;
-    default: assert(0 && "Unexpected bsize.");
-  }
-  if (!nn_config) return;
-
-  aom_clear_system_state();
-
-  // Generate features.
-  float features[FEATURES];
-  int feature_index = 0;
-  features[feature_index++] = (float)part_ctx;
-  features[feature_index++] = (float)get_unsigned_bits(pb_source_variance);
-
-  const int rdcost = (int)AOMMIN(INT_MAX, best_rd);
-  int sub_block_rdcost[8] = { 0 };
-  int rd_index = 0;
-  for (int i = 0; i < 2; ++i) {
-    if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
-      sub_block_rdcost[rd_index] = (int)horz_rd[i];
-    ++rd_index;
-  }
-  for (int i = 0; i < 2; ++i) {
-    if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
-      sub_block_rdcost[rd_index] = (int)vert_rd[i];
-    ++rd_index;
-  }
-  for (int i = 0; i < 4; ++i) {
-    if (split_rd[i] > 0 && split_rd[i] < 1000000000)
-      sub_block_rdcost[rd_index] = (int)split_rd[i];
-    ++rd_index;
-  }
-  for (int i = 0; i < 8; ++i) {
-    // Ratio between the sub-block RD and the whole-block RD.
-    float rd_ratio = 1.0f;
-    if (sub_block_rdcost[i] > 0 && sub_block_rdcost[i] < rdcost)
-      rd_ratio = (float)sub_block_rdcost[i] / (float)rdcost;
-    features[feature_index++] = rd_ratio;
-  }
-
-  // Get variance of the 1:4 and 4:1 sub-blocks.
-  unsigned int horz_4_source_var[4] = { 0 };
-  unsigned int vert_4_source_var[4] = { 0 };
-  {
-    BLOCK_SIZE horz_4_bs = get_partition_subsize(bsize, PARTITION_HORZ_4);
-    BLOCK_SIZE vert_4_bs = get_partition_subsize(bsize, PARTITION_VERT_4);
-    av1_setup_src_planes(x, cpi->source, mi_row, mi_col,
-                         av1_num_planes(&cpi->common), bsize);
-    const int src_stride = x->plane[0].src.stride;
-    const uint8_t *src = x->plane[0].src.buf;
-    const MACROBLOCKD *const xd = &x->e_mbd;
-    for (int i = 0; i < 4; ++i) {
-      const uint8_t *horz_src =
-          src + i * block_size_high[horz_4_bs] * src_stride;
-      const uint8_t *vert_src = src + i * block_size_wide[vert_4_bs];
-      unsigned int horz_var, vert_var, sse;
-      if (is_cur_buf_hbd(xd)) {
-        switch (xd->bd) {
-          case 10:
-            horz_var = cpi->fn_ptr[horz_4_bs].vf(
-                horz_src, src_stride, CONVERT_TO_BYTEPTR(AV1_HIGH_VAR_OFFS_10),
-                0, &sse);
-            vert_var = cpi->fn_ptr[vert_4_bs].vf(
-                vert_src, src_stride, CONVERT_TO_BYTEPTR(AV1_HIGH_VAR_OFFS_10),
-                0, &sse);
-            break;
-          case 12:
-            horz_var = cpi->fn_ptr[horz_4_bs].vf(
-                horz_src, src_stride, CONVERT_TO_BYTEPTR(AV1_HIGH_VAR_OFFS_12),
-                0, &sse);
-            vert_var = cpi->fn_ptr[vert_4_bs].vf(
-                vert_src, src_stride, CONVERT_TO_BYTEPTR(AV1_HIGH_VAR_OFFS_12),
-                0, &sse);
-            break;
-          case 8:
-          default:
-            horz_var = cpi->fn_ptr[horz_4_bs].vf(
-                horz_src, src_stride, CONVERT_TO_BYTEPTR(AV1_HIGH_VAR_OFFS_8),
-                0, &sse);
-            vert_var = cpi->fn_ptr[vert_4_bs].vf(
-                vert_src, src_stride, CONVERT_TO_BYTEPTR(AV1_HIGH_VAR_OFFS_8),
-                0, &sse);
-            break;
-        }
-        horz_4_source_var[i] =
-            ROUND_POWER_OF_TWO(horz_var, num_pels_log2_lookup[horz_4_bs]);
-        vert_4_source_var[i] =
-            ROUND_POWER_OF_TWO(vert_var, num_pels_log2_lookup[vert_4_bs]);
-      } else {
-        horz_var = cpi->fn_ptr[horz_4_bs].vf(horz_src, src_stride, AV1_VAR_OFFS,
-                                             0, &sse);
-        vert_var = cpi->fn_ptr[vert_4_bs].vf(vert_src, src_stride, AV1_VAR_OFFS,
-                                             0, &sse);
-        horz_4_source_var[i] =
-            ROUND_POWER_OF_TWO(horz_var, num_pels_log2_lookup[horz_4_bs]);
-        vert_4_source_var[i] =
-            ROUND_POWER_OF_TWO(vert_var, num_pels_log2_lookup[vert_4_bs]);
-      }
-    }
-  }
-
-  const float denom = (float)(pb_source_variance + 1);
-  const float low_b = 0.1f;
-  const float high_b = 10.0f;
-  for (int i = 0; i < 4; ++i) {
-    // Ratio between the 4:1 sub-block variance and the whole-block variance.
-    float var_ratio = (float)(horz_4_source_var[i] + 1) / denom;
-    if (var_ratio < low_b) var_ratio = low_b;
-    if (var_ratio > high_b) var_ratio = high_b;
-    features[feature_index++] = var_ratio;
-  }
-  for (int i = 0; i < 4; ++i) {
-    // Ratio between the 1:4 sub-block RD and the whole-block RD.
-    float var_ratio = (float)(vert_4_source_var[i] + 1) / denom;
-    if (var_ratio < low_b) var_ratio = low_b;
-    if (var_ratio > high_b) var_ratio = high_b;
-    features[feature_index++] = var_ratio;
-  }
-  assert(feature_index == FEATURES);
-
-  // Calculate scores using the NN model.
-  float score[LABELS] = { 0.0f };
-  av1_nn_predict(features, nn_config, score);
-  aom_clear_system_state();
-  int int_score[LABELS];
-  int max_score = -1000;
-  for (int i = 0; i < LABELS; ++i) {
-    int_score[i] = (int)(100 * score[i]);
-    max_score = AOMMAX(int_score[i], max_score);
-  }
-
-  // Make decisions based on the model scores.
-  int thresh = max_score;
-  switch (bsize) {
-    case BLOCK_16X16: thresh -= 500; break;
-    case BLOCK_32X32: thresh -= 500; break;
-    case BLOCK_64X64: thresh -= 200; break;
-    default: break;
-  }
-  *partition_horz4_allowed = 0;
-  *partition_vert4_allowed = 0;
-  for (int i = 0; i < LABELS; ++i) {
-    if (int_score[i] >= thresh) {
-      if ((i >> 0) & 1) *partition_horz4_allowed = 1;
-      if ((i >> 1) & 1) *partition_vert4_allowed = 1;
-    }
-  }
-}
-#undef FEATURES
-#undef LABELS
-
-#define FEATURES 4
-// ML-based partition search breakout.
-static int ml_predict_breakout(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
-                               const MACROBLOCK *const x,
-                               const RD_STATS *const rd_stats,
-                               unsigned int pb_source_variance) {
-  const NN_CONFIG *nn_config = NULL;
-  int thresh = 0;
-  switch (bsize) {
-    case BLOCK_8X8:
-      nn_config = &av1_partition_breakout_nnconfig_8;
-      thresh = cpi->sf.ml_partition_search_breakout_thresh[0];
-      break;
-    case BLOCK_16X16:
-      nn_config = &av1_partition_breakout_nnconfig_16;
-      thresh = cpi->sf.ml_partition_search_breakout_thresh[1];
-      break;
-    case BLOCK_32X32:
-      nn_config = &av1_partition_breakout_nnconfig_32;
-      thresh = cpi->sf.ml_partition_search_breakout_thresh[2];
-      break;
-    case BLOCK_64X64:
-      nn_config = &av1_partition_breakout_nnconfig_64;
-      thresh = cpi->sf.ml_partition_search_breakout_thresh[3];
-      break;
-    case BLOCK_128X128:
-      nn_config = &av1_partition_breakout_nnconfig_128;
-      thresh = cpi->sf.ml_partition_search_breakout_thresh[4];
-      break;
-    default: assert(0 && "Unexpected bsize.");
-  }
-  if (!nn_config || thresh < 0) return 0;
-
-  // Generate feature values.
-  float features[FEATURES];
-  int feature_index = 0;
-  aom_clear_system_state();
-
-  const int num_pels_log2 = num_pels_log2_lookup[bsize];
-  float rate_f = (float)AOMMIN(rd_stats->rate, INT_MAX);
-  rate_f = ((float)x->rdmult / 128.0f / 512.0f / (float)(1 << num_pels_log2)) *
-           rate_f;
-  features[feature_index++] = rate_f;
-
-  const float dist_f =
-      (float)(AOMMIN(rd_stats->dist, INT_MAX) >> num_pels_log2);
-  features[feature_index++] = dist_f;
-
-  features[feature_index++] = (float)pb_source_variance;
-
-  const int dc_q = (int)x->plane[0].dequant_QTX[0];
-  features[feature_index++] = (float)(dc_q * dc_q) / 256.0f;
-  assert(feature_index == FEATURES);
-
-  // Calculate score using the NN model.
-  float score = 0.0f;
-  av1_nn_predict(features, nn_config, &score);
-  aom_clear_system_state();
-
-  // Make decision.
-  return (int)(score * 100) >= thresh;
-}
-#undef FEATURES
-
 // Record the ref frames that have been selected by square partition blocks.
 static void update_picked_ref_frames_mask(MACROBLOCK *const x, int ref_type,
                                           BLOCK_SIZE bsize, int mib_size,
@@ -3139,7 +2653,7 @@
   if (bsize > BLOCK_4X4 && x->use_cb_search_range) {
     int split_score = 0;
     int none_score = 0;
-    const int score_valid = ml_prune_2pass_split_partition(
+    const int score_valid = av1_ml_prune_2pass_split_partition(
         &pc_tree->pc_tree_stats, bsize, &split_score, &none_score);
     if (score_valid) {
       {
@@ -3350,8 +2864,8 @@
               bsize <= cpi->sf.use_square_partition_only_threshold &&
               bsize > BLOCK_4X4 && xd->bd == 8;
           if (use_ml_based_breakout) {
-            if (ml_predict_breakout(cpi, bsize, x, &this_rdc,
-                                    pb_source_variance)) {
+            if (av1_ml_predict_breakout(cpi, bsize, x, &this_rdc,
+                                        pb_source_variance)) {
               do_square_split = 0;
               do_rectangular_split = 0;
             }
@@ -3480,8 +2994,8 @@
       (partition_horz_allowed || partition_vert_allowed) &&
       !(prune_horz || prune_vert) && !terminate_partition_search) {
     av1_setup_src_planes(x, cpi->source, mi_row, mi_col, num_planes, bsize);
-    ml_prune_rect_partition(cpi, x, bsize, best_rdc.rdcost, cur_none_rd,
-                            split_rd, &prune_horz, &prune_vert);
+    av1_ml_prune_rect_partition(cpi, x, bsize, best_rdc.rdcost, cur_none_rd,
+                                split_rd, &prune_horz, &prune_vert);
   }
 
   // PARTITION_HORZ
@@ -3774,11 +3288,11 @@
     // TODO(huisu@google.com): x->source_variance may not be the current
     // block's variance. The correct one to use is pb_source_variance. Need to
     // re-train the model to fix it.
-    ml_prune_ab_partition(bsize, pc_tree->partitioning,
-                          get_unsigned_bits(x->source_variance),
-                          best_rdc.rdcost, horz_rd, vert_rd, split_rd,
-                          &horza_partition_allowed, &horzb_partition_allowed,
-                          &verta_partition_allowed, &vertb_partition_allowed);
+    av1_ml_prune_ab_partition(
+        bsize, pc_tree->partitioning, get_unsigned_bits(x->source_variance),
+        best_rdc.rdcost, horz_rd, vert_rd, split_rd, &horza_partition_allowed,
+        &horzb_partition_allowed, &verta_partition_allowed,
+        &vertb_partition_allowed);
   }
 
   horza_partition_allowed &= cpi->oxcf.enable_ab_partitions;
@@ -3980,10 +3494,10 @@
   }
   if (cpi->sf.ml_prune_4_partition && partition4_allowed &&
       partition_horz_allowed && partition_vert_allowed) {
-    ml_prune_4_partition(cpi, x, bsize, pc_tree->partitioning, best_rdc.rdcost,
-                         horz_rd, vert_rd, split_rd, &partition_horz4_allowed,
-                         &partition_vert4_allowed, pb_source_variance, mi_row,
-                         mi_col);
+    av1_ml_prune_4_partition(cpi, x, bsize, pc_tree->partitioning,
+                             best_rdc.rdcost, horz_rd, vert_rd, split_rd,
+                             &partition_horz4_allowed, &partition_vert4_allowed,
+                             pb_source_variance, mi_row, mi_col);
   }
 
 #if CONFIG_DIST_8X8
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index e979a49..0a96914 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -949,3 +949,454 @@
   if (score < thresh) *terminate_partition_search = 1;
 }
 #undef FEATURES
+
+#define FEATURE_SIZE 19
+int av1_ml_prune_2pass_split_partition(const PC_TREE_STATS *pc_tree_stats,
+                                       BLOCK_SIZE bsize, int *split_score,
+                                       int *none_score) {
+  if (!pc_tree_stats->valid) return 0;
+  const float *split_weights = NULL;
+  const float *none_weights = NULL;
+  switch (bsize) {
+    case BLOCK_4X4: break;
+    case BLOCK_8X8:
+      split_weights = av1_2pass_split_partition_weights_8;
+      none_weights = av1_2pass_none_partition_weights_8;
+      break;
+    case BLOCK_16X16:
+      split_weights = av1_2pass_split_partition_weights_16;
+      none_weights = av1_2pass_none_partition_weights_16;
+      break;
+    case BLOCK_32X32:
+      split_weights = av1_2pass_split_partition_weights_32;
+      none_weights = av1_2pass_none_partition_weights_32;
+      break;
+    case BLOCK_64X64:
+      split_weights = av1_2pass_split_partition_weights_64;
+      none_weights = av1_2pass_none_partition_weights_64;
+      break;
+    case BLOCK_128X128:
+      split_weights = av1_2pass_split_partition_weights_128;
+      none_weights = av1_2pass_none_partition_weights_128;
+      break;
+    default: assert(0 && "Unexpected bsize.");
+  }
+  if (!split_weights || !none_weights) return 0;
+
+  aom_clear_system_state();
+
+  float features[FEATURE_SIZE];
+  int feature_index = 0;
+  features[feature_index++] = (float)pc_tree_stats->split;
+  features[feature_index++] = (float)pc_tree_stats->skip;
+  const int rdcost = (int)AOMMIN(INT_MAX, pc_tree_stats->rdcost);
+  const int rd_valid = rdcost > 0 && rdcost < 1000000000;
+  features[feature_index++] = (float)rd_valid;
+  for (int i = 0; i < 4; ++i) {
+    features[feature_index++] = (float)pc_tree_stats->sub_block_split[i];
+    features[feature_index++] = (float)pc_tree_stats->sub_block_skip[i];
+    const int sub_rdcost =
+        (int)AOMMIN(INT_MAX, pc_tree_stats->sub_block_rdcost[i]);
+    const int sub_rd_valid = sub_rdcost > 0 && sub_rdcost < 1000000000;
+    features[feature_index++] = (float)sub_rd_valid;
+    // Ratio between the sub-block RD and the whole-block RD.
+    float rd_ratio = 1.0f;
+    if (rd_valid && sub_rd_valid && sub_rdcost < rdcost)
+      rd_ratio = (float)sub_rdcost / (float)rdcost;
+    features[feature_index++] = rd_ratio;
+  }
+  assert(feature_index == FEATURE_SIZE);
+
+  float score_1 = split_weights[FEATURE_SIZE];
+  float score_2 = none_weights[FEATURE_SIZE];
+  for (int i = 0; i < FEATURE_SIZE; ++i) {
+    score_1 += features[i] * split_weights[i];
+    score_2 += features[i] * none_weights[i];
+  }
+  *split_score = (int)(score_1 * 100);
+  *none_score = (int)(score_2 * 100);
+  return 1;
+}
+#undef FEATURE_SIZE
+
+void av1_ml_prune_rect_partition(const AV1_COMP *const cpi,
+                                 const MACROBLOCK *const x, BLOCK_SIZE bsize,
+                                 int64_t best_rd, int64_t none_rd,
+                                 int64_t *split_rd, int *const dst_prune_horz,
+                                 int *const dst_prune_vert) {
+  if (bsize < BLOCK_8X8 || best_rd >= 1000000000) return;
+  best_rd = AOMMAX(best_rd, 1);
+  const NN_CONFIG *nn_config = NULL;
+  const float prob_thresholds[5] = { 0.01f, 0.01f, 0.004f, 0.002f, 0.002f };
+  float cur_thresh = 0.0f;
+  switch (bsize) {
+    case BLOCK_8X8:
+      nn_config = &av1_rect_partition_nnconfig_8;
+      cur_thresh = prob_thresholds[0];
+      break;
+    case BLOCK_16X16:
+      nn_config = &av1_rect_partition_nnconfig_16;
+      cur_thresh = prob_thresholds[1];
+      break;
+    case BLOCK_32X32:
+      nn_config = &av1_rect_partition_nnconfig_32;
+      cur_thresh = prob_thresholds[2];
+      break;
+    case BLOCK_64X64:
+      nn_config = &av1_rect_partition_nnconfig_64;
+      cur_thresh = prob_thresholds[3];
+      break;
+    case BLOCK_128X128:
+      nn_config = &av1_rect_partition_nnconfig_128;
+      cur_thresh = prob_thresholds[4];
+      break;
+    default: assert(0 && "Unexpected bsize.");
+  }
+  if (!nn_config) return;
+  aom_clear_system_state();
+
+  // 1. Compute input features
+  float features[9];
+
+  // RD cost ratios
+  for (int i = 0; i < 5; i++) features[i] = 1.0f;
+  if (none_rd > 0 && none_rd < 1000000000)
+    features[0] = (float)none_rd / (float)best_rd;
+  for (int i = 0; i < 4; i++) {
+    if (split_rd[i] > 0 && split_rd[i] < 1000000000)
+      features[1 + i] = (float)split_rd[i] / (float)best_rd;
+  }
+
+  // Variance ratios
+  const MACROBLOCKD *const xd = &x->e_mbd;
+  int whole_block_variance;
+  if (is_cur_buf_hbd(xd)) {
+    whole_block_variance = av1_high_get_sby_perpixel_variance(
+        cpi, &x->plane[0].src, bsize, xd->bd);
+  } else {
+    whole_block_variance =
+        av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, bsize);
+  }
+  whole_block_variance = AOMMAX(whole_block_variance, 1);
+
+  int split_variance[4];
+  const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+  struct buf_2d buf;
+  buf.stride = x->plane[0].src.stride;
+  const int bw = block_size_wide[bsize];
+  for (int i = 0; i < 4; ++i) {
+    const int x_idx = (i & 1) * bw / 2;
+    const int y_idx = (i >> 1) * bw / 2;
+    buf.buf = x->plane[0].src.buf + x_idx + y_idx * buf.stride;
+    if (is_cur_buf_hbd(xd)) {
+      split_variance[i] =
+          av1_high_get_sby_perpixel_variance(cpi, &buf, subsize, xd->bd);
+    } else {
+      split_variance[i] = av1_get_sby_perpixel_variance(cpi, &buf, subsize);
+    }
+  }
+
+  for (int i = 0; i < 4; i++)
+    features[5 + i] = (float)split_variance[i] / (float)whole_block_variance;
+
+  // 2. Do the prediction and prune 0-2 partitions based on their probabilities
+  float raw_scores[3] = { 0.0f };
+  av1_nn_predict(features, nn_config, raw_scores);
+  aom_clear_system_state();
+  float probs[3] = { 0.0f };
+  av1_nn_softmax(raw_scores, probs, 3);
+
+  // probs[0] is the probability of the fact that both rectangular partitions
+  // are worse than current best_rd
+  if (probs[1] <= cur_thresh) (*dst_prune_horz) = 1;
+  if (probs[2] <= cur_thresh) (*dst_prune_vert) = 1;
+}
+
+// Use a ML model to predict if horz_a, horz_b, vert_a, and vert_b should be
+// considered.
+void av1_ml_prune_ab_partition(BLOCK_SIZE bsize, int part_ctx, int var_ctx,
+                               int64_t best_rd, int64_t horz_rd[2],
+                               int64_t vert_rd[2], int64_t split_rd[4],
+                               int *const horza_partition_allowed,
+                               int *const horzb_partition_allowed,
+                               int *const verta_partition_allowed,
+                               int *const vertb_partition_allowed) {
+  if (bsize < BLOCK_8X8 || best_rd >= 1000000000) return;
+  const NN_CONFIG *nn_config = NULL;
+  switch (bsize) {
+    case BLOCK_8X8: nn_config = NULL; break;
+    case BLOCK_16X16: nn_config = &av1_ab_partition_nnconfig_16; break;
+    case BLOCK_32X32: nn_config = &av1_ab_partition_nnconfig_32; break;
+    case BLOCK_64X64: nn_config = &av1_ab_partition_nnconfig_64; break;
+    case BLOCK_128X128: nn_config = &av1_ab_partition_nnconfig_128; break;
+    default: assert(0 && "Unexpected bsize.");
+  }
+  if (!nn_config) return;
+
+  aom_clear_system_state();
+
+  // Generate features.
+  float features[10];
+  int feature_index = 0;
+  features[feature_index++] = (float)part_ctx;
+  features[feature_index++] = (float)var_ctx;
+  const int rdcost = (int)AOMMIN(INT_MAX, best_rd);
+  int sub_block_rdcost[8] = { 0 };
+  int rd_index = 0;
+  for (int i = 0; i < 2; ++i) {
+    if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)horz_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < 2; ++i) {
+    if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)vert_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < 4; ++i) {
+    if (split_rd[i] > 0 && split_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)split_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < 8; ++i) {
+    // Ratio between the sub-block RD and the whole-block RD.
+    float rd_ratio = 1.0f;
+    if (sub_block_rdcost[i] > 0 && sub_block_rdcost[i] < rdcost)
+      rd_ratio = (float)sub_block_rdcost[i] / (float)rdcost;
+    features[feature_index++] = rd_ratio;
+  }
+  assert(feature_index == 10);
+
+  // Calculate scores using the NN model.
+  float score[16] = { 0.0f };
+  av1_nn_predict(features, nn_config, score);
+  aom_clear_system_state();
+  int int_score[16];
+  int max_score = -1000;
+  for (int i = 0; i < 16; ++i) {
+    int_score[i] = (int)(100 * score[i]);
+    max_score = AOMMAX(int_score[i], max_score);
+  }
+
+  // Make decisions based on the model scores.
+  int thresh = max_score;
+  switch (bsize) {
+    case BLOCK_16X16: thresh -= 150; break;
+    case BLOCK_32X32: thresh -= 100; break;
+    default: break;
+  }
+  *horza_partition_allowed = 0;
+  *horzb_partition_allowed = 0;
+  *verta_partition_allowed = 0;
+  *vertb_partition_allowed = 0;
+  for (int i = 0; i < 16; ++i) {
+    if (int_score[i] >= thresh) {
+      if ((i >> 0) & 1) *horza_partition_allowed = 1;
+      if ((i >> 1) & 1) *horzb_partition_allowed = 1;
+      if ((i >> 2) & 1) *verta_partition_allowed = 1;
+      if ((i >> 3) & 1) *vertb_partition_allowed = 1;
+    }
+  }
+}
+
+#define FEATURES 18
+#define LABELS 4
+// Use a ML model to predict if horz4 and vert4 should be considered.
+void av1_ml_prune_4_partition(const AV1_COMP *const cpi, MACROBLOCK *const x,
+                              BLOCK_SIZE bsize, int part_ctx, int64_t best_rd,
+                              int64_t horz_rd[2], int64_t vert_rd[2],
+                              int64_t split_rd[4],
+                              int *const partition_horz4_allowed,
+                              int *const partition_vert4_allowed,
+                              unsigned int pb_source_variance, int mi_row,
+                              int mi_col) {
+  if (best_rd >= 1000000000) return;
+  const NN_CONFIG *nn_config = NULL;
+  switch (bsize) {
+    case BLOCK_16X16: nn_config = &av1_4_partition_nnconfig_16; break;
+    case BLOCK_32X32: nn_config = &av1_4_partition_nnconfig_32; break;
+    case BLOCK_64X64: nn_config = &av1_4_partition_nnconfig_64; break;
+    default: assert(0 && "Unexpected bsize.");
+  }
+  if (!nn_config) return;
+
+  aom_clear_system_state();
+
+  // Generate features.
+  float features[FEATURES];
+  int feature_index = 0;
+  features[feature_index++] = (float)part_ctx;
+  features[feature_index++] = (float)get_unsigned_bits(pb_source_variance);
+
+  const int rdcost = (int)AOMMIN(INT_MAX, best_rd);
+  int sub_block_rdcost[8] = { 0 };
+  int rd_index = 0;
+  for (int i = 0; i < 2; ++i) {
+    if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)horz_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < 2; ++i) {
+    if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)vert_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < 4; ++i) {
+    if (split_rd[i] > 0 && split_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)split_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < 8; ++i) {
+    // Ratio between the sub-block RD and the whole-block RD.
+    float rd_ratio = 1.0f;
+    if (sub_block_rdcost[i] > 0 && sub_block_rdcost[i] < rdcost)
+      rd_ratio = (float)sub_block_rdcost[i] / (float)rdcost;
+    features[feature_index++] = rd_ratio;
+  }
+
+  // Get variance of the 1:4 and 4:1 sub-blocks.
+  unsigned int horz_4_source_var[4] = { 0 };
+  unsigned int vert_4_source_var[4] = { 0 };
+  {
+    BLOCK_SIZE horz_4_bs = get_partition_subsize(bsize, PARTITION_HORZ_4);
+    BLOCK_SIZE vert_4_bs = get_partition_subsize(bsize, PARTITION_VERT_4);
+    av1_setup_src_planes(x, cpi->source, mi_row, mi_col,
+                         av1_num_planes(&cpi->common), bsize);
+    const int src_stride = x->plane[0].src.stride;
+    uint8_t *src = x->plane[0].src.buf;
+    const MACROBLOCKD *const xd = &x->e_mbd;
+
+    struct buf_2d horz_4_src, vert_4_src;
+    horz_4_src.stride = src_stride;
+    vert_4_src.stride = src_stride;
+
+    for (int i = 0; i < 4; ++i) {
+      horz_4_src.buf = src + i * block_size_high[horz_4_bs] * src_stride;
+      vert_4_src.buf = src + i * block_size_wide[vert_4_bs];
+
+      if (is_cur_buf_hbd(xd)) {
+        horz_4_source_var[i] = av1_high_get_sby_perpixel_variance(
+            cpi, &horz_4_src, horz_4_bs, xd->bd);
+        vert_4_source_var[i] = av1_high_get_sby_perpixel_variance(
+            cpi, &vert_4_src, vert_4_bs, xd->bd);
+      } else {
+        horz_4_source_var[i] =
+            av1_get_sby_perpixel_variance(cpi, &horz_4_src, horz_4_bs);
+        vert_4_source_var[i] =
+            av1_get_sby_perpixel_variance(cpi, &vert_4_src, vert_4_bs);
+      }
+    }
+  }
+
+  const float denom = (float)(pb_source_variance + 1);
+  const float low_b = 0.1f;
+  const float high_b = 10.0f;
+  for (int i = 0; i < 4; ++i) {
+    // Ratio between the 4:1 sub-block variance and the whole-block variance.
+    float var_ratio = (float)(horz_4_source_var[i] + 1) / denom;
+    if (var_ratio < low_b) var_ratio = low_b;
+    if (var_ratio > high_b) var_ratio = high_b;
+    features[feature_index++] = var_ratio;
+  }
+  for (int i = 0; i < 4; ++i) {
+    // Ratio between the 1:4 sub-block RD and the whole-block RD.
+    float var_ratio = (float)(vert_4_source_var[i] + 1) / denom;
+    if (var_ratio < low_b) var_ratio = low_b;
+    if (var_ratio > high_b) var_ratio = high_b;
+    features[feature_index++] = var_ratio;
+  }
+  assert(feature_index == FEATURES);
+
+  // Calculate scores using the NN model.
+  float score[LABELS] = { 0.0f };
+  av1_nn_predict(features, nn_config, score);
+  aom_clear_system_state();
+  int int_score[LABELS];
+  int max_score = -1000;
+  for (int i = 0; i < LABELS; ++i) {
+    int_score[i] = (int)(100 * score[i]);
+    max_score = AOMMAX(int_score[i], max_score);
+  }
+
+  // Make decisions based on the model scores.
+  int thresh = max_score;
+  switch (bsize) {
+    case BLOCK_16X16: thresh -= 500; break;
+    case BLOCK_32X32: thresh -= 500; break;
+    case BLOCK_64X64: thresh -= 200; break;
+    default: break;
+  }
+  *partition_horz4_allowed = 0;
+  *partition_vert4_allowed = 0;
+  for (int i = 0; i < LABELS; ++i) {
+    if (int_score[i] >= thresh) {
+      if ((i >> 0) & 1) *partition_horz4_allowed = 1;
+      if ((i >> 1) & 1) *partition_vert4_allowed = 1;
+    }
+  }
+}
+#undef FEATURES
+#undef LABELS
+
+#define FEATURES 4
+int av1_ml_predict_breakout(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
+                            const MACROBLOCK *const x,
+                            const RD_STATS *const rd_stats,
+                            unsigned int pb_source_variance) {
+  const NN_CONFIG *nn_config = NULL;
+  int thresh = 0;
+  switch (bsize) {
+    case BLOCK_8X8:
+      nn_config = &av1_partition_breakout_nnconfig_8;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[0];
+      break;
+    case BLOCK_16X16:
+      nn_config = &av1_partition_breakout_nnconfig_16;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[1];
+      break;
+    case BLOCK_32X32:
+      nn_config = &av1_partition_breakout_nnconfig_32;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[2];
+      break;
+    case BLOCK_64X64:
+      nn_config = &av1_partition_breakout_nnconfig_64;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[3];
+      break;
+    case BLOCK_128X128:
+      nn_config = &av1_partition_breakout_nnconfig_128;
+      thresh = cpi->sf.ml_partition_search_breakout_thresh[4];
+      break;
+    default: assert(0 && "Unexpected bsize.");
+  }
+  if (!nn_config || thresh < 0) return 0;
+
+  // Generate feature values.
+  float features[FEATURES];
+  int feature_index = 0;
+  aom_clear_system_state();
+
+  const int num_pels_log2 = num_pels_log2_lookup[bsize];
+  float rate_f = (float)AOMMIN(rd_stats->rate, INT_MAX);
+  rate_f = ((float)x->rdmult / 128.0f / 512.0f / (float)(1 << num_pels_log2)) *
+           rate_f;
+  features[feature_index++] = rate_f;
+
+  const float dist_f =
+      (float)(AOMMIN(rd_stats->dist, INT_MAX) >> num_pels_log2);
+  features[feature_index++] = dist_f;
+
+  features[feature_index++] = (float)pb_source_variance;
+
+  const int dc_q = (int)x->plane[0].dequant_QTX[0];
+  features[feature_index++] = (float)(dc_q * dc_q) / 256.0f;
+  assert(feature_index == FEATURES);
+
+  // Calculate score using the NN model.
+  float score = 0.0f;
+  av1_nn_predict(features, nn_config, &score);
+  aom_clear_system_state();
+
+  // Make decision.
+  return (int)(score * 100) >= thresh;
+}
+#undef FEATURES
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index 064f530..fbe832d 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -85,6 +85,61 @@
 BLOCK_SIZE av1_predict_max_partition(AV1_COMP *const cpi, MACROBLOCK *const x,
                                      const float *features);
 
+// Attempts an early termination after PARTITION_SPLIT.
+void av1_ml_early_term_after_split(AV1_COMP *const cpi, MACROBLOCK *const x,
+                                   PC_TREE *const pc_tree, BLOCK_SIZE bsize,
+                                   int64_t best_rd, int64_t part_none_rd,
+                                   int64_t part_split_rd,
+                                   int64_t *split_block_rd, int mi_row,
+                                   int mi_col,
+                                   int *const terminate_partition_search);
+
+// Use data from first partition pass to emit split_scores and none_scores.
+// Returns 0 if the firstpass data is not valid, 1  otherwise.
+// split_score indicates confidence of picking split partition;
+// none_score indicates confidence of picking none partition;
+int av1_ml_prune_2pass_split_partition(const PC_TREE_STATS *pc_tree_stats,
+                                       BLOCK_SIZE bsize, int *split_score,
+                                       int *none_score);
+
+// Use the rdcost ratio and source var ratio to prune PARTITION_HORZ and
+// PARTITION_VERT.
+// TODO(chiyotsai@google.com): Currently this model does not use q value and has
+// no information about rectangular partitions. Preliminary experiments suggest
+// that we can get better performance by adding in q_index and rectangular
+// sse/var from SMS. We should retrain and tune this model later.
+void av1_ml_prune_rect_partition(const AV1_COMP *const cpi,
+                                 const MACROBLOCK *const x, BLOCK_SIZE bsize,
+                                 int64_t best_rd, int64_t none_rd,
+                                 int64_t *split_rd, int *const dst_prune_horz,
+                                 int *const dst_prune_vert);
+
+// Use a ML model to predict if horz_a, horz_b, vert_a, and vert_b should be
+// considered.
+void av1_ml_prune_ab_partition(BLOCK_SIZE bsize, int part_ctx, int var_ctx,
+                               int64_t best_rd, int64_t horz_rd[2],
+                               int64_t vert_rd[2], int64_t split_rd[4],
+                               int *const horza_partition_allowed,
+                               int *const horzb_partition_allowed,
+                               int *const verta_partition_allowed,
+                               int *const vertb_partition_allowed);
+
+// Use a ML model to predict if horz4 and vert4 should be considered.
+void av1_ml_prune_4_partition(const AV1_COMP *const cpi, MACROBLOCK *const x,
+                              BLOCK_SIZE bsize, int part_ctx, int64_t best_rd,
+                              int64_t horz_rd[2], int64_t vert_rd[2],
+                              int64_t split_rd[4],
+                              int *const partition_horz4_allowed,
+                              int *const partition_vert4_allowed,
+                              unsigned int pb_source_variance, int mi_row,
+                              int mi_col);
+
+// ML-based partition search breakout after PARTITION_NONE
+int av1_ml_predict_breakout(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
+                            const MACROBLOCK *const x,
+                            const RD_STATS *const rd_stats,
+                            unsigned int pb_source_variance);
+
 // A simplified version of set_offsets meant to be used for
 // simple_motion_search.
 static INLINE void set_offsets_for_motion_search(const AV1_COMP *const cpi,
@@ -168,12 +223,4 @@
              INTNL_OVERLAY_UPDATE;
 }
 
-void av1_ml_early_term_after_split(AV1_COMP *const cpi, MACROBLOCK *const x,
-                                   PC_TREE *const pc_tree, BLOCK_SIZE bsize,
-                                   int64_t best_rd, int64_t part_none_rd,
-                                   int64_t part_split_rd,
-                                   int64_t *split_block_rd, int mi_row,
-                                   int mi_col,
-                                   int *const terminate_partition_search);
-
 #endif  // AOM_AV1_ENCODER_PARTITION_STRATEGY_H_