Change comp_group index context and save sending comp_group

Extend context model for comp_group_idx.
Save sending comp_group_idx when masked_compound is not allowed.

Change-Id: Ia7ae53958c9e1c8fe07be4b14a425d9b8648082d
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index ca14064..1ec0cca 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5661,7 +5661,6 @@
   }
 }
 
-#if !CONFIG_JNT_COMP
 static int get_interinter_compound_type_bits(BLOCK_SIZE bsize,
                                              COMPOUND_TYPE comp_type) {
   (void)bsize;
@@ -5672,7 +5671,6 @@
     default: assert(0); return 0;
   }
 }
-#endif
 
 typedef struct {
   int eobs;
@@ -8142,12 +8140,16 @@
 #if CONFIG_JNT_COMP
   if (is_comp_pred) {
     if (mbmi->compound_idx == 0) {
-      mbmi->comp_group_idx = 0;
-      const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
-      rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+      int masked_compound_used = is_any_masked_compound_used(bsize);
+      masked_compound_used = masked_compound_used && cm->allow_masked_compound;
+
+      if (masked_compound_used) {
+        const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+        rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+      }
 
       const int comp_index_ctx = get_comp_index_context(cm, xd);
-      rd_stats->rate += x->comp_idx_cost[comp_index_ctx][0];
+      rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
     }
   }
 #endif  // CONFIG_JNT_COMP
@@ -8316,26 +8318,27 @@
       best_rd_cur = INT64_MAX;
       mbmi->interinter_compound_type = cur_type;
 #if CONFIG_JNT_COMP
-      const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
-      if (cur_type == COMPOUND_AVERAGE) {
-        mbmi->comp_group_idx = 0;
-        rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][0];
+      const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+      int masked_type_cost = 0;
+      if (masked_compound_used) {
+        if (cur_type == COMPOUND_AVERAGE) {
+          masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
 
-        const int comp_index_ctx = get_comp_index_context(cm, xd);
-        rs2 += x->comp_idx_cost[comp_index_ctx][1];
-      } else {
-        mbmi->comp_group_idx = 1;
-        rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][1];
+          const int comp_index_ctx = get_comp_index_context(cm, xd);
+          masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
+        } else {
+          masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
 
-        int masked_type_cost = 0;
-        if (masked_compound_used) {
-          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
-            masked_type_cost +=
-                x->compound_type_cost[bsize]
-                                     [mbmi->interinter_compound_type - 1];
+          masked_type_cost +=
+              x->compound_type_cost[bsize][mbmi->interinter_compound_type - 1];
         }
-        rs2 += masked_type_cost;
+      } else {
+        const int comp_index_ctx = get_comp_index_context(cm, xd);
+        masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
       }
+      rs2 = av1_cost_literal(get_interinter_compound_type_bits(
+                bsize, mbmi->interinter_compound_type)) +
+            masked_type_cost;
 #else
       int masked_type_cost = 0;
       if (masked_compound_used) {