Add ability to have multiple compound modes for interinter

This is currently just a refactor and creates no change in performance.
It allows new compound types to be added easily in the future to
facilitate experiments with segmentation masks.

Change-Id: If48fed216d482454fabb45a304b4220ada0dbdee
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index 378708f..84de8e4 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -280,7 +280,7 @@
   int use_wedge_interintra;
   int interintra_wedge_index;
   int interintra_wedge_sign;
-  int use_wedge_interinter;
+  COMPOUND_TYPE interinter_compound;
   int interinter_wedge_index;
   int interinter_wedge_sign;
 #endif  // CONFIG_EXT_INTER
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 405c983..b9dc812 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -291,6 +291,18 @@
       { 25, 29, 50, 192, 64, 192, 128, 180, 180 },   // 6 = two intra neighbours
     };
 
+static const aom_prob default_compound_type_probs[BLOCK_SIZES]
+                                                 [COMPOUND_TYPES - 1] = {
+                                                   { 208 }, { 208 }, { 208 },
+                                                   { 208 }, { 208 }, { 208 },
+                                                   { 216 }, { 216 }, { 216 },
+                                                   { 224 }, { 224 }, { 240 },
+                                                   { 240 },
+#if CONFIG_EXT_PARTITION
+                                                   { 255 }, { 255 }, { 255 },
+#endif  // CONFIG_EXT_PARTITION
+                                                 };
+
 static const aom_prob default_interintra_prob[BLOCK_SIZE_GROUPS] = {
   208, 208, 208, 208,
 };
@@ -309,13 +321,6 @@
   208, 208, 208
 #endif  // CONFIG_EXT_PARTITION
 };
-
-static const aom_prob default_wedge_interinter_prob[BLOCK_SIZES] = {
-  208, 208, 208, 208, 208, 208, 216, 216, 216, 224, 224, 224, 240,
-#if CONFIG_EXT_PARTITION
-  255, 255, 255
-#endif  // CONFIG_EXT_PARTITION
-};
 #endif  // CONFIG_EXT_INTER
 
 // Change this section appropriately once warped motion is supported
@@ -428,6 +433,10 @@
   -INTER_COMPOUND_OFFSET(NEAREST_NEWMV), -INTER_COMPOUND_OFFSET(NEW_NEARESTMV),
   -INTER_COMPOUND_OFFSET(NEAR_NEWMV), -INTER_COMPOUND_OFFSET(NEW_NEARMV)
 };
+
+const aom_tree_index av1_compound_type_tree[TREE_SIZE(COMPOUND_TYPES)] = {
+  -COMPOUND_AVERAGE, -COMPOUND_WEDGE
+};
 /* clang-format on */
 #endif  // CONFIG_EXT_INTER
 
@@ -1470,10 +1479,10 @@
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
 #if CONFIG_EXT_INTER
   av1_copy(fc->inter_compound_mode_probs, default_inter_compound_mode_probs);
+  av1_copy(fc->compound_type_prob, default_compound_type_probs);
   av1_copy(fc->interintra_prob, default_interintra_prob);
   av1_copy(fc->interintra_mode_prob, default_interintra_mode_prob);
   av1_copy(fc->wedge_interintra_prob, default_wedge_interintra_prob);
-  av1_copy(fc->wedge_interinter_prob, default_wedge_interinter_prob);
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_SUPERTX
   av1_copy(fc->supertx_prob, default_supertx_prob);
@@ -1676,10 +1685,12 @@
       fc->wedge_interintra_prob[i] = av1_mode_mv_merge_probs(
           pre_fc->wedge_interintra_prob[i], counts->wedge_interintra[i]);
   }
+
   for (i = 0; i < BLOCK_SIZES; ++i) {
     if (is_interinter_wedge_used(i))
-      fc->wedge_interinter_prob[i] = av1_mode_mv_merge_probs(
-          pre_fc->wedge_interinter_prob[i], counts->wedge_interinter[i]);
+      aom_tree_merge_probs(
+          av1_compound_type_tree, pre_fc->compound_type_prob[i],
+          counts->compound_interinter[i], fc->compound_type_prob[i]);
   }
 #endif  // CONFIG_EXT_INTER
 
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index c480613..6478bcf 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -108,10 +108,10 @@
 #if CONFIG_EXT_INTER
   aom_prob inter_compound_mode_probs[INTER_MODE_CONTEXTS]
                                     [INTER_COMPOUND_MODES - 1];
+  aom_prob compound_type_prob[BLOCK_SIZES][COMPOUND_TYPES - 1];
   aom_prob interintra_prob[BLOCK_SIZE_GROUPS];
   aom_prob interintra_mode_prob[BLOCK_SIZE_GROUPS][INTERINTRA_MODES - 1];
   aom_prob wedge_interintra_prob[BLOCK_SIZES];
-  aom_prob wedge_interinter_prob[BLOCK_SIZES];
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
   aom_prob motion_mode_prob[BLOCK_SIZES][MOTION_MODES - 1];
@@ -221,7 +221,7 @@
   unsigned int interintra[BLOCK_SIZE_GROUPS][2];
   unsigned int interintra_mode[BLOCK_SIZE_GROUPS][INTERINTRA_MODES];
   unsigned int wedge_interintra[BLOCK_SIZES][2];
-  unsigned int wedge_interinter[BLOCK_SIZES][2];
+  unsigned int compound_interinter[BLOCK_SIZES][COMPOUND_TYPES];
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
   unsigned int motion_mode[BLOCK_SIZES][MOTION_MODES];
@@ -313,6 +313,7 @@
     av1_interintra_mode_tree[TREE_SIZE(INTERINTRA_MODES)];
 extern const aom_tree_index
     av1_inter_compound_mode_tree[TREE_SIZE(INTER_COMPOUND_MODES)];
+extern const aom_tree_index av1_compound_type_tree[TREE_SIZE(COMPOUND_TYPES)];
 #endif  // CONFIG_EXT_INTER
 extern const aom_tree_index av1_partition_tree[TREE_SIZE(PARTITION_TYPES)];
 #if CONFIG_EXT_PARTITION_TYPES
diff --git a/av1/common/enums.h b/av1/common/enums.h
index ebf520a..4ed01ac 100644
--- a/av1/common/enums.h
+++ b/av1/common/enums.h
@@ -335,6 +335,11 @@
   INTERINTRA_MODES
 } INTERINTRA_MODE;
 
+typedef enum {
+  COMPOUND_AVERAGE = 0,
+  COMPOUND_WEDGE,
+  COMPOUND_TYPES,
+} COMPOUND_TYPE;
 #endif  // CONFIG_EXT_INTER
 
 #if CONFIG_FILTER_INTRA
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index f0ba7fd..c1a9f9f 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -632,7 +632,7 @@
 
 #if CONFIG_EXT_INTER
           if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) &&
-              mi->mbmi.use_wedge_interinter)
+              mi->mbmi.interinter_compound)
             av1_make_masked_inter_predictor(
                 pre, pre_buf->stride, dst, dst_buf->stride, subpel_x, subpel_y,
                 sf, w, h, mi->mbmi.interp_filter, xs, ys,
@@ -698,7 +698,7 @@
 
 #if CONFIG_EXT_INTER
     if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) &&
-        mi->mbmi.use_wedge_interinter)
+        mi->mbmi.interinter_compound)
       av1_make_masked_inter_predictor(pre, pre_buf->stride, dst,
                                       dst_buf->stride, subpel_x, subpel_y, sf,
                                       w, h, mi->mbmi.interp_filter, xs, ys,
@@ -1283,8 +1283,8 @@
   if (is_interintra_pred(mbmi)) {
     mbmi->ref_frame[1] = NONE;
   } else if (has_second_ref(mbmi) && is_interinter_wedge_used(mbmi->sb_type) &&
-             mbmi->use_wedge_interinter) {
-    mbmi->use_wedge_interinter = 0;
+             mbmi->interinter_compound) {
+    mbmi->interinter_compound = COMPOUND_AVERAGE;
     mbmi->ref_frame[1] = NONE;
   }
   return;
@@ -2048,7 +2048,7 @@
   uint8_t *const dst = dst_buf->buf + dst_buf->stride * y + x;
 
   if (is_compound && is_interinter_wedge_used(mbmi->sb_type) &&
-      mbmi->use_wedge_interinter) {
+      mbmi->interinter_compound) {
 #if CONFIG_AOM_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
       build_masked_compound_wedge_highbd(
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 38b7261..4870899 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -4196,7 +4196,9 @@
     if (cm->reference_mode != SINGLE_REFERENCE) {
       for (i = 0; i < BLOCK_SIZES; i++) {
         if (is_interinter_wedge_used(i)) {
-          av1_diff_update_prob(&r, &fc->wedge_interinter_prob[i], ACCT_STR);
+          for (j = 0; j < COMPOUND_TYPES - 1; j++) {
+            av1_diff_update_prob(&r, &fc->compound_type_prob[i][j], ACCT_STR);
+          }
         }
       }
     }
@@ -4289,8 +4291,9 @@
                  sizeof(cm->counts.interintra)));
   assert(!memcmp(cm->counts.wedge_interintra, zero_counts.wedge_interintra,
                  sizeof(cm->counts.wedge_interintra)));
-  assert(!memcmp(cm->counts.wedge_interinter, zero_counts.wedge_interinter,
-                 sizeof(cm->counts.wedge_interinter)));
+  assert(!memcmp(cm->counts.compound_interinter,
+                 zero_counts.compound_interinter,
+                 sizeof(cm->counts.compound_interinter)));
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
   assert(!memcmp(cm->counts.motion_mode, zero_counts.motion_mode,
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index b560f02..56311e7 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -1761,7 +1761,7 @@
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
 
 #if CONFIG_EXT_INTER
-  mbmi->use_wedge_interinter = 0;
+  mbmi->interinter_compound = COMPOUND_AVERAGE;
   if (cm->reference_mode != SINGLE_REFERENCE &&
       is_inter_compound_mode(mbmi->mode) &&
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
@@ -1769,11 +1769,11 @@
         mbmi->motion_mode != SIMPLE_TRANSLATION) &&
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
       is_interinter_wedge_used(bsize)) {
-    mbmi->use_wedge_interinter =
-        aom_read(r, cm->fc->wedge_interinter_prob[bsize], ACCT_STR);
+    mbmi->interinter_compound = aom_read_tree(
+        r, av1_compound_type_tree, cm->fc->compound_type_prob[bsize], ACCT_STR);
     if (xd->counts)
-      xd->counts->wedge_interinter[bsize][mbmi->use_wedge_interinter]++;
-    if (mbmi->use_wedge_interinter) {
+      xd->counts->compound_interinter[bsize][mbmi->interinter_compound]++;
+    if (mbmi->interinter_compound) {
       mbmi->interinter_wedge_index =
           aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
       mbmi->interinter_wedge_sign = aom_read_bit(r, ACCT_STR);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 4c8c2fa..7030210 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -111,6 +111,7 @@
 #endif  // CONFIG_EXT_INTRA
 #if CONFIG_EXT_INTER
 static struct av1_token interintra_mode_encodings[INTERINTRA_MODES];
+static struct av1_token compound_type_encodings[COMPOUND_TYPES];
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
 static struct av1_token motion_mode_encodings[MOTION_MODES];
@@ -155,6 +156,7 @@
 #endif  // CONFIG_EXT_INTRA
 #if CONFIG_EXT_INTER
   av1_tokens_from_tree(interintra_mode_encodings, av1_interintra_mode_tree);
+  av1_tokens_from_tree(compound_type_encodings, av1_compound_type_tree);
 #endif  // CONFIG_EXT_INTER
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
   av1_tokens_from_tree(motion_mode_encodings, av1_motion_mode_tree);
@@ -1602,9 +1604,10 @@
           mbmi->motion_mode != SIMPLE_TRANSLATION) &&
 #endif  // CONFIG_MOTION_VAR
         is_interinter_wedge_used(bsize)) {
-      aom_write(w, mbmi->use_wedge_interinter,
-                cm->fc->wedge_interinter_prob[bsize]);
-      if (mbmi->use_wedge_interinter) {
+      av1_write_token(w, av1_compound_type_tree,
+                      cm->fc->compound_type_prob[bsize],
+                      &compound_type_encodings[mbmi->interinter_compound]);
+      if (mbmi->interinter_compound) {
         aom_write_literal(w, mbmi->interinter_wedge_index,
                           get_wedge_bits_lookup(bsize));
         aom_write_bit(w, mbmi->interinter_wedge_sign);
@@ -4133,8 +4136,9 @@
     if (cm->reference_mode != SINGLE_REFERENCE) {
       for (i = 0; i < BLOCK_SIZES; i++)
         if (is_interinter_wedge_used(i))
-          av1_cond_prob_diff_update(header_bc, &fc->wedge_interinter_prob[i],
-                                    cm->counts.wedge_interinter[i], probwt);
+          prob_diff_update(av1_compound_type_tree, fc->compound_type_prob[i],
+                           cm->counts.compound_interinter[i], COMPOUND_TYPES,
+                           probwt, header_bc);
     }
 #endif  // CONFIG_EXT_INTER
 
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 657c86e..8922417 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1955,7 +1955,7 @@
               mbmi->motion_mode != SIMPLE_TRANSLATION) &&
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
             is_interinter_wedge_used(bsize)) {
-          counts->wedge_interinter[bsize][mbmi->use_wedge_interinter]++;
+          counts->compound_interinter[bsize][mbmi->interinter_compound]++;
         }
 #endif  // CONFIG_EXT_INTER
       }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 49e8b2d..f293a04 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -6955,7 +6955,7 @@
   *compmode_interintra_cost = 0;
   mbmi->use_wedge_interintra = 0;
   *compmode_wedge_cost = 0;
-  mbmi->use_wedge_interinter = 0;
+  mbmi->interinter_compound = COMPOUND_AVERAGE;
 
   // is_comp_interintra_pred implies !is_comp_pred
   assert(!is_comp_interintra_pred || (!is_comp_pred));
@@ -7360,9 +7360,12 @@
     int64_t best_rd_wedge = INT64_MAX;
     int tmp_skip_txfm_sb;
     int64_t tmp_skip_sse_sb;
+    int compound_type_cost[COMPOUND_TYPES];
 
-    rs2 = av1_cost_bit(cm->fc->wedge_interinter_prob[bsize], 0);
-    mbmi->use_wedge_interinter = 0;
+    mbmi->interinter_compound = COMPOUND_AVERAGE;
+    av1_cost_tokens(compound_type_cost, cm->fc->compound_type_prob[bsize],
+                    av1_compound_type_tree);
+    rs2 = compound_type_cost[mbmi->interinter_compound];
     av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize);
     av1_subtract_plane(x, bsize, 0);
     rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
@@ -7380,9 +7383,9 @@
       uint8_t *preds1[1] = { pred1 };
       int strides[1] = { bw };
 
-      mbmi->use_wedge_interinter = 1;
+      mbmi->interinter_compound = COMPOUND_WEDGE;
       rs2 = av1_cost_literal(get_interinter_wedge_bits(bsize)) +
-            av1_cost_bit(cm->fc->wedge_interinter_prob[bsize], 1);
+            compound_type_cost[mbmi->interinter_compound];
 
       av1_build_inter_predictors_for_planes_single_buf(
           xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides);
@@ -7443,13 +7446,13 @@
         best_rd_wedge = rd;
 
         if (best_rd_wedge < best_rd_nowedge) {
-          mbmi->use_wedge_interinter = 1;
+          mbmi->interinter_compound = COMPOUND_WEDGE;
           xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int;
           xd->mi[0]->bmi[0].as_mv[1].as_int = mbmi->mv[1].as_int;
           rd_stats->rate += tmp_rate_mv - rate_mv;
           rate_mv = tmp_rate_mv;
         } else {
-          mbmi->use_wedge_interinter = 0;
+          mbmi->interinter_compound = COMPOUND_AVERAGE;
           mbmi->mv[0].as_int = cur_mv[0].as_int;
           mbmi->mv[1].as_int = cur_mv[1].as_int;
           xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int;
@@ -7466,9 +7469,9 @@
           rd = RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum);
         best_rd_wedge = rd;
         if (best_rd_wedge < best_rd_nowedge) {
-          mbmi->use_wedge_interinter = 1;
+          mbmi->interinter_compound = COMPOUND_WEDGE;
         } else {
-          mbmi->use_wedge_interinter = 0;
+          mbmi->interinter_compound = COMPOUND_AVERAGE;
         }
       }
     }
@@ -7478,13 +7481,11 @@
 
     pred_exists = 0;
 
-    if (mbmi->use_wedge_interinter)
-      *compmode_wedge_cost =
-          av1_cost_literal(get_interinter_wedge_bits(bsize)) +
-          av1_cost_bit(cm->fc->wedge_interinter_prob[bsize], 1);
-    else
-      *compmode_wedge_cost =
-          av1_cost_bit(cm->fc->wedge_interinter_prob[bsize], 0);
+    *compmode_wedge_cost = compound_type_cost[mbmi->interinter_compound];
+
+    if (mbmi->interinter_compound)
+      *compmode_wedge_cost +=
+          av1_cost_literal(get_interinter_wedge_bits(bsize));
   }
 
   if (is_comp_interintra_pred) {
@@ -10170,7 +10171,7 @@
 #endif  // CONFIG_FILTER_INTRA
   mbmi->motion_mode = SIMPLE_TRANSLATION;
 #if CONFIG_EXT_INTER
-  mbmi->use_wedge_interinter = 0;
+  mbmi->interinter_compound = COMPOUND_AVERAGE;
   mbmi->use_wedge_interintra = 0;
 #endif  // CONFIG_EXT_INTER