ExtPart: Pass valid partition types to the ML model

First determine valide partition types for the curren block according
to the AV1 spec then pass it as a bitmask to the ML model.

Change-Id: Iebfc2c3bae79148e14e573c7e49a96dbcfdb7a46
diff --git a/aom/aom_external_partition.h b/aom/aom_external_partition.h
index 1bb31c4..e39921d 100644
--- a/aom/aom_external_partition.h
+++ b/aom/aom_external_partition.h
@@ -30,7 +30,7 @@
  * types, removing or reassigning enums, adding/removing/rearranging
  * fields to structures.
  */
-#define AOM_EXT_PART_ABI_VERSION 3
+#define AOM_EXT_PART_ABI_VERSION 4
 
 #ifdef __cplusplus
 extern "C" {
@@ -240,6 +240,14 @@
   int frame_width;                ///< Frame width
   int frame_height;               ///< Frame height
   int block_size;                 ///< As "BLOCK_SIZE" in av1/common/enums.h
+  /*!
+   * Valid partition types. A bitmask is used.  "1" represents the
+   * corresponding type is vaild. The bitmask follows the enum order for
+   * PARTITION_TYPE in "enums.h" to represent one partition type at a bit.
+   * For example, 0x01 stands for only PARTITION_NONE is valid,
+   * 0x09 (00...001001) stands for PARTITION_NONE and PARTITION_SPLIT are valid.
+   */
+  int valid_partition_types;
 } aom_partition_features_t;
 
 /*!\brief Partition decisions received from the external model.
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index 1bc6b32..8598bb4 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -4206,6 +4206,63 @@
   return true;
 }
 
+// Use a bitmask to represent the valid partition types for the current
+// block. "1" represents the corresponding partition type is vaild.
+// The least significant bit represents "PARTITION_NONE", the
+// largest significant bit represents "PARTITION_VERT_4", follow
+// the enum order for PARTITION_TYPE in "enums.h"
+static int get_valid_partition_types(
+    const AV1_COMP *const cpi,
+    const PartitionSearchState *const part_search_state,
+    const BLOCK_SIZE bsize) {
+  const PartitionCfg *const part_cfg = &cpi->oxcf.part_cfg;
+  const PartitionBlkParams blk_params = part_search_state->part_blk_params;
+  int valid_types = 0;
+  // PARTITION_NONE
+  valid_types |= (part_search_state->partition_none_allowed << 0);
+  // PARTITION_HORZ
+  valid_types |= (part_search_state->partition_rect_allowed[HORZ] << 1);
+  // PARTITION_VERT
+  valid_types |= (part_search_state->partition_rect_allowed[VERT] << 2);
+  // PARTITION_SPLIT
+  valid_types |= (part_search_state->do_square_split << 3);
+  // PARTITION_HORZ_A
+  const int ext_partition_allowed = part_search_state->do_rectangular_split &&
+                                    av1_blk_has_rows_and_cols(&blk_params);
+  const int horzab_partition_allowed =
+      ext_partition_allowed && part_cfg->enable_ab_partitions &&
+      part_search_state->partition_rect_allowed[HORZ];
+  valid_types |= (horzab_partition_allowed << 4);
+  // PARTITION_HORZ_B
+  valid_types |= (horzab_partition_allowed << 5);
+  // PARTITION_VERT_A
+  const int vertab_partition_allowed =
+      ext_partition_allowed && part_cfg->enable_ab_partitions &&
+      part_search_state->partition_rect_allowed[VERT];
+  valid_types |= (vertab_partition_allowed << 6);
+  // PARTITION_VERT_B
+  valid_types |= (vertab_partition_allowed << 7);
+  // PARTITION_HORZ_4
+  const int partition4_allowed = part_cfg->enable_1to4_partitions &&
+                                 ext_partition_allowed &&
+                                 bsize != BLOCK_128X128;
+  const int horz4_allowed =
+      partition4_allowed && part_search_state->partition_rect_allowed[HORZ] &&
+      get_plane_block_size(get_partition_subsize(bsize, PARTITION_HORZ_4),
+                           part_search_state->ss_x,
+                           part_search_state->ss_y) != BLOCK_INVALID;
+  valid_types |= (horz4_allowed << 8);
+  // PARTITION_VERT_4
+  const int vert4_allowed =
+      partition4_allowed && part_search_state->partition_rect_allowed[HORZ] &&
+      get_plane_block_size(get_partition_subsize(bsize, PARTITION_VERT_4),
+                           part_search_state->ss_x,
+                           part_search_state->ss_y) != BLOCK_INVALID;
+  valid_types |= (vert4_allowed << 9);
+
+  return valid_types;
+}
+
 static bool recursive_partition(AV1_COMP *const cpi, ThreadData *td,
                                 TileDataEnc *tile_data, TokenExtra **tp,
                                 SIMPLE_MOTION_DATA_TREE *sms_root,
@@ -4219,19 +4276,7 @@
   }
   aom_partition_decision_t partition_decision;
   do {
-    aom_partition_features_t features;
-    features.mi_row = mi_row;
-    features.mi_col = mi_col;
-    features.frame_width = cpi->frame_info.frame_width;
-    features.frame_height = cpi->frame_info.frame_height;
-    features.block_size = bsize;
-    av1_ext_part_send_features(ext_part_controller, &features);
-    const bool valid_decision = av1_ext_part_get_partition_decision(
-        ext_part_controller, &partition_decision);
-    if (!valid_decision) return false;
-    pc_tree->partitioning = partition_decision.current_decision;
     PartitionSearchState part_search_state;
-
     // Initialization of state variables used in partition search.
     // TODO(chengchen): check if there is hidden conditions that don't allow
     // all possible partition types.
@@ -4242,6 +4287,21 @@
     PartitionBlkParams blk_params = part_search_state.part_blk_params;
     if (!av1_blk_has_rows_and_cols(&blk_params))
       set_partition_cost_for_edge_blk(cm, &part_search_state);
+    const int valid_partition_types =
+        get_valid_partition_types(cpi, &part_search_state, bsize);
+
+    aom_partition_features_t features;
+    features.mi_row = mi_row;
+    features.mi_col = mi_col;
+    features.frame_width = cpi->frame_info.frame_width;
+    features.frame_height = cpi->frame_info.frame_height;
+    features.block_size = bsize;
+    features.valid_partition_types = valid_partition_types;
+    av1_ext_part_send_features(ext_part_controller, &features);
+    const bool valid_decision = av1_ext_part_get_partition_decision(
+        ext_part_controller, &partition_decision);
+    if (!valid_decision) return false;
+    pc_tree->partitioning = partition_decision.current_decision;
 
     av1_init_rd_stats(this_rdcost);
     if (partition_decision.current_decision == PARTITION_SPLIT) {