Move skip_intra out of av1_rd_pick_inter_mode()

This is a refactoring CL.

Change-Id: Ibc1254b27976e1be3483415d297622cb4c94614e
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 4ea08e4..afb740a 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5271,6 +5271,84 @@
   }
 }
 
+#if !CONFIG_REALTIME_ONLY
+// Prepare inter_cost and intra_cost from TPL stats, which are used as ML
+// features in intra mode pruning.
+static AOM_INLINE void calculate_cost_from_tpl_data(
+    const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row,
+    int mi_col, int64_t *inter_cost, int64_t *intra_cost) {
+  const AV1_COMMON *const cm = &cpi->common;
+  // Only consider full SB.
+  const BLOCK_SIZE sb_size = cm->seq_params->sb_size;
+  const int tpl_bsize_1d = cpi->ppi->tpl_data.tpl_bsize_1d;
+  const int len = (block_size_wide[sb_size] / tpl_bsize_1d) *
+                  (block_size_high[sb_size] / tpl_bsize_1d);
+  SuperBlockEnc *sb_enc = &x->sb_enc;
+  if (sb_enc->tpl_data_count == len) {
+    const BLOCK_SIZE tpl_bsize = convert_length_to_bsize(tpl_bsize_1d);
+    const int tpl_stride = sb_enc->tpl_stride;
+    const int tplw = mi_size_wide[tpl_bsize];
+    const int tplh = mi_size_high[tpl_bsize];
+    const int nw = mi_size_wide[bsize] / tplw;
+    const int nh = mi_size_high[bsize] / tplh;
+    if (nw >= 1 && nh >= 1) {
+      const int of_h = mi_row % mi_size_high[sb_size];
+      const int of_w = mi_col % mi_size_wide[sb_size];
+      const int start = of_h / tplh * tpl_stride + of_w / tplw;
+
+      for (int k = 0; k < nh; k++) {
+        for (int l = 0; l < nw; l++) {
+          *inter_cost += sb_enc->tpl_inter_cost[start + k * tpl_stride + l];
+          *intra_cost += sb_enc->tpl_intra_cost[start + k * tpl_stride + l];
+        }
+      }
+      *inter_cost /= nw * nh;
+      *intra_cost /= nw * nh;
+    }
+  }
+}
+#endif  // !CONFIG_REALTIME_ONLY
+
+// When the speed feature skip_intra_in_interframe > 0, enable ML model to prune
+// intra mode search.
+static AOM_INLINE void skip_intra_modes_in_interframe(
+    AV1_COMMON *const cm, struct macroblock *x, BLOCK_SIZE bsize,
+    InterModeSearchState *search_state, int64_t inter_cost, int64_t intra_cost,
+    int skip_intra_in_interframe) {
+  MACROBLOCKD *const xd = &x->e_mbd;
+  if (inter_cost >= 0 && intra_cost >= 0) {
+    aom_clear_system_state();
+    const NN_CONFIG *nn_config = (AOMMIN(cm->width, cm->height) <= 480)
+                                     ? &av1_intrap_nn_config
+                                     : &av1_intrap_hd_nn_config;
+    float nn_features[6];
+    float scores[2] = { 0.0f };
+
+    nn_features[0] = (float)search_state->best_mbmode.skip_txfm;
+    nn_features[1] = (float)mi_size_wide_log2[bsize];
+    nn_features[2] = (float)mi_size_high_log2[bsize];
+    nn_features[3] = (float)intra_cost;
+    nn_features[4] = (float)inter_cost;
+    const int ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
+    const int ac_q_max = av1_ac_quant_QTX(255, 0, xd->bd);
+    nn_features[5] = (float)(ac_q_max / ac_q);
+
+    av1_nn_predict(nn_features, nn_config, 1, scores);
+    aom_clear_system_state();
+
+    // For two parameters, the max prob returned from av1_nn_softmax equals
+    // 1.0 / (1.0 + e^(-|diff_score|)). Here use scores directly to avoid the
+    // calling of av1_nn_softmax.
+    const float thresh[2] = { 1.4f, 1.4f };
+    if (scores[1] > scores[0] + thresh[skip_intra_in_interframe - 1]) {
+      search_state->intra_search_state.skip_intra_modes = 1;
+    }
+  } else if ((search_state->best_mbmode.skip_txfm) &&
+             (skip_intra_in_interframe >= 2)) {
+    search_state->intra_search_state.skip_intra_modes = 1;
+  }
+}
+
 // TODO(chiyotsai@google.com): See the todo for av1_rd_pick_intra_mode_sb.
 void av1_rd_pick_inter_mode(struct AV1_COMP *cpi, struct TileDataEnc *tile_data,
                             struct macroblock *x, struct RD_STATS *rd_cost,
@@ -5420,36 +5498,9 @@
   const int do_pruning =
       (AOMMIN(cm->width, cm->height) > 480 && cpi->speed <= 1) ? 0 : 1;
   if (do_pruning && sf->intra_sf.skip_intra_in_interframe &&
-      cpi->oxcf.algo_cfg.enable_tpl_model) {
-    // Only consider full SB.
-    const BLOCK_SIZE sb_size = cm->seq_params->sb_size;
-    const int tpl_bsize_1d = cpi->ppi->tpl_data.tpl_bsize_1d;
-    const int len = (block_size_wide[sb_size] / tpl_bsize_1d) *
-                    (block_size_high[sb_size] / tpl_bsize_1d);
-    SuperBlockEnc *sb_enc = &x->sb_enc;
-    if (sb_enc->tpl_data_count == len) {
-      const BLOCK_SIZE tpl_bsize = convert_length_to_bsize(tpl_bsize_1d);
-      const int tpl_stride = sb_enc->tpl_stride;
-      const int tplw = mi_size_wide[tpl_bsize];
-      const int tplh = mi_size_high[tpl_bsize];
-      const int nw = mi_size_wide[bsize] / tplw;
-      const int nh = mi_size_high[bsize] / tplh;
-      if (nw >= 1 && nh >= 1) {
-        const int of_h = mi_row % mi_size_high[sb_size];
-        const int of_w = mi_col % mi_size_wide[sb_size];
-        const int start = of_h / tplh * tpl_stride + of_w / tplw;
-
-        for (int k = 0; k < nh; k++) {
-          for (int l = 0; l < nw; l++) {
-            inter_cost += sb_enc->tpl_inter_cost[start + k * tpl_stride + l];
-            intra_cost += sb_enc->tpl_intra_cost[start + k * tpl_stride + l];
-          }
-        }
-        inter_cost /= nw * nh;
-        intra_cost /= nw * nh;
-      }
-    }
-  }
+      cpi->oxcf.algo_cfg.enable_tpl_model)
+    calculate_cost_from_tpl_data(cpi, x, bsize, mi_row, mi_col, &inter_cost,
+                                 &intra_cost);
 #endif  // !CONFIG_REALTIME_ONLY
 
   // Initialize best mode stats for winner mode processing
@@ -5634,39 +5685,9 @@
   const unsigned int src_var_thresh_intra_skip = 1;
   const int skip_intra_in_interframe = sf->intra_sf.skip_intra_in_interframe;
   if (skip_intra_in_interframe &&
-      (x->source_variance > src_var_thresh_intra_skip)) {
-    if (inter_cost >= 0 && intra_cost >= 0) {
-      aom_clear_system_state();
-      const NN_CONFIG *nn_config = (AOMMIN(cm->width, cm->height) <= 480)
-                                       ? &av1_intrap_nn_config
-                                       : &av1_intrap_hd_nn_config;
-      float nn_features[6];
-      float scores[2] = { 0.0f };
-
-      nn_features[0] = (float)search_state.best_mbmode.skip_txfm;
-      nn_features[1] = (float)mi_size_wide_log2[bsize];
-      nn_features[2] = (float)mi_size_high_log2[bsize];
-      nn_features[3] = (float)intra_cost;
-      nn_features[4] = (float)inter_cost;
-      const int ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
-      const int ac_q_max = av1_ac_quant_QTX(255, 0, xd->bd);
-      nn_features[5] = (float)(ac_q_max / ac_q);
-
-      av1_nn_predict(nn_features, nn_config, 1, scores);
-      aom_clear_system_state();
-
-      // For two parameters, the max prob returned from av1_nn_softmax equals
-      // 1.0 / (1.0 + e^(-|diff_score|)). Here use scores directly to avoid the
-      // calling of av1_nn_softmax.
-      const float thresh[2] = { 1.4f, 1.4f };
-      if (scores[1] > scores[0] + thresh[skip_intra_in_interframe - 1]) {
-        search_state.intra_search_state.skip_intra_modes = 1;
-      }
-    } else if ((search_state.best_mbmode.skip_txfm) &&
-               (skip_intra_in_interframe >= 2)) {
-      search_state.intra_search_state.skip_intra_modes = 1;
-    }
-  }
+      (x->source_variance > src_var_thresh_intra_skip))
+    skip_intra_modes_in_interframe(cm, x, bsize, &search_state, inter_cost,
+                                   intra_cost, skip_intra_in_interframe);
 
   const unsigned int intra_ref_frame_cost = ref_costs_single[INTRA_FRAME];
   search_intra_modes_in_interframe(&search_state, cpi, x, rd_cost, bsize, ctx,