ExtPart: encode with decisions from ml model
This is the easy approach. We assume the ml model provides the whole
decision tree for the superblock.
Change-Id: Icff72e55a4dfd5d7cc5f656ab682656c15e4ee3a
diff --git a/aom/aom_external_partition.h b/aom/aom_external_partition.h
index 60318d8..0a14260 100644
--- a/aom/aom_external_partition.h
+++ b/aom/aom_external_partition.h
@@ -231,8 +231,9 @@
*/
typedef struct aom_partition_decision {
// Decisions for directly set partition types
- int is_final_decision; /**< The flag whether it is the final decision */
- int partition_decision[256]; /**< Partition decisions */
+ int is_final_decision; /**< The flag whether it is the final decision */
+ int num_nodes; /**< The number of leaf nodes */
+ int partition_decision[2048]; /**< Partition decisions */
// Decisions for partition type pruning
int terminate_partition_search; /**< Terminate further partition search */
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index eec232d..9958f81 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -4082,22 +4082,87 @@
collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col, features);
}
+static void update_partition_stats(const RD_STATS this_rdcost,
+ aom_partition_stats_t *stats) {
+ stats->rate = this_rdcost.rate;
+ stats->dist = this_rdcost.dist;
+ stats->rdcost = this_rdcost.rdcost;
+}
+
+static void build_pc_tree_from_part_decision(
+ const aom_partition_decision_t *partition_decision,
+ const BLOCK_SIZE this_bsize, PC_TREE *pc_tree) {
+ BLOCK_SIZE bsize = this_bsize;
+ int num_nodes = partition_decision->num_nodes;
+ PC_TREE *tree_node_queue[NUM_NODES] = { NULL };
+ int last_idx = 1;
+ int q_idx = 0;
+ tree_node_queue[q_idx] = pc_tree;
+ while (num_nodes > 0) {
+ const int partitioning = partition_decision->partition_decision[q_idx];
+ assert(partitioning >= PARTITION_NONE &&
+ partitioning < EXT_PARTITION_TYPES);
+ PC_TREE *node = tree_node_queue[q_idx];
+ if (node != NULL) node->partitioning = partitioning;
+ if (partitioning == PARTITION_SPLIT) {
+ const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+ for (int i = 0; i < 4; ++i) {
+ if (node != NULL) { // Suppress warning
+ node->split[i] = av1_alloc_pc_tree_node(subsize);
+ node->split[i]->index = i;
+ tree_node_queue[last_idx] = node->split[i];
+ ++last_idx;
+ }
+ }
+ bsize = subsize;
+ }
+ --num_nodes;
+ ++q_idx;
+ }
+}
+
static bool ml_partition_search(AV1_COMP *const cpi, ThreadData *td,
TileDataEnc *tile_data, TokenExtra **tp,
SIMPLE_MOTION_DATA_TREE *sms_root, int mi_row,
int mi_col, const BLOCK_SIZE bsize) {
- (void)tile_data;
- (void)tp;
- (void)sms_root;
+ AV1_COMMON *const cm = &cpi->common;
+ MACROBLOCK *const x = &td->mb;
ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
aom_partition_features_t features;
prepare_sb_features_before_search(cpi, td, mi_row, mi_col, bsize, &features);
av1_ext_part_send_features(ext_part_controller, &features);
+ PC_TREE *pc_tree;
- // TODO(chengchen): implement the main body (APIs) of the the function:
- // (1). Get partition decision from external ml model.
- // (2). Encode with the given decision.
- // (3). Send stats back to external ml model.
+ // rd mode search (dry run) for a valid partition decision from the ml model.
+ aom_partition_decision_t partition_decision;
+ do {
+ const bool valid_decision = av1_ext_part_get_partition_decision(
+ ext_part_controller, &partition_decision);
+ if (!valid_decision) return false;
+
+ // First, let's take the easy approach.
+ // We require that the ml model has to provide partition decisions for the
+ // whole superblock.
+ pc_tree = av1_alloc_pc_tree_node(bsize);
+ build_pc_tree_from_part_decision(&partition_decision, bsize, pc_tree);
+
+ const RD_STATS this_rdcost = rd_search_for_fixed_partition(
+ cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize, pc_tree);
+ aom_partition_stats_t stats;
+ update_partition_stats(this_rdcost, &stats);
+ av1_ext_part_send_partition_stats(ext_part_controller, &stats);
+ if (!partition_decision.is_final_decision) {
+ av1_free_pc_tree_recursive(pc_tree, av1_num_planes(cm), 0, 0);
+ }
+ } while (!partition_decision.is_final_decision);
+
+ // Encode with the selected mode and partition.
+ set_cb_offsets(x->cb_offset, 0, 0);
+ encode_sb(cpi, td, tile_data, tp, mi_row, mi_col, OUTPUT_ENABLED, bsize,
+ pc_tree, NULL);
+
+ av1_free_pc_tree_recursive(pc_tree, av1_num_planes(cm), 0, 0);
+
return true;
}