Extend av1_ml_predict_breakout() to hbd encoding

This CL enables the use of av1_ml_predict_breakout() to
terminate partition search based on partition NONE results
for high bit-depth encoding.

         Instruction Count        BD-Rate Loss(%)
cpu-used    Reduction(%)     avg.psnr  ovr.psnr   ssim
   1          2.934           0.0312    0.0313    0.0061
   2          2.231           0.0218    0.0222    0.0078
   3          0.558           0.0201    0.0223    0.0410
   4          0.667           0.0122    0.0125    0.0060
   5          0.285           0.0140    0.0218    0.0391
   6          0.127           0.0147    0.0149    0.0105

STATS_CHANGED for hbd encoding

Change-Id: I9f7847d734e755fa0cb17622be7aa35a72725c67
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index 412e761..a56aba1 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -3051,10 +3051,10 @@
       !x->e_mbd.lossless[xd->mi[0]->segment_id] && ctx_none->skippable) {
     const int use_ml_based_breakout =
         bsize <= cpi->sf.part_sf.use_square_partition_only_threshold &&
-        bsize > BLOCK_4X4 && xd->bd == 8;
+        bsize > BLOCK_4X4 && cpi->sf.part_sf.ml_predict_breakout_level >= 1;
     if (use_ml_based_breakout) {
-      if (av1_ml_predict_breakout(cpi, bsize, x, this_rdc,
-                                  *pb_source_variance)) {
+      if (av1_ml_predict_breakout(cpi, bsize, x, this_rdc, *pb_source_variance,
+                                  xd->bd)) {
         part_search_state->do_square_split = 0;
         part_search_state->do_rectangular_split = 0;
       }
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index 3f12e1f..686623b 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -1230,7 +1230,7 @@
 int av1_ml_predict_breakout(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
                             const MACROBLOCK *const x,
                             const RD_STATS *const rd_stats,
-                            unsigned int pb_source_variance) {
+                            unsigned int pb_source_variance, int bit_depth) {
   const NN_CONFIG *nn_config = NULL;
   int thresh = 0;
   switch (bsize) {
@@ -1258,6 +1258,12 @@
   }
   if (!nn_config || thresh < 0) return 0;
 
+  const float ml_predict_breakout_thresh_scale[3] = { 1.15f, 1.05f, 1.0f };
+  thresh =
+      (int)((float)thresh *
+            ml_predict_breakout_thresh_scale[cpi->sf.part_sf
+                                                 .ml_predict_breakout_level]);
+
   // Generate feature values.
   float features[FEATURES];
   int feature_index = 0;
@@ -1275,7 +1281,7 @@
 
   features[feature_index++] = (float)pb_source_variance;
 
-  const int dc_q = (int)x->plane[0].dequant_QTX[0];
+  const int dc_q = (int)x->plane[0].dequant_QTX[0] >> (bit_depth - 8);
   features[feature_index++] = (float)(dc_q * dc_q) / 256.0f;
   assert(feature_index == FEATURES);
 
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index 1c83551..0527a94 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -157,7 +157,7 @@
 int av1_ml_predict_breakout(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
                             const MACROBLOCK *const x,
                             const RD_STATS *const rd_stats,
-                            unsigned int pb_source_variance);
+                            unsigned int pb_source_variance, int bit_depth);
 
 // The first round of partition pruning determined before any partition
 // has been tested. The decisions will be updated and passed back
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index f0cab24..47b0fc3 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -325,6 +325,7 @@
       boosted || gf_group->update_type[gf_group->index] == INTNL_ARF_UPDATE;
   const int allow_screen_content_tools =
       cm->features.allow_screen_content_tools;
+  const int use_hbd = cpi->oxcf.use_highbitdepth;
   if (!cpi->oxcf.tile_cfg.enable_large_scale_tile) {
     sf->hl_sf.high_precision_mv_usage = LAST_MV_DATA;
   }
@@ -339,6 +340,7 @@
   sf->part_sf.ml_prune_rect_partition = 1;
   sf->part_sf.prune_ext_partition_types_search_level = 1;
   sf->part_sf.simple_motion_search_prune_rect = 1;
+  sf->part_sf.ml_predict_breakout_level = use_hbd ? 0 : 2;
 
   sf->inter_sf.disable_wedge_search_var_thresh = 0;
   // TODO(debargha): Test, tweak and turn on either 1 or 2
@@ -388,6 +390,7 @@
     // simple_motion_search_split in partition search function and set the
     // speed feature accordingly
     sf->part_sf.simple_motion_search_split = allow_screen_content_tools ? 1 : 2;
+    sf->part_sf.ml_predict_breakout_level = use_hbd ? 1 : 2;
 
     sf->mv_sf.exhaustive_searches_thresh <<= 1;
     sf->mv_sf.obmc_full_pixel_search_level = 1;
@@ -544,6 +547,7 @@
     sf->part_sf.simple_motion_search_reduce_search_steps = 4;
     sf->part_sf.prune_ab_partition_using_split_info = 1;
     sf->part_sf.early_term_after_none_split = 1;
+    sf->part_sf.ml_predict_breakout_level = 2;
 
     sf->inter_sf.alt_ref_search_fp = 1;
     sf->inter_sf.txfm_rd_gate_level = boosted ? 0 : 4;
@@ -1014,6 +1018,7 @@
   part_sf->prune_4_partition_using_split_info = 0;
   part_sf->prune_ab_partition_using_split_info = 0;
   part_sf->early_term_after_none_split = 0;
+  part_sf->ml_predict_breakout_level = 0;
 }
 
 static AOM_INLINE void init_mv_sf(MV_SPEED_FEATURES *mv_sf) {
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 7a6051a..20e8056 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -473,6 +473,11 @@
   // Terminate partition search for child partition,
   // when NONE and SPLIT partition rd_costs are INT64_MAX.
   int early_term_after_none_split;
+
+  // Level used to adjust threshold for av1_ml_predict_breakout(). At lower
+  // levels, more conservative threshold is used. Value of 2 corresponds to
+  // default case with no adjustment to lbd thresholds.
+  int ml_predict_breakout_level;
 } PARTITION_SPEED_FEATURES;
 
 typedef struct MV_SPEED_FEATURES {