Streamline compound types
This patch includes various refactoring towards integrating
dist_wtd compound type search with the other compound modes.
The current patch keeps separate rd loops for distwtd compound
and other compound modes through the macro SEPARATE_COMP_DISTWTD_RD.
Currently it is set as 1 so that no change is introduced in
the coding stats.
Experiments are in progress for turning this macro off and
cleaning up the code.
Change-Id: Ibc29fd4d088f977053cdd38fc03ea2005b073c07
diff --git a/av1/common/entropy.c b/av1/common/entropy.c
index 4f95ef6..f63ac98 100644
--- a/av1/common/entropy.c
+++ b/av1/common/entropy.c
@@ -101,7 +101,7 @@
RESET_CDF_COUNTER(fc->refmv_cdf, 2);
RESET_CDF_COUNTER(fc->drl_cdf, 2);
RESET_CDF_COUNTER(fc->inter_compound_mode_cdf, INTER_COMPOUND_MODES);
- RESET_CDF_COUNTER(fc->compound_type_cdf, COMPOUND_TYPES - 1);
+ RESET_CDF_COUNTER(fc->compound_type_cdf, MASKED_COMPOUND_TYPES);
RESET_CDF_COUNTER(fc->wedge_idx_cdf, 16);
RESET_CDF_COUNTER(fc->interintra_cdf, 2);
RESET_CDF_COUNTER(fc->wedge_interintra_cdf, 2);
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 8e7e952..90702ac 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -488,17 +488,17 @@
{ AOM_CDF2(16384) }
};
-static const aom_cdf_prob
- default_compound_type_cdf[BLOCK_SIZES_ALL][CDF_SIZE(COMPOUND_TYPES - 1)] = {
- { AOM_CDF2(16384) }, { AOM_CDF2(16384) }, { AOM_CDF2(16384) },
- { AOM_CDF2(23431) }, { AOM_CDF2(13171) }, { AOM_CDF2(11470) },
- { AOM_CDF2(9770) }, { AOM_CDF2(9100) }, { AOM_CDF2(8233) },
- { AOM_CDF2(6172) }, { AOM_CDF2(16384) }, { AOM_CDF2(16384) },
- { AOM_CDF2(16384) }, { AOM_CDF2(16384) }, { AOM_CDF2(16384) },
- { AOM_CDF2(16384) }, { AOM_CDF2(16384) }, { AOM_CDF2(16384) },
- { AOM_CDF2(11820) }, { AOM_CDF2(7701) }, { AOM_CDF2(16384) },
- { AOM_CDF2(16384) }
- };
+static const aom_cdf_prob default_compound_type_cdf[BLOCK_SIZES_ALL][CDF_SIZE(
+ MASKED_COMPOUND_TYPES)] = {
+ { AOM_CDF2(16384) }, { AOM_CDF2(16384) }, { AOM_CDF2(16384) },
+ { AOM_CDF2(23431) }, { AOM_CDF2(13171) }, { AOM_CDF2(11470) },
+ { AOM_CDF2(9770) }, { AOM_CDF2(9100) }, { AOM_CDF2(8233) },
+ { AOM_CDF2(6172) }, { AOM_CDF2(16384) }, { AOM_CDF2(16384) },
+ { AOM_CDF2(16384) }, { AOM_CDF2(16384) }, { AOM_CDF2(16384) },
+ { AOM_CDF2(16384) }, { AOM_CDF2(16384) }, { AOM_CDF2(16384) },
+ { AOM_CDF2(11820) }, { AOM_CDF2(7701) }, { AOM_CDF2(16384) },
+ { AOM_CDF2(16384) }
+};
static const aom_cdf_prob default_wedge_idx_cdf[BLOCK_SIZES_ALL][CDF_SIZE(16)] =
{ { AOM_CDF16(2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432,
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 7047f34..69b5218 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -92,7 +92,8 @@
aom_cdf_prob inter_compound_mode_cdf[INTER_MODE_CONTEXTS]
[CDF_SIZE(INTER_COMPOUND_MODES)];
- aom_cdf_prob compound_type_cdf[BLOCK_SIZES_ALL][CDF_SIZE(COMPOUND_TYPES - 1)];
+ aom_cdf_prob compound_type_cdf[BLOCK_SIZES_ALL]
+ [CDF_SIZE(MASKED_COMPOUND_TYPES)];
aom_cdf_prob wedge_idx_cdf[BLOCK_SIZES_ALL][CDF_SIZE(16)];
aom_cdf_prob interintra_cdf[BLOCK_SIZE_GROUPS][CDF_SIZE(2)];
aom_cdf_prob wedge_interintra_cdf[BLOCK_SIZES_ALL][CDF_SIZE(2)];
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 9af7b0f..97fa682 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -455,9 +455,11 @@
enum {
COMPOUND_AVERAGE,
+ COMPOUND_DISTWTD,
COMPOUND_WEDGE,
COMPOUND_DIFFWTD,
COMPOUND_TYPES,
+ MASKED_COMPOUND_TYPES = 2,
} UENUM1BYTE(COMPOUND_TYPE);
enum {
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 1385be5..4d62991 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -167,6 +167,7 @@
const int comp_allowed = is_comp_ref_allowed(sb_type);
switch (type) {
case COMPOUND_AVERAGE:
+ case COMPOUND_DISTWTD:
case COMPOUND_DIFFWTD: return comp_allowed;
case COMPOUND_WEDGE:
return comp_allowed && wedge_params_lookup[sb_type].bits > 0;
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index b7431f2..b8fcb5d 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -1421,9 +1421,12 @@
const int comp_index_ctx = get_comp_index_context(cm, xd);
mbmi->compound_idx = aom_read_symbol(
r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR);
+ mbmi->interinter_comp.type =
+ mbmi->compound_idx ? COMPOUND_AVERAGE : COMPOUND_DISTWTD;
} else {
// Distance-weighted compound is disabled, so always use average
mbmi->compound_idx = 1;
+ mbmi->interinter_comp.type = COMPOUND_AVERAGE;
}
} else {
assert(cm->current_frame.reference_mode != SINGLE_REFERENCE &&
@@ -1434,8 +1437,9 @@
// compound_diffwtd, wedge
if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
mbmi->interinter_comp.type =
- 1 + aom_read_symbol(r, ec_ctx->compound_type_cdf[bsize],
- COMPOUND_TYPES - 1, ACCT_STR);
+ COMPOUND_WEDGE + aom_read_symbol(r,
+ ec_ctx->compound_type_cdf[bsize],
+ MASKED_COMPOUND_TYPES, ACCT_STR);
else
mbmi->interinter_comp.type = COMPOUND_DIFFWTD;
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 019ecc4..de43624 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1162,9 +1162,9 @@
mbmi->interinter_comp.type == COMPOUND_DIFFWTD);
if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
- aom_write_symbol(w, mbmi->interinter_comp.type - 1,
+ aom_write_symbol(w, mbmi->interinter_comp.type - COMPOUND_WEDGE,
ec_ctx->compound_type_cdf[bsize],
- COMPOUND_TYPES - 1);
+ MASKED_COMPOUND_TYPES);
if (mbmi->interinter_comp.type == COMPOUND_WEDGE) {
assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index e96ca16..fa0ab80 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -351,7 +351,7 @@
// BWDREF_FRAME) in bidir-comp mode.
int comp_bwdref_cost[REF_CONTEXTS][BWD_REFS - 1][2];
int inter_compound_mode_cost[INTER_MODE_CONTEXTS][INTER_COMPOUND_MODES];
- int compound_type_cost[BLOCK_SIZES_ALL][COMPOUND_TYPES - 1];
+ int compound_type_cost[BLOCK_SIZES_ALL][MASKED_COMPOUND_TYPES];
int wedge_idx_cost[BLOCK_SIZES_ALL][16];
int interintra_cost[BLOCK_SIZE_GROUPS][2];
int wedge_interintra_cost[BLOCK_SIZES_ALL][2];
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 5b10714..94bd4f0 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1290,11 +1290,13 @@
assert(masked_compound_used);
if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
#if CONFIG_ENTROPY_STATS
- ++counts->compound_type[bsize][mbmi->interinter_comp.type - 1];
+ ++counts->compound_type[bsize][mbmi->interinter_comp.type -
+ COMPOUND_WEDGE];
#endif
if (allow_update_cdf) {
update_cdf(fc->compound_type_cdf[bsize],
- mbmi->interinter_comp.type - 1, COMPOUND_TYPES - 1);
+ mbmi->interinter_comp.type - COMPOUND_WEDGE,
+ MASKED_COMPOUND_TYPES);
}
}
}
@@ -5431,7 +5433,7 @@
AVERAGE_CDF(ctx_left->inter_compound_mode_cdf,
ctx_tr->inter_compound_mode_cdf, INTER_COMPOUND_MODES);
AVERAGE_CDF(ctx_left->compound_type_cdf, ctx_tr->compound_type_cdf,
- COMPOUND_TYPES - 1);
+ MASKED_COMPOUND_TYPES);
AVERAGE_CDF(ctx_left->wedge_idx_cdf, ctx_tr->wedge_idx_cdf, 16);
AVERAGE_CDF(ctx_left->interintra_cdf, ctx_tr->interintra_cdf, 2);
AVERAGE_CDF(ctx_left->wedge_interintra_cdf, ctx_tr->wedge_interintra_cdf, 2);
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index c0f50bf..7b0af8f 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -419,7 +419,7 @@
unsigned int interintra[BLOCK_SIZE_GROUPS][2];
unsigned int interintra_mode[BLOCK_SIZE_GROUPS][INTERINTRA_MODES];
unsigned int wedge_interintra[BLOCK_SIZES_ALL][2];
- unsigned int compound_type[BLOCK_SIZES_ALL][COMPOUND_TYPES - 1];
+ unsigned int compound_type[BLOCK_SIZES_ALL][MASKED_COMPOUND_TYPES];
unsigned int motion_mode[BLOCK_SIZES_ALL][MOTION_MODES];
unsigned int obmc[BLOCK_SIZES_ALL][2];
unsigned int intra_inter[INTRA_INTER_CONTEXTS][2];
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 9e6d6b1..7598d7c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8433,16 +8433,20 @@
}
// Store the stats for compound average
- comp_rate[0] = st->rate[0];
- comp_dist[0] = st->dist[0];
+ comp_rate[COMPOUND_AVERAGE] = st->rate[COMPOUND_AVERAGE];
+ comp_dist[COMPOUND_AVERAGE] = st->dist[COMPOUND_AVERAGE];
+ comp_rate[COMPOUND_DISTWTD] = st->rate[COMPOUND_DISTWTD];
+ comp_dist[COMPOUND_DISTWTD] = st->dist[COMPOUND_DISTWTD];
// For compound wedge/segment, reuse data only if NEWMV is not present in
// either of the directions
if ((!have_newmv_in_inter_mode(mi->mode) &&
!have_newmv_in_inter_mode(st->mode)) ||
(cpi->sf.disable_interinter_wedge_newmv_search)) {
- memcpy(&comp_rate[1], &st->rate[1], sizeof(comp_rate[1]) * 2);
- memcpy(&comp_dist[1], &st->dist[1], sizeof(comp_dist[1]) * 2);
+ memcpy(&comp_rate[COMPOUND_WEDGE], &st->rate[COMPOUND_WEDGE],
+ sizeof(comp_rate[COMPOUND_WEDGE]) * 2);
+ memcpy(&comp_dist[COMPOUND_WEDGE], &st->dist[COMPOUND_WEDGE],
+ sizeof(comp_dist[COMPOUND_WEDGE]) * 2);
}
return 1;
}
@@ -9720,14 +9724,12 @@
uint8_t *tmp_best_mask_buf; // backup of the best segmentation mask
} CompoundTypeRdBuffers;
-static int compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x,
- BLOCK_SIZE bsize, int mi_col, int mi_row,
- int_mv *cur_mv, int masked_compound_used,
- const BUFFER_SET *orig_dst,
- const BUFFER_SET *tmp_dst,
- CompoundTypeRdBuffers *buffers, int *rate_mv,
- int64_t *rd, RD_STATS *rd_stats,
- int64_t ref_best_rd, int *is_luma_interp_done) {
+static int compound_type_rd(
+ const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_col,
+ int mi_row, int_mv *cur_mv, int do_comp_distwtd, int masked_compound_used,
+ const BUFFER_SET *orig_dst, const BUFFER_SET *tmp_dst,
+ CompoundTypeRdBuffers *buffers, int *rate_mv, int64_t *rd,
+ RD_STATS *rd_stats, int64_t ref_best_rd, int *is_luma_interp_done) {
const AV1_COMMON *cm = &cpi->common;
MACROBLOCKD *xd = &x->e_mbd;
MB_MODE_INFO *mbmi = xd->mi[0];
@@ -9747,16 +9749,24 @@
COMPOUND_TYPE cur_type;
int best_compmode_interinter_cost = 0;
int calc_pred_masked_compound = 1;
- int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX };
- int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX };
+ int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
+ INT64_MAX };
+ int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
+ // TODO(debargha): Remove the code related to comp_rd_stats since it is
+ // not used.
const int match_found =
find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist);
best_mv[0].as_int = cur_mv[0].as_int;
best_mv[1].as_int = cur_mv[1].as_int;
*rd = INT64_MAX;
for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
- if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break;
+ if (cur_type >= COMPOUND_WEDGE && !masked_compound_used) break;
if (!is_interinter_compound_used(cur_type, bsize)) continue;
+ if (cur_type == COMPOUND_DISTWTD &&
+ (!do_comp_distwtd ||
+ !cm->seq_params.order_hint_info.enable_dist_wtd_comp ||
+ cpi->sf.use_dist_wtd_comp_flag == DIST_WTD_COMP_DISABLED))
+ continue;
tmp_rate_mv = *rate_mv;
int64_t best_rd_cur = INT64_MAX;
mbmi->interinter_comp.type = cur_type;
@@ -9764,48 +9774,52 @@
const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
const int comp_index_ctx = get_comp_index_context(cm, xd);
- mbmi->compound_idx = 1;
- if (cur_type == COMPOUND_AVERAGE) {
+ if (cur_type == COMPOUND_AVERAGE || cur_type == COMPOUND_DISTWTD) {
mbmi->comp_group_idx = 0;
+ mbmi->compound_idx = (cur_type == COMPOUND_AVERAGE);
if (masked_compound_used) {
- masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+ masked_type_cost +=
+ x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
}
- masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
+ masked_type_cost += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
rs2 = masked_type_cost;
const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
if (mode_rd < ref_best_rd) {
// Reuse data if matching record is found
- if (comp_rate[0] == INT_MAX) {
+ if (comp_rate[cur_type] == INT_MAX) {
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
- *is_luma_interp_done = 1;
+ if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
RD_STATS est_rd_stats;
const int64_t est_rd =
estimate_yrd_for_sb(cpi, bsize, x, INT64_MAX, &est_rd_stats);
- if (comp_rate[0] != INT_MAX) {
- assert(comp_rate[0] == est_rd_stats.rate);
- assert(comp_dist[0] == est_rd_stats.dist);
+ if (comp_rate[cur_type] != INT_MAX) {
+ assert(comp_rate[cur_type] == est_rd_stats.rate);
+ assert(comp_dist[cur_type] == est_rd_stats.dist);
}
if (est_rd != INT64_MAX) {
best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
est_rd_stats.dist);
// Backup rate and distortion for future reuse
- comp_rate[0] = est_rd_stats.rate;
- comp_dist[0] = est_rd_stats.dist;
+ comp_rate[cur_type] = est_rd_stats.rate;
+ comp_dist[cur_type] = est_rd_stats.dist;
}
} else {
// Calculate RD cost based on stored stats
- assert(comp_dist[0] != INT64_MAX);
- best_rd_cur =
- RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[0], comp_dist[0]);
+ assert(comp_dist[cur_type] != INT64_MAX);
+ best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
+ comp_dist[cur_type]);
}
}
// use spare buffer for following compound type try
restore_dst_buf(xd, *tmp_dst, 1);
} else {
mbmi->comp_group_idx = 1;
- masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
- masked_type_cost += x->compound_type_cost[bsize][cur_type - 1];
+ mbmi->compound_idx = 1;
+ masked_type_cost +=
+ x->comp_group_idx_cost[comp_group_idx_ctx][mbmi->comp_group_idx];
+ masked_type_cost +=
+ x->compound_type_cost[bsize][cur_type - COMPOUND_WEDGE];
rs2 = masked_type_cost;
if (((*rd / cpi->max_comp_type_rd_threshold_div) *
@@ -9846,8 +9860,8 @@
mbmi->mv[1].as_int = cur_mv[1].as_int;
}
if (mbmi->interinter_comp.type != best_compound_data.type) {
- mbmi->comp_group_idx =
- (best_compound_data.type == COMPOUND_AVERAGE) ? 0 : 1;
+ mbmi->comp_group_idx = (best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
+ mbmi->compound_idx = !(best_compound_data.type == COMPOUND_DISTWTD);
mbmi->interinter_comp = best_compound_data;
memcpy(xd->seg_mask, buffers->tmp_best_mask_buf, mask_len);
}
@@ -9902,6 +9916,8 @@
int_mv mv;
} inter_mode_info;
+#define SEPARATE_COMP_DISTWTD_RD 1
+
static int64_t handle_inter_mode(
const AV1_COMP *const cpi, TileDataEnc *tile_data, MACROBLOCK *x,
BLOCK_SIZE bsize, RD_STATS *rd_stats, RD_STATS *rd_stats_y,
@@ -9955,10 +9971,14 @@
inter_mode_info mode_info[MAX_REF_MV_SERCH];
int comp_idx;
+#if SEPARATE_COMP_DISTWTD_RD
const int search_dist_wtd_comp =
is_comp_pred & cm->seq_params.order_hint_info.enable_dist_wtd_comp &
(mbmi->mode != GLOBAL_GLOBALMV) &
(cpi->sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
+#else
+ const int search_dist_wtd_comp = 0;
+#endif // SEPARATE_COMP_DISTWTD_RD
// TODO(jingning): This should be deprecated shortly.
const int has_nearmv = have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0;
@@ -10022,13 +10042,14 @@
mbmi->compound_idx = comp_idx;
if (is_comp_pred && comp_idx == 0) {
*rd_stats = backup_rd_stats;
- mbmi->interinter_comp.type = COMPOUND_AVERAGE;
+ mbmi->interinter_comp.type = COMPOUND_DISTWTD;
mbmi->num_proj_ref = 0;
mbmi->motion_mode = SIMPLE_TRANSLATION;
mbmi->comp_group_idx = 0;
const int comp_index_ctx = get_comp_index_context(cm, xd);
- compmode_interinter_cost += x->comp_idx_cost[comp_index_ctx][0];
+ compmode_interinter_cost +=
+ x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
}
int_mv cur_mv[2];
@@ -10171,10 +10192,11 @@
}
int64_t best_rd_compound;
+ const int do_comp_distwtd = !SEPARATE_COMP_DISTWTD_RD;
compmode_interinter_cost = compound_type_rd(
- cpi, x, bsize, mi_col, mi_row, cur_mv, masked_compound_used,
- &orig_dst, &tmp_dst, rd_buffers, &rate_mv, &best_rd_compound,
- rd_stats, ref_best_rd, &is_luma_interp_done);
+ cpi, x, bsize, mi_col, mi_row, cur_mv, do_comp_distwtd,
+ masked_compound_used, &orig_dst, &tmp_dst, rd_buffers, &rate_mv,
+ &best_rd_compound, rd_stats, ref_best_rd, &is_luma_interp_done);
if (ref_best_rd < INT64_MAX &&
(best_rd_compound >> 4) * 13 > ref_best_rd) {
restore_dst_buf(xd, orig_dst, num_planes);
@@ -10184,6 +10206,7 @@
// COMPOUND_AVERAGE is selected because it is the first
// candidate in compound_type_rd, and the following
// compound types searching uses tmp_dst buffer
+
if (mbmi->interinter_comp.type == COMPOUND_AVERAGE &&
is_luma_interp_done) {
if (num_planes > 1) {