Unify parameters of hbd/lbd compound pred function

Update the parameter list of following functions,

aom_highbd_upsampled_pred,
aom_highbd_comp_avg_upsampled_pred,
aom_highbd_jnt_comp_avg_upsampled_pred,
aom_highbd_comp_avg_pred,
aom_highbd_jnt_comp_avg_pred,
aom_highbd_comp_mask_pred

Same as other variance functions, always pass
"uint8_t* comp_pred" in, recover when actually
using it. So that the function prototype can
match with lbd version.

Change-Id: I159057fbe470f3c760a2fd010d60d052ac2ca4f9
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 2639e39..8bd8911 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -905,18 +905,17 @@
 
 
   add_proto qw/void aom_highbd_upsampled_pred/, "MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-                                                 const MV *const mv, uint16_t *comp_pred, int width, int height, int subpel_x_q3,
+                                                 const MV *const mv, uint8_t *comp_pred8, int width, int height, int subpel_x_q3,
                                                  int subpel_y_q3, const uint8_t *ref8, int ref_stride, int bd, int subpel_search";
   specialize qw/aom_highbd_upsampled_pred sse2/;
 
   add_proto qw/void aom_highbd_comp_avg_upsampled_pred/, "MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-                                                          const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
-                                                          int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8, int ref_stride,
-                                                          int bd, int subpel_search";
+                                                          const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
+                                                          int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8, int ref_stride, int bd, int subpel_search";
   specialize qw/aom_highbd_comp_avg_upsampled_pred sse2/;
 
   add_proto qw/void aom_highbd_jnt_comp_avg_upsampled_pred/, "MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-                                                              const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
+                                                              const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
                                                               int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8,
                                                               int ref_stride, int bd, const JNT_COMP_PARAMS *jcp_param, int subpel_search";
   specialize qw/aom_highbd_jnt_comp_avg_upsampled_pred sse2/;
@@ -1337,9 +1336,9 @@
     add_proto qw/unsigned int aom_highbd_12_mse8x8/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
     specialize qw/aom_highbd_12_mse8x8 sse2/;
 
-    add_proto qw/void aom_highbd_comp_avg_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride";
+    add_proto qw/void aom_highbd_comp_avg_pred/, "uint8_t *comp_pred8, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride";
 
-    add_proto qw/void aom_highbd_jnt_comp_avg_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const JNT_COMP_PARAMS *jcp_param";
+    add_proto qw/void aom_highbd_jnt_comp_avg_pred/, "uint8_t *comp_pred8, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const JNT_COMP_PARAMS *jcp_param";
     specialize qw/aom_highbd_jnt_comp_avg_pred sse2/;
 
     #
@@ -1566,7 +1565,7 @@
   add_proto qw/void aom_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
   specialize qw/aom_comp_mask_pred ssse3 avx2/;
 
-  add_proto qw/void aom_highbd_comp_mask_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
+  add_proto qw/void aom_highbd_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
   specialize qw/aom_highbd_comp_mask_pred sse2 avx2/;
 
 }  # CONFIG_AV1_ENCODER
diff --git a/aom_dsp/sad.c b/aom_dsp/sad.c
index ede4c58..1e24df4 100644
--- a/aom_dsp/sad.c
+++ b/aom_dsp/sad.c
@@ -200,15 +200,16 @@
       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,  \
       const uint8_t *second_pred) {                                            \
     uint16_t comp_pred[m * n];                                                 \
-    aom_highbd_comp_avg_pred(comp_pred, second_pred, m, n, ref, ref_stride);   \
+    aom_highbd_comp_avg_pred(CONVERT_TO_BYTEPTR(comp_pred), second_pred, m, n, \
+                             ref, ref_stride);                                 \
     return highbd_sadb(src, src_stride, comp_pred, m, m, n);                   \
   }                                                                            \
   unsigned int aom_highbd_jnt_sad##m##x##n##_avg_c(                            \
       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride,  \
       const uint8_t *second_pred, const JNT_COMP_PARAMS *jcp_param) {          \
     uint16_t comp_pred[m * n];                                                 \
-    aom_highbd_jnt_comp_avg_pred(comp_pred, second_pred, m, n, ref,            \
-                                 ref_stride, jcp_param);                       \
+    aom_highbd_jnt_comp_avg_pred(CONVERT_TO_BYTEPTR(comp_pred), second_pred,   \
+                                 m, n, ref, ref_stride, jcp_param);            \
     return highbd_sadb(src, src_stride, comp_pred, m, m, n);                   \
   }
 
diff --git a/aom_dsp/variance.c b/aom_dsp/variance.c
index 09fe2fd..bb77174 100644
--- a/aom_dsp/variance.c
+++ b/aom_dsp/variance.c
@@ -708,125 +708,125 @@
                                                dst, dst_stride, sse);        \
   }
 
-#define HIGHBD_SUBPIX_AVG_VAR(W, H)                                           \
-  uint32_t aom_highbd_8_sub_pixel_avg_variance##W##x##H##_c(                  \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, uint32_t *sse,                      \
-      const uint8_t *second_pred) {                                           \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                              \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    aom_highbd_comp_avg_pred_c(temp3, second_pred, W, H,                      \
-                               CONVERT_TO_BYTEPTR(temp2), W);                 \
-                                                                              \
-    return aom_highbd_8_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,   \
-                                              dst, dst_stride, sse);          \
-  }                                                                           \
-                                                                              \
-  uint32_t aom_highbd_10_sub_pixel_avg_variance##W##x##H##_c(                 \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, uint32_t *sse,                      \
-      const uint8_t *second_pred) {                                           \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                              \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    aom_highbd_comp_avg_pred_c(temp3, second_pred, W, H,                      \
-                               CONVERT_TO_BYTEPTR(temp2), W);                 \
-                                                                              \
-    return aom_highbd_10_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,  \
-                                               dst, dst_stride, sse);         \
-  }                                                                           \
-                                                                              \
-  uint32_t aom_highbd_12_sub_pixel_avg_variance##W##x##H##_c(                 \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, uint32_t *sse,                      \
-      const uint8_t *second_pred) {                                           \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                              \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    aom_highbd_comp_avg_pred_c(temp3, second_pred, W, H,                      \
-                               CONVERT_TO_BYTEPTR(temp2), W);                 \
-                                                                              \
-    return aom_highbd_12_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,  \
-                                               dst, dst_stride, sse);         \
-  }                                                                           \
-                                                                              \
-  uint32_t aom_highbd_8_jnt_sub_pixel_avg_variance##W##x##H##_c(              \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, uint32_t *sse,                      \
-      const uint8_t *second_pred, const JNT_COMP_PARAMS *jcp_param) {         \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                              \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    aom_highbd_jnt_comp_avg_pred(temp3, second_pred, W, H,                    \
-                                 CONVERT_TO_BYTEPTR(temp2), W, jcp_param);    \
-                                                                              \
-    return aom_highbd_8_variance##W##x##H(CONVERT_TO_BYTEPTR(temp3), W, dst,  \
-                                          dst_stride, sse);                   \
-  }                                                                           \
-                                                                              \
-  uint32_t aom_highbd_10_jnt_sub_pixel_avg_variance##W##x##H##_c(             \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, uint32_t *sse,                      \
-      const uint8_t *second_pred, const JNT_COMP_PARAMS *jcp_param) {         \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                              \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    aom_highbd_jnt_comp_avg_pred(temp3, second_pred, W, H,                    \
-                                 CONVERT_TO_BYTEPTR(temp2), W, jcp_param);    \
-                                                                              \
-    return aom_highbd_10_variance##W##x##H(CONVERT_TO_BYTEPTR(temp3), W, dst, \
-                                           dst_stride, sse);                  \
-  }                                                                           \
-                                                                              \
-  uint32_t aom_highbd_12_jnt_sub_pixel_avg_variance##W##x##H##_c(             \
-      const uint8_t *src, int src_stride, int xoffset, int yoffset,           \
-      const uint8_t *dst, int dst_stride, uint32_t *sse,                      \
-      const uint8_t *second_pred, const JNT_COMP_PARAMS *jcp_param) {         \
-    uint16_t fdata3[(H + 1) * W];                                             \
-    uint16_t temp2[H * W];                                                    \
-    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                              \
-                                                                              \
-    aom_highbd_var_filter_block2d_bil_first_pass(                             \
-        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);  \
-    aom_highbd_var_filter_block2d_bil_second_pass(                            \
-        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);             \
-                                                                              \
-    aom_highbd_jnt_comp_avg_pred(temp3, second_pred, W, H,                    \
-                                 CONVERT_TO_BYTEPTR(temp2), W, jcp_param);    \
-                                                                              \
-    return aom_highbd_12_variance##W##x##H(CONVERT_TO_BYTEPTR(temp3), W, dst, \
-                                           dst_stride, sse);                  \
+#define HIGHBD_SUBPIX_AVG_VAR(W, H)                                            \
+  uint32_t aom_highbd_8_sub_pixel_avg_variance##W##x##H##_c(                   \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, uint32_t *sse,                       \
+      const uint8_t *second_pred) {                                            \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_comp_avg_pred_c(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H,   \
+                               CONVERT_TO_BYTEPTR(temp2), W);                  \
+                                                                               \
+    return aom_highbd_8_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,    \
+                                              dst, dst_stride, sse);           \
+  }                                                                            \
+                                                                               \
+  uint32_t aom_highbd_10_sub_pixel_avg_variance##W##x##H##_c(                  \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, uint32_t *sse,                       \
+      const uint8_t *second_pred) {                                            \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_comp_avg_pred_c(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H,   \
+                               CONVERT_TO_BYTEPTR(temp2), W);                  \
+                                                                               \
+    return aom_highbd_10_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,   \
+                                               dst, dst_stride, sse);          \
+  }                                                                            \
+                                                                               \
+  uint32_t aom_highbd_12_sub_pixel_avg_variance##W##x##H##_c(                  \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, uint32_t *sse,                       \
+      const uint8_t *second_pred) {                                            \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_comp_avg_pred_c(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H,   \
+                               CONVERT_TO_BYTEPTR(temp2), W);                  \
+                                                                               \
+    return aom_highbd_12_variance##W##x##H##_c(CONVERT_TO_BYTEPTR(temp3), W,   \
+                                               dst, dst_stride, sse);          \
+  }                                                                            \
+                                                                               \
+  uint32_t aom_highbd_8_jnt_sub_pixel_avg_variance##W##x##H##_c(               \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, uint32_t *sse,                       \
+      const uint8_t *second_pred, const JNT_COMP_PARAMS *jcp_param) {          \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_jnt_comp_avg_pred(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H, \
+                                 CONVERT_TO_BYTEPTR(temp2), W, jcp_param);     \
+                                                                               \
+    return aom_highbd_8_variance##W##x##H(CONVERT_TO_BYTEPTR(temp3), W, dst,   \
+                                          dst_stride, sse);                    \
+  }                                                                            \
+                                                                               \
+  uint32_t aom_highbd_10_jnt_sub_pixel_avg_variance##W##x##H##_c(              \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, uint32_t *sse,                       \
+      const uint8_t *second_pred, const JNT_COMP_PARAMS *jcp_param) {          \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_jnt_comp_avg_pred(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H, \
+                                 CONVERT_TO_BYTEPTR(temp2), W, jcp_param);     \
+                                                                               \
+    return aom_highbd_10_variance##W##x##H(CONVERT_TO_BYTEPTR(temp3), W, dst,  \
+                                           dst_stride, sse);                   \
+  }                                                                            \
+                                                                               \
+  uint32_t aom_highbd_12_jnt_sub_pixel_avg_variance##W##x##H##_c(              \
+      const uint8_t *src, int src_stride, int xoffset, int yoffset,            \
+      const uint8_t *dst, int dst_stride, uint32_t *sse,                       \
+      const uint8_t *second_pred, const JNT_COMP_PARAMS *jcp_param) {          \
+    uint16_t fdata3[(H + 1) * W];                                              \
+    uint16_t temp2[H * W];                                                     \
+    DECLARE_ALIGNED(16, uint16_t, temp3[H * W]);                               \
+                                                                               \
+    aom_highbd_var_filter_block2d_bil_first_pass(                              \
+        src, fdata3, src_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
+    aom_highbd_var_filter_block2d_bil_second_pass(                             \
+        fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
+                                                                               \
+    aom_highbd_jnt_comp_avg_pred(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H, \
+                                 CONVERT_TO_BYTEPTR(temp2), W, jcp_param);     \
+                                                                               \
+    return aom_highbd_12_variance##W##x##H(CONVERT_TO_BYTEPTR(temp3), W, dst,  \
+                                           dst_stride, sse);                   \
   }
 
 /* All three forms of the variance are available in the same sizes. */
@@ -869,12 +869,13 @@
 HIGHBD_MSE(8, 16)
 HIGHBD_MSE(8, 8)
 
-void aom_highbd_comp_avg_pred_c(uint16_t *comp_pred, const uint8_t *pred8,
+void aom_highbd_comp_avg_pred_c(uint8_t *comp_pred8, const uint8_t *pred8,
                                 int width, int height, const uint8_t *ref8,
                                 int ref_stride) {
   int i, j;
   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
   for (i = 0; i < height; ++i) {
     for (j = 0; j < width; ++j) {
       const int tmp = pred[j] + ref[j];
@@ -889,7 +890,7 @@
 void aom_highbd_upsampled_pred_c(MACROBLOCKD *xd,
                                  const struct AV1Common *const cm, int mi_row,
                                  int mi_col, const MV *const mv,
-                                 uint16_t *comp_pred, int width, int height,
+                                 uint8_t *comp_pred8, int width, int height,
                                  int subpel_x_q3, int subpel_y_q3,
                                  const uint8_t *ref8, int ref_stride, int bd,
                                  int subpel_search) {
@@ -905,8 +906,6 @@
     if (is_scaled) {
       // Note: This is mostly a copy from the >=8X8 case in
       // build_inter_predictors() function, with some small tweaks.
-      uint8_t *comp_pred8 = CONVERT_TO_BYTEPTR(comp_pred);
-
       // Some assumptions.
       const int plane = 0;
 
@@ -983,10 +982,9 @@
           : av1_get_interp_filter_params_with_block_size(EIGHTTAP_REGULAR, 8);
 
   if (!subpel_x_q3 && !subpel_y_q3) {
-    const uint16_t *ref;
-    int i;
-    ref = CONVERT_TO_SHORTPTR(ref8);
-    for (i = 0; i < height; i++) {
+    const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+    uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
+    for (int i = 0; i < height; i++) {
       memcpy(comp_pred, ref, width * sizeof(*comp_pred));
       comp_pred += width;
       ref += ref_stride;
@@ -994,13 +992,13 @@
   } else if (!subpel_y_q3) {
     const int16_t *const kernel =
         av1_get_interp_filter_subpel_kernel(filter, subpel_x_q3 << 1);
-    aom_highbd_convolve8_horiz(ref8, ref_stride, CONVERT_TO_BYTEPTR(comp_pred),
-                               width, kernel, 16, NULL, -1, width, height, bd);
+    aom_highbd_convolve8_horiz(ref8, ref_stride, comp_pred8, width, kernel, 16,
+                               NULL, -1, width, height, bd);
   } else if (!subpel_x_q3) {
     const int16_t *const kernel =
         av1_get_interp_filter_subpel_kernel(filter, subpel_y_q3 << 1);
-    aom_highbd_convolve8_vert(ref8, ref_stride, CONVERT_TO_BYTEPTR(comp_pred),
-                              width, NULL, -1, kernel, 16, width, height, bd);
+    aom_highbd_convolve8_vert(ref8, ref_stride, comp_pred8, width, NULL, -1,
+                              kernel, 16, width, height, bd);
   } else {
     DECLARE_ALIGNED(16, uint16_t,
                     temp[((MAX_SB_SIZE + 16) + 16) * MAX_SB_SIZE]);
@@ -1017,20 +1015,21 @@
                                intermediate_height, bd);
     aom_highbd_convolve8_vert(
         CONVERT_TO_BYTEPTR(temp + MAX_SB_SIZE * ((filter->taps >> 1) - 1)),
-        MAX_SB_SIZE, CONVERT_TO_BYTEPTR(comp_pred), width, NULL, -1, kernel_y,
-        16, width, height, bd);
+        MAX_SB_SIZE, comp_pred8, width, NULL, -1, kernel_y, 16, width, height,
+        bd);
   }
 }
 
 void aom_highbd_comp_avg_upsampled_pred_c(
     MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-    const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
+    const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
     int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8,
     int ref_stride, int bd, int subpel_search) {
   int i, j;
 
   const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
-  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred, width,
+  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
+  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred8, width,
                             height, subpel_x_q3, subpel_y_q3, ref8, ref_stride,
                             bd, subpel_search);
   for (i = 0; i < height; ++i) {
@@ -1042,7 +1041,7 @@
   }
 }
 
-void aom_highbd_jnt_comp_avg_pred_c(uint16_t *comp_pred, const uint8_t *pred8,
+void aom_highbd_jnt_comp_avg_pred_c(uint8_t *comp_pred8, const uint8_t *pred8,
                                     int width, int height, const uint8_t *ref8,
                                     int ref_stride,
                                     const JNT_COMP_PARAMS *jcp_param) {
@@ -1051,6 +1050,7 @@
   const int bck_offset = jcp_param->bck_offset;
   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
 
   for (i = 0; i < height; ++i) {
     for (j = 0; j < width; ++j) {
@@ -1066,7 +1066,7 @@
 
 void aom_highbd_jnt_comp_avg_upsampled_pred_c(
     MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-    const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
+    const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
     int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8,
     int ref_stride, int bd, const JNT_COMP_PARAMS *jcp_param,
     int subpel_search) {
@@ -1074,8 +1074,8 @@
   const int fwd_offset = jcp_param->fwd_offset;
   const int bck_offset = jcp_param->bck_offset;
   const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
-
-  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred, width,
+  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
+  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred8, width,
                             height, subpel_x_q3, subpel_y_q3, ref8, ref_stride,
                             bd, subpel_search);
 
@@ -1172,13 +1172,14 @@
 MASK_SUBPIX_VAR(16, 64)
 MASK_SUBPIX_VAR(64, 16)
 
-void aom_highbd_comp_mask_pred_c(uint16_t *comp_pred, const uint8_t *pred8,
+void aom_highbd_comp_mask_pred_c(uint8_t *comp_pred8, const uint8_t *pred8,
                                  int width, int height, const uint8_t *ref8,
                                  int ref_stride, const uint8_t *mask,
                                  int mask_stride, int invert_mask) {
   int i, j;
   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
   for (i = 0; i < height; ++i) {
     for (j = 0; j < width; ++j) {
       if (!invert_mask)
@@ -1195,16 +1196,15 @@
 
 void aom_highbd_comp_mask_upsampled_pred(
     MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-    const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
+    const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
     int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8,
     int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask,
     int bd, int subpel_search) {
-  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred, width,
+  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred8, width,
                             height, subpel_x_q3, subpel_y_q3, ref8, ref_stride,
                             bd, subpel_search);
-  aom_highbd_comp_mask_pred(comp_pred, pred8, width, height,
-                            CONVERT_TO_BYTEPTR(comp_pred), width, mask,
-                            mask_stride, invert_mask);
+  aom_highbd_comp_mask_pred(comp_pred8, pred8, width, height, comp_pred8, width,
+                            mask, mask_stride, invert_mask);
 }
 
 #define HIGHBD_MASK_SUBPIX_VAR(W, H)                                           \
@@ -1222,7 +1222,7 @@
     aom_highbd_var_filter_block2d_bil_second_pass(                             \
         fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
                                                                                \
-    aom_highbd_comp_mask_pred_c(temp3, second_pred, W, H,                      \
+    aom_highbd_comp_mask_pred_c(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H,  \
                                 CONVERT_TO_BYTEPTR(temp2), W, msk, msk_stride, \
                                 invert_mask);                                  \
                                                                                \
@@ -1244,7 +1244,7 @@
     aom_highbd_var_filter_block2d_bil_second_pass(                             \
         fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
                                                                                \
-    aom_highbd_comp_mask_pred_c(temp3, second_pred, W, H,                      \
+    aom_highbd_comp_mask_pred_c(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H,  \
                                 CONVERT_TO_BYTEPTR(temp2), W, msk, msk_stride, \
                                 invert_mask);                                  \
                                                                                \
@@ -1266,7 +1266,7 @@
     aom_highbd_var_filter_block2d_bil_second_pass(                             \
         fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);              \
                                                                                \
-    aom_highbd_comp_mask_pred_c(temp3, second_pred, W, H,                      \
+    aom_highbd_comp_mask_pred_c(CONVERT_TO_BYTEPTR(temp3), second_pred, W, H,  \
                                 CONVERT_TO_BYTEPTR(temp2), W, msk, msk_stride, \
                                 invert_mask);                                  \
                                                                                \
diff --git a/aom_dsp/variance.h b/aom_dsp/variance.h
index a3e74b9..d29c6e4 100644
--- a/aom_dsp/variance.h
+++ b/aom_dsp/variance.h
@@ -79,7 +79,7 @@
 
 void aom_highbd_comp_mask_upsampled_pred(
     MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-    const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
+    const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
     int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8,
     int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask,
     int bd, int subpel_search);
diff --git a/aom_dsp/x86/highbd_variance_sse2.c b/aom_dsp/x86/highbd_variance_sse2.c
index f855934..47b052a 100644
--- a/aom_dsp/x86/highbd_variance_sse2.c
+++ b/aom_dsp/x86/highbd_variance_sse2.c
@@ -593,7 +593,7 @@
 void aom_highbd_upsampled_pred_sse2(MACROBLOCKD *xd,
                                     const struct AV1Common *const cm,
                                     int mi_row, int mi_col, const MV *const mv,
-                                    uint16_t *comp_pred, int width, int height,
+                                    uint8_t *comp_pred8, int width, int height,
                                     int subpel_x_q3, int subpel_y_q3,
                                     const uint8_t *ref8, int ref_stride, int bd,
                                     int subpel_search) {
@@ -609,8 +609,6 @@
     if (is_scaled) {
       // Note: This is mostly a copy from the >=8X8 case in
       // build_inter_predictors() function, with some small tweaks.
-      uint8_t *comp_pred8 = CONVERT_TO_BYTEPTR(comp_pred);
-
       // Some assumptions.
       const int plane = 0;
 
@@ -686,6 +684,7 @@
 
   if (!subpel_x_q3 && !subpel_y_q3) {
     uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+    uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
     if (width >= 8) {
       int i;
       assert(!(width & 7));
@@ -716,13 +715,13 @@
   } else if (!subpel_y_q3) {
     const int16_t *const kernel =
         av1_get_interp_filter_subpel_kernel(filter, subpel_x_q3 << 1);
-    aom_highbd_convolve8_horiz(ref8, ref_stride, CONVERT_TO_BYTEPTR(comp_pred),
-                               width, kernel, 16, NULL, -1, width, height, bd);
+    aom_highbd_convolve8_horiz(ref8, ref_stride, comp_pred8, width, kernel, 16,
+                               NULL, -1, width, height, bd);
   } else if (!subpel_x_q3) {
     const int16_t *const kernel =
         av1_get_interp_filter_subpel_kernel(filter, subpel_y_q3 << 1);
-    aom_highbd_convolve8_vert(ref8, ref_stride, CONVERT_TO_BYTEPTR(comp_pred),
-                              width, NULL, -1, kernel, 16, width, height, bd);
+    aom_highbd_convolve8_vert(ref8, ref_stride, comp_pred8, width, NULL, -1,
+                              kernel, 16, width, height, bd);
   } else {
     DECLARE_ALIGNED(16, uint16_t,
                     temp[((MAX_SB_SIZE + 16) + 16) * MAX_SB_SIZE]);
@@ -739,30 +738,29 @@
                                intermediate_height, bd);
     aom_highbd_convolve8_vert(
         CONVERT_TO_BYTEPTR(temp + MAX_SB_SIZE * ((filter->taps >> 1) - 1)),
-        MAX_SB_SIZE, CONVERT_TO_BYTEPTR(comp_pred), width, NULL, -1, kernel_y,
-        16, width, height, bd);
+        MAX_SB_SIZE, comp_pred8, width, NULL, -1, kernel_y, 16, width, height,
+        bd);
   }
 }
 
 void aom_highbd_comp_avg_upsampled_pred_sse2(
     MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-    const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
+    const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
     int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8,
     int ref_stride, int bd, int subpel_search) {
-  uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
-  int n;
-  int i;
-  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred, width,
+  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred8, width,
                             height, subpel_x_q3, subpel_y_q3, ref8, ref_stride,
                             bd, subpel_search);
+  uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
+  uint16_t *comp_pred16 = CONVERT_TO_SHORTPTR(comp_pred8);
   /*The total number of pixels must be a multiple of 8 (e.g., 4x4).*/
   assert(!(width * height & 7));
-  n = width * height >> 3;
-  for (i = 0; i < n; i++) {
-    __m128i s0 = _mm_loadu_si128((const __m128i *)comp_pred);
+  int n = width * height >> 3;
+  for (int i = 0; i < n; i++) {
+    __m128i s0 = _mm_loadu_si128((const __m128i *)comp_pred16);
     __m128i p0 = _mm_loadu_si128((const __m128i *)pred);
-    _mm_storeu_si128((__m128i *)comp_pred, _mm_avg_epu16(s0, p0));
-    comp_pred += 8;
+    _mm_storeu_si128((__m128i *)comp_pred16, _mm_avg_epu16(s0, p0));
+    comp_pred16 += 8;
     pred += 8;
   }
 }
@@ -782,7 +780,7 @@
   xx_storeu_128(result, shift);
 }
 
-void aom_highbd_jnt_comp_avg_pred_sse2(uint16_t *comp_pred,
+void aom_highbd_jnt_comp_avg_pred_sse2(uint8_t *comp_pred8,
                                        const uint8_t *pred8, int width,
                                        int height, const uint8_t *ref8,
                                        int ref_stride,
@@ -797,6 +795,7 @@
       _mm_set_epi16(round, round, round, round, round, round, round, round);
   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
 
   if (width >= 8) {
     // Read 8 pixels one row at a time
@@ -835,14 +834,14 @@
 
 void aom_highbd_jnt_comp_avg_upsampled_pred_sse2(
     MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-    const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
+    const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
     int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8,
     int ref_stride, int bd, const JNT_COMP_PARAMS *jcp_param,
     int subpel_search) {
   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
   int n;
   int i;
-  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred, width,
+  aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, comp_pred8, width,
                             height, subpel_x_q3, subpel_y_q3, ref8, ref_stride,
                             bd, subpel_search);
   assert(!(width * height & 7));
@@ -856,13 +855,14 @@
   const __m128i r =
       _mm_set_epi16(round, round, round, round, round, round, round, round);
 
+  uint16_t *comp_pred16 = CONVERT_TO_SHORTPTR(comp_pred8);
   for (i = 0; i < n; i++) {
-    __m128i p0 = xx_loadu_128(comp_pred);
+    __m128i p0 = xx_loadu_128(comp_pred16);
     __m128i p1 = xx_loadu_128(pred);
 
-    highbd_compute_jnt_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
+    highbd_compute_jnt_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred16);
 
-    comp_pred += 8;
+    comp_pred16 += 8;
     pred += 8;
   }
 }
diff --git a/aom_dsp/x86/highbd_variance_sse4.c b/aom_dsp/x86/highbd_variance_sse4.c
index 6c247a9..df5449a 100644
--- a/aom_dsp/x86/highbd_variance_sse4.c
+++ b/aom_dsp/x86/highbd_variance_sse4.c
@@ -168,8 +168,8 @@
   aom_highbd_var_filter_block2d_bil_second_pass(fdata3, temp2, 4, 4, 4, 4,
                                                 bilinear_filters_2t[yoffset]);
 
-  aom_highbd_comp_avg_pred(temp3, second_pred, 4, 4, CONVERT_TO_BYTEPTR(temp2),
-                           4);
+  aom_highbd_comp_avg_pred(CONVERT_TO_BYTEPTR(temp3), second_pred, 4, 4,
+                           CONVERT_TO_BYTEPTR(temp2), 4);
 
   return aom_highbd_8_variance4x4(CONVERT_TO_BYTEPTR(temp3), 4, dst, dst_stride,
                                   sse);
@@ -188,8 +188,8 @@
   aom_highbd_var_filter_block2d_bil_second_pass(fdata3, temp2, 4, 4, 4, 4,
                                                 bilinear_filters_2t[yoffset]);
 
-  aom_highbd_comp_avg_pred(temp3, second_pred, 4, 4, CONVERT_TO_BYTEPTR(temp2),
-                           4);
+  aom_highbd_comp_avg_pred(CONVERT_TO_BYTEPTR(temp3), second_pred, 4, 4,
+                           CONVERT_TO_BYTEPTR(temp2), 4);
 
   return aom_highbd_10_variance4x4(CONVERT_TO_BYTEPTR(temp3), 4, dst,
                                    dst_stride, sse);
@@ -208,8 +208,8 @@
   aom_highbd_var_filter_block2d_bil_second_pass(fdata3, temp2, 4, 4, 4, 4,
                                                 bilinear_filters_2t[yoffset]);
 
-  aom_highbd_comp_avg_pred(temp3, second_pred, 4, 4, CONVERT_TO_BYTEPTR(temp2),
-                           4);
+  aom_highbd_comp_avg_pred(CONVERT_TO_BYTEPTR(temp3), second_pred, 4, 4,
+                           CONVERT_TO_BYTEPTR(temp2), 4);
 
   return aom_highbd_12_variance4x4(CONVERT_TO_BYTEPTR(temp3), 4, dst,
                                    dst_stride, sse);
diff --git a/aom_dsp/x86/variance_avx2.c b/aom_dsp/x86/variance_avx2.c
index a7ac2c9..800aef1 100644
--- a/aom_dsp/x86/variance_avx2.c
+++ b/aom_dsp/x86/variance_avx2.c
@@ -433,13 +433,14 @@
   return comp;
 }
 
-void aom_highbd_comp_mask_pred_avx2(uint16_t *comp_pred, const uint8_t *pred8,
+void aom_highbd_comp_mask_pred_avx2(uint8_t *comp_pred8, const uint8_t *pred8,
                                     int width, int height, const uint8_t *ref8,
                                     int ref_stride, const uint8_t *mask,
                                     int mask_stride, int invert_mask) {
   int i = 0;
   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
+  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
   const uint16_t *src0 = invert_mask ? pred : ref;
   const uint16_t *src1 = invert_mask ? ref : pred;
   const int stride0 = invert_mask ? width : ref_stride;
diff --git a/aom_dsp/x86/variance_sse2.c b/aom_dsp/x86/variance_sse2.c
index a5a0acc..1a27fd2 100644
--- a/aom_dsp/x86/variance_sse2.c
+++ b/aom_dsp/x86/variance_sse2.c
@@ -694,11 +694,12 @@
   return comp;
 }
 
-void aom_highbd_comp_mask_pred_sse2(uint16_t *comp_pred, const uint8_t *pred8,
+void aom_highbd_comp_mask_pred_sse2(uint8_t *comp_pred8, const uint8_t *pred8,
                                     int width, int height, const uint8_t *ref8,
                                     int ref_stride, const uint8_t *mask,
                                     int mask_stride, int invert_mask) {
   int i = 0;
+  uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
   const uint16_t *src0 = invert_mask ? pred : ref;
diff --git a/av1/encoder/mcomp.c b/av1/encoder/mcomp.c
index ba66bae..6094aa8 100644
--- a/av1/encoder/mcomp.c
+++ b/av1/encoder/mcomp.c
@@ -343,19 +343,19 @@
   if (second_pred != NULL) {
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
       DECLARE_ALIGNED(16, uint16_t, comp_pred16[MAX_SB_SQUARE]);
+      uint8_t *comp_pred = CONVERT_TO_BYTEPTR(comp_pred16);
       if (mask) {
-        aom_highbd_comp_mask_pred(comp_pred16, second_pred, w, h, y + offset,
+        aom_highbd_comp_mask_pred(comp_pred, second_pred, w, h, y + offset,
                                   y_stride, mask, mask_stride, invert_mask);
       } else {
         if (xd->jcp_param.use_jnt_comp_avg)
-          aom_highbd_jnt_comp_avg_pred(comp_pred16, second_pred, w, h,
-                                       y + offset, y_stride, &xd->jcp_param);
+          aom_highbd_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset,
+                                       y_stride, &xd->jcp_param);
         else
-          aom_highbd_comp_avg_pred(comp_pred16, second_pred, w, h, y + offset,
+          aom_highbd_comp_avg_pred(comp_pred, second_pred, w, h, y + offset,
                                    y_stride);
       }
-      besterr =
-          vfp->vf(CONVERT_TO_BYTEPTR(comp_pred16), w, src, src_stride, sse1);
+      besterr = vfp->vf(comp_pred, w, src, src_stride, sse1);
     } else {
       DECLARE_ALIGNED(16, uint8_t, comp_pred[MAX_SB_SQUARE]);
       if (mask) {
@@ -653,30 +653,29 @@
   unsigned int besterr;
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
     DECLARE_ALIGNED(16, uint16_t, pred16[MAX_SB_SQUARE]);
+    uint8_t *pred8 = CONVERT_TO_BYTEPTR(pred16);
     if (second_pred != NULL) {
       if (mask) {
         aom_highbd_comp_mask_upsampled_pred(
-            xd, cm, mi_row, mi_col, mv, pred16, second_pred, w, h, subpel_x_q3,
+            xd, cm, mi_row, mi_col, mv, pred8, second_pred, w, h, subpel_x_q3,
             subpel_y_q3, y, y_stride, mask, mask_stride, invert_mask, xd->bd,
             subpel_search);
       } else {
         if (xd->jcp_param.use_jnt_comp_avg)
           aom_highbd_jnt_comp_avg_upsampled_pred(
-              xd, cm, mi_row, mi_col, mv, pred16, second_pred, w, h,
-              subpel_x_q3, subpel_y_q3, y, y_stride, xd->bd, &xd->jcp_param,
-              subpel_search);
+              xd, cm, mi_row, mi_col, mv, pred8, second_pred, w, h, subpel_x_q3,
+              subpel_y_q3, y, y_stride, xd->bd, &xd->jcp_param, subpel_search);
         else
           aom_highbd_comp_avg_upsampled_pred(
-              xd, cm, mi_row, mi_col, mv, pred16, second_pred, w, h,
-              subpel_x_q3, subpel_y_q3, y, y_stride, xd->bd, subpel_search);
+              xd, cm, mi_row, mi_col, mv, pred8, second_pred, w, h, subpel_x_q3,
+              subpel_y_q3, y, y_stride, xd->bd, subpel_search);
       }
     } else {
-      aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, pred16, w, h,
+      aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, pred8, w, h,
                                 subpel_x_q3, subpel_y_q3, y, y_stride, xd->bd,
                                 subpel_search);
     }
-
-    besterr = vfp->vf(CONVERT_TO_BYTEPTR(pred16), w, src, src_stride, sse);
+    besterr = vfp->vf(pred8, w, src, src_stride, sse);
   } else {
     DECLARE_ALIGNED(16, uint8_t, pred[MAX_SB_SQUARE]);
     if (second_pred != NULL) {
@@ -2344,15 +2343,15 @@
     int subpel_x_q3, int subpel_y_q3, int w, int h, unsigned int *sse,
     int subpel_search) {
   unsigned int besterr;
+
+  DECLARE_ALIGNED(16, uint8_t, pred[2 * MAX_SB_SQUARE]);
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
-    DECLARE_ALIGNED(16, uint16_t, pred16[MAX_SB_SQUARE]);
-    aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, pred16, w, h,
+    uint8_t *pred8 = CONVERT_TO_BYTEPTR(pred);
+    aom_highbd_upsampled_pred(xd, cm, mi_row, mi_col, mv, pred8, w, h,
                               subpel_x_q3, subpel_y_q3, y, y_stride, xd->bd,
                               subpel_search);
-
-    besterr = vfp->ovf(CONVERT_TO_BYTEPTR(pred16), w, wsrc, mask, sse);
+    besterr = vfp->ovf(pred8, w, wsrc, mask, sse);
   } else {
-    DECLARE_ALIGNED(16, uint8_t, pred[MAX_SB_SQUARE]);
     aom_upsampled_pred(xd, cm, mi_row, mi_col, mv, pred, w, h, subpel_x_q3,
                        subpel_y_q3, y, y_stride, subpel_search);
 
diff --git a/test/comp_avg_pred_test.cc b/test/comp_avg_pred_test.cc
index 8bd826e..9ad8973 100644
--- a/test/comp_avg_pred_test.cc
+++ b/test/comp_avg_pred_test.cc
@@ -50,9 +50,9 @@
 TEST_P(AV1HighBDJNTCOMPAVGTest, CheckOutput) { RunCheckOutput(GET_PARAM(1)); }
 
 #if HAVE_SSE2
-INSTANTIATE_TEST_CASE_P(
-    SSE2, AV1HighBDJNTCOMPAVGTest,
-    libaom_test::AV1JNTCOMPAVG::BuildParams(aom_highbd_jnt_comp_avg_pred_sse2));
+INSTANTIATE_TEST_CASE_P(SSE2, AV1HighBDJNTCOMPAVGTest,
+                        libaom_test::AV1JNTCOMPAVG::BuildParams(
+                            aom_highbd_jnt_comp_avg_pred_sse2, 1));
 #endif
 
 TEST_P(AV1HighBDJNTCOMPAVGUPSAMPLEDTest, DISABLED_Speed) {
diff --git a/test/comp_avg_pred_test.h b/test/comp_avg_pred_test.h
index 7028d22..fe7c0ca 100644
--- a/test/comp_avg_pred_test.h
+++ b/test/comp_avg_pred_test.h
@@ -38,14 +38,9 @@
     int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref,
     int ref_stride, const JNT_COMP_PARAMS *jcp_param, int subpel_search);
 
-typedef void (*highbdjntcompavg_func)(uint16_t *comp_pred, const uint8_t *pred8,
-                                      int width, int height,
-                                      const uint8_t *ref8, int ref_stride,
-                                      const JNT_COMP_PARAMS *jcp_param);
-
 typedef void (*highbdjntcompavgupsampled_func)(
     MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col,
-    const MV *const mv, uint16_t *comp_pred, const uint8_t *pred8, int width,
+    const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width,
     int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8,
     int ref_stride, int bd, const JNT_COMP_PARAMS *jcp_param,
     int subpel_search);
@@ -55,7 +50,7 @@
 typedef ::testing::tuple<jntcompavgupsampled_func, BLOCK_SIZE>
     JNTCOMPAVGUPSAMPLEDParam;
 
-typedef ::testing::tuple<int, highbdjntcompavg_func, BLOCK_SIZE>
+typedef ::testing::tuple<int, jntcompavg_func, BLOCK_SIZE>
     HighbdJNTCOMPAVGParam;
 
 typedef ::testing::tuple<int, highbdjntcompavgupsampled_func, BLOCK_SIZE>
@@ -74,7 +69,8 @@
 }
 
 ::testing::internal::ParamGenerator<HighbdJNTCOMPAVGParam> BuildParams(
-    highbdjntcompavg_func filter) {
+    jntcompavg_func filter, int is_hbd) {
+  (void)is_hbd;
   return ::testing::Combine(::testing::Range(8, 13, 2),
                             ::testing::Values(filter),
                             ::testing::Range(BLOCK_4X4, BLOCK_SIZES_ALL));
@@ -324,7 +320,7 @@
   void TearDown() { libaom_test::ClearSystemState(); }
 
  protected:
-  void RunCheckOutput(highbdjntcompavg_func test_impl) {
+  void RunCheckOutput(jntcompavg_func test_impl) {
     const int w = kMaxSize, h = kMaxSize;
     const int block_idx = GET_PARAM(2);
     const int bd = GET_PARAM(0);
@@ -352,13 +348,14 @@
         const int offset_r = 3 + rnd_.PseudoUniform(h - in_h - 7);
         const int offset_c = 3 + rnd_.PseudoUniform(w - in_w - 7);
         aom_highbd_jnt_comp_avg_pred_c(
-            output, CONVERT_TO_BYTEPTR(pred8) + offset_r * w + offset_c, in_w,
-            in_h, CONVERT_TO_BYTEPTR(ref8) + offset_r * w + offset_c, in_w,
+            CONVERT_TO_BYTEPTR(output),
+            CONVERT_TO_BYTEPTR(pred8) + offset_r * w + offset_c, in_w, in_h,
+            CONVERT_TO_BYTEPTR(ref8) + offset_r * w + offset_c, in_w,
             &jnt_comp_params);
-        test_impl(output2, CONVERT_TO_BYTEPTR(pred8) + offset_r * w + offset_c,
-                  in_w, in_h,
-                  CONVERT_TO_BYTEPTR(ref8) + offset_r * w + offset_c, in_w,
-                  &jnt_comp_params);
+        test_impl(CONVERT_TO_BYTEPTR(output2),
+                  CONVERT_TO_BYTEPTR(pred8) + offset_r * w + offset_c, in_w,
+                  in_h, CONVERT_TO_BYTEPTR(ref8) + offset_r * w + offset_c,
+                  in_w, &jnt_comp_params);
 
         for (int i = 0; i < in_h; ++i) {
           for (int j = 0; j < in_w; ++j) {
@@ -372,7 +369,7 @@
       }
     }
   }
-  void RunSpeedTest(highbdjntcompavg_func test_impl) {
+  void RunSpeedTest(jntcompavg_func test_impl) {
     const int w = kMaxSize, h = kMaxSize;
     const int block_idx = GET_PARAM(2);
     const int bd = GET_PARAM(0);
@@ -400,9 +397,9 @@
     aom_usec_timer_start(&timer);
 
     for (int i = 0; i < num_loops; ++i)
-      aom_highbd_jnt_comp_avg_pred_c(output, CONVERT_TO_BYTEPTR(pred8), in_w,
-                                     in_h, CONVERT_TO_BYTEPTR(ref8), in_w,
-                                     &jnt_comp_params);
+      aom_highbd_jnt_comp_avg_pred_c(
+          CONVERT_TO_BYTEPTR(output), CONVERT_TO_BYTEPTR(pred8), in_w, in_h,
+          CONVERT_TO_BYTEPTR(ref8), in_w, &jnt_comp_params);
 
     aom_usec_timer_mark(&timer);
     const int elapsed_time = static_cast<int>(aom_usec_timer_elapsed(&timer));
@@ -413,8 +410,8 @@
     aom_usec_timer_start(&timer1);
 
     for (int i = 0; i < num_loops; ++i)
-      test_impl(output2, CONVERT_TO_BYTEPTR(pred8), in_w, in_h,
-                CONVERT_TO_BYTEPTR(ref8), in_w, &jnt_comp_params);
+      test_impl(CONVERT_TO_BYTEPTR(output2), CONVERT_TO_BYTEPTR(pred8), in_w,
+                in_h, CONVERT_TO_BYTEPTR(ref8), in_w, &jnt_comp_params);
 
     aom_usec_timer_mark(&timer1);
     const int elapsed_time1 = static_cast<int>(aom_usec_timer_elapsed(&timer1));
@@ -466,12 +463,12 @@
               const int offset_c = 3 + rnd_.PseudoUniform(w - in_w - 7);
 
               aom_highbd_jnt_comp_avg_upsampled_pred_c(
-                  NULL, NULL, 0, 0, NULL, output,
+                  NULL, NULL, 0, 0, NULL, CONVERT_TO_BYTEPTR(output),
                   CONVERT_TO_BYTEPTR(pred8) + offset_r * w + offset_c, in_w,
                   in_h, sub_x_q3, sub_y_q3,
                   CONVERT_TO_BYTEPTR(ref8) + offset_r * w + offset_c, in_w, bd,
                   &jnt_comp_params, subpel_search);
-              test_impl(NULL, NULL, 0, 0, NULL, output2,
+              test_impl(NULL, NULL, 0, 0, NULL, CONVERT_TO_BYTEPTR(output2),
                         CONVERT_TO_BYTEPTR(pred8) + offset_r * w + offset_c,
                         in_w, in_h, sub_x_q3, sub_y_q3,
                         CONVERT_TO_BYTEPTR(ref8) + offset_r * w + offset_c,
@@ -525,9 +522,9 @@
     int subpel_search = 2;  // set to 1 to test 4-tap filter.
     for (int i = 0; i < num_loops; ++i)
       aom_highbd_jnt_comp_avg_upsampled_pred_c(
-          NULL, NULL, 0, 0, NULL, output, CONVERT_TO_BYTEPTR(pred8), in_w, in_h,
-          sub_x_q3, sub_y_q3, CONVERT_TO_BYTEPTR(ref8), in_w, bd,
-          &jnt_comp_params, subpel_search);
+          NULL, NULL, 0, 0, NULL, CONVERT_TO_BYTEPTR(output),
+          CONVERT_TO_BYTEPTR(pred8), in_w, in_h, sub_x_q3, sub_y_q3,
+          CONVERT_TO_BYTEPTR(ref8), in_w, bd, &jnt_comp_params, subpel_search);
 
     aom_usec_timer_mark(&timer);
     const int elapsed_time = static_cast<int>(aom_usec_timer_elapsed(&timer));
@@ -538,9 +535,10 @@
     aom_usec_timer_start(&timer1);
 
     for (int i = 0; i < num_loops; ++i)
-      test_impl(NULL, NULL, 0, 0, NULL, output2, CONVERT_TO_BYTEPTR(pred8),
-                in_w, in_h, sub_x_q3, sub_y_q3, CONVERT_TO_BYTEPTR(ref8), in_w,
-                bd, &jnt_comp_params, subpel_search);
+      test_impl(NULL, NULL, 0, 0, NULL, CONVERT_TO_BYTEPTR(output2),
+                CONVERT_TO_BYTEPTR(pred8), in_w, in_h, sub_x_q3, sub_y_q3,
+                CONVERT_TO_BYTEPTR(ref8), in_w, bd, &jnt_comp_params,
+                subpel_search);
 
     aom_usec_timer_mark(&timer1);
     const int elapsed_time1 = static_cast<int>(aom_usec_timer_elapsed(&timer1));
diff --git a/test/comp_mask_variance_test.cc b/test/comp_mask_variance_test.cc
index b2ab496..e663469 100644
--- a/test/comp_mask_variance_test.cc
+++ b/test/comp_mask_variance_test.cc
@@ -276,7 +276,7 @@
 
 #endif  // ifndef aom_comp_mask_pred
 
-typedef void (*highbd_comp_mask_pred_func)(uint16_t *comp_pred,
+typedef void (*highbd_comp_mask_pred_func)(uint8_t *comp_pred8,
                                            const uint8_t *pred8, int width,
                                            int height, const uint8_t *ref8,
                                            int ref_stride, const uint8_t *mask,
@@ -362,11 +362,11 @@
   for (int wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
     const uint8_t *mask = av1_get_contiguous_soft_mask(wedge_index, 1, bsize);
 
-    aom_highbd_comp_mask_pred_c(comp_pred1_, CONVERT_TO_BYTEPTR(pred_), w, h,
-                                CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w,
-                                inv);
+    aom_highbd_comp_mask_pred_c(
+        CONVERT_TO_BYTEPTR(comp_pred1_), CONVERT_TO_BYTEPTR(pred_), w, h,
+        CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv);
 
-    test_impl(comp_pred2_, CONVERT_TO_BYTEPTR(pred_), w, h,
+    test_impl(CONVERT_TO_BYTEPTR(comp_pred2_), CONVERT_TO_BYTEPTR(pred_), w, h,
               CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv);
 
     ASSERT_EQ(CheckResult(w, h), true)
@@ -402,7 +402,7 @@
     aom_usec_timer_start(&timer);
     highbd_comp_mask_pred_func func = funcs[i];
     for (int j = 0; j < num_loops; ++j) {
-      func(comp_pred1_, CONVERT_TO_BYTEPTR(pred_), w, h,
+      func(CONVERT_TO_BYTEPTR(comp_pred1_), CONVERT_TO_BYTEPTR(pred_), w, h,
            CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, 0);
     }
     aom_usec_timer_mark(&timer);
@@ -482,15 +482,17 @@
 
         aom_highbd_comp_mask_pred = aom_highbd_comp_mask_pred_c;  // ref
         aom_highbd_comp_mask_upsampled_pred(
-            NULL, NULL, 0, 0, NULL, comp_pred1_, CONVERT_TO_BYTEPTR(pred_), w,
-            h, subx, suby, CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv,
-            bd_, subpel_search);
+            NULL, NULL, 0, 0, NULL, CONVERT_TO_BYTEPTR(comp_pred1_),
+            CONVERT_TO_BYTEPTR(pred_), w, h, subx, suby,
+            CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv, bd_,
+            subpel_search);
 
         aom_highbd_comp_mask_pred = test_impl;  // test
         aom_highbd_comp_mask_upsampled_pred(
-            NULL, NULL, 0, 0, NULL, comp_pred2_, CONVERT_TO_BYTEPTR(pred_), w,
-            h, subx, suby, CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv,
-            bd_, subpel_search);
+            NULL, NULL, 0, 0, NULL, CONVERT_TO_BYTEPTR(comp_pred2_),
+            CONVERT_TO_BYTEPTR(pred_), w, h, subx, suby,
+            CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, inv, bd_,
+            subpel_search);
         ASSERT_EQ(CheckResult(w, h), true)
             << " wedge " << wedge_index << " inv " << inv << "sub (" << subx
             << "," << suby << ")";
@@ -529,9 +531,9 @@
     int subpel_search = 2;  // set to 1 to test 4-tap filter.
     for (int j = 0; j < num_loops; ++j) {
       aom_highbd_comp_mask_upsampled_pred(
-          NULL, NULL, 0, 0, NULL, comp_pred1_, CONVERT_TO_BYTEPTR(pred_), w, h,
-          subx, suby, CONVERT_TO_BYTEPTR(ref_), MAX_SB_SIZE, mask, w, 0, bd_,
-          subpel_search);
+          NULL, NULL, 0, 0, NULL, CONVERT_TO_BYTEPTR(comp_pred1_),
+          CONVERT_TO_BYTEPTR(pred_), w, h, subx, suby, CONVERT_TO_BYTEPTR(ref_),
+          MAX_SB_SIZE, mask, w, 0, bd_, subpel_search);
     }
     aom_usec_timer_mark(&timer);
     double time = static_cast<double>(aom_usec_timer_elapsed(&timer));