Revert "JNT_COMP: turn off for one_sided_compound"

This reverts commit 060e192bfc5e0b8d1e96ff78b206c11e55e005c6.

Change-Id: I5700d351a3cbb682ec49a0efb9cca4d0e83f9a3a
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index f25758a..024ad99 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -1385,28 +1385,6 @@
 #endif  // CONFIG_NEW_MULTISYMBOL
 }
 
-#if CONFIG_JNT_COMP
-static INLINE int has_two_sided_comp_refs(const AV1_COMMON *cm,
-                                          const MB_MODE_INFO *mbmi) {
-  if (!has_second_ref(mbmi)) return 0;
-
-  const int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
-  const int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
-  if (bck_idx < 0 || fwd_idx < 0) return 0;
-
-  const int cur_frame_index = cm->cur_frame->cur_frame_offset;
-  const int bck_frame_index =
-      cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
-  const int fwd_frame_index =
-      cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
-
-  return ((bck_frame_index > cur_frame_index) &&
-          (fwd_frame_index < cur_frame_index)) ||
-         ((bck_frame_index < cur_frame_index) &&
-          (fwd_frame_index > cur_frame_index));
-}
-#endif  // CONFIG_JNT_COMP
-
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index 5d22dac..5002769 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -115,7 +115,7 @@
 
   if (above_mi) {
     const MB_MODE_INFO *above_mbmi = &above_mi->mbmi;
-    if (has_two_sided_comp_refs(cm, above_mbmi))
+    if (has_second_ref(above_mbmi))
       above_ctx = above_mbmi->compound_idx;
     else if (above_mbmi->ref_frame[0] == ALTREF_FRAME)
       above_ctx = 1;
@@ -123,7 +123,7 @@
 
   if (left_mi) {
     const MB_MODE_INFO *left_mbmi = &left_mi->mbmi;
-    if (has_two_sided_comp_refs(cm, left_mbmi))
+    if (has_second_ref(left_mbmi))
       left_ctx = left_mbmi->compound_idx;
     else if (left_mbmi->ref_frame[0] == ALTREF_FRAME)
       left_ctx = 1;
diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c
index 471ba13..4479a98 100644
--- a/av1/common/reconinter.c
+++ b/av1/common/reconinter.c
@@ -933,11 +933,11 @@
 
 #if CONFIG_JNT_COMP
 void av1_jnt_comp_weight_assign(const AV1_COMMON *cm, const MB_MODE_INFO *mbmi,
-                                int order_idx, int *fwd_offset,
-                                int *bck_offset) {
+                                int order_idx, int *fwd_offset, int *bck_offset,
+                                int is_compound) {
   assert(fwd_offset != NULL && bck_offset != NULL);
 
-  if (has_two_sided_comp_refs(cm, mbmi)) {
+  if (is_compound) {
     int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
     int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
     int bck_frame_index = 0, fwd_frame_index = 0;
@@ -1298,7 +1298,7 @@
         get_conv_params_no_round(ref, ref, plane, tmp_dst, MAX_SB_SIZE);
 #if CONFIG_JNT_COMP
     av1_jnt_comp_weight_assign(cm, &mi->mbmi, 0, &conv_params.fwd_offset,
-                               &conv_params.bck_offset);
+                               &conv_params.bck_offset, is_compound);
 #endif  // CONFIG_JNT_COMP
 
 #else
diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h
index 3e3260a..272e4f6 100644
--- a/av1/common/reconinter.h
+++ b/av1/common/reconinter.h
@@ -559,10 +559,11 @@
                                               int ext_dst_stride0[3],
                                               uint8_t *ext_dst1[3],
                                               int ext_dst_stride1[3]);
+
 #if CONFIG_JNT_COMP
 void av1_jnt_comp_weight_assign(const AV1_COMMON *cm, const MB_MODE_INFO *mbmi,
-                                int order_idx, int *fwd_offset,
-                                int *bck_offset);
+                                int order_idx, int *fwd_offset, int *bck_offset,
+                                int is_compound);
 #endif  // CONFIG_JNT_COMP
 
 #ifdef __cplusplus
diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c
index 7752e49..d885269 100644
--- a/av1/decoder/decodemv.c
+++ b/av1/decoder/decodemv.c
@@ -2155,7 +2155,7 @@
   is_compound = has_second_ref(mbmi);
 
 #if CONFIG_JNT_COMP
-  if (has_two_sided_comp_refs(cm, mbmi)) {
+  if (is_compound) {
     const int comp_index_ctx = get_comp_index_context(cm, xd);
 #if CONFIG_NEW_MULTISYMBOL
     mbmi->compound_idx = aom_read_symbol(
@@ -2166,8 +2166,6 @@
 #endif  // CONFIG_NEW_MULTISYMBOL
     if (xd->counts)
       ++xd->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
-  } else {
-    mbmi->compound_idx = 1;
   }
 #endif  // CONFIG_JNT_COMP
 
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index ada55a6..550afde 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1499,13 +1499,13 @@
 
 #if CONFIG_JNT_COMP
 #if CONFIG_NEW_MULTISYMBOL
-    if (has_two_sided_comp_refs(cm, mbmi)) {
+    if (has_second_ref(mbmi)) {
       const int comp_index_ctx = get_comp_index_context(cm, xd);
       aom_write_symbol(w, mbmi->compound_idx,
                        ec_ctx->compound_index_cdf[comp_index_ctx], 2);
     }
 #else
-    if (has_two_sided_comp_refs(cm, mbmi)) {
+    if (has_second_ref(mbmi)) {
       const int comp_index_ctx = get_comp_index_context(cm, xd);
       aom_write(w, mbmi->compound_idx,
                 ec_ctx->compound_index_probs[comp_index_ctx]);
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 399fa85..8e6aac2 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -660,7 +660,7 @@
 
 #if CONFIG_JNT_COMP
 #if !CONFIG_NEW_MULTISYMBOL
-  if (has_two_sided_comp_refs(cm, mbmi)) {
+  if (has_second_ref(mbmi)) {
     const int comp_index_ctx = get_comp_index_context(cm, xd);
     ++td->counts->compound_index[comp_index_ctx][mbmi->compound_idx];
   }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 02408b6..abfa631 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5990,7 +5990,7 @@
 #if CONFIG_JNT_COMP
     const int order_idx = id != 0;
     av1_jnt_comp_weight_assign(cm, mbmi, order_idx, &xd->jcp_param.fwd_offset,
-                               &xd->jcp_param.bck_offset);
+                               &xd->jcp_param.bck_offset, 1);
 #endif  // CONFIG_JNT_COMP
 
     // Do compound motion search on the current reference frame.
@@ -6706,7 +6706,7 @@
 
 #if CONFIG_JNT_COMP
   av1_jnt_comp_weight_assign(cm, mbmi, 0, &xd->jcp_param.fwd_offset,
-                             &xd->jcp_param.bck_offset);
+                             &xd->jcp_param.bck_offset, 1);
 #endif  // CONFIG_JNT_COMP
 
   if (scaled_ref_frame) {
@@ -8245,7 +8245,7 @@
 #endif  // CONFIG_COMPOUND_SINGLEREF
 
 #if CONFIG_JNT_COMP
-  if (has_two_sided_comp_refs(cm, mbmi)) {
+  if (is_comp_pred) {
     const int comp_index_ctx = get_comp_index_context(cm, xd);
     rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
   }
@@ -9975,7 +9975,6 @@
 #if CONFIG_JNT_COMP
       {
         int cum_rate = rate2;
-        mbmi->compound_idx = 1;
         MB_MODE_INFO backup_mbmi = *mbmi;
 
         int_mv backup_frame_mv[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
@@ -9993,7 +9992,6 @@
 
         for (int comp_idx = 0; comp_idx < 1 + has_second_ref(mbmi);
              ++comp_idx) {
-          if (comp_idx == 0 && !has_two_sided_comp_refs(cm, mbmi)) continue;
           RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
           av1_init_rd_stats(&rd_stats);
           av1_init_rd_stats(&rd_stats_y);
@@ -10146,8 +10144,6 @@
           ref_idx = sidx;
           if (has_second_ref(mbmi)) ref_idx /= 2;
           mbmi->compound_idx = sidx % 2;
-          if (mbmi->compound_idx == 0 && !has_two_sided_comp_refs(cm, mbmi))
-            continue;
 #endif  // CONFIG_JNT_COMP
 
           av1_invalid_rd_stats(&tmp_rd_stats);