Add SSE2 variant for mse_wxh_16bit in cdef

Added SSE2 variant for mse_wxh_16bit function
and unit test (MseWxHTest).

Module level gains:
BLOCKSIZE    Gain w.r.t. C
8x8             2.7x
8x4             2.5x
4x8             2.5x
4x4             2.2x

Change-Id: I9cdc4fc609ed074626e6533e14b5385564253523
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 79fba0a..d6b27a6 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1130,7 +1130,7 @@
   add_proto qw/unsigned int/, "aom_variance4x2", "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
 
   add_proto qw/uint64_t/, "aom_mse_wxh_16bit", "uint8_t *dst, int dstride,uint16_t *src, int sstride, int w, int h";
-  specialize qw/aom_mse_wxh_16bit   avx2/;
+  specialize qw/aom_mse_wxh_16bit  sse2 avx2/;
 
   foreach (@block_sizes) {
     ($w, $h) = @$_;
diff --git a/aom_dsp/x86/variance_sse2.c b/aom_dsp/x86/variance_sse2.c
index 97f71fc..1a24a37 100644
--- a/aom_dsp/x86/variance_sse2.c
+++ b/aom_dsp/x86/variance_sse2.c
@@ -756,3 +756,98 @@
     } while (i < height);
   }
 }
+
+uint64_t aom_mse_4xh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src,
+                                int sstride, int h) {
+  uint64_t sum = 0;
+  __m128i dst0_8x8, dst1_8x8, dst_16x8;
+  __m128i src0_16x4, src1_16x4, src_16x8;
+  __m128i res0_32x4, res1_32x4, res0_64x4, res1_64x4, res2_64x4, res3_64x4;
+  __m128i sub_result_16x8;
+  const __m128i zeros = _mm_setzero_si128();
+  __m128i square_result = _mm_setzero_si128();
+  for (int i = 0; i < h; i += 2) {
+    dst0_8x8 = _mm_cvtsi32_si128(*(uint32_t const *)(&dst[(i + 0) * dstride]));
+    dst1_8x8 = _mm_cvtsi32_si128(*(uint32_t const *)(&dst[(i + 1) * dstride]));
+    dst_16x8 = _mm_unpacklo_epi8(_mm_unpacklo_epi32(dst0_8x8, dst1_8x8), zeros);
+
+    src0_16x4 = _mm_loadl_epi64((__m128i const *)(&src[(i + 0) * sstride]));
+    src1_16x4 = _mm_loadl_epi64((__m128i const *)(&src[(i + 1) * sstride]));
+    src_16x8 = _mm_unpacklo_epi64(src0_16x4, src1_16x4);
+
+    sub_result_16x8 = _mm_sub_epi16(src_16x8, dst_16x8);
+
+    res0_32x4 = _mm_unpacklo_epi16(sub_result_16x8, zeros);
+    res1_32x4 = _mm_unpackhi_epi16(sub_result_16x8, zeros);
+
+    res0_32x4 = _mm_madd_epi16(res0_32x4, res0_32x4);
+    res1_32x4 = _mm_madd_epi16(res1_32x4, res1_32x4);
+
+    res0_64x4 = _mm_unpacklo_epi32(res0_32x4, zeros);
+    res1_64x4 = _mm_unpackhi_epi32(res0_32x4, zeros);
+    res2_64x4 = _mm_unpacklo_epi32(res1_32x4, zeros);
+    res3_64x4 = _mm_unpackhi_epi32(res1_32x4, zeros);
+
+    square_result = _mm_add_epi64(
+        square_result,
+        _mm_add_epi64(
+            _mm_add_epi64(_mm_add_epi64(res0_64x4, res1_64x4), res2_64x4),
+            res3_64x4));
+  }
+  const __m128i sum_1x64 =
+      _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
+  xx_storel_64(&sum, sum_1x64);
+  return sum;
+}
+
+uint64_t aom_mse_8xh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src,
+                                int sstride, int h) {
+  uint64_t sum = 0;
+  __m128i dst_8x8, dst_16x8;
+  __m128i src_16x8;
+  __m128i res0_32x4, res1_32x4, res0_64x4, res1_64x4, res2_64x4, res3_64x4;
+  __m128i sub_result_16x8;
+  const __m128i zeros = _mm_setzero_si128();
+  __m128i square_result = _mm_setzero_si128();
+
+  for (int i = 0; i < h; i++) {
+    dst_8x8 = _mm_loadl_epi64((__m128i const *)(&dst[(i + 0) * dstride]));
+    dst_16x8 = _mm_unpacklo_epi8(dst_8x8, zeros);
+
+    src_16x8 = _mm_loadu_si128((__m128i *)&src[i * sstride]);
+
+    sub_result_16x8 = _mm_sub_epi16(src_16x8, dst_16x8);
+
+    res0_32x4 = _mm_unpacklo_epi16(sub_result_16x8, zeros);
+    res1_32x4 = _mm_unpackhi_epi16(sub_result_16x8, zeros);
+
+    res0_32x4 = _mm_madd_epi16(res0_32x4, res0_32x4);
+    res1_32x4 = _mm_madd_epi16(res1_32x4, res1_32x4);
+
+    res0_64x4 = _mm_unpacklo_epi32(res0_32x4, zeros);
+    res1_64x4 = _mm_unpackhi_epi32(res0_32x4, zeros);
+    res2_64x4 = _mm_unpacklo_epi32(res1_32x4, zeros);
+    res3_64x4 = _mm_unpackhi_epi32(res1_32x4, zeros);
+
+    square_result = _mm_add_epi64(
+        square_result,
+        _mm_add_epi64(
+            _mm_add_epi64(_mm_add_epi64(res0_64x4, res1_64x4), res2_64x4),
+            res3_64x4));
+  }
+  const __m128i sum_1x64 =
+      _mm_add_epi64(square_result, _mm_srli_si128(square_result, 8));
+  xx_storel_64(&sum, sum_1x64);
+  return sum;
+}
+
+uint64_t aom_mse_wxh_16bit_sse2(uint8_t *dst, int dstride, uint16_t *src,
+                                int sstride, int w, int h) {
+  assert((w == 8 || w == 4) && (h == 8 || h == 4) &&
+         "w=8/4 and h=8/4 must satisfy");
+  switch (w) {
+    case 4: return aom_mse_4xh_16bit_sse2(dst, dstride, src, sstride, h);
+    case 8: return aom_mse_8xh_16bit_sse2(dst, dstride, src, sstride, h);
+    default: assert(0 && "unsupported width"); return -1;
+  }
+}
diff --git a/test/variance_test.cc b/test/variance_test.cc
index 4c016af..9932bef 100644
--- a/test/variance_test.cc
+++ b/test/variance_test.cc
@@ -1172,7 +1172,7 @@
          params_.bit_depth, elapsed_time);
 }
 
-typedef MseWxHTestClass<MseWxH16bitFunc> AvxMseWxHTest;
+typedef MseWxHTestClass<MseWxH16bitFunc> MseWxHTest;
 typedef MainTestClass<Get4x4SseFunc> AvxSseTest;
 typedef MainTestClass<VarianceMxNFunc> AvxMseTest;
 typedef MainTestClass<VarianceMxNFunc> AvxVarianceTest;
@@ -1181,11 +1181,12 @@
 typedef SubpelVarianceTest<DistWtdSubpixAvgVarMxNFunc>
     AvxDistWtdSubpelAvgVarianceTest;
 typedef ObmcVarianceTest<ObmcSubpelVarFunc> AvxObmcSubpelVarianceTest;
+typedef TestParams<MseWxH16bitFunc> MseWxHParams;
 
 TEST_P(AvxSseTest, RefSse) { RefTestSse(); }
 TEST_P(AvxSseTest, MaxSse) { MaxTestSse(); }
-TEST_P(AvxMseWxHTest, RefMse) { RefMatchTestMse(); }
-TEST_P(AvxMseWxHTest, DISABLED_SpeedMse) { SpeedTest(); }
+TEST_P(MseWxHTest, RefMse) { RefMatchTestMse(); }
+TEST_P(MseWxHTest, DISABLED_SpeedMse) { SpeedTest(); }
 TEST_P(AvxMseTest, RefMse) { RefTestMse(); }
 TEST_P(AvxMseTest, MaxMse) { MaxTestMse(); }
 TEST_P(AvxVarianceTest, Zero) { ZeroTest(); }
@@ -1795,6 +1796,13 @@
 #endif  // CONFIG_AV1_HIGHBITDEPTH
 
 #if HAVE_SSE2
+INSTANTIATE_TEST_SUITE_P(
+    SSE2, MseWxHTest,
+    ::testing::Values(MseWxHParams(3, 3, &aom_mse_wxh_16bit_sse2, 8),
+                      MseWxHParams(3, 2, &aom_mse_wxh_16bit_sse2, 8),
+                      MseWxHParams(2, 3, &aom_mse_wxh_16bit_sse2, 8),
+                      MseWxHParams(2, 2, &aom_mse_wxh_16bit_sse2, 8)));
+
 INSTANTIATE_TEST_SUITE_P(SSE2, SumOfSquaresTest,
                          ::testing::Values(aom_get_mb_ss_sse2));
 
@@ -2380,9 +2388,8 @@
 
 #if HAVE_AVX2
 
-typedef TestParams<MseWxH16bitFunc> MseWxHParams;
 INSTANTIATE_TEST_SUITE_P(
-    AVX2, AvxMseWxHTest,
+    AVX2, MseWxHTest,
     ::testing::Values(MseWxHParams(3, 3, &aom_mse_wxh_16bit_avx2, 8),
                       MseWxHParams(3, 2, &aom_mse_wxh_16bit_avx2, 8),
                       MseWxHParams(2, 3, &aom_mse_wxh_16bit_avx2, 8),