Optimize implementation of aom_sum_squares_2d_i16_neon

Optimize the implementation of aom_sum_squares_2d_i16_neon by
refactoring the sum-of-squares helper functions. The main change is
to use vector accumulators that don't require some kind of reduction
on every loop iteration.

Also add new 32/64-bit agnostic reduction functions to sum_neon.h and
use them instead of #ifdef blocks in sum_squares_neon.c.

Change-Id: Icd7cab01a57621f567414b3637e5dea10a56e960
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 4116509..0cf110a 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -37,6 +37,23 @@
 #endif
 }
 
+static INLINE uint64_t horizontal_add_u64x2(const uint64x2_t a) {
+#if defined(__aarch64__)
+  return vaddvq_u64(a);
+#else
+  return vgetq_lane_u64(a, 0) + vgetq_lane_u64(a, 1);
+#endif
+}
+
+static INLINE uint64_t horizontal_long_add_u32x4(const uint32x4_t a) {
+#if defined(__aarch64__)
+  return vaddlvq_u32(a);
+#else
+  const uint64x2_t b = vpaddlq_u32(a);
+  return vgetq_lane_u64(b, 0) + vgetq_lane_u64(b, 1);
+#endif
+}
+
 static INLINE unsigned int horizontal_add_u32x4(const uint32x4_t a) {
 #if defined(__aarch64__)
   return vaddvq_u32(a);
diff --git a/aom_dsp/arm/sum_squares_neon.c b/aom_dsp/arm/sum_squares_neon.c
index 0b7337a..524b098 100644
--- a/aom_dsp/arm/sum_squares_neon.c
+++ b/aom_dsp/arm/sum_squares_neon.c
@@ -13,111 +13,83 @@
 #include <assert.h>
 
 #include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/arm/sum_neon.h"
 #include "config/aom_dsp_rtcd.h"
 
-static INLINE uint32x4_t sum_squares_i16_4x4_neon(const int16_t *src,
-                                                  int stride) {
-  const int16x4_t v_val_01_lo = vld1_s16(src + 0 * stride);
-  const int16x4_t v_val_01_hi = vld1_s16(src + 1 * stride);
-  const int16x4_t v_val_23_lo = vld1_s16(src + 2 * stride);
-  const int16x4_t v_val_23_hi = vld1_s16(src + 3 * stride);
-  int32x4_t v_sq_01_d = vmull_s16(v_val_01_lo, v_val_01_lo);
-  v_sq_01_d = vmlal_s16(v_sq_01_d, v_val_01_hi, v_val_01_hi);
-  int32x4_t v_sq_23_d = vmull_s16(v_val_23_lo, v_val_23_lo);
-  v_sq_23_d = vmlal_s16(v_sq_23_d, v_val_23_hi, v_val_23_hi);
-#if defined(__aarch64__)
-  return vreinterpretq_u32_s32(vpaddq_s32(v_sq_01_d, v_sq_23_d));
-#else
-  return vreinterpretq_u32_s32(vcombine_s32(
-      vqmovn_s64(vpaddlq_s32(v_sq_01_d)), vqmovn_s64(vpaddlq_s32(v_sq_23_d))));
-#endif
+static INLINE uint64_t aom_sum_squares_2d_i16_4x4_neon(const int16_t *src,
+                                                       int stride) {
+  int16x4_t s0 = vld1_s16(src + 0 * stride);
+  int16x4_t s1 = vld1_s16(src + 1 * stride);
+  int16x4_t s2 = vld1_s16(src + 2 * stride);
+  int16x4_t s3 = vld1_s16(src + 3 * stride);
+
+  int32x4_t sum_squares = vmull_s16(s0, s0);
+  sum_squares = vmlal_s16(sum_squares, s1, s1);
+  sum_squares = vmlal_s16(sum_squares, s2, s2);
+  sum_squares = vmlal_s16(sum_squares, s3, s3);
+
+  return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sum_squares));
 }
 
-uint64_t aom_sum_squares_2d_i16_4x4_neon(const int16_t *src, int stride) {
-  const uint32x4_t v_sum_0123_d = sum_squares_i16_4x4_neon(src, stride);
-#if defined(__aarch64__)
-  return (uint64_t)vaddvq_u32(v_sum_0123_d);
-#else
-  uint64x2_t v_sum_d = vpaddlq_u32(v_sum_0123_d);
-  v_sum_d = vaddq_u64(v_sum_d, vextq_u64(v_sum_d, v_sum_d, 1));
-  return vgetq_lane_u64(v_sum_d, 0);
-#endif
-}
+static INLINE uint64_t aom_sum_squares_2d_i16_4xn_neon(const int16_t *src,
+                                                       int stride, int height) {
+  int32x4_t sum_squares[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
 
-uint64_t aom_sum_squares_2d_i16_4xn_neon(const int16_t *src, int stride,
-                                         int height) {
-  int r = 0;
-  uint32x4_t v_acc_q = vdupq_n_u32(0);
+  int h = 0;
   do {
-    const uint32x4_t v_acc_d = sum_squares_i16_4x4_neon(src, stride);
-    v_acc_q = vaddq_u32(v_acc_q, v_acc_d);
-    src += stride << 2;
-    r += 4;
-  } while (r < height);
+    int16x4_t s0 = vld1_s16(src + 0 * stride);
+    int16x4_t s1 = vld1_s16(src + 1 * stride);
+    int16x4_t s2 = vld1_s16(src + 2 * stride);
+    int16x4_t s3 = vld1_s16(src + 3 * stride);
 
-  uint64x2_t v_acc_64 = vpaddlq_u32(v_acc_q);
-#if defined(__aarch64__)
-  return vaddvq_u64(v_acc_64);
-#else
-  v_acc_64 = vaddq_u64(v_acc_64, vextq_u64(v_acc_64, v_acc_64, 1));
-  return vgetq_lane_u64(v_acc_64, 0);
-#endif
-}
-
-uint64_t aom_sum_squares_2d_i16_nxn_neon(const int16_t *src, int stride,
-                                         int width, int height) {
-  int r = 0;
-  const int32x4_t zero = vdupq_n_s32(0);
-  uint64x2_t v_acc_q = vreinterpretq_u64_s32(zero);
-  do {
-    int32x4_t v_sum = zero;
-    int c = 0;
-    do {
-      const int16_t *b = src + c;
-      const int16x8_t v_val_0 = vld1q_s16(b + 0 * stride);
-      const int16x8_t v_val_1 = vld1q_s16(b + 1 * stride);
-      const int16x8_t v_val_2 = vld1q_s16(b + 2 * stride);
-      const int16x8_t v_val_3 = vld1q_s16(b + 3 * stride);
-      const int16x4_t v_val_0_lo = vget_low_s16(v_val_0);
-      const int16x4_t v_val_1_lo = vget_low_s16(v_val_1);
-      const int16x4_t v_val_2_lo = vget_low_s16(v_val_2);
-      const int16x4_t v_val_3_lo = vget_low_s16(v_val_3);
-      int32x4_t v_sum_01 = vmull_s16(v_val_0_lo, v_val_0_lo);
-      v_sum_01 = vmlal_s16(v_sum_01, v_val_1_lo, v_val_1_lo);
-      int32x4_t v_sum_23 = vmull_s16(v_val_2_lo, v_val_2_lo);
-      v_sum_23 = vmlal_s16(v_sum_23, v_val_3_lo, v_val_3_lo);
-#if defined(__aarch64__)
-      v_sum_01 = vmlal_high_s16(v_sum_01, v_val_0, v_val_0);
-      v_sum_01 = vmlal_high_s16(v_sum_01, v_val_1, v_val_1);
-      v_sum_23 = vmlal_high_s16(v_sum_23, v_val_2, v_val_2);
-      v_sum_23 = vmlal_high_s16(v_sum_23, v_val_3, v_val_3);
-      v_sum = vaddq_s32(v_sum, vpaddq_s32(v_sum_01, v_sum_23));
-#else
-      const int16x4_t v_val_0_hi = vget_high_s16(v_val_0);
-      const int16x4_t v_val_1_hi = vget_high_s16(v_val_1);
-      const int16x4_t v_val_2_hi = vget_high_s16(v_val_2);
-      const int16x4_t v_val_3_hi = vget_high_s16(v_val_3);
-      v_sum_01 = vmlal_s16(v_sum_01, v_val_0_hi, v_val_0_hi);
-      v_sum_01 = vmlal_s16(v_sum_01, v_val_1_hi, v_val_1_hi);
-      v_sum_23 = vmlal_s16(v_sum_23, v_val_2_hi, v_val_2_hi);
-      v_sum_23 = vmlal_s16(v_sum_23, v_val_3_hi, v_val_3_hi);
-      v_sum = vaddq_s32(v_sum, vcombine_s32(vqmovn_s64(vpaddlq_s32(v_sum_01)),
-                                            vqmovn_s64(vpaddlq_s32(v_sum_23))));
-#endif
-      c += 8;
-    } while (c < width);
-
-    v_acc_q = vpadalq_u32(v_acc_q, vreinterpretq_u32_s32(v_sum));
+    sum_squares[0] = vmlal_s16(sum_squares[0], s0, s0);
+    sum_squares[0] = vmlal_s16(sum_squares[0], s1, s1);
+    sum_squares[1] = vmlal_s16(sum_squares[1], s2, s2);
+    sum_squares[1] = vmlal_s16(sum_squares[1], s3, s3);
 
     src += 4 * stride;
-    r += 4;
-  } while (r < height);
-#if defined(__aarch64__)
-  return vaddvq_u64(v_acc_q);
-#else
-  v_acc_q = vaddq_u64(v_acc_q, vextq_u64(v_acc_q, v_acc_q, 1));
-  return vgetq_lane_u64(v_acc_q, 0);
-#endif
+    h += 4;
+  } while (h < height);
+
+  return horizontal_long_add_u32x4(
+      vreinterpretq_u32_s32(vaddq_s32(sum_squares[0], sum_squares[1])));
+}
+
+static INLINE uint64_t aom_sum_squares_2d_i16_nxn_neon(const int16_t *src,
+                                                       int stride, int width,
+                                                       int height) {
+  uint64x2_t sum_squares = vdupq_n_u64(0);
+
+  int h = 0;
+  do {
+    int32x4_t ss_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
+    int w = 0;
+    do {
+      const int16_t *s = src + w;
+      int16x8_t s0 = vld1q_s16(s + 0 * stride);
+      int16x8_t s1 = vld1q_s16(s + 1 * stride);
+      int16x8_t s2 = vld1q_s16(s + 2 * stride);
+      int16x8_t s3 = vld1q_s16(s + 3 * stride);
+
+      ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s0), vget_low_s16(s0));
+      ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s1), vget_low_s16(s1));
+      ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s2), vget_low_s16(s2));
+      ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s3), vget_low_s16(s3));
+      ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s0), vget_high_s16(s0));
+      ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s1), vget_high_s16(s1));
+      ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s2), vget_high_s16(s2));
+      ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s3), vget_high_s16(s3));
+      w += 8;
+    } while (w < width);
+
+    sum_squares = vpadalq_u32(
+        sum_squares, vreinterpretq_u32_s32(vaddq_s32(ss_row[0], ss_row[1])));
+
+    src += 4 * stride;
+    h += 4;
+  } while (h < height);
+
+  return horizontal_add_u64x2(sum_squares);
 }
 
 uint64_t aom_sum_squares_2d_i16_neon(const int16_t *src, int stride, int width,