JNT_COMP: Simplify logic on inter-inter comp modes

This patch simplies the checking criteria for the two groups of
compound modes. It also makes the encoder side cdf update inside the
RD loop consistent with that in the bitstream.

Experimental results on Google test sets (30 frames of lowres and
midres) confirm this patch obtains identical coding performance.

Change-Id: I170eea91f7d2be2170df544cfc2c692b09aa82d6
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 51ee501..60932d6 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2196,23 +2196,21 @@
 
 #if CONFIG_JNT_COMP
   // init
-  mbmi->comp_group_idx = 1;
+  mbmi->comp_group_idx = 0;
   mbmi->compound_idx = 1;
   mbmi->interinter_compound_type = COMPOUND_AVERAGE;
 
-  // read idx to indicate current compound inter prediction mode group
-  int masked_compound_used = is_any_masked_compound_used(bsize);
-  masked_compound_used = masked_compound_used && cm->allow_masked_compound;
-
   if (has_second_ref(mbmi)) {
+    // Read idx to indicate current compound inter prediction mode group
+    const int masked_compound_used =
+        is_any_masked_compound_used(bsize) && cm->allow_masked_compound;
+
     if (masked_compound_used) {
       const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
       mbmi->comp_group_idx = aom_read_symbol(
           r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
       if (xd->counts)
         ++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx];
-    } else {
-      mbmi->comp_group_idx = 0;
     }
 
     if (mbmi->comp_group_idx == 0) {
@@ -2222,44 +2220,33 @@
 
       if (xd->counts)
         ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
-
-      if (mbmi->compound_idx) mbmi->interinter_compound_type = COMPOUND_AVERAGE;
     } else {
+      assert(cm->reference_mode != SINGLE_REFERENCE &&
+             is_inter_compound_mode(mbmi->mode) &&
+             mbmi->motion_mode == SIMPLE_TRANSLATION);
+      assert(masked_compound_used);
+
       // compound_segment, wedge
-      mbmi->interinter_compound_type = COMPOUND_AVERAGE;
-      if (cm->reference_mode != SINGLE_REFERENCE &&
-          is_inter_compound_mode(mbmi->mode) &&
-          mbmi->motion_mode == SIMPLE_TRANSLATION
-#if CONFIG_EXT_SKIP
-          && !mbmi->skip_mode
-#endif  // CONFIG_EXT_SKIP
-          ) {
-        if (is_any_masked_compound_used(bsize)) {
-          if (cm->allow_masked_compound) {
-            if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
-              mbmi->interinter_compound_type =
-                  1 + aom_read_symbol(r, ec_ctx->compound_type_cdf[bsize],
-                                      COMPOUND_TYPES - 1, ACCT_STR);
-            else
-              mbmi->interinter_compound_type = COMPOUND_SEG;
+      if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+        mbmi->interinter_compound_type =
+            1 + aom_read_symbol(r, ec_ctx->compound_type_cdf[bsize],
+                                COMPOUND_TYPES - 1, ACCT_STR);
+      else
+        mbmi->interinter_compound_type = COMPOUND_SEG;
 
-            if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
-              assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
-              mbmi->wedge_index =
-                  aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
-              mbmi->wedge_sign = aom_read_bit(r, ACCT_STR);
-            }
-            if (mbmi->interinter_compound_type == COMPOUND_SEG) {
-              mbmi->mask_type =
-                  aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
-            }
-          }
-        }
-
-        if (xd->counts)
-          xd->counts->compound_interinter[bsize]
-                                         [mbmi->interinter_compound_type - 1]++;
+      if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+        assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
+        mbmi->wedge_index =
+            aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR);
+        mbmi->wedge_sign = aom_read_bit(r, ACCT_STR);
+      } else {
+        assert(mbmi->interinter_compound_type == COMPOUND_SEG);
+        mbmi->mask_type = aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
       }
+
+      if (xd->counts)
+        xd->counts
+            ->compound_interinter[bsize][mbmi->interinter_compound_type - 1]++;
     }
   }
 #else  // CONFIG_JNT_COMP
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 26bc975..607357d 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1460,16 +1460,16 @@
     // First write idx to indicate current compound inter prediction mode group
     // Group A (0): jnt_comp, compound_average
     // Group B (1): interintra, compound_segment, wedge
-    int masked_compound_used = is_any_masked_compound_used(bsize);
-    masked_compound_used = masked_compound_used && cm->allow_masked_compound;
-
     if (has_second_ref(mbmi)) {
-      if (masked_compound_used) {
-        assert(mbmi->comp_group_idx == 0);
+      const int masked_compound_used =
+          is_any_masked_compound_used(bsize) && cm->allow_masked_compound;
 
+      if (masked_compound_used) {
         const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
         aom_write_symbol(w, mbmi->comp_group_idx,
                          ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
+      } else {
+        assert(mbmi->comp_group_idx == 0);
       }
 
       if (mbmi->comp_group_idx == 0) {
@@ -1480,27 +1480,26 @@
         aom_write_symbol(w, mbmi->compound_idx,
                          ec_ctx->compound_index_cdf[comp_index_ctx], 2);
       } else {
+        assert(cpi->common.reference_mode != SINGLE_REFERENCE &&
+               is_inter_compound_mode(mbmi->mode) &&
+               mbmi->motion_mode == SIMPLE_TRANSLATION);
+        assert(masked_compound_used);
+        // compound_segment, wedge
         assert(mbmi->interinter_compound_type == COMPOUND_WEDGE ||
                mbmi->interinter_compound_type == COMPOUND_SEG);
-        // compound_segment, wedge
-        if (cpi->common.reference_mode != SINGLE_REFERENCE &&
-            is_inter_compound_mode(mbmi->mode) &&
-            mbmi->motion_mode == SIMPLE_TRANSLATION &&
-            is_any_masked_compound_used(bsize) && cm->allow_masked_compound) {
-          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
-            aom_write_symbol(w, mbmi->interinter_compound_type - 1,
-                             ec_ctx->compound_type_cdf[bsize],
-                             COMPOUND_TYPES - 1);
 
-          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize) &&
-              mbmi->interinter_compound_type == COMPOUND_WEDGE) {
-            aom_write_literal(w, mbmi->wedge_index,
-                              get_wedge_bits_lookup(bsize));
-            aom_write_bit(w, mbmi->wedge_sign);
-          }
-          if (mbmi->interinter_compound_type == COMPOUND_SEG) {
-            aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
-          }
+        if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
+          aom_write_symbol(w, mbmi->interinter_compound_type - 1,
+                           ec_ctx->compound_type_cdf[bsize],
+                           COMPOUND_TYPES - 1);
+
+        if (mbmi->interinter_compound_type == COMPOUND_WEDGE) {
+          assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
+          aom_write_literal(w, mbmi->wedge_index, get_wedge_bits_lookup(bsize));
+          aom_write_bit(w, mbmi->wedge_sign);
+        } else {
+          assert(mbmi->interinter_compound_type == COMPOUND_SEG);
+          aom_write_literal(w, mbmi->mask_type, MAX_SEG_MASK_BITS);
         }
       }
     }
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index f3575b1..17715b3 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1125,7 +1125,7 @@
           else
             // This flag is also updated for 4x4 blocks
             rdc->single_ref_used_flag = 1;
-          if (is_comp_ref_allowed(mbmi->sb_type)) {
+          if (is_comp_ref_allowed(bsize)) {
             counts->comp_inter[av1_get_reference_mode_context(cm, xd)]
                               [has_second_ref(mbmi)]++;
             if (allow_update_cdf)
@@ -1237,31 +1237,39 @@
               if (mbmi->mode == NEARESTMV) wm_ctx = 2;
             }
 
-            counts->motion_mode[wm_ctx][mbmi->sb_type][mbmi->motion_mode]++;
+            counts->motion_mode[wm_ctx][bsize][mbmi->motion_mode]++;
             if (allow_update_cdf)
-              update_cdf(fc->motion_mode_cdf[wm_ctx][mbmi->sb_type],
-                         mbmi->motion_mode, MOTION_MODES);
+              update_cdf(fc->motion_mode_cdf[wm_ctx][bsize], mbmi->motion_mode,
+                         MOTION_MODES);
 #else
-            counts->motion_mode[mbmi->sb_type][mbmi->motion_mode]++;
+            counts->motion_mode[bsize][mbmi->motion_mode]++;
             if (allow_update_cdf)
-              update_cdf(fc->motion_mode_cdf[mbmi->sb_type], mbmi->motion_mode,
+              update_cdf(fc->motion_mode_cdf[bsize], mbmi->motion_mode,
                          MOTION_MODES);
 #endif  // CONFIG_EXT_WARPED_MOTION
           } else if (motion_allowed == OBMC_CAUSAL) {
-            counts->obmc[mbmi->sb_type][mbmi->motion_mode == OBMC_CAUSAL]++;
+            counts->obmc[bsize][mbmi->motion_mode == OBMC_CAUSAL]++;
             if (allow_update_cdf)
-              update_cdf(fc->obmc_cdf[mbmi->sb_type],
-                         mbmi->motion_mode == OBMC_CAUSAL, 2);
+              update_cdf(fc->obmc_cdf[bsize], mbmi->motion_mode == OBMC_CAUSAL,
+                         2);
           }
         }
 
 #if CONFIG_JNT_COMP
         if (has_second_ref(mbmi)) {
-          const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
-          ++counts->comp_group_idx[comp_group_idx_ctx][mbmi->comp_group_idx];
-          if (allow_update_cdf)
-            update_cdf(fc->comp_group_idx_cdf[comp_group_idx_ctx],
-                       mbmi->comp_group_idx, 2);
+          assert(cm->reference_mode != SINGLE_REFERENCE &&
+                 is_inter_compound_mode(mbmi->mode) &&
+                 mbmi->motion_mode == SIMPLE_TRANSLATION);
+
+          const int masked_compound_used =
+              is_any_masked_compound_used(bsize) && cm->allow_masked_compound;
+          if (masked_compound_used) {
+            const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
+            ++counts->comp_group_idx[comp_group_idx_ctx][mbmi->comp_group_idx];
+            if (allow_update_cdf)
+              update_cdf(fc->comp_group_idx_cdf[comp_group_idx_ctx],
+                         mbmi->comp_group_idx, 2);
+          }
 
           if (mbmi->comp_group_idx == 0) {
             const int comp_index_ctx = get_comp_index_context(cm, xd);
@@ -1269,23 +1277,19 @@
             if (allow_update_cdf)
               update_cdf(fc->compound_index_cdf[comp_index_ctx],
                          mbmi->compound_idx, 2);
-          }
-        }
-
-        if (cm->reference_mode != SINGLE_REFERENCE &&
-            is_inter_compound_mode(mbmi->mode) && mbmi->comp_group_idx &&
-            mbmi->motion_mode == SIMPLE_TRANSLATION) {
-          if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
-            counts->compound_interinter[bsize]
-                                       [mbmi->interinter_compound_type - 1]++;
-            if (allow_update_cdf)
-              update_cdf(fc->compound_type_cdf[bsize],
-                         mbmi->interinter_compound_type - 1,
-                         COMPOUND_TYPES - 1);
+          } else {
+            assert(masked_compound_used);
+            if (is_interinter_compound_used(COMPOUND_WEDGE, bsize)) {
+              counts->compound_interinter[bsize]
+                                         [mbmi->interinter_compound_type - 1]++;
+              if (allow_update_cdf)
+                update_cdf(fc->compound_type_cdf[bsize],
+                           mbmi->interinter_compound_type - 1,
+                           COMPOUND_TYPES - 1);
+            }
           }
         }
 #else   // CONFIG_JNT_COMP
-
         if (cm->reference_mode != SINGLE_REFERENCE &&
             is_inter_compound_mode(mbmi->mode) &&
             mbmi->motion_mode == SIMPLE_TRANSLATION) {