ExtPart: Unit test for recursive ml partition

We compare the psnr values to two encoding runs.

The first run is encoding with the baseline, by restricting block sizes
to the super block size.

The second run is partition decisions from a toy model that provides
matching partition decisions as the first run.

The psnr values differ by a tiny amount (less than 0.001 db).

Change-Id: I677e616f284c744a9d7735f5c4777a57de0ac816
diff --git a/aom/aom_external_partition.h b/aom/aom_external_partition.h
index daf8fe4..a5220ac 100644
--- a/aom/aom_external_partition.h
+++ b/aom/aom_external_partition.h
@@ -235,6 +235,7 @@
   int mi_col;                    /**< Mi_col position of the block */
   int frame_width;               /**< Frame width */
   int frame_height;              /**< Frame height */
+  int block_size;                /**< As "BLOCK_SIZE" in av1/common/enums.h */
 } aom_partition_features_t;
 
 /*!\brief Partition decisions received from the external model.
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index 8a1e2aa..5e8e5f5 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -4161,6 +4161,7 @@
   features.mi_col = mi_col;
   features.frame_width = cpi->frame_info.frame_width;
   features.frame_height = cpi->frame_info.frame_height;
+  features.block_size = bsize;
   av1_ext_part_send_features(ext_part_controller, &features);
   PC_TREE *pc_tree;
 
@@ -4210,6 +4211,13 @@
   }
   aom_partition_decision_t partition_decision;
   do {
+    aom_partition_features_t features;
+    features.mi_row = mi_row;
+    features.mi_col = mi_col;
+    features.frame_width = cpi->frame_info.frame_width;
+    features.frame_height = cpi->frame_info.frame_height;
+    features.block_size = bsize;
+    av1_ext_part_send_features(ext_part_controller, &features);
     const bool valid_decision = av1_ext_part_get_partition_decision(
         ext_part_controller, &partition_decision);
     if (!valid_decision) return false;
@@ -4221,6 +4229,11 @@
     // all possible partition types.
     init_partition_search_state_params(x, cpi, &part_search_state, mi_row,
                                        mi_col, bsize);
+    // Override partition costs at the edges of the frame in the same
+    // way as in read_partition (see decodeframe.c).
+    PartitionBlkParams blk_params = part_search_state.part_blk_params;
+    if (!av1_blk_has_rows_and_cols(&blk_params))
+      set_partition_cost_for_edge_blk(cm, &part_search_state);
 
     av1_init_rd_stats(this_rdcost);
     if (partition_decision.current_decision == PARTITION_SPLIT) {
@@ -4233,6 +4246,9 @@
           pc_tree->split[i] = av1_alloc_pc_tree_node(subsize);
         pc_tree->split[i]->index = i;
       }
+      const int orig_rdmult = x->rdmult;
+      setup_block_rdmult(cpi, x, mi_row, mi_col, bsize, NO_AQ, NULL);
+      (void)orig_rdmult;
       // TODO(chengchen): check boundary conditions
       // top-left
       recursive_partition(cpi, td, tile_data, tp, sms_root, pc_tree->split[0],
@@ -4251,11 +4267,13 @@
                           mi_col + mi_size_wide[subsize], subsize,
                           &split_rdc[3]);
       this_rdcost->rate += part_search_state.partition_cost[PARTITION_SPLIT];
+      // problem is here, the rdmult is different from the rdmult in sub block.
       for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
         this_rdcost->rate += split_rdc[i].rate;
         this_rdcost->dist += split_rdc[i].dist;
         av1_rd_cost_update(x->rdmult, this_rdcost);
       }
+      x->rdmult = orig_rdmult;
     } else {
       *this_rdcost = rd_search_for_fixed_partition(
           cpi, td, tile_data, tp, sms_root, mi_row, mi_col, bsize, pc_tree);
@@ -4291,6 +4309,11 @@
   ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
   aom_partition_features_t features;
   prepare_sb_features_before_search(cpi, td, mi_row, mi_col, bsize, &features);
+  features.mi_row = mi_row;
+  features.mi_col = mi_col;
+  features.frame_width = cpi->frame_info.frame_width;
+  features.frame_height = cpi->frame_info.frame_height;
+  features.block_size = bsize;
   av1_ext_part_send_features(ext_part_controller, &features);
   PC_TREE *pc_tree;
   pc_tree = av1_alloc_pc_tree_node(bsize);
diff --git a/test/av1_external_partition_test.cc b/test/av1_external_partition_test.cc
index 20329de..d6a7507 100644
--- a/test/av1_external_partition_test.cc
+++ b/test/av1_external_partition_test.cc
@@ -42,6 +42,7 @@
   int mi_col;
   int frame_width;
   int frame_height;
+  BLOCK_SIZE block_size;
 } ToyModel;
 
 // Note:
@@ -71,6 +72,7 @@
   toy_model->mi_col = part_features->mi_col;
   toy_model->frame_width = part_features->frame_width;
   toy_model->frame_height = part_features->frame_height;
+  toy_model->block_size = static_cast<BLOCK_SIZE>(part_features->block_size);
   return AOM_EXT_PART_OK;
 }
 
@@ -181,6 +183,43 @@
   return AOM_EXT_PART_OK;
 }
 
+aom_ext_part_status_t ext_part_get_partition_decision_recursive(
+    aom_ext_part_model_t ext_part_model,
+    aom_partition_decision_t *ext_part_decision) {
+  ext_part_decision->current_decision = PARTITION_NONE;
+  ext_part_decision->is_final_decision = 1;
+  ToyModel *toy_model = static_cast<ToyModel *>(ext_part_model);
+  // Note: super block size is fixed to BLOCK_64X64 for the
+  // input video. It is determined inside the encoder, see the
+  // check in "ext_part_create_model".
+  const int is_last_sb_col =
+      toy_model->mi_col * 4 + 64 > toy_model->frame_width;
+  const int is_last_sb_row =
+      toy_model->mi_row * 4 + 64 > toy_model->frame_height;
+  if (is_last_sb_row && is_last_sb_col) {
+    if (block_size_wide[toy_model->block_size] == 64) {
+      ext_part_decision->current_decision = PARTITION_SPLIT;
+    } else {
+      ext_part_decision->current_decision = PARTITION_NONE;
+    }
+  } else if (is_last_sb_row) {
+    if (block_size_wide[toy_model->block_size] == 64) {
+      ext_part_decision->current_decision = PARTITION_SPLIT;
+    } else {
+      ext_part_decision->current_decision = PARTITION_NONE;
+    }
+  } else if (is_last_sb_col) {
+    if (block_size_wide[toy_model->block_size] == 64) {
+      ext_part_decision->current_decision = PARTITION_SPLIT;
+    } else {
+      ext_part_decision->current_decision = PARTITION_NONE;
+    }
+  } else {
+    ext_part_decision->current_decision = PARTITION_NONE;
+  }
+  return AOM_EXT_PART_OK;
+}
+
 aom_ext_part_status_t ext_part_send_partition_stats(
     aom_ext_part_model_t ext_part_model,
     const aom_partition_stats_t *ext_part_stats) {
@@ -240,39 +279,79 @@
 
   void SetPartitionControlMode(int mode) { partition_control_mode_ = mode; }
 
+  void SetDecisionMode(aom_ext_part_decision_mode_t mode) {
+    decision_mode_ = mode;
+  }
+
   virtual void PreEncodeFrameHook(::libaom_test::VideoSource *video,
                                   ::libaom_test::Encoder *encoder) {
     if (video->frame() == 0) {
-      aom_ext_part_funcs_t ext_part_funcs;
-      ext_part_funcs.priv = reinterpret_cast<void *>(&test_data_);
-      ext_part_funcs.decision_mode = WHOLE_TREE_DECISION;
-      ext_part_funcs.create_model = ext_part_create_model;
-      ext_part_funcs.send_features = ext_part_send_features;
-      ext_part_funcs.get_partition_decision =
-          ext_part_get_partition_decision_whole_tree;
-      ext_part_funcs.send_partition_stats = ext_part_send_partition_stats;
-      ext_part_funcs.delete_model = ext_part_delete_model;
+      if (decision_mode_ == WHOLE_TREE_DECISION) {
+        aom_ext_part_funcs_t ext_part_funcs;
+        ext_part_funcs.priv = reinterpret_cast<void *>(&test_data_);
+        ext_part_funcs.decision_mode = WHOLE_TREE_DECISION;
+        ext_part_funcs.create_model = ext_part_create_model;
+        ext_part_funcs.send_features = ext_part_send_features;
+        ext_part_funcs.get_partition_decision =
+            ext_part_get_partition_decision_whole_tree;
+        ext_part_funcs.send_partition_stats = ext_part_send_partition_stats;
+        ext_part_funcs.delete_model = ext_part_delete_model;
 
-      encoder->Control(AOME_SET_CPUUSED, cpu_used_);
-      encoder->Control(AOME_SET_ENABLEAUTOALTREF, 1);
-      if (use_external_partition_) {
-        encoder->Control(AV1E_SET_EXTERNAL_PARTITION, &ext_part_funcs);
-      }
-      if (partition_control_mode_ == -1) {
-        encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 128);
-        encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
-      } else {
-        switch (partition_control_mode_) {
-          case 1:
-            encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 64);
-            encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 64);
-            break;
-          case 2:
-            encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 4);
-            encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
-            break;
-          default: assert(0 && "Invalid partition control mode."); break;
+        encoder->Control(AOME_SET_CPUUSED, cpu_used_);
+        encoder->Control(AOME_SET_ENABLEAUTOALTREF, 1);
+        if (use_external_partition_) {
+          encoder->Control(AV1E_SET_EXTERNAL_PARTITION, &ext_part_funcs);
         }
+        if (partition_control_mode_ == -1) {
+          encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 128);
+          encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
+        } else {
+          switch (partition_control_mode_) {
+            case 1:
+              encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 64);
+              encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 64);
+              break;
+            case 2:
+              encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 4);
+              encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
+              break;
+            default: assert(0 && "Invalid partition control mode."); break;
+          }
+        }
+      } else if (decision_mode_ == RECURSIVE_DECISION) {
+        aom_ext_part_funcs_t ext_part_funcs;
+        ext_part_funcs.priv = reinterpret_cast<void *>(&test_data_);
+        ext_part_funcs.decision_mode = RECURSIVE_DECISION;
+        ext_part_funcs.create_model = ext_part_create_model;
+        ext_part_funcs.send_features = ext_part_send_features;
+        ext_part_funcs.get_partition_decision =
+            ext_part_get_partition_decision_recursive;
+        ext_part_funcs.send_partition_stats = ext_part_send_partition_stats;
+        ext_part_funcs.delete_model = ext_part_delete_model;
+
+        encoder->Control(AOME_SET_CPUUSED, cpu_used_);
+        encoder->Control(AOME_SET_ENABLEAUTOALTREF, 1);
+        if (use_external_partition_) {
+          encoder->Control(AV1E_SET_EXTERNAL_PARTITION, &ext_part_funcs);
+        }
+        if (partition_control_mode_ == -1) {
+          encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 128);
+          encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
+        } else {
+          switch (partition_control_mode_) {
+            case 1:
+              encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 64);
+              encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 64);
+              break;
+            case 2:
+              encoder->Control(AV1E_SET_MAX_PARTITION_SIZE, 4);
+              encoder->Control(AV1E_SET_MIN_PARTITION_SIZE, 4);
+              break;
+            default: assert(0 && "Invalid partition control mode."); break;
+          }
+        }
+      } else {
+        assert(0 && "Invalid decision mode.");
       }
     }
   }
@@ -285,6 +364,7 @@
   bool use_external_partition_ = false;
   TestData test_data_;
   int partition_control_mode_ = -1;
+  aom_ext_part_decision_mode_t decision_mode_;
 };
 
 // Encode twice and expect the same psnr value.
@@ -298,17 +378,39 @@
   ::libaom_test::Y4mVideoSource video("paris_352_288_30.y4m", 0, kFrameNum);
   SetExternalPartition(false);
   SetPartitionControlMode(2);
+  SetDecisionMode(WHOLE_TREE_DECISION);
   ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
   const double psnr = GetAveragePsnr();
 
   SetExternalPartition(true);
   SetPartitionControlMode(2);
+  SetDecisionMode(WHOLE_TREE_DECISION);
   ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
   const double psnr2 = GetAveragePsnr();
 
+  printf("psnr %.5f\n", psnr);
+
   EXPECT_DOUBLE_EQ(psnr, psnr2);
 }
 
+TEST_P(ExternalPartitionTestAPI, RecursivePartition) {
+  ::libaom_test::Y4mVideoSource video("paris_352_288_30.y4m", 0, kFrameNum);
+  SetExternalPartition(false);
+  SetPartitionControlMode(1);
+  SetDecisionMode(RECURSIVE_DECISION);
+  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
+  const double psnr = GetAveragePsnr();
+
+  SetExternalPartition(true);
+  SetPartitionControlMode(1);
+  SetDecisionMode(RECURSIVE_DECISION);
+  ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
+  const double psnr2 = GetAveragePsnr();
+
+  const double psnr_thresh = 0.001;
+  EXPECT_NEAR(psnr, psnr2, psnr_thresh);
+}
+
 AV1_INSTANTIATE_TEST_SUITE(ExternalPartitionTestAPI,
                            ::testing::Values(::libaom_test::kTwoPassGood),
                            ::testing::Values(4));  // cpu_used