Rework the tpl model construction
Support forward building and backward synthesis to build the
temporal dependency model.
Change-Id: Id04d838647eaf366c3a2b2a25028cdcdadcb7d87
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 611c71e..df29b54 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -166,6 +166,9 @@
int64_t inter_cost;
int64_t mc_flow;
int64_t mc_dep_cost;
+ int_mv mv;
+ int ref_frame_index;
+ double quant_ratio;
#if !USE_TPL_CLASSIC_MODEL
int64_t mc_count;
int64_t mc_saved;
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 7cbc824..a7fadf7 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -139,8 +139,7 @@
int frame_idx, int16_t *src_diff, tran_low_t *coeff, tran_low_t *qcoeff,
tran_low_t *dqcoeff, int use_satd, int mi_row, int mi_col, BLOCK_SIZE bsize,
TX_SIZE tx_size, const YV12_BUFFER_CONFIG *ref_frame[], uint8_t *predictor,
- int64_t *recon_error, int64_t *sse, TplDepStats *tpl_stats,
- int *ref_frame_index, int_mv *mv) {
+ int64_t *recon_error, int64_t *sse, TplDepStats *tpl_stats) {
AV1_COMMON *cm = &cpi->common;
const GF_GROUP *gf_group = &cpi->gf_group;
@@ -323,8 +322,9 @@
tpl_stats->intra_cost = best_intra_cost << TPL_DEP_COST_SCALE_LOG2;
if (frame_idx && best_rf_idx != -1) {
- *ref_frame_index = cpi->tpl_frame[frame_idx].ref_map_index[best_rf_idx];
- mv->as_int = best_mv.as_int;
+ tpl_stats->mv.as_int = best_mv.as_int;
+ tpl_stats->ref_frame_index =
+ cpi->tpl_frame[frame_idx].ref_map_index[best_rf_idx];
}
}
@@ -383,14 +383,12 @@
static AOM_INLINE void tpl_model_update_b(AV1_COMP *cpi, TplDepFrame *tpl_frame,
TplDepStats *tpl_stats_ptr,
int mi_row, int mi_col,
- double quant_ratio,
- const BLOCK_SIZE bsize,
- int ref_frame_index, int_mv mv) {
- TplDepFrame *ref_tpl_frame = &tpl_frame[ref_frame_index];
+ const BLOCK_SIZE bsize) {
+ TplDepFrame *ref_tpl_frame = &tpl_frame[tpl_stats_ptr->ref_frame_index];
TplDepStats *ref_stats_ptr = ref_tpl_frame->tpl_stats_ptr;
- const int ref_pos_row = mi_row * MI_SIZE + (mv.as_mv.row >> 3);
- const int ref_pos_col = mi_col * MI_SIZE + (mv.as_mv.col >> 3);
+ const int ref_pos_row = mi_row * MI_SIZE + (tpl_stats_ptr->mv.as_mv.row >> 3);
+ const int ref_pos_col = mi_col * MI_SIZE + (tpl_stats_ptr->mv.as_mv.col >> 3);
const int bw = 4 << mi_size_wide_log2[bsize];
const int bh = 4 << mi_size_high_log2[bsize];
@@ -416,8 +414,11 @@
const double iiratio_nl = iiratio_nonlinear(
(double)tpl_stats_ptr->inter_cost / tpl_stats_ptr->intra_cost);
- int64_t mc_flow = (int64_t)(quant_ratio * tpl_stats_ptr->mc_dep_cost *
- (1.0 - iiratio_nl));
+ tpl_stats_ptr->mc_dep_cost =
+ tpl_stats_ptr->intra_cost + tpl_stats_ptr->mc_flow;
+ int64_t mc_flow =
+ (int64_t)(tpl_stats_ptr->quant_ratio * tpl_stats_ptr->mc_dep_cost *
+ (1.0 - iiratio_nl));
#if !USE_TPL_CLASSIC_MODEL
int64_t mc_saved = tpl_stats_ptr->intra_cost - tpl_stats_ptr->inter_cost;
#endif // #if !USE_TPL_CLASSIC_MODEL
@@ -440,9 +441,7 @@
static AOM_INLINE void tpl_model_update(AV1_COMP *cpi, TplDepFrame *tpl_frame,
TplDepStats *tpl_stats_ptr, int mi_row,
- int mi_col, double quant_ratio,
- const BLOCK_SIZE bsize,
- int ref_frame_index, int_mv mv) {
+ int mi_col, const BLOCK_SIZE bsize) {
const int mi_height = mi_size_high[bsize];
const int mi_width = mi_size_wide[bsize];
const int step = 1 << cpi->tpl_stats_block_mis_log2;
@@ -454,7 +453,7 @@
TplDepStats *tpl_ptr = &tpl_stats_ptr[av1_tpl_ptr_pos(
cpi, mi_row + idy, mi_col + idx, tpl_frame->stride)];
tpl_model_update_b(cpi, tpl_frame, tpl_ptr, mi_row + idy, mi_col + idx,
- quant_ratio, tpl_block_size, ref_frame_index, mv);
+ tpl_block_size);
}
}
}
@@ -469,19 +468,18 @@
int64_t intra_cost = src_stats->intra_cost / (mi_height * mi_width);
int64_t inter_cost = src_stats->inter_cost / (mi_height * mi_width);
-
- TplDepStats *tpl_ptr;
-
intra_cost = AOMMAX(1, intra_cost);
inter_cost = AOMMAX(1, inter_cost);
for (int idy = 0; idy < mi_height; idy += step) {
- tpl_ptr =
+ TplDepStats *tpl_ptr =
&tpl_stats_ptr[av1_tpl_ptr_pos(cpi, mi_row + idy, mi_col, stride)];
for (int idx = 0; idx < mi_width; idx += step) {
tpl_ptr->intra_cost = intra_cost;
tpl_ptr->inter_cost = inter_cost;
- tpl_ptr->mc_dep_cost = tpl_ptr->intra_cost + tpl_ptr->mc_flow;
+ tpl_ptr->quant_ratio = src_stats->quant_ratio;
+ tpl_ptr->mv.as_int = src_stats->mv.as_int;
+ tpl_ptr->ref_frame_index = src_stats->ref_frame_index;
++tpl_ptr;
}
}
@@ -608,10 +606,6 @@
xd->mb_to_bottom_edge = ((cm->mi_rows - mi_height - mi_row) * MI_SIZE) * 8;
for (mi_col = 0; mi_col < cm->mi_cols; mi_col += mi_width) {
TplDepStats tpl_stats;
- int ref_frame_index = -1;
- int_mv mv;
-
- mv.as_int = 0;
// Motion estimation column boundary
x->mv_limits.col_min =
@@ -622,17 +616,13 @@
xd->mb_to_right_edge = ((cm->mi_cols - mi_width - mi_col) * MI_SIZE) * 8;
mode_estimation(cpi, x, xd, &sf, frame_idx, src_diff, coeff, qcoeff,
dqcoeff, 1, mi_row, mi_col, bsize, tx_size, ref_frame,
- predictor, &recon_error, &sse, &tpl_stats,
- &ref_frame_index, &mv);
+ predictor, &recon_error, &sse, &tpl_stats);
// Motion flow dependency dispenser.
+ double quant_ratio = (double)recon_error / sse;
+ tpl_stats.quant_ratio = quant_ratio;
tpl_model_store(cpi, tpl_frame->tpl_stats_ptr, mi_row, mi_col, bsize,
tpl_frame->stride, &tpl_stats);
- double quant_ratio = (double)recon_error / sse;
- if (frame_idx) {
- tpl_model_update(cpi, cpi->tpl_frame, tpl_frame->tpl_stats_ptr, mi_row,
- mi_col, quant_ratio, bsize, ref_frame_index, mv);
- }
}
}
@@ -640,6 +630,28 @@
xd->mi = backup_mi_grid;
}
+static void mc_flow_synthesizer(AV1_COMP *cpi, int frame_idx) {
+ AV1_COMMON *cm = &cpi->common;
+
+ const GF_GROUP *gf_group = &cpi->gf_group;
+ if (frame_idx == gf_group->size) return;
+
+ TplDepFrame *tpl_frame = &cpi->tpl_frame[frame_idx];
+
+ const BLOCK_SIZE bsize = convert_length_to_bsize(MC_FLOW_BSIZE_1D);
+ const int mi_height = mi_size_high[bsize];
+ const int mi_width = mi_size_wide[bsize];
+
+ for (int mi_row = 0; mi_row < cm->mi_rows; mi_row += mi_height) {
+ for (int mi_col = 0; mi_col < cm->mi_cols; mi_col += mi_width) {
+ if (frame_idx) {
+ tpl_model_update(cpi, cpi->tpl_frame, tpl_frame->tpl_stats_ptr, mi_row,
+ mi_col, bsize);
+ }
+ }
+ }
+}
+
static AOM_INLINE void init_gop_frames_for_tpl(
AV1_COMP *cpi, const EncodeFrameParams *const init_frame_params,
GF_GROUP *gf_group, int *tpl_group_frames,
@@ -832,13 +844,22 @@
if (cpi->oxcf.enable_tpl_model == 1) {
// Backward propagation from tpl_group_frames to 1.
+ for (int frame_idx = gf_group->index; frame_idx < cpi->tpl_gf_group_frames;
+ ++frame_idx) {
+ if (gf_group->update_type[frame_idx] == INTNL_OVERLAY_UPDATE ||
+ gf_group->update_type[frame_idx] == OVERLAY_UPDATE)
+ continue;
+
+ mc_flow_dispenser(cpi, frame_idx);
+ }
+
for (int frame_idx = cpi->tpl_gf_group_frames - 1;
frame_idx >= gf_group->index; --frame_idx) {
if (gf_group->update_type[frame_idx] == INTNL_OVERLAY_UPDATE ||
gf_group->update_type[frame_idx] == OVERLAY_UPDATE)
continue;
- mc_flow_dispenser(cpi, frame_idx);
+ mc_flow_synthesizer(cpi, frame_idx);
}
}