JNT_COMP: divide compound modes into two groups
Divide compound inter prediction modes into two groups:
Group A: jnt_comp, compound_average
Group B: interintra, compound_segment, wedge
Change-Id: I1142da2e3dfadf382d6b8183a87bde95119cf1b7
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index eb434eb..5e5fa29 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -367,6 +367,7 @@
#if CONFIG_JNT_COMP
int compound_idx;
+ int comp_group_idx;
#endif
} MB_MODE_INFO;
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 49a4501..c0368d4 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -1473,9 +1473,20 @@
{ AOM_ICDF(16384), AOM_ICDF(32768), 0 },
{ AOM_ICDF(8192), AOM_ICDF(32768), 0 },
};
+
+static const aom_cdf_prob
+ default_comp_group_idx_cdfs[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)] = {
+ { AOM_ICDF(24576), AOM_ICDF(32768), 0 },
+ { AOM_ICDF(16384), AOM_ICDF(32768), 0 },
+ { AOM_ICDF(8192), AOM_ICDF(32768), 0 },
+ };
static const aom_prob default_compound_idx_probs[COMP_INDEX_CONTEXTS] = {
192, 128, 64, 192, 128, 64
};
+
+static const aom_prob default_comp_group_idx_probs[COMP_GROUP_IDX_CONTEXTS] = {
+ 192, 128, 64
+};
#endif // CONFIG_JNT_COMP
#if CONFIG_FILTER_INTRA
@@ -2929,7 +2940,9 @@
av1_copy(fc->txfm_partition_cdf, default_txfm_partition_cdf);
#if CONFIG_JNT_COMP
av1_copy(fc->compound_index_cdf, default_compound_idx_cdfs);
+ av1_copy(fc->comp_group_idx_cdf, default_comp_group_idx_cdfs);
av1_copy(fc->compound_index_probs, default_compound_idx_probs);
+ av1_copy(fc->comp_group_idx_probs, default_comp_group_idx_probs);
#endif // CONFIG_JNT_COMP
av1_copy(fc->newmv_cdf, default_newmv_cdf);
av1_copy(fc->zeromv_cdf, default_zeromv_cdf);
@@ -3044,6 +3057,9 @@
for (i = 0; i < COMP_INDEX_CONTEXTS; ++i)
fc->compound_index_probs[i] = av1_mode_mv_merge_probs(
pre_fc->compound_index_probs[i], counts->compound_index[i]);
+ for (i = 0; i < COMP_GROUP_IDX_CONTEXTS; ++i)
+ fc->comp_group_idx_probs[i] = av1_mode_mv_merge_probs(
+ pre_fc->comp_group_idx_probs[i], counts->comp_group_idx[i]);
#endif // CONFIG_JNT_COMP
}
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index ec17d1c..adf461d 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -245,7 +245,9 @@
aom_cdf_prob txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)];
#if CONFIG_JNT_COMP
aom_cdf_prob compound_index_cdf[COMP_INDEX_CONTEXTS][CDF_SIZE(2)];
+ aom_cdf_prob comp_group_idx_cdf[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)];
aom_prob compound_index_probs[COMP_INDEX_CONTEXTS];
+ aom_prob comp_group_idx_probs[COMP_GROUP_IDX_CONTEXTS];
#endif // CONFIG_JNT_COMP
#if CONFIG_EXT_SKIP
aom_cdf_prob skip_mode_cdfs[SKIP_CONTEXTS][CDF_SIZE(2)];
@@ -395,6 +397,7 @@
unsigned int skip[SKIP_CONTEXTS][2];
#if CONFIG_JNT_COMP
unsigned int compound_index[COMP_INDEX_CONTEXTS][2];
+ unsigned int comp_group_idx[COMP_GROUP_IDX_CONTEXTS][2];
#endif // CONFIG_JNT_COMP
unsigned int delta_q[DELTA_Q_PROBS][2];
#if CONFIG_EXT_DELTA_Q
diff --git a/av1/common/enums.h b/av1/common/enums.h
index ddcb342..667cc9f 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -592,6 +592,7 @@
#if CONFIG_JNT_COMP
#define COMP_INDEX_CONTEXTS 6
+#define COMP_GROUP_IDX_CONTEXTS 3
#endif // CONFIG_JNT_COMP
#define NMV_CONTEXTS 3
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index 5eaf1fa..8047769 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -122,6 +122,29 @@
return above_ctx + left_ctx + 3 * offset;
}
+
+static INLINE int get_comp_group_idx_context(const MACROBLOCKD *xd) {
+ const MODE_INFO *const above_mi = xd->above_mi;
+ const MODE_INFO *const left_mi = xd->left_mi;
+ int above_ctx = 0, left_ctx = 0;
+
+ if (above_mi) {
+ const MB_MODE_INFO *above_mbmi = &above_mi->mbmi;
+ if (has_second_ref(above_mbmi))
+ above_ctx = above_mbmi->comp_group_idx;
+ else if (above_mbmi->ref_frame[0] == ALTREF_FRAME)
+ above_ctx = 1;
+ }
+ if (left_mi) {
+ const MB_MODE_INFO *left_mbmi = &left_mi->mbmi;
+ if (has_second_ref(left_mbmi))
+ left_ctx = left_mbmi->comp_group_idx;
+ else if (left_mbmi->ref_frame[0] == ALTREF_FRAME)
+ left_ctx = 1;
+ }
+
+ return above_ctx + left_ctx;
+}
#endif // CONFIG_JNT_COMP
static INLINE aom_cdf_prob *av1_get_pred_cdf_seg_id(
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 577640c..5917301 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -1856,15 +1856,9 @@
read_ref_frames(cm, xd, r, mbmi->segment_id, mbmi->ref_frame);
const int is_compound = has_second_ref(mbmi);
-#if CONFIG_JNT_COMP
- if (is_compound) {
- const int comp_index_ctx = get_comp_index_context(cm, xd);
- mbmi->compound_idx = aom_read_symbol(
- r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR);
- if (xd->counts)
- ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
- }
-#endif // CONFIG_JNT_COMP
+#if CONFIG_EXT_SKIP
+// TODO(zoeliu): To work with JNT_COMP
+#endif // CONFIG_EXT_SKIP
for (int ref = 0; ref < 1 + is_compound; ++ref) {
MV_REFERENCE_FRAME frame = mbmi->ref_frame[ref];
@@ -2188,6 +2182,68 @@
mbmi->motion_mode = read_motion_mode(cm, xd, mi, r);
#endif // CONFIG_EXT_WARPED_MOTION
+#if CONFIG_JNT_COMP
+ // init
+ mbmi->comp_group_idx = 1;
+ mbmi->compound_idx = 1;
+ mbmi->interinter_compound_type = COMPOUND_AVERAGE;
+
+ // read idx to indicate current compound inter prediction mode group
+ if (has_second_ref(mbmi)) {
+ const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+ mbmi->comp_group_idx = aom_read_symbol(
+ r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
+ if (xd->counts)
+ ++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx];
+
+ if (mbmi->comp_group_idx == 0) {
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ mbmi->compound_idx = aom_read_symbol(
+ r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR);
+
+ if (xd->counts)
+ ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+
+ if (mbmi->compound_idx) mbmi->interinter_compound_type = COMPOUND_AVERAGE;
+ } else {
+ // compound_segment, wedge
+ mbmi->interinter_compound_type = COMPOUND_AVERAGE;
+ if (cm->reference_mode != SINGLE_REFERENCE &&
+ is_inter_compound_mode(mbmi->mode) &&
+ mbmi->motion_mode == SIMPLE_TRANSLATION
+#if CONFIG_EXT_SKIP
+ && !mbmi->skip_mode
+#endif // CONFIG_EXT_SKIP
+ && mbmi->comp_group_idx) {
+ if (is_any_masked_compound_used(bsize)) {
+ if (cm->allow_masked_compound) {
+ if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+ mbmi->interinter_compound_type =
+ aom_read_symbol(r, ec_ctx->compound_type_cdf[bsize],
+ COMPOUND_TYPES, ACCT_STR);
+ else
+ mbmi->interinter_compound_type = COMPOUND_SEG;
+
+ if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+ assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
+ mbmi->wedge_index =
+ aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
+ mbmi->wedge_sign = aom_read_bit(r, ACCT_STR);
+ }
+ if (mbmi->interinter_compound_type == COMPOUND_SEG) {
+ mbmi->mask_type =
+ aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
+ }
+ }
+ }
+
+ if (xd->counts)
+ xd->counts
+ ->compound_interinter[bsize][mbmi->interinter_compound_type]++;
+ }
+ }
+ }
+#else // CONFIG_JNT_COMP
mbmi->interinter_compound_type = COMPOUND_AVERAGE;
if (cm->reference_mode != SINGLE_REFERENCE &&
is_inter_compound_mode(mbmi->mode) &&
@@ -2225,6 +2281,7 @@
if (xd->counts)
xd->counts->compound_interinter[bsize][mbmi->interinter_compound_type]++;
}
+#endif // CONFIG_JNT_COMP
read_mb_interp_filter(cm, xd, mbmi, r);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index cb4d9b4..4f05d99 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1372,14 +1372,6 @@
int16_t mode_ctx;
write_ref_frames(cm, xd, w);
-#if CONFIG_JNT_COMP
- if (has_second_ref(mbmi)) {
- const int comp_index_ctx = get_comp_index_context(cm, xd);
- aom_write_symbol(w, mbmi->compound_idx,
- ec_ctx->compound_index_cdf[comp_index_ctx], 2);
- }
-#endif // CONFIG_JNT_COMP
-
if (is_compound)
mode_ctx = mbmi_ext->compound_mode_context[mbmi->ref_frame[0]];
else
@@ -1454,16 +1446,53 @@
if (mbmi->ref_frame[1] != INTRA_FRAME) write_motion_mode(cm, xd, mi, w);
+#if CONFIG_JNT_COMP
+ // First write idx to indicate current compound inter prediction mode group
+ // Group A (0): jnt_comp, compound_average
+ // Group B (1): interintra, compound_segment, wedge
+ if (has_second_ref(mbmi)) {
+ const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+ aom_write_symbol(w, mbmi->comp_group_idx,
+ ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+
+ if (mbmi->comp_group_idx == 0) {
+ if (mbmi->compound_idx)
+ assert(mbmi->interinter_compound_type == COMPOUND_AVERAGE);
+
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ aom_write_symbol(w, mbmi->compound_idx,
+ ec_ctx->compound_index_cdf[comp_index_ctx], 2);
+ } else {
+ assert(mbmi->interinter_compound_type == COMPOUND_WEDGE ||
+ mbmi->interinter_compound_type == COMPOUND_SEG);
+ // compound_segment, wedge
+ if (cpi->common.reference_mode != SINGLE_REFERENCE &&
+ is_inter_compound_mode(mbmi->mode) &&
+ mbmi->motion_mode == SIMPLE_TRANSLATION &&
+ is_any_masked_compound_used(bsize) && cm->allow_masked_compound &&
+ mbmi->comp_group_idx) {
+ if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+ aom_write_symbol(w, mbmi->interinter_compound_type,
+ ec_ctx->compound_type_cdf[bsize], COMPOUND_TYPES);
+
+ if (is_interinter_compound_used(COMPOUND_WEDGE, bsize) &&
+ mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+ aom_write_literal(w, mbmi->wedge_index,
+ get_wedge_bits_lookup(bsize));
+ aom_write_bit(w, mbmi->wedge_sign);
+ }
+ if (mbmi->interinter_compound_type == COMPOUND_SEG) {
+ aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
+ }
+ }
+ }
+ }
+#else // CONFIG_JNT_COMP
if (cpi->common.reference_mode != SINGLE_REFERENCE &&
is_inter_compound_mode(mbmi->mode) &&
mbmi->motion_mode == SIMPLE_TRANSLATION &&
is_any_masked_compound_used(bsize)) {
-#if CONFIG_JNT_COMP
- if (cm->allow_masked_compound && mbmi->compound_idx)
-#else
- if (cm->allow_masked_compound)
-#endif // CONFIG_JNT_COMP
- {
+ if (cm->allow_masked_compound) {
if (!is_interinter_compound_used(COMPOUND_WEDGE, bsize))
aom_write_bit(w, mbmi->interinter_compound_type == COMPOUND_AVERAGE);
else
@@ -1479,6 +1508,7 @@
}
}
}
+#endif // CONFIG_JNT_COMP
write_mb_interp_filter(cpi, xd, w);
}
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a22217f..be77347 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -641,9 +641,6 @@
const int x_mis = AOMMIN(bw, cm->mi_cols - mi_col);
const int y_mis = AOMMIN(bh, cm->mi_rows - mi_row);
av1_copy_frame_mvs(cm, mi, mi_row, mi_col, x_mis, y_mis);
-
-#if CONFIG_JNT_COMP
-#endif // CONFIG_JNT_COMP
}
#if NC_MODE_INFO
@@ -1258,9 +1255,30 @@
}
}
+#if CONFIG_JNT_COMP
+ if (has_second_ref(mbmi)) {
+ const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+ ++counts->comp_group_idx[comp_group_idx_ctx][mbmi->comp_group_idx];
+ if (allow_update_cdf)
+ update_cdf(fc->comp_group_idx_cdf[comp_group_idx_ctx],
+ mbmi->comp_group_idx, 2);
+
+ if (mbmi->comp_group_idx == 0) {
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ ++counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+ if (allow_update_cdf)
+ update_cdf(fc->compound_index_cdf[comp_index_ctx],
+ mbmi->compound_idx, 2);
+ }
+ }
+#endif // CONFIG_JNT_COMP
+
if (cm->reference_mode != SINGLE_REFERENCE &&
- is_inter_compound_mode(mbmi->mode) &&
- mbmi->motion_mode == SIMPLE_TRANSLATION) {
+ is_inter_compound_mode(mbmi->mode)
+#if CONFIG_JNT_COMP
+ && mbmi->comp_group_idx
+#endif // CONFIG_JNT_COMP
+ && mbmi->motion_mode == SIMPLE_TRANSLATION) {
if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
counts
->compound_interinter[bsize][mbmi->interinter_compound_type]++;
@@ -1319,16 +1337,6 @@
}
}
}
-
-#if CONFIG_JNT_COMP
- if (has_second_ref(mbmi)) {
- const int comp_index_ctx = get_comp_index_context(cm, xd);
- ++counts->compound_index[comp_index_ctx][mbmi->compound_idx];
- if (allow_update_cdf)
- update_cdf(fc->compound_index_cdf[comp_index_ctx], mbmi->compound_idx,
- 2);
- }
-#endif // CONFIG_JNT_COMP
}
}
}
@@ -1484,6 +1492,15 @@
mbmi->current_delta_lf_from_base = xd->prev_delta_lf_from_base;
}
#endif
+#if CONFIG_JNT_COMP
+ if (has_second_ref(mbmi)) {
+ if (mbmi->compound_idx == 0 ||
+ mbmi->interinter_compound_type == COMPOUND_AVERAGE)
+ mbmi->comp_group_idx = 0;
+ else
+ mbmi->comp_group_idx = 1;
+ }
+#endif
update_stats(&cpi->common, tile_data, td, mi_row, mi_col);
}
}