JNT_COMP: Refactor code

The refactoring serves two purposes:
1. Separate code paths for jnt_comp and original compound average
computation. It provides function interface for jnt_comp while leaving
original compound average computation unchanged. In near future, SIMD
functions can be added for jnt_comp using the interface.

2. Previous implementation uses a hack on second_pred. But it may cause
segmentation fault when the test clip is small. As reported in Issue
944. This refactoring removes hacking and make it possible to address
the seg fault problem in the future.

Change-Id: Idd2cb99f6c77dae03d32ccfa1f9cbed1d7eed067
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 1990e70..5d6ab41 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -2758,6 +2758,19 @@
   av1_set_speed_features_framesize_independent(cpi);
   av1_set_speed_features_framesize_dependent(cpi);
 
+#if CONFIG_JNT_COMP
+#define BFP(BT, SDF, SDAF, VF, SVF, SVAF, SDX3F, SDX8F, SDX4DF, JSDAF, JSVAF) \
+  cpi->fn_ptr[BT].sdf = SDF;                                                  \
+  cpi->fn_ptr[BT].sdaf = SDAF;                                                \
+  cpi->fn_ptr[BT].vf = VF;                                                    \
+  cpi->fn_ptr[BT].svf = SVF;                                                  \
+  cpi->fn_ptr[BT].svaf = SVAF;                                                \
+  cpi->fn_ptr[BT].sdx3f = SDX3F;                                              \
+  cpi->fn_ptr[BT].sdx8f = SDX8F;                                              \
+  cpi->fn_ptr[BT].sdx4df = SDX4DF;                                            \
+  cpi->fn_ptr[BT].jsdaf = JSDAF;                                              \
+  cpi->fn_ptr[BT].jsvaf = JSVAF;
+#else  // CONFIG_JNT_COMP
 #define BFP(BT, SDF, SDAF, VF, SVF, SVAF, SDX3F, SDX8F, SDX4DF) \
   cpi->fn_ptr[BT].sdf = SDF;                                    \
   cpi->fn_ptr[BT].sdaf = SDAF;                                  \
@@ -2767,7 +2780,142 @@
   cpi->fn_ptr[BT].sdx3f = SDX3F;                                \
   cpi->fn_ptr[BT].sdx8f = SDX8F;                                \
   cpi->fn_ptr[BT].sdx4df = SDX4DF;
+#endif  // CONFIG_JNT_COMP
 
+#if CONFIG_JNT_COMP
+#if CONFIG_EXT_PARTITION_TYPES
+  BFP(BLOCK_4X16, aom_sad4x16, aom_sad4x16_avg, aom_variance4x16,
+      aom_sub_pixel_variance4x16, aom_sub_pixel_avg_variance4x16, NULL, NULL,
+      aom_sad4x16x4d, aom_jnt_sad4x16_avg_c,
+      aom_jnt_sub_pixel_avg_variance4x16_c)
+
+  BFP(BLOCK_16X4, aom_sad16x4, aom_sad16x4_avg, aom_variance16x4,
+      aom_sub_pixel_variance16x4, aom_sub_pixel_avg_variance16x4, NULL, NULL,
+      aom_sad16x4x4d, aom_jnt_sad16x4_avg_c,
+      aom_jnt_sub_pixel_avg_variance16x4_c)
+
+  BFP(BLOCK_8X32, aom_sad8x32, aom_sad8x32_avg, aom_variance8x32,
+      aom_sub_pixel_variance8x32, aom_sub_pixel_avg_variance8x32, NULL, NULL,
+      aom_sad8x32x4d, aom_jnt_sad8x32_avg_c,
+      aom_jnt_sub_pixel_avg_variance8x32_c)
+
+  BFP(BLOCK_32X8, aom_sad32x8, aom_sad32x8_avg, aom_variance32x8,
+      aom_sub_pixel_variance32x8, aom_sub_pixel_avg_variance32x8, NULL, NULL,
+      aom_sad32x8x4d, aom_jnt_sad32x8_avg_c,
+      aom_jnt_sub_pixel_avg_variance32x8_c)
+
+  BFP(BLOCK_16X64, aom_sad16x64, aom_sad16x64_avg, aom_variance16x64,
+      aom_sub_pixel_variance16x64, aom_sub_pixel_avg_variance16x64, NULL, NULL,
+      aom_sad16x64x4d, aom_jnt_sad16x64_avg_c,
+      aom_jnt_sub_pixel_avg_variance16x64_c)
+
+  BFP(BLOCK_64X16, aom_sad64x16, aom_sad64x16_avg, aom_variance64x16,
+      aom_sub_pixel_variance64x16, aom_sub_pixel_avg_variance64x16, NULL, NULL,
+      aom_sad64x16x4d, aom_jnt_sad64x16_avg_c,
+      aom_jnt_sub_pixel_avg_variance64x16_c)
+
+#if CONFIG_EXT_PARTITION
+  BFP(BLOCK_32X128, aom_sad32x128, aom_sad32x128_avg, aom_variance32x128,
+      aom_sub_pixel_variance32x128, aom_sub_pixel_avg_variance32x128, NULL,
+      NULL, aom_sad32x128x4d, aom_jnt_sad32x128_avg_c,
+      aom_jnt_sub_pixel_avg_variance32x128_c)
+
+  BFP(BLOCK_128X32, aom_sad128x32, aom_sad128x32_avg, aom_variance128x32,
+      aom_sub_pixel_variance128x32, aom_sub_pixel_avg_variance128x32, NULL,
+      NULL, aom_sad128x32x4d, aom_jnt_sad128x32_avg_c,
+      aom_jnt_sub_pixel_avg_variance128x32_c)
+#endif  // CONFIG_EXT_PARTITION
+#endif  // CONFIG_EXT_PARTITION_TYPES
+
+#if CONFIG_EXT_PARTITION
+  BFP(BLOCK_128X128, aom_sad128x128, aom_sad128x128_avg, aom_variance128x128,
+      aom_sub_pixel_variance128x128, aom_sub_pixel_avg_variance128x128,
+      aom_sad128x128x3, aom_sad128x128x8, aom_sad128x128x4d,
+      aom_jnt_sad128x128_avg_c, aom_jnt_sub_pixel_avg_variance128x128_c)
+
+  BFP(BLOCK_128X64, aom_sad128x64, aom_sad128x64_avg, aom_variance128x64,
+      aom_sub_pixel_variance128x64, aom_sub_pixel_avg_variance128x64, NULL,
+      NULL, aom_sad128x64x4d, aom_jnt_sad128x64_avg_c,
+      aom_jnt_sub_pixel_avg_variance128x64_c)
+
+  BFP(BLOCK_64X128, aom_sad64x128, aom_sad64x128_avg, aom_variance64x128,
+      aom_sub_pixel_variance64x128, aom_sub_pixel_avg_variance64x128, NULL,
+      NULL, aom_sad64x128x4d, aom_jnt_sad64x128_avg_c,
+      aom_jnt_sub_pixel_avg_variance64x128_c)
+#endif  // CONFIG_EXT_PARTITION
+
+  BFP(BLOCK_32X16, aom_sad32x16, aom_sad32x16_avg, aom_variance32x16,
+      aom_sub_pixel_variance32x16, aom_sub_pixel_avg_variance32x16, NULL, NULL,
+      aom_sad32x16x4d, aom_jnt_sad32x16_avg_c,
+      aom_jnt_sub_pixel_avg_variance32x16_c)
+
+  BFP(BLOCK_16X32, aom_sad16x32, aom_sad16x32_avg, aom_variance16x32,
+      aom_sub_pixel_variance16x32, aom_sub_pixel_avg_variance16x32, NULL, NULL,
+      aom_sad16x32x4d, aom_jnt_sad16x32_avg_c,
+      aom_jnt_sub_pixel_avg_variance16x32_c)
+
+  BFP(BLOCK_64X32, aom_sad64x32, aom_sad64x32_avg, aom_variance64x32,
+      aom_sub_pixel_variance64x32, aom_sub_pixel_avg_variance64x32, NULL, NULL,
+      aom_sad64x32x4d, aom_jnt_sad64x32_avg_c,
+      aom_jnt_sub_pixel_avg_variance64x32_c)
+
+  BFP(BLOCK_32X64, aom_sad32x64, aom_sad32x64_avg, aom_variance32x64,
+      aom_sub_pixel_variance32x64, aom_sub_pixel_avg_variance32x64, NULL, NULL,
+      aom_sad32x64x4d, aom_jnt_sad32x64_avg_c,
+      aom_jnt_sub_pixel_avg_variance32x64_c)
+
+  BFP(BLOCK_32X32, aom_sad32x32, aom_sad32x32_avg, aom_variance32x32,
+      aom_sub_pixel_variance32x32, aom_sub_pixel_avg_variance32x32,
+      aom_sad32x32x3, aom_sad32x32x8, aom_sad32x32x4d, aom_jnt_sad32x32_avg_c,
+      aom_jnt_sub_pixel_avg_variance32x32_c)
+
+  BFP(BLOCK_64X64, aom_sad64x64, aom_sad64x64_avg, aom_variance64x64,
+      aom_sub_pixel_variance64x64, aom_sub_pixel_avg_variance64x64,
+      aom_sad64x64x3, aom_sad64x64x8, aom_sad64x64x4d, aom_jnt_sad64x64_avg_c,
+      aom_jnt_sub_pixel_avg_variance64x64_c)
+
+  BFP(BLOCK_16X16, aom_sad16x16, aom_sad16x16_avg, aom_variance16x16,
+      aom_sub_pixel_variance16x16, aom_sub_pixel_avg_variance16x16,
+      aom_sad16x16x3, aom_sad16x16x8, aom_sad16x16x4d, aom_jnt_sad16x16_avg_c,
+      aom_jnt_sub_pixel_avg_variance16x16_c)
+
+  BFP(BLOCK_16X8, aom_sad16x8, aom_sad16x8_avg, aom_variance16x8,
+      aom_sub_pixel_variance16x8, aom_sub_pixel_avg_variance16x8, aom_sad16x8x3,
+      aom_sad16x8x8, aom_sad16x8x4d, aom_jnt_sad16x8_avg_c,
+      aom_jnt_sub_pixel_avg_variance16x8_c)
+
+  BFP(BLOCK_8X16, aom_sad8x16, aom_sad8x16_avg, aom_variance8x16,
+      aom_sub_pixel_variance8x16, aom_sub_pixel_avg_variance8x16, aom_sad8x16x3,
+      aom_sad8x16x8, aom_sad8x16x4d, aom_jnt_sad8x16_avg_c,
+      aom_jnt_sub_pixel_avg_variance8x16_c)
+
+  BFP(BLOCK_8X8, aom_sad8x8, aom_sad8x8_avg, aom_variance8x8,
+      aom_sub_pixel_variance8x8, aom_sub_pixel_avg_variance8x8, aom_sad8x8x3,
+      aom_sad8x8x8, aom_sad8x8x4d, aom_jnt_sad8x8_avg_c,
+      aom_jnt_sub_pixel_avg_variance8x8_c)
+
+  BFP(BLOCK_8X4, aom_sad8x4, aom_sad8x4_avg, aom_variance8x4,
+      aom_sub_pixel_variance8x4, aom_sub_pixel_avg_variance8x4, NULL,
+      aom_sad8x4x8, aom_sad8x4x4d, aom_jnt_sad8x4_avg_c,
+      aom_jnt_sub_pixel_avg_variance8x4_c)
+
+  BFP(BLOCK_4X8, aom_sad4x8, aom_sad4x8_avg, aom_variance4x8,
+      aom_sub_pixel_variance4x8, aom_sub_pixel_avg_variance4x8, NULL,
+      aom_sad4x8x8, aom_sad4x8x4d, aom_jnt_sad4x8_avg_c,
+      aom_jnt_sub_pixel_avg_variance4x8_c)
+
+  BFP(BLOCK_4X4, aom_sad4x4, aom_sad4x4_avg, aom_variance4x4,
+      aom_sub_pixel_variance4x4, aom_sub_pixel_avg_variance4x4, aom_sad4x4x3,
+      aom_sad4x4x8, aom_sad4x4x4d, aom_jnt_sad4x4_avg_c,
+      aom_jnt_sub_pixel_avg_variance4x4_c)
+
+  BFP(BLOCK_2X2, NULL, NULL, aom_variance2x2, NULL, NULL, NULL, NULL, NULL,
+      NULL, NULL)
+  BFP(BLOCK_2X4, NULL, NULL, aom_variance2x4, NULL, NULL, NULL, NULL, NULL,
+      NULL, NULL)
+  BFP(BLOCK_4X2, NULL, NULL, aom_variance4x2, NULL, NULL, NULL, NULL, NULL,
+      NULL, NULL)
+#else  // CONFIG_JNT_COMP
 #if CONFIG_EXT_PARTITION_TYPES
   BFP(BLOCK_4X16, aom_sad4x16, aom_sad4x16_avg, aom_variance4x16,
       aom_sub_pixel_variance4x16, aom_sub_pixel_avg_variance4x16, NULL, NULL,
@@ -2818,59 +2966,6 @@
       NULL, aom_sad64x128x4d)
 #endif  // CONFIG_EXT_PARTITION
 
-#if CONFIG_JNT_COMP
-  BFP(BLOCK_32X16, aom_sad32x16, aom_sad32x16_avg_c, aom_variance32x16,
-      aom_sub_pixel_variance32x16, aom_sub_pixel_avg_variance32x16, NULL, NULL,
-      aom_sad32x16x4d)
-
-  BFP(BLOCK_16X32, aom_sad16x32, aom_sad16x32_avg_c, aom_variance16x32,
-      aom_sub_pixel_variance16x32, aom_sub_pixel_avg_variance16x32, NULL, NULL,
-      aom_sad16x32x4d)
-
-  BFP(BLOCK_64X32, aom_sad64x32, aom_sad64x32_avg_c, aom_variance64x32,
-      aom_sub_pixel_variance64x32, aom_sub_pixel_avg_variance64x32, NULL, NULL,
-      aom_sad64x32x4d)
-
-  BFP(BLOCK_32X64, aom_sad32x64, aom_sad32x64_avg_c, aom_variance32x64,
-      aom_sub_pixel_variance32x64, aom_sub_pixel_avg_variance32x64, NULL, NULL,
-      aom_sad32x64x4d)
-
-  BFP(BLOCK_32X32, aom_sad32x32, aom_sad32x32_avg_c, aom_variance32x32,
-      aom_sub_pixel_variance32x32, aom_sub_pixel_avg_variance32x32,
-      aom_sad32x32x3, aom_sad32x32x8, aom_sad32x32x4d)
-
-  BFP(BLOCK_64X64, aom_sad64x64, aom_sad64x64_avg_c, aom_variance64x64,
-      aom_sub_pixel_variance64x64, aom_sub_pixel_avg_variance64x64,
-      aom_sad64x64x3, aom_sad64x64x8, aom_sad64x64x4d)
-
-  BFP(BLOCK_16X16, aom_sad16x16, aom_sad16x16_avg_c, aom_variance16x16,
-      aom_sub_pixel_variance16x16, aom_sub_pixel_avg_variance16x16,
-      aom_sad16x16x3, aom_sad16x16x8, aom_sad16x16x4d)
-
-  BFP(BLOCK_16X8, aom_sad16x8, aom_sad16x8_avg_c, aom_variance16x8,
-      aom_sub_pixel_variance16x8, aom_sub_pixel_avg_variance16x8, aom_sad16x8x3,
-      aom_sad16x8x8, aom_sad16x8x4d)
-
-  BFP(BLOCK_8X16, aom_sad8x16, aom_sad8x16_avg_c, aom_variance8x16,
-      aom_sub_pixel_variance8x16, aom_sub_pixel_avg_variance8x16, aom_sad8x16x3,
-      aom_sad8x16x8, aom_sad8x16x4d)
-
-  BFP(BLOCK_8X8, aom_sad8x8, aom_sad8x8_avg_c, aom_variance8x8,
-      aom_sub_pixel_variance8x8, aom_sub_pixel_avg_variance8x8, aom_sad8x8x3,
-      aom_sad8x8x8, aom_sad8x8x4d)
-
-  BFP(BLOCK_8X4, aom_sad8x4, aom_sad8x4_avg_c, aom_variance8x4,
-      aom_sub_pixel_variance8x4, aom_sub_pixel_avg_variance8x4, NULL,
-      aom_sad8x4x8, aom_sad8x4x4d)
-
-  BFP(BLOCK_4X8, aom_sad4x8, aom_sad4x8_avg_c, aom_variance4x8,
-      aom_sub_pixel_variance4x8, aom_sub_pixel_avg_variance4x8, NULL,
-      aom_sad4x8x8, aom_sad4x8x4d)
-
-  BFP(BLOCK_4X4, aom_sad4x4, aom_sad4x4_avg_c, aom_variance4x4,
-      aom_sub_pixel_variance4x4, aom_sub_pixel_avg_variance4x4, aom_sad4x4x3,
-      aom_sad4x4x8, aom_sad4x4x4d)
-#else
   BFP(BLOCK_32X16, aom_sad32x16, aom_sad32x16_avg, aom_variance32x16,
       aom_sub_pixel_variance32x16, aom_sub_pixel_avg_variance32x16, NULL, NULL,
       aom_sad32x16x4d)
@@ -2922,11 +3017,11 @@
   BFP(BLOCK_4X4, aom_sad4x4, aom_sad4x4_avg, aom_variance4x4,
       aom_sub_pixel_variance4x4, aom_sub_pixel_avg_variance4x4, aom_sad4x4x3,
       aom_sad4x4x8, aom_sad4x4x4d)
-#endif  // CONFIG_JNT_COMP
 
   BFP(BLOCK_2X2, NULL, NULL, aom_variance2x2, NULL, NULL, NULL, NULL, NULL)
   BFP(BLOCK_2X4, NULL, NULL, aom_variance2x4, NULL, NULL, NULL, NULL, NULL)
   BFP(BLOCK_4X2, NULL, NULL, aom_variance4x2, NULL, NULL, NULL, NULL, NULL)
+#endif  // CONFIG_JNT_COMP
 
 #define OBFP(BT, OSDF, OVF, OSVF) \
   cpi->fn_ptr[BT].osdf = OSDF;    \
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index f4d132e..63f8105 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -176,6 +176,39 @@
 }
 
 /* checks if (r, c) has better score than previous best */
+#if CONFIG_JNT_COMP
+#define CHECK_BETTER(v, r, c)                                                \
+  if (c >= minc && c <= maxc && r >= minr && r <= maxr) {                    \
+    MV this_mv = { r, c };                                                   \
+    v = mv_err_cost(&this_mv, ref_mv, mvjcost, mvcost, error_per_bit);       \
+    if (second_pred == NULL) {                                               \
+      thismse = vfp->svf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r),     \
+                         src_address, src_stride, &sse);                     \
+    } else if (mask) {                                                       \
+      thismse = vfp->msvf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r),    \
+                          src_address, src_stride, second_pred, mask,        \
+                          mask_stride, invert_mask, &sse);                   \
+    } else {                                                                 \
+      if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)  \
+        thismse = vfp->jsvaf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r), \
+                             src_address, src_stride, &sse, second_pred,     \
+                             &xd->jcp_param);                                \
+      else                                                                   \
+        thismse = vfp->svaf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r),  \
+                            src_address, src_stride, &sse, second_pred);     \
+    }                                                                        \
+    v += thismse;                                                            \
+    if (v < besterr) {                                                       \
+      besterr = v;                                                           \
+      br = r;                                                                \
+      bc = c;                                                                \
+      *distortion = thismse;                                                 \
+      *sse1 = sse;                                                           \
+    }                                                                        \
+  } else {                                                                   \
+    v = INT_MAX;                                                             \
+  }
+#else  // CONFIG_JNT_COMP
 #define CHECK_BETTER(v, r, c)                                             \
   if (c >= minc && c <= maxc && r >= minr && r <= maxr) {                 \
     MV this_mv = { r, c };                                                \
@@ -201,6 +234,7 @@
   } else {                                                                \
     v = INT_MAX;                                                          \
   }
+#endif  // CONFIG_JNT_COMP
 
 #define CHECK_BETTER0(v, r, c) CHECK_BETTER(v, r, c)
 
@@ -345,15 +379,18 @@
           vfp->vf(CONVERT_TO_BYTEPTR(comp_pred16), w, src, src_stride, sse1);
     } else {
       DECLARE_ALIGNED(16, uint8_t, comp_pred[MAX_SB_SQUARE]);
-      if (mask)
+      if (mask) {
         aom_comp_mask_pred(comp_pred, second_pred, w, h, y + offset, y_stride,
                            mask, mask_stride, invert_mask);
-      else
+      } else {
 #if CONFIG_JNT_COMP
-        aom_comp_avg_pred_c(comp_pred, second_pred, w, h, y + offset, y_stride);
-#else
-        aom_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride);
+        if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
+          aom_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset,
+                                y_stride, &xd->jcp_param);
+        else
 #endif  // CONFIG_JNT_COMP
+          aom_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride);
+      }
       besterr = vfp->vf(comp_pred, w, src, src_stride, sse1);
     }
   } else {
@@ -365,15 +402,18 @@
   (void)xd;
   if (second_pred != NULL) {
     DECLARE_ALIGNED(16, uint8_t, comp_pred[MAX_SB_SQUARE]);
-    if (mask)
+    if (mask) {
       aom_comp_mask_pred(comp_pred, second_pred, w, h, y + offset, y_stride,
                          mask, mask_stride, invert_mask);
-    else
+    } else {
 #if CONFIG_JNT_COMP
-      aom_comp_avg_pred_c(comp_pred, second_pred, w, h, y + offset, y_stride);
-#else
-      aom_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride);
+      if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
+        aom_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset,
+                              y_stride, &xd->jcp_param);
+      else
 #endif  // CONFIG_JNT_COMP
+        aom_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, y_stride);
+    }
     besterr = vfp->vf(comp_pred, w, src, src_stride, sse1);
   } else {
     besterr = vfp->vf(y + offset, y_stride, src, src_stride, sse1);
@@ -666,18 +706,21 @@
   (void)xd;
 #endif  // CONFIG_HIGHBITDEPTH
     if (second_pred != NULL) {
-      if (mask)
+      if (mask) {
         aom_comp_mask_upsampled_pred(pred, second_pred, w, h, subpel_x_q3,
                                      subpel_y_q3, y, y_stride, mask,
                                      mask_stride, invert_mask);
-      else
+      } else {
 #if CONFIG_JNT_COMP
-        aom_comp_avg_upsampled_pred_c(pred, second_pred, w, h, subpel_x_q3,
-                                      subpel_y_q3, y, y_stride);
-#else
-      aom_comp_avg_upsampled_pred(pred, second_pred, w, h, subpel_x_q3,
-                                  subpel_y_q3, y, y_stride);
+        if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
+          aom_jnt_comp_avg_upsampled_pred(pred, second_pred, w, h, subpel_x_q3,
+                                          subpel_y_q3, y, y_stride,
+                                          &xd->jcp_param);
+        else
 #endif  // CONFIG_JNT_COMP
+          aom_comp_avg_upsampled_pred(pred, second_pred, w, h, subpel_x_q3,
+                                      subpel_y_q3, y, y_stride);
+      }
     } else {
       aom_upsampled_pred(pred, w, h, subpel_x_q3, subpel_y_q3, y, y_stride);
     }
@@ -771,16 +814,25 @@
                                          mask_stride, invert_mask, w, h, &sse);
         } else {
           const uint8_t *const pre_address = pre(y, y_stride, tr, tc);
-          if (second_pred == NULL)
+          if (second_pred == NULL) {
             thismse = vfp->svf(pre_address, y_stride, sp(tc), sp(tr),
                                src_address, src_stride, &sse);
-          else if (mask)
+          } else if (mask) {
             thismse = vfp->msvf(pre_address, y_stride, sp(tc), sp(tr),
                                 src_address, src_stride, second_pred, mask,
                                 mask_stride, invert_mask, &sse);
-          else
-            thismse = vfp->svaf(pre_address, y_stride, sp(tc), sp(tr),
-                                src_address, src_stride, &sse, second_pred);
+          } else {
+#if CONFIG_JNT_COMP
+            if (xd->jcp_param.fwd_offset != -1 &&
+                xd->jcp_param.bck_offset != -1)
+              thismse =
+                  vfp->jsvaf(pre_address, y_stride, sp(tc), sp(tr), src_address,
+                             src_stride, &sse, second_pred, &xd->jcp_param);
+            else
+#endif  // CONFIG_JNT_COMP
+              thismse = vfp->svaf(pre_address, y_stride, sp(tc), sp(tr),
+                                  src_address, src_stride, &sse, second_pred);
+          }
         }
 
         cost_array[idx] = thismse + mv_err_cost(&this_mv, ref_mv, mvjcost,
@@ -814,16 +866,24 @@
       } else {
         const uint8_t *const pre_address = pre(y, y_stride, tr, tc);
 
-        if (second_pred == NULL)
+        if (second_pred == NULL) {
           thismse = vfp->svf(pre_address, y_stride, sp(tc), sp(tr), src_address,
                              src_stride, &sse);
-        else if (mask)
+        } else if (mask) {
           thismse = vfp->msvf(pre_address, y_stride, sp(tc), sp(tr),
                               src_address, src_stride, second_pred, mask,
                               mask_stride, invert_mask, &sse);
-        else
-          thismse = vfp->svaf(pre_address, y_stride, sp(tc), sp(tr),
-                              src_address, src_stride, &sse, second_pred);
+        } else {
+#if CONFIG_JNT_COMP
+          if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
+            thismse =
+                vfp->jsvaf(pre_address, y_stride, sp(tc), sp(tr), src_address,
+                           src_stride, &sse, second_pred, &xd->jcp_param);
+          else
+#endif  // CONFIG_JNT_COMP
+            thismse = vfp->svaf(pre_address, y_stride, sp(tc), sp(tr),
+                                src_address, src_stride, &sse, second_pred);
+        }
       }
 
       cost_array[4] = thismse + mv_err_cost(&this_mv, ref_mv, mvjcost, mvcost,
@@ -1397,11 +1457,21 @@
   const MV mv = { best_mv->row * 8, best_mv->col * 8 };
   unsigned int unused;
 
-  return vfp->svaf(get_buf_from_mv(in_what, best_mv), in_what->stride, 0, 0,
-                   what->buf, what->stride, &unused, second_pred) +
-         (use_mvcost ? mv_err_cost(&mv, center_mv, x->nmvjointcost, x->mvcost,
-                                   x->errorperbit)
-                     : 0);
+#if CONFIG_JNT_COMP
+  if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
+    return vfp->jsvaf(get_buf_from_mv(in_what, best_mv), in_what->stride, 0, 0,
+                      what->buf, what->stride, &unused, second_pred,
+                      &xd->jcp_param) +
+           (use_mvcost ? mv_err_cost(&mv, center_mv, x->nmvjointcost, x->mvcost,
+                                     x->errorperbit)
+                       : 0);
+  else
+#endif  // CONFIG_JNT_COMP
+    return vfp->svaf(get_buf_from_mv(in_what, best_mv), in_what->stride, 0, 0,
+                     what->buf, what->stride, &unused, second_pred) +
+           (use_mvcost ? mv_err_cost(&mv, center_mv, x->nmvjointcost, x->mvcost,
+                                     x->errorperbit)
+                       : 0);
 }
 
 int av1_get_mvpred_mask_var(const MACROBLOCK *x, const MV *best_mv,
@@ -2405,16 +2475,25 @@
 
   clamp_mv(best_mv, x->mv_limits.col_min, x->mv_limits.col_max,
            x->mv_limits.row_min, x->mv_limits.row_max);
-  if (mask)
+  if (mask) {
     best_sad = fn_ptr->msdf(what->buf, what->stride,
                             get_buf_from_mv(in_what, best_mv), in_what->stride,
                             second_pred, mask, mask_stride, invert_mask) +
                mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
-  else
-    best_sad =
-        fn_ptr->sdaf(what->buf, what->stride, get_buf_from_mv(in_what, best_mv),
-                     in_what->stride, second_pred) +
-        mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
+  } else {
+#if CONFIG_JNT_COMP
+    if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
+      best_sad = fn_ptr->jsdaf(what->buf, what->stride,
+                               get_buf_from_mv(in_what, best_mv),
+                               in_what->stride, second_pred, &xd->jcp_param) +
+                 mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
+    else
+#endif  // CONFIG_JNT_COMP
+      best_sad = fn_ptr->sdaf(what->buf, what->stride,
+                              get_buf_from_mv(in_what, best_mv),
+                              in_what->stride, second_pred) +
+                 mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
+  }
 
   for (i = 0; i < search_range; ++i) {
     int best_site = -1;
@@ -2425,14 +2504,22 @@
 
       if (is_mv_in(&x->mv_limits, &mv)) {
         unsigned int sad;
-        if (mask)
+        if (mask) {
           sad = fn_ptr->msdf(what->buf, what->stride,
                              get_buf_from_mv(in_what, &mv), in_what->stride,
                              second_pred, mask, mask_stride, invert_mask);
-        else
-          sad = fn_ptr->sdaf(what->buf, what->stride,
-                             get_buf_from_mv(in_what, &mv), in_what->stride,
-                             second_pred);
+        } else {
+#if CONFIG_JNT_COMP
+          if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
+            sad = fn_ptr->jsdaf(what->buf, what->stride,
+                                get_buf_from_mv(in_what, &mv), in_what->stride,
+                                second_pred, &xd->jcp_param);
+          else
+#endif  // CONFIG_JNT_COMP
+            sad = fn_ptr->sdaf(what->buf, what->stride,
+                               get_buf_from_mv(in_what, &mv), in_what->stride,
+                               second_pred);
+        }
         if (sad < best_sad) {
           sad += mvsad_err_cost(x, &mv, &fcenter_mv, error_per_bit);
           if (sad < best_sad) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 36c3aed..92249a1 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5789,56 +5789,6 @@
   return 1;
 }
 
-#if CONFIG_JNT_COMP
-static void jnt_comp_weight_assign(const AV1_COMMON *cm,
-                                   const MB_MODE_INFO *mbmi, int order_idx,
-                                   uint8_t *second_pred) {
-  if (mbmi->compound_idx) {
-    second_pred[4096] = -1;
-    second_pred[4097] = -1;
-  } else {
-    int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
-    int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
-    int bck_frame_index = 0, fwd_frame_index = 0;
-    int cur_frame_index = cm->cur_frame->cur_frame_offset;
-
-    if (bck_idx >= 0) {
-      bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
-    }
-
-    if (fwd_idx >= 0) {
-      fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
-    }
-
-    const double fwd = abs(fwd_frame_index - cur_frame_index);
-    const double bck = abs(cur_frame_index - bck_frame_index);
-    int order;
-    double ratio;
-
-    if (COMPOUND_WEIGHT_MODE == DIST) {
-      if (fwd > bck) {
-        ratio = (bck != 0) ? fwd / bck : 5.0;
-        order = 0;
-      } else {
-        ratio = (fwd != 0) ? bck / fwd : 5.0;
-        order = 1;
-      }
-      int quant_dist_idx;
-      for (quant_dist_idx = 0; quant_dist_idx < 4; ++quant_dist_idx) {
-        if (ratio < quant_dist_category[quant_dist_idx]) break;
-      }
-      second_pred[4096] =
-          quant_dist_lookup_table[order_idx][quant_dist_idx][order];
-      second_pred[4097] =
-          quant_dist_lookup_table[order_idx][quant_dist_idx][1 - order];
-    } else {
-      second_pred[4096] = (DIST_PRECISION >> 1);
-      second_pred[4097] = (DIST_PRECISION >> 1);
-    }
-  }
-}
-#endif  // CONFIG_JNT_COMP
-
 static void joint_motion_search(const AV1_COMP *cpi, MACROBLOCK *x,
                                 BLOCK_SIZE bsize, int_mv *frame_mv,
 #if CONFIG_COMPOUND_SINGLEREF
@@ -5901,13 +5851,8 @@
 
 // Prediction buffer from second frame.
 #if CONFIG_HIGHBITDEPTH
-#if CONFIG_JNT_COMP
-  DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE + 2]);
-  uint8_t *second_pred;
-#else
   DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE]);
   uint8_t *second_pred;
-#endif  // CONFIG_JNT_COMP
 #else   // CONFIG_HIGHBITDEPTH
   DECLARE_ALIGNED(16, uint8_t, second_pred[MAX_SB_SQUARE]);
 #endif  // CONFIG_HIGHBITDEPTH
@@ -6046,7 +5991,8 @@
 
 #if CONFIG_JNT_COMP
     const int order_idx = id != 0;
-    jnt_comp_weight_assign(cm, mbmi, order_idx, second_pred);
+    av1_jnt_comp_weight_assign(cm, mbmi, order_idx, &xd->jcp_param.fwd_offset,
+                               &xd->jcp_param.bck_offset, 1);
 #endif  // CONFIG_JNT_COMP
 
     // Do compound motion search on the current reference frame.
@@ -6761,7 +6707,8 @@
 #endif  // CONFIG_HIGHBITDEPTH
 
 #if CONFIG_JNT_COMP
-  jnt_comp_weight_assign(cm, mbmi, 0, second_pred);
+  av1_jnt_comp_weight_assign(cm, mbmi, 0, &xd->jcp_param.fwd_offset,
+                             &xd->jcp_param.bck_offset, 1);
 #endif  // CONFIG_JNT_COMP
 
   if (scaled_ref_frame) {
@@ -6930,11 +6877,7 @@
 
 // Prediction buffer from second frame.
 #if CONFIG_HIGHBITDEPTH
-#if CONFIG_JNT_COMP
-  DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE + 2]);
-#else
   DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE]);
-#endif  // CONFIG_JNT_COMP
   uint8_t *second_pred;
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
     second_pred = CONVERT_TO_BYTEPTR(second_pred_alloc_16);