Add a model to skip intra mode evaluation
Added a ML model to skip intra mode evaluation. This is enabled for
<=midres only currently. Other resolution support will be added.
Ran midres set borg test. Here are the results.
avg_psnr: ovr_psnr: ssim: avg_speedup over midres set:
speed 1: 0.009 0.002 -0.008 1.3%
speed 2: 0.005 0.013 -0.002 1.2%
speed 3: 0.006 0.003 0.010 0.7%
STATS_CHANGED
Change-Id: If18c64da015b00fbf2737a44beda3639681caa9d
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index bb8d386..7e50dfa 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -462,6 +462,16 @@
int cost_stride;
};
+// Only consider full SB, MC_FLOW_BSIZE_1D = 16.
+static INLINE int tpl_blocks_in_sb(BLOCK_SIZE bsize) {
+ switch (bsize) {
+ case BLOCK_64X64: return 16;
+ case BLOCK_128X128: return 64;
+ default: assert(0);
+ }
+ return -1;
+}
+
static INLINE int is_rect_tx_allowed_bsize(BLOCK_SIZE bsize) {
static const char LUT[BLOCK_SIZES_ALL] = {
0, // BLOCK_4X4
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 8557d50..e1fd083 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -3646,16 +3646,16 @@
assert(cpi->oxcf.enable_tpl_model == 2);
return 0;
}
-
+ if (cpi->oxcf.superres_mode != SUPERRES_NONE) return 0;
if (cpi->common.current_frame.frame_type == KEY_FRAME) return 0;
const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
if (update_type == INTNL_OVERLAY_UPDATE || update_type == OVERLAY_UPDATE)
return 0;
- const int gf_group_index = cpi->gf_group.index;
- if (gf_group_index <= 0 || cpi->gf_group.index >= cpi->gf_group.size)
- return 0;
+ assert(IMPLIES(cpi->gf_group.size > 0,
+ cpi->gf_group.index < cpi->gf_group.size));
AV1_COMMON *const cm = &cpi->common;
+ const int gf_group_index = cpi->gf_group.index;
TplDepFrame *tpl_frame = &cpi->tpl_frame[gf_group_index];
TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
int tpl_stride = tpl_frame->stride;
@@ -3663,6 +3663,7 @@
const int mi_high = mi_size_high[bsize];
if (tpl_frame->is_valid == 0) return 0;
+ if (gf_group_index >= MAX_LAG_BUFFERS) return 0;
int mi_count = 0;
const int mi_col_sr =
@@ -3671,6 +3672,9 @@
mi_col + mi_wide, cm->superres_scale_denominator);
const int mi_cols_sr = av1_pixels_to_mi(cm->superres_upscaled_width);
+ // TPL store unit size is not the same as the motion estimation unit size.
+ // Here always use motion estimation size to avoid getting repetitive inter/
+ // intra cost.
const BLOCK_SIZE tpl_bsize = convert_length_to_bsize(MC_FLOW_BSIZE_1D);
const int step = mi_size_wide[tpl_bsize];
assert(mi_size_wide[tpl_bsize] == mi_size_high[tpl_bsize]);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 83866b6..0e8b90d 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -59,6 +59,7 @@
#include "av1/encoder/rdopt.h"
#include "av1/encoder/reconinter_enc.h"
#include "av1/encoder/tokenize.h"
+#include "av1/encoder/tpl_model.h"
#include "av1/encoder/tx_prune_model_weights.h"
// Set this macro as 1 to collect data about tx size selection.
@@ -12589,6 +12590,37 @@
INT64_MAX, INT64_MAX, INT64_MAX,
INT64_MAX, INT64_MAX };
const int skip_ctx = av1_get_skip_context(xd);
+
+ // Prepared stats used later to check if we could skip intra mode eval.
+ int64_t inter_cost = -1;
+ int64_t intra_cost = -1;
+ // Now only use this for <=480p. Will try other resolutions.
+ if (sf->skip_intra_in_interframe && AOMMIN(cm->width, cm->height) <= 480) {
+ // Only consider full SB.
+ int len = tpl_blocks_in_sb(cm->seq_params.sb_size);
+ if (len == x->valid_cost_b) {
+ const BLOCK_SIZE tpl_bsize = convert_length_to_bsize(MC_FLOW_BSIZE_1D);
+ const int tplw = mi_size_wide[tpl_bsize];
+ const int tplh = mi_size_high[tpl_bsize];
+ const int nw = mi_size_wide[bsize] / tplw;
+ const int nh = mi_size_high[bsize] / tplh;
+ if (nw >= 1 && nh >= 1) {
+ const int of_h = mi_row % mi_size_high[cm->seq_params.sb_size];
+ const int of_w = mi_col % mi_size_wide[cm->seq_params.sb_size];
+ const int start = of_h / tplh * x->cost_stride + of_w / tplw;
+
+ for (int k = 0; k < nh; k++) {
+ for (int l = 0; l < nw; l++) {
+ inter_cost += x->inter_cost_b[start + k * x->cost_stride + l];
+ intra_cost += x->intra_cost_b[start + k * x->cost_stride + l];
+ }
+ }
+ inter_cost /= nw * nh;
+ intra_cost /= nw * nh;
+ }
+ }
+ }
+
for (int midx = 0; midx < MAX_MODES; ++midx) {
// After we done with single reference modes, find the 2nd best RD
// for a reference frame. Only search compound modes that have a reference
@@ -12897,9 +12929,34 @@
// Gate intra mode evaluation if best of inter is skip except when source
// variance is extremely low
- if ((search_state.best_mbmode.skip) && (sf->skip_intra_in_interframe >= 2) &&
- (x->source_variance > sf->src_var_thresh_intra_skip))
- search_state.skip_intra_modes = 1;
+ if (sf->skip_intra_in_interframe &&
+ (x->source_variance > sf->src_var_thresh_intra_skip)) {
+ if (inter_cost >= 0 && intra_cost >= 0) {
+ aom_clear_system_state();
+ const NN_CONFIG *nn_config = &av1_intrap_nn_config;
+ float features[6];
+ float scores[2] = { 0.0f };
+ float probs[2] = { 0.0f };
+
+ features[0] = (float)search_state.best_mbmode.skip;
+ features[1] = (float)mi_size_wide_log2[bsize];
+ features[2] = (float)mi_size_high_log2[bsize];
+ features[3] = (float)intra_cost;
+ features[4] = (float)inter_cost;
+ const int ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
+ const int ac_q_max = av1_ac_quant_QTX(255, 0, xd->bd);
+ features[5] = (float)(ac_q_max / ac_q);
+
+ av1_nn_predict(features, nn_config, 1, scores);
+ aom_clear_system_state();
+ av1_nn_softmax(scores, probs, 2);
+
+ if (probs[1] > 0.8) search_state.skip_intra_modes = 1;
+ } else if ((search_state.best_mbmode.skip) &&
+ (sf->skip_intra_in_interframe >= 2)) {
+ search_state.skip_intra_modes = 1;
+ }
+ }
const int intra_ref_frame_cost = ref_costs_single[INTRA_FRAME];
for (int j = 0; j < intra_mode_num; ++j) {
diff --git a/av1/encoder/tx_prune_model_weights.h b/av1/encoder/tx_prune_model_weights.h
index 76efe93..7693507 100644
--- a/av1/encoder/tx_prune_model_weights.h
+++ b/av1/encoder/tx_prune_model_weights.h
@@ -3313,6 +3313,89 @@
&av1_tx_split_nnconfig_16x64, // TX_64X16,
};
+// TODO(yunqing): put intra mode skipping model here temporarily. Will move to
+// a new file once it is all done.
+#define NUM_HIDDEN_LAYERS_12 1
+#define NUM_FEATURES_12 6
+#define NUM_LAYER_0_UNITS_12 24
+#define NUM_LOGITS_12 2
+
+static const float av1_intrap_hiddenlayer_0_kernel_12[] = {
+ 7.28372f, -1.3333898f, -1.3180022f, -0.007156151f, -0.40799126f,
+ -0.57538104f, -31.81647f, 6.7057495f, 6.351472f, -0.029544508f,
+ 0.026801195f, 1.12863f, -0.70769817f, -0.24183524f, 0.0649113f,
+ -0.7189517f, 0.21791299f, 0.12840256f, -0.56424767f, 0.16924907f,
+ 0.4605501f, -0.170895f, -0.60358995f, -0.15383226f, -4.0523643f,
+ 0.6961917f, 1.3100256f, -0.4189354f, 0.37264112f, -0.14555685f,
+ 10.628014f, 8.184437f, 8.941916f, -0.011731001f, -0.45127156f,
+ 0.42704004f, 36.84277f, 8.988796f, 8.844238f, 0.00030091056f,
+ -0.022038324f, 1.3566176f, -8.863219f, -0.84811693f, -1.0908632f,
+ 0.00023130262f, -1.0698471f, -6.755927f, 7.1711984f, 4.7216063f,
+ 3.5099216f, -0.6650184f, 0.5935173f, -0.6696286f, 11.8595295f,
+ 0.3001874f, 0.29822728f, 0.04319222f, -1.203178f, 1.1210147f,
+ 0.035045594f, -0.20559944f, -0.015388541f, -0.7857941f, -0.94100875f,
+ -0.1278549f, -19.22603f, 7.9466896f, 6.5048656f, -0.22195444f,
+ 0.19061874f, 1.3927288f, -8.896529f, -0.48146892f, -1.6098932f,
+ -0.0030235797f, -0.6533787f, -2.1333003f, -22.256454f, -4.934058f,
+ -4.4707212f, -0.015831878f, -0.4243649f, -2.776269f, -0.23762038f,
+ 0.1820098f, -0.51865315f, -1.1893421f, 0.34969202f, 0.10636194f,
+ 14.545696f, 1.3849198f, 2.6815193f, -0.5145498f, 0.45948258f,
+ -0.8842355f, -0.9111363f, -0.39652422f, 0.077266276f, -0.68084997f,
+ 0.4593515f, -0.28872707f, -6.936231f, 1.12253f, 1.7616503f,
+ -0.014069137f, -0.0052156276f, -4.5095444f, 6.2076726f, -0.058755957f,
+ -0.4675936f, -0.13039507f, 0.12094394f, -0.07285393f, 68.26125f,
+ 7.4893136f, 8.770954f, 0.020274093f, -0.027877754f, 1.6579602f,
+ -0.1825479f, 0.34832543f, 0.07472531f, -0.44812247f, -1.0941806f,
+ -0.16749863f, 1.1394324f, 0.47983396f, -0.99983627f, -0.00064249727f,
+ -1.3345739f, -0.057157427f, -18.14875f, 16.506035f, 15.539248f,
+ 0.013191509f, -0.021674965f, -25.006235f, 0.51220596f, 0.7334426f,
+ 0.81836903f, -1.0443225f, 0.4459505f, -1.2045046f
+};
+
+static const float av1_intrap_hiddenlayer_0_bias_12[] = {
+ -4.154915f, 14.33833f, 0.0f, 0.0f, 2.0440118f, 12.40922f,
+ -16.77514f, 0.5879813f, 3.2305415f, 0.8303539f, 0.0f, 14.488708f,
+ 2.94393f, 1.874383f, 0.0f, -0.53140444f, 0.0f, 1.8456234f,
+ -0.55427986f, -19.856262f, 0.0f, 0.17281002f, 48.31631f, 0.0f
+};
+
+static const float av1_intrap_logits_kernel_12[] = {
+ 0.26843873f, -0.09576241f, 0.34427166f, 0.09914787f, -0.10275399f,
+ 0.02999484f, -0.1467772f, 0.11594324f, 0.29200763f, 0.0067976206f,
+ 0.050393578f, -0.018694371f, 0.3333476f, 0.2127221f, 0.35128218f,
+ 0.19968672f, 0.08099991f, 0.084850654f, -0.16045967f, 0.30286232f,
+ 0.6164765f, -0.27140254f, 0.08210814f, 0.34852806f, 0.25028184f,
+ -0.12188078f, 0.16310331f, 0.31253803f, -0.10792341f, 0.065858394f,
+ -0.1349708f, 0.08948815f, 0.31905392f, 0.03680656f, -0.05040944f,
+ -0.051539157f, 0.3211852f, 0.2137136f, 0.45037416f, 0.22748767f,
+ -0.10978614f, 0.06475646f, -0.16954158f, 0.32831904f, 0.16479677f,
+ -0.30020145f, 0.066221856f, 0.37213042f
+};
+
+static const float av1_intrap_logits_bias_12[] = { 0.95783f, -0.95823103f };
+
+static const NN_CONFIG av1_intrap_nn_config = {
+ NUM_FEATURES_12,
+ NUM_LOGITS_12,
+ NUM_HIDDEN_LAYERS_12,
+ {
+ NUM_LAYER_0_UNITS_12,
+ },
+ {
+ av1_intrap_hiddenlayer_0_kernel_12,
+ av1_intrap_logits_kernel_12,
+ },
+ {
+ av1_intrap_hiddenlayer_0_bias_12,
+ av1_intrap_logits_bias_12,
+ },
+};
+
+#undef NUM_HIDDEN_LAYERS_12
+#undef NUM_FEATURES_12
+#undef NUM_LAYER_0_UNITS_12
+#undef NUM_LOGITS_12
+
#ifdef __cplusplus
} // extern "C"
#endif