Move skip_intra out of av1_rd_pick_inter_mode()
This is a refactoring CL.
Change-Id: Ibc1254b27976e1be3483415d297622cb4c94614e
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 4ea08e4..afb740a 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5271,6 +5271,84 @@
}
}
+#if !CONFIG_REALTIME_ONLY
+// Prepare inter_cost and intra_cost from TPL stats, which are used as ML
+// features in intra mode pruning.
+static AOM_INLINE void calculate_cost_from_tpl_data(
+ const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row,
+ int mi_col, int64_t *inter_cost, int64_t *intra_cost) {
+ const AV1_COMMON *const cm = &cpi->common;
+ // Only consider full SB.
+ const BLOCK_SIZE sb_size = cm->seq_params->sb_size;
+ const int tpl_bsize_1d = cpi->ppi->tpl_data.tpl_bsize_1d;
+ const int len = (block_size_wide[sb_size] / tpl_bsize_1d) *
+ (block_size_high[sb_size] / tpl_bsize_1d);
+ SuperBlockEnc *sb_enc = &x->sb_enc;
+ if (sb_enc->tpl_data_count == len) {
+ const BLOCK_SIZE tpl_bsize = convert_length_to_bsize(tpl_bsize_1d);
+ const int tpl_stride = sb_enc->tpl_stride;
+ const int tplw = mi_size_wide[tpl_bsize];
+ const int tplh = mi_size_high[tpl_bsize];
+ const int nw = mi_size_wide[bsize] / tplw;
+ const int nh = mi_size_high[bsize] / tplh;
+ if (nw >= 1 && nh >= 1) {
+ const int of_h = mi_row % mi_size_high[sb_size];
+ const int of_w = mi_col % mi_size_wide[sb_size];
+ const int start = of_h / tplh * tpl_stride + of_w / tplw;
+
+ for (int k = 0; k < nh; k++) {
+ for (int l = 0; l < nw; l++) {
+ *inter_cost += sb_enc->tpl_inter_cost[start + k * tpl_stride + l];
+ *intra_cost += sb_enc->tpl_intra_cost[start + k * tpl_stride + l];
+ }
+ }
+ *inter_cost /= nw * nh;
+ *intra_cost /= nw * nh;
+ }
+ }
+}
+#endif // !CONFIG_REALTIME_ONLY
+
+// When the speed feature skip_intra_in_interframe > 0, enable ML model to prune
+// intra mode search.
+static AOM_INLINE void skip_intra_modes_in_interframe(
+ AV1_COMMON *const cm, struct macroblock *x, BLOCK_SIZE bsize,
+ InterModeSearchState *search_state, int64_t inter_cost, int64_t intra_cost,
+ int skip_intra_in_interframe) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ if (inter_cost >= 0 && intra_cost >= 0) {
+ aom_clear_system_state();
+ const NN_CONFIG *nn_config = (AOMMIN(cm->width, cm->height) <= 480)
+ ? &av1_intrap_nn_config
+ : &av1_intrap_hd_nn_config;
+ float nn_features[6];
+ float scores[2] = { 0.0f };
+
+ nn_features[0] = (float)search_state->best_mbmode.skip_txfm;
+ nn_features[1] = (float)mi_size_wide_log2[bsize];
+ nn_features[2] = (float)mi_size_high_log2[bsize];
+ nn_features[3] = (float)intra_cost;
+ nn_features[4] = (float)inter_cost;
+ const int ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
+ const int ac_q_max = av1_ac_quant_QTX(255, 0, xd->bd);
+ nn_features[5] = (float)(ac_q_max / ac_q);
+
+ av1_nn_predict(nn_features, nn_config, 1, scores);
+ aom_clear_system_state();
+
+ // For two parameters, the max prob returned from av1_nn_softmax equals
+ // 1.0 / (1.0 + e^(-|diff_score|)). Here use scores directly to avoid the
+ // calling of av1_nn_softmax.
+ const float thresh[2] = { 1.4f, 1.4f };
+ if (scores[1] > scores[0] + thresh[skip_intra_in_interframe - 1]) {
+ search_state->intra_search_state.skip_intra_modes = 1;
+ }
+ } else if ((search_state->best_mbmode.skip_txfm) &&
+ (skip_intra_in_interframe >= 2)) {
+ search_state->intra_search_state.skip_intra_modes = 1;
+ }
+}
+
// TODO(chiyotsai@google.com): See the todo for av1_rd_pick_intra_mode_sb.
void av1_rd_pick_inter_mode(struct AV1_COMP *cpi, struct TileDataEnc *tile_data,
struct macroblock *x, struct RD_STATS *rd_cost,
@@ -5420,36 +5498,9 @@
const int do_pruning =
(AOMMIN(cm->width, cm->height) > 480 && cpi->speed <= 1) ? 0 : 1;
if (do_pruning && sf->intra_sf.skip_intra_in_interframe &&
- cpi->oxcf.algo_cfg.enable_tpl_model) {
- // Only consider full SB.
- const BLOCK_SIZE sb_size = cm->seq_params->sb_size;
- const int tpl_bsize_1d = cpi->ppi->tpl_data.tpl_bsize_1d;
- const int len = (block_size_wide[sb_size] / tpl_bsize_1d) *
- (block_size_high[sb_size] / tpl_bsize_1d);
- SuperBlockEnc *sb_enc = &x->sb_enc;
- if (sb_enc->tpl_data_count == len) {
- const BLOCK_SIZE tpl_bsize = convert_length_to_bsize(tpl_bsize_1d);
- const int tpl_stride = sb_enc->tpl_stride;
- const int tplw = mi_size_wide[tpl_bsize];
- const int tplh = mi_size_high[tpl_bsize];
- const int nw = mi_size_wide[bsize] / tplw;
- const int nh = mi_size_high[bsize] / tplh;
- if (nw >= 1 && nh >= 1) {
- const int of_h = mi_row % mi_size_high[sb_size];
- const int of_w = mi_col % mi_size_wide[sb_size];
- const int start = of_h / tplh * tpl_stride + of_w / tplw;
-
- for (int k = 0; k < nh; k++) {
- for (int l = 0; l < nw; l++) {
- inter_cost += sb_enc->tpl_inter_cost[start + k * tpl_stride + l];
- intra_cost += sb_enc->tpl_intra_cost[start + k * tpl_stride + l];
- }
- }
- inter_cost /= nw * nh;
- intra_cost /= nw * nh;
- }
- }
- }
+ cpi->oxcf.algo_cfg.enable_tpl_model)
+ calculate_cost_from_tpl_data(cpi, x, bsize, mi_row, mi_col, &inter_cost,
+ &intra_cost);
#endif // !CONFIG_REALTIME_ONLY
// Initialize best mode stats for winner mode processing
@@ -5634,39 +5685,9 @@
const unsigned int src_var_thresh_intra_skip = 1;
const int skip_intra_in_interframe = sf->intra_sf.skip_intra_in_interframe;
if (skip_intra_in_interframe &&
- (x->source_variance > src_var_thresh_intra_skip)) {
- if (inter_cost >= 0 && intra_cost >= 0) {
- aom_clear_system_state();
- const NN_CONFIG *nn_config = (AOMMIN(cm->width, cm->height) <= 480)
- ? &av1_intrap_nn_config
- : &av1_intrap_hd_nn_config;
- float nn_features[6];
- float scores[2] = { 0.0f };
-
- nn_features[0] = (float)search_state.best_mbmode.skip_txfm;
- nn_features[1] = (float)mi_size_wide_log2[bsize];
- nn_features[2] = (float)mi_size_high_log2[bsize];
- nn_features[3] = (float)intra_cost;
- nn_features[4] = (float)inter_cost;
- const int ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
- const int ac_q_max = av1_ac_quant_QTX(255, 0, xd->bd);
- nn_features[5] = (float)(ac_q_max / ac_q);
-
- av1_nn_predict(nn_features, nn_config, 1, scores);
- aom_clear_system_state();
-
- // For two parameters, the max prob returned from av1_nn_softmax equals
- // 1.0 / (1.0 + e^(-|diff_score|)). Here use scores directly to avoid the
- // calling of av1_nn_softmax.
- const float thresh[2] = { 1.4f, 1.4f };
- if (scores[1] > scores[0] + thresh[skip_intra_in_interframe - 1]) {
- search_state.intra_search_state.skip_intra_modes = 1;
- }
- } else if ((search_state.best_mbmode.skip_txfm) &&
- (skip_intra_in_interframe >= 2)) {
- search_state.intra_search_state.skip_intra_modes = 1;
- }
- }
+ (x->source_variance > src_var_thresh_intra_skip))
+ skip_intra_modes_in_interframe(cm, x, bsize, &search_state, inter_cost,
+ intra_cost, skip_intra_in_interframe);
const unsigned int intra_ref_frame_cost = ref_costs_single[INTRA_FRAME];
search_intra_modes_in_interframe(&search_state, cpi, x, rd_cost, bsize, ctx,