JNT_COMP: Refactor code

The refactoring serves two purposes:
1. Separate code paths for jnt_comp and original compound average
computation. It provides function interface for jnt_comp while leaving
original compound average computation unchanged. In near future, SIMD
functions can be added for jnt_comp using the interface.

2. Previous implementation uses a hack on second_pred. But it may cause
segmentation fault when the test clip is small. As reported in Issue
944. This refactoring removes hacking and make it possible to address
the seg fault problem in the future.

Change-Id: Idd2cb99f6c77dae03d32ccfa1f9cbed1d7eed067
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 36c3aed..92249a1 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5789,56 +5789,6 @@
   return 1;
 }
 
-#if CONFIG_JNT_COMP
-static void jnt_comp_weight_assign(const AV1_COMMON *cm,
-                                   const MB_MODE_INFO *mbmi, int order_idx,
-                                   uint8_t *second_pred) {
-  if (mbmi->compound_idx) {
-    second_pred[4096] = -1;
-    second_pred[4097] = -1;
-  } else {
-    int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
-    int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
-    int bck_frame_index = 0, fwd_frame_index = 0;
-    int cur_frame_index = cm->cur_frame->cur_frame_offset;
-
-    if (bck_idx >= 0) {
-      bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
-    }
-
-    if (fwd_idx >= 0) {
-      fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
-    }
-
-    const double fwd = abs(fwd_frame_index - cur_frame_index);
-    const double bck = abs(cur_frame_index - bck_frame_index);
-    int order;
-    double ratio;
-
-    if (COMPOUND_WEIGHT_MODE == DIST) {
-      if (fwd > bck) {
-        ratio = (bck != 0) ? fwd / bck : 5.0;
-        order = 0;
-      } else {
-        ratio = (fwd != 0) ? bck / fwd : 5.0;
-        order = 1;
-      }
-      int quant_dist_idx;
-      for (quant_dist_idx = 0; quant_dist_idx < 4; ++quant_dist_idx) {
-        if (ratio < quant_dist_category[quant_dist_idx]) break;
-      }
-      second_pred[4096] =
-          quant_dist_lookup_table[order_idx][quant_dist_idx][order];
-      second_pred[4097] =
-          quant_dist_lookup_table[order_idx][quant_dist_idx][1 - order];
-    } else {
-      second_pred[4096] = (DIST_PRECISION >> 1);
-      second_pred[4097] = (DIST_PRECISION >> 1);
-    }
-  }
-}
-#endif  // CONFIG_JNT_COMP
-
 static void joint_motion_search(const AV1_COMP *cpi, MACROBLOCK *x,
                                 BLOCK_SIZE bsize, int_mv *frame_mv,
 #if CONFIG_COMPOUND_SINGLEREF
@@ -5901,13 +5851,8 @@
 
 // Prediction buffer from second frame.
 #if CONFIG_HIGHBITDEPTH
-#if CONFIG_JNT_COMP
-  DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE + 2]);
-  uint8_t *second_pred;
-#else
   DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE]);
   uint8_t *second_pred;
-#endif  // CONFIG_JNT_COMP
 #else   // CONFIG_HIGHBITDEPTH
   DECLARE_ALIGNED(16, uint8_t, second_pred[MAX_SB_SQUARE]);
 #endif  // CONFIG_HIGHBITDEPTH
@@ -6046,7 +5991,8 @@
 
 #if CONFIG_JNT_COMP
     const int order_idx = id != 0;
-    jnt_comp_weight_assign(cm, mbmi, order_idx, second_pred);
+    av1_jnt_comp_weight_assign(cm, mbmi, order_idx, &xd->jcp_param.fwd_offset,
+                               &xd->jcp_param.bck_offset, 1);
 #endif  // CONFIG_JNT_COMP
 
     // Do compound motion search on the current reference frame.
@@ -6761,7 +6707,8 @@
 #endif  // CONFIG_HIGHBITDEPTH
 
 #if CONFIG_JNT_COMP
-  jnt_comp_weight_assign(cm, mbmi, 0, second_pred);
+  av1_jnt_comp_weight_assign(cm, mbmi, 0, &xd->jcp_param.fwd_offset,
+                             &xd->jcp_param.bck_offset, 1);
 #endif  // CONFIG_JNT_COMP
 
   if (scaled_ref_frame) {
@@ -6930,11 +6877,7 @@
 
 // Prediction buffer from second frame.
 #if CONFIG_HIGHBITDEPTH
-#if CONFIG_JNT_COMP
-  DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE + 2]);
-#else
   DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE]);
-#endif  // CONFIG_JNT_COMP
   uint8_t *second_pred;
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
     second_pred = CONVERT_TO_BYTEPTR(second_pred_alloc_16);