JNT_COMP: Simplify logic on inter-inter comp modes
This patch simplies the checking criteria for the two groups of
compound modes. It also makes the encoder side cdf update inside the
RD loop consistent with that in the bitstream.
Experimental results on Google test sets (30 frames of lowres and
midres) confirm this patch obtains identical coding performance.
Change-Id: I170eea91f7d2be2170df544cfc2c692b09aa82d6
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 51ee501..60932d6 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2196,23 +2196,21 @@
#if CONFIG_JNT_COMP
// init
- mbmi->comp_group_idx = 1;
+ mbmi->comp_group_idx = 0;
mbmi->compound_idx = 1;
mbmi->interinter_compound_type = COMPOUND_AVERAGE;
- // read idx to indicate current compound inter prediction mode group
- int masked_compound_used = is_any_masked_compound_used(bsize);
- masked_compound_used = masked_compound_used && cm->allow_masked_compound;
-
if (has_second_ref(mbmi)) {
+ // Read idx to indicate current compound inter prediction mode group
+ const int masked_compound_used =
+ is_any_masked_compound_used(bsize) && cm->allow_masked_compound;
+
if (masked_compound_used) {
const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
mbmi->comp_group_idx = aom_read_symbol(
r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
if (xd->counts)
++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx];
- } else {
- mbmi->comp_group_idx = 0;
}
if (mbmi->comp_group_idx == 0) {
@@ -2222,44 +2220,33 @@
if (xd->counts)
++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
-
- if (mbmi->compound_idx) mbmi->interinter_compound_type = COMPOUND_AVERAGE;
} else {
+ assert(cm->reference_mode != SINGLE_REFERENCE &&
+ is_inter_compound_mode(mbmi->mode) &&
+ mbmi->motion_mode == SIMPLE_TRANSLATION);
+ assert(masked_compound_used);
+
// compound_segment, wedge
- mbmi->interinter_compound_type = COMPOUND_AVERAGE;
- if (cm->reference_mode != SINGLE_REFERENCE &&
- is_inter_compound_mode(mbmi->mode) &&
- mbmi->motion_mode == SIMPLE_TRANSLATION
-#if CONFIG_EXT_SKIP
- && !mbmi->skip_mode
-#endif // CONFIG_EXT_SKIP
- ) {
- if (is_any_masked_compound_used(bsize)) {
- if (cm->allow_masked_compound) {
- if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
- mbmi->interinter_compound_type =
- 1 + aom_read_symbol(r, ec_ctx->compound_type_cdf[bsize],
- COMPOUND_TYPES - 1, ACCT_STR);
- else
- mbmi->interinter_compound_type = COMPOUND_SEG;
+ if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+ mbmi->interinter_compound_type =
+ 1 + aom_read_symbol(r, ec_ctx->compound_type_cdf[bsize],
+ COMPOUND_TYPES - 1, ACCT_STR);
+ else
+ mbmi->interinter_compound_type = COMPOUND_SEG;
- if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
- assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
- mbmi->wedge_index =
- aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
- mbmi->wedge_sign = aom_read_bit(r, ACCT_STR);
- }
- if (mbmi->interinter_compound_type == COMPOUND_SEG) {
- mbmi->mask_type =
- aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
- }
- }
- }
-
- if (xd->counts)
- xd->counts->compound_interinter[bsize]
- [mbmi->interinter_compound_type - 1]++;
+ if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+ assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
+ mbmi->wedge_index =
+ aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
+ mbmi->wedge_sign = aom_read_bit(r, ACCT_STR);
+ } else {
+ assert(mbmi->interinter_compound_type == COMPOUND_SEG);
+ mbmi->mask_type = aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
}
+
+ if (xd->counts)
+ xd->counts
+ ->compound_interinter[bsize][mbmi->interinter_compound_type - 1]++;
}
}
#else // CONFIG_JNT_COMP
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 26bc975..607357d 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1460,16 +1460,16 @@
// First write idx to indicate current compound inter prediction mode group
// Group A (0): jnt_comp, compound_average
// Group B (1): interintra, compound_segment, wedge
- int masked_compound_used = is_any_masked_compound_used(bsize);
- masked_compound_used = masked_compound_used && cm->allow_masked_compound;
-
if (has_second_ref(mbmi)) {
- if (masked_compound_used) {
- assert(mbmi->comp_group_idx == 0);
+ const int masked_compound_used =
+ is_any_masked_compound_used(bsize) && cm->allow_masked_compound;
+ if (masked_compound_used) {
const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
aom_write_symbol(w, mbmi->comp_group_idx,
ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+ } else {
+ assert(mbmi->comp_group_idx == 0);
}
if (mbmi->comp_group_idx == 0) {
@@ -1480,27 +1480,26 @@
aom_write_symbol(w, mbmi->compound_idx,
ec_ctx->compound_index_cdf[comp_index_ctx], 2);
} else {
+ assert(cpi->common.reference_mode != SINGLE_REFERENCE &&
+ is_inter_compound_mode(mbmi->mode) &&
+ mbmi->motion_mode == SIMPLE_TRANSLATION);
+ assert(masked_compound_used);
+ // compound_segment, wedge
assert(mbmi->interinter_compound_type == COMPOUND_WEDGE ||
mbmi->interinter_compound_type == COMPOUND_SEG);
- // compound_segment, wedge
- if (cpi->common.reference_mode != SINGLE_REFERENCE &&
- is_inter_compound_mode(mbmi->mode) &&
- mbmi->motion_mode == SIMPLE_TRANSLATION &&
- is_any_masked_compound_used(bsize) && cm->allow_masked_compound) {
- if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
- aom_write_symbol(w, mbmi->interinter_compound_type - 1,
- ec_ctx->compound_type_cdf[bsize],
- COMPOUND_TYPES - 1);
- if (is_interinter_compound_used(COMPOUND_WEDGE, bsize) &&
- mbmi->interinter_compound_type == COMPOUND_WEDGE) {
- aom_write_literal(w, mbmi->wedge_index,
- get_wedge_bits_lookup(bsize));
- aom_write_bit(w, mbmi->wedge_sign);
- }
- if (mbmi->interinter_compound_type == COMPOUND_SEG) {
- aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
- }
+ if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+ aom_write_symbol(w, mbmi->interinter_compound_type - 1,
+ ec_ctx->compound_type_cdf[bsize],
+ COMPOUND_TYPES - 1);
+
+ if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+ assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
+ aom_write_literal(w, mbmi->wedge_index, get_wedge_bits_lookup(bsize));
+ aom_write_bit(w, mbmi->wedge_sign);
+ } else {
+ assert(mbmi->interinter_compound_type == COMPOUND_SEG);
+ aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
}
}
}
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index f3575b1..17715b3 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1125,7 +1125,7 @@
else
// This flag is also updated for 4x4 blocks
rdc->single_ref_used_flag = 1;
- if (is_comp_ref_allowed(mbmi->sb_type)) {
+ if (is_comp_ref_allowed(bsize)) {
counts->comp_inter[av1_get_reference_mode_context(cm, xd)]
[has_second_ref(mbmi)]++;
if (allow_update_cdf)
@@ -1237,31 +1237,39 @@
if (mbmi->mode == NEARESTMV) wm_ctx = 2;
}
- counts->motion_mode[wm_ctx][mbmi->sb_type][mbmi->motion_mode]++;
+ counts->motion_mode[wm_ctx][bsize][mbmi->motion_mode]++;
if (allow_update_cdf)
- update_cdf(fc->motion_mode_cdf[wm_ctx][mbmi->sb_type],
- mbmi->motion_mode, MOTION_MODES);
+ update_cdf(fc->motion_mode_cdf[wm_ctx][bsize], mbmi->motion_mode,
+ MOTION_MODES);
#else
- counts->motion_mode[mbmi->sb_type][mbmi->motion_mode]++;
+ counts->motion_mode[bsize][mbmi->motion_mode]++;
if (allow_update_cdf)
- update_cdf(fc->motion_mode_cdf[mbmi->sb_type], mbmi->motion_mode,
+ update_cdf(fc->motion_mode_cdf[bsize], mbmi->motion_mode,
MOTION_MODES);
#endif // CONFIG_EXT_WARPED_MOTION
} else if (motion_allowed == OBMC_CAUSAL) {
- counts->obmc[mbmi->sb_type][mbmi->motion_mode == OBMC_CAUSAL]++;
+ counts->obmc[bsize][mbmi->motion_mode == OBMC_CAUSAL]++;
if (allow_update_cdf)
- update_cdf(fc->obmc_cdf[mbmi->sb_type],
- mbmi->motion_mode == OBMC_CAUSAL, 2);
+ update_cdf(fc->obmc_cdf[bsize], mbmi->motion_mode == OBMC_CAUSAL,
+ 2);
}
}
#if CONFIG_JNT_COMP
if (has_second_ref(mbmi)) {
- const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
- ++counts->comp_group_idx[comp_group_idx_ctx][mbmi->comp_group_idx];
- if (allow_update_cdf)
- update_cdf(fc->comp_group_idx_cdf[comp_group_idx_ctx],
- mbmi->comp_group_idx, 2);
+ assert(cm->reference_mode != SINGLE_REFERENCE &&
+ is_inter_compound_mode(mbmi->mode) &&
+ mbmi->motion_mode == SIMPLE_TRANSLATION);
+
+ const int masked_compound_used =
+ is_any_masked_compound_used(bsize) && cm->allow_masked_compound;
+ if (masked_compound_used) {
+ const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+ ++counts->comp_group_idx[comp_group_idx_ctx][mbmi->comp_group_idx];
+ if (allow_update_cdf)
+ update_cdf(fc->comp_group_idx_cdf[comp_group_idx_ctx],
+ mbmi->comp_group_idx, 2);
+ }
if (mbmi->comp_group_idx == 0) {
const int comp_index_ctx = get_comp_index_context(cm, xd);
@@ -1269,23 +1277,19 @@
if (allow_update_cdf)
update_cdf(fc->compound_index_cdf[comp_index_ctx],
mbmi->compound_idx, 2);
- }
- }
-
- if (cm->reference_mode != SINGLE_REFERENCE &&
- is_inter_compound_mode(mbmi->mode) && mbmi->comp_group_idx &&
- mbmi->motion_mode == SIMPLE_TRANSLATION) {
- if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
- counts->compound_interinter[bsize]
- [mbmi->interinter_compound_type - 1]++;
- if (allow_update_cdf)
- update_cdf(fc->compound_type_cdf[bsize],
- mbmi->interinter_compound_type - 1,
- COMPOUND_TYPES - 1);
+ } else {
+ assert(masked_compound_used);
+ if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
+ counts->compound_interinter[bsize]
+ [mbmi->interinter_compound_type - 1]++;
+ if (allow_update_cdf)
+ update_cdf(fc->compound_type_cdf[bsize],
+ mbmi->interinter_compound_type - 1,
+ COMPOUND_TYPES - 1);
+ }
}
}
#else // CONFIG_JNT_COMP
-
if (cm->reference_mode != SINGLE_REFERENCE &&
is_inter_compound_mode(mbmi->mode) &&
mbmi->motion_mode == SIMPLE_TRANSLATION) {