Unify mask type cost computation in comp type rd
Moved the masked_type_cost computation outside the loop in
compound_type_rd()
Change-Id: I25db618b3eefadf028782523840247ac7b261794
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 0c8d0aa..4b3e09b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -9706,6 +9706,29 @@
return cost;
}
+// Calculates the cost for compound type mask
+static INLINE void calc_masked_type_cost(MACROBLOCK *x, BLOCK_SIZE bsize,
+ int comp_group_idx_ctx,
+ int comp_index_ctx,
+ int masked_compound_used,
+ int *masked_type_cost) {
+ av1_zero_array(masked_type_cost, 4);
+ // Account for group index cost when wedge and/or diffwtd prediction are
+ // enabled
+ if (masked_compound_used) {
+ masked_type_cost[0] += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+ masked_type_cost[1] += masked_type_cost[0];
+ masked_type_cost[2] += x->comp_group_idx_cost[comp_group_idx_ctx][1];
+ masked_type_cost[3] += masked_type_cost[2];
+ }
+
+ // Compute the cost to signal compound type
+ masked_type_cost[0] += x->comp_idx_cost[comp_index_ctx][1];
+ masked_type_cost[1] += x->comp_idx_cost[comp_index_ctx][0];
+ masked_type_cost[2] += x->compound_type_cost[bsize][0];
+ masked_type_cost[3] += x->compound_type_cost[bsize][1];
+}
+
static int compound_type_rd(
const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_col,
int mi_row, int_mv *cur_mv, int mode_search_mask, int masked_compound_used,
@@ -9730,6 +9753,8 @@
const int num_pix = 1 << num_pels_log2_lookup[bsize];
const int mask_len = 2 * num_pix * sizeof(uint8_t);
COMPOUND_TYPE cur_type;
+ // Local array to store the mask cost for different compound types
+ int masked_type_cost[4];
int best_compmode_interinter_cost = 0;
int calc_pred_masked_compound = 1;
int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
@@ -9759,6 +9784,13 @@
try_average_comp && try_distwtd_comp &&
comp_rate[COMPOUND_AVERAGE] == INT_MAX &&
comp_rate[COMPOUND_DISTWTD] == INT_MAX;
+ // The following context indices are independent of compound type
+ const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+
+ // Populates masked_type_cost local array for the 4 compound types
+ calc_masked_type_cost(x, bsize, comp_group_idx_ctx, comp_index_ctx,
+ masked_compound_used, masked_type_cost);
for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
if (((1 << cur_type) & mode_search_mask) == 0) {
if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
@@ -9772,25 +9804,16 @@
int64_t comp_model_rd_cur = INT64_MAX;
tmp_rate_mv = *rate_mv;
int64_t best_rd_cur = INT64_MAX;
- const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
- const int comp_index_ctx = get_comp_index_context(cm, xd);
if (cur_type == COMPOUND_DISTWTD && try_average_and_distwtd_comp) {
int est_rate[2];
int64_t est_dist[2], est_rd[2];
- int masked_type_cost[2] = { 0, 0 };
mbmi->comp_group_idx = 0;
// First find the modeled rd cost for COMPOUND_AVERAGE
mbmi->interinter_comp.type = COMPOUND_AVERAGE;
mbmi->compound_idx = 1;
- if (masked_compound_used) {
- masked_type_cost[COMPOUND_AVERAGE] +=
- x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
- }
- masked_type_cost[COMPOUND_AVERAGE] +=
- x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
*is_luma_interp_done = 1;
@@ -9806,12 +9829,6 @@
// Next find the modeled rd cost for COMPOUND_DISTWTD
mbmi->interinter_comp.type = COMPOUND_DISTWTD;
mbmi->compound_idx = 0;
- if (masked_compound_used) {
- masked_type_cost[COMPOUND_DISTWTD] +=
- x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
- }
- masked_type_cost[COMPOUND_DISTWTD] +=
- x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
model_rd_sb_fn[MODELRD_CURVFIT](
@@ -9863,17 +9880,10 @@
}
} else {
mbmi->interinter_comp.type = cur_type;
- int masked_type_cost = 0;
if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
mbmi->comp_group_idx = 0;
mbmi->compound_idx = (cur_type == COMPOUND_AVERAGE);
- if (masked_compound_used) {
- masked_type_cost +=
- x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
- }
- masked_type_cost +=
- x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
- rs2 = masked_type_cost;
+ rs2 = masked_type_cost[cur_type];
const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
if (mode_rd < ref_best_rd) {
// Reuse data if matching record is found
@@ -9918,11 +9928,7 @@
} else {
mbmi->comp_group_idx = 1;
mbmi->compound_idx = 1;
- masked_type_cost +=
- x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
- masked_type_cost +=
- x->compound_type_cost[bsize][cur_type - COMPOUND_WEDGE];
- rs2 = masked_type_cost;
+ rs2 = masked_type_cost[cur_type];
if (((*rd / cpi->max_comp_type_rd_threshold_div) *
cpi->max_comp_type_rd_threshold_mul) < ref_best_rd) {