Add SSE2 variant for mse_wxh_16bit_highbd in cdef

Added SSE2 variant for mse_wxh_16bit_highbd function
and unit test (MseHBDWxHTest).

Module level gains:
BLOCKSIZE    Gain w.r.t. C
8x8             3.3x
8x4             3.0x
4x8             3.3x
4x4             3.2x

Change-Id: I9ad6851588e212a5417a0ab95d8f31e8656df1c1
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index ae25b00..2cef014 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1132,9 +1132,6 @@
   add_proto qw/uint64_t/, "aom_mse_wxh_16bit", "uint8_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
   specialize qw/aom_mse_wxh_16bit  sse2 avx2/;
 
-  add_proto qw/uint64_t/, "aom_mse_wxh_16bit_highbd", "uint16_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
-  specialize qw/aom_mse_wxh_16bit_highbd   avx2/;
-
   foreach (@block_sizes) {
     ($w, $h) = @$_;
     add_proto qw/unsigned int/, "aom_variance${w}x${h}", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
@@ -1528,6 +1525,9 @@
 
     add_proto qw/void aom_highbd_dist_wtd_comp_avg_pred/, "uint8_t *comp_pred8, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const DIST_WTD_COMP_PARAMS *jcp_param";
     specialize qw/aom_highbd_dist_wtd_comp_avg_pred sse2/;
+
+    add_proto qw/uint64_t/, "aom_mse_wxh_16bit_highbd", "uint16_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
+    specialize qw/aom_mse_wxh_16bit_highbd   sse2 avx2/;
   }
     #
     # Subpixel Variance
diff --git a/aom_dsp/x86/highbd_variance_sse2.c b/aom_dsp/x86/highbd_variance_sse2.c
index b7d15f9..d1bd7d4 100644
--- a/aom_dsp/x86/highbd_variance_sse2.c
+++ b/aom_dsp/x86/highbd_variance_sse2.c
@@ -840,3 +840,100 @@
     pred += 8;
   }
 }
+
+uint64_t aom_mse_4xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
+                                       uint16_t *src, int sstride, int h) {
+  uint64_t sum = 0;
+  __m128i reg0_4x16, reg1_4x16;
+  __m128i src_8x16;
+  __m128i dst_8x16;
+  __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
+  __m128i sub_result_8x16;
+  const __m128i zeros = _mm_setzero_si128();
+  __m128i square_result = _mm_setzero_si128();
+  for (int i = 0; i < h; i += 2) {
+    reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
+    reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 1) * dstride]));
+    dst_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
+
+    reg0_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
+    reg1_4x16 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
+    src_8x16 = _mm_unpacklo_epi64(reg0_4x16, reg1_4x16);
+
+    sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
+
+    res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
+    res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
+
+    res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
+    res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
+
+    res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
+    res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
+    res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
+    res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
+
+    square_result = _mm_add_epi64(
+        square_result,
+        _mm_add_epi64(
+            _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
+            res3_4x64));
+  }
+
+  const __m128i sum_1x64 =
+      _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
+  xx_storel_64(&sum, sum_1x64);
+  return sum;
+}
+
+uint64_t aom_mse_8xh_16bit_highbd_sse2(uint16_t *dst, int dstride,
+                                       uint16_t *src, int sstride, int h) {
+  uint64_t sum = 0;
+  __m128i src_8x16;
+  __m128i dst_8x16;
+  __m128i res0_4x32, res1_4x32, res0_4x64, res1_4x64, res2_4x64, res3_4x64;
+  __m128i sub_result_8x16;
+  const __m128i zeros = _mm_setzero_si128();
+  __m128i square_result = _mm_setzero_si128();
+
+  for (int i = 0; i < h; i++) {
+    dst_8x16 = _mm_loadu_si128((__m128i *)&dst[i * dstride]);
+    src_8x16 = _mm_loadu_si128((__m128i *)&src[i * sstride]);
+
+    sub_result_8x16 = _mm_sub_epi16(src_8x16, dst_8x16);
+
+    res0_4x32 = _mm_unpacklo_epi16(sub_result_8x16, zeros);
+    res1_4x32 = _mm_unpackhi_epi16(sub_result_8x16, zeros);
+
+    res0_4x32 = _mm_madd_epi16(res0_4x32, res0_4x32);
+    res1_4x32 = _mm_madd_epi16(res1_4x32, res1_4x32);
+
+    res0_4x64 = _mm_unpacklo_epi32(res0_4x32, zeros);
+    res1_4x64 = _mm_unpackhi_epi32(res0_4x32, zeros);
+    res2_4x64 = _mm_unpacklo_epi32(res1_4x32, zeros);
+    res3_4x64 = _mm_unpackhi_epi32(res1_4x32, zeros);
+
+    square_result = _mm_add_epi64(
+        square_result,
+        _mm_add_epi64(
+            _mm_add_epi64(_mm_add_epi64(res0_4x64, res1_4x64), res2_4x64),
+            res3_4x64));
+  }
+
+  const __m128i sum_1x64 =
+      _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
+  xx_storel_64(&sum, sum_1x64);
+  return sum;
+}
+
+uint64_t aom_mse_wxh_16bit_highbd_sse2(uint16_t *dst, int dstride,
+                                       uint16_t *src, int sstride, int w,
+                                       int h) {
+  assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
+         "w=8/4 and h=8/4 must satisfy");
+  switch (w) {
+    case 4: return aom_mse_4xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
+    case 8: return aom_mse_8xh_16bit_highbd_sse2(dst, dstride, src, sstride, h);
+    default: assert(0 && "unsupported width"); return -1;
+  }
+}
diff --git a/av1/encoder/pickcdef.c b/av1/encoder/pickcdef.c
index 1d9a315..ac293bc 100644
--- a/av1/encoder/pickcdef.c
+++ b/av1/encoder/pickcdef.c
@@ -226,6 +226,7 @@
                                         BLOCK_SIZE bsize, int coeff_shift,
                                         int row, int col);
 
+#if CONFIG_AV1_HIGHBITDEPTH
 static void copy_sb16_16_highbd(uint16_t *dst, int dstride, const void *src,
                                 int src_voffset, int src_hoffset, int sstride,
                                 int vsize, int hsize) {
@@ -235,6 +236,7 @@
   for (r = 0; r < vsize; r++)
     memcpy(dst + r * dstride, base + r * sstride, hsize * sizeof(*base));
 }
+#endif
 
 static void copy_sb16_16(uint16_t *dst, int dstride, const void *src,
                          int src_voffset, int src_hoffset, int sstride,
@@ -256,7 +258,7 @@
   *width_log2 = MI_SIZE_LOG2 + mi_size_wide_log2[bsize];
   *height_log2 = MI_SIZE_LOG2 + mi_size_wide_log2[bsize];
 }
-
+#if CONFIG_AV1_HIGHBITDEPTH
 /* Compute MSE only on the blocks we filtered. */
 static uint64_t compute_cdef_dist_highbd(void *dst, int dstride, uint16_t *src,
                                          cdef_list *dlist, int cdef_count,
@@ -280,7 +282,7 @@
   }
   return sum >> 2 * coeff_shift;
 }
-
+#endif
 static uint64_t compute_cdef_dist(void *dst, int dstride, uint16_t *src,
                                   cdef_list *dlist, int cdef_count,
                                   BLOCK_SIZE bsize, int coeff_shift, int row,
@@ -422,7 +424,7 @@
 
   copy_fn_t copy_fn;
   compute_cdef_dist_t compute_cdef_dist_fn;
-
+#if CONFIG_AV1_HIGHBITDEPTH
   if (cm->seq_params.use_highbitdepth) {
     copy_fn = copy_sb16_16_highbd;
     compute_cdef_dist_fn = compute_cdef_dist_highbd;
@@ -430,6 +432,10 @@
     copy_fn = copy_sb16_16;
     compute_cdef_dist_fn = compute_cdef_dist;
   }
+#else
+  copy_fn = copy_sb16_16;
+  compute_cdef_dist_fn = compute_cdef_dist;
+#endif
 
   DECLARE_ALIGNED(32, uint16_t, inbuf[CDEF_INBUF_SIZE]);
   uint16_t *const in = inbuf + CDEF_VBORDER * CDEF_BSTRIDE + CDEF_HBORDER;
diff --git a/test/variance_test.cc b/test/variance_test.cc
index e3c060e..72fc754 100644
--- a/test/variance_test.cc
+++ b/test/variance_test.cc
@@ -2002,6 +2002,15 @@
                                 0)));
 
 #if CONFIG_AV1_HIGHBITDEPTH
+#if HAVE_SSE2
+INSTANTIATE_TEST_SUITE_P(
+    SSE2, MseHBDWxHTest,
+    ::testing::Values(MseHBDWxHParams(3, 3, &aom_mse_wxh_16bit_highbd_sse2, 10),
+                      MseHBDWxHParams(3, 2, &aom_mse_wxh_16bit_highbd_sse2, 10),
+                      MseHBDWxHParams(2, 3, &aom_mse_wxh_16bit_highbd_sse2, 10),
+                      MseHBDWxHParams(2, 2, &aom_mse_wxh_16bit_highbd_sse2,
+                                      10)));
+#endif  // HAVE_SSE2
 #if HAVE_SSE4_1
 INSTANTIATE_TEST_SUITE_P(
     SSE4_1, AvxSubpelVarianceTest,