External partition: collect features and pass to external model

Use the external model's decision to overwrite the current partition
decision.

Change-Id: I64b93165f14defe5d42be9f35cd189d6a6d979f8
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index d238809..ddd8995 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -3283,6 +3283,11 @@
                                         PartitionSearchState *part_search_state,
                                         RD_STATS *best_rdc,
                                         unsigned int *pb_source_variance) {
+  if (av1_ext_ml_model_decision_after_none(cpi, x, sms_tree, part_search_state,
+                                           *pb_source_variance)) {
+    return;
+  }
+
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   PartitionBlkParams blk_params = part_search_state->part_blk_params;
@@ -3349,6 +3354,12 @@
     AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
     PartitionSearchState *part_search_state, RD_STATS *best_rdc,
     int64_t part_none_rd, int64_t part_split_rd) {
+  if (av1_ext_ml_model_decision_after_split(
+          cpi, x, sms_tree, part_search_state, best_rdc, part_none_rd,
+          part_split_rd, part_search_state->split_rd)) {
+    return;
+  }
+
   const AV1_COMMON *const cm = &cpi->common;
   PartitionBlkParams blk_params = part_search_state->part_blk_params;
   const int mi_row = blk_params.mi_row;
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index 1f67d9b..351bc13 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -35,6 +35,33 @@
     int mi_row, int mi_col, BLOCK_SIZE bsize, float *features,
     int features_to_get);
 
+static bool ext_ml_model_decision_before_none(
+    AV1_COMP *cpi, const float features_from_motion[FEATURE_SIZE_SMS_SPLIT],
+    int *partition_none_allowed, int *partition_horz_allowed,
+    int *partition_vert_allowed, int *do_rectangular_split,
+    int *do_square_split);
+
+static bool ext_ml_model_decision_before_none_part2(
+    AV1_COMP *cpi,
+    const float features_from_motion[FEATURE_SIZE_SMS_PRUNE_PART],
+    int *prune_horz, int *prune_vert);
+
+static bool ext_ml_model_decision_after_rect(
+    AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+    BLOCK_SIZE bsize, int pb_source_variance, int64_t best_rdcost,
+    int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+    int64_t split_rd[SUB_PARTITIONS_SPLIT], int ext_partition_allowed,
+    int partition_horz_allowed, int partition_vert_allowed,
+    int *horza_partition_allowed, int *horzb_partition_allowed,
+    int *verta_partition_allowed, int *vertb_partition_allowed);
+
+static bool ext_ml_model_decision_after_part_ab(
+    AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
+    int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+    int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
+    int *const partition_vert4_allowed, unsigned int pb_source_variance,
+    int mi_row, int mi_col);
+
 static INLINE int convert_bsize_to_idx(BLOCK_SIZE bsize) {
   switch (bsize) {
     case BLOCK_128X128: return 0;
@@ -284,6 +311,15 @@
   simple_motion_search_prune_part_features(cpi, x, sms_tree, mi_row, mi_col,
                                            bsize, features,
                                            FEATURE_SMS_SPLIT_MODEL_FLAG);
+
+  // Note: it is intended to not normalize the features here, to keep it
+  // consistent for all features collected and passed to the external model.
+  if (ext_ml_model_decision_before_none(
+          cpi, features, partition_none_allowed, partition_horz_allowed,
+          partition_vert_allowed, do_rectangular_split, do_square_split)) {
+    return;
+  }
+
   for (int idx = 0; idx < FEATURE_SIZE_SMS_SPLIT; idx++) {
     features[idx] = (features[idx] - ml_mean[idx]) / ml_std[idx];
   }
@@ -543,6 +579,19 @@
   simple_motion_search_prune_part_features(cpi, x, sms_tree, mi_row, mi_col,
                                            bsize, features,
                                            FEATURE_SMS_PRUNE_PART_FLAG);
+
+  // Note: it is intended to not normalize the features here, to keep it
+  // consistent for all features collected and passed to the external model.
+  if (cpi->sf.part_sf.simple_motion_search_prune_rect &&
+      !frame_is_intra_only(cm) &&
+      (partition_horz_allowed || partition_vert_allowed) &&
+      bsize >= BLOCK_8X8 && !av1_superres_scaled(cm)) {
+    if (ext_ml_model_decision_before_none_part2(cpi, features, prune_horz,
+                                                prune_vert)) {
+      return;
+    }
+  }
+
   for (int f_idx = 0; f_idx < FEATURE_SIZE_SMS_PRUNE_PART; f_idx++) {
     features[f_idx] = (features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
   }
@@ -1120,12 +1169,17 @@
 #define LABELS 4
 // Use a ML model to predict if horz4 and vert4 should be considered.
 void av1_ml_prune_4_partition(
-    const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
-    int part_ctx, int64_t best_rd,
-    int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+    AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
+    int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
     int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
     int *const partition_vert4_allowed, unsigned int pb_source_variance,
     int mi_row, int mi_col) {
+  if (ext_ml_model_decision_after_part_ab(
+          cpi, x, bsize, part_ctx, best_rd, rect_part_rd, split_rd,
+          partition_horz4_allowed, partition_vert4_allowed, pb_source_variance,
+          mi_row, mi_col))
+    return;
+
   if (best_rd >= 1000000000) return;
   int64_t *horz_rd = rect_part_rd[HORZ];
   int64_t *vert_rd = rect_part_rd[VERT];
@@ -1502,7 +1556,7 @@
 }
 
 void av1_prune_ab_partitions(
-    const AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+    AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
     BLOCK_SIZE bsize, int pb_source_variance, int64_t best_rdcost,
     int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
     int64_t split_rd[SUB_PARTITIONS_SPLIT],
@@ -1510,6 +1564,14 @@
     int partition_horz_allowed, int partition_vert_allowed,
     int *horza_partition_allowed, int *horzb_partition_allowed,
     int *verta_partition_allowed, int *vertb_partition_allowed) {
+  if (ext_ml_model_decision_after_rect(
+          cpi, x, pc_tree, bsize, pb_source_variance, best_rdcost, rect_part_rd,
+          split_rd, ext_partition_allowed, partition_horz_allowed,
+          partition_vert_allowed, horza_partition_allowed,
+          horzb_partition_allowed, verta_partition_allowed,
+          vertb_partition_allowed))
+    return;
+
   int64_t *horz_rd = rect_part_rd[HORZ];
   int64_t *vert_rd = rect_part_rd[VERT];
   const PartitionCfg *const part_cfg = &cpi->oxcf.part_cfg;
@@ -1636,4 +1698,557 @@
   }
 }
 
+// Prepare features for the external model. Specifically, features after
+// none partition is searched.
+static void prepare_features_after_part_none(
+    AV1_COMP *const cpi, AV1_COMMON *const cm, MACROBLOCK *const x,
+    SIMPLE_MOTION_DATA_TREE *sms_tree, PartitionSearchState *part_search_state,
+    const unsigned int pb_source_variance,
+    aom_partition_features_t *const features) {
+  aom_clear_system_state();
+  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  RD_STATS *this_rdc = &part_search_state->this_rdc;
+  const BLOCK_SIZE bsize = blk_params.bsize;
+  const int bit_depth = cm->seq_params->bit_depth;
+
+  // 4 features defined in av1_ml_predict_breakout().
+  const int num_pels_log2 = num_pels_log2_lookup[bsize];
+  float rate_f = (float)AOMMIN(this_rdc->rate, INT_MAX);
+  rate_f = ((float)x->rdmult / 128.0f / 512.0f / (float)(1 << num_pels_log2)) *
+           rate_f;
+  const float dist_f =
+      (float)(AOMMIN(this_rdc->dist, INT_MAX) >> num_pels_log2);
+  const int dc_q = (int)x->plane[0].dequant_QTX[0] >> (bit_depth - 8);
+  const float dc_q_f = (float)(dc_q * dc_q) / 256.0f;
+
+  features->after_part_none.rate = rate_f;
+  features->after_part_none.dist = dist_f;
+  // TODO(chengchen): is this normalized variance?
+  features->after_part_none.source_variance = (float)pb_source_variance;
+  features->after_part_none.q = dc_q_f;
+
+  // features below are used to decide "terminate_partition_search"
+  // defined in av1_simple_motion_search_early_term_none().
+  float features_terminate[FEATURE_SIZE_SMS_TERM_NONE] = { 0.0f };
+  simple_motion_search_prune_part_features(
+      cpi, x, sms_tree, blk_params.mi_row, blk_params.mi_col, bsize,
+      features_terminate, FEATURE_SMS_PRUNE_PART_FLAG);
+  int f_idx = FEATURE_SIZE_SMS_PRUNE_PART;
+
+  features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->rate);
+  features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->dist);
+  features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->rdcost);
+
+  assert(f_idx == FEATURE_SIZE_SMS_TERM_NONE);
+  for (int i = 0; i < FEATURE_SIZE_SMS_TERM_NONE; ++i) {
+    features->after_part_none.f_terminate[i] = features_terminate[i];
+  }
+}
+
+// Prepare features for the external model. Specifically, features after
+// split partition is searched.
+static void prepare_features_after_split(
+    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+    PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+    const int64_t partition_none_rdcost, const int64_t partition_split_rdcost,
+    const int64_t *split_block_rdcost,
+    aom_partition_features_t *const features) {
+  aom_clear_system_state();
+  const AV1_COMMON *const cm = &cpi->common;
+  PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  const BLOCK_SIZE bsize = blk_params.bsize;
+  const int bs = block_size_wide[bsize];
+  const int bit_depth = cm->seq_params->bit_depth;
+  const int dc_q =
+      av1_dc_quant_QTX(cm->quant_params.base_qindex, 0, bit_depth) >>
+      (bit_depth - 8);
+
+  // 31 features in av1_ml_early_term_after_split()
+  int feature_idx = 0;
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)dc_q / 4.0f);
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)best_rdc->rdcost / bs / bs / 1024.0f);
+
+  add_rd_feature(partition_none_rdcost, best_rdc->rdcost,
+                 features->after_part_split.f_terminate, &feature_idx);
+  add_rd_feature(partition_split_rdcost, best_rdc->rdcost,
+                 features->after_part_split.f_terminate, &feature_idx);
+
+  for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+    add_rd_feature(split_block_rdcost[i], best_rdc->rdcost,
+                   features->after_part_split.f_terminate, &feature_idx);
+    int min_bw = MAX_SB_SIZE_LOG2;
+    int min_bh = MAX_SB_SIZE_LOG2;
+    get_min_bsize(sms_tree->split[i], &min_bw, &min_bh);
+    features->after_part_split.f_terminate[feature_idx++] = (float)min_bw;
+    features->after_part_split.f_terminate[feature_idx++] = (float)min_bh;
+  }
+
+  simple_motion_search_prune_part_features(cpi, x, sms_tree, blk_params.mi_row,
+                                           blk_params.mi_col, bsize, NULL,
+                                           FEATURE_SMS_PRUNE_PART_FLAG);
+
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->sms_none_feat[1]);
+
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->split[0]->sms_none_feat[1]);
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->split[1]->sms_none_feat[1]);
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->split[2]->sms_none_feat[1]);
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->split[3]->sms_none_feat[1]);
+
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->sms_rect_feat[1]);
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->sms_rect_feat[3]);
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->sms_rect_feat[5]);
+  features->after_part_split.f_terminate[feature_idx++] =
+      logf(1.0f + (float)sms_tree->sms_rect_feat[7]);
+
+  assert(feature_idx == 31);
+
+  // 9 features av1_ml_prune_rect_partition()
+  for (int i = 0; i < 5; i++) features->after_part_split.f_prune_rect[i] = 1.0f;
+  if (part_search_state->none_rd > 0 && part_search_state->none_rd < 1000000000)
+    features->after_part_split.f_prune_rect[0] =
+        (float)part_search_state->none_rd / (float)best_rdc->rdcost;
+  for (int i = 0; i < SUB_PARTITIONS_SPLIT; i++) {
+    if (part_search_state->split_rd[i] > 0 &&
+        part_search_state->split_rd[i] < 1000000000)
+      features->after_part_split.f_prune_rect[1 + i] =
+          (float)part_search_state->split_rd[i] / (float)best_rdc->rdcost;
+  }
+
+  // Variance ratios
+  const MACROBLOCKD *const xd = &x->e_mbd;
+  int whole_block_variance;
+  if (is_cur_buf_hbd(xd)) {
+    whole_block_variance = av1_high_get_sby_perpixel_variance(
+        cpi, &x->plane[0].src, bsize, xd->bd);
+  } else {
+    whole_block_variance =
+        av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, bsize);
+  }
+  whole_block_variance = AOMMAX(whole_block_variance, 1);
+
+  int split_variance[SUB_PARTITIONS_SPLIT];
+  const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+  struct buf_2d buf;
+  buf.stride = x->plane[0].src.stride;
+  const int bw = block_size_wide[bsize];
+  for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+    const int x_idx = (i & 1) * bw / 2;
+    const int y_idx = (i >> 1) * bw / 2;
+    buf.buf = x->plane[0].src.buf + x_idx + y_idx * buf.stride;
+    if (is_cur_buf_hbd(xd)) {
+      split_variance[i] =
+          av1_high_get_sby_perpixel_variance(cpi, &buf, subsize, xd->bd);
+    } else {
+      split_variance[i] = av1_get_sby_perpixel_variance(cpi, &buf, subsize);
+    }
+  }
+
+  for (int i = 0; i < SUB_PARTITIONS_SPLIT; i++)
+    features->after_part_split.f_prune_rect[5 + i] =
+        (float)split_variance[i] / (float)whole_block_variance;
+}
+
+// Prepare features for the external model. Specifically, features after
+// rectangular partition is searched.
+static void prepare_features_after_part_rect(
+    const AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+    BLOCK_SIZE bsize, int64_t best_rdcost,
+    int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+    int64_t split_rd[SUB_PARTITIONS_SPLIT],
+    aom_partition_features_t *const features) {
+  (void)cpi;
+  if (bsize < BLOCK_8X8) return;
+
+  int64_t *horz_rd = rect_part_rd[HORZ];
+  int64_t *vert_rd = rect_part_rd[VERT];
+
+  int feature_index = 0;
+  features->after_part_rect.f[feature_index++] = (float)pc_tree->partitioning;
+  features->after_part_rect.f[feature_index++] =
+      (float)get_unsigned_bits(x->source_variance);
+  const int rdcost = (int)AOMMIN(INT_MAX, best_rdcost);
+  int sub_block_rdcost[8] = { 0 };
+  int rd_index = 0;
+  for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+    if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)horz_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+    if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)vert_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+    if (split_rd[i] > 0 && split_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)split_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < 8; ++i) {
+    // Ratio between the sub-block RD and the whole-block RD.
+    float rd_ratio = 1.0f;
+    if (sub_block_rdcost[i] > 0 && sub_block_rdcost[i] < rdcost)
+      rd_ratio = (float)sub_block_rdcost[i] / (float)rdcost;
+    features->after_part_rect.f[feature_index++] = rd_ratio;
+  }
+  assert(feature_index == 10);
+}
+
+// Prepare features for the external model. Specifically, features after
+// ab partition is searched.
+static void prepare_features_after_part_ab(
+    const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
+    int part_ctx, int64_t best_rd,
+    int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+    int64_t split_rd[SUB_PARTITIONS_SPLIT], unsigned int pb_source_variance,
+    int mi_row, int mi_col, aom_partition_features_t *const features) {
+  int64_t *horz_rd = rect_part_rd[HORZ];
+  int64_t *vert_rd = rect_part_rd[VERT];
+
+  aom_clear_system_state();
+
+  // Generate features.
+  int feature_index = 0;
+  features->after_part_ab.f[feature_index++] = (float)part_ctx;
+  features->after_part_ab.f[feature_index++] =
+      (float)get_unsigned_bits(pb_source_variance);
+
+  const int rdcost = (int)AOMMIN(INT_MAX, best_rd);
+  int sub_block_rdcost[8] = { 0 };
+  int rd_index = 0;
+  for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+    if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)horz_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+    if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)vert_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+    if (split_rd[i] > 0 && split_rd[i] < 1000000000)
+      sub_block_rdcost[rd_index] = (int)split_rd[i];
+    ++rd_index;
+  }
+  for (int i = 0; i < 8; ++i) {
+    // Ratio between the sub-block RD and the whole-block RD.
+    float rd_ratio = 1.0f;
+    if (sub_block_rdcost[i] > 0 && sub_block_rdcost[i] < rdcost)
+      rd_ratio = (float)sub_block_rdcost[i] / (float)rdcost;
+    features->after_part_ab.f[feature_index++] = rd_ratio;
+  }
+
+  // Get variance of the 1:4 and 4:1 sub-blocks.
+  unsigned int horz_4_source_var[SUB_PARTITIONS_PART4] = { 0 };
+  unsigned int vert_4_source_var[SUB_PARTITIONS_PART4] = { 0 };
+  {
+    BLOCK_SIZE horz_4_bs = get_partition_subsize(bsize, PARTITION_HORZ_4);
+    BLOCK_SIZE vert_4_bs = get_partition_subsize(bsize, PARTITION_VERT_4);
+    av1_setup_src_planes(x, cpi->source, mi_row, mi_col,
+                         av1_num_planes(&cpi->common), bsize);
+    const int src_stride = x->plane[0].src.stride;
+    uint8_t *src = x->plane[0].src.buf;
+    const MACROBLOCKD *const xd = &x->e_mbd;
+
+    struct buf_2d horz_4_src, vert_4_src;
+    horz_4_src.stride = src_stride;
+    vert_4_src.stride = src_stride;
+
+    for (int i = 0; i < SUB_PARTITIONS_PART4; ++i) {
+      horz_4_src.buf = src + i * block_size_high[horz_4_bs] * src_stride;
+      vert_4_src.buf = src + i * block_size_wide[vert_4_bs];
+
+      if (is_cur_buf_hbd(xd)) {
+        horz_4_source_var[i] = av1_high_get_sby_perpixel_variance(
+            cpi, &horz_4_src, horz_4_bs, xd->bd);
+        vert_4_source_var[i] = av1_high_get_sby_perpixel_variance(
+            cpi, &vert_4_src, vert_4_bs, xd->bd);
+      } else {
+        horz_4_source_var[i] =
+            av1_get_sby_perpixel_variance(cpi, &horz_4_src, horz_4_bs);
+        vert_4_source_var[i] =
+            av1_get_sby_perpixel_variance(cpi, &vert_4_src, vert_4_bs);
+      }
+    }
+  }
+
+  const float denom = (float)(pb_source_variance + 1);
+  const float low_b = 0.1f;
+  const float high_b = 10.0f;
+  for (int i = 0; i < SUB_PARTITIONS_PART4; ++i) {
+    // Ratio between the 4:1 sub-block variance and the whole-block variance.
+    float var_ratio = (float)(horz_4_source_var[i] + 1) / denom;
+    if (var_ratio < low_b) var_ratio = low_b;
+    if (var_ratio > high_b) var_ratio = high_b;
+    features->after_part_ab.f[feature_index++] = var_ratio;
+  }
+  for (int i = 0; i < SUB_PARTITIONS_PART4; ++i) {
+    // Ratio between the 1:4 sub-block RD and the whole-block RD.
+    float var_ratio = (float)(vert_4_source_var[i] + 1) / denom;
+    if (var_ratio < low_b) var_ratio = low_b;
+    if (var_ratio > high_b) var_ratio = high_b;
+    features->after_part_ab.f[feature_index++] = var_ratio;
+  }
+  assert(feature_index == 18);
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions before partition none. Specifically, these parameters:
+// partition_none_allowed
+// partition_horz_allowed
+// partition_vert_allowed
+// do_rectangular_split
+// do_square_split
+static bool ext_ml_model_decision_before_none(
+    AV1_COMP *cpi, const float features_from_motion[FEATURE_SIZE_SMS_SPLIT],
+    int *partition_none_allowed, int *partition_horz_allowed,
+    int *partition_vert_allowed, int *do_rectangular_split,
+    int *do_square_split) {
+  // If we do not let the external partition model make decisions, return false.
+  return false;
+
+  ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+  // Setup features.
+  aom_partition_features_t features;
+  for (int i = 0; i < FEATURE_SIZE_SMS_SPLIT; ++i) {
+    features.before_part_none.f[i] = features_from_motion[i];
+  }
+
+  // Send necessary features to the external model.
+  av1_ext_part_send_features(ext_part_controller, &features);
+
+  // Get partition decisions from the external model.
+  aom_partition_decision_t decision;
+  av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+  // Populate decisions
+  *partition_none_allowed = decision.partition_none_allowed;
+  *partition_horz_allowed = decision.partition_rect_allowed[HORZ];
+  *partition_vert_allowed = decision.partition_rect_allowed[VERT];
+  *do_rectangular_split = decision.do_rectangular_split;
+  *do_square_split = decision.do_square_split;
+
+  return true;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions before partition none. Specifically, these parameters:
+// prune_horz
+// prune_vert
+static bool ext_ml_model_decision_before_none_part2(
+    AV1_COMP *cpi,
+    const float features_from_motion[FEATURE_SIZE_SMS_PRUNE_PART],
+    int *prune_horz, int *prune_vert) {
+  // If we do not let the external partition model make decisions, return false.
+  return false;
+
+  ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+  // Setup features.
+  aom_partition_features_t features;
+  for (int i = 0; i < FEATURE_SIZE_SMS_PRUNE_PART; ++i) {
+    features.before_part_none.f[i] = features_from_motion[i];
+  }
+
+  // Send necessary features to the external model.
+  av1_ext_part_send_features(ext_part_controller, &features);
+
+  // Get partition decisions from the external model.
+  aom_partition_decision_t decision;
+  av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+  // Populate decisions
+  *prune_horz = decision.prune_rect_part[HORZ];
+  *prune_vert = decision.prune_rect_part[VERT];
+
+  return true;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions after none partition. Specifically, these parameters:
+// do_square_split
+// do_rectangular_split
+// terminate_partition_search
+bool av1_ext_ml_model_decision_after_none(
+    AV1_COMP *const cpi, MACROBLOCK *const x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+    PartitionSearchState *part_search_state,
+    const unsigned int pb_source_variance) {
+  // If we do not let the external partition model make decisions, return false.
+  return false;
+
+  AV1_COMMON *const cm = &cpi->common;
+  ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+  if (!frame_is_intra_only(cm) && ext_part_controller->ready) {
+    // Setup features.
+    aom_partition_features_t features;
+    prepare_features_after_part_none(cpi, cm, x, sms_tree, part_search_state,
+                                     pb_source_variance, &features);
+
+    // Send necessary features to the external model.
+    av1_ext_part_send_features(ext_part_controller, &features);
+
+    // Get partition decisions from the external model.
+    aom_partition_decision_t decision;
+    av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+    // Populate decisions
+    part_search_state->do_square_split = decision.do_square_split;
+    part_search_state->do_rectangular_split = decision.do_rectangular_split;
+    part_search_state->terminate_partition_search =
+        decision.terminate_partition_search;
+
+    return true;
+  }
+
+  return false;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions after none partition. Specifically, these parameters:
+// terminate_partition_search
+// prune_rect_part[HORZ]
+// prune_rect_part[VERT]
+bool av1_ext_ml_model_decision_after_split(
+    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+    PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+    const int64_t partition_none_rdcost, const int64_t partition_split_rdcost,
+    const int64_t *split_block_rdcost) {
+  // If we do not let the external partition model make decisions, return false.
+  return false;
+
+  const AV1_COMMON *const cm = &cpi->common;
+  ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+  if (!frame_is_intra_only(cm) && cpi->ext_part_controller.ready) {
+    // Setup features.
+    aom_partition_features_t features;
+    prepare_features_after_split(cpi, x, sms_tree, part_search_state, best_rdc,
+                                 partition_none_rdcost, partition_split_rdcost,
+                                 split_block_rdcost, &features);
+
+    // Send necessary features to the external model.
+    av1_ext_part_send_features(ext_part_controller, &features);
+
+    // Get partition decisions from the external model.
+    aom_partition_decision_t decision;
+    av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+    // Populate decisions
+    part_search_state->terminate_partition_search =
+        decision.terminate_partition_search;
+    part_search_state->prune_rect_part[HORZ] = decision.prune_rect_part[0];
+    part_search_state->prune_rect_part[VERT] = decision.prune_rect_part[1];
+
+    return true;
+  }
+
+  return false;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions after rectangular partition. Specifically, these parameters:
+// horza_partition_allowed
+// horzb_partition_allowed
+// verta_partition_allowed
+// vertb_partition_allowed
+static bool ext_ml_model_decision_after_rect(
+    AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+    BLOCK_SIZE bsize, int pb_source_variance, int64_t best_rdcost,
+    int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+    int64_t split_rd[SUB_PARTITIONS_SPLIT], int ext_partition_allowed,
+    int partition_horz_allowed, int partition_vert_allowed,
+    int *horza_partition_allowed, int *horzb_partition_allowed,
+    int *verta_partition_allowed, int *vertb_partition_allowed) {
+  // If we do not let the external partition model make decisions, return false.
+  return false;
+
+  (void)pb_source_variance;
+  const AV1_COMMON *const cm = &cpi->common;
+  ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+  const PartitionCfg *const part_cfg = &cpi->oxcf.part_cfg;
+  // The standard AB partitions are allowed initially if ext-partition-types are
+  // allowed.
+  int ab_partition_allowed =
+      ext_partition_allowed & part_cfg->enable_ab_partitions;
+
+  if (!frame_is_intra_only(cm) && ext_part_controller->ready &&
+      partition_horz_allowed && partition_vert_allowed &&
+      ab_partition_allowed) {
+    // Setup features.
+    aom_partition_features_t features;
+    prepare_features_after_part_rect(cpi, x, pc_tree, bsize, best_rdcost,
+                                     rect_part_rd, split_rd, &features);
+
+    // Send necessary features to the external model.
+    av1_ext_part_send_features(ext_part_controller, &features);
+
+    // Get partition decisions from the external model.
+    aom_partition_decision_t decision;
+    av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+    // Populate decisions
+    *horza_partition_allowed = decision.horza_partition_allowed;
+    *horzb_partition_allowed = decision.horzb_partition_allowed;
+    *verta_partition_allowed = decision.verta_partition_allowed;
+    *vertb_partition_allowed = decision.vertb_partition_allowed;
+
+    return true;
+  }
+
+  return false;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions after AB partition. Specifically, these parameters:
+// partition_vert4_allowed
+// partition_horz4_allowed
+static bool ext_ml_model_decision_after_part_ab(
+    AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
+    int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+    int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
+    int *const partition_vert4_allowed, unsigned int pb_source_variance,
+    int mi_row, int mi_col) {
+  // If we do not let the external partition model make decisions, return false.
+  return false;
+
+  const AV1_COMMON *const cm = &cpi->common;
+  ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+  if (!frame_is_intra_only(cm) && ext_part_controller->ready) {
+    // Setup features.
+    aom_partition_features_t features;
+    prepare_features_after_part_ab(cpi, x, bsize, part_ctx, best_rd,
+                                   rect_part_rd, split_rd, pb_source_variance,
+                                   mi_row, mi_col, &features);
+
+    // Send necessary features to the external model.
+    av1_ext_part_send_features(ext_part_controller, &features);
+
+    // Get partition decisions from the external model.
+    aom_partition_decision_t decision;
+    av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+    // Populate decisions
+    *partition_horz4_allowed = decision.partition_horz4_allowed;
+    *partition_vert4_allowed = decision.partition_vert4_allowed;
+
+    return true;
+  }
+
+  return false;
+}
+
 #endif  // !CONFIG_REALTIME_ONLY
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index 621edae..f744f87 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -98,9 +98,8 @@
 
 // Use a ML model to predict if horz4 and vert4 should be considered.
 void av1_ml_prune_4_partition(
-    const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
-    int part_ctx, int64_t best_rd,
-    int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+    AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
+    int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
     int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
     int *const partition_vert4_allowed, unsigned int pb_source_variance,
     int mi_row, int mi_col);
@@ -135,7 +134,7 @@
 // Prune out AB partitions based on rd decisions made from testing the
 // basic partitions.
 void av1_prune_ab_partitions(
-    const AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+    AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
     BLOCK_SIZE bsize, int pb_source_variance, int64_t best_rdcost,
     int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
     int64_t split_rd[SUB_PARTITIONS_SPLIT],
@@ -274,5 +273,16 @@
                sb_enc->min_partition_size);
   }
 }
+
+bool av1_ext_ml_model_decision_after_none(
+    AV1_COMP *const cpi, MACROBLOCK *const x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+    PartitionSearchState *part_search_state,
+    const unsigned int pb_source_variance);
+
+bool av1_ext_ml_model_decision_after_split(
+    AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+    PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+    const int64_t partition_none_rdcost, const int64_t partition_split_rdcost,
+    const int64_t *split_block_rdcost);
 #endif  // !CONFIG_REALTIME_ONLY
 #endif  // AOM_AV1_ENCODER_PARTITION_STRATEGY_H_