Frame level flags to ctrl ext-inter compound modes

Change-Id: I904283119d8f2c1099e6ec2953ea1c10c5e3b280
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index f703bc6..6aa51e2 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -4738,30 +4738,34 @@
                          fc->inter_singleref_comp_mode_probs[i]);
 #endif  // CONFIG_COMPOUND_SINGLEREF
 #if CONFIG_INTERINTRA
-  for (i = 0; i < BLOCK_SIZE_GROUPS; ++i) {
-    if (is_interintra_allowed_bsize_group(i))
-      fc->interintra_prob[i] = av1_mode_mv_merge_probs(
-          pre_fc->interintra_prob[i], counts->interintra[i]);
-  }
-  for (i = 0; i < BLOCK_SIZE_GROUPS; i++) {
-    aom_tree_merge_probs(
-        av1_interintra_mode_tree, pre_fc->interintra_mode_prob[i],
-        counts->interintra_mode[i], fc->interintra_mode_prob[i]);
-  }
+  if (cm->allow_interintra_compound) {
+    for (i = 0; i < BLOCK_SIZE_GROUPS; ++i) {
+      if (is_interintra_allowed_bsize_group(i))
+        fc->interintra_prob[i] = av1_mode_mv_merge_probs(
+            pre_fc->interintra_prob[i], counts->interintra[i]);
+    }
+    for (i = 0; i < BLOCK_SIZE_GROUPS; i++) {
+      aom_tree_merge_probs(
+          av1_interintra_mode_tree, pre_fc->interintra_mode_prob[i],
+          counts->interintra_mode[i], fc->interintra_mode_prob[i]);
+    }
 #if CONFIG_WEDGE
-  for (i = 0; i < BLOCK_SIZES; ++i) {
-    if (is_interintra_allowed_bsize(i) && is_interintra_wedge_used(i))
-      fc->wedge_interintra_prob[i] = av1_mode_mv_merge_probs(
-          pre_fc->wedge_interintra_prob[i], counts->wedge_interintra[i]);
-  }
+    for (i = 0; i < BLOCK_SIZES; ++i) {
+      if (is_interintra_allowed_bsize(i) && is_interintra_wedge_used(i))
+        fc->wedge_interintra_prob[i] = av1_mode_mv_merge_probs(
+            pre_fc->wedge_interintra_prob[i], counts->wedge_interintra[i]);
+    }
 #endif  // CONFIG_WEDGE
+  }
 #endif  // CONFIG_INTERINTRA
 
 #if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
-  for (i = 0; i < BLOCK_SIZES; ++i) {
-    aom_tree_merge_probs(av1_compound_type_tree, pre_fc->compound_type_prob[i],
-                         counts->compound_interinter[i],
-                         fc->compound_type_prob[i]);
+  if (cm->allow_masked_compound) {
+    for (i = 0; i < BLOCK_SIZES; ++i) {
+      aom_tree_merge_probs(
+          av1_compound_type_tree, pre_fc->compound_type_prob[i],
+          counts->compound_interinter[i], fc->compound_type_prob[i]);
+    }
   }
 #endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
 #endif  // CONFIG_EXT_INTER
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 80a3b0a..3dc2a48 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -211,6 +211,14 @@
 #if CONFIG_PALETTE || CONFIG_INTRABC
   int allow_screen_content_tools;
 #endif  // CONFIG_PALETTE || CONFIG_INTRABC
+#if CONFIG_EXT_INTER
+#if CONFIG_INTERINTRA
+  int allow_interintra_compound;
+#endif  // CONFIG_INTERINTRA
+#if CONFIG_WEDGE || CONFIG_COMPOUND_SEGMENT
+  int allow_masked_compound;
+#endif  // CONFIG_WEDGE || CONFIG_COMPOUND_SEGMENT
+#endif  // CONFIG_EXT_INTER
 
   // Flag signaling which frame contexts should be reset to default values.
   RESET_FRAME_CONTEXT_MODE reset_frame_context;
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index 0b39bc7..1a12cd5 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -4057,6 +4057,28 @@
 }
 #endif
 
+#if CONFIG_EXT_INTER
+static void read_compound_tools(AV1_COMMON *cm,
+                                struct aom_read_bit_buffer *rb) {
+  (void)cm;
+  (void)rb;
+#if CONFIG_INTERINTRA
+  if (!frame_is_intra_only(cm) && cm->reference_mode != COMPOUND_REFERENCE) {
+    cm->allow_interintra_compound = aom_rb_read_bit(rb);
+  } else {
+    cm->allow_interintra_compound = 0;
+  }
+#endif  // CONFIG_INTERINTRA
+#if CONFIG_WEDGE || CONFIG_COMPOUND_SEGMENT
+  if (!frame_is_intra_only(cm) && cm->reference_mode != SINGLE_REFERENCE) {
+    cm->allow_masked_compound = aom_rb_read_bit(rb);
+  } else {
+    cm->allow_masked_compound = 0;
+  }
+#endif  // CONFIG_WEDGE || CONFIG_COMPOUND_SEGMENT
+}
+#endif  // CONFIG_EXT_INTER
+
 static size_t read_uncompressed_header(AV1Decoder *pbi,
                                        struct aom_read_bit_buffer *rb) {
   AV1_COMMON *const cm = &pbi->common;
@@ -4467,6 +4489,9 @@
   setup_segmentation_dequant(cm);
   cm->tx_mode = read_tx_mode(cm, xd, rb);
   cm->reference_mode = read_frame_reference_mode(cm, rb);
+#if CONFIG_EXT_INTER
+  read_compound_tools(cm, rb);
+#endif  // CONFIG_EXT_INTER
 
 #if CONFIG_EXT_TX
   cm->reduced_tx_set_used = aom_rb_read_bit(rb);
@@ -4762,8 +4787,8 @@
 #if CONFIG_EXT_INTER
     read_inter_compound_mode_probs(fc, &r);
 #if CONFIG_INTERINTRA
-    if (cm->reference_mode != COMPOUND_REFERENCE) {
-#if CONFIG_INTERINTRA
+    if (cm->reference_mode != COMPOUND_REFERENCE &&
+        cm->allow_interintra_compound) {
       for (i = 0; i < BLOCK_SIZE_GROUPS; i++) {
         if (is_interintra_allowed_bsize_group(i)) {
           av1_diff_update_prob(&r, &fc->interintra_prob[i], ACCT_STR);
@@ -4780,11 +4805,10 @@
         }
       }
 #endif  // CONFIG_WEDGE
-#endif  // CONFIG_INTERINTRA
     }
 #endif  // CONFIG_INTERINTRA
 #if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
-    if (cm->reference_mode != SINGLE_REFERENCE) {
+    if (cm->reference_mode != SINGLE_REFERENCE && cm->allow_masked_compound) {
       for (i = 0; i < BLOCK_SIZES; i++) {
         for (j = 0; j < COMPOUND_TYPES - 1; j++) {
           av1_diff_update_prob(&r, &fc->compound_type_prob[i][j], ACCT_STR);
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 81d3dcc..32096c4 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2029,7 +2029,7 @@
 #if CONFIG_SUPERTX
       !supertx_enabled &&
 #endif
-      is_interintra_allowed(mbmi)) {
+      cm->allow_interintra_compound && is_interintra_allowed(mbmi)) {
     const int bsize_group = size_group_lookup[bsize];
     const int interintra =
         aom_read(r, cm->fc->interintra_prob[bsize_group], ACCT_STR);
@@ -2108,22 +2108,24 @@
       ) {
     if (is_any_masked_compound_used(bsize)) {
 #if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
-      mbmi->interinter_compound_type =
-          aom_read_tree(r, av1_compound_type_tree,
-                        cm->fc->compound_type_prob[bsize], ACCT_STR);
-#endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
+      if (cm->allow_masked_compound) {
+        mbmi->interinter_compound_type =
+            aom_read_tree(r, av1_compound_type_tree,
+                          cm->fc->compound_type_prob[bsize], ACCT_STR);
 #if CONFIG_WEDGE
-      if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
-        mbmi->wedge_index =
-            aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
-        mbmi->wedge_sign = aom_read_bit(r, ACCT_STR);
-      }
+        if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+          mbmi->wedge_index =
+              aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
+          mbmi->wedge_sign = aom_read_bit(r, ACCT_STR);
+        }
 #endif  // CONFIG_WEDGE
 #if CONFIG_COMPOUND_SEGMENT
-      if (mbmi->interinter_compound_type == COMPOUND_SEG) {
-        mbmi->mask_type = aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
-      }
+        if (mbmi->interinter_compound_type == COMPOUND_SEG) {
+          mbmi->mask_type = aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
+        }
 #endif  // CONFIG_COMPOUND_SEGMENT
+      }
+#endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
     } else {
       mbmi->interinter_compound_type = COMPOUND_AVERAGE;
     }
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index ea0136c..7a6a824 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1962,7 +1962,7 @@
 #if CONFIG_SUPERTX
         !supertx_enabled &&
 #endif  // CONFIG_SUPERTX
-        is_interintra_allowed(mbmi)) {
+        cpi->common.allow_interintra_compound && is_interintra_allowed(mbmi)) {
       const int interintra = mbmi->ref_frame[1] == INTRA_FRAME;
       const int bsize_group = size_group_lookup[bsize];
       aom_write(w, interintra, cm->fc->interintra_prob[bsize_group]);
@@ -2000,21 +2000,23 @@
 #endif  // CONFIG_MOTION_VAR
         && is_any_masked_compound_used(bsize)) {
 #if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
-      av1_write_token(w, av1_compound_type_tree,
-                      cm->fc->compound_type_prob[bsize],
-                      &compound_type_encodings[mbmi->interinter_compound_type]);
-#endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
+      if (cm->allow_masked_compound) {
+        av1_write_token(
+            w, av1_compound_type_tree, cm->fc->compound_type_prob[bsize],
+            &compound_type_encodings[mbmi->interinter_compound_type]);
 #if CONFIG_WEDGE
-      if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
-        aom_write_literal(w, mbmi->wedge_index, get_wedge_bits_lookup(bsize));
-        aom_write_bit(w, mbmi->wedge_sign);
-      }
+        if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+          aom_write_literal(w, mbmi->wedge_index, get_wedge_bits_lookup(bsize));
+          aom_write_bit(w, mbmi->wedge_sign);
+        }
 #endif  // CONFIG_WEDGE
 #if CONFIG_COMPOUND_SEGMENT
-      if (mbmi->interinter_compound_type == COMPOUND_SEG) {
-        aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
-      }
+        if (mbmi->interinter_compound_type == COMPOUND_SEG) {
+          aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
+        }
 #endif  // CONFIG_COMPOUND_SEGMENT
+      }
+#endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
     }
 #endif  // CONFIG_EXT_INTER
 
@@ -4204,6 +4206,28 @@
 }
 #endif
 
+#if CONFIG_EXT_INTER
+static void write_compound_tools(const AV1_COMMON *cm,
+                                 struct aom_write_bit_buffer *wb) {
+  (void)cm;
+  (void)wb;
+#if CONFIG_INTERINTRA
+  if (!frame_is_intra_only(cm) && cm->reference_mode != COMPOUND_REFERENCE) {
+    aom_wb_write_bit(wb, cm->allow_interintra_compound);
+  } else {
+    assert(cm->allow_interintra_compound == 0);
+  }
+#endif  // CONFIG_INTERINTRA
+#if CONFIG_WEDGE || CONFIG_COMPOUND_SEGMENT
+  if (!frame_is_intra_only(cm) && cm->reference_mode != SINGLE_REFERENCE) {
+    aom_wb_write_bit(wb, cm->allow_masked_compound);
+  } else {
+    assert(cm->allow_masked_compound == 0);
+  }
+#endif  // CONFIG_WEDGE || CONFIG_COMPOUND_SEGMENT
+}
+#endif  // CONFIG_EXT_INTER
+
 static void write_uncompressed_header(AV1_COMP *cpi,
                                       struct aom_write_bit_buffer *wb) {
   AV1_COMMON *const cm = &cpi->common;
@@ -4461,6 +4485,9 @@
     if (!use_hybrid_pred) aom_wb_write_bit(wb, use_compound_pred);
 #endif  // !CONFIG_REF_ADAPT
   }
+#if CONFIG_EXT_INTER
+  write_compound_tools(cm, wb);
+#endif  // CONFIG_EXT_INTER
 
 #if CONFIG_EXT_TX
   aom_wb_write_bit(wb, cm->reduced_tx_set_used);
@@ -4682,8 +4709,8 @@
 #if CONFIG_EXT_INTER
     update_inter_compound_mode_probs(cm, probwt, header_bc);
 #if CONFIG_INTERINTRA
-    if (cm->reference_mode != COMPOUND_REFERENCE) {
-#if CONFIG_INTERINTRA
+    if (cm->reference_mode != COMPOUND_REFERENCE &&
+        cm->allow_interintra_compound) {
       for (i = 0; i < BLOCK_SIZE_GROUPS; i++) {
         if (is_interintra_allowed_bsize_group(i)) {
           av1_cond_prob_diff_update(header_bc, &fc->interintra_prob[i],
@@ -4702,11 +4729,10 @@
                                     cm->counts.wedge_interintra[i], probwt);
       }
 #endif  // CONFIG_WEDGE
-#endif  // CONFIG_INTERINTRA
     }
 #endif  // CONFIG_INTERINTRA
 #if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
-    if (cm->reference_mode != SINGLE_REFERENCE) {
+    if (cm->reference_mode != SINGLE_REFERENCE && cm->allow_masked_compound) {
       for (i = 0; i < BLOCK_SIZES; i++)
         prob_diff_update(av1_compound_type_tree, fc->compound_type_prob[i],
                          cm->counts.compound_interinter[i], COMPOUND_TYPES,
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 884c0f0..485c977 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -2176,7 +2176,7 @@
 #if CONFIG_SUPERTX
             !supertx_enabled &&
 #endif
-            is_interintra_allowed(mbmi)) {
+            cm->allow_interintra_compound && is_interintra_allowed(mbmi)) {
           const int bsize_group = size_group_lookup[bsize];
           if (mbmi->ref_frame[1] == INTRA_FRAME) {
             counts->interintra[bsize_group][1]++;
@@ -5409,6 +5409,20 @@
 #endif
 }
 
+#if CONFIG_EXT_INTER
+static void make_consistent_compound_tools(AV1_COMMON *cm) {
+  (void)cm;
+#if CONFIG_INTERINTRA
+  if (frame_is_intra_only(cm) || cm->reference_mode == COMPOUND_REFERENCE)
+    cm->allow_interintra_compound = 0;
+#endif  // CONFIG_INTERINTRA
+#if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
+  if (frame_is_intra_only(cm) || cm->reference_mode == SINGLE_REFERENCE)
+    cm->allow_masked_compound = 0;
+#endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
+}
+#endif  // CONFIG_EXT_INTER
+
 void av1_encode_frame(AV1_COMP *cpi) {
   AV1_COMMON *const cm = &cpi->common;
 #if CONFIG_EXT_TX
@@ -5496,6 +5510,9 @@
     cm->interp_filter = SWITCHABLE;
 #endif
 
+#if CONFIG_EXT_INTER
+    make_consistent_compound_tools(cm);
+#endif  // CONFIG_EXT_INTER
     encode_frame_internal(cpi);
 
     for (i = 0; i < REFERENCE_MODES; ++i)
@@ -5696,6 +5713,9 @@
     }
 #endif
   } else {
+#if CONFIG_EXT_INTER
+    make_consistent_compound_tools(cm);
+#endif  // CONFIG_EXT_INTER
     encode_frame_internal(cpi);
   }
 }
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 8901f0a..ed7bb41 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -1913,6 +1913,18 @@
                   aom_calloc(cm->mi_rows * cm->mi_cols, 1));
 }
 
+#if CONFIG_EXT_INTER
+void set_compound_tools(AV1_COMMON *cm) {
+  (void)cm;
+#if CONFIG_INTERINTRA
+  cm->allow_interintra_compound = 1;
+#endif  // CONFIG_INTERINTRA
+#if CONFIG_WEDGE || CONFIG_COMPOUND_SEGMENT
+  cm->allow_masked_compound = 1;
+#endif  // CONFIG_WEDGE || CONFIG_COMPOUND_SEGMENT
+}
+#endif  // CONFIG_EXT_INTER
+
 void av1_change_config(struct AV1_COMP *cpi, const AV1EncoderConfig *oxcf) {
   AV1_COMMON *const cm = &cpi->common;
   RATE_CONTROL *const rc = &cpi->rc;
@@ -1965,7 +1977,9 @@
     av1_setup_pc_tree(&cpi->common, &cpi->td);
   }
 #endif  // CONFIG_PALETTE
-
+#if CONFIG_EXT_INTER
+  set_compound_tools(cm);
+#endif  // CONFIG_EXT_INTER
   av1_reset_segment_features(cm);
   av1_set_high_precision_mv(cpi, 0);
 
@@ -3694,6 +3708,9 @@
   av1_set_rd_speed_thresholds(cpi);
   av1_set_rd_speed_thresholds_sub8x8(cpi);
   cpi->common.interp_filter = cpi->sf.default_interp_filter;
+#if CONFIG_EXT_INTER
+  if (!frame_is_intra_only(&cpi->common)) set_compound_tools(&cpi->common);
+#endif  // CONFIG_EXT_INTER
 }
 
 static void set_size_dependent_vars(AV1_COMP *cpi, int *q, int *bottom_index,
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 993e15c..13cfaa2 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7776,6 +7776,7 @@
   int wedge_sign = 0;
 
   assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
+  assert(cpi->common.allow_masked_compound);
 
   if (cpi->sf.fast_wedge_sign_estimate) {
     wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
@@ -7889,6 +7890,7 @@
   int wedge_index = -1;
 
   assert(is_interintra_wedge_used(bsize));
+  assert(cpi->common.allow_interintra_compound);
 
   rd = pick_wedge_fixed_sign(cpi, x, bsize, p0, p1, 0, &wedge_index);
 
@@ -8690,6 +8692,11 @@
   *args->compmode_interinter_cost = 0;
   mbmi->interinter_compound_type = COMPOUND_AVERAGE;
 
+#if CONFIG_INTERINTRA
+  if (!cm->allow_interintra_compound && is_comp_interintra_pred)
+    return INT64_MAX;
+#endif  // CONFIG_INTERINTRA
+
   // is_comp_interintra_pred implies !is_comp_pred
   assert(!is_comp_interintra_pred || (!is_comp_pred));
   // is_comp_interintra_pred implies is_interintra_allowed(mbmi->sb_type)
@@ -8898,6 +8905,9 @@
     int strides[1] = { bw };
     int tmp_rate_mv;
     int masked_compound_used = is_any_masked_compound_used(bsize);
+#if CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
+    masked_compound_used = masked_compound_used && cm->allow_masked_compound;
+#endif  // CONFIG_COMPOUND_SEGMENT || CONFIG_WEDGE
     COMPOUND_TYPE cur_type;
 
     best_mv[0].as_int = cur_mv[0].as_int;
@@ -8919,6 +8929,7 @@
     }
 
     for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
+      if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break;
       if (!is_interinter_compound_used(cur_type, bsize)) break;
       tmp_rate_mv = rate_mv;
       best_rd_cur = INT64_MAX;