Add Neon implementation of aom_sum_sse_2d_i16

Add an Arm Neon implementation of aom_sum_sse_2d_i16 and use it instead
of the scalar C implementation.

Also add test coverage for the new Neon implementation.

Change-Id: I9f4982223cc267bdbbece0aa10ff5255cd7994b1
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 8349af7..fc9dab1 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -764,7 +764,7 @@
   }
 
   add_proto qw/uint64_t aom_sum_sse_2d_i16/, "const int16_t *src, int src_stride, int width, int height, int *sum";
-  specialize qw/aom_sum_sse_2d_i16 sse2 avx2/;
+  specialize qw/aom_sum_sse_2d_i16 avx2 neon    sse2/;
   specialize qw/aom_sad128x128    avx2 neon     sse2/;
   specialize qw/aom_sad128x64     avx2 neon     sse2/;
   specialize qw/aom_sad64x128     avx2 neon     sse2/;
diff --git a/aom_dsp/arm/sum_squares_neon.c b/aom_dsp/arm/sum_squares_neon.c
index 524b098..bf212a9 100644
--- a/aom_dsp/arm/sum_squares_neon.c
+++ b/aom_dsp/arm/sum_squares_neon.c
@@ -108,3 +108,118 @@
     return aom_sum_squares_2d_i16_c(src, stride, width, height);
   }
 }
+
+static INLINE uint64_t aom_sum_sse_2d_i16_4x4_neon(const int16_t *src,
+                                                   int stride, int *sum) {
+  int16x4_t s0 = vld1_s16(src + 0 * stride);
+  int16x4_t s1 = vld1_s16(src + 1 * stride);
+  int16x4_t s2 = vld1_s16(src + 2 * stride);
+  int16x4_t s3 = vld1_s16(src + 3 * stride);
+
+  int32x4_t sse = vmull_s16(s0, s0);
+  sse = vmlal_s16(sse, s1, s1);
+  sse = vmlal_s16(sse, s2, s2);
+  sse = vmlal_s16(sse, s3, s3);
+
+  int32x4_t sum_01 = vaddl_s16(s0, s1);
+  int32x4_t sum_23 = vaddl_s16(s2, s3);
+  *sum += horizontal_add_s32x4(vaddq_s32(sum_01, sum_23));
+
+  return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sse));
+}
+
+static INLINE uint64_t aom_sum_sse_2d_i16_4xn_neon(const int16_t *src,
+                                                   int stride, int height,
+                                                   int *sum) {
+  int32x4_t sse[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
+  int32x2_t sum_acc[2] = { vdup_n_s32(0), vdup_n_s32(0) };
+
+  int h = 0;
+  do {
+    int16x4_t s0 = vld1_s16(src + 0 * stride);
+    int16x4_t s1 = vld1_s16(src + 1 * stride);
+    int16x4_t s2 = vld1_s16(src + 2 * stride);
+    int16x4_t s3 = vld1_s16(src + 3 * stride);
+
+    sse[0] = vmlal_s16(sse[0], s0, s0);
+    sse[0] = vmlal_s16(sse[0], s1, s1);
+    sse[1] = vmlal_s16(sse[1], s2, s2);
+    sse[1] = vmlal_s16(sse[1], s3, s3);
+
+    sum_acc[0] = vpadal_s16(sum_acc[0], s0);
+    sum_acc[0] = vpadal_s16(sum_acc[0], s1);
+    sum_acc[1] = vpadal_s16(sum_acc[1], s2);
+    sum_acc[1] = vpadal_s16(sum_acc[1], s3);
+
+    src += 4 * stride;
+    h += 4;
+  } while (h < height);
+
+  *sum += horizontal_add_s32x4(vcombine_s32(sum_acc[0], sum_acc[1]));
+  return horizontal_long_add_u32x4(
+      vreinterpretq_u32_s32(vaddq_s32(sse[0], sse[1])));
+}
+
+static INLINE uint64_t aom_sum_sse_2d_i16_nxn_neon(const int16_t *src,
+                                                   int stride, int width,
+                                                   int height, int *sum) {
+  uint64x2_t sse = vdupq_n_u64(0);
+  int32x4_t sum_acc = vdupq_n_s32(0);
+
+  int h = 0;
+  do {
+    int32x4_t sse_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
+    int w = 0;
+    do {
+      const int16_t *s = src + w;
+      int16x8_t s0 = vld1q_s16(s + 0 * stride);
+      int16x8_t s1 = vld1q_s16(s + 1 * stride);
+      int16x8_t s2 = vld1q_s16(s + 2 * stride);
+      int16x8_t s3 = vld1q_s16(s + 3 * stride);
+
+      sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s0), vget_low_s16(s0));
+      sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s1), vget_low_s16(s1));
+      sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s2), vget_low_s16(s2));
+      sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s3), vget_low_s16(s3));
+      sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s0), vget_high_s16(s0));
+      sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s1), vget_high_s16(s1));
+      sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s2), vget_high_s16(s2));
+      sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s3), vget_high_s16(s3));
+
+      sum_acc = vpadalq_s16(sum_acc, s0);
+      sum_acc = vpadalq_s16(sum_acc, s1);
+      sum_acc = vpadalq_s16(sum_acc, s2);
+      sum_acc = vpadalq_s16(sum_acc, s3);
+
+      w += 8;
+    } while (w < width);
+
+    sse = vpadalq_u32(sse,
+                      vreinterpretq_u32_s32(vaddq_s32(sse_row[0], sse_row[1])));
+
+    src += 4 * stride;
+    h += 4;
+  } while (h < height);
+
+  *sum += horizontal_add_s32x4(sum_acc);
+  return horizontal_add_u64x2(sse);
+}
+
+uint64_t aom_sum_sse_2d_i16_neon(const int16_t *src, int stride, int width,
+                                 int height, int *sum) {
+  uint64_t sse;
+
+  if (LIKELY(width == 4 && height == 4)) {
+    sse = aom_sum_sse_2d_i16_4x4_neon(src, stride, sum);
+  } else if (LIKELY(width == 4 && (height & 3) == 0)) {
+    // width = 4, height is a multiple of 4.
+    sse = aom_sum_sse_2d_i16_4xn_neon(src, stride, height, sum);
+  } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) {
+    // Generic case - width is multiple of 8, height is multiple of 4.
+    sse = aom_sum_sse_2d_i16_nxn_neon(src, stride, width, height, sum);
+  } else {
+    sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum);
+  }
+
+  return sse;
+}
diff --git a/test/sse_sum_test.cc b/test/sse_sum_test.cc
index e7c32e6..68355ec 100644
--- a/test/sse_sum_test.cc
+++ b/test/sse_sum_test.cc
@@ -160,6 +160,13 @@
                              &aom_sum_sse_2d_i16_c, &aom_sum_sse_2d_i16_sse2)));
 
 #endif  // HAVE_SSE2
+
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(NEON, SumSSETest,
+                         ::testing::Values(TestFuncs(
+                             &aom_sum_sse_2d_i16_c, &aom_sum_sse_2d_i16_neon)));
+#endif  // HAVE_NEON
+
 #if HAVE_AVX2
 INSTANTIATE_TEST_SUITE_P(AVX2, SumSSETest,
                          ::testing::Values(TestFuncs(