JNT_COMP: divide compound modes into two groups

Divide compound inter prediction modes into two groups:
Group A: jnt_comp, compound_average
Group B: interintra, compound_segment, wedge

Change-Id: I1142da2e3dfadf382d6b8183a87bde95119cf1b7
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index eb434eb..5e5fa29 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -367,6 +367,7 @@
 
 #if CONFIG_JNT_COMP
   int compound_idx;
+  int comp_group_idx;
 #endif
 } MB_MODE_INFO;
 
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 49a4501..c0368d4 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -1473,9 +1473,20 @@
       { AOM_ICDF(16384), AOM_ICDF(32768), 0 },
       { AOM_ICDF(8192), AOM_ICDF(32768), 0 },
     };
+
+static const aom_cdf_prob
+    default_comp_group_idx_cdfs[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)] = {
+      { AOM_ICDF(24576), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(16384), AOM_ICDF(32768), 0 },
+      { AOM_ICDF(8192), AOM_ICDF(32768), 0 },
+    };
 static const aom_prob default_compound_idx_probs[COMP_INDEX_CONTEXTS] = {
   192, 128, 64, 192, 128, 64
 };
+
+static const aom_prob default_comp_group_idx_probs[COMP_GROUP_IDX_CONTEXTS] = {
+  192, 128, 64
+};
 #endif  // CONFIG_JNT_COMP
 
 #if CONFIG_FILTER_INTRA
@@ -2929,7 +2940,9 @@
   av1_copy(fc->txfm_partition_cdf, default_txfm_partition_cdf);
 #if CONFIG_JNT_COMP
   av1_copy(fc->compound_index_cdf, default_compound_idx_cdfs);
+  av1_copy(fc->comp_group_idx_cdf, default_comp_group_idx_cdfs);
   av1_copy(fc->compound_index_probs, default_compound_idx_probs);
+  av1_copy(fc->comp_group_idx_probs, default_comp_group_idx_probs);
 #endif  // CONFIG_JNT_COMP
   av1_copy(fc->newmv_cdf, default_newmv_cdf);
   av1_copy(fc->zeromv_cdf, default_zeromv_cdf);
@@ -3044,6 +3057,9 @@
   for (i = 0; i < COMP_INDEX_CONTEXTS; ++i)
     fc->compound_index_probs[i] = av1_mode_mv_merge_probs(
         pre_fc->compound_index_probs[i], counts->compound_index[i]);
+  for (i = 0; i < COMP_GROUP_IDX_CONTEXTS; ++i)
+    fc->comp_group_idx_probs[i] = av1_mode_mv_merge_probs(
+        pre_fc->comp_group_idx_probs[i], counts->comp_group_idx[i]);
 #endif  // CONFIG_JNT_COMP
 }
 
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index ec17d1c..adf461d 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -245,7 +245,9 @@
   aom_cdf_prob txfm_partition_cdf[TXFM_PARTITION_CONTEXTS][CDF_SIZE(2)];
 #if CONFIG_JNT_COMP
   aom_cdf_prob compound_index_cdf[COMP_INDEX_CONTEXTS][CDF_SIZE(2)];
+  aom_cdf_prob comp_group_idx_cdf[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)];
   aom_prob compound_index_probs[COMP_INDEX_CONTEXTS];
+  aom_prob comp_group_idx_probs[COMP_GROUP_IDX_CONTEXTS];
 #endif  // CONFIG_JNT_COMP
 #if CONFIG_EXT_SKIP
   aom_cdf_prob skip_mode_cdfs[SKIP_CONTEXTS][CDF_SIZE(2)];
@@ -395,6 +397,7 @@
   unsigned int skip[SKIP_CONTEXTS][2];
 #if CONFIG_JNT_COMP
   unsigned int compound_index[COMP_INDEX_CONTEXTS][2];
+  unsigned int comp_group_idx[COMP_GROUP_IDX_CONTEXTS][2];
 #endif  // CONFIG_JNT_COMP
   unsigned int delta_q[DELTA_Q_PROBS][2];
 #if CONFIG_EXT_DELTA_Q
diff --git a/av1/common/enums.h b/av1/common/enums.h
index ddcb342..667cc9f 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -592,6 +592,7 @@
 
 #if CONFIG_JNT_COMP
 #define COMP_INDEX_CONTEXTS 6
+#define COMP_GROUP_IDX_CONTEXTS 3
 #endif  // CONFIG_JNT_COMP
 
 #define NMV_CONTEXTS 3
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index 5eaf1fa..8047769 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -122,6 +122,29 @@
 
   return above_ctx + left_ctx + 3 * offset;
 }
+
+static INLINE int get_comp_group_idx_context(const MACROBLOCKD *xd) {
+  const MODE_INFO *const above_mi = xd->above_mi;
+  const MODE_INFO *const left_mi = xd->left_mi;
+  int above_ctx = 0, left_ctx = 0;
+
+  if (above_mi) {
+    const MB_MODE_INFO *above_mbmi = &above_mi->mbmi;
+    if (has_second_ref(above_mbmi))
+      above_ctx = above_mbmi->comp_group_idx;
+    else if (above_mbmi->ref_frame[0] == ALTREF_FRAME)
+      above_ctx = 1;
+  }
+  if (left_mi) {
+    const MB_MODE_INFO *left_mbmi = &left_mi->mbmi;
+    if (has_second_ref(left_mbmi))
+      left_ctx = left_mbmi->comp_group_idx;
+    else if (left_mbmi->ref_frame[0] == ALTREF_FRAME)
+      left_ctx = 1;
+  }
+
+  return above_ctx + left_ctx;
+}
 #endif  // CONFIG_JNT_COMP
 
 static INLINE aom_cdf_prob *av1_get_pred_cdf_seg_id(
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 577640c..5917301 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -1856,15 +1856,9 @@
   read_ref_frames(cm, xd, r, mbmi->segment_id, mbmi->ref_frame);
   const int is_compound = has_second_ref(mbmi);
 
-#if CONFIG_JNT_COMP
-  if (is_compound) {
-    const int comp_index_ctx = get_comp_index_context(cm, xd);
-    mbmi->compound_idx = aom_read_symbol(
-        r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR);
-    if (xd->counts)
-      ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
-  }
-#endif  // CONFIG_JNT_COMP
+#if CONFIG_EXT_SKIP
+// TODO(zoeliu): To work with JNT_COMP
+#endif  // CONFIG_EXT_SKIP
 
   for (int ref = 0; ref < 1 + is_compound; ++ref) {
     MV_REFERENCE_FRAME frame = mbmi->ref_frame[ref];
@@ -2188,6 +2182,68 @@
     mbmi->motion_mode = read_motion_mode(cm, xd, mi, r);
 #endif  // CONFIG_EXT_WARPED_MOTION
 
+#if CONFIG_JNT_COMP
+  // init
+  mbmi->comp_group_idx = 1;
+  mbmi->compound_idx = 1;
+  mbmi->interinter_compound_type = COMPOUND_AVERAGE;
+
+  // read idx to indicate current compound inter prediction mode group
+  if (has_second_ref(mbmi)) {
+    const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+    mbmi->comp_group_idx = aom_read_symbol(
+        r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
+    if (xd->counts)
+      ++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx];
+
+    if (mbmi->comp_group_idx == 0) {
+      const int comp_index_ctx = get_comp_index_context(cm, xd);
+      mbmi->compound_idx = aom_read_symbol(
+          r, ec_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR);
+
+      if (xd->counts)
+        ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+
+      if (mbmi->compound_idx) mbmi->interinter_compound_type = COMPOUND_AVERAGE;
+    } else {
+      // compound_segment, wedge
+      mbmi->interinter_compound_type = COMPOUND_AVERAGE;
+      if (cm->reference_mode != SINGLE_REFERENCE &&
+          is_inter_compound_mode(mbmi->mode) &&
+          mbmi->motion_mode == SIMPLE_TRANSLATION
+#if CONFIG_EXT_SKIP
+          && !mbmi->skip_mode
+#endif  // CONFIG_EXT_SKIP
+          && mbmi->comp_group_idx) {
+        if (is_any_masked_compound_used(bsize)) {
+          if (cm->allow_masked_compound) {
+            if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+              mbmi->interinter_compound_type =
+                  aom_read_symbol(r, ec_ctx->compound_type_cdf[bsize],
+                                  COMPOUND_TYPES, ACCT_STR);
+            else
+              mbmi->interinter_compound_type = COMPOUND_SEG;
+
+            if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+              assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
+              mbmi->wedge_index =
+                  aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
+              mbmi->wedge_sign = aom_read_bit(r, ACCT_STR);
+            }
+            if (mbmi->interinter_compound_type == COMPOUND_SEG) {
+              mbmi->mask_type =
+                  aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
+            }
+          }
+        }
+
+        if (xd->counts)
+          xd->counts
+              ->compound_interinter[bsize][mbmi->interinter_compound_type]++;
+      }
+    }
+  }
+#else  // CONFIG_JNT_COMP
   mbmi->interinter_compound_type = COMPOUND_AVERAGE;
   if (cm->reference_mode != SINGLE_REFERENCE &&
       is_inter_compound_mode(mbmi->mode) &&
@@ -2225,6 +2281,7 @@
     if (xd->counts)
       xd->counts->compound_interinter[bsize][mbmi->interinter_compound_type]++;
   }
+#endif  // CONFIG_JNT_COMP
 
   read_mb_interp_filter(cm, xd, mbmi, r);
 
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index cb4d9b4..4f05d99 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1372,14 +1372,6 @@
     int16_t mode_ctx;
     write_ref_frames(cm, xd, w);
 
-#if CONFIG_JNT_COMP
-    if (has_second_ref(mbmi)) {
-      const int comp_index_ctx = get_comp_index_context(cm, xd);
-      aom_write_symbol(w, mbmi->compound_idx,
-                       ec_ctx->compound_index_cdf[comp_index_ctx], 2);
-    }
-#endif  // CONFIG_JNT_COMP
-
     if (is_compound)
       mode_ctx = mbmi_ext->compound_mode_context[mbmi->ref_frame[0]];
     else
@@ -1454,16 +1446,53 @@
 
     if (mbmi->ref_frame[1] != INTRA_FRAME) write_motion_mode(cm, xd, mi, w);
 
+#if CONFIG_JNT_COMP
+    // First write idx to indicate current compound inter prediction mode group
+    // Group A (0): jnt_comp, compound_average
+    // Group B (1): interintra, compound_segment, wedge
+    if (has_second_ref(mbmi)) {
+      const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
+      aom_write_symbol(w, mbmi->comp_group_idx,
+                       ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+
+      if (mbmi->comp_group_idx == 0) {
+        if (mbmi->compound_idx)
+          assert(mbmi->interinter_compound_type == COMPOUND_AVERAGE);
+
+        const int comp_index_ctx = get_comp_index_context(cm, xd);
+        aom_write_symbol(w, mbmi->compound_idx,
+                         ec_ctx->compound_index_cdf[comp_index_ctx], 2);
+      } else {
+        assert(mbmi->interinter_compound_type == COMPOUND_WEDGE ||
+               mbmi->interinter_compound_type == COMPOUND_SEG);
+        // compound_segment, wedge
+        if (cpi->common.reference_mode != SINGLE_REFERENCE &&
+            is_inter_compound_mode(mbmi->mode) &&
+            mbmi->motion_mode == SIMPLE_TRANSLATION &&
+            is_any_masked_compound_used(bsize) && cm->allow_masked_compound &&
+            mbmi->comp_group_idx) {
+          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+            aom_write_symbol(w, mbmi->interinter_compound_type,
+                             ec_ctx->compound_type_cdf[bsize], COMPOUND_TYPES);
+
+          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize) &&
+              mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+            aom_write_literal(w, mbmi->wedge_index,
+                              get_wedge_bits_lookup(bsize));
+            aom_write_bit(w, mbmi->wedge_sign);
+          }
+          if (mbmi->interinter_compound_type == COMPOUND_SEG) {
+            aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
+          }
+        }
+      }
+    }
+#else   // CONFIG_JNT_COMP
     if (cpi->common.reference_mode != SINGLE_REFERENCE &&
         is_inter_compound_mode(mbmi->mode) &&
         mbmi->motion_mode == SIMPLE_TRANSLATION &&
         is_any_masked_compound_used(bsize)) {
-#if CONFIG_JNT_COMP
-      if (cm->allow_masked_compound && mbmi->compound_idx)
-#else
-      if (cm->allow_masked_compound)
-#endif  // CONFIG_JNT_COMP
-      {
+      if (cm->allow_masked_compound) {
         if (!is_interinter_compound_used(COMPOUND_WEDGE, bsize))
           aom_write_bit(w, mbmi->interinter_compound_type == COMPOUND_AVERAGE);
         else
@@ -1479,6 +1508,7 @@
         }
       }
     }
+#endif  // CONFIG_JNT_COMP
 
     write_mb_interp_filter(cpi, xd, w);
   }
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index a22217f..be77347 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -641,9 +641,6 @@
   const int x_mis = AOMMIN(bw, cm->mi_cols - mi_col);
   const int y_mis = AOMMIN(bh, cm->mi_rows - mi_row);
   av1_copy_frame_mvs(cm, mi, mi_row, mi_col, x_mis, y_mis);
-
-#if CONFIG_JNT_COMP
-#endif  // CONFIG_JNT_COMP
 }
 
 #if NC_MODE_INFO
@@ -1258,9 +1255,30 @@
           }
         }
 
+#if CONFIG_JNT_COMP
+        if (has_second_ref(mbmi)) {
+          const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+          ++counts->comp_group_idx[comp_group_idx_ctx][mbmi->comp_group_idx];
+          if (allow_update_cdf)
+            update_cdf(fc->comp_group_idx_cdf[comp_group_idx_ctx],
+                       mbmi->comp_group_idx, 2);
+
+          if (mbmi->comp_group_idx == 0) {
+            const int comp_index_ctx = get_comp_index_context(cm, xd);
+            ++counts->compound_index[comp_index_ctx][mbmi->compound_idx];
+            if (allow_update_cdf)
+              update_cdf(fc->compound_index_cdf[comp_index_ctx],
+                         mbmi->compound_idx, 2);
+          }
+        }
+#endif  // CONFIG_JNT_COMP
+
         if (cm->reference_mode != SINGLE_REFERENCE &&
-            is_inter_compound_mode(mbmi->mode) &&
-            mbmi->motion_mode == SIMPLE_TRANSLATION) {
+            is_inter_compound_mode(mbmi->mode)
+#if CONFIG_JNT_COMP
+            && mbmi->comp_group_idx
+#endif  // CONFIG_JNT_COMP
+            && mbmi->motion_mode == SIMPLE_TRANSLATION) {
           if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
             counts
                 ->compound_interinter[bsize][mbmi->interinter_compound_type]++;
@@ -1319,16 +1337,6 @@
           }
         }
       }
-
-#if CONFIG_JNT_COMP
-      if (has_second_ref(mbmi)) {
-        const int comp_index_ctx = get_comp_index_context(cm, xd);
-        ++counts->compound_index[comp_index_ctx][mbmi->compound_idx];
-        if (allow_update_cdf)
-          update_cdf(fc->compound_index_cdf[comp_index_ctx], mbmi->compound_idx,
-                     2);
-      }
-#endif  // CONFIG_JNT_COMP
     }
   }
 }
@@ -1484,6 +1492,15 @@
       mbmi->current_delta_lf_from_base = xd->prev_delta_lf_from_base;
     }
 #endif
+#if CONFIG_JNT_COMP
+    if (has_second_ref(mbmi)) {
+      if (mbmi->compound_idx == 0 ||
+          mbmi->interinter_compound_type == COMPOUND_AVERAGE)
+        mbmi->comp_group_idx = 0;
+      else
+        mbmi->comp_group_idx = 1;
+    }
+#endif
     update_stats(&cpi->common, tile_data, td, mi_row, mi_col);
   }
 }