JNT_COMP: divide compound modes into two groups

Divide compound inter prediction modes into two groups:
Group A: jnt_comp, compound_average
Group B: interintra, compound_segment, wedge

Change-Id: I1142da2e3dfadf382d6b8183a87bde95119cf1b7
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index cb4d9b4..4f05d99 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1372,14 +1372,6 @@
     int16_t mode_ctx;
     write_ref_frames(cm, xd, w);
 
-#if CONFIG_JNT_COMP
-    if (has_second_ref(mbmi)) {
-      const int comp_index_ctx = get_comp_index_context(cm, xd);
-      aom_write_symbol(w, mbmi->compound_idx,
-                       ec_ctx->compound_index_cdf[comp_index_ctx], 2);
-    }
-#endif  // CONFIG_JNT_COMP
-
     if (is_compound)
       mode_ctx = mbmi_ext->compound_mode_context[mbmi->ref_frame[0]];
     else
@@ -1454,16 +1446,53 @@
 
     if (mbmi->ref_frame[1] != INTRA_FRAME) write_motion_mode(cm, xd, mi, w);
 
+#if CONFIG_JNT_COMP
+    // First write idx to indicate current compound inter prediction mode group
+    // Group A (0): jnt_comp, compound_average
+    // Group B (1): interintra, compound_segment, wedge
+    if (has_second_ref(mbmi)) {
+      const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+      aom_write_symbol(w, mbmi->comp_group_idx,
+                       ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+
+      if (mbmi->comp_group_idx == 0) {
+        if (mbmi->compound_idx)
+          assert(mbmi->interinter_compound_type == COMPOUND_AVERAGE);
+
+        const int comp_index_ctx = get_comp_index_context(cm, xd);
+        aom_write_symbol(w, mbmi->compound_idx,
+                         ec_ctx->compound_index_cdf[comp_index_ctx], 2);
+      } else {
+        assert(mbmi->interinter_compound_type == COMPOUND_WEDGE ||
+               mbmi->interinter_compound_type == COMPOUND_SEG);
+        // compound_segment, wedge
+        if (cpi->common.reference_mode != SINGLE_REFERENCE &&
+            is_inter_compound_mode(mbmi->mode) &&
+            mbmi->motion_mode == SIMPLE_TRANSLATION &&
+            is_any_masked_compound_used(bsize) && cm->allow_masked_compound &&
+            mbmi->comp_group_idx) {
+          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+            aom_write_symbol(w, mbmi->interinter_compound_type,
+                             ec_ctx->compound_type_cdf[bsize], COMPOUND_TYPES);
+
+          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize) &&
+              mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+            aom_write_literal(w, mbmi->wedge_index,
+                              get_wedge_bits_lookup(bsize));
+            aom_write_bit(w, mbmi->wedge_sign);
+          }
+          if (mbmi->interinter_compound_type == COMPOUND_SEG) {
+            aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
+          }
+        }
+      }
+    }
+#else   // CONFIG_JNT_COMP
     if (cpi->common.reference_mode != SINGLE_REFERENCE &&
         is_inter_compound_mode(mbmi->mode) &&
         mbmi->motion_mode == SIMPLE_TRANSLATION &&
         is_any_masked_compound_used(bsize)) {
-#if CONFIG_JNT_COMP
-      if (cm->allow_masked_compound && mbmi->compound_idx)
-#else
-      if (cm->allow_masked_compound)
-#endif  // CONFIG_JNT_COMP
-      {
+      if (cm->allow_masked_compound) {
         if (!is_interinter_compound_used(COMPOUND_WEDGE, bsize))
           aom_write_bit(w, mbmi->interinter_compound_type == COMPOUND_AVERAGE);
         else
@@ -1479,6 +1508,7 @@
         }
       }
     }
+#endif  // CONFIG_JNT_COMP
 
     write_mb_interp_filter(cpi, xd, w);
   }