ExtPart: Pass valid partition types to the ML model
First determine valide partition types for the curren block according
to the AV1 spec then pass it as a bitmask to the ML model.
Change-Id: Iebfc2c3bae79148e14e573c7e49a96dbcfdb7a46
diff --git a/aom/aom_external_partition.h b/aom/aom_external_partition.h
index 1bb31c4..e39921d 100644
--- a/aom/aom_external_partition.h
+++ b/aom/aom_external_partition.h
@@ -30,7 +30,7 @@
* types, removing or reassigning enums, adding/removing/rearranging
* fields to structures.
*/
-#define AOM_EXT_PART_ABI_VERSION 3
+#define AOM_EXT_PART_ABI_VERSION 4
#ifdef __cplusplus
extern "C" {
@@ -240,6 +240,14 @@
int frame_width; ///< Frame width
int frame_height; ///< Frame height
int block_size; ///< As "BLOCK_SIZE" in av1/common/enums.h
+ /*!
+ * Valid partition types. A bitmask is used. "1" represents the
+ * corresponding type is vaild. The bitmask follows the enum order for
+ * PARTITION_TYPE in "enums.h" to represent one partition type at a bit.
+ * For example, 0x01 stands for only PARTITION_NONE is valid,
+ * 0x09 (00...001001) stands for PARTITION_NONE and PARTITION_SPLIT are valid.
+ */
+ int valid_partition_types;
} aom_partition_features_t;
/*!\brief Partition decisions received from the external model.
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index 1bc6b32..8598bb4 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -4206,6 +4206,63 @@
return true;
}
+// Use a bitmask to represent the valid partition types for the current
+// block. "1" represents the corresponding partition type is vaild.
+// The least significant bit represents "PARTITION_NONE", the
+// largest significant bit represents "PARTITION_VERT_4", follow
+// the enum order for PARTITION_TYPE in "enums.h"
+static int get_valid_partition_types(
+ const AV1_COMP *const cpi,
+ const PartitionSearchState *const part_search_state,
+ const BLOCK_SIZE bsize) {
+ const PartitionCfg *const part_cfg = &cpi->oxcf.part_cfg;
+ const PartitionBlkParams blk_params = part_search_state->part_blk_params;
+ int valid_types = 0;
+ // PARTITION_NONE
+ valid_types |= (part_search_state->partition_none_allowed << 0);
+ // PARTITION_HORZ
+ valid_types |= (part_search_state->partition_rect_allowed[HORZ] << 1);
+ // PARTITION_VERT
+ valid_types |= (part_search_state->partition_rect_allowed[VERT] << 2);
+ // PARTITION_SPLIT
+ valid_types |= (part_search_state->do_square_split << 3);
+ // PARTITION_HORZ_A
+ const int ext_partition_allowed = part_search_state->do_rectangular_split &&
+ av1_blk_has_rows_and_cols(&blk_params);
+ const int horzab_partition_allowed =
+ ext_partition_allowed && part_cfg->enable_ab_partitions &&
+ part_search_state->partition_rect_allowed[HORZ];
+ valid_types |= (horzab_partition_allowed << 4);
+ // PARTITION_HORZ_B
+ valid_types |= (horzab_partition_allowed << 5);
+ // PARTITION_VERT_A
+ const int vertab_partition_allowed =
+ ext_partition_allowed && part_cfg->enable_ab_partitions &&
+ part_search_state->partition_rect_allowed[VERT];
+ valid_types |= (vertab_partition_allowed << 6);
+ // PARTITION_VERT_B
+ valid_types |= (vertab_partition_allowed << 7);
+ // PARTITION_HORZ_4
+ const int partition4_allowed = part_cfg->enable_1to4_partitions &&
+ ext_partition_allowed &&
+ bsize != BLOCK_128X128;
+ const int horz4_allowed =
+ partition4_allowed && part_search_state->partition_rect_allowed[HORZ] &&
+ get_plane_block_size(get_partition_subsize(bsize, PARTITION_HORZ_4),
+ part_search_state->ss_x,
+ part_search_state->ss_y) != BLOCK_INVALID;
+ valid_types |= (horz4_allowed << 8);
+ // PARTITION_VERT_4
+ const int vert4_allowed =
+ partition4_allowed && part_search_state->partition_rect_allowed[HORZ] &&
+ get_plane_block_size(get_partition_subsize(bsize, PARTITION_VERT_4),
+ part_search_state->ss_x,
+ part_search_state->ss_y) != BLOCK_INVALID;
+ valid_types |= (vert4_allowed << 9);
+
+ return valid_types;
+}
+
static bool recursive_partition(AV1_COMP *const cpi, ThreadData *td,
TileDataEnc *tile_data, TokenExtra **tp,
SIMPLE_MOTION_DATA_TREE *sms_root,
@@ -4219,19 +4276,7 @@
}
aom_partition_decision_t partition_decision;
do {
- aom_partition_features_t features;
- features.mi_row = mi_row;
- features.mi_col = mi_col;
- features.frame_width = cpi->frame_info.frame_width;
- features.frame_height = cpi->frame_info.frame_height;
- features.block_size = bsize;
- av1_ext_part_send_features(ext_part_controller, &features);
- const bool valid_decision = av1_ext_part_get_partition_decision(
- ext_part_controller, &partition_decision);
- if (!valid_decision) return false;
- pc_tree->partitioning = partition_decision.current_decision;
PartitionSearchState part_search_state;
-
// Initialization of state variables used in partition search.
// TODO(chengchen): check if there is hidden conditions that don't allow
// all possible partition types.
@@ -4242,6 +4287,21 @@
PartitionBlkParams blk_params = part_search_state.part_blk_params;
if (!av1_blk_has_rows_and_cols(&blk_params))
set_partition_cost_for_edge_blk(cm, &part_search_state);
+ const int valid_partition_types =
+ get_valid_partition_types(cpi, &part_search_state, bsize);
+
+ aom_partition_features_t features;
+ features.mi_row = mi_row;
+ features.mi_col = mi_col;
+ features.frame_width = cpi->frame_info.frame_width;
+ features.frame_height = cpi->frame_info.frame_height;
+ features.block_size = bsize;
+ features.valid_partition_types = valid_partition_types;
+ av1_ext_part_send_features(ext_part_controller, &features);
+ const bool valid_decision = av1_ext_part_get_partition_decision(
+ ext_part_controller, &partition_decision);
+ if (!valid_decision) return false;
+ pc_tree->partitioning = partition_decision.current_decision;
av1_init_rd_stats(this_rdcost);
if (partition_decision.current_decision == PARTITION_SPLIT) {