Rearrange inter mode rd fitting/checking process
1) Do rd model fitting after encoding each superblock.
Note that the actual fitting process will only be triggered
if the data size exceeds to the corresponding threshold.
2) Move rd skip check into motion_mode_rd. The check happens
before transform type/size search.
3) Move rd stats collection process into motion_mode_rd.
4) Keep the best estimated rd score and incorporate it into the
skip check
Change-Id: Ib6b445c04943182ac60e9ce35c5ffe35452dba6f
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 3f0b110..f096644 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -3750,6 +3750,9 @@
pc_root, NULL);
}
}
+#if CONFIG_COLLECT_INTER_MODE_RD_STATS
+ av1_inter_mode_data_fit(x->rdmult);
+#endif
}
}
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 24fc462..b485712 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -512,71 +512,37 @@
double a;
double b;
double dist_mean;
+ int skip_count;
+ int non_skip_count;
+ int bracket_idx;
} InterModeRdModel;
InterModeRdModel inter_mode_rd_models[BLOCK_SIZES_ALL];
-#define INTER_MODE_RD_DATA_OVERALL_SIZE 2000
-static int inter_mode_data_idx = 0;
-static int64_t inter_mode_data_sse[INTER_MODE_RD_DATA_OVERALL_SIZE];
-static int64_t inter_mode_data_dist[INTER_MODE_RD_DATA_OVERALL_SIZE];
-static int inter_mode_data_residue_cost[INTER_MODE_RD_DATA_OVERALL_SIZE];
+#define INTER_MODE_RD_DATA_OVERALL_SIZE 6400
+static int inter_mode_data_idx[4];
+static int64_t inter_mode_data_sse[4][INTER_MODE_RD_DATA_OVERALL_SIZE];
+static int64_t inter_mode_data_dist[4][INTER_MODE_RD_DATA_OVERALL_SIZE];
+static int inter_mode_data_residue_cost[4][INTER_MODE_RD_DATA_OVERALL_SIZE];
+static int inter_mode_data_all_cost[4][INTER_MODE_RD_DATA_OVERALL_SIZE];
+static int64_t inter_mode_data_ref_best_rd[4][INTER_MODE_RD_DATA_OVERALL_SIZE];
+
+int inter_mode_data_block_idx(BLOCK_SIZE bsize) {
+ if (bsize == BLOCK_8X8) return 1;
+ if (bsize == BLOCK_16X16) return 2;
+ if (bsize == BLOCK_32X32) return 3;
+ return -1;
+}
void av1_inter_mode_data_init() {
- inter_mode_data_idx = 0;
for (int i = 0; i < BLOCK_SIZES_ALL; ++i) {
+ const int block_idx = inter_mode_data_block_idx(i);
+ if (block_idx != -1) inter_mode_data_idx[block_idx] = 0;
InterModeRdModel *md = &inter_mode_rd_models[i];
md->ready = 0;
- }
-}
-
-static void inter_mode_data_fit(BLOCK_SIZE bsize) {
- InterModeRdModel *md = &inter_mode_rd_models[bsize];
- double my = 0;
- double mx = 0;
- double dx = 0;
- double dxy = 0;
- double dist_mean = 0;
- for (int i = 0; i < inter_mode_data_idx; ++i) {
- const double sse = inter_mode_data_sse[i];
- const double dist = inter_mode_data_dist[i];
- const double residue_cost = inter_mode_data_residue_cost[i];
- const double ld = (sse - dist) / residue_cost;
- dist_mean += dist;
- my += ld;
- mx += sse;
- dx += sse * sse;
- dxy += sse * ld;
- }
- dist_mean = dist_mean / inter_mode_data_idx;
- my = my / inter_mode_data_idx;
- mx = mx / inter_mode_data_idx;
- dx = sqrt(dx / inter_mode_data_idx);
- dxy = dxy / inter_mode_data_idx;
-
- md->dist_mean = dist_mean;
- md->a = (dxy - mx * my) / (dx * dx - mx * mx);
- md->b = my - md->a * mx;
- md->ready = 1;
-}
-
-static void inter_mode_data_push(BLOCK_SIZE bsize, int64_t sse, int64_t dist,
- int residue_cost) {
- if (residue_cost == 0) return;
- InterModeRdModel *md = &inter_mode_rd_models[bsize];
- if (md->ready) {
- return;
- }
- if (bsize != BLOCK_4X4) return;
- if (inter_mode_data_idx < INTER_MODE_RD_DATA_OVERALL_SIZE) {
- inter_mode_data_sse[inter_mode_data_idx] = sse;
- inter_mode_data_dist[inter_mode_data_idx] = dist;
- inter_mode_data_residue_cost[inter_mode_data_idx] = residue_cost;
- ++inter_mode_data_idx;
- }
- if (inter_mode_data_idx == INTER_MODE_RD_DATA_OVERALL_SIZE) {
- // TODO(angiebird): find an adative way to do the fitting
- inter_mode_data_fit(bsize);
+ md->skip_count = 0;
+ md->non_skip_count = 0;
+ md->bracket_idx = 0;
}
}
@@ -594,6 +560,97 @@
return 0;
}
+#define DATA_BRACKETS 7
+static const int data_num_threshold[DATA_BRACKETS] = {
+ 200, 400, 800, 1600, 3200, 6400, INT32_MAX
+};
+
+void av1_inter_mode_data_fit(int rdmult) {
+ for (int bsize = 0; bsize < BLOCK_SIZES_ALL; ++bsize) {
+ const int block_idx = inter_mode_data_block_idx(bsize);
+ InterModeRdModel *md = &inter_mode_rd_models[bsize];
+ if (block_idx == -1) continue;
+ int data_num = inter_mode_data_idx[block_idx];
+ if (data_num < data_num_threshold[md->bracket_idx]) {
+ continue;
+ }
+ double my = 0;
+ double mx = 0;
+ double dx = 0;
+ double dxy = 0;
+ double dist_mean = 0;
+ const int train_num = data_num;
+ for (int i = 0; i < train_num; ++i) {
+ const double sse = inter_mode_data_sse[block_idx][i];
+ const double dist = inter_mode_data_dist[block_idx][i];
+ const double residue_cost = inter_mode_data_residue_cost[block_idx][i];
+ const double ld = (sse - dist) / residue_cost;
+ dist_mean += dist;
+ my += ld;
+ mx += sse;
+ dx += sse * sse;
+ dxy += sse * ld;
+ }
+ dist_mean = dist_mean / data_num;
+ my = my / train_num;
+ mx = mx / train_num;
+ dx = sqrt(dx / train_num);
+ dxy = dxy / train_num;
+
+ md->dist_mean = dist_mean;
+ md->a = (dxy - mx * my) / (dx * dx - mx * mx);
+ md->b = my - md->a * mx;
+ ++md->bracket_idx;
+ md->ready = 1;
+ assert(md->bracket_idx < DATA_BRACKETS);
+
+ (void)rdmult;
+#if 0
+ int skip_count = 0;
+ int fp_skip_count = 0;
+ double avg_error = 0;
+ const int test_num = data_num;
+ for (int i = 0; i < data_num; ++i) {
+ const int64_t sse = inter_mode_data_sse[block_idx][i];
+ const int64_t dist = inter_mode_data_dist[block_idx][i];
+ const int64_t residue_cost = inter_mode_data_residue_cost[block_idx][i];
+ const int64_t all_cost = inter_mode_data_all_cost[block_idx][i];
+ const int64_t est_rd =
+ get_est_rd(bsize, rdmult, sse, all_cost - residue_cost);
+ const int64_t real_rd = RDCOST(rdmult, all_cost, dist);
+ const int64_t ref_best_rd = inter_mode_data_ref_best_rd[block_idx][i];
+ if (est_rd > ref_best_rd) {
+ ++skip_count;
+ if (real_rd < ref_best_rd) {
+ ++fp_skip_count;
+ }
+ }
+ avg_error += abs(est_rd - real_rd) * 100. / real_rd;
+ }
+ avg_error /= test_num;
+ printf("test_num %d bsize %d avg_error %f skip_count %d fp_skip_count %d\n",
+ test_num, bsize, avg_error, skip_count, fp_skip_count);
+#endif
+ }
+}
+
+static void inter_mode_data_push(BLOCK_SIZE bsize, int64_t sse, int64_t dist,
+ int residue_cost, int all_cost,
+ int64_t ref_best_rd) {
+ if (residue_cost == 0) return;
+ const int block_idx = inter_mode_data_block_idx(bsize);
+ if (block_idx == -1) return;
+ if (inter_mode_data_idx[block_idx] < INTER_MODE_RD_DATA_OVERALL_SIZE) {
+ const int data_idx = inter_mode_data_idx[block_idx];
+ inter_mode_data_sse[block_idx][data_idx] = sse;
+ inter_mode_data_dist[block_idx][data_idx] = dist;
+ inter_mode_data_residue_cost[block_idx][data_idx] = residue_cost;
+ inter_mode_data_all_cost[block_idx][data_idx] = all_cost;
+ inter_mode_data_ref_best_rd[block_idx][data_idx] = ref_best_rd;
+ ++inter_mode_data_idx[block_idx];
+ }
+}
+
#define INTER_MODE_RD_DATA_SIZE 300
typedef struct InterModeRdData {
@@ -1803,6 +1860,31 @@
*dist <<= 4;
}
+#if CONFIG_COLLECT_INTER_MODE_RD_STATS
+static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
+ const AV1_COMMON *cm = &cpi->common;
+ const int num_planes = av1_num_planes(cm);
+ const MACROBLOCKD *xd = &x->e_mbd;
+ const MB_MODE_INFO *mbmi = xd->mi[0];
+ int64_t total_sse = 0;
+ for (int plane = 0; plane < num_planes; ++plane) {
+ const struct macroblock_plane *const p = &x->plane[plane];
+ const struct macroblockd_plane *const pd = &xd->plane[plane];
+ const BLOCK_SIZE bs = get_plane_block_size(mbmi->sb_type, pd->subsampling_x,
+ pd->subsampling_y);
+ unsigned int sse;
+
+ if (x->skip_chroma_rd && plane) continue;
+
+ cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
+ &sse);
+ total_sse += sse;
+ }
+ total_sse <<= 4;
+ return total_sse;
+}
+#endif
+
static void model_rd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
MACROBLOCK *x, MACROBLOCKD *xd, int plane_from,
int plane_to, int *out_rate_sum,
@@ -7589,7 +7671,12 @@
int *disable_skip, int mi_row, int mi_col,
HandleInterModeArgs *const args,
const int64_t ref_best_rd, const int *refs,
- int rate_mv, BUFFER_SET *orig_dst) {
+ int rate_mv, BUFFER_SET *orig_dst
+#if CONFIG_COLLECT_INTER_MODE_RD_STATS
+ ,
+ int64_t *best_est_rd
+#endif
+) {
const AV1_COMMON *const cm = &cpi->common;
const int num_planes = av1_num_planes(cm);
MACROBLOCKD *xd = &x->e_mbd;
@@ -7898,6 +7985,27 @@
}
}
if (!skip_txfm_sb) {
+#if CONFIG_COLLECT_INTER_MODE_RD_STATS
+#if !INTER_MODE_RD_STATS_DUMP
+ InterModeRdModel *md = &inter_mode_rd_models[mbmi->sb_type];
+ if (md->ready) {
+ const int64_t curr_sse = get_sse(cpi, x);
+ const int64_t est_rd =
+ get_est_rd(mbmi->sb_type, x->rdmult, curr_sse, rd_stats->rate);
+ if (est_rd * 0.8 > *best_est_rd) {
+ ++md->skip_count;
+ mbmi->ref_frame[1] = ref_frame_1;
+ continue;
+ } else {
+ if (est_rd < *best_est_rd) {
+ *best_est_rd = est_rd;
+ }
+ ++md->non_skip_count;
+ }
+ }
+#endif // !INTER_MODE_RD_STATS_DUMP
+#endif // CONFIG_COLLECT_INTER_MODE_RD_STATS
+
int64_t rdcosty = INT64_MAX;
int is_cost_valid_uv = 0;
@@ -8005,6 +8113,12 @@
}
}
+#if CONFIG_COLLECT_INTER_MODE_RD_STATS
+ inter_mode_data_push(mbmi->sb_type, rd_stats->sse, rd_stats->dist,
+ rd_stats_y->rate + rd_stats_uv->rate, rd_stats->rate,
+ ref_best_rd);
+#endif // CONFIG_COLLECT_INTER_MODE_RD_STATS
+
tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
if ((mbmi->motion_mode == SIMPLE_TRANSLATION &&
mbmi->ref_frame[1] != INTRA_FRAME) ||
@@ -8179,7 +8293,8 @@
HandleInterModeArgs *args, int64_t ref_best_rd
#if CONFIG_COLLECT_INTER_MODE_RD_STATS
,
- InterModeRdVector *inter_mode_rd_vector
+ InterModeRdVector *inter_mode_rd_vector,
+ int64_t *best_est_rd
#endif
) {
const AV1_COMMON *cm = &cpi->common;
@@ -8594,18 +8709,6 @@
#if CONFIG_COLLECT_INTER_MODE_RD_STATS
(void)inter_mode_rd_vector;
- const int mv_cost = ref_mv_cost + rate_mv;
- const int curr_cost = mv_cost + args->single_comp_cost +
- args->ref_frame_cost + compmode_interinter_cost;
-#if !INTER_MODE_RD_STATS_DUMP
- const int64_t est_rd =
- get_est_rd(mbmi->sb_type, x->rdmult, skip_sse_sb, curr_cost);
- if (est_rd > ref_best_rd) {
- restore_dst_buf(xd, orig_dst, num_planes);
- early_terminate = INT64_MAX;
- continue;
- }
-#endif // !INTER_MODE_RD_STATS_DUMP
#endif // CONFIG_COLLECT_INTER_MODE_RD_STATS
if (search_jnt_comp && cpi->sf.jnt_comp_fast_tx_search && comp_idx == 0) {
@@ -8622,16 +8725,20 @@
rd_stats_y->dist = plane_dist[0];
rd_stats_uv->dist = plane_dist[1] + plane_dist[2];
} else {
+#if CONFIG_COLLECT_INTER_MODE_RD_STATS
+ ret_val = motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y, rd_stats_uv,
+ disable_skip, mi_row, mi_col, args, ref_best_rd,
+ refs, rate_mv, &orig_dst, best_est_rd);
+#else
ret_val = motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y, rd_stats_uv,
disable_skip, mi_row, mi_col, args, ref_best_rd,
refs, rate_mv, &orig_dst);
+#endif
}
if (ret_val != INT64_MAX) {
#if CONFIG_COLLECT_INTER_MODE_RD_STATS
- const int residue_cost = rd_stats_y->rate + rd_stats_uv->rate;
- inter_mode_data_push(mbmi->sb_type, rd_stats->sse, rd_stats->dist,
- residue_cost);
#if INTER_MODE_RD_STATS_DUMP
+ const int residue_cost = rd_stats_y->rate + rd_stats_uv->rate;
inter_mode_rd_vector_push(
inter_mode_rd_vector, this_mode, mbmi, mbmi_ext, rd_stats->sse,
rd_stats->dist, args->single_comp_cost + compmode_interinter_cost,
@@ -9772,6 +9879,7 @@
InterModeRdVector inter_mode_rd_vector;
inter_mode_rd_vector_init(&inter_mode_rd_vector, mi_row, mi_col, bsize,
x->rdmult);
+ int64_t best_est_rd = INT64_MAX;
#endif
for (int midx = 0; midx < MAX_MODES; ++midx) {
@@ -10111,7 +10219,8 @@
#if CONFIG_COLLECT_INTER_MODE_RD_STATS
this_rd = handle_inter_mode(cpi, x, bsize, &rd_stats, &rd_stats_y,
&rd_stats_uv, &disable_skip, mi_row, mi_col,
- &args, ref_best_rd, &inter_mode_rd_vector);
+ &args, ref_best_rd, &inter_mode_rd_vector,
+ &best_est_rd);
#else
this_rd = handle_inter_mode(cpi, x, bsize, &rd_stats, &rd_stats_y,
&rd_stats_uv, &disable_skip, mi_row, mi_col,
@@ -10230,7 +10339,7 @@
tmp_alt_rd = handle_inter_mode(
cpi, x, bsize, &tmp_rd_stats, &tmp_rd_stats_y, &tmp_rd_stats_uv,
&dummy_disable_skip, mi_row, mi_col, &args, ref_best_rd,
- &inter_mode_rd_vector);
+ &inter_mode_rd_vector, &best_est_rd);
#else
tmp_alt_rd = handle_inter_mode(
cpi, x, bsize, &tmp_rd_stats, &tmp_rd_stats_y, &tmp_rd_stats_uv,
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 02b9ae1..a701898 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -131,6 +131,7 @@
#if CONFIG_COLLECT_INTER_MODE_RD_STATS
void av1_inter_mode_data_init();
+void av1_inter_mode_data_fit(int rdmult);
#endif
#ifdef __cplusplus