ExtPart: encode with ml decision recursively
This approach is similar to the baseline DFS partition search.
Each time the ML model makes one decision for the current block,
then the encoder encodes and send information back to the external
model.
Such process starts from the super block and proceeds recursively
to smaller blocks when split happens.
Change-Id: Id81a3b747938735567d833625efb986821a72596
diff --git a/aom/aom_external_partition.h b/aom/aom_external_partition.h
index 4d50833..b6ea96e 100644
--- a/aom/aom_external_partition.h
+++ b/aom/aom_external_partition.h
@@ -238,6 +238,7 @@
int is_final_decision; /**< The flag whether it is the final decision */
int num_nodes; /**< The number of leaf nodes */
int partition_decision[2048]; /**< Partition decisions */
+ int current_decision; /**< Partition decision for the current block */
// Decisions for partition type pruning
int terminate_partition_search; /**< Terminate further partition search */
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index a263f91..ccdd8e5 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -34,6 +34,7 @@
#endif
#define COLLECT_MOTION_SEARCH_FEATURE_SB 0
+#define ML_PARTITION_WHOLE_TREE_DECISION 0
void av1_reset_part_sf(PARTITION_SPEED_FEATURES *part_sf) {
part_sf->partition_search_type = SEARCH_PARTITION;
@@ -4105,11 +4106,11 @@
collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col, features);
}
-static void update_partition_stats(const RD_STATS this_rdcost,
+static void update_partition_stats(const RD_STATS *const this_rdcost,
aom_partition_stats_t *stats) {
- stats->rate = this_rdcost.rate;
- stats->dist = this_rdcost.dist;
- stats->rdcost = this_rdcost.rdcost;
+ stats->rate = this_rdcost->rate;
+ stats->dist = this_rdcost->dist;
+ stats->rdcost = this_rdcost->rdcost;
}
static void build_pc_tree_from_part_decision(
@@ -4144,10 +4145,13 @@
}
}
-static bool ml_partition_search(AV1_COMP *const cpi, ThreadData *td,
- TileDataEnc *tile_data, TokenExtra **tp,
- SIMPLE_MOTION_DATA_TREE *sms_root, int mi_row,
- int mi_col, const BLOCK_SIZE bsize) {
+// The ML model needs to provide the whole decition tree for the superblock.
+static bool ml_partition_search_whole_tree(AV1_COMP *const cpi, ThreadData *td,
+ TileDataEnc *tile_data,
+ TokenExtra **tp,
+ SIMPLE_MOTION_DATA_TREE *sms_root,
+ int mi_row, int mi_col,
+ const BLOCK_SIZE bsize) {
AV1_COMMON *const cm = &cpi->common;
MACROBLOCK *const x = &td->mb;
ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
@@ -4176,7 +4180,7 @@
const RD_STATS this_rdcost = rd_search_for_fixed_partition(
cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize, pc_tree);
aom_partition_stats_t stats;
- update_partition_stats(this_rdcost, &stats);
+ update_partition_stats(&this_rdcost, &stats);
av1_ext_part_send_partition_stats(ext_part_controller, &stats);
if (!partition_decision.is_final_decision) {
av1_free_pc_tree_recursive(pc_tree, av1_num_planes(cm), 0, 0);
@@ -4193,14 +4197,136 @@
return true;
}
+static bool recursive_partition(AV1_COMP *const cpi, ThreadData *td,
+ TileDataEnc *tile_data, TokenExtra **tp,
+ SIMPLE_MOTION_DATA_TREE *sms_root,
+ PC_TREE *pc_tree, int mi_row, int mi_col,
+ const BLOCK_SIZE bsize, RD_STATS *this_rdcost) {
+ const AV1_COMMON *const cm = &cpi->common;
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+ MACROBLOCK *const x = &td->mb;
+ if (mi_row >= cm->mi_params.mi_rows || mi_col >= cm->mi_params.mi_cols) {
+ return false;
+ }
+ aom_partition_decision_t partition_decision;
+ do {
+ const bool valid_decision = av1_ext_part_get_partition_decision(
+ ext_part_controller, &partition_decision);
+ if (!valid_decision) return false;
+ pc_tree->partitioning = partition_decision.current_decision;
+ PartitionSearchState part_search_state;
+
+ // Initialization of state variables used in partition search.
+ // TODO(chengchen): check if there is hidden conditions that don't allow
+ // all possible partition types.
+ init_partition_search_state_params(x, cpi, &part_search_state, mi_row,
+ mi_col, bsize);
+
+ av1_init_rd_stats(this_rdcost);
+ if (partition_decision.current_decision == PARTITION_SPLIT) {
+ assert(block_size_wide[bsize] >= 8 && block_size_high[bsize] >= 8);
+ const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+ RD_STATS split_rdc[SUB_PARTITIONS_SPLIT];
+ for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+ av1_init_rd_stats(&split_rdc[i]);
+ if (pc_tree->split[i] == NULL)
+ pc_tree->split[i] = av1_alloc_pc_tree_node(subsize);
+ pc_tree->split[i]->index = i;
+ }
+ // TODO(chengchen): check boundary conditions
+ // top-left
+ recursive_partition(cpi, td, tile_data, tp, sms_root, pc_tree->split[0],
+ mi_row, mi_col, subsize, &split_rdc[0]);
+ // top-right
+ recursive_partition(cpi, td, tile_data, tp, sms_root, pc_tree->split[1],
+ mi_row, mi_col + mi_size_wide[subsize], subsize,
+ &split_rdc[1]);
+ // bottom-left
+ recursive_partition(cpi, td, tile_data, tp, sms_root, pc_tree->split[2],
+ mi_row + mi_size_high[subsize], mi_col, subsize,
+ &split_rdc[2]);
+ // bottom_right
+ recursive_partition(cpi, td, tile_data, tp, sms_root, pc_tree->split[3],
+ mi_row + mi_size_high[subsize],
+ mi_col + mi_size_wide[subsize], subsize,
+ &split_rdc[3]);
+ this_rdcost->rate += part_search_state.partition_cost[PARTITION_SPLIT];
+ for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+ this_rdcost->rate += split_rdc[i].rate;
+ this_rdcost->dist += split_rdc[i].dist;
+ av1_rd_cost_update(x->rdmult, this_rdcost);
+ }
+ } else {
+ *this_rdcost = rd_search_for_fixed_partition(
+ cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize, pc_tree);
+ }
+
+ aom_partition_stats_t stats;
+ update_partition_stats(this_rdcost, &stats);
+ av1_ext_part_send_partition_stats(ext_part_controller, &stats);
+ if (!partition_decision.is_final_decision) {
+ if (partition_decision.current_decision == PARTITION_SPLIT) {
+ for (int i = 0; i < 4; ++i) {
+ if (pc_tree->split[i] != NULL) {
+ av1_free_pc_tree_recursive(pc_tree->split[i], av1_num_planes(cm), 0,
+ 0);
+ pc_tree->split[i] = NULL;
+ }
+ }
+ }
+ }
+ } while (!partition_decision.is_final_decision);
+
+ return true;
+}
+
+// The ML model only needs to make decisions for the current block each time.
+static bool ml_partition_search_partial(AV1_COMP *const cpi, ThreadData *td,
+ TileDataEnc *tile_data, TokenExtra **tp,
+ SIMPLE_MOTION_DATA_TREE *sms_root,
+ int mi_row, int mi_col,
+ const BLOCK_SIZE bsize) {
+ AV1_COMMON *const cm = &cpi->common;
+ MACROBLOCK *const x = &td->mb;
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+ aom_partition_features_t features;
+ prepare_sb_features_before_search(cpi, td, mi_row, mi_col, bsize, &features);
+ av1_ext_part_send_features(ext_part_controller, &features);
+ PC_TREE *pc_tree;
+ pc_tree = av1_alloc_pc_tree_node(bsize);
+
+ RD_STATS rdcost;
+ const bool valid_partition =
+ recursive_partition(cpi, td, tile_data, tp, sms_root, pc_tree, mi_row,
+ mi_col, bsize, &rdcost);
+ if (!valid_partition) {
+ return false;
+ }
+
+ // Encode with the selected mode and partition.
+ set_cb_offsets(x->cb_offset, 0, 0);
+ encode_sb(cpi, td, tile_data, tp, mi_row, mi_col, OUTPUT_ENABLED, bsize,
+ pc_tree, NULL);
+
+ av1_free_pc_tree_recursive(pc_tree, av1_num_planes(cm), 0, 0);
+
+ return true;
+}
+
bool av1_rd_partition_search(AV1_COMP *const cpi, ThreadData *td,
TileDataEnc *tile_data, TokenExtra **tp,
SIMPLE_MOTION_DATA_TREE *sms_root, int mi_row,
int mi_col, const BLOCK_SIZE bsize,
RD_STATS *best_rd_cost) {
if (cpi->ext_part_controller.ready) {
- const bool valid_search = ml_partition_search(
- cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize);
+ bool valid_search = true;
+ if (ML_PARTITION_WHOLE_TREE_DECISION) {
+ valid_search = ml_partition_search_whole_tree(
+ cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize);
+ } else {
+ valid_search = ml_partition_search_partial(
+ cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize);
+ }
if (!valid_search) {
assert(0 && "Invalid search from ML model, partition search failed.");
exit(0);