Unify HBD and 8-bit Joint convolve functions

HBD Joint convolve functions are modified similar
to 8-bit functions.
Unit test updated for new functions.

Change-Id: Id0a292e1366084b711c3a72a653786b6edaee026
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 5f11b07..139b184 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -537,6 +537,12 @@
 if (aom_config("CONFIG_JNT_COMP") eq "yes") {
   add_proto qw/void av1_highbd_jnt_convolve_2d/, "const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, int h, InterpFilterParams *filter_params_x, InterpFilterParams *filter_params_y, const int subpel_x_q4, const int subpel_y_q4, ConvolveParams *conv_params, int bd";
   specialize qw/av1_highbd_jnt_convolve_2d sse4_1/;
+
+  add_proto qw/void av1_highbd_jnt_convolve_x/, "const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, int h, InterpFilterParams *filter_params_x, InterpFilterParams *filter_params_y, const int subpel_x_q4, const int subpel_y_q4, ConvolveParams *conv_params, int bd";
+  
+  add_proto qw/void av1_highbd_jnt_convolve_y/, "const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, int h, InterpFilterParams *filter_params_x, InterpFilterParams *filter_params_y, const int subpel_x_q4, const int subpel_y_q4, ConvolveParams *conv_params, int bd";
+
+  add_proto qw/void av1_highbd_jnt_convolve_2d_copy/, "const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, int h, InterpFilterParams *filter_params_x, InterpFilterParams *filter_params_y, const int subpel_x_q4, const int subpel_y_q4, ConvolveParams *conv_params, int bd";
 }
 
 # INTRA_EDGE functions
diff --git a/av1/common/convolve.c b/av1/common/convolve.c
index c80b3cb..2a96da5 100644
--- a/av1/common/convolve.c
+++ b/av1/common/convolve.c
@@ -1352,6 +1352,128 @@
     }
   }
 }
+
+void av1_highbd_jnt_convolve_x_c(const uint16_t *src, int src_stride,
+                                 uint16_t *dst0, int dst_stride0, int w, int h,
+                                 InterpFilterParams *filter_params_x,
+                                 InterpFilterParams *filter_params_y,
+                                 const int subpel_x_q4, const int subpel_y_q4,
+                                 ConvolveParams *conv_params, int bd) {
+  CONV_BUF_TYPE *dst = conv_params->dst;
+  int dst_stride = conv_params->dst_stride;
+  const int fo_horiz = filter_params_x->taps / 2 - 1;
+  const int bits = FILTER_BITS - conv_params->round_1;
+  (void)filter_params_y;
+  (void)subpel_y_q4;
+  (void)dst0;
+  (void)dst_stride0;
+  (void)bd;
+
+  // horizontal filter
+  const int16_t *x_filter = av1_get_interp_filter_subpel_kernel(
+      *filter_params_x, subpel_x_q4 & SUBPEL_MASK);
+  for (int y = 0; y < h; ++y) {
+    for (int x = 0; x < w; ++x) {
+      CONV_BUF_TYPE res = 0;
+      for (int k = 0; k < filter_params_x->taps; ++k) {
+        res += x_filter[k] * src[y * src_stride + x - fo_horiz + k];
+      }
+      res = (1 << bits) * ROUND_POWER_OF_TWO(res, conv_params->round_0);
+      if (conv_params->use_jnt_comp_avg) {
+        if (conv_params->do_average) {
+          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+        } else {
+          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+        }
+      } else {
+        if (conv_params->do_average)
+          dst[y * dst_stride + x] += res;
+        else
+          dst[y * dst_stride + x] = res;
+      }
+    }
+  }
+}
+
+void av1_highbd_jnt_convolve_y_c(const uint16_t *src, int src_stride,
+                                 uint16_t *dst0, int dst_stride0, int w, int h,
+                                 InterpFilterParams *filter_params_x,
+                                 InterpFilterParams *filter_params_y,
+                                 const int subpel_x_q4, const int subpel_y_q4,
+                                 ConvolveParams *conv_params, int bd) {
+  CONV_BUF_TYPE *dst = conv_params->dst;
+  int dst_stride = conv_params->dst_stride;
+  const int fo_vert = filter_params_y->taps / 2 - 1;
+  const int bits = FILTER_BITS - conv_params->round_0;
+  (void)filter_params_x;
+  (void)subpel_x_q4;
+  (void)dst0;
+  (void)dst_stride0;
+  (void)bd;
+
+  // vertical filter
+  const int16_t *y_filter = av1_get_interp_filter_subpel_kernel(
+      *filter_params_y, subpel_y_q4 & SUBPEL_MASK);
+  for (int y = 0; y < h; ++y) {
+    for (int x = 0; x < w; ++x) {
+      CONV_BUF_TYPE res = 0;
+      for (int k = 0; k < filter_params_y->taps; ++k) {
+        res += y_filter[k] * src[(y - fo_vert + k) * src_stride + x];
+      }
+      res *= (1 << bits);
+      res = ROUND_POWER_OF_TWO(res, conv_params->round_1);
+      if (conv_params->use_jnt_comp_avg) {
+        if (conv_params->do_average) {
+          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+        } else {
+          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+        }
+      } else {
+        if (conv_params->do_average)
+          dst[y * dst_stride + x] += res;
+        else
+          dst[y * dst_stride + x] = res;
+      }
+    }
+  }
+}
+
+void av1_highbd_jnt_convolve_2d_copy_c(
+    const uint16_t *src, int src_stride, uint16_t *dst0, int dst_stride0, int w,
+    int h, InterpFilterParams *filter_params_x,
+    InterpFilterParams *filter_params_y, const int subpel_x_q4,
+    const int subpel_y_q4, ConvolveParams *conv_params, int bd) {
+  CONV_BUF_TYPE *dst = conv_params->dst;
+  int dst_stride = conv_params->dst_stride;
+  const int bits =
+      FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
+
+  (void)filter_params_x;
+  (void)filter_params_y;
+  (void)subpel_x_q4;
+  (void)subpel_y_q4;
+  (void)dst0;
+  (void)dst_stride0;
+  (void)bd;
+
+  for (int y = 0; y < h; ++y) {
+    for (int x = 0; x < w; ++x) {
+      CONV_BUF_TYPE res = src[y * src_stride + x] << bits;
+      if (conv_params->use_jnt_comp_avg) {
+        if (conv_params->do_average) {
+          dst[y * dst_stride + x] += res * conv_params->bck_offset;
+        } else {
+          dst[y * dst_stride + x] = res * conv_params->fwd_offset;
+        }
+      } else {
+        if (conv_params->do_average)
+          dst[y * dst_stride + x] += res;
+        else
+          dst[y * dst_stride + x] = res;
+      }
+    }
+  }
+}
 #endif  // CONFIG_JNT_COMP
 
 void av1_highbd_convolve_2d_scale_c(const uint16_t *src, int src_stride,
diff --git a/av1/common/scale.c b/av1/common/scale.c
index 1b3db93..71ac0bc 100644
--- a/av1/common/scale.c
+++ b/av1/common/scale.c
@@ -218,11 +218,11 @@
   sf->highbd_convolve[1][1][0] = av1_highbd_convolve_2d_sr;
 #if CONFIG_JNT_COMP
   // subpel_x_q4 == 0 && subpel_y_q4 == 0
-  sf->highbd_convolve[0][0][1] = av1_highbd_jnt_convolve_2d;
+  sf->highbd_convolve[0][0][1] = av1_highbd_jnt_convolve_2d_copy;
   // subpel_x_q4 == 0
-  sf->highbd_convolve[0][1][1] = av1_highbd_jnt_convolve_2d;
+  sf->highbd_convolve[0][1][1] = av1_highbd_jnt_convolve_y;
   // subpel_y_q4 == 0
-  sf->highbd_convolve[1][0][1] = av1_highbd_jnt_convolve_2d;
+  sf->highbd_convolve[1][0][1] = av1_highbd_jnt_convolve_x;
   // subpel_x_q4 != 0 && subpel_y_q4 != 0
   sf->highbd_convolve[1][1][1] = av1_highbd_jnt_convolve_2d;
 #else
diff --git a/test/av1_convolve_2d_test.cc b/test/av1_convolve_2d_test.cc
index 688d8dc..b8b5971 100644
--- a/test/av1_convolve_2d_test.cc
+++ b/test/av1_convolve_2d_test.cc
@@ -178,6 +178,18 @@
 INSTANTIATE_TEST_CASE_P(SSE4_1, AV1HighbdJntConvolve2DTest,
                         libaom_test::AV1HighbdConvolve2D::BuildParams(
                             av1_highbd_jnt_convolve_2d_sse4_1, 1, 1, 1));
+
+INSTANTIATE_TEST_CASE_P(C_X, AV1HighbdJntConvolve2DTest,
+                        libaom_test::AV1HighbdConvolve2D::BuildParams(
+                            av1_highbd_jnt_convolve_x_c, 1, 0, 1));
+
+INSTANTIATE_TEST_CASE_P(C_Y, AV1HighbdJntConvolve2DTest,
+                        libaom_test::AV1HighbdConvolve2D::BuildParams(
+                            av1_highbd_jnt_convolve_y_c, 0, 1, 1));
+
+INSTANTIATE_TEST_CASE_P(C_COPY, AV1HighbdJntConvolve2DTest,
+                        libaom_test::AV1HighbdConvolve2D::BuildParams(
+                            av1_highbd_jnt_convolve_2d_copy_c, 0, 0, 1));
 #endif  // CONFIG_JNT_COMP
 #endif
 
diff --git a/test/av1_convolve_2d_test_util.cc b/test/av1_convolve_2d_test_util.cc
index 888afaf..b166630 100644
--- a/test/av1_convolve_2d_test_util.cc
+++ b/test/av1_convolve_2d_test_util.cc
@@ -577,6 +577,8 @@
     highbd_convolve_2d_func test_impl) {
   const int w = kMaxSize, h = kMaxSize;
   const int bd = GET_PARAM(0);
+  const int has_subx = GET_PARAM(2);
+  const int has_suby = GET_PARAM(3);
   int hfilter, vfilter, subx, suby;
   uint16_t input[kMaxSize * kMaxSize];
   DECLARE_ALIGNED(32, CONV_BUF_TYPE, output[MAX_SB_SQUARE]);
@@ -608,8 +610,10 @@
           conv_params1.use_jnt_comp_avg = 0;
           conv_params2.use_jnt_comp_avg = 0;
 
-          for (subx = 0; subx < 16; ++subx) {
-            for (suby = 0; suby < 16; ++suby) {
+          const int subx_range = has_subx ? 16 : 1;
+          const int suby_range = has_suby ? 16 : 1;
+          for (subx = 0; subx < subx_range; ++subx) {
+            for (suby = 0; suby < suby_range; ++suby) {
               // Choose random locations within the source block
               const int offset_r = 3 + rnd_.PseudoUniform(h - out_h - 7);
               const int offset_c = 3 + rnd_.PseudoUniform(w - out_w - 7);
@@ -644,8 +648,10 @@
               conv_params2.fwd_offset = quant_dist_lookup_table[k][l][0];
               conv_params2.bck_offset = quant_dist_lookup_table[k][l][1];
 
-              for (subx = 0; subx < 16; ++subx) {
-                for (suby = 0; suby < 16; ++suby) {
+              const int subx_range = has_subx ? 16 : 1;
+              const int suby_range = has_suby ? 16 : 1;
+              for (subx = 0; subx < subx_range; ++subx) {
+                for (suby = 0; suby < suby_range; ++suby) {
                   // Choose random locations within the source block
                   const int offset_r = 3 + rnd_.PseudoUniform(h - out_h - 7);
                   const int offset_c = 3 + rnd_.PseudoUniform(w - out_w - 7);