JNT_COMP: divide compound modes into two groups Divide compound inter prediction modes into two groups: Group A: jnt_comp, compound_average Group B: interintra, compound_segment, wedge Change-Id: I1142da2e3dfadf382d6b8183a87bde95119cf1b7
diff --git a/av1/common/blockd.h b/av1/common/blockd.h index eb434eb..5e5fa29 100644 --- a/av1/common/blockd.h +++ b/av1/common/blockd.h
@@ -367,6 +367,7 @@ #if CONFIG_JNT_COMP int compound_idx; + int comp_group_idx; #endif } MB_MODE_INFO;
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c index 49a4501..c0368d4 100644 --- a/av1/common/entropymode.c +++ b/av1/common/entropymode.c
@@ -1473,9 +1473,20 @@ { AOM_ICDF(16384), AOM_ICDF(32768), 0 }, { AOM_ICDF(8192), AOM_ICDF(32768), 0 }, }; + +static const aom_cdf_prob + default_comp_group_idx_cdfs[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)] = { + { AOM_ICDF(24576), AOM_ICDF(32768), 0 }, + { AOM_ICDF(16384), AOM_ICDF(32768), 0 }, + { AOM_ICDF(8192), AOM_ICDF(32768), 0 }, + }; static const aom_prob default_compound_idx_probs[COMP_INDEX_CONTEXTS] = { 192, 128, 64, 192, 128, 64 }; + +static const aom_prob default_comp_group_idx_probs[COMP_GROUP_IDX_CONTEXTS] = { + 192, 128, 64 +}; #endif // CONFIG_JNT_COMP #if CONFIG_FILTER_INTRA @@ -2929,7 +2940,9 @@ av1_copy(fc->txfm_partition_cdf, default_txfm_partition_cdf); #if CONFIG_JNT_COMP av1_copy(fc->compound_index_cdf, default_compound_idx_cdfs); + av1_copy(fc->comp_group_idx_cdf, default_comp_group_idx_cdfs); av1_copy(fc->compound_index_probs, default_compound_idx_probs); + av1_copy(fc->comp_group_idx_probs, default_comp_group_idx_probs); #endif // CONFIG_JNT_COMP av1_copy(fc->newmv_cdf, default_newmv_cdf); av1_copy(fc->zeromv_cdf, default_zeromv_cdf); @@ -3044,6 +3057,9 @@ for (i = 0; i < COMP_INDEX_CONTEXTS; ++i) fc->compound_index_probs[i] = av1_mode_mv_merge_probs( pre_fc->compound_index_probs[i], counts->compound_index[i]); + for (i = 0; i < COMP_GROUP_IDX_CONTEXTS; ++i) + fc->comp_group_idx_probs[i] = av1_mode_mv_merge_probs( + pre_fc->comp_group_idx_probs[i], counts->comp_group_idx[i]); #endif // CONFIG_JNT_COMP }
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h index ec17d1c..adf461d 100644 --- a/av1/common/entropymode.h +++ b/av1/common/entropymode.h
@@ -245,7 +245,9 @@ aom_cdf_prob txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)]; #if CONFIG_JNT_COMP aom_cdf_prob compound_index_cdf[COMP_INDEX_CONTEXTS][CDF_SIZE(2)]; + aom_cdf_prob comp_group_idx_cdf[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)]; aom_prob compound_index_probs[COMP_INDEX_CONTEXTS]; + aom_prob comp_group_idx_probs[COMP_GROUP_IDX_CONTEXTS]; #endif // CONFIG_JNT_COMP #if CONFIG_EXT_SKIP aom_cdf_prob skip_mode_cdfs[SKIP_CONTEXTS][CDF_SIZE(2)]; @@ -395,6 +397,7 @@ unsigned int skip[SKIP_CONTEXTS][2]; #if CONFIG_JNT_COMP unsigned int compound_index[COMP_INDEX_CONTEXTS][2]; + unsigned int comp_group_idx[COMP_GROUP_IDX_CONTEXTS][2]; #endif // CONFIG_JNT_COMP unsigned int delta_q[DELTA_Q_PROBS][2]; #if CONFIG_EXT_DELTA_Q
diff --git a/av1/common/enums.h b/av1/common/enums.h index ddcb342..667cc9f 100644 --- a/av1/common/enums.h +++ b/av1/common/enums.h
@@ -592,6 +592,7 @@ #if CONFIG_JNT_COMP #define COMP_INDEX_CONTEXTS 6 +#define COMP_GROUP_IDX_CONTEXTS 3 #endif // CONFIG_JNT_COMP #define NMV_CONTEXTS 3
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h index 5eaf1fa..8047769 100644 --- a/av1/common/pred_common.h +++ b/av1/common/pred_common.h
@@ -122,6 +122,29 @@ return above_ctx + left_ctx + 3 * offset; } + +static INLINE int get_comp_group_idx_context(const MACROBLOCKD *xd) { + const MODE_INFO *const above_mi = xd->above_mi; + const MODE_INFO *const left_mi = xd->left_mi; + int above_ctx = 0, left_ctx = 0; + + if (above_mi) { + const MB_MODE_INFO *above_mbmi = &above_mi->mbmi; + if (has_second_ref(above_mbmi)) + above_ctx = above_mbmi->comp_group_idx; + else if (above_mbmi->ref_frame[0] == ALTREF_FRAME) + above_ctx = 1; + } + if (left_mi) { + const MB_MODE_INFO *left_mbmi = &left_mi->mbmi; + if (has_second_ref(left_mbmi)) + left_ctx = left_mbmi->comp_group_idx; + else if (left_mbmi->ref_frame[0] == ALTREF_FRAME) + left_ctx = 1; + } + + return above_ctx + left_ctx; +} #endif // CONFIG_JNT_COMP static INLINE aom_cdf_prob *av1_get_pred_cdf_seg_id(
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c index 577640c..5917301 100644 --- a/av1/decoder/decodemv.c +++ b/av1/decoder/decodemv.c
@@ -1856,15 +1856,9 @@ read_ref_frames(cm, xd, r, mbmi->segment_id, mbmi->ref_frame); const int is_compound = has_second_ref(mbmi); -#if CONFIG_JNT_COMP - if (is_compound) { - const int comp_index_ctx = get_comp_index_context(cm, xd); - mbmi->compound_idx = aom_read_symbol( - r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR); - if (xd->counts) - ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx]; - } -#endif // CONFIG_JNT_COMP +#if CONFIG_EXT_SKIP +// TODO(zoeliu): To work with JNT_COMP +#endif // CONFIG_EXT_SKIP for (int ref = 0; ref < 1 + is_compound; ++ref) { MV_REFERENCE_FRAME frame = mbmi->ref_frame[ref]; @@ -2188,6 +2182,68 @@ mbmi->motion_mode = read_motion_mode(cm, xd, mi, r); #endif // CONFIG_EXT_WARPED_MOTION +#if CONFIG_JNT_COMP + // init + mbmi->comp_group_idx = 1; + mbmi->compound_idx = 1; + mbmi->interinter_compound_type = COMPOUND_AVERAGE; + + // read idx to indicate current compound inter prediction mode group + if (has_second_ref(mbmi)) { + const int ctx_comp_group_idx = get_comp_group_idx_context(xd); + mbmi->comp_group_idx = aom_read_symbol( + r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR); + if (xd->counts) + ++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx]; + + if (mbmi->comp_group_idx == 0) { + const int comp_index_ctx = get_comp_index_context(cm, xd); + mbmi->compound_idx = aom_read_symbol( + r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR); + + if (xd->counts) + ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx]; + + if (mbmi->compound_idx) mbmi->interinter_compound_type = COMPOUND_AVERAGE; + } else { + // compound_segment, wedge + mbmi->interinter_compound_type = COMPOUND_AVERAGE; + if (cm->reference_mode != SINGLE_REFERENCE && + is_inter_compound_mode(mbmi->mode) && + mbmi->motion_mode == SIMPLE_TRANSLATION +#if CONFIG_EXT_SKIP + && !mbmi->skip_mode +#endif // CONFIG_EXT_SKIP + && mbmi->comp_group_idx) { + if (is_any_masked_compound_used(bsize)) { + if (cm->allow_masked_compound) { + if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) + mbmi->interinter_compound_type = + aom_read_symbol(r, ec_ctx->compound_type_cdf[bsize], + COMPOUND_TYPES, ACCT_STR); + else + mbmi->interinter_compound_type = COMPOUND_SEG; + + if (mbmi->interinter_compound_type == COMPOUND_WEDGE) { + assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize)); + mbmi->wedge_index = + aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR); + mbmi->wedge_sign = aom_read_bit(r, ACCT_STR); + } + if (mbmi->interinter_compound_type == COMPOUND_SEG) { + mbmi->mask_type = + aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR); + } + } + } + + if (xd->counts) + xd->counts + ->compound_interinter[bsize][mbmi->interinter_compound_type]++; + } + } + } +#else // CONFIG_JNT_COMP mbmi->interinter_compound_type = COMPOUND_AVERAGE; if (cm->reference_mode != SINGLE_REFERENCE && is_inter_compound_mode(mbmi->mode) && @@ -2225,6 +2281,7 @@ if (xd->counts) xd->counts->compound_interinter[bsize][mbmi->interinter_compound_type]++; } +#endif // CONFIG_JNT_COMP read_mb_interp_filter(cm, xd, mbmi, r);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c index cb4d9b4..4f05d99 100644 --- a/av1/encoder/bitstream.c +++ b/av1/encoder/bitstream.c
@@ -1372,14 +1372,6 @@ int16_t mode_ctx; write_ref_frames(cm, xd, w); -#if CONFIG_JNT_COMP - if (has_second_ref(mbmi)) { - const int comp_index_ctx = get_comp_index_context(cm, xd); - aom_write_symbol(w, mbmi->compound_idx, - ec_ctx->compound_index_cdf[comp_index_ctx], 2); - } -#endif // CONFIG_JNT_COMP - if (is_compound) mode_ctx = mbmi_ext->compound_mode_context[mbmi->ref_frame[0]]; else @@ -1454,16 +1446,53 @@ if (mbmi->ref_frame[1] != INTRA_FRAME) write_motion_mode(cm, xd, mi, w); +#if CONFIG_JNT_COMP + // First write idx to indicate current compound inter prediction mode group + // Group A (0): jnt_comp, compound_average + // Group B (1): interintra, compound_segment, wedge + if (has_second_ref(mbmi)) { + const int ctx_comp_group_idx = get_comp_group_idx_context(xd); + aom_write_symbol(w, mbmi->comp_group_idx, + ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2); + + if (mbmi->comp_group_idx == 0) { + if (mbmi->compound_idx) + assert(mbmi->interinter_compound_type == COMPOUND_AVERAGE); + + const int comp_index_ctx = get_comp_index_context(cm, xd); + aom_write_symbol(w, mbmi->compound_idx, + ec_ctx->compound_index_cdf[comp_index_ctx], 2); + } else { + assert(mbmi->interinter_compound_type == COMPOUND_WEDGE || + mbmi->interinter_compound_type == COMPOUND_SEG); + // compound_segment, wedge + if (cpi->common.reference_mode != SINGLE_REFERENCE && + is_inter_compound_mode(mbmi->mode) && + mbmi->motion_mode == SIMPLE_TRANSLATION && + is_any_masked_compound_used(bsize) && cm->allow_masked_compound && + mbmi->comp_group_idx) { + if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) + aom_write_symbol(w, mbmi->interinter_compound_type, + ec_ctx->compound_type_cdf[bsize], COMPOUND_TYPES); + + if (is_interinter_compound_used(COMPOUND_WEDGE, bsize) && + mbmi->interinter_compound_type == COMPOUND_WEDGE) { + aom_write_literal(w, mbmi->wedge_index, + get_wedge_bits_lookup(bsize)); + aom_write_bit(w, mbmi->wedge_sign); + } + if (mbmi->interinter_compound_type == COMPOUND_SEG) { + aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS); + } + } + } + } +#else // CONFIG_JNT_COMP if (cpi->common.reference_mode != SINGLE_REFERENCE && is_inter_compound_mode(mbmi->mode) && mbmi->motion_mode == SIMPLE_TRANSLATION && is_any_masked_compound_used(bsize)) { -#if CONFIG_JNT_COMP - if (cm->allow_masked_compound && mbmi->compound_idx) -#else - if (cm->allow_masked_compound) -#endif // CONFIG_JNT_COMP - { + if (cm->allow_masked_compound) { if (!is_interinter_compound_used(COMPOUND_WEDGE, bsize)) aom_write_bit(w, mbmi->interinter_compound_type == COMPOUND_AVERAGE); else @@ -1479,6 +1508,7 @@ } } } +#endif // CONFIG_JNT_COMP write_mb_interp_filter(cpi, xd, w); }
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c index a22217f..be77347 100644 --- a/av1/encoder/encodeframe.c +++ b/av1/encoder/encodeframe.c
@@ -641,9 +641,6 @@ const int x_mis = AOMMIN(bw, cm->mi_cols - mi_col); const int y_mis = AOMMIN(bh, cm->mi_rows - mi_row); av1_copy_frame_mvs(cm, mi, mi_row, mi_col, x_mis, y_mis); - -#if CONFIG_JNT_COMP -#endif // CONFIG_JNT_COMP } #if NC_MODE_INFO @@ -1258,9 +1255,30 @@ } } +#if CONFIG_JNT_COMP + if (has_second_ref(mbmi)) { + const int comp_group_idx_ctx = get_comp_group_idx_context(xd); + ++counts->comp_group_idx[comp_group_idx_ctx][mbmi->comp_group_idx]; + if (allow_update_cdf) + update_cdf(fc->comp_group_idx_cdf[comp_group_idx_ctx], + mbmi->comp_group_idx, 2); + + if (mbmi->comp_group_idx == 0) { + const int comp_index_ctx = get_comp_index_context(cm, xd); + ++counts->compound_index[comp_index_ctx][mbmi->compound_idx]; + if (allow_update_cdf) + update_cdf(fc->compound_index_cdf[comp_index_ctx], + mbmi->compound_idx, 2); + } + } +#endif // CONFIG_JNT_COMP + if (cm->reference_mode != SINGLE_REFERENCE && - is_inter_compound_mode(mbmi->mode) && - mbmi->motion_mode == SIMPLE_TRANSLATION) { + is_inter_compound_mode(mbmi->mode) +#if CONFIG_JNT_COMP + && mbmi->comp_group_idx +#endif // CONFIG_JNT_COMP + && mbmi->motion_mode == SIMPLE_TRANSLATION) { if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) { counts ->compound_interinter[bsize][mbmi->interinter_compound_type]++; @@ -1319,16 +1337,6 @@ } } } - -#if CONFIG_JNT_COMP - if (has_second_ref(mbmi)) { - const int comp_index_ctx = get_comp_index_context(cm, xd); - ++counts->compound_index[comp_index_ctx][mbmi->compound_idx]; - if (allow_update_cdf) - update_cdf(fc->compound_index_cdf[comp_index_ctx], mbmi->compound_idx, - 2); - } -#endif // CONFIG_JNT_COMP } } } @@ -1484,6 +1492,15 @@ mbmi->current_delta_lf_from_base = xd->prev_delta_lf_from_base; } #endif +#if CONFIG_JNT_COMP + if (has_second_ref(mbmi)) { + if (mbmi->compound_idx == 0 || + mbmi->interinter_compound_type == COMPOUND_AVERAGE) + mbmi->comp_group_idx = 0; + else + mbmi->comp_group_idx = 1; + } +#endif update_stats(&cpi->common, tile_data, td, mi_row, mi_col); } }