External partition: collect features and pass to external model
Use the external model's decision to overwrite the current partition
decision.
Change-Id: I64b93165f14defe5d42be9f35cd189d6a6d979f8
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index d238809..ddd8995 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -3283,6 +3283,11 @@
PartitionSearchState *part_search_state,
RD_STATS *best_rdc,
unsigned int *pb_source_variance) {
+ if (av1_ext_ml_model_decision_after_none(cpi, x, sms_tree, part_search_state,
+ *pb_source_variance)) {
+ return;
+ }
+
const AV1_COMMON *const cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
PartitionBlkParams blk_params = part_search_state->part_blk_params;
@@ -3349,6 +3354,12 @@
AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
PartitionSearchState *part_search_state, RD_STATS *best_rdc,
int64_t part_none_rd, int64_t part_split_rd) {
+ if (av1_ext_ml_model_decision_after_split(
+ cpi, x, sms_tree, part_search_state, best_rdc, part_none_rd,
+ part_split_rd, part_search_state->split_rd)) {
+ return;
+ }
+
const AV1_COMMON *const cm = &cpi->common;
PartitionBlkParams blk_params = part_search_state->part_blk_params;
const int mi_row = blk_params.mi_row;
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index 1f67d9b..351bc13 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -35,6 +35,33 @@
int mi_row, int mi_col, BLOCK_SIZE bsize, float *features,
int features_to_get);
+static bool ext_ml_model_decision_before_none(
+ AV1_COMP *cpi, const float features_from_motion[FEATURE_SIZE_SMS_SPLIT],
+ int *partition_none_allowed, int *partition_horz_allowed,
+ int *partition_vert_allowed, int *do_rectangular_split,
+ int *do_square_split);
+
+static bool ext_ml_model_decision_before_none_part2(
+ AV1_COMP *cpi,
+ const float features_from_motion[FEATURE_SIZE_SMS_PRUNE_PART],
+ int *prune_horz, int *prune_vert);
+
+static bool ext_ml_model_decision_after_rect(
+ AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+ BLOCK_SIZE bsize, int pb_source_variance, int64_t best_rdcost,
+ int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+ int64_t split_rd[SUB_PARTITIONS_SPLIT], int ext_partition_allowed,
+ int partition_horz_allowed, int partition_vert_allowed,
+ int *horza_partition_allowed, int *horzb_partition_allowed,
+ int *verta_partition_allowed, int *vertb_partition_allowed);
+
+static bool ext_ml_model_decision_after_part_ab(
+ AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
+ int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+ int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
+ int *const partition_vert4_allowed, unsigned int pb_source_variance,
+ int mi_row, int mi_col);
+
static INLINE int convert_bsize_to_idx(BLOCK_SIZE bsize) {
switch (bsize) {
case BLOCK_128X128: return 0;
@@ -284,6 +311,15 @@
simple_motion_search_prune_part_features(cpi, x, sms_tree, mi_row, mi_col,
bsize, features,
FEATURE_SMS_SPLIT_MODEL_FLAG);
+
+ // Note: it is intended to not normalize the features here, to keep it
+ // consistent for all features collected and passed to the external model.
+ if (ext_ml_model_decision_before_none(
+ cpi, features, partition_none_allowed, partition_horz_allowed,
+ partition_vert_allowed, do_rectangular_split, do_square_split)) {
+ return;
+ }
+
for (int idx = 0; idx < FEATURE_SIZE_SMS_SPLIT; idx++) {
features[idx] = (features[idx] - ml_mean[idx]) / ml_std[idx];
}
@@ -543,6 +579,19 @@
simple_motion_search_prune_part_features(cpi, x, sms_tree, mi_row, mi_col,
bsize, features,
FEATURE_SMS_PRUNE_PART_FLAG);
+
+ // Note: it is intended to not normalize the features here, to keep it
+ // consistent for all features collected and passed to the external model.
+ if (cpi->sf.part_sf.simple_motion_search_prune_rect &&
+ !frame_is_intra_only(cm) &&
+ (partition_horz_allowed || partition_vert_allowed) &&
+ bsize >= BLOCK_8X8 && !av1_superres_scaled(cm)) {
+ if (ext_ml_model_decision_before_none_part2(cpi, features, prune_horz,
+ prune_vert)) {
+ return;
+ }
+ }
+
for (int f_idx = 0; f_idx < FEATURE_SIZE_SMS_PRUNE_PART; f_idx++) {
features[f_idx] = (features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
}
@@ -1120,12 +1169,17 @@
#define LABELS 4
// Use a ML model to predict if horz4 and vert4 should be considered.
void av1_ml_prune_4_partition(
- const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
- int part_ctx, int64_t best_rd,
- int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+ AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
+ int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
int *const partition_vert4_allowed, unsigned int pb_source_variance,
int mi_row, int mi_col) {
+ if (ext_ml_model_decision_after_part_ab(
+ cpi, x, bsize, part_ctx, best_rd, rect_part_rd, split_rd,
+ partition_horz4_allowed, partition_vert4_allowed, pb_source_variance,
+ mi_row, mi_col))
+ return;
+
if (best_rd >= 1000000000) return;
int64_t *horz_rd = rect_part_rd[HORZ];
int64_t *vert_rd = rect_part_rd[VERT];
@@ -1502,7 +1556,7 @@
}
void av1_prune_ab_partitions(
- const AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+ AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
BLOCK_SIZE bsize, int pb_source_variance, int64_t best_rdcost,
int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
int64_t split_rd[SUB_PARTITIONS_SPLIT],
@@ -1510,6 +1564,14 @@
int partition_horz_allowed, int partition_vert_allowed,
int *horza_partition_allowed, int *horzb_partition_allowed,
int *verta_partition_allowed, int *vertb_partition_allowed) {
+ if (ext_ml_model_decision_after_rect(
+ cpi, x, pc_tree, bsize, pb_source_variance, best_rdcost, rect_part_rd,
+ split_rd, ext_partition_allowed, partition_horz_allowed,
+ partition_vert_allowed, horza_partition_allowed,
+ horzb_partition_allowed, verta_partition_allowed,
+ vertb_partition_allowed))
+ return;
+
int64_t *horz_rd = rect_part_rd[HORZ];
int64_t *vert_rd = rect_part_rd[VERT];
const PartitionCfg *const part_cfg = &cpi->oxcf.part_cfg;
@@ -1636,4 +1698,557 @@
}
}
+// Prepare features for the external model. Specifically, features after
+// none partition is searched.
+static void prepare_features_after_part_none(
+ AV1_COMP *const cpi, AV1_COMMON *const cm, MACROBLOCK *const x,
+ SIMPLE_MOTION_DATA_TREE *sms_tree, PartitionSearchState *part_search_state,
+ const unsigned int pb_source_variance,
+ aom_partition_features_t *const features) {
+ aom_clear_system_state();
+ PartitionBlkParams blk_params = part_search_state->part_blk_params;
+ RD_STATS *this_rdc = &part_search_state->this_rdc;
+ const BLOCK_SIZE bsize = blk_params.bsize;
+ const int bit_depth = cm->seq_params->bit_depth;
+
+ // 4 features defined in av1_ml_predict_breakout().
+ const int num_pels_log2 = num_pels_log2_lookup[bsize];
+ float rate_f = (float)AOMMIN(this_rdc->rate, INT_MAX);
+ rate_f = ((float)x->rdmult / 128.0f / 512.0f / (float)(1 << num_pels_log2)) *
+ rate_f;
+ const float dist_f =
+ (float)(AOMMIN(this_rdc->dist, INT_MAX) >> num_pels_log2);
+ const int dc_q = (int)x->plane[0].dequant_QTX[0] >> (bit_depth - 8);
+ const float dc_q_f = (float)(dc_q * dc_q) / 256.0f;
+
+ features->after_part_none.rate = rate_f;
+ features->after_part_none.dist = dist_f;
+ // TODO(chengchen): is this normalized variance?
+ features->after_part_none.source_variance = (float)pb_source_variance;
+ features->after_part_none.q = dc_q_f;
+
+ // features below are used to decide "terminate_partition_search"
+ // defined in av1_simple_motion_search_early_term_none().
+ float features_terminate[FEATURE_SIZE_SMS_TERM_NONE] = { 0.0f };
+ simple_motion_search_prune_part_features(
+ cpi, x, sms_tree, blk_params.mi_row, blk_params.mi_col, bsize,
+ features_terminate, FEATURE_SMS_PRUNE_PART_FLAG);
+ int f_idx = FEATURE_SIZE_SMS_PRUNE_PART;
+
+ features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->rate);
+ features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->dist);
+ features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->rdcost);
+
+ assert(f_idx == FEATURE_SIZE_SMS_TERM_NONE);
+ for (int i = 0; i < FEATURE_SIZE_SMS_TERM_NONE; ++i) {
+ features->after_part_none.f_terminate[i] = features_terminate[i];
+ }
+}
+
+// Prepare features for the external model. Specifically, features after
+// split partition is searched.
+static void prepare_features_after_split(
+ AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+ PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+ const int64_t partition_none_rdcost, const int64_t partition_split_rdcost,
+ const int64_t *split_block_rdcost,
+ aom_partition_features_t *const features) {
+ aom_clear_system_state();
+ const AV1_COMMON *const cm = &cpi->common;
+ PartitionBlkParams blk_params = part_search_state->part_blk_params;
+ const BLOCK_SIZE bsize = blk_params.bsize;
+ const int bs = block_size_wide[bsize];
+ const int bit_depth = cm->seq_params->bit_depth;
+ const int dc_q =
+ av1_dc_quant_QTX(cm->quant_params.base_qindex, 0, bit_depth) >>
+ (bit_depth - 8);
+
+ // 31 features in av1_ml_early_term_after_split()
+ int feature_idx = 0;
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)dc_q / 4.0f);
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)best_rdc->rdcost / bs / bs / 1024.0f);
+
+ add_rd_feature(partition_none_rdcost, best_rdc->rdcost,
+ features->after_part_split.f_terminate, &feature_idx);
+ add_rd_feature(partition_split_rdcost, best_rdc->rdcost,
+ features->after_part_split.f_terminate, &feature_idx);
+
+ for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+ add_rd_feature(split_block_rdcost[i], best_rdc->rdcost,
+ features->after_part_split.f_terminate, &feature_idx);
+ int min_bw = MAX_SB_SIZE_LOG2;
+ int min_bh = MAX_SB_SIZE_LOG2;
+ get_min_bsize(sms_tree->split[i], &min_bw, &min_bh);
+ features->after_part_split.f_terminate[feature_idx++] = (float)min_bw;
+ features->after_part_split.f_terminate[feature_idx++] = (float)min_bh;
+ }
+
+ simple_motion_search_prune_part_features(cpi, x, sms_tree, blk_params.mi_row,
+ blk_params.mi_col, bsize, NULL,
+ FEATURE_SMS_PRUNE_PART_FLAG);
+
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->sms_none_feat[1]);
+
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->split[0]->sms_none_feat[1]);
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->split[1]->sms_none_feat[1]);
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->split[2]->sms_none_feat[1]);
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->split[3]->sms_none_feat[1]);
+
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->sms_rect_feat[1]);
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->sms_rect_feat[3]);
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->sms_rect_feat[5]);
+ features->after_part_split.f_terminate[feature_idx++] =
+ logf(1.0f + (float)sms_tree->sms_rect_feat[7]);
+
+ assert(feature_idx == 31);
+
+ // 9 features av1_ml_prune_rect_partition()
+ for (int i = 0; i < 5; i++) features->after_part_split.f_prune_rect[i] = 1.0f;
+ if (part_search_state->none_rd > 0 && part_search_state->none_rd < 1000000000)
+ features->after_part_split.f_prune_rect[0] =
+ (float)part_search_state->none_rd / (float)best_rdc->rdcost;
+ for (int i = 0; i < SUB_PARTITIONS_SPLIT; i++) {
+ if (part_search_state->split_rd[i] > 0 &&
+ part_search_state->split_rd[i] < 1000000000)
+ features->after_part_split.f_prune_rect[1 + i] =
+ (float)part_search_state->split_rd[i] / (float)best_rdc->rdcost;
+ }
+
+ // Variance ratios
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ int whole_block_variance;
+ if (is_cur_buf_hbd(xd)) {
+ whole_block_variance = av1_high_get_sby_perpixel_variance(
+ cpi, &x->plane[0].src, bsize, xd->bd);
+ } else {
+ whole_block_variance =
+ av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, bsize);
+ }
+ whole_block_variance = AOMMAX(whole_block_variance, 1);
+
+ int split_variance[SUB_PARTITIONS_SPLIT];
+ const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+ struct buf_2d buf;
+ buf.stride = x->plane[0].src.stride;
+ const int bw = block_size_wide[bsize];
+ for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+ const int x_idx = (i & 1) * bw / 2;
+ const int y_idx = (i >> 1) * bw / 2;
+ buf.buf = x->plane[0].src.buf + x_idx + y_idx * buf.stride;
+ if (is_cur_buf_hbd(xd)) {
+ split_variance[i] =
+ av1_high_get_sby_perpixel_variance(cpi, &buf, subsize, xd->bd);
+ } else {
+ split_variance[i] = av1_get_sby_perpixel_variance(cpi, &buf, subsize);
+ }
+ }
+
+ for (int i = 0; i < SUB_PARTITIONS_SPLIT; i++)
+ features->after_part_split.f_prune_rect[5 + i] =
+ (float)split_variance[i] / (float)whole_block_variance;
+}
+
+// Prepare features for the external model. Specifically, features after
+// rectangular partition is searched.
+static void prepare_features_after_part_rect(
+ const AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+ BLOCK_SIZE bsize, int64_t best_rdcost,
+ int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+ int64_t split_rd[SUB_PARTITIONS_SPLIT],
+ aom_partition_features_t *const features) {
+ (void)cpi;
+ if (bsize < BLOCK_8X8) return;
+
+ int64_t *horz_rd = rect_part_rd[HORZ];
+ int64_t *vert_rd = rect_part_rd[VERT];
+
+ int feature_index = 0;
+ features->after_part_rect.f[feature_index++] = (float)pc_tree->partitioning;
+ features->after_part_rect.f[feature_index++] =
+ (float)get_unsigned_bits(x->source_variance);
+ const int rdcost = (int)AOMMIN(INT_MAX, best_rdcost);
+ int sub_block_rdcost[8] = { 0 };
+ int rd_index = 0;
+ for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+ if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
+ sub_block_rdcost[rd_index] = (int)horz_rd[i];
+ ++rd_index;
+ }
+ for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+ if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
+ sub_block_rdcost[rd_index] = (int)vert_rd[i];
+ ++rd_index;
+ }
+ for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+ if (split_rd[i] > 0 && split_rd[i] < 1000000000)
+ sub_block_rdcost[rd_index] = (int)split_rd[i];
+ ++rd_index;
+ }
+ for (int i = 0; i < 8; ++i) {
+ // Ratio between the sub-block RD and the whole-block RD.
+ float rd_ratio = 1.0f;
+ if (sub_block_rdcost[i] > 0 && sub_block_rdcost[i] < rdcost)
+ rd_ratio = (float)sub_block_rdcost[i] / (float)rdcost;
+ features->after_part_rect.f[feature_index++] = rd_ratio;
+ }
+ assert(feature_index == 10);
+}
+
+// Prepare features for the external model. Specifically, features after
+// ab partition is searched.
+static void prepare_features_after_part_ab(
+ const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
+ int part_ctx, int64_t best_rd,
+ int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+ int64_t split_rd[SUB_PARTITIONS_SPLIT], unsigned int pb_source_variance,
+ int mi_row, int mi_col, aom_partition_features_t *const features) {
+ int64_t *horz_rd = rect_part_rd[HORZ];
+ int64_t *vert_rd = rect_part_rd[VERT];
+
+ aom_clear_system_state();
+
+ // Generate features.
+ int feature_index = 0;
+ features->after_part_ab.f[feature_index++] = (float)part_ctx;
+ features->after_part_ab.f[feature_index++] =
+ (float)get_unsigned_bits(pb_source_variance);
+
+ const int rdcost = (int)AOMMIN(INT_MAX, best_rd);
+ int sub_block_rdcost[8] = { 0 };
+ int rd_index = 0;
+ for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+ if (horz_rd[i] > 0 && horz_rd[i] < 1000000000)
+ sub_block_rdcost[rd_index] = (int)horz_rd[i];
+ ++rd_index;
+ }
+ for (int i = 0; i < SUB_PARTITIONS_RECT; ++i) {
+ if (vert_rd[i] > 0 && vert_rd[i] < 1000000000)
+ sub_block_rdcost[rd_index] = (int)vert_rd[i];
+ ++rd_index;
+ }
+ for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+ if (split_rd[i] > 0 && split_rd[i] < 1000000000)
+ sub_block_rdcost[rd_index] = (int)split_rd[i];
+ ++rd_index;
+ }
+ for (int i = 0; i < 8; ++i) {
+ // Ratio between the sub-block RD and the whole-block RD.
+ float rd_ratio = 1.0f;
+ if (sub_block_rdcost[i] > 0 && sub_block_rdcost[i] < rdcost)
+ rd_ratio = (float)sub_block_rdcost[i] / (float)rdcost;
+ features->after_part_ab.f[feature_index++] = rd_ratio;
+ }
+
+ // Get variance of the 1:4 and 4:1 sub-blocks.
+ unsigned int horz_4_source_var[SUB_PARTITIONS_PART4] = { 0 };
+ unsigned int vert_4_source_var[SUB_PARTITIONS_PART4] = { 0 };
+ {
+ BLOCK_SIZE horz_4_bs = get_partition_subsize(bsize, PARTITION_HORZ_4);
+ BLOCK_SIZE vert_4_bs = get_partition_subsize(bsize, PARTITION_VERT_4);
+ av1_setup_src_planes(x, cpi->source, mi_row, mi_col,
+ av1_num_planes(&cpi->common), bsize);
+ const int src_stride = x->plane[0].src.stride;
+ uint8_t *src = x->plane[0].src.buf;
+ const MACROBLOCKD *const xd = &x->e_mbd;
+
+ struct buf_2d horz_4_src, vert_4_src;
+ horz_4_src.stride = src_stride;
+ vert_4_src.stride = src_stride;
+
+ for (int i = 0; i < SUB_PARTITIONS_PART4; ++i) {
+ horz_4_src.buf = src + i * block_size_high[horz_4_bs] * src_stride;
+ vert_4_src.buf = src + i * block_size_wide[vert_4_bs];
+
+ if (is_cur_buf_hbd(xd)) {
+ horz_4_source_var[i] = av1_high_get_sby_perpixel_variance(
+ cpi, &horz_4_src, horz_4_bs, xd->bd);
+ vert_4_source_var[i] = av1_high_get_sby_perpixel_variance(
+ cpi, &vert_4_src, vert_4_bs, xd->bd);
+ } else {
+ horz_4_source_var[i] =
+ av1_get_sby_perpixel_variance(cpi, &horz_4_src, horz_4_bs);
+ vert_4_source_var[i] =
+ av1_get_sby_perpixel_variance(cpi, &vert_4_src, vert_4_bs);
+ }
+ }
+ }
+
+ const float denom = (float)(pb_source_variance + 1);
+ const float low_b = 0.1f;
+ const float high_b = 10.0f;
+ for (int i = 0; i < SUB_PARTITIONS_PART4; ++i) {
+ // Ratio between the 4:1 sub-block variance and the whole-block variance.
+ float var_ratio = (float)(horz_4_source_var[i] + 1) / denom;
+ if (var_ratio < low_b) var_ratio = low_b;
+ if (var_ratio > high_b) var_ratio = high_b;
+ features->after_part_ab.f[feature_index++] = var_ratio;
+ }
+ for (int i = 0; i < SUB_PARTITIONS_PART4; ++i) {
+ // Ratio between the 1:4 sub-block RD and the whole-block RD.
+ float var_ratio = (float)(vert_4_source_var[i] + 1) / denom;
+ if (var_ratio < low_b) var_ratio = low_b;
+ if (var_ratio > high_b) var_ratio = high_b;
+ features->after_part_ab.f[feature_index++] = var_ratio;
+ }
+ assert(feature_index == 18);
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions before partition none. Specifically, these parameters:
+// partition_none_allowed
+// partition_horz_allowed
+// partition_vert_allowed
+// do_rectangular_split
+// do_square_split
+static bool ext_ml_model_decision_before_none(
+ AV1_COMP *cpi, const float features_from_motion[FEATURE_SIZE_SMS_SPLIT],
+ int *partition_none_allowed, int *partition_horz_allowed,
+ int *partition_vert_allowed, int *do_rectangular_split,
+ int *do_square_split) {
+ // If we do not let the external partition model make decisions, return false.
+ return false;
+
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+ // Setup features.
+ aom_partition_features_t features;
+ for (int i = 0; i < FEATURE_SIZE_SMS_SPLIT; ++i) {
+ features.before_part_none.f[i] = features_from_motion[i];
+ }
+
+ // Send necessary features to the external model.
+ av1_ext_part_send_features(ext_part_controller, &features);
+
+ // Get partition decisions from the external model.
+ aom_partition_decision_t decision;
+ av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+ // Populate decisions
+ *partition_none_allowed = decision.partition_none_allowed;
+ *partition_horz_allowed = decision.partition_rect_allowed[HORZ];
+ *partition_vert_allowed = decision.partition_rect_allowed[VERT];
+ *do_rectangular_split = decision.do_rectangular_split;
+ *do_square_split = decision.do_square_split;
+
+ return true;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions before partition none. Specifically, these parameters:
+// prune_horz
+// prune_vert
+static bool ext_ml_model_decision_before_none_part2(
+ AV1_COMP *cpi,
+ const float features_from_motion[FEATURE_SIZE_SMS_PRUNE_PART],
+ int *prune_horz, int *prune_vert) {
+ // If we do not let the external partition model make decisions, return false.
+ return false;
+
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+ // Setup features.
+ aom_partition_features_t features;
+ for (int i = 0; i < FEATURE_SIZE_SMS_PRUNE_PART; ++i) {
+ features.before_part_none.f[i] = features_from_motion[i];
+ }
+
+ // Send necessary features to the external model.
+ av1_ext_part_send_features(ext_part_controller, &features);
+
+ // Get partition decisions from the external model.
+ aom_partition_decision_t decision;
+ av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+ // Populate decisions
+ *prune_horz = decision.prune_rect_part[HORZ];
+ *prune_vert = decision.prune_rect_part[VERT];
+
+ return true;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions after none partition. Specifically, these parameters:
+// do_square_split
+// do_rectangular_split
+// terminate_partition_search
+bool av1_ext_ml_model_decision_after_none(
+ AV1_COMP *const cpi, MACROBLOCK *const x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+ PartitionSearchState *part_search_state,
+ const unsigned int pb_source_variance) {
+ // If we do not let the external partition model make decisions, return false.
+ return false;
+
+ AV1_COMMON *const cm = &cpi->common;
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+ if (!frame_is_intra_only(cm) && ext_part_controller->ready) {
+ // Setup features.
+ aom_partition_features_t features;
+ prepare_features_after_part_none(cpi, cm, x, sms_tree, part_search_state,
+ pb_source_variance, &features);
+
+ // Send necessary features to the external model.
+ av1_ext_part_send_features(ext_part_controller, &features);
+
+ // Get partition decisions from the external model.
+ aom_partition_decision_t decision;
+ av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+ // Populate decisions
+ part_search_state->do_square_split = decision.do_square_split;
+ part_search_state->do_rectangular_split = decision.do_rectangular_split;
+ part_search_state->terminate_partition_search =
+ decision.terminate_partition_search;
+
+ return true;
+ }
+
+ return false;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions after none partition. Specifically, these parameters:
+// terminate_partition_search
+// prune_rect_part[HORZ]
+// prune_rect_part[VERT]
+bool av1_ext_ml_model_decision_after_split(
+ AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+ PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+ const int64_t partition_none_rdcost, const int64_t partition_split_rdcost,
+ const int64_t *split_block_rdcost) {
+ // If we do not let the external partition model make decisions, return false.
+ return false;
+
+ const AV1_COMMON *const cm = &cpi->common;
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+ if (!frame_is_intra_only(cm) && cpi->ext_part_controller.ready) {
+ // Setup features.
+ aom_partition_features_t features;
+ prepare_features_after_split(cpi, x, sms_tree, part_search_state, best_rdc,
+ partition_none_rdcost, partition_split_rdcost,
+ split_block_rdcost, &features);
+
+ // Send necessary features to the external model.
+ av1_ext_part_send_features(ext_part_controller, &features);
+
+ // Get partition decisions from the external model.
+ aom_partition_decision_t decision;
+ av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+ // Populate decisions
+ part_search_state->terminate_partition_search =
+ decision.terminate_partition_search;
+ part_search_state->prune_rect_part[HORZ] = decision.prune_rect_part[0];
+ part_search_state->prune_rect_part[VERT] = decision.prune_rect_part[1];
+
+ return true;
+ }
+
+ return false;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions after rectangular partition. Specifically, these parameters:
+// horza_partition_allowed
+// horzb_partition_allowed
+// verta_partition_allowed
+// vertb_partition_allowed
+static bool ext_ml_model_decision_after_rect(
+ AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+ BLOCK_SIZE bsize, int pb_source_variance, int64_t best_rdcost,
+ int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+ int64_t split_rd[SUB_PARTITIONS_SPLIT], int ext_partition_allowed,
+ int partition_horz_allowed, int partition_vert_allowed,
+ int *horza_partition_allowed, int *horzb_partition_allowed,
+ int *verta_partition_allowed, int *vertb_partition_allowed) {
+ // If we do not let the external partition model make decisions, return false.
+ return false;
+
+ (void)pb_source_variance;
+ const AV1_COMMON *const cm = &cpi->common;
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+ const PartitionCfg *const part_cfg = &cpi->oxcf.part_cfg;
+ // The standard AB partitions are allowed initially if ext-partition-types are
+ // allowed.
+ int ab_partition_allowed =
+ ext_partition_allowed & part_cfg->enable_ab_partitions;
+
+ if (!frame_is_intra_only(cm) && ext_part_controller->ready &&
+ partition_horz_allowed && partition_vert_allowed &&
+ ab_partition_allowed) {
+ // Setup features.
+ aom_partition_features_t features;
+ prepare_features_after_part_rect(cpi, x, pc_tree, bsize, best_rdcost,
+ rect_part_rd, split_rd, &features);
+
+ // Send necessary features to the external model.
+ av1_ext_part_send_features(ext_part_controller, &features);
+
+ // Get partition decisions from the external model.
+ aom_partition_decision_t decision;
+ av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+ // Populate decisions
+ *horza_partition_allowed = decision.horza_partition_allowed;
+ *horzb_partition_allowed = decision.horzb_partition_allowed;
+ *verta_partition_allowed = decision.verta_partition_allowed;
+ *vertb_partition_allowed = decision.vertb_partition_allowed;
+
+ return true;
+ }
+
+ return false;
+}
+
+// If the external partition model is used, we let it determine partition
+// decisions after AB partition. Specifically, these parameters:
+// partition_vert4_allowed
+// partition_horz4_allowed
+static bool ext_ml_model_decision_after_part_ab(
+ AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
+ int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+ int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
+ int *const partition_vert4_allowed, unsigned int pb_source_variance,
+ int mi_row, int mi_col) {
+ // If we do not let the external partition model make decisions, return false.
+ return false;
+
+ const AV1_COMMON *const cm = &cpi->common;
+ ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+
+ if (!frame_is_intra_only(cm) && ext_part_controller->ready) {
+ // Setup features.
+ aom_partition_features_t features;
+ prepare_features_after_part_ab(cpi, x, bsize, part_ctx, best_rd,
+ rect_part_rd, split_rd, pb_source_variance,
+ mi_row, mi_col, &features);
+
+ // Send necessary features to the external model.
+ av1_ext_part_send_features(ext_part_controller, &features);
+
+ // Get partition decisions from the external model.
+ aom_partition_decision_t decision;
+ av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+
+ // Populate decisions
+ *partition_horz4_allowed = decision.partition_horz4_allowed;
+ *partition_vert4_allowed = decision.partition_vert4_allowed;
+
+ return true;
+ }
+
+ return false;
+}
+
#endif // !CONFIG_REALTIME_ONLY
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index 621edae..f744f87 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -98,9 +98,8 @@
// Use a ML model to predict if horz4 and vert4 should be considered.
void av1_ml_prune_4_partition(
- const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
- int part_ctx, int64_t best_rd,
- int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
+ AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize, int part_ctx,
+ int64_t best_rd, int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
int *const partition_vert4_allowed, unsigned int pb_source_variance,
int mi_row, int mi_col);
@@ -135,7 +134,7 @@
// Prune out AB partitions based on rd decisions made from testing the
// basic partitions.
void av1_prune_ab_partitions(
- const AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
+ AV1_COMP *cpi, const MACROBLOCK *x, const PC_TREE *pc_tree,
BLOCK_SIZE bsize, int pb_source_variance, int64_t best_rdcost,
int64_t rect_part_rd[NUM_RECT_PARTS][SUB_PARTITIONS_RECT],
int64_t split_rd[SUB_PARTITIONS_SPLIT],
@@ -274,5 +273,16 @@
sb_enc->min_partition_size);
}
}
+
+bool av1_ext_ml_model_decision_after_none(
+ AV1_COMP *const cpi, MACROBLOCK *const x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+ PartitionSearchState *part_search_state,
+ const unsigned int pb_source_variance);
+
+bool av1_ext_ml_model_decision_after_split(
+ AV1_COMP *const cpi, MACROBLOCK *x, SIMPLE_MOTION_DATA_TREE *sms_tree,
+ PartitionSearchState *part_search_state, RD_STATS *best_rdc,
+ const int64_t partition_none_rdcost, const int64_t partition_split_rdcost,
+ const int64_t *split_block_rdcost);
#endif // !CONFIG_REALTIME_ONLY
#endif // AOM_AV1_ENCODER_PARTITION_STRATEGY_H_