Refactor core loop of compound_type_rd
Simplified the core loop structure and gatings in compound_type_rd
Change-Id: Ib761bdae6ef9453810d61260e12662feff289661
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 24a55ef..88731c3 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7758,7 +7758,7 @@
}
}
-static int64_t build_and_cost_compound_type(
+static int64_t masked_compound_type_rd(
const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
@@ -9804,6 +9804,42 @@
}
}
+// Computes the valid compound_types to be evaluated in core loop
+static INLINE int compute_valid_comp_types(
+ MACROBLOCK *x, const AV1_COMP *const cpi, BLOCK_SIZE bsize,
+ int try_average_comp, int try_distwtd_comp,
+ int try_average_and_distwtd_comp, int masked_compound_used,
+ int mode_search_mask, COMPOUND_TYPE *valid_comp_types) {
+ int valid_type_count = 0;
+ int comp_type, valid_check;
+ int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
+
+ // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
+ for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
+ comp_type++) {
+ valid_check =
+ (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
+ if (!try_average_and_distwtd_comp && valid_check &&
+ is_interinter_compound_used(comp_type, bsize))
+ valid_comp_types[valid_type_count++] = comp_type;
+ }
+ // Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
+ if (masked_compound_used) {
+ // enable_masked_type[0] corresponds to COMPOUND_WEDGE
+ // enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
+ enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
+ enable_masked_type[1] = cpi->oxcf.enable_diff_wtd_comp;
+ for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
+ comp_type++) {
+ if ((mode_search_mask & (1 << comp_type)) &&
+ is_interinter_compound_used(comp_type, bsize) &&
+ enable_masked_type[comp_type - COMPOUND_WEDGE])
+ valid_comp_types[valid_type_count++] = comp_type;
+ }
+ }
+ return valid_type_count;
+}
+
static int compound_type_rd(
const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_col,
int mi_row, int_mv *cur_mv, int mode_search_mask, int masked_compound_used,
@@ -9910,6 +9946,7 @@
const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
const int64_t est_rd_ =
estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
+
if (est_rd_ != INT64_MAX) {
best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
est_rd_stats.dist);
@@ -9926,23 +9963,25 @@
}
}
- for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
- if (((1 << cur_type) & mode_search_mask) == 0) {
- if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
- continue;
- }
- if (!is_interinter_compound_used(cur_type, bsize)) continue;
- if (cur_type >= COMPOUND_WEDGE && !masked_compound_used) break;
- if (cur_type == COMPOUND_DISTWTD &&
- (!try_distwtd_comp || try_average_and_distwtd_comp))
- continue;
- if (cur_type == COMPOUND_AVERAGE && try_average_and_distwtd_comp) continue;
+ COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
+ COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
+ };
+ int valid_type_count = 0;
+ // compute_valid_comp_types() returns the number of valid compound types to be
+ // evaluated and populates the same in the local array valid_comp_types[]
+ valid_type_count = compute_valid_comp_types(
+ x, cpi, bsize, try_average_comp, try_distwtd_comp,
+ try_average_and_distwtd_comp, masked_compound_used, mode_search_mask,
+ valid_comp_types);
+ if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
+ for (int i = 0; i < valid_type_count; i++) {
+ cur_type = valid_comp_types[i];
comp_model_rd_cur = INT64_MAX;
tmp_rate_mv = *rate_mv;
best_rd_cur = INT64_MAX;
- if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
+ if (cur_type < COMPOUND_WEDGE) {
update_mbmi_for_compound_type(mbmi, cur_type);
rs2 = masked_type_cost[cur_type];
const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
@@ -9987,20 +10026,13 @@
cpi->max_comp_type_rd_threshold_mul);
if (approx_rd < ref_best_rd) {
- int8_t enable_wedge = ((cur_type == COMPOUND_WEDGE) &&
- enable_wedge_interinter_search(x, cpi));
- int8_t enable_diffwtd =
- ((cur_type == COMPOUND_DIFFWTD) && cpi->oxcf.enable_diff_wtd_comp);
-
- if (enable_wedge || enable_diffwtd) {
- const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
- best_rd_cur = build_and_cost_compound_type(
- cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
- &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
- strides, mi_row, mi_col, rd_stats->rate, tmp_rd_thresh,
- &calc_pred_masked_compound, comp_rate, comp_dist, comp_model_rd,
- best_type_stats.comp_best_model_rd, &comp_model_rd_cur);
- }
+ const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
+ best_rd_cur = masked_compound_type_rd(
+ cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
+ &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
+ strides, mi_row, mi_col, rd_stats->rate, tmp_rd_thresh,
+ &calc_pred_masked_compound, comp_rate, comp_dist, comp_model_rd,
+ best_type_stats.comp_best_model_rd, &comp_model_rd_cur);
}
}
if (best_rd_cur < *rd) {