Change comp_group index context and save sending comp_group

Extend context model for comp_group_idx.
Save sending comp_group_idx when masked_compound is not allowed.

Change-Id: Ia7ae53958c9e1c8fe07be4b14a425d9b8648082d
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index d145f514..1aa0667 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -1504,16 +1504,20 @@
 
 static const aom_cdf_prob
     default_comp_group_idx_cdfs[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)] = {
+      { AOM_ICDF(29491), AOM_ICDF(32768), 0 },
       { AOM_ICDF(24576), AOM_ICDF(32768), 0 },
       { AOM_ICDF(16384), AOM_ICDF(32768), 0 },
-      { AOM_ICDF(8192), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(24576), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(16384), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(13107), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(13107), AOM_ICDF(32768), 0 },
     };
 static const aom_prob default_compound_idx_probs[COMP_INDEX_CONTEXTS] = {
   192, 128, 64, 192, 128, 64
 };
 
 static const aom_prob default_comp_group_idx_probs[COMP_GROUP_IDX_CONTEXTS] = {
-  192, 128, 64
+  192, 128, 64, 192, 128, 64, 128
 };
 #endif  // CONFIG_JNT_COMP
 
diff --git a/av1/common/enums.h b/av1/common/enums.h
index 667cc9f..7a82c5c 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -592,7 +592,7 @@
 
 #if CONFIG_JNT_COMP
 #define COMP_INDEX_CONTEXTS 6
-#define COMP_GROUP_IDX_CONTEXTS 3
+#define COMP_GROUP_IDX_CONTEXTS 7
 #endif  // CONFIG_JNT_COMP
 
 #define NMV_CONTEXTS 3
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index 8047769..bf3f091 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -133,14 +133,14 @@
     if (has_second_ref(above_mbmi))
       above_ctx = above_mbmi->comp_group_idx;
     else if (above_mbmi->ref_frame[0] == ALTREF_FRAME)
-      above_ctx = 1;
+      above_ctx = 3;
   }
   if (left_mi) {
     const MB_MODE_INFO *left_mbmi = &left_mi->mbmi;
     if (has_second_ref(left_mbmi))
       left_ctx = left_mbmi->comp_group_idx;
     else if (left_mbmi->ref_frame[0] == ALTREF_FRAME)
-      left_ctx = 1;
+      left_ctx = 3;
   }
 
   return above_ctx + left_ctx;
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 3e7039e..532471f 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2189,12 +2189,19 @@
   mbmi->interinter_compound_type = COMPOUND_AVERAGE;
 
   // read idx to indicate current compound inter prediction mode group
+  int masked_compound_used = is_any_masked_compound_used(bsize);
+  masked_compound_used = masked_compound_used && cm->allow_masked_compound;
+
   if (has_second_ref(mbmi)) {
-    const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
-    mbmi->comp_group_idx = aom_read_symbol(
-        r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
-    if (xd->counts)
-      ++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx];
+    if (masked_compound_used) {
+      const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+      mbmi->comp_group_idx = aom_read_symbol(
+          r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
+      if (xd->counts)
+        ++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx];
+    } else {
+      mbmi->comp_group_idx = 0;
+    }
 
     if (mbmi->comp_group_idx == 0) {
       const int comp_index_ctx = get_comp_index_context(cm, xd);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index c669a4c..e4f1fec 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1454,10 +1454,17 @@
     // First write idx to indicate current compound inter prediction mode group
     // Group A (0): jnt_comp, compound_average
     // Group B (1): interintra, compound_segment, wedge
+    int masked_compound_used = is_any_masked_compound_used(bsize);
+    masked_compound_used = masked_compound_used && cm->allow_masked_compound;
+
     if (has_second_ref(mbmi)) {
-      const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
-      aom_write_symbol(w, mbmi->comp_group_idx,
-                       ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+      if (masked_compound_used) {
+        assert(mbmi->comp_group_idx == 0);
+
+        const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+        aom_write_symbol(w, mbmi->comp_group_idx,
+                         ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+      }
 
       if (mbmi->comp_group_idx == 0) {
         if (mbmi->compound_idx)
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index ca14064..1ec0cca 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5661,7 +5661,6 @@
   }
 }
 
-#if !CONFIG_JNT_COMP
 static int get_interinter_compound_type_bits(BLOCK_SIZE bsize,
                                              COMPOUND_TYPE comp_type) {
   (void)bsize;
@@ -5672,7 +5671,6 @@
     default: assert(0); return 0;
   }
 }
-#endif
 
 typedef struct {
   int eobs;
@@ -8142,12 +8140,16 @@
 #if CONFIG_JNT_COMP
   if (is_comp_pred) {
     if (mbmi->compound_idx == 0) {
-      mbmi->comp_group_idx = 0;
-      const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
-      rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+      int masked_compound_used = is_any_masked_compound_used(bsize);
+      masked_compound_used = masked_compound_used && cm->allow_masked_compound;
+
+      if (masked_compound_used) {
+        const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+        rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+      }
 
       const int comp_index_ctx = get_comp_index_context(cm, xd);
-      rd_stats->rate += x->comp_idx_cost[comp_index_ctx][0];
+      rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
     }
   }
 #endif  // CONFIG_JNT_COMP
@@ -8316,26 +8318,27 @@
       best_rd_cur = INT64_MAX;
       mbmi->interinter_compound_type = cur_type;
 #if CONFIG_JNT_COMP
-      const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
-      if (cur_type == COMPOUND_AVERAGE) {
-        mbmi->comp_group_idx = 0;
-        rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][0];
+      const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+      int masked_type_cost = 0;
+      if (masked_compound_used) {
+        if (cur_type == COMPOUND_AVERAGE) {
+          masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
 
-        const int comp_index_ctx = get_comp_index_context(cm, xd);
-        rs2 += x->comp_idx_cost[comp_index_ctx][1];
-      } else {
-        mbmi->comp_group_idx = 1;
-        rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][1];
+          const int comp_index_ctx = get_comp_index_context(cm, xd);
+          masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
+        } else {
+          masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
 
-        int masked_type_cost = 0;
-        if (masked_compound_used) {
-          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
-            masked_type_cost +=
-                x->compound_type_cost[bsize]
-                                     [mbmi->interinter_compound_type - 1];
+          masked_type_cost +=
+              x->compound_type_cost[bsize][mbmi->interinter_compound_type - 1];
         }
-        rs2 += masked_type_cost;
+      } else {
+        const int comp_index_ctx = get_comp_index_context(cm, xd);
+        masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
       }
+      rs2 = av1_cost_literal(get_interinter_compound_type_bits(
+                bsize, mbmi->interinter_compound_type)) +
+            masked_type_cost;
 #else
       int masked_type_cost = 0;
       if (masked_compound_used) {