Enable compound motion search in tpl model
STATS_CHANGED
Change-Id: I71bca71ba41ded81f546d05befcdabe1e27f1449
diff --git a/av1/common/convolve.h b/av1/common/convolve.h
index 490d778..5f3e596 100644
--- a/av1/common/convolve.h
+++ b/av1/common/convolve.h
@@ -68,6 +68,7 @@
assert(IMPLIES(cmp_index, is_compound));
conv_params.is_compound = is_compound;
+ conv_params.use_dist_wtd_comp_avg = 0;
conv_params.round_0 = ROUND0_BITS;
conv_params.round_1 = is_compound ? COMPOUND_ROUND1_BITS
: 2 * FILTER_BITS - conv_params.round_0;
diff --git a/av1/encoder/motion_search_facade.c b/av1/encoder/motion_search_facade.c
index 6294895..50cbdc7 100644
--- a/av1/encoder/motion_search_facade.c
+++ b/av1/encoder/motion_search_facade.c
@@ -535,53 +535,15 @@
xd, cm, &ms_params, start_mv, &best_mv.as_mv, &dis, &sse, NULL);
if (try_second) {
- struct macroblockd_plane *p = xd->plane;
- const BUFFER_SET orig_dst = {
- { p[0].dst.buf, p[1].dst.buf, p[2].dst.buf },
- { p[0].dst.stride, p[1].dst.stride, p[2].dst.stride },
- };
- mbmi->mv[id].as_mv = best_mv.as_mv;
- mbmi->mv[!id].as_mv = cur_mv[!id].as_mv;
-
- xd->plane[plane].pre[0] = ref_yv12[0];
- xd->plane[plane].pre[1] = ref_yv12[1];
-
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, &orig_dst, bsize,
- 0, 0);
- av1_subtract_plane(x, bsize, 0);
- RD_STATS this_rd_stats;
- av1_init_rd_stats(&this_rd_stats);
- av1_estimate_txfm_yrd(cpi, x, &this_rd_stats, INT64_MAX, bsize,
- max_txsize_rect_lookup[bsize]);
- int this_mv_rate = av1_mv_bit_cost(
- &best_mv.as_mv, &ref_mv[id].as_mv, mv_costs->nmv_joint_cost,
- mv_costs->mv_cost_stack, MV_COST_WEIGHT);
- int64_t rd = RDCOST(x->rdmult, this_mv_rate + this_rd_stats.rate,
- this_rd_stats.dist);
-
MV this_best_mv;
MV subpel_start_mv = get_mv_from_fullmv(&second_best_mv.as_fullmv);
if (av1_is_subpelmv_in_range(&ms_params.mv_limits, subpel_start_mv)) {
- const int this_var = cpi->mv_search_params.find_fractional_mv_step(
+ const int thissme = cpi->mv_search_params.find_fractional_mv_step(
xd, cm, &ms_params, subpel_start_mv, &this_best_mv, &dis, &sse,
NULL);
- mbmi->mv[id].as_mv = this_best_mv;
- mbmi->mv[!id].as_mv = cur_mv[!id].as_mv;
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, &orig_dst,
- bsize, 0, 0);
- av1_subtract_plane(x, bsize, 0);
- RD_STATS tmp_rd_stats;
- av1_init_rd_stats(&tmp_rd_stats);
- av1_estimate_txfm_yrd(cpi, x, &tmp_rd_stats, INT64_MAX, bsize,
- max_txsize_rect_lookup[bsize]);
- int tmp_mv_rate = av1_mv_bit_cost(
- &this_best_mv, &ref_mv[id].as_mv, mv_costs->nmv_joint_cost,
- mv_costs->mv_cost_stack, MV_COST_WEIGHT);
- int64_t tmp_rd = RDCOST(x->rdmult, tmp_rd_stats.rate + tmp_mv_rate,
- tmp_rd_stats.dist);
- if (tmp_rd < rd) {
+ if (thissme < bestsme) {
best_mv.as_mv = this_best_mv;
- bestsme = AOMMIN(bestsme, this_var);
+ bestsme = thissme;
}
}
}
diff --git a/av1/encoder/reconinter_enc.c b/av1/encoder/reconinter_enc.c
index 20da822..6020b94 100644
--- a/av1/encoder/reconinter_enc.c
+++ b/av1/encoder/reconinter_enc.c
@@ -77,7 +77,8 @@
InterPredParams *inter_pred_params) {
av1_build_one_inter_predictor(
dst, dst_stride, src_mv, inter_pred_params, NULL /* xd */, 0 /* mi_x */,
- 0 /* mi_y */, 0 /* ref */, NULL /* mc_buf */, enc_calc_subpel_params);
+ 0 /* mi_y */, inter_pred_params->conv_params.do_average /* ref */,
+ NULL /* mc_buf */, enc_calc_subpel_params);
}
static void enc_build_inter_predictors(const AV1_COMMON *cm, MACROBLOCKD *xd,
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 5344dc4..807cd6c 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -571,6 +571,7 @@
sf->tpl_sf.reduce_first_step_size = 6;
sf->tpl_sf.subpel_force_stop = QUARTER_PEL;
sf->tpl_sf.search_method = DIAMOND;
+ sf->tpl_sf.allow_compound_pred = 0;
sf->tx_sf.adaptive_txb_search_level = boosted ? 2 : 3;
sf->tx_sf.tx_type_search.use_skip_flag_prediction = 2;
@@ -1064,6 +1065,7 @@
tpl_sf->search_method = NSTEP;
tpl_sf->disable_filtered_key_tpl = 0;
tpl_sf->prune_ref_frames_in_tpl = 0;
+ tpl_sf->allow_compound_pred = 1;
}
static AOM_INLINE void init_gm_sf(GLOBAL_MOTION_SPEED_FEATURES *gm_sf) {
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 4d11f14..21cee75 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -369,6 +369,9 @@
// Prune reference frames in TPL.
int prune_ref_frames_in_tpl;
+
+ // Support compound predictions.
+ int allow_compound_pred;
} TPL_SPEED_FEATURES;
typedef struct GLOBAL_MOTION_SPEED_FEATURES {
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 821cbea..5949128 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -29,6 +29,7 @@
#include "av1/encoder/encodeframe_utils.h"
#include "av1/encoder/encode_strategy.h"
#include "av1/encoder/hybrid_fwd_txfm.h"
+#include "av1/encoder/motion_search_facade.h"
#include "av1/encoder/rd.h"
#include "av1/encoder/rdopt.h"
#include "av1/encoder/reconinter_enc.h"
@@ -434,16 +435,19 @@
// Motion compensated prediction
xd->mi[0]->ref_frame[0] = INTRA_FRAME;
xd->mi[0]->ref_frame[1] = NONE_FRAME;
+ xd->mi[0]->compound_idx = 1;
int best_rf_idx = -1;
int_mv best_mv;
int64_t inter_cost;
int64_t best_inter_cost = INT64_MAX;
int rf_idx;
+ int_mv single_mv[INTER_REFS_PER_FRAME];
best_mv.as_int = INVALID_MV;
for (rf_idx = 0; rf_idx < INTER_REFS_PER_FRAME; ++rf_idx) {
+ single_mv[rf_idx].as_int = INVALID_MV;
if (tpl_data->ref_frame[rf_idx] == NULL ||
tpl_data->src_ref_frame[rf_idx] == NULL) {
tpl_stats->mv[rf_idx].as_int = INVALID_MV;
@@ -535,6 +539,7 @@
}
tpl_stats->mv[rf_idx].as_int = best_rfidx_mv.as_int;
+ single_mv[rf_idx] = best_rfidx_mv;
struct buf_2d ref_buf = { NULL, ref_frame_ptr->y_buffer,
ref_frame_ptr->y_width, ref_frame_ptr->y_height,
@@ -567,6 +572,68 @@
}
}
+ int comp_ref_frames[3][2] = {
+ { 0, 4 },
+ { 0, 6 },
+ { 3, 6 },
+ };
+
+ xd->mi_row = mi_row;
+ xd->mi_col = mi_col;
+ for (int cmp_rf_idx = 0; cmp_rf_idx < 3; ++cmp_rf_idx) {
+ int rf_idx0 = comp_ref_frames[cmp_rf_idx][0];
+ int rf_idx1 = comp_ref_frames[cmp_rf_idx][1];
+
+ if (tpl_data->ref_frame[rf_idx0] == NULL ||
+ tpl_data->src_ref_frame[rf_idx0] == NULL ||
+ tpl_data->ref_frame[rf_idx1] == NULL ||
+ tpl_data->src_ref_frame[rf_idx1] == NULL) {
+ continue;
+ }
+
+ const YV12_BUFFER_CONFIG *ref_frame_ptr[2] = {
+ tpl_data->src_ref_frame[rf_idx0],
+ tpl_data->src_ref_frame[rf_idx1],
+ };
+
+ xd->mi[0]->ref_frame[0] = LAST_FRAME;
+ xd->mi[0]->ref_frame[1] = ALTREF_FRAME;
+
+ struct buf_2d yv12_mb[2][MAX_MB_PLANE];
+ for (int i = 0; i < 2; ++i) {
+ av1_setup_pred_block(xd, yv12_mb[i], ref_frame_ptr[i],
+ xd->block_ref_scale_factors[i],
+ xd->block_ref_scale_factors[i], MAX_MB_PLANE);
+ for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
+ xd->plane[plane].pre[i] = yv12_mb[i][plane];
+ }
+ }
+
+ int_mv tmp_mv[2] = { single_mv[rf_idx0], single_mv[rf_idx1] };
+ int rate_mv;
+ av1_joint_motion_search(cpi, x, bsize, tmp_mv, NULL, 0, &rate_mv);
+
+ for (int ref = 0; ref < 2; ++ref) {
+ struct buf_2d ref_buf = { NULL, ref_frame_ptr[ref]->y_buffer,
+ ref_frame_ptr[ref]->y_width,
+ ref_frame_ptr[ref]->y_height,
+ ref_frame_ptr[ref]->y_stride };
+ InterPredParams inter_pred_params;
+ av1_init_inter_params(&inter_pred_params, bw, bh, mi_row * MI_SIZE,
+ mi_col * MI_SIZE, 0, 0, xd->bd, is_cur_buf_hbd(xd),
+ 0, &tpl_data->sf, &ref_buf, kernel);
+ av1_init_comp_mode(&inter_pred_params);
+
+ inter_pred_params.conv_params = get_conv_params_no_round(
+ ref, 0, xd->tmp_conv_dst, MAX_SB_SIZE, 1, xd->bd);
+
+ av1_enc_build_one_inter_predictor(predictor, bw, &tmp_mv[ref].as_mv,
+ &inter_pred_params);
+ }
+ tpl_get_satd_cost(x, src_diff, bw, src_mb_buffer, src_stride, predictor, bw,
+ coeff, bw, bh, tx_size);
+ }
+
if (best_inter_cost < INT64_MAX) {
const YV12_BUFFER_CONFIG *ref_frame_ptr =
tpl_data->src_ref_frame[best_rf_idx];
@@ -1257,6 +1324,9 @@
tpl_row_mt->sync_read_ptr = av1_tpl_row_mt_sync_read_dummy;
tpl_row_mt->sync_write_ptr = av1_tpl_row_mt_sync_write_dummy;
+ av1_setup_scale_factors_for_frame(&cm->sf_identity, cm->width, cm->height,
+ cm->width, cm->height);
+
// Backward propagation from tpl_group_frames to 1.
for (int frame_idx = gf_group->index; frame_idx < tpl_gf_group_frames;
++frame_idx) {
@@ -1265,7 +1335,7 @@
continue;
init_mc_flow_dispenser(cpi, frame_idx, pframe_qindex);
- if (mt_info->num_workers > 1) {
+ if (mt_info->num_workers > 1 && !cpi->sf.tpl_sf.allow_compound_pred) {
tpl_row_mt->sync_read_ptr = av1_tpl_row_mt_sync_read;
tpl_row_mt->sync_write_ptr = av1_tpl_row_mt_sync_write;
av1_mc_flow_dispenser_mt(cpi);