Add AVX2 variant for sum_squares_2d

Achieved module level gains improved by ~71%
on average for width >= 16 and observed no drop
in gains for remaining widths.

Change-Id: Ica295d97b9854a2a22e3af9cb8671f91376e2562
diff --git a/aom_dsp/aom_dsp.cmake b/aom_dsp/aom_dsp.cmake
index a76f88f..e2bf87b 100644
--- a/aom_dsp/aom_dsp.cmake
+++ b/aom_dsp/aom_dsp.cmake
@@ -67,7 +67,8 @@
             "${AOM_ROOT}/aom_dsp/x86/lpf_common_sse2.h"
             "${AOM_ROOT}/aom_dsp/x86/mem_sse2.h"
             "${AOM_ROOT}/aom_dsp/x86/transpose_sse2.h"
-            "${AOM_ROOT}/aom_dsp/x86/txfm_common_sse2.h")
+            "${AOM_ROOT}/aom_dsp/x86/txfm_common_sse2.h"
+            "${AOM_ROOT}/aom_dsp/x86/sum_squares_sse2.h")
 
 list(APPEND AOM_DSP_COMMON_ASM_SSSE3
             "${AOM_ROOT}/aom_dsp/x86/aom_subpixel_8t_ssse3.asm"
@@ -201,7 +202,8 @@
               "${AOM_ROOT}/aom_dsp/x86/sse_avx2.c"
               "${AOM_ROOT}/aom_dsp/x86/variance_impl_avx2.c"
               "${AOM_ROOT}/aom_dsp/x86/obmc_sad_avx2.c"
-              "${AOM_ROOT}/aom_dsp/x86/obmc_variance_avx2.c")
+              "${AOM_ROOT}/aom_dsp/x86/obmc_variance_avx2.c"
+              "${AOM_ROOT}/aom_dsp/x86/sum_squares_avx2.c")
 
   list(APPEND AOM_DSP_ENCODER_ASM_SSSE3_X86_64
               "${AOM_ROOT}/aom_dsp/x86/quantize_ssse3_x86_64.asm")
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 27e95b9..50af6fd 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -580,7 +580,7 @@
     # Sum of Squares
     #
     add_proto qw/uint64_t aom_sum_squares_2d_i16/, "const int16_t *src, int stride, int width, int height";
-    specialize qw/aom_sum_squares_2d_i16 sse2/;
+    specialize qw/aom_sum_squares_2d_i16 sse2 avx2/;
 
     add_proto qw/uint64_t aom_sum_squares_i16/, "const int16_t *src, uint32_t N";
     specialize qw/aom_sum_squares_i16 sse2/;
diff --git a/aom_dsp/x86/sum_squares_avx2.c b/aom_dsp/x86/sum_squares_avx2.c
new file mode 100644
index 0000000..4c00b15
--- /dev/null
+++ b/aom_dsp/x86/sum_squares_avx2.c
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <immintrin.h>
+#include <smmintrin.h>
+
+#include "aom_dsp/x86/synonyms.h"
+#include "aom_dsp/x86/sum_squares_sse2.h"
+#include "config/aom_dsp_rtcd.h"
+
+static uint64_t aom_sum_squares_2d_i16_nxn_avx2(const int16_t *src, int stride,
+                                                int width, int height) {
+  uint64_t result;
+  __m256i v_acc_q = _mm256_setzero_si256();
+  const __m256i v_zext_mask_q = _mm256_set1_epi64x(0xffffffff);
+  for (int col = 0; col < height; col += 4) {
+    __m256i v_acc_d = _mm256_setzero_si256();
+    for (int row = 0; row < width; row += 16) {
+      const int16_t *tempsrc = src + row;
+      const __m256i v_val_0_w =
+          _mm256_loadu_si256((const __m256i *)(tempsrc + 0 * stride));
+      const __m256i v_val_1_w =
+          _mm256_loadu_si256((const __m256i *)(tempsrc + 1 * stride));
+      const __m256i v_val_2_w =
+          _mm256_loadu_si256((const __m256i *)(tempsrc + 2 * stride));
+      const __m256i v_val_3_w =
+          _mm256_loadu_si256((const __m256i *)(tempsrc + 3 * stride));
+
+      const __m256i v_sq_0_d = _mm256_madd_epi16(v_val_0_w, v_val_0_w);
+      const __m256i v_sq_1_d = _mm256_madd_epi16(v_val_1_w, v_val_1_w);
+      const __m256i v_sq_2_d = _mm256_madd_epi16(v_val_2_w, v_val_2_w);
+      const __m256i v_sq_3_d = _mm256_madd_epi16(v_val_3_w, v_val_3_w);
+
+      const __m256i v_sum_01_d = _mm256_add_epi32(v_sq_0_d, v_sq_1_d);
+      const __m256i v_sum_23_d = _mm256_add_epi32(v_sq_2_d, v_sq_3_d);
+      const __m256i v_sum_0123_d = _mm256_add_epi32(v_sum_01_d, v_sum_23_d);
+
+      v_acc_d = _mm256_add_epi32(v_acc_d, v_sum_0123_d);
+    }
+    v_acc_q =
+        _mm256_add_epi64(v_acc_q, _mm256_and_si256(v_acc_d, v_zext_mask_q));
+    v_acc_q = _mm256_add_epi64(v_acc_q, _mm256_srli_epi64(v_acc_d, 32));
+    src += 4 * stride;
+  }
+  __m128i lower_64_2_Value = _mm256_castsi256_si128(v_acc_q);
+  __m128i higher_64_2_Value = _mm256_extracti128_si256(v_acc_q, 1);
+  __m128i result_64_2_int = _mm_add_epi64(lower_64_2_Value, higher_64_2_Value);
+
+  result_64_2_int = _mm_add_epi64(
+      result_64_2_int, _mm_unpackhi_epi64(result_64_2_int, result_64_2_int));
+
+  xx_storel_64(&result, result_64_2_int);
+
+  return result;
+}
+
+uint64_t aom_sum_squares_2d_i16_avx2(const int16_t *src, int stride, int width,
+                                     int height) {
+  if (LIKELY(width == 4 && height == 4)) {
+    return aom_sum_squares_2d_i16_4x4_sse2(src, stride);
+  } else if (LIKELY(width == 4 && (height & 3) == 0)) {
+    return aom_sum_squares_2d_i16_4xn_sse2(src, stride, height);
+  } else if (LIKELY(width == 8 && (height & 3) == 0)) {
+    return aom_sum_squares_2d_i16_nxn_sse2(src, stride, width, height);
+  } else if (LIKELY(((width & 15) == 0) && ((height & 3) == 0))) {
+    return aom_sum_squares_2d_i16_nxn_avx2(src, stride, width, height);
+  } else {
+    return aom_sum_squares_2d_i16_c(src, stride, width, height);
+  }
+}
diff --git a/aom_dsp/x86/sum_squares_sse2.c b/aom_dsp/x86/sum_squares_sse2.c
index a79f22d..22d7739 100644
--- a/aom_dsp/x86/sum_squares_sse2.c
+++ b/aom_dsp/x86/sum_squares_sse2.c
@@ -14,6 +14,7 @@
 #include <stdio.h>
 
 #include "aom_dsp/x86/synonyms.h"
+#include "aom_dsp/x86/sum_squares_sse2.h"
 #include "config/aom_dsp_rtcd.h"
 
 static INLINE __m128i xx_loadh_64(__m128i a, const void *b) {
@@ -44,8 +45,7 @@
   return _mm_add_epi32(v_sq_01_d, v_sq_23_d);
 }
 
-static uint64_t aom_sum_squares_2d_i16_4x4_sse2(const int16_t *src,
-                                                int stride) {
+uint64_t aom_sum_squares_2d_i16_4x4_sse2(const int16_t *src, int stride) {
   const __m128i v_sum_0123_d = sum_squares_i16_4x4_sse2(src, stride);
   __m128i v_sum_d =
       _mm_add_epi32(v_sum_0123_d, _mm_srli_epi64(v_sum_0123_d, 32));
@@ -53,8 +53,8 @@
   return (uint64_t)_mm_cvtsi128_si32(v_sum_d);
 }
 
-static uint64_t aom_sum_squares_2d_i16_4xn_sse2(const int16_t *src, int stride,
-                                                int height) {
+uint64_t aom_sum_squares_2d_i16_4xn_sse2(const int16_t *src, int stride,
+                                         int height) {
   int r = 0;
   __m128i v_acc_q = _mm_setzero_si128();
   do {
@@ -76,7 +76,7 @@
 // maintenance instructions in the common case of 4x4.
 __attribute__((noinline))
 #endif
-static uint64_t
+uint64_t
 aom_sum_squares_2d_i16_nxn_sse2(const int16_t *src, int stride, int width,
                                 int height) {
   int r = 0;
diff --git a/aom_dsp/x86/sum_squares_sse2.h b/aom_dsp/x86/sum_squares_sse2.h
new file mode 100644
index 0000000..491e31c
--- /dev/null
+++ b/aom_dsp/x86/sum_squares_sse2.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_DSP_X86_SUM_SQUARES_SSE2_H_
+#define AOM_DSP_X86_SUM_SQUARES_SSE2_H_
+
+uint64_t aom_sum_squares_2d_i16_nxn_sse2(const int16_t *src, int stride,
+                                         int width, int height);
+
+uint64_t aom_sum_squares_2d_i16_4xn_sse2(const int16_t *src, int stride,
+                                         int height);
+uint64_t aom_sum_squares_2d_i16_4x4_sse2(const int16_t *src, int stride);
+
+#endif  // AOM_DSP_X86_SUM_SQUARES_SSE2_H_
diff --git a/test/sum_squares_test.cc b/test/sum_squares_test.cc
index c03ebad..f109984 100644
--- a/test/sum_squares_test.cc
+++ b/test/sum_squares_test.cc
@@ -52,6 +52,7 @@
     aom_free(src_);
   }
   void RunTest(int isRandom);
+  void RunSpeedTest();
 
   void GenRandomData(int width, int height, int stride) {
     const int msb = 11;  // Up to 12 bit input
@@ -108,6 +109,38 @@
   }
 }
 
+void SumSquaresTest::RunSpeedTest() {
+  for (int block = BLOCK_4X4; block < BLOCK_SIZES_ALL; block++) {
+    const int width = block_size_wide[block];   // Up to 128x128
+    const int height = block_size_high[block];  // Up to 128x128
+    int stride = 4 << rnd_(7);                  // Up to 256 stride
+    while (stride < width) {                    // Make sure it's valid
+      stride = 4 << rnd_(7);
+    }
+    GenExtremeData(width, height, stride);
+    const int num_loops = 1000000000 / (width + height);
+    aom_usec_timer timer;
+    aom_usec_timer_start(&timer);
+
+    for (int i = 0; i < num_loops; ++i)
+      params_.ref_func(src_, stride, width, height);
+
+    aom_usec_timer_mark(&timer);
+    const int elapsed_time = static_cast<int>(aom_usec_timer_elapsed(&timer));
+    printf("SumSquaresTest C %3dx%-3d: %7.2f ns\n", width, height,
+           1000.0 * elapsed_time / num_loops);
+
+    aom_usec_timer timer1;
+    aom_usec_timer_start(&timer1);
+    for (int i = 0; i < num_loops; ++i)
+      params_.tst_func(src_, stride, width, height);
+    aom_usec_timer_mark(&timer1);
+    const int elapsed_time1 = static_cast<int>(aom_usec_timer_elapsed(&timer1));
+    printf("SumSquaresTest Test %3dx%-3d: %7.2f ns\n", width, height,
+           1000.0 * elapsed_time1 / num_loops);
+  }
+}
+
 TEST_P(SumSquaresTest, OperationCheck) {
   RunTest(1);  // GenRandomData
 }
@@ -116,6 +149,8 @@
   RunTest(0);  // GenExtremeData
 }
 
+TEST_P(SumSquaresTest, DISABLED_Speed) { RunSpeedTest(); }
+
 #if HAVE_SSE2
 
 INSTANTIATE_TEST_CASE_P(
@@ -125,6 +160,13 @@
 
 #endif  // HAVE_SSE2
 
+#if HAVE_AVX2
+INSTANTIATE_TEST_CASE_P(
+    AVX2, SumSquaresTest,
+    ::testing::Values(TestFuncs(&aom_sum_squares_2d_i16_c,
+                                &aom_sum_squares_2d_i16_avx2)));
+#endif  // HAVE_AVX2
+
 //////////////////////////////////////////////////////////////////////////////
 // 1D version
 //////////////////////////////////////////////////////////////////////////////