JNT_COMP: 4. add context and entropy read/write
Change-Id: I0e6f7ab981e31f7120105515f6204568b6dc82d3
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index f30728d..0d2d1ed 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -2354,6 +2354,12 @@
static const aom_prob default_skip_probs[SKIP_CONTEXTS] = { 192, 128, 64 };
#endif // CONFIG_NEW_MULTISYMBOL
+#if CONFIG_JNT_COMP
+static const aom_prob default_compound_idx_probs[COMP_INDEX_CONTEXTS] = {
+ 192, 128, 64, 192, 128, 64, 192, 128, 64,
+};
+#endif // CONFIG_JNT_COMP
+
#if CONFIG_LGT_FROM_PRED
static const aom_prob default_intra_lgt_prob[LGT_SIZES][INTRA_MODES] = {
{ 255, 208, 208, 180, 230, 208, 194, 214, 220, 255,
@@ -6224,6 +6230,9 @@
#if CONFIG_NEW_MULTISYMBOL
av1_copy(fc->txfm_partition_cdf, default_txfm_partition_cdf);
#endif
+#if CONFIG_JNT_COMP
+ av1_copy(fc->compound_index_probs, default_compound_idx_probs);
+#endif // CONFIG_JNT_COMP
av1_copy(fc->newmv_prob, default_newmv_prob);
av1_copy(fc->zeromv_prob, default_zeromv_prob);
av1_copy(fc->refmv_prob, default_refmv_prob);
@@ -6455,6 +6464,12 @@
}
}
#endif // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
+
+#if CONFIG_JNT_COMP
+ for (i = 0; i < COMP_INDEX_CONTEXTS; ++i)
+ fc->compound_index_probs[i] = av1_mode_mv_merge_probs(
+ pre_fc->compound_index_probs[i], counts->compound_index[i]);
+#endif // CONFIG_JNT_COMP
}
void av1_adapt_intra_frame_probs(AV1_COMMON *cm) {
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index bc4b5b9..d278238 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -292,6 +292,9 @@
#if CONFIG_NEW_MULTISYMBOL
aom_cdf_prob txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)];
#endif
+#if CONFIG_JNT_COMP
+ aom_prob compound_index_probs[COMP_INDEX_CONTEXTS];
+#endif // CONFIG_JNT_COMP
#if CONFIG_NEW_MULTISYMBOL
aom_cdf_prob skip_cdfs[SKIP_CONTEXTS][CDF_SIZE(2)];
aom_cdf_prob intra_inter_cdf[INTRA_INTER_CONTEXTS][CDF_SIZE(2)];
@@ -477,6 +480,9 @@
unsigned int txfm_partition[TXFM_PARTITION_CONTEXTS][2];
unsigned int skip[SKIP_CONTEXTS][2];
nmv_context_counts mv[NMV_CONTEXTS];
+#if CONFIG_JNT_COMP
+ unsigned int compound_index[COMP_INDEX_CONTEXTS][2];
+#endif // CONFIG_JNT_COMP
#if CONFIG_INTRABC
unsigned int intrabc[2];
nmv_context_counts dv;
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 70c90da..2a3368a 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -640,6 +640,10 @@
#define SKIP_CONTEXTS 3
+#if CONFIG_JNT_COMP
+#define COMP_INDEX_CONTEXTS 9
+#endif // CONFIG_JNT_COMP
+
#define NMV_CONTEXTS 3
#define NEWMV_MODE_CONTEXTS 7
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index bdf6527..169035b 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -54,6 +54,53 @@
return segp->pred_probs[av1_get_pred_context_seg_id(xd)];
}
+#if CONFIG_JNT_COMP
+static INLINE int get_comp_index_context(const AV1_COMMON *cm,
+ const MACROBLOCKD *xd) {
+ MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+ int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
+ int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
+ int bck_frame_index = 0, fwd_frame_index = 0;
+ int cur_frame_index = cm->cur_frame->cur_frame_offset;
+
+ if (bck_idx >= 0)
+ bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
+
+ if (fwd_idx >= 0)
+ fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
+ int fwd = abs(fwd_frame_index - cur_frame_index);
+ int bck = abs(cur_frame_index - bck_frame_index);
+
+ const MODE_INFO *const above_mi = xd->above_mi;
+ const MODE_INFO *const left_mi = xd->left_mi;
+
+ int above_ctx = 0, left_ctx = 0;
+ int offset = (fwd > bck) ? 0 : 1;
+
+ if (fwd < (bck >> 1) + bck && bck < (fwd >> 1) + fwd) {
+ offset = 2;
+ }
+
+ if (above_mi) {
+ const MB_MODE_INFO *above_mbmi = &above_mi->mbmi;
+ if (has_second_ref(above_mbmi))
+ above_ctx = above_mbmi->compound_idx;
+ else if (above_mbmi->ref_frame[0] == ALTREF_FRAME)
+ above_ctx = 1;
+ }
+
+ if (left_mi) {
+ const MB_MODE_INFO *left_mbmi = &left_mi->mbmi;
+ if (has_second_ref(left_mbmi))
+ left_ctx = left_mbmi->compound_idx;
+ else if (left_mbmi->ref_frame[0] == ALTREF_FRAME)
+ left_ctx = 1;
+ }
+
+ return above_ctx + left_ctx + 3 * offset;
+}
+#endif // CONFIG_JNT_COMP
+
#if CONFIG_NEW_MULTISYMBOL
static INLINE aom_cdf_prob *av1_get_pred_cdf_seg_id(
struct segmentation_probs *segp, const MACROBLOCKD *xd) {
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index d7d7059..f507b93 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -3531,6 +3531,11 @@
av1_diff_update_prob(&r, &fc->txfm_partition_prob[i], ACCT_STR);
for (int i = 0; i < SKIP_CONTEXTS; ++i)
av1_diff_update_prob(&r, &fc->skip_probs[i], ACCT_STR);
+
+#if CONFIG_JNT_COMP
+ for (int i = 0; i < COMP_INDEX_CONTEXTS; ++i)
+ av1_diff_update_prob(&r, &fc->compound_index_probs[i], ACCT_STR);
+#endif // CONFIG_JNT_COMP
#endif
if (!frame_is_intra_only(cm)) {
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index d2b0acb..a22f3ec 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2221,6 +2221,16 @@
read_ref_frames(cm, xd, r, mbmi->segment_id, mbmi->ref_frame);
is_compound = has_second_ref(mbmi);
+#if CONFIG_JNT_COMP
+ if (is_compound) {
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ mbmi->compound_idx =
+ aom_read(r, ec_ctx->compound_index_probs[comp_index_ctx], ACCT_STR);
+ if (xd->counts)
+ ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+ }
+#endif // CONFIG_JNT_COMP
+
#if CONFIG_EXT_COMP_REFS
#if !USE_UNI_COMP_REFS
// NOTE: uni-directional comp refs disabled
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index bbe80d5..59fb670 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1605,6 +1605,14 @@
int16_t mode_ctx;
write_ref_frames(cm, xd, w);
+#if CONFIG_JNT_COMP
+ if (has_second_ref(mbmi)) {
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ aom_write(w, mbmi->compound_idx,
+ ec_ctx->compound_index_probs[comp_index_ctx]);
+ }
+#endif // CONFIG_JNT_COMP
+
#if CONFIG_COMPOUND_SINGLEREF
if (!segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME)) {
// NOTE: Handle single ref comp mode
@@ -4570,6 +4578,11 @@
#if !CONFIG_NEW_MULTISYMBOL
update_skip_probs(cm, header_bc, counts);
+#if CONFIG_JNT_COMP
+ for (int k = 0; k < COMP_INDEX_CONTEXTS; ++k)
+ av1_cond_prob_diff_update(header_bc, &cm->fc->compound_index_probs[k],
+ counts->compound_index[k], probwt);
+#endif // CONFIG_JNT_COMP
#endif
if (!frame_is_intra_only(cm)) {
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 5403318..9099737 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -590,6 +590,13 @@
const int x_mis = AOMMIN(bw, cm->mi_cols - mi_col);
const int y_mis = AOMMIN(bh, cm->mi_rows - mi_row);
av1_copy_frame_mvs(cm, mi, mi_row, mi_col, x_mis, y_mis);
+
+#if CONFIG_JNT_COMP
+ if (has_second_ref(mbmi)) {
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ ++td->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+ }
+#endif // CONFIG_JNT_COMP
}
#if CONFIG_MOTION_VAR && NC_MODE_INFO
@@ -3443,7 +3450,7 @@
av1_setup_block_planes(xd, cm->subsampling_x, cm->subsampling_y);
}
-#if !CONFIG_REF_ADAPT
+#if !CONFIG_REF_ADAPT && !CONFIG_JNT_COMP
static int check_dual_ref_flags(AV1_COMP *cpi) {
const int ref_flags = cpi->ref_frame_flags;
@@ -4162,12 +4169,14 @@
if (is_alt_ref || !cpi->allow_comp_inter_inter)
#endif // CONFIG_BGSPRITE
cm->reference_mode = SINGLE_REFERENCE;
+#if !CONFIG_JNT_COMP
else if (mode_thrs[COMPOUND_REFERENCE] > mode_thrs[SINGLE_REFERENCE] &&
mode_thrs[COMPOUND_REFERENCE] > mode_thrs[REFERENCE_MODE_SELECT] &&
check_dual_ref_flags(cpi) && cpi->static_mb_pct == 100)
cm->reference_mode = COMPOUND_REFERENCE;
else if (mode_thrs[SINGLE_REFERENCE] > mode_thrs[REFERENCE_MODE_SELECT])
cm->reference_mode = SINGLE_REFERENCE;
+#endif // CONFIG_JNT_COMP
else
cm->reference_mode = REFERENCE_MODE_SELECT;
#endif // CONFIG_REF_ADAPT
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index ce4c03d..f4d001c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8462,6 +8462,14 @@
}
#endif // CONFIG_COMPOUND_SINGLEREF
+#if CONFIG_JNT_COMP
+ if (is_comp_pred) {
+ const int comp_index_ctx = get_comp_index_context(cm, xd);
+ rd_stats->rate += av1_cost_bit(cm->fc->compound_index_probs[comp_index_ctx],
+ mbmi->compound_idx);
+ }
+#endif // CONFIG_JNT_COMP
+
if (this_mode == NEAREST_NEARESTMV) {
if (mbmi_ext->ref_mv_count[ref_frame_type] > 0) {
cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv;