JNT_COMP: Support new_multisymbol

Support cdf for jnt_comp read/write.

Change-Id: I2c29277a8b06b3e9f571355946b70ce0d492fbb2
diff --git a/av1/common/entropy.c b/av1/common/entropy.c
index f89eac5..f19643f 100644
--- a/av1/common/entropy.c
+++ b/av1/common/entropy.c
@@ -1824,4 +1824,9 @@
   int j;
   for (j = 0; j < Q_SEGMENT_CDF_COUNT; j++) AVERAGE_TILE_CDFS(seg.q_seg_cdf[j]);
 #endif
+#if CONFIG_JNT_COMP
+#if CONFIG_NEW_MULTISYMBOL
+  AVERAGE_TILE_CDFS(compound_index_cdf);
+#endif  // CONFIG_NEW_MULTISYMBOL
+#endif  // CONFIG_JNT_COMP
 }
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 5ac7e15..548d7d2 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -1652,6 +1652,20 @@
 #endif  // CONFIG_NEW_MULTISYMBOL
 
 #if CONFIG_JNT_COMP
+#if CONFIG_NEW_MULTISYMBOL
+static const aom_cdf_prob
+    default_compound_idx_cdfs[COMP_INDEX_CONTEXTS][CDF_SIZE(2)] = {
+      { AOM_ICDF(24576), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(16384), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(8192), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(24576), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(16384), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(8192), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(24576), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(16384), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(8192), AOM_ICDF(32768), 0 }
+    };
+#endif  // CONFIG_NEW_MULTISYMBOL
 static const aom_prob default_compound_idx_probs[COMP_INDEX_CONTEXTS] = {
   192, 128, 64, 192, 128, 64, 192, 128, 64,
 };
@@ -3446,6 +3460,9 @@
   av1_copy(fc->txfm_partition_cdf, default_txfm_partition_cdf);
 #endif
 #if CONFIG_JNT_COMP
+#if CONFIG_NEW_MULTISYMBOL
+  av1_copy(fc->compound_index_cdf, default_compound_idx_cdfs);
+#endif  // CONFIG_NEW_MULTISYMBOL
   av1_copy(fc->compound_index_probs, default_compound_idx_probs);
 #endif  // CONFIG_JNT_COMP
   av1_copy(fc->newmv_prob, default_newmv_prob);
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 310e58a..5cd8a84 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -266,6 +266,9 @@
   aom_cdf_prob txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)];
 #endif
 #if CONFIG_JNT_COMP
+#if CONFIG_NEW_MULTISYMBOL
+  aom_cdf_prob compound_index_cdf[COMP_INDEX_CONTEXTS][CDF_SIZE(2)];
+#endif  // CONFIG_NEW_MULTISYMBOL
   aom_prob compound_index_probs[COMP_INDEX_CONTEXTS];
 #endif  // CONFIG_JNT_COMP
 #if CONFIG_NEW_MULTISYMBOL
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index c724285..8fb19bc 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2161,8 +2161,13 @@
 #if CONFIG_JNT_COMP
   if (has_two_sided_comp_refs(cm, mbmi)) {
     const int comp_index_ctx = get_comp_index_context(cm, xd);
+#if CONFIG_NEW_MULTISYMBOL
+    mbmi->compound_idx = aom_read_symbol(
+        r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR);
+#else
     mbmi->compound_idx =
         aom_read(r, ec_ctx->compound_index_probs[comp_index_ctx], ACCT_STR);
+#endif  // CONFIG_NEW_MULTISYMBOL
     if (xd->counts)
       ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
   } else {
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index a3bd05a..b87fa06 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1497,11 +1497,19 @@
     write_ref_frames(cm, xd, w);
 
 #if CONFIG_JNT_COMP
+#if CONFIG_NEW_MULTISYMBOL
+    if (has_two_sided_comp_refs(cm, mbmi)) {
+      const int comp_index_ctx = get_comp_index_context(cm, xd);
+      aom_write_symbol(w, mbmi->compound_idx,
+                       ec_ctx->compound_index_cdf[comp_index_ctx], 2);
+    }
+#else
     if (has_two_sided_comp_refs(cm, mbmi)) {
       const int comp_index_ctx = get_comp_index_context(cm, xd);
       aom_write(w, mbmi->compound_idx,
                 ec_ctx->compound_index_probs[comp_index_ctx]);
     }
+#endif  // CONFIG_NEW_MULTISYMBOL
 #endif  // CONFIG_JNT_COMP
 
 #if CONFIG_COMPOUND_SINGLEREF
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 065e1a0..bd79d1f 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -301,6 +301,9 @@
   DECLARE_ALIGNED(16, uint8_t, decoded_8x8[8 * 8]);
 #endif
 #endif  // CONFIG_DIST_8X8
+#if CONFIG_JNT_COMP
+  int comp_idx_cost[COMP_INDEX_CONTEXTS][2];
+#endif  // CONFIG_JNT_COMP
 };
 
 #ifdef __cplusplus
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index dac927f..d30f995 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -608,10 +608,12 @@
   av1_copy_frame_mvs(cm, mi, mi_row, mi_col, x_mis, y_mis);
 
 #if CONFIG_JNT_COMP
+#if !CONFIG_NEW_MULTISYMBOL
   if (has_two_sided_comp_refs(cm, mbmi)) {
     const int comp_index_ctx = get_comp_index_context(cm, xd);
     ++td->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
   }
+#endif  // CONFIG_NEW_MULTISYMBOL
 #endif  // CONFIG_JNT_COMP
 }
 
@@ -1284,6 +1286,17 @@
           }
         }
       }
+
+#if CONFIG_JNT_COMP
+#if CONFIG_NEW_MULTISYMBOL
+      if (has_second_ref(mbmi)) {
+        const int comp_index_ctx = get_comp_index_context(cm, xd);
+        ++counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+        update_cdf(fc->compound_index_cdf[comp_index_ctx], mbmi->compound_idx,
+                   2);
+      }
+#endif  // CONFIG_NEW_MULTISYMBOL
+#endif  // CONFIG_JNT_COMP
     }
   }
 }
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index c0c8438..a44bc4a 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -379,6 +379,17 @@
       x->motion_mode_cost1[i][1] = av1_cost_bit(fc->obmc_prob[i], 1);
 #endif
     }
+#if CONFIG_JNT_COMP
+    for (i = 0; i < COMP_INDEX_CONTEXTS; ++i) {
+#if CONFIG_NEW_MULTISYMBOL
+      av1_cost_tokens_from_cdf(x->comp_idx_cost[i], fc->compound_index_cdf[i],
+                               NULL);
+#else
+      x->comp_idx_cost[i][0] = av1_cost_bit(fc->compound_index_probs[i], 0);
+      x->comp_idx_cost[i][1] = av1_cost_bit(fc->compound_index_probs[i], 1);
+#endif  // CONFIG_NEW_MULTISYMBOL
+    }
+#endif  // CONFIG_JNT_COMP
   }
 }
 
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 4250f61..52a4a91 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -8249,8 +8249,7 @@
 #if CONFIG_JNT_COMP
   if (has_two_sided_comp_refs(cm, mbmi)) {
     const int comp_index_ctx = get_comp_index_context(cm, xd);
-    rd_stats->rate += av1_cost_bit(cm->fc->compound_index_probs[comp_index_ctx],
-                                   mbmi->compound_idx);
+    rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
   }
 #endif  // CONFIG_JNT_COMP