Refactor try_average_and_distwtd_comp case
Moved try_average_and_distwtd_comp case out of
the core loop in compound_type_rd and refactored
it to remove redundant code
Change-Id: I822a6c70e908e0c717eedaab25d37c8032cd8c06
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index d700a27..24a55ef 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -9870,6 +9870,62 @@
// Populates masked_type_cost local array for the 4 compound types
calc_masked_type_cost(x, bsize, comp_group_idx_ctx, comp_index_ctx,
masked_compound_used, masked_type_cost);
+
+ int64_t comp_model_rd_cur = INT64_MAX;
+ int64_t best_rd_cur = INT64_MAX;
+
+ // Special case of COMPOUND_AVERAGE and COMPOUND_DISTWTD search
+ if (try_average_and_distwtd_comp) {
+ int est_rate[2];
+ int64_t est_dist[2], est_rd[2];
+ COMPOUND_TYPE best_type;
+
+ // Calculate model_rd for COMPOUND_AVERAGE and COMPOUND_DISTWTD
+ for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
+ comp_type++) {
+ update_mbmi_for_compound_type(mbmi, comp_type);
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+ AOM_PLANE_Y, AOM_PLANE_Y);
+ model_rd_sb_fn[MODELRD_CURVFIT](
+ cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &est_rate[comp_type],
+ &est_dist[comp_type], NULL, NULL, NULL, NULL, NULL);
+ est_rate[comp_type] += masked_type_cost[comp_type];
+ est_rd[comp_type] = RDCOST(x->rdmult, est_rate[comp_type] + *rate_mv,
+ est_dist[comp_type]);
+ if (comp_type == COMPOUND_AVERAGE) {
+ *is_luma_interp_done = 1;
+ restore_dst_buf(xd, *tmp_dst, 1);
+ }
+ }
+ // Choose the better of the two based on modeled cost and call
+ // estimate_yrd_for_sb() for that one.
+ best_type = (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD])
+ ? COMPOUND_AVERAGE
+ : COMPOUND_DISTWTD;
+ update_mbmi_for_compound_type(mbmi, best_type);
+ if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *orig_dst, 1);
+ rs2 = masked_type_cost[best_type];
+ RD_STATS est_rd_stats;
+ const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
+ const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+ const int64_t est_rd_ =
+ estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+ if (est_rd_ != INT64_MAX) {
+ best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
+ est_rd_stats.dist);
+ // Backup rate and distortion for future reuse
+ comp_rate[best_type] = est_rd_stats.rate;
+ comp_dist[best_type] = est_rd_stats.dist;
+ comp_model_rd[best_type] = est_rd[best_type];
+ comp_model_rd_cur = est_rd[best_type];
+ }
+ if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
+ if (best_rd_cur < *rd) {
+ update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
+ comp_model_rd_cur, rs2);
+ }
+ }
+
for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
if (((1 << cur_type) & mode_search_mask) == 0) {
if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
@@ -9877,144 +9933,73 @@
}
if (!is_interinter_compound_used(cur_type, bsize)) continue;
if (cur_type >= COMPOUND_WEDGE && !masked_compound_used) break;
- if (cur_type == COMPOUND_DISTWTD && !try_distwtd_comp) continue;
+ if (cur_type == COMPOUND_DISTWTD &&
+ (!try_distwtd_comp || try_average_and_distwtd_comp))
+ continue;
if (cur_type == COMPOUND_AVERAGE && try_average_and_distwtd_comp) continue;
- int64_t comp_model_rd_cur = INT64_MAX;
+ comp_model_rd_cur = INT64_MAX;
tmp_rate_mv = *rate_mv;
- int64_t best_rd_cur = INT64_MAX;
+ best_rd_cur = INT64_MAX;
- if (cur_type == COMPOUND_DISTWTD && try_average_and_distwtd_comp) {
- int est_rate[2];
- int64_t est_dist[2], est_rd[2];
+ if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
+ update_mbmi_for_compound_type(mbmi, cur_type);
+ rs2 = masked_type_cost[cur_type];
+ const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
+ if (mode_rd < ref_best_rd) {
+ // Reuse data if matching record is found
+ if (comp_rate[cur_type] == INT_MAX) {
+ av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
+ AOM_PLANE_Y, AOM_PLANE_Y);
+ if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
+ RD_STATS est_rd_stats;
+ const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
+ const int64_t est_rd =
+ estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+ if (est_rd != INT64_MAX) {
+ best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
+ est_rd_stats.dist);
+ model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
+ cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
+ &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
+ comp_model_rd_cur =
+ RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
- // First find the modeled rd cost for COMPOUND_AVERAGE
- update_mbmi_for_compound_type(mbmi, COMPOUND_AVERAGE);
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
- AOM_PLANE_Y, AOM_PLANE_Y);
- *is_luma_interp_done = 1;
- model_rd_sb_fn[MODELRD_CURVFIT](
- cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &est_rate[COMPOUND_AVERAGE],
- &est_dist[COMPOUND_AVERAGE], NULL, NULL, NULL, NULL, NULL);
- est_rate[COMPOUND_AVERAGE] += masked_type_cost[COMPOUND_AVERAGE];
- est_rd[COMPOUND_AVERAGE] =
- RDCOST(x->rdmult, est_rate[COMPOUND_AVERAGE] + *rate_mv,
- est_dist[COMPOUND_AVERAGE]);
- restore_dst_buf(xd, *tmp_dst, 1);
-
- // Next find the modeled rd cost for COMPOUND_DISTWTD
- update_mbmi_for_compound_type(mbmi, COMPOUND_DISTWTD);
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
- AOM_PLANE_Y, AOM_PLANE_Y);
- model_rd_sb_fn[MODELRD_CURVFIT](
- cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &est_rate[COMPOUND_DISTWTD],
- &est_dist[COMPOUND_DISTWTD], NULL, NULL, NULL, NULL, NULL);
- est_rate[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_DISTWTD];
- est_rd[COMPOUND_DISTWTD] =
- RDCOST(x->rdmult, est_rate[COMPOUND_DISTWTD] + *rate_mv,
- est_dist[COMPOUND_DISTWTD]);
-
- // Choose the better of the two based on modeled cost and call
- // estimate_yrd_for_sb() for that one.
- if (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD]) {
- update_mbmi_for_compound_type(mbmi, COMPOUND_AVERAGE);
- restore_dst_buf(xd, *orig_dst, 1);
- rs2 = masked_type_cost[COMPOUND_AVERAGE];
- RD_STATS est_rd_stats;
- const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
- const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
- const int64_t est_rd_ =
- estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
- if (est_rd_ != INT64_MAX) {
- best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
- est_rd_stats.dist);
- restore_dst_buf(xd, *tmp_dst, 1);
- comp_rate[COMPOUND_AVERAGE] = est_rd_stats.rate;
- comp_dist[COMPOUND_AVERAGE] = est_rd_stats.dist;
- comp_model_rd[COMPOUND_AVERAGE] = est_rd[COMPOUND_AVERAGE];
- comp_model_rd_cur = est_rd[COMPOUND_AVERAGE];
- }
- restore_dst_buf(xd, *tmp_dst, 1);
- } else {
- rs2 = masked_type_cost[COMPOUND_DISTWTD];
- RD_STATS est_rd_stats;
- const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
- const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
- const int64_t est_rd_ =
- estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
-
- if (est_rd_ != INT64_MAX) {
- best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
- est_rd_stats.dist);
- comp_rate[COMPOUND_DISTWTD] = est_rd_stats.rate;
- comp_dist[COMPOUND_DISTWTD] = est_rd_stats.dist;
- comp_model_rd[COMPOUND_DISTWTD] = est_rd[COMPOUND_DISTWTD];
- comp_model_rd_cur = est_rd[COMPOUND_DISTWTD];
+ // Backup rate and distortion for future reuse
+ comp_rate[cur_type] = est_rd_stats.rate;
+ comp_dist[cur_type] = est_rd_stats.dist;
+ comp_model_rd[cur_type] = comp_model_rd_cur;
+ }
+ } else {
+ // Calculate RD cost based on stored stats
+ assert(comp_dist[cur_type] != INT64_MAX);
+ best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
+ comp_dist[cur_type]);
+ comp_model_rd_cur = comp_model_rd[cur_type];
}
}
+ // use spare buffer for following compound type try
+ if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
} else {
- if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
- update_mbmi_for_compound_type(mbmi, cur_type);
- rs2 = masked_type_cost[cur_type];
- const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
- if (mode_rd < ref_best_rd) {
- // Reuse data if matching record is found
- if (comp_rate[cur_type] == INT_MAX) {
- av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst,
- bsize, AOM_PLANE_Y, AOM_PLANE_Y);
- if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
- RD_STATS est_rd_stats;
- const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
- const int64_t est_rd = estimate_yrd_for_sb(
- cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
- if (est_rd != INT64_MAX) {
- best_rd_cur =
- RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
- est_rd_stats.dist);
- model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
- cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
- &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
- comp_model_rd_cur =
- RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
+ update_mbmi_for_compound_type(mbmi, cur_type);
+ rs2 = masked_type_cost[cur_type];
+ int64_t approx_rd = ((*rd / cpi->max_comp_type_rd_threshold_div) *
+ cpi->max_comp_type_rd_threshold_mul);
- // Backup rate and distortion for future reuse
- comp_rate[cur_type] = est_rd_stats.rate;
- comp_dist[cur_type] = est_rd_stats.dist;
- comp_model_rd[cur_type] = comp_model_rd_cur;
- }
- } else {
- // Calculate RD cost based on stored stats
- assert(comp_dist[cur_type] != INT64_MAX);
- best_rd_cur =
- RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
- comp_dist[cur_type]);
- comp_model_rd_cur = comp_model_rd[cur_type];
- }
- }
- // use spare buffer for following compound type try
- if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
- } else {
- update_mbmi_for_compound_type(mbmi, cur_type);
- rs2 = masked_type_cost[cur_type];
- int64_t approx_rd = ((*rd / cpi->max_comp_type_rd_threshold_div) *
- cpi->max_comp_type_rd_threshold_mul);
+ if (approx_rd < ref_best_rd) {
+ int8_t enable_wedge = ((cur_type == COMPOUND_WEDGE) &&
+ enable_wedge_interinter_search(x, cpi));
+ int8_t enable_diffwtd =
+ ((cur_type == COMPOUND_DIFFWTD) && cpi->oxcf.enable_diff_wtd_comp);
- if (approx_rd < ref_best_rd) {
- int8_t enable_wedge = ((cur_type == COMPOUND_WEDGE) &&
- enable_wedge_interinter_search(x, cpi));
- int8_t enable_diffwtd = ((cur_type == COMPOUND_DIFFWTD) &&
- cpi->oxcf.enable_diff_wtd_comp);
-
- if (enable_wedge || enable_diffwtd) {
- const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
- best_rd_cur = build_and_cost_compound_type(
- cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
- &tmp_rate_mv, preds0, preds1, buffers->residual1,
- buffers->diff10, strides, mi_row, mi_col, rd_stats->rate,
- tmp_rd_thresh, &calc_pred_masked_compound, comp_rate, comp_dist,
- comp_model_rd, best_type_stats.comp_best_model_rd,
- &comp_model_rd_cur);
- }
+ if (enable_wedge || enable_diffwtd) {
+ const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
+ best_rd_cur = build_and_cost_compound_type(
+ cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
+ &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
+ strides, mi_row, mi_col, rd_stats->rate, tmp_rd_thresh,
+ &calc_pred_masked_compound, comp_rate, comp_dist, comp_model_rd,
+ best_type_stats.comp_best_model_rd, &comp_model_rd_cur);
}
}
}