External partition: Prepare superblock features before search
Collect features and send to external model.
Change-Id: I11205825ea06833d76fdc8686cc09738a941a1be
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index d0fd8df..bb7ea7e 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -78,8 +78,10 @@
}
#if !CONFIG_REALTIME_ONLY
-// Write tpl stats to text file for each super block.
+// If input |features| is NULL, write tpl stats to file for each super block.
+// Otherwise, store tpl stats to |features|.
// The tpl stats is computed in the unit of tpl_bsize_1d (16x16).
+// When writing to text file:
// The first row contains super block position, super block size,
// tpl unit length, number of units in the super block.
// The second row contains the intra prediction cost for each unit.
@@ -87,7 +89,8 @@
// The forth row contains the motion compensated dependency cost for each unit.
static void collect_tpl_stats_sb(const AV1_COMP *const cpi,
const BLOCK_SIZE bsize, const int mi_row,
- const int mi_col) {
+ const int mi_col,
+ aom_partition_features_t *features) {
const AV1_COMMON *const cm = &cpi->common;
GF_GROUP *gf_group = &cpi->ppi->gf_group;
if (gf_group->update_type[cpi->gf_frame_index] == INTNL_OVERLAY_UPDATE ||
@@ -108,51 +111,73 @@
const int row_steps = mi_height / step;
const int num_blocks = col_steps * row_steps;
- char filename[256];
- snprintf(filename, sizeof(filename), "%s/tpl_feature_sb%d",
- cpi->oxcf.partition_info_path, cpi->sb_counter);
- FILE *pfile = fopen(filename, "w");
- fprintf(pfile, "%d,%d,%d,%d,%d\n", mi_row, mi_col, bsize,
- tpl_data->tpl_bsize_1d, num_blocks);
- int count = 0;
- for (int row = 0; row < mi_height; row += step) {
- for (int col = 0; col < mi_width; col += step) {
- TplDepStats *this_stats =
- &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
- tpl_data->tpl_stats_block_mis_log2)];
- fprintf(pfile, "%.0f", (double)this_stats->intra_cost);
- if (count < num_blocks - 1) fprintf(pfile, ",");
- ++count;
+ if (features == NULL) {
+ char filename[256];
+ snprintf(filename, sizeof(filename), "%s/tpl_feature_sb%d",
+ cpi->oxcf.partition_info_path, cpi->sb_counter);
+ FILE *pfile = fopen(filename, "w");
+ fprintf(pfile, "%d,%d,%d,%d,%d\n", mi_row, mi_col, bsize,
+ tpl_data->tpl_bsize_1d, num_blocks);
+ int count = 0;
+ for (int row = 0; row < mi_height; row += step) {
+ for (int col = 0; col < mi_width; col += step) {
+ TplDepStats *this_stats =
+ &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
+ tpl_data->tpl_stats_block_mis_log2)];
+ fprintf(pfile, "%.0f", (double)this_stats->intra_cost);
+ if (count < num_blocks - 1) fprintf(pfile, ",");
+ ++count;
+ }
+ }
+ fprintf(pfile, "\n");
+ count = 0;
+ for (int row = 0; row < mi_height; row += step) {
+ for (int col = 0; col < mi_width; col += step) {
+ TplDepStats *this_stats =
+ &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
+ tpl_data->tpl_stats_block_mis_log2)];
+ fprintf(pfile, "%.0f", (double)this_stats->inter_cost);
+ if (count < num_blocks - 1) fprintf(pfile, ",");
+ ++count;
+ }
+ }
+ fprintf(pfile, "\n");
+ count = 0;
+ for (int row = 0; row < mi_height; row += step) {
+ for (int col = 0; col < mi_width; col += step) {
+ TplDepStats *this_stats =
+ &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
+ tpl_data->tpl_stats_block_mis_log2)];
+ const int64_t mc_dep_delta =
+ RDCOST(tpl_frame->base_rdmult, this_stats->mc_dep_rate,
+ this_stats->mc_dep_dist);
+ fprintf(pfile, "%.0f", (double)mc_dep_delta);
+ if (count < num_blocks - 1) fprintf(pfile, ",");
+ ++count;
+ }
+ }
+ fclose(pfile);
+ } else {
+ features->sb_features.tpl_features.tpl_unit_length = tpl_data->tpl_bsize_1d;
+ features->sb_features.tpl_features.num_units = num_blocks;
+ int count = 0;
+ for (int row = 0; row < mi_height; row += step) {
+ for (int col = 0; col < mi_width; col += step) {
+ TplDepStats *this_stats =
+ &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
+ tpl_data->tpl_stats_block_mis_log2)];
+ const int64_t mc_dep_delta =
+ RDCOST(tpl_frame->base_rdmult, this_stats->mc_dep_rate,
+ this_stats->mc_dep_dist);
+ features->sb_features.tpl_features.intra_cost[count] =
+ this_stats->intra_cost;
+ features->sb_features.tpl_features.inter_cost[count] =
+ this_stats->inter_cost;
+ features->sb_features.tpl_features.mc_dep_cost[count] = mc_dep_delta;
+ ++count;
+ }
}
}
- fprintf(pfile, "\n");
- count = 0;
- for (int row = 0; row < mi_height; row += step) {
- for (int col = 0; col < mi_width; col += step) {
- TplDepStats *this_stats =
- &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
- tpl_data->tpl_stats_block_mis_log2)];
- fprintf(pfile, "%.0f", (double)this_stats->inter_cost);
- if (count < num_blocks - 1) fprintf(pfile, ",");
- ++count;
- }
- }
- fprintf(pfile, "\n");
- count = 0;
- for (int row = 0; row < mi_height; row += step) {
- for (int col = 0; col < mi_width; col += step) {
- TplDepStats *this_stats =
- &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
- tpl_data->tpl_stats_block_mis_log2)];
- const int64_t mc_dep_delta =
- RDCOST(tpl_frame->base_rdmult, this_stats->mc_dep_rate,
- this_stats->mc_dep_dist);
- fprintf(pfile, "%.0f", (double)mc_dep_delta);
- if (count < num_blocks - 1) fprintf(pfile, ",");
- ++count;
- }
- }
- fclose(pfile);
}
#endif // !CONFIG_REALTIME_ONLY
@@ -4053,11 +4078,48 @@
return best_rdc;
}
+static void prepare_sb_features_before_search(
+ AV1_COMP *const cpi, ThreadData *td, int mi_row, int mi_col,
+ const BLOCK_SIZE bsize, aom_partition_features_t *features) {
+ av1_collect_motion_search_features_sb(cpi, td, mi_row, mi_col, bsize,
+ features);
+ collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col, features);
+}
+
+static bool ml_partition_search(AV1_COMP *const cpi, ThreadData *td,
+ TileDataEnc *tile_data, TokenExtra **tp,
+ SIMPLE_MOTION_DATA_TREE *sms_root, int mi_row,
+ int mi_col, const BLOCK_SIZE bsize) {
+ (void)tile_data;
+ (void)tp;
+ (void)sms_root;
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+ aom_partition_features_t features;
+ prepare_sb_features_before_search(cpi, td, mi_row, mi_col, bsize, &features);
+ av1_ext_part_send_features(ext_part_controller, &features);
+
+ // TODO(chengchen): implement the main body (APIs) of the the function:
+ // (1). Get partition decision from external ml model.
+ // (2). Encode with the given decision.
+ // (3). Send stats back to external ml model.
+ return true;
+}
+
bool av1_rd_partition_search(AV1_COMP *const cpi, ThreadData *td,
TileDataEnc *tile_data, TokenExtra **tp,
SIMPLE_MOTION_DATA_TREE *sms_root, int mi_row,
int mi_col, const BLOCK_SIZE bsize,
RD_STATS *best_rd_cost) {
+ if (cpi->ext_part_controller.ready) {
+ const bool valid_search = ml_partition_search(
+ cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize);
+ if (!valid_search) {
+ assert(0 && "Invalid search from ML model, partition search failed.");
+ exit(0);
+ }
+ return true;
+ }
+
AV1_COMMON *const cm = &cpi->common;
MACROBLOCK *const x = &td->mb;
int best_idx = 0;
@@ -4224,8 +4286,9 @@
// av1_get_max_min_partition_features().
if (COLLECT_MOTION_SEARCH_FEATURE_SB && !frame_is_intra_only(cm) &&
bsize == cm->seq_params->sb_size) {
- av1_collect_motion_search_features_sb(cpi, td, mi_row, mi_col, bsize);
- collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col);
+ av1_collect_motion_search_features_sb(cpi, td, mi_row, mi_col, bsize,
+ /*features=*/NULL);
+ collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col, /*features=*/NULL);
}
// Update rd cost of the bound using the current multiplier.
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index 1825e98..facd088 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -2246,8 +2246,8 @@
snprintf(filename, sizeof(filename), "%s/motion_search_feature_sb%d", path,
sb_counter);
FILE *pfile = fopen(filename, "w");
- fprintf(pfile, "%d,%d,%d,%d,%d\n", mi_row, mi_col, bsize, fixed_block_size,
- num_blocks);
+ fprintf(pfile, "%d,%d,%d,%d,%d\n", mi_row, mi_col, bsize,
+ block_size_wide[fixed_block_size], num_blocks);
for (int i = 0; i < num_blocks; ++i) {
fprintf(pfile, "%d", block_sse[i]);
if (i < num_blocks - 1) fprintf(pfile, ",");
@@ -2263,7 +2263,8 @@
void av1_collect_motion_search_features_sb(AV1_COMP *const cpi, ThreadData *td,
const int mi_row, const int mi_col,
- const BLOCK_SIZE bsize) {
+ const BLOCK_SIZE bsize,
+ aom_partition_features_t *features) {
const AV1_COMMON *const cm = &cpi->common;
MACROBLOCK *const x = &td->mb;
const BLOCK_SIZE fixed_block_size = BLOCK_16X16;
@@ -2294,9 +2295,19 @@
++idx;
}
}
- write_motion_feature_to_file(cpi->oxcf.partition_info_path, cpi->sb_counter,
- block_sse, block_var, idx, bsize,
- fixed_block_size, mi_row, mi_col);
+ if (features == NULL) {
+ write_motion_feature_to_file(cpi->oxcf.partition_info_path, cpi->sb_counter,
+ block_sse, block_var, idx, bsize,
+ fixed_block_size, mi_row, mi_col);
+ } else {
+ features->sb_features.motion_features.unit_length =
+ block_size_wide[fixed_block_size];
+ features->sb_features.motion_features.num_units = idx;
+ for (int i = 0; i < idx; ++i) {
+ features->sb_features.motion_features.block_sse[i] = block_sse[i];
+ features->sb_features.motion_features.block_var[i] = block_var[i];
+ }
+ }
aom_free(block_sse);
aom_free(block_var);
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index 6b87bfa..23a94d7 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -150,7 +150,8 @@
void av1_collect_motion_search_features_sb(AV1_COMP *const cpi, ThreadData *td,
const int mi_row, const int mi_col,
- const BLOCK_SIZE bsize);
+ const BLOCK_SIZE bsize,
+ aom_partition_features_t *features);
#endif // !CONFIG_REALTIME_ONLY
// A simplified version of set_offsets meant to be used for