Add a few avx2 functions for sad

Change-Id: I2b656dbc5467434180abb0e4a7768d530f304b2b
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 8c51e7f..ad16f01 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -683,8 +683,8 @@
   specialize qw/aom_dist_wtd_sad16x4_avg     ssse3/;
   specialize qw/aom_dist_wtd_sad8x32_avg     ssse3/;
   specialize qw/aom_dist_wtd_sad32x8_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad16x64_avg     ssse3/;
-  specialize qw/aom_dist_wtd_sad64x16_avg     ssse3/;
+  specialize qw/aom_dist_wtd_sad16x64_avg    ssse3/;
+  specialize qw/aom_dist_wtd_sad64x16_avg    ssse3/;
 
   add_proto qw/unsigned int/, "aom_sad4xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
   add_proto qw/unsigned int/, "aom_sad8xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
@@ -737,17 +737,17 @@
     specialize qw/aom_highbd_sad16x8_avg    avx2 sse2/;
     specialize qw/aom_highbd_sad8x4_avg     sse2/;
 
-    specialize qw/aom_highbd_sad16x4       sse2/;
-    specialize qw/aom_highbd_sad8x32       sse2/;
-    specialize qw/aom_highbd_sad32x8       sse2/;
-    specialize qw/aom_highbd_sad16x64      sse2/;
-    specialize qw/aom_highbd_sad64x16      sse2/;
+    specialize qw/aom_highbd_sad16x4        avx2 sse2/;
+    specialize qw/aom_highbd_sad8x32        sse2/;
+    specialize qw/aom_highbd_sad32x8        avx2 sse2/;
+    specialize qw/aom_highbd_sad16x64       avx2 sse2/;
+    specialize qw/aom_highbd_sad64x16       avx2 sse2/;
 
-    specialize qw/aom_highbd_sad16x4_avg   sse2/;
-    specialize qw/aom_highbd_sad8x32_avg   sse2/;
-    specialize qw/aom_highbd_sad32x8_avg   sse2/;
-    specialize qw/aom_highbd_sad16x64_avg  sse2/;
-    specialize qw/aom_highbd_sad64x16_avg  sse2/;
+    specialize qw/aom_highbd_sad16x4_avg    avx2 sse2/;
+    specialize qw/aom_highbd_sad8x32_avg    sse2/;
+    specialize qw/aom_highbd_sad32x8_avg    avx2 sse2/;
+    specialize qw/aom_highbd_sad16x64_avg   avx2 sse2/;
+    specialize qw/aom_highbd_sad64x16_avg   avx2 sse2/;
 
   #
   # Masked SAD
@@ -846,12 +846,12 @@
   specialize qw/aom_highbd_sad4x8x4d     sse2/;
   specialize qw/aom_highbd_sad4x4x4d     sse2/;
 
-  specialize qw/aom_highbd_sad4x16x4d  sse2/;
-  specialize qw/aom_highbd_sad16x4x4d  sse2/;
-  specialize qw/aom_highbd_sad8x32x4d  sse2/;
-  specialize qw/aom_highbd_sad32x8x4d  sse2/;
-  specialize qw/aom_highbd_sad16x64x4d sse2/;
-  specialize qw/aom_highbd_sad64x16x4d sse2/;
+  specialize qw/aom_highbd_sad4x16x4d         sse2/;
+  specialize qw/aom_highbd_sad16x4x4d    avx2 sse2/;
+  specialize qw/aom_highbd_sad8x32x4d         sse2/;
+  specialize qw/aom_highbd_sad32x8x4d    avx2 sse2/;
+  specialize qw/aom_highbd_sad16x64x4d   avx2 sse2/;
+  specialize qw/aom_highbd_sad64x16x4d   avx2 sse2/;
 
   #
   # Avg
diff --git a/aom_dsp/x86/sad_highbd_avx2.c b/aom_dsp/x86/sad_highbd_avx2.c
index b506d46..eba442c 100644
--- a/aom_dsp/x86/sad_highbd_avx2.c
+++ b/aom_dsp/x86/sad_highbd_avx2.c
@@ -229,6 +229,23 @@
   return get_sad_from_mm256_epi32(&sad);
 }
 
+unsigned int aom_highbd_sad32x8_avx2(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride) {
+  __m256i sad = _mm256_setzero_si256();
+  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+  const int left_shift = 2;
+  int row_section = 0;
+
+  while (row_section < 2) {
+    sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
+    srcp += src_stride << left_shift;
+    refp += ref_stride << left_shift;
+    row_section += 1;
+  }
+  return get_sad_from_mm256_epi32(&sad);
+}
+
 unsigned int aom_highbd_sad16x32_avx2(const uint8_t *src, int src_stride,
                                       const uint8_t *ref, int ref_stride) {
   uint32_t sum = aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride);
@@ -352,6 +369,23 @@
   return get_sad_from_mm256_epi32(&sad);
 }
 
+unsigned int aom_highbd_sad64x16_avx2(const uint8_t *src, int src_stride,
+                                      const uint8_t *ref, int ref_stride) {
+  __m256i sad = _mm256_setzero_si256();
+  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+  const int left_shift = 1;
+  int row_section = 0;
+
+  while (row_section < 8) {
+    sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
+    srcp += src_stride << left_shift;
+    refp += ref_stride << left_shift;
+    row_section += 1;
+  }
+  return get_sad_from_mm256_epi32(&sad);
+}
+
 unsigned int aom_highbd_sad64x64_avx2(const uint8_t *src, int src_stride,
                                       const uint8_t *ref, int ref_stride) {
   uint32_t sum = aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride);
@@ -520,6 +554,27 @@
   *sad_acc = _mm256_add_epi32(*sad_acc, r0);
 }
 
+unsigned int aom_highbd_sad16x4_avx2(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride) {
+  __m256i sad = _mm256_setzero_si256();
+  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+  sad16x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
+  return get_sad_from_mm256_epi32(&sad);
+}
+
+unsigned int aom_highbd_sad16x4_avg_avx2(const uint8_t *src, int src_stride,
+                                         const uint8_t *ref, int ref_stride,
+                                         const uint8_t *second_pred) {
+  __m256i sad = _mm256_setzero_si256();
+  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+  uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
+  sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
+
+  return get_sad_from_mm256_epi32(&sad);
+}
+
 unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride,
                                          const uint8_t *ref, int ref_stride,
                                          const uint8_t *second_pred) {
@@ -566,6 +621,50 @@
   return sum;
 }
 
+unsigned int aom_highbd_sad16x64_avx2(const uint8_t *src, int src_stride,
+                                      const uint8_t *ref, int ref_stride) {
+  const int left_shift = 5;
+  uint32_t sum = aom_highbd_sad16x32_avx2(src, src_stride, ref, ref_stride);
+  src += src_stride << left_shift;
+  ref += ref_stride << left_shift;
+  sum += aom_highbd_sad16x32_avx2(src, src_stride, ref, ref_stride);
+  return sum;
+}
+
+unsigned int aom_highbd_sad16x64_avg_avx2(const uint8_t *src, int src_stride,
+                                          const uint8_t *ref, int ref_stride,
+                                          const uint8_t *second_pred) {
+  const int left_shift = 5;
+  uint32_t sum = aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
+                                              second_pred);
+  src += src_stride << left_shift;
+  ref += ref_stride << left_shift;
+  second_pred += 16 << left_shift;
+  sum += aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
+                                      second_pred);
+  return sum;
+}
+
+unsigned int aom_highbd_sad32x8_avg_avx2(const uint8_t *src, int src_stride,
+                                         const uint8_t *ref, int ref_stride,
+                                         const uint8_t *second_pred) {
+  __m256i sad = _mm256_setzero_si256();
+  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+  uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
+  const int left_shift = 2;
+  int row_section = 0;
+
+  while (row_section < 2) {
+    sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
+    srcp += src_stride << left_shift;
+    refp += ref_stride << left_shift;
+    secp += 32 << left_shift;
+    row_section += 1;
+  }
+  return get_sad_from_mm256_epi32(&sad);
+}
+
 unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride,
                                           const uint8_t *ref, int ref_stride,
                                           const uint8_t *second_pred) {
@@ -614,6 +713,26 @@
   return sum;
 }
 
+unsigned int aom_highbd_sad64x16_avg_avx2(const uint8_t *src, int src_stride,
+                                          const uint8_t *ref, int ref_stride,
+                                          const uint8_t *second_pred) {
+  __m256i sad = _mm256_setzero_si256();
+  uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+  uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+  uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
+  const int left_shift = 1;
+  int row_section = 0;
+
+  while (row_section < 8) {
+    sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
+    srcp += src_stride << left_shift;
+    refp += ref_stride << left_shift;
+    secp += 64 << left_shift;
+    row_section += 1;
+  }
+  return get_sad_from_mm256_epi32(&sad);
+}
+
 unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride,
                                           const uint8_t *ref, int ref_stride,
                                           const uint8_t *second_pred) {
@@ -752,6 +871,25 @@
   s[3] = _mm256_setzero_si256();
 }
 
+void aom_highbd_sad16x4x4d_avx2(const uint8_t *src, int src_stride,
+                                const uint8_t *const ref_array[],
+                                int ref_stride, uint32_t *sad_array) {
+  __m256i sad_vec[4];
+  const uint16_t *refp[4];
+  const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
+  const uint16_t *srcp;
+  int i;
+
+  init_sad(sad_vec);
+  convert_pointers(ref_array, refp);
+
+  for (i = 0; i < 4; ++i) {
+    srcp = keep;
+    sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
+  }
+  get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
+}
+
 void aom_highbd_sad16x8x4d_avx2(const uint8_t *src, int src_stride,
                                 const uint8_t *const ref_array[],
                                 int ref_stride, uint32_t *sad_array) {
@@ -827,6 +965,59 @@
   sad_array[3] = first_half[3] + second_half[3];
 }
 
+void aom_highbd_sad16x64x4d_avx2(const uint8_t *src, int src_stride,
+                                 const uint8_t *const ref_array[],
+                                 int ref_stride, uint32_t *sad_array) {
+  uint32_t first_half[4];
+  uint32_t second_half[4];
+  const uint8_t *ref[4];
+  const int shift_for_rows = 5;
+
+  ref[0] = ref_array[0];
+  ref[1] = ref_array[1];
+  ref[2] = ref_array[2];
+  ref[3] = ref_array[3];
+
+  aom_highbd_sad16x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
+  src += src_stride << shift_for_rows;
+  ref[0] += ref_stride << shift_for_rows;
+  ref[1] += ref_stride << shift_for_rows;
+  ref[2] += ref_stride << shift_for_rows;
+  ref[3] += ref_stride << shift_for_rows;
+  aom_highbd_sad16x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
+  sad_array[0] = first_half[0] + second_half[0];
+  sad_array[1] = first_half[1] + second_half[1];
+  sad_array[2] = first_half[2] + second_half[2];
+  sad_array[3] = first_half[3] + second_half[3];
+}
+
+void aom_highbd_sad32x8x4d_avx2(const uint8_t *src, int src_stride,
+                                const uint8_t *const ref_array[],
+                                int ref_stride, uint32_t *sad_array) {
+  __m256i sad_vec[4];
+  const uint16_t *refp[4];
+  const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
+  const uint16_t *srcp;
+  const int shift_for_4_rows = 2;
+  int i;
+  int rows_section;
+
+  init_sad(sad_vec);
+  convert_pointers(ref_array, refp);
+
+  for (i = 0; i < 4; ++i) {
+    srcp = keep;
+    rows_section = 0;
+    while (rows_section < 2) {
+      sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
+      srcp += src_stride << shift_for_4_rows;
+      refp[i] += ref_stride << shift_for_4_rows;
+      rows_section++;
+    }
+  }
+  get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
+}
+
 void aom_highbd_sad32x16x4d_avx2(const uint8_t *src, int src_stride,
                                  const uint8_t *const ref_array[],
                                  int ref_stride, uint32_t *sad_array) {
@@ -906,6 +1097,33 @@
   sad_array[3] = first_half[3] + second_half[3];
 }
 
+void aom_highbd_sad64x16x4d_avx2(const uint8_t *src, int src_stride,
+                                 const uint8_t *const ref_array[],
+                                 int ref_stride, uint32_t *sad_array) {
+  __m256i sad_vec[4];
+  const uint16_t *refp[4];
+  const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
+  const uint16_t *srcp;
+  const int shift_for_rows = 1;
+  int i;
+  int rows_section;
+
+  init_sad(sad_vec);
+  convert_pointers(ref_array, refp);
+
+  for (i = 0; i < 4; ++i) {
+    srcp = keep;
+    rows_section = 0;
+    while (rows_section < 8) {
+      sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
+      srcp += src_stride << shift_for_rows;
+      refp[i] += ref_stride << shift_for_rows;
+      rows_section++;
+    }
+  }
+  get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
+}
+
 void aom_highbd_sad64x32x4d_avx2(const uint8_t *src, int src_stride,
                                  const uint8_t *const ref_array[],
                                  int ref_stride, uint32_t *sad_array) {
diff --git a/test/sad_test.cc b/test/sad_test.cc
index 3bb6150..a4fd08d 100644
--- a/test/sad_test.cc
+++ b/test/sad_test.cc
@@ -1582,6 +1582,19 @@
   make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 8),
   make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 10),
   make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 12),
+
+  make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 8),
+  make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 10),
+  make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 12),
+  make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 8),
+  make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 10),
+  make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 12),
+  make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 8),
+  make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 10),
+  make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 12),
+  make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 8),
+  make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 10),
+  make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 12),
 };
 INSTANTIATE_TEST_CASE_P(AVX2, SADTest, ::testing::ValuesIn(avx2_tests));
 
@@ -1627,6 +1640,19 @@
   make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 8),
   make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 10),
   make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 12),
+
+  make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 8),
+  make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 10),
+  make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 12),
+  make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 8),
+  make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 10),
+  make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 12),
+  make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 8),
+  make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 10),
+  make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 12),
+  make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 8),
+  make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 10),
+  make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 12),
 };
 INSTANTIATE_TEST_CASE_P(AVX2, SADavgTest, ::testing::ValuesIn(avg_avx2_tests));
 
@@ -1671,6 +1697,19 @@
   make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 8),
   make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 10),
   make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 12),
+
+  make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 8),
+  make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 10),
+  make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 12),
+  make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 8),
+  make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 10),
+  make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 12),
+  make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 8),
+  make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 10),
+  make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 12),
+  make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 8),
+  make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 10),
+  make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 12),
 };
 INSTANTIATE_TEST_CASE_P(AVX2, SADx4Test, ::testing::ValuesIn(x4d_avx2_tests));
 #endif  // HAVE_AVX2