JNT_COMP: 4. add context and entropy read/write

Change-Id: I0e6f7ab981e31f7120105515f6204568b6dc82d3
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index f30728d..0d2d1ed 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -2354,6 +2354,12 @@
 static const aom_prob default_skip_probs[SKIP_CONTEXTS] = { 192, 128, 64 };
 #endif  // CONFIG_NEW_MULTISYMBOL
 
+#if CONFIG_JNT_COMP
+static const aom_prob default_compound_idx_probs[COMP_INDEX_CONTEXTS] = {
+  192, 128, 64, 192, 128, 64, 192, 128, 64,
+};
+#endif  // CONFIG_JNT_COMP
+
 #if CONFIG_LGT_FROM_PRED
 static const aom_prob default_intra_lgt_prob[LGT_SIZES][INTRA_MODES] = {
   { 255, 208, 208, 180, 230, 208, 194, 214, 220, 255,
@@ -6224,6 +6230,9 @@
 #if CONFIG_NEW_MULTISYMBOL
   av1_copy(fc->txfm_partition_cdf, default_txfm_partition_cdf);
 #endif
+#if CONFIG_JNT_COMP
+  av1_copy(fc->compound_index_probs, default_compound_idx_probs);
+#endif  // CONFIG_JNT_COMP
   av1_copy(fc->newmv_prob, default_newmv_prob);
   av1_copy(fc->zeromv_prob, default_zeromv_prob);
   av1_copy(fc->refmv_prob, default_refmv_prob);
@@ -6455,6 +6464,12 @@
     }
   }
 #endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
+
+#if CONFIG_JNT_COMP
+  for (i = 0; i < COMP_INDEX_CONTEXTS; ++i)
+    fc->compound_index_probs[i] = av1_mode_mv_merge_probs(
+        pre_fc->compound_index_probs[i], counts->compound_index[i]);
+#endif  // CONFIG_JNT_COMP
 }
 
 void av1_adapt_intra_frame_probs(AV1_COMMON *cm) {
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index bc4b5b9..d278238 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -292,6 +292,9 @@
 #if CONFIG_NEW_MULTISYMBOL
   aom_cdf_prob txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)];
 #endif
+#if CONFIG_JNT_COMP
+  aom_prob compound_index_probs[COMP_INDEX_CONTEXTS];
+#endif  // CONFIG_JNT_COMP
 #if CONFIG_NEW_MULTISYMBOL
   aom_cdf_prob skip_cdfs[SKIP_CONTEXTS][CDF_SIZE(2)];
   aom_cdf_prob intra_inter_cdf[INTRA_INTER_CONTEXTS][CDF_SIZE(2)];
@@ -477,6 +480,9 @@
   unsigned int txfm_partition[TXFM_PARTITION_CONTEXTS][2];
   unsigned int skip[SKIP_CONTEXTS][2];
   nmv_context_counts mv[NMV_CONTEXTS];
+#if CONFIG_JNT_COMP
+  unsigned int compound_index[COMP_INDEX_CONTEXTS][2];
+#endif  // CONFIG_JNT_COMP
 #if CONFIG_INTRABC
   unsigned int intrabc[2];
   nmv_context_counts dv;
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 70c90da..2a3368a 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -640,6 +640,10 @@
 
 #define SKIP_CONTEXTS 3
 
+#if CONFIG_JNT_COMP
+#define COMP_INDEX_CONTEXTS 9
+#endif  // CONFIG_JNT_COMP
+
 #define NMV_CONTEXTS 3
 
 #define NEWMV_MODE_CONTEXTS 7
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index bdf6527..169035b 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -54,6 +54,53 @@
   return segp->pred_probs[av1_get_pred_context_seg_id(xd)];
 }
 
+#if CONFIG_JNT_COMP
+static INLINE int get_comp_index_context(const AV1_COMMON *cm,
+                                         const MACROBLOCKD *xd) {
+  MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+  int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
+  int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
+  int bck_frame_index = 0, fwd_frame_index = 0;
+  int cur_frame_index = cm->cur_frame->cur_frame_offset;
+
+  if (bck_idx >= 0)
+    bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
+
+  if (fwd_idx >= 0)
+    fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
+  int fwd = abs(fwd_frame_index - cur_frame_index);
+  int bck = abs(cur_frame_index - bck_frame_index);
+
+  const MODE_INFO *const above_mi = xd->above_mi;
+  const MODE_INFO *const left_mi = xd->left_mi;
+
+  int above_ctx = 0, left_ctx = 0;
+  int offset = (fwd > bck) ? 0 : 1;
+
+  if (fwd < (bck >> 1) + bck && bck < (fwd >> 1) + fwd) {
+    offset = 2;
+  }
+
+  if (above_mi) {
+    const MB_MODE_INFO *above_mbmi = &above_mi->mbmi;
+    if (has_second_ref(above_mbmi))
+      above_ctx = above_mbmi->compound_idx;
+    else if (above_mbmi->ref_frame[0] == ALTREF_FRAME)
+      above_ctx = 1;
+  }
+
+  if (left_mi) {
+    const MB_MODE_INFO *left_mbmi = &left_mi->mbmi;
+    if (has_second_ref(left_mbmi))
+      left_ctx = left_mbmi->compound_idx;
+    else if (left_mbmi->ref_frame[0] == ALTREF_FRAME)
+      left_ctx = 1;
+  }
+
+  return above_ctx + left_ctx + 3 * offset;
+}
+#endif  // CONFIG_JNT_COMP
+
 #if CONFIG_NEW_MULTISYMBOL
 static INLINE aom_cdf_prob *av1_get_pred_cdf_seg_id(
     struct segmentation_probs *segp, const MACROBLOCKD *xd) {
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index d7d7059..f507b93 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -3531,6 +3531,11 @@
       av1_diff_update_prob(&r, &fc->txfm_partition_prob[i], ACCT_STR);
   for (int i = 0; i < SKIP_CONTEXTS; ++i)
     av1_diff_update_prob(&r, &fc->skip_probs[i], ACCT_STR);
+
+#if CONFIG_JNT_COMP
+  for (int i = 0; i < COMP_INDEX_CONTEXTS; ++i)
+    av1_diff_update_prob(&r, &fc->compound_index_probs[i], ACCT_STR);
+#endif  // CONFIG_JNT_COMP
 #endif
 
   if (!frame_is_intra_only(cm)) {
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index d2b0acb..a22f3ec 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2221,6 +2221,16 @@
   read_ref_frames(cm, xd, r, mbmi->segment_id, mbmi->ref_frame);
   is_compound = has_second_ref(mbmi);
 
+#if CONFIG_JNT_COMP
+  if (is_compound) {
+    const int comp_index_ctx = get_comp_index_context(cm, xd);
+    mbmi->compound_idx =
+        aom_read(r, ec_ctx->compound_index_probs[comp_index_ctx], ACCT_STR);
+    if (xd->counts)
+      ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+  }
+#endif  // CONFIG_JNT_COMP
+
 #if CONFIG_EXT_COMP_REFS
 #if !USE_UNI_COMP_REFS
   // NOTE: uni-directional comp refs disabled
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index bbe80d5..59fb670 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1605,6 +1605,14 @@
     int16_t mode_ctx;
     write_ref_frames(cm, xd, w);
 
+#if CONFIG_JNT_COMP
+    if (has_second_ref(mbmi)) {
+      const int comp_index_ctx = get_comp_index_context(cm, xd);
+      aom_write(w, mbmi->compound_idx,
+                ec_ctx->compound_index_probs[comp_index_ctx]);
+    }
+#endif  // CONFIG_JNT_COMP
+
 #if CONFIG_COMPOUND_SINGLEREF
     if (!segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME)) {
       // NOTE: Handle single ref comp mode
@@ -4570,6 +4578,11 @@
 
 #if !CONFIG_NEW_MULTISYMBOL
   update_skip_probs(cm, header_bc, counts);
+#if CONFIG_JNT_COMP
+  for (int k = 0; k < COMP_INDEX_CONTEXTS; ++k)
+    av1_cond_prob_diff_update(header_bc, &cm->fc->compound_index_probs[k],
+                              counts->compound_index[k], probwt);
+#endif  // CONFIG_JNT_COMP
 #endif
 
   if (!frame_is_intra_only(cm)) {
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 5403318..9099737 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -590,6 +590,13 @@
   const int x_mis = AOMMIN(bw, cm->mi_cols - mi_col);
   const int y_mis = AOMMIN(bh, cm->mi_rows - mi_row);
   av1_copy_frame_mvs(cm, mi, mi_row, mi_col, x_mis, y_mis);
+
+#if CONFIG_JNT_COMP
+  if (has_second_ref(mbmi)) {
+    const int comp_index_ctx = get_comp_index_context(cm, xd);
+    ++td->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+  }
+#endif  // CONFIG_JNT_COMP
 }
 
 #if CONFIG_MOTION_VAR && NC_MODE_INFO
@@ -3443,7 +3450,7 @@
   av1_setup_block_planes(xd, cm->subsampling_x, cm->subsampling_y);
 }
 
-#if !CONFIG_REF_ADAPT
+#if !CONFIG_REF_ADAPT && !CONFIG_JNT_COMP
 static int check_dual_ref_flags(AV1_COMP *cpi) {
   const int ref_flags = cpi->ref_frame_flags;
 
@@ -4162,12 +4169,14 @@
     if (is_alt_ref || !cpi->allow_comp_inter_inter)
 #endif  // CONFIG_BGSPRITE
       cm->reference_mode = SINGLE_REFERENCE;
+#if !CONFIG_JNT_COMP
     else if (mode_thrs[COMPOUND_REFERENCE] > mode_thrs[SINGLE_REFERENCE] &&
              mode_thrs[COMPOUND_REFERENCE] > mode_thrs[REFERENCE_MODE_SELECT] &&
              check_dual_ref_flags(cpi) && cpi->static_mb_pct == 100)
       cm->reference_mode = COMPOUND_REFERENCE;
     else if (mode_thrs[SINGLE_REFERENCE] > mode_thrs[REFERENCE_MODE_SELECT])
       cm->reference_mode = SINGLE_REFERENCE;
+#endif  // CONFIG_JNT_COMP
     else
       cm->reference_mode = REFERENCE_MODE_SELECT;
 #endif  // CONFIG_REF_ADAPT
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index ce4c03d..f4d001c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8462,6 +8462,14 @@
   }
 #endif  // CONFIG_COMPOUND_SINGLEREF
 
+#if CONFIG_JNT_COMP
+  if (is_comp_pred) {
+    const int comp_index_ctx = get_comp_index_context(cm, xd);
+    rd_stats->rate += av1_cost_bit(cm->fc->compound_index_probs[comp_index_ctx],
+                                   mbmi->compound_idx);
+  }
+#endif  // CONFIG_JNT_COMP
+
   if (this_mode == NEAREST_NEARESTMV) {
     if (mbmi_ext->ref_mv_count[ref_frame_type] > 0) {
       cur_mv[0] = mbmi_ext->ref_mv_stack[ref_frame_type][0].this_mv;