ExtPart: encode with decisions from ml model

This is the easy approach. We assume the ml model provides the whole
decision tree for the superblock.

Change-Id: Icff72e55a4dfd5d7cc5f656ab682656c15e4ee3a
diff --git a/aom/aom_external_partition.h b/aom/aom_external_partition.h
index 60318d8..0a14260 100644
--- a/aom/aom_external_partition.h
+++ b/aom/aom_external_partition.h
@@ -231,8 +231,9 @@
  */
 typedef struct aom_partition_decision {
   // Decisions for directly set partition types
-  int is_final_decision;       /**< The flag whether it is the final decision */
-  int partition_decision[256]; /**< Partition decisions */
+  int is_final_decision; /**< The flag whether it is the final decision */
+  int num_nodes;         /**< The number of leaf nodes */
+  int partition_decision[2048]; /**< Partition decisions */
 
   // Decisions for partition type pruning
   int terminate_partition_search; /**< Terminate further partition search */
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index eec232d..9958f81 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -4082,22 +4082,87 @@
   collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col, features);
 }
 
+static void update_partition_stats(const RD_STATS this_rdcost,
+                                   aom_partition_stats_t *stats) {
+  stats->rate = this_rdcost.rate;
+  stats->dist = this_rdcost.dist;
+  stats->rdcost = this_rdcost.rdcost;
+}
+
+static void build_pc_tree_from_part_decision(
+    const aom_partition_decision_t *partition_decision,
+    const BLOCK_SIZE this_bsize, PC_TREE *pc_tree) {
+  BLOCK_SIZE bsize = this_bsize;
+  int num_nodes = partition_decision->num_nodes;
+  PC_TREE *tree_node_queue[NUM_NODES] = { NULL };
+  int last_idx = 1;
+  int q_idx = 0;
+  tree_node_queue[q_idx] = pc_tree;
+  while (num_nodes > 0) {
+    const int partitioning = partition_decision->partition_decision[q_idx];
+    assert(partitioning >= PARTITION_NONE &&
+           partitioning < EXT_PARTITION_TYPES);
+    PC_TREE *node = tree_node_queue[q_idx];
+    if (node != NULL) node->partitioning = partitioning;
+    if (partitioning == PARTITION_SPLIT) {
+      const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+      for (int i = 0; i < 4; ++i) {
+        if (node != NULL) {  // Suppress warning
+          node->split[i] = av1_alloc_pc_tree_node(subsize);
+          node->split[i]->index = i;
+          tree_node_queue[last_idx] = node->split[i];
+          ++last_idx;
+        }
+      }
+      bsize = subsize;
+    }
+    --num_nodes;
+    ++q_idx;
+  }
+}
+
 static bool ml_partition_search(AV1_COMP *const cpi, ThreadData *td,
                                 TileDataEnc *tile_data, TokenExtra **tp,
                                 SIMPLE_MOTION_DATA_TREE *sms_root, int mi_row,
                                 int mi_col, const BLOCK_SIZE bsize) {
-  (void)tile_data;
-  (void)tp;
-  (void)sms_root;
+  AV1_COMMON *const cm = &cpi->common;
+  MACROBLOCK *const x = &td->mb;
   ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
   aom_partition_features_t features;
   prepare_sb_features_before_search(cpi, td, mi_row, mi_col, bsize, &features);
   av1_ext_part_send_features(ext_part_controller, &features);
+  PC_TREE *pc_tree;
 
-  // TODO(chengchen): implement the main body (APIs) of the the function:
-  // (1). Get partition decision from external ml model.
-  // (2). Encode with the given decision.
-  // (3). Send stats back to external ml model.
+  // rd mode search (dry run) for a valid partition decision from the ml model.
+  aom_partition_decision_t partition_decision;
+  do {
+    const bool valid_decision = av1_ext_part_get_partition_decision(
+        ext_part_controller, &partition_decision);
+    if (!valid_decision) return false;
+
+    // First, let's take the easy approach.
+    // We require that the ml model has to provide partition decisions for the
+    // whole superblock.
+    pc_tree = av1_alloc_pc_tree_node(bsize);
+    build_pc_tree_from_part_decision(&partition_decision, bsize, pc_tree);
+
+    const RD_STATS this_rdcost = rd_search_for_fixed_partition(
+        cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize, pc_tree);
+    aom_partition_stats_t stats;
+    update_partition_stats(this_rdcost, &stats);
+    av1_ext_part_send_partition_stats(ext_part_controller, &stats);
+    if (!partition_decision.is_final_decision) {
+      av1_free_pc_tree_recursive(pc_tree, av1_num_planes(cm), 0, 0);
+    }
+  } while (!partition_decision.is_final_decision);
+
+  // Encode with the selected mode and partition.
+  set_cb_offsets(x->cb_offset, 0, 0);
+  encode_sb(cpi, td, tile_data, tp, mi_row, mi_col, OUTPUT_ENABLED, bsize,
+            pc_tree, NULL);
+
+  av1_free_pc_tree_recursive(pc_tree, av1_num_planes(cm), 0, 0);
+
   return true;
 }