JNT_COMP: change COMPOUND_AVERAGE in cdf

Remove COMPOUND_AVERAGE from compound_type_cdfs since it is now grouped
to compound_idx. However, COMPOUND_AVERAGE is still used elsewhere.

Change-Id: Ie0d460aabf9252e80eb4130cfef9aaf0efc3969d
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 4f05d99..c669a4c 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -70,7 +70,11 @@
 }
 
 static struct av1_token interintra_mode_encodings[INTERINTRA_MODES];
+#if CONFIG_JNT_COMP
+static struct av1_token compound_type_encodings[COMPOUND_TYPES - 1];
+#else
 static struct av1_token compound_type_encodings[COMPOUND_TYPES];
+#endif  // CONFIG_JNT_COMP
 #if CONFIG_LOOP_RESTORATION
 static void loop_restoration_write_sb_coeffs(const AV1_COMMON *const cm,
                                              MACROBLOCKD *xd,
@@ -1469,11 +1473,11 @@
         if (cpi->common.reference_mode != SINGLE_REFERENCE &&
             is_inter_compound_mode(mbmi->mode) &&
             mbmi->motion_mode == SIMPLE_TRANSLATION &&
-            is_any_masked_compound_used(bsize) && cm->allow_masked_compound &&
-            mbmi->comp_group_idx) {
+            is_any_masked_compound_used(bsize) && cm->allow_masked_compound) {
           if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
-            aom_write_symbol(w, mbmi->interinter_compound_type,
-                             ec_ctx->compound_type_cdf[bsize], COMPOUND_TYPES);
+            aom_write_symbol(w, mbmi->interinter_compound_type - 1,
+                             ec_ctx->compound_type_cdf[bsize],
+                             COMPOUND_TYPES - 1);
 
           if (is_interinter_compound_used(COMPOUND_WEDGE, bsize) &&
               mbmi->interinter_compound_type == COMPOUND_WEDGE) {
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 8de0a78..ec99b38 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -271,7 +271,11 @@
   int drl_mode_cost0[DRL_MODE_CONTEXTS][2];
 
   int inter_compound_mode_cost[INTER_MODE_CONTEXTS][INTER_COMPOUND_MODES];
+#if CONFIG_JNT_COMP
+  int compound_type_cost[BLOCK_SIZES_ALL][COMPOUND_TYPES - 1];
+#else
   int compound_type_cost[BLOCK_SIZES_ALL][COMPOUND_TYPES];
+#endif  // CONFIG_JNT_COMP
   int interintra_cost[BLOCK_SIZE_GROUPS][2];
   int wedge_interintra_cost[BLOCK_SIZES_ALL][2];
   int interintra_mode_cost[BLOCK_SIZE_GROUPS][INTERINTRA_MODES];
@@ -347,6 +351,7 @@
 #endif  // CONFIG_DIST_8X8
 #if CONFIG_JNT_COMP
   int comp_idx_cost[COMP_INDEX_CONTEXTS][2];
+  int comp_group_idx_cost[COMP_GROUP_IDX_CONTEXTS][2];
 #endif  // CONFIG_JNT_COMP
 };
 
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index be77347..f3575b1 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1271,14 +1271,24 @@
                          mbmi->compound_idx, 2);
           }
         }
-#endif  // CONFIG_JNT_COMP
 
         if (cm->reference_mode != SINGLE_REFERENCE &&
-            is_inter_compound_mode(mbmi->mode)
-#if CONFIG_JNT_COMP
-            && mbmi->comp_group_idx
-#endif  // CONFIG_JNT_COMP
-            && mbmi->motion_mode == SIMPLE_TRANSLATION) {
+            is_inter_compound_mode(mbmi->mode) && mbmi->comp_group_idx &&
+            mbmi->motion_mode == SIMPLE_TRANSLATION) {
+          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
+            counts->compound_interinter[bsize]
+                                       [mbmi->interinter_compound_type - 1]++;
+            if (allow_update_cdf)
+              update_cdf(fc->compound_type_cdf[bsize],
+                         mbmi->interinter_compound_type - 1,
+                         COMPOUND_TYPES - 1);
+          }
+        }
+#else   // CONFIG_JNT_COMP
+
+        if (cm->reference_mode != SINGLE_REFERENCE &&
+            is_inter_compound_mode(mbmi->mode) &&
+            mbmi->motion_mode == SIMPLE_TRANSLATION) {
           if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
             counts
                 ->compound_interinter[bsize][mbmi->interinter_compound_type]++;
@@ -1287,6 +1297,7 @@
                          mbmi->interinter_compound_type, COMPOUND_TYPES);
           }
         }
+#endif  // CONFIG_JNT_COMP
       }
     }
 
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 8abee95..3bd3b42 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -284,6 +284,10 @@
       av1_cost_tokens_from_cdf(x->comp_idx_cost[i], fc->compound_index_cdf[i],
                                NULL);
     }
+    for (i = 0; i < COMP_GROUP_IDX_CONTEXTS; ++i) {
+      av1_cost_tokens_from_cdf(x->comp_group_idx_cost[i],
+                               fc->comp_group_idx_cdf[i], NULL);
+    }
 #endif  // CONFIG_JNT_COMP
   }
 }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index d7e722e..ca14064 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5661,6 +5661,7 @@
   }
 }
 
+#if !CONFIG_JNT_COMP
 static int get_interinter_compound_type_bits(BLOCK_SIZE bsize,
                                              COMPOUND_TYPE comp_type) {
   (void)bsize;
@@ -5671,6 +5672,7 @@
     default: assert(0); return 0;
   }
 }
+#endif
 
 typedef struct {
   int eobs;
@@ -8139,8 +8141,14 @@
 
 #if CONFIG_JNT_COMP
   if (is_comp_pred) {
-    const int comp_index_ctx = get_comp_index_context(cm, xd);
-    rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
+    if (mbmi->compound_idx == 0) {
+      mbmi->comp_group_idx = 0;
+      const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+      rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
+
+      const int comp_index_ctx = get_comp_index_context(cm, xd);
+      rd_stats->rate += x->comp_idx_cost[comp_index_ctx][0];
+    }
   }
 #endif  // CONFIG_JNT_COMP
 
@@ -8307,6 +8315,28 @@
       tmp_rate_mv = rate_mv;
       best_rd_cur = INT64_MAX;
       mbmi->interinter_compound_type = cur_type;
+#if CONFIG_JNT_COMP
+      const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+      if (cur_type == COMPOUND_AVERAGE) {
+        mbmi->comp_group_idx = 0;
+        rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][0];
+
+        const int comp_index_ctx = get_comp_index_context(cm, xd);
+        rs2 += x->comp_idx_cost[comp_index_ctx][1];
+      } else {
+        mbmi->comp_group_idx = 1;
+        rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][1];
+
+        int masked_type_cost = 0;
+        if (masked_compound_used) {
+          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+            masked_type_cost +=
+                x->compound_type_cost[bsize]
+                                     [mbmi->interinter_compound_type - 1];
+        }
+        rs2 += masked_type_cost;
+      }
+#else
       int masked_type_cost = 0;
       if (masked_compound_used) {
         if (!is_interinter_compound_used(COMPOUND_WEDGE, bsize))
@@ -8318,6 +8348,7 @@
       rs2 = av1_cost_literal(get_interinter_compound_type_bits(
                 bsize, mbmi->interinter_compound_type)) +
             masked_type_cost;
+#endif  // CONFIG_JNT_COMP
 
       switch (cur_type) {
         case COMPOUND_AVERAGE: