External partition: Handle invalid decisions.

(1). If the external model returns an invalid decision, use the
baseline's decision.
(2). Set the condition of calling the external model the same
as the baseline, such that we expect the coding result is the same
when the external model's decision is invalid.

Change-Id: I4f8efece0bdb3814cfa661a6a4cc2a3ba298c6ea
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index ddd8995..bfa1dfd 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -3283,8 +3283,8 @@
                                         PartitionSearchState *part_search_state,
                                         RD_STATS *best_rdc,
                                         unsigned int *pb_source_variance) {
-  if (av1_ext_ml_model_decision_after_none(cpi, x, sms_tree, part_search_state,
-                                           *pb_source_variance)) {
+  if (av1_ext_ml_model_decision_after_none(
+          cpi, x, sms_tree, ctx_none, part_search_state, *pb_source_variance)) {
     return;
   }
 
diff --git a/av1/encoder/partition_strategy.c b/av1/encoder/partition_strategy.c
index f4d0758..9f56538 100644
--- a/av1/encoder/partition_strategy.c
+++ b/av1/encoder/partition_strategy.c
@@ -1702,10 +1702,13 @@
 // none partition is searched.
 static void prepare_features_after_part_none(
     AV1_COMP *const cpi, AV1_COMMON *const cm, MACROBLOCK *const x,
-    SIMPLE_MOTION_DATA_TREE *sms_tree, PartitionSearchState *part_search_state,
+    SIMPLE_MOTION_DATA_TREE *sms_tree, PICK_MODE_CONTEXT *ctx_none,
+    PartitionSearchState *part_search_state,
     const unsigned int pb_source_variance,
     aom_partition_features_t *const features) {
   aom_clear_system_state();
+  const CommonModeInfoParams *const mi_params = &cm->mi_params;
+  const MACROBLOCKD *const xd = &x->e_mbd;
   PartitionBlkParams blk_params = part_search_state->part_blk_params;
   RD_STATS *this_rdc = &part_search_state->this_rdc;
   const BLOCK_SIZE bsize = blk_params.bsize;
@@ -1721,27 +1724,52 @@
   const int dc_q = (int)x->plane[0].dequant_QTX[0] >> (bit_depth - 8);
   const float dc_q_f = (float)(dc_q * dc_q) / 256.0f;
 
-  features->after_part_none.rate = rate_f;
-  features->after_part_none.dist = dist_f;
-  // TODO(chengchen): is this normalized variance?
-  features->after_part_none.source_variance = (float)pb_source_variance;
-  features->after_part_none.q = dc_q_f;
+  if (!frame_is_intra_only(cm) &&
+      (part_search_state->do_square_split ||
+       part_search_state->do_rectangular_split) &&
+      !x->e_mbd.lossless[xd->mi[0]->segment_id] && ctx_none->skippable &&
+      bsize <= cpi->sf.part_sf.use_square_partition_only_threshold &&
+      bsize > BLOCK_4X4 && cpi->sf.part_sf.ml_predict_breakout_level >= 1) {
+    features->after_part_none.rate = rate_f;
+    features->after_part_none.dist = dist_f;
+    // TODO(chengchen): is this normalized variance?
+    features->after_part_none.source_variance = (float)pb_source_variance;
+    features->after_part_none.q = dc_q_f;
+  } else {
+    features->after_part_none.rate = -1;
+    features->after_part_none.dist = -1;
+    features->after_part_none.source_variance = -1;
+    features->after_part_none.q = -1;
+  }
 
-  // features below are used to decide "terminate_partition_search"
-  // defined in av1_simple_motion_search_early_term_none().
-  float features_terminate[FEATURE_SIZE_SMS_TERM_NONE] = { 0.0f };
-  simple_motion_search_prune_part_features(
-      cpi, x, sms_tree, blk_params.mi_row, blk_params.mi_col, bsize,
-      features_terminate, FEATURE_SMS_PRUNE_PART_FLAG);
-  int f_idx = FEATURE_SIZE_SMS_PRUNE_PART;
+  if (cpi->sf.part_sf.simple_motion_search_early_term_none && cm->show_frame &&
+      !frame_is_intra_only(cm) && bsize >= BLOCK_16X16 &&
+      blk_params.mi_row_edge < mi_params->mi_rows &&
+      blk_params.mi_col_edge < mi_params->mi_cols &&
+      this_rdc->rdcost < INT64_MAX && this_rdc->rdcost >= 0 &&
+      this_rdc->rate < INT_MAX && this_rdc->rate >= 0 &&
+      (part_search_state->do_square_split ||
+       part_search_state->do_rectangular_split)) {
+    // features below are used to decide "terminate_partition_search"
+    // defined in av1_simple_motion_search_early_term_none().
+    float features_terminate[FEATURE_SIZE_SMS_TERM_NONE] = { 0.0f };
+    simple_motion_search_prune_part_features(
+        cpi, x, sms_tree, blk_params.mi_row, blk_params.mi_col, bsize,
+        features_terminate, FEATURE_SMS_PRUNE_PART_FLAG);
+    int f_idx = FEATURE_SIZE_SMS_PRUNE_PART;
 
-  features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->rate);
-  features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->dist);
-  features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->rdcost);
+    features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->rate);
+    features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->dist);
+    features_terminate[f_idx++] = logf(1.0f + (float)this_rdc->rdcost);
 
-  assert(f_idx == FEATURE_SIZE_SMS_TERM_NONE);
-  for (int i = 0; i < FEATURE_SIZE_SMS_TERM_NONE; ++i) {
-    features->after_part_none.f_terminate[i] = features_terminate[i];
+    assert(f_idx == FEATURE_SIZE_SMS_TERM_NONE);
+    for (int i = 0; i < FEATURE_SIZE_SMS_TERM_NONE; ++i) {
+      features->after_part_none.f_terminate[i] = features_terminate[i];
+    }
+  } else {
+    for (int i = 0; i < FEATURE_SIZE_SMS_TERM_NONE; ++i) {
+      features->after_part_none.f_terminate[i] = -1;
+    }
   }
 }
 
@@ -1765,97 +1793,124 @@
 
   // 31 features in av1_ml_early_term_after_split()
   int feature_idx = 0;
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)dc_q / 4.0f);
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)best_rdc->rdcost / bs / bs / 1024.0f);
+  if (cpi->sf.part_sf.ml_early_term_after_part_split_level &&
+      !frame_is_intra_only(cm) &&
+      !part_search_state->terminate_partition_search &&
+      part_search_state->do_rectangular_split &&
+      (part_search_state->partition_rect_allowed[HORZ] ||
+       part_search_state->partition_rect_allowed[VERT])) {
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)dc_q / 4.0f);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)best_rdc->rdcost / bs / bs / 1024.0f);
 
-  add_rd_feature(partition_none_rdcost, best_rdc->rdcost,
-                 features->after_part_split.f_terminate, &feature_idx);
-  add_rd_feature(partition_split_rdcost, best_rdc->rdcost,
-                 features->after_part_split.f_terminate, &feature_idx);
-
-  for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
-    add_rd_feature(split_block_rdcost[i], best_rdc->rdcost,
+    add_rd_feature(partition_none_rdcost, best_rdc->rdcost,
                    features->after_part_split.f_terminate, &feature_idx);
-    int min_bw = MAX_SB_SIZE_LOG2;
-    int min_bh = MAX_SB_SIZE_LOG2;
-    get_min_bsize(sms_tree->split[i], &min_bw, &min_bh);
-    features->after_part_split.f_terminate[feature_idx++] = (float)min_bw;
-    features->after_part_split.f_terminate[feature_idx++] = (float)min_bh;
-  }
+    add_rd_feature(partition_split_rdcost, best_rdc->rdcost,
+                   features->after_part_split.f_terminate, &feature_idx);
 
-  simple_motion_search_prune_part_features(cpi, x, sms_tree, blk_params.mi_row,
-                                           blk_params.mi_col, bsize, NULL,
-                                           FEATURE_SMS_PRUNE_PART_FLAG);
+    for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+      add_rd_feature(split_block_rdcost[i], best_rdc->rdcost,
+                     features->after_part_split.f_terminate, &feature_idx);
+      int min_bw = MAX_SB_SIZE_LOG2;
+      int min_bh = MAX_SB_SIZE_LOG2;
+      get_min_bsize(sms_tree->split[i], &min_bw, &min_bh);
+      features->after_part_split.f_terminate[feature_idx++] = (float)min_bw;
+      features->after_part_split.f_terminate[feature_idx++] = (float)min_bh;
+    }
 
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->sms_none_feat[1]);
+    simple_motion_search_prune_part_features(
+        cpi, x, sms_tree, blk_params.mi_row, blk_params.mi_col, bsize, NULL,
+        FEATURE_SMS_PRUNE_PART_FLAG);
 
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->split[0]->sms_none_feat[1]);
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->split[1]->sms_none_feat[1]);
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->split[2]->sms_none_feat[1]);
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->split[3]->sms_none_feat[1]);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->sms_none_feat[1]);
 
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->sms_rect_feat[1]);
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->sms_rect_feat[3]);
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->sms_rect_feat[5]);
-  features->after_part_split.f_terminate[feature_idx++] =
-      logf(1.0f + (float)sms_tree->sms_rect_feat[7]);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->split[0]->sms_none_feat[1]);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->split[1]->sms_none_feat[1]);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->split[2]->sms_none_feat[1]);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->split[3]->sms_none_feat[1]);
 
-  assert(feature_idx == 31);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->sms_rect_feat[1]);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->sms_rect_feat[3]);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->sms_rect_feat[5]);
+    features->after_part_split.f_terminate[feature_idx++] =
+        logf(1.0f + (float)sms_tree->sms_rect_feat[7]);
 
-  // 9 features av1_ml_prune_rect_partition()
-  for (int i = 0; i < 5; i++) features->after_part_split.f_prune_rect[i] = 1.0f;
-  if (part_search_state->none_rd > 0 && part_search_state->none_rd < 1000000000)
-    features->after_part_split.f_prune_rect[0] =
-        (float)part_search_state->none_rd / (float)best_rdc->rdcost;
-  for (int i = 0; i < SUB_PARTITIONS_SPLIT; i++) {
-    if (part_search_state->split_rd[i] > 0 &&
-        part_search_state->split_rd[i] < 1000000000)
-      features->after_part_split.f_prune_rect[1 + i] =
-          (float)part_search_state->split_rd[i] / (float)best_rdc->rdcost;
-  }
-
-  // Variance ratios
-  const MACROBLOCKD *const xd = &x->e_mbd;
-  int whole_block_variance;
-  if (is_cur_buf_hbd(xd)) {
-    whole_block_variance = av1_high_get_sby_perpixel_variance(
-        cpi, &x->plane[0].src, bsize, xd->bd);
+    assert(feature_idx == 31);
   } else {
-    whole_block_variance =
-        av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, bsize);
-  }
-  whole_block_variance = AOMMAX(whole_block_variance, 1);
-
-  int split_variance[SUB_PARTITIONS_SPLIT];
-  const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
-  struct buf_2d buf;
-  buf.stride = x->plane[0].src.stride;
-  const int bw = block_size_wide[bsize];
-  for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
-    const int x_idx = (i & 1) * bw / 2;
-    const int y_idx = (i >> 1) * bw / 2;
-    buf.buf = x->plane[0].src.buf + x_idx + y_idx * buf.stride;
-    if (is_cur_buf_hbd(xd)) {
-      split_variance[i] =
-          av1_high_get_sby_perpixel_variance(cpi, &buf, subsize, xd->bd);
-    } else {
-      split_variance[i] = av1_get_sby_perpixel_variance(cpi, &buf, subsize);
+    for (int i = 0; i < 31; ++i) {
+      features->after_part_split.f_terminate[i] = -1;
     }
   }
 
-  for (int i = 0; i < SUB_PARTITIONS_SPLIT; i++)
-    features->after_part_split.f_prune_rect[5 + i] =
-        (float)split_variance[i] / (float)whole_block_variance;
+  // 9 features av1_ml_prune_rect_partition()
+  if (!cpi->sf.part_sf.ml_early_term_after_part_split_level &&
+      cpi->sf.part_sf.ml_prune_partition && !frame_is_intra_only(cm) &&
+      (part_search_state->partition_rect_allowed[HORZ] ||
+       part_search_state->partition_rect_allowed[VERT]) &&
+      !(part_search_state->prune_rect_part[HORZ] ||
+        part_search_state->prune_rect_part[VERT]) &&
+      !part_search_state->terminate_partition_search) {
+    av1_setup_src_planes(x, cpi->source, blk_params.mi_row, blk_params.mi_col,
+                         av1_num_planes(cm), bsize);
+    for (int i = 0; i < 5; i++)
+      features->after_part_split.f_prune_rect[i] = 1.0f;
+    if (part_search_state->none_rd > 0 &&
+        part_search_state->none_rd < 1000000000)
+      features->after_part_split.f_prune_rect[0] =
+          (float)part_search_state->none_rd / (float)best_rdc->rdcost;
+    for (int i = 0; i < SUB_PARTITIONS_SPLIT; i++) {
+      if (part_search_state->split_rd[i] > 0 &&
+          part_search_state->split_rd[i] < 1000000000)
+        features->after_part_split.f_prune_rect[1 + i] =
+            (float)part_search_state->split_rd[i] / (float)best_rdc->rdcost;
+    }
+
+    // Variance ratios
+    const MACROBLOCKD *const xd = &x->e_mbd;
+    int whole_block_variance;
+    if (is_cur_buf_hbd(xd)) {
+      whole_block_variance = av1_high_get_sby_perpixel_variance(
+          cpi, &x->plane[0].src, bsize, xd->bd);
+    } else {
+      whole_block_variance =
+          av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, bsize);
+    }
+    whole_block_variance = AOMMAX(whole_block_variance, 1);
+
+    int split_variance[SUB_PARTITIONS_SPLIT];
+    const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
+    struct buf_2d buf;
+    buf.stride = x->plane[0].src.stride;
+    const int bw = block_size_wide[bsize];
+    for (int i = 0; i < SUB_PARTITIONS_SPLIT; ++i) {
+      const int x_idx = (i & 1) * bw / 2;
+      const int y_idx = (i >> 1) * bw / 2;
+      buf.buf = x->plane[0].src.buf + x_idx + y_idx * buf.stride;
+      if (is_cur_buf_hbd(xd)) {
+        split_variance[i] =
+            av1_high_get_sby_perpixel_variance(cpi, &buf, subsize, xd->bd);
+      } else {
+        split_variance[i] = av1_get_sby_perpixel_variance(cpi, &buf, subsize);
+      }
+    }
+
+    for (int i = 0; i < SUB_PARTITIONS_SPLIT; i++)
+      features->after_part_split.f_prune_rect[5 + i] =
+          (float)split_variance[i] / (float)whole_block_variance;
+  } else {
+    for (int i = 0; i < 9; i++) {
+      features->after_part_split.f_prune_rect[i] = -1;
+    }
+  }
 }
 
 // Prepare features for the external model. Specifically, features after
@@ -2015,10 +2070,8 @@
     int *partition_none_allowed, int *partition_horz_allowed,
     int *partition_vert_allowed, int *do_rectangular_split,
     int *do_square_split) {
-  // If we do not let the external partition model make decisions, return false.
-  return false;
-
   ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+  if (!ext_part_controller->ready) return false;
 
   // Setup features.
   aom_partition_features_t features;
@@ -2032,7 +2085,9 @@
 
   // Get partition decisions from the external model.
   aom_partition_decision_t decision;
-  av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+  const bool valid_decision =
+      av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+  if (!valid_decision) return false;
 
   // Populate decisions
   *partition_none_allowed = decision.partition_none_allowed;
@@ -2052,10 +2107,8 @@
     AV1_COMP *cpi,
     const float features_from_motion[FEATURE_SIZE_SMS_PRUNE_PART],
     int *prune_horz, int *prune_vert) {
-  // If we do not let the external partition model make decisions, return false.
-  return false;
-
   ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
+  if (!ext_part_controller->ready) return false;
 
   // Setup features.
   aom_partition_features_t features;
@@ -2069,7 +2122,9 @@
 
   // Get partition decisions from the external model.
   aom_partition_decision_t decision;
-  av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+  const bool valid_decision =
+      av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+  if (!valid_decision) return false;
 
   // Populate decisions
   *prune_horz = decision.prune_rect_part[HORZ];
@@ -2085,11 +2140,8 @@
 // terminate_partition_search
 bool av1_ext_ml_model_decision_after_none(
     AV1_COMP *const cpi, MACROBLOCK *const x, SIMPLE_MOTION_DATA_TREE *sms_tree,
-    PartitionSearchState *part_search_state,
+    PICK_MODE_CONTEXT *ctx_none, PartitionSearchState *part_search_state,
     const unsigned int pb_source_variance) {
-  // If we do not let the external partition model make decisions, return false.
-  return false;
-
   AV1_COMMON *const cm = &cpi->common;
   ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
 
@@ -2097,15 +2149,18 @@
     // Setup features.
     aom_partition_features_t features;
     features.id = FEATURE_AFTER_PART_NONE;
-    prepare_features_after_part_none(cpi, cm, x, sms_tree, part_search_state,
-                                     pb_source_variance, &features);
+    prepare_features_after_part_none(cpi, cm, x, sms_tree, ctx_none,
+                                     part_search_state, pb_source_variance,
+                                     &features);
 
     // Send necessary features to the external model.
     av1_ext_part_send_features(ext_part_controller, &features);
 
     // Get partition decisions from the external model.
     aom_partition_decision_t decision;
-    av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+    const bool valid_decision =
+        av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+    if (!valid_decision) return false;
 
     // Populate decisions
     part_search_state->do_square_split = decision.do_square_split;
@@ -2129,9 +2184,6 @@
     PartitionSearchState *part_search_state, RD_STATS *best_rdc,
     const int64_t partition_none_rdcost, const int64_t partition_split_rdcost,
     const int64_t *split_block_rdcost) {
-  // If we do not let the external partition model make decisions, return false.
-  return false;
-
   const AV1_COMMON *const cm = &cpi->common;
   ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
 
@@ -2148,7 +2200,9 @@
 
     // Get partition decisions from the external model.
     aom_partition_decision_t decision;
-    av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+    const bool valid_decision =
+        av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+    if (!valid_decision) return false;
 
     // Populate decisions
     part_search_state->terminate_partition_search =
@@ -2176,9 +2230,6 @@
     int partition_horz_allowed, int partition_vert_allowed,
     int *horza_partition_allowed, int *horzb_partition_allowed,
     int *verta_partition_allowed, int *vertb_partition_allowed) {
-  // If we do not let the external partition model make decisions, return false.
-  return false;
-
   (void)pb_source_variance;
   const AV1_COMMON *const cm = &cpi->common;
   ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
@@ -2202,7 +2253,9 @@
 
     // Get partition decisions from the external model.
     aom_partition_decision_t decision;
-    av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+    const bool valid_decision =
+        av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+    if (!valid_decision) return false;
 
     // Populate decisions
     *horza_partition_allowed = decision.horza_partition_allowed;
@@ -2226,9 +2279,6 @@
     int64_t split_rd[SUB_PARTITIONS_SPLIT], int *const partition_horz4_allowed,
     int *const partition_vert4_allowed, unsigned int pb_source_variance,
     int mi_row, int mi_col) {
-  // If we do not let the external partition model make decisions, return false.
-  return false;
-
   const AV1_COMMON *const cm = &cpi->common;
   ExtPartController *const ext_part_controller = &cpi->ext_part_controller;
 
@@ -2245,7 +2295,9 @@
 
     // Get partition decisions from the external model.
     aom_partition_decision_t decision;
-    av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+    const bool valid_decision =
+        av1_ext_part_get_partition_decision(ext_part_controller, &decision);
+    if (!valid_decision) return false;
 
     // Populate decisions
     *partition_horz4_allowed = decision.partition_horz4_allowed;
diff --git a/av1/encoder/partition_strategy.h b/av1/encoder/partition_strategy.h
index f744f87..abd0067 100644
--- a/av1/encoder/partition_strategy.h
+++ b/av1/encoder/partition_strategy.h
@@ -276,7 +276,7 @@
 
 bool av1_ext_ml_model_decision_after_none(
     AV1_COMP *const cpi, MACROBLOCK *const x, SIMPLE_MOTION_DATA_TREE *sms_tree,
-    PartitionSearchState *part_search_state,
+    PICK_MODE_CONTEXT *ctx_none, PartitionSearchState *part_search_state,
     const unsigned int pb_source_variance);
 
 bool av1_ext_ml_model_decision_after_split(