External partition: Prepare superblock features before search

Collect features and send to external model.

Change-Id: I11205825ea06833d76fdc8686cc09738a941a1be
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index d0fd8df..bb7ea7e 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -78,8 +78,10 @@
 }
 
 #if !CONFIG_REALTIME_ONLY
-// Write tpl stats to text file for each super block.
+// If input |features| is NULL, write tpl stats to file for each super block.
+// Otherwise, store tpl stats to |features|.
 // The tpl stats is computed in the unit of tpl_bsize_1d (16x16).
+// When writing to text file:
 // The first row contains super block position, super block size,
 // tpl unit length, number of units in the super block.
 // The second row contains the intra prediction cost for each unit.
@@ -87,7 +89,8 @@
 // The forth row contains the motion compensated dependency cost for each unit.
 static void collect_tpl_stats_sb(const AV1_COMP *const cpi,
                                  const BLOCK_SIZE bsize, const int mi_row,
-                                 const int mi_col) {
+                                 const int mi_col,
+                                 aom_partition_features_t *features) {
   const AV1_COMMON *const cm = &cpi->common;
   GF_GROUP *gf_group = &cpi->ppi->gf_group;
   if (gf_group->update_type[cpi->gf_frame_index] == INTNL_OVERLAY_UPDATE ||
@@ -108,51 +111,73 @@
   const int row_steps = mi_height / step;
   const int num_blocks = col_steps * row_steps;
 
-  char filename[256];
-  snprintf(filename, sizeof(filename), "%s/tpl_feature_sb%d",
-           cpi->oxcf.partition_info_path, cpi->sb_counter);
-  FILE *pfile = fopen(filename, "w");
-  fprintf(pfile, "%d,%d,%d,%d,%d\n", mi_row, mi_col, bsize,
-          tpl_data->tpl_bsize_1d, num_blocks);
-  int count = 0;
-  for (int row = 0; row < mi_height; row += step) {
-    for (int col = 0; col < mi_width; col += step) {
-      TplDepStats *this_stats =
-          &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
-                                     tpl_data->tpl_stats_block_mis_log2)];
-      fprintf(pfile, "%.0f", (double)this_stats->intra_cost);
-      if (count < num_blocks - 1) fprintf(pfile, ",");
-      ++count;
+  if (features == NULL) {
+    char filename[256];
+    snprintf(filename, sizeof(filename), "%s/tpl_feature_sb%d",
+             cpi->oxcf.partition_info_path, cpi->sb_counter);
+    FILE *pfile = fopen(filename, "w");
+    fprintf(pfile, "%d,%d,%d,%d,%d\n", mi_row, mi_col, bsize,
+            tpl_data->tpl_bsize_1d, num_blocks);
+    int count = 0;
+    for (int row = 0; row < mi_height; row += step) {
+      for (int col = 0; col < mi_width; col += step) {
+        TplDepStats *this_stats =
+            &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
+                                       tpl_data->tpl_stats_block_mis_log2)];
+        fprintf(pfile, "%.0f", (double)this_stats->intra_cost);
+        if (count < num_blocks - 1) fprintf(pfile, ",");
+        ++count;
+      }
+    }
+    fprintf(pfile, "\n");
+    count = 0;
+    for (int row = 0; row < mi_height; row += step) {
+      for (int col = 0; col < mi_width; col += step) {
+        TplDepStats *this_stats =
+            &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
+                                       tpl_data->tpl_stats_block_mis_log2)];
+        fprintf(pfile, "%.0f", (double)this_stats->inter_cost);
+        if (count < num_blocks - 1) fprintf(pfile, ",");
+        ++count;
+      }
+    }
+    fprintf(pfile, "\n");
+    count = 0;
+    for (int row = 0; row < mi_height; row += step) {
+      for (int col = 0; col < mi_width; col += step) {
+        TplDepStats *this_stats =
+            &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
+                                       tpl_data->tpl_stats_block_mis_log2)];
+        const int64_t mc_dep_delta =
+            RDCOST(tpl_frame->base_rdmult, this_stats->mc_dep_rate,
+                   this_stats->mc_dep_dist);
+        fprintf(pfile, "%.0f", (double)mc_dep_delta);
+        if (count < num_blocks - 1) fprintf(pfile, ",");
+        ++count;
+      }
+    }
+    fclose(pfile);
+  } else {
+    features->sb_features.tpl_features.tpl_unit_length = tpl_data->tpl_bsize_1d;
+    features->sb_features.tpl_features.num_units = num_blocks;
+    int count = 0;
+    for (int row = 0; row < mi_height; row += step) {
+      for (int col = 0; col < mi_width; col += step) {
+        TplDepStats *this_stats =
+            &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
+                                       tpl_data->tpl_stats_block_mis_log2)];
+        const int64_t mc_dep_delta =
+            RDCOST(tpl_frame->base_rdmult, this_stats->mc_dep_rate,
+                   this_stats->mc_dep_dist);
+        features->sb_features.tpl_features.intra_cost[count] =
+            this_stats->intra_cost;
+        features->sb_features.tpl_features.inter_cost[count] =
+            this_stats->inter_cost;
+        features->sb_features.tpl_features.mc_dep_cost[count] = mc_dep_delta;
+        ++count;
+      }
     }
   }
-  fprintf(pfile, "\n");
-  count = 0;
-  for (int row = 0; row < mi_height; row += step) {
-    for (int col = 0; col < mi_width; col += step) {
-      TplDepStats *this_stats =
-          &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
-                                     tpl_data->tpl_stats_block_mis_log2)];
-      fprintf(pfile, "%.0f", (double)this_stats->inter_cost);
-      if (count < num_blocks - 1) fprintf(pfile, ",");
-      ++count;
-    }
-  }
-  fprintf(pfile, "\n");
-  count = 0;
-  for (int row = 0; row < mi_height; row += step) {
-    for (int col = 0; col < mi_width; col += step) {
-      TplDepStats *this_stats =
-          &tpl_stats[av1_tpl_ptr_pos(mi_row + row, mi_col + col, tpl_stride,
-                                     tpl_data->tpl_stats_block_mis_log2)];
-      const int64_t mc_dep_delta =
-          RDCOST(tpl_frame->base_rdmult, this_stats->mc_dep_rate,
-                 this_stats->mc_dep_dist);
-      fprintf(pfile, "%.0f", (double)mc_dep_delta);
-      if (count < num_blocks - 1) fprintf(pfile, ",");
-      ++count;
-    }
-  }
-  fclose(pfile);
 }
 #endif  // !CONFIG_REALTIME_ONLY
 
@@ -4053,11 +4078,48 @@
   return best_rdc;
 }
 
+static void prepare_sb_features_before_search(
+    AV1_COMP *const cpi, ThreadData *td, int mi_row, int mi_col,
+    const BLOCK_SIZE bsize, aom_partition_features_t *features) {
+  av1_collect_motion_search_features_sb(cpi, td, mi_row, mi_col, bsize,
+                                        features);
+  collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col, features);
+}
+
+static bool ml_partition_search(AV1_COMP *const cpi, ThreadData *td,
+                                TileDataEnc *tile_data, TokenExtra **tp,
+                                SIMPLE_MOTION_DATA_TREE *sms_root, int mi_row,
+                                int mi_col, const BLOCK_SIZE bsize) {
+  (void)tile_data;
+  (void)tp;
+  (void)sms_root;
+  ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+  aom_partition_features_t features;
+  prepare_sb_features_before_search(cpi, td, mi_row, mi_col, bsize, &features);
+  av1_ext_part_send_features(ext_part_controller, &features);
+
+  // TODO(chengchen): implement the main body (APIs) of the the function:
+  // (1). Get partition decision from external ml model.
+  // (2). Encode with the given decision.
+  // (3). Send stats back to external ml model.
+  return true;
+}
+
 bool av1_rd_partition_search(AV1_COMP *const cpi, ThreadData *td,
                              TileDataEnc *tile_data, TokenExtra **tp,
                              SIMPLE_MOTION_DATA_TREE *sms_root, int mi_row,
                              int mi_col, const BLOCK_SIZE bsize,
                              RD_STATS *best_rd_cost) {
+  if (cpi->ext_part_controller.ready) {
+    const bool valid_search = ml_partition_search(
+        cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize);
+    if (!valid_search) {
+      assert(0 && "Invalid search from ML model, partition search failed.");
+      exit(0);
+    }
+    return true;
+  }
+
   AV1_COMMON *const cm = &cpi->common;
   MACROBLOCK *const x = &td->mb;
   int best_idx = 0;
@@ -4224,8 +4286,9 @@
   // av1_get_max_min_partition_features().
   if (COLLECT_MOTION_SEARCH_FEATURE_SB && !frame_is_intra_only(cm) &&
       bsize == cm->seq_params->sb_size) {
-    av1_collect_motion_search_features_sb(cpi, td, mi_row, mi_col, bsize);
-    collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col);
+    av1_collect_motion_search_features_sb(cpi, td, mi_row, mi_col, bsize,
+                                          /*features=*/NULL);
+    collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col, /*features=*/NULL);
   }
 
   // Update rd cost of the bound using the current multiplier.
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index 1825e98..facd088 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -2246,8 +2246,8 @@
   snprintf(filename, sizeof(filename), "%s/motion_search_feature_sb%d", path,
            sb_counter);
   FILE *pfile = fopen(filename, "w");
-  fprintf(pfile, "%d,%d,%d,%d,%d\n", mi_row, mi_col, bsize, fixed_block_size,
-          num_blocks);
+  fprintf(pfile, "%d,%d,%d,%d,%d\n", mi_row, mi_col, bsize,
+          block_size_wide[fixed_block_size], num_blocks);
   for (int i = 0; i < num_blocks; ++i) {
     fprintf(pfile, "%d", block_sse[i]);
     if (i < num_blocks - 1) fprintf(pfile, ",");
@@ -2263,7 +2263,8 @@
 
 void av1_collect_motion_search_features_sb(AV1_COMP *const cpi, ThreadData *td,
                                            const int mi_row, const int mi_col,
-                                           const BLOCK_SIZE bsize) {
+                                           const BLOCK_SIZE bsize,
+                                           aom_partition_features_t *features) {
   const AV1_COMMON *const cm = &cpi->common;
   MACROBLOCK *const x = &td->mb;
   const BLOCK_SIZE fixed_block_size = BLOCK_16X16;
@@ -2294,9 +2295,19 @@
       ++idx;
     }
   }
-  write_motion_feature_to_file(cpi->oxcf.partition_info_path, cpi->sb_counter,
-                               block_sse, block_var, idx, bsize,
-                               fixed_block_size, mi_row, mi_col);
+  if (features == NULL) {
+    write_motion_feature_to_file(cpi->oxcf.partition_info_path, cpi->sb_counter,
+                                 block_sse, block_var, idx, bsize,
+                                 fixed_block_size, mi_row, mi_col);
+  } else {
+    features->sb_features.motion_features.unit_length =
+        block_size_wide[fixed_block_size];
+    features->sb_features.motion_features.num_units = idx;
+    for (int i = 0; i < idx; ++i) {
+      features->sb_features.motion_features.block_sse[i] = block_sse[i];
+      features->sb_features.motion_features.block_var[i] = block_var[i];
+    }
+  }
 
   aom_free(block_sse);
   aom_free(block_var);
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index 6b87bfa..23a94d7 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -150,7 +150,8 @@
 
 void av1_collect_motion_search_features_sb(AV1_COMP *const cpi, ThreadData *td,
                                            const int mi_row, const int mi_col,
-                                           const BLOCK_SIZE bsize);
+                                           const BLOCK_SIZE bsize,
+                                           aom_partition_features_t *features);
 #endif  // !CONFIG_REALTIME_ONLY
 
 // A simplified version of set_offsets meant to be used for