Refactor Neon implementations of SAD functions

Use a Neon vector accumulator and only perform a reduction once at
the end of the function rather than on every loop iteration in
sad<w>xh_neon() helper functions.

Implement aom_sad<w>x<w>_neon() functions using the sad<w>xh_neon()
helpers to reduce code duplication.

Change-Id: If6c073aaeeba211e2dfb0a5602be59c586688841
diff --git a/aom_dsp/arm/sad_neon.c b/aom_dsp/arm/sad_neon.c
index acd2c54..df66275 100644
--- a/aom_dsp/arm/sad_neon.c
+++ b/aom_dsp/arm/sad_neon.c
@@ -15,447 +15,261 @@
 #include "aom/aom_integer.h"
 #include "aom_dsp/arm/sum_neon.h"
 
-unsigned int aom_sad8x16_neon(const uint8_t *src_ptr, int src_stride,
-                              const uint8_t *ref_ptr, int ref_stride) {
-  uint8x8_t d0, d8;
-  uint16x8_t q12;
-  uint32x4_t q1;
-  uint64x2_t q3;
-  uint32x2_t d5;
-  int i;
-
-  d0 = vld1_u8(src_ptr);
-  src_ptr += src_stride;
-  d8 = vld1_u8(ref_ptr);
-  ref_ptr += ref_stride;
-  q12 = vabdl_u8(d0, d8);
-
-  for (i = 0; i < 15; i++) {
-    d0 = vld1_u8(src_ptr);
-    src_ptr += src_stride;
-    d8 = vld1_u8(ref_ptr);
-    ref_ptr += ref_stride;
-    q12 = vabal_u8(q12, d0, d8);
-  }
-
-  q1 = vpaddlq_u16(q12);
-  q3 = vpaddlq_u32(q1);
-  d5 = vadd_u32(vreinterpret_u32_u64(vget_low_u64(q3)),
-                vreinterpret_u32_u64(vget_high_u64(q3)));
-
-  return vget_lane_u32(d5, 0);
-}
-
-unsigned int aom_sad4x4_neon(const uint8_t *src_ptr, int src_stride,
-                             const uint8_t *ref_ptr, int ref_stride) {
-  uint8x8_t d0, d8;
-  uint16x8_t q12;
-  uint32x2_t d1;
-  uint64x1_t d3;
-  int i;
-
-  d0 = vld1_u8(src_ptr);
-  src_ptr += src_stride;
-  d8 = vld1_u8(ref_ptr);
-  ref_ptr += ref_stride;
-  q12 = vabdl_u8(d0, d8);
-
-  for (i = 0; i < 3; i++) {
-    d0 = vld1_u8(src_ptr);
-    src_ptr += src_stride;
-    d8 = vld1_u8(ref_ptr);
-    ref_ptr += ref_stride;
-    q12 = vabal_u8(q12, d0, d8);
-  }
-
-  d1 = vpaddl_u16(vget_low_u16(q12));
-  d3 = vpaddl_u32(d1);
-
-  return vget_lane_u32(vreinterpret_u32_u64(d3), 0);
-}
-
-unsigned int aom_sad16x8_neon(const uint8_t *src_ptr, int src_stride,
-                              const uint8_t *ref_ptr, int ref_stride) {
-  uint8x16_t q0, q4;
-  uint16x8_t q12, q13;
-  uint32x4_t q1;
-  uint64x2_t q3;
-  uint32x2_t d5;
-  int i;
-
-  q0 = vld1q_u8(src_ptr);
-  src_ptr += src_stride;
-  q4 = vld1q_u8(ref_ptr);
-  ref_ptr += ref_stride;
-  q12 = vabdl_u8(vget_low_u8(q0), vget_low_u8(q4));
-  q13 = vabdl_u8(vget_high_u8(q0), vget_high_u8(q4));
-
-  for (i = 0; i < 7; i++) {
-    q0 = vld1q_u8(src_ptr);
-    src_ptr += src_stride;
-    q4 = vld1q_u8(ref_ptr);
-    ref_ptr += ref_stride;
-    q12 = vabal_u8(q12, vget_low_u8(q0), vget_low_u8(q4));
-    q13 = vabal_u8(q13, vget_high_u8(q0), vget_high_u8(q4));
-  }
-
-  q12 = vaddq_u16(q12, q13);
-  q1 = vpaddlq_u16(q12);
-  q3 = vpaddlq_u32(q1);
-  d5 = vadd_u32(vreinterpret_u32_u64(vget_low_u64(q3)),
-                vreinterpret_u32_u64(vget_high_u64(q3)));
-
-  return vget_lane_u32(d5, 0);
-}
-
-unsigned int aom_sad64x64_neon(const uint8_t *src, int src_stride,
-                               const uint8_t *ref, int ref_stride) {
-  int i;
-  uint16x8_t vec_accum_lo = vdupq_n_u16(0);
-  uint16x8_t vec_accum_hi = vdupq_n_u16(0);
-  for (i = 0; i < 64; ++i) {
-    const uint8x16_t vec_src_00 = vld1q_u8(src);
-    const uint8x16_t vec_src_16 = vld1q_u8(src + 16);
-    const uint8x16_t vec_src_32 = vld1q_u8(src + 32);
-    const uint8x16_t vec_src_48 = vld1q_u8(src + 48);
-    const uint8x16_t vec_ref_00 = vld1q_u8(ref);
-    const uint8x16_t vec_ref_16 = vld1q_u8(ref + 16);
-    const uint8x16_t vec_ref_32 = vld1q_u8(ref + 32);
-    const uint8x16_t vec_ref_48 = vld1q_u8(ref + 48);
-    src += src_stride;
-    ref += ref_stride;
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_00),
-                            vget_low_u8(vec_ref_00));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_00),
-                            vget_high_u8(vec_ref_00));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_16),
-                            vget_low_u8(vec_ref_16));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_16),
-                            vget_high_u8(vec_ref_16));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_32),
-                            vget_low_u8(vec_ref_32));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_32),
-                            vget_high_u8(vec_ref_32));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_48),
-                            vget_low_u8(vec_ref_48));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_48),
-                            vget_high_u8(vec_ref_48));
-  }
-  return horizontal_long_add_u16x8(vec_accum_lo, vec_accum_hi);
-}
-
-unsigned int aom_sad128x128_neon(const uint8_t *src, int src_stride,
-                                 const uint8_t *ref, int ref_stride) {
-  uint16x8_t vec_accum_lo, vec_accum_hi;
-  uint32x4_t vec_accum_32lo = vdupq_n_u32(0);
-  uint32x4_t vec_accum_32hi = vdupq_n_u32(0);
-  uint16x8_t tmp;
-  for (int i = 0; i < 128; ++i) {
-    const uint8x16_t vec_src_00 = vld1q_u8(src);
-    const uint8x16_t vec_src_16 = vld1q_u8(src + 16);
-    const uint8x16_t vec_src_32 = vld1q_u8(src + 32);
-    const uint8x16_t vec_src_48 = vld1q_u8(src + 48);
-    const uint8x16_t vec_src_64 = vld1q_u8(src + 64);
-    const uint8x16_t vec_src_80 = vld1q_u8(src + 80);
-    const uint8x16_t vec_src_96 = vld1q_u8(src + 96);
-    const uint8x16_t vec_src_112 = vld1q_u8(src + 112);
-    const uint8x16_t vec_ref_00 = vld1q_u8(ref);
-    const uint8x16_t vec_ref_16 = vld1q_u8(ref + 16);
-    const uint8x16_t vec_ref_32 = vld1q_u8(ref + 32);
-    const uint8x16_t vec_ref_48 = vld1q_u8(ref + 48);
-    const uint8x16_t vec_ref_64 = vld1q_u8(ref + 64);
-    const uint8x16_t vec_ref_80 = vld1q_u8(ref + 80);
-    const uint8x16_t vec_ref_96 = vld1q_u8(ref + 96);
-    const uint8x16_t vec_ref_112 = vld1q_u8(ref + 112);
-    src += src_stride;
-    ref += ref_stride;
-    vec_accum_lo = vdupq_n_u16(0);
-    vec_accum_hi = vdupq_n_u16(0);
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_00),
-                            vget_low_u8(vec_ref_00));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_00),
-                            vget_high_u8(vec_ref_00));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_16),
-                            vget_low_u8(vec_ref_16));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_16),
-                            vget_high_u8(vec_ref_16));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_32),
-                            vget_low_u8(vec_ref_32));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_32),
-                            vget_high_u8(vec_ref_32));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_48),
-                            vget_low_u8(vec_ref_48));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_48),
-                            vget_high_u8(vec_ref_48));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_64),
-                            vget_low_u8(vec_ref_64));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_64),
-                            vget_high_u8(vec_ref_64));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_80),
-                            vget_low_u8(vec_ref_80));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_80),
-                            vget_high_u8(vec_ref_80));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_96),
-                            vget_low_u8(vec_ref_96));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_96),
-                            vget_high_u8(vec_ref_96));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_112),
-                            vget_low_u8(vec_ref_112));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_112),
-                            vget_high_u8(vec_ref_112));
-
-    tmp = vaddq_u16(vec_accum_lo, vec_accum_hi);
-    vec_accum_32lo = vaddw_u16(vec_accum_32lo, vget_low_u16(tmp));
-    vec_accum_32hi = vaddw_u16(vec_accum_32hi, vget_high_u16(tmp));
-  }
-  const uint32x4_t a = vaddq_u32(vec_accum_32lo, vec_accum_32hi);
-  const uint64x2_t b = vpaddlq_u32(a);
-  const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)),
-                                vreinterpret_u32_u64(vget_high_u64(b)));
-  return vget_lane_u32(c, 0);
-}
-
-unsigned int aom_sad32x32_neon(const uint8_t *src, int src_stride,
-                               const uint8_t *ref, int ref_stride) {
-  int i;
-  uint16x8_t vec_accum_lo = vdupq_n_u16(0);
-  uint16x8_t vec_accum_hi = vdupq_n_u16(0);
-
-  for (i = 0; i < 32; ++i) {
-    const uint8x16_t vec_src_00 = vld1q_u8(src);
-    const uint8x16_t vec_src_16 = vld1q_u8(src + 16);
-    const uint8x16_t vec_ref_00 = vld1q_u8(ref);
-    const uint8x16_t vec_ref_16 = vld1q_u8(ref + 16);
-    src += src_stride;
-    ref += ref_stride;
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_00),
-                            vget_low_u8(vec_ref_00));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_00),
-                            vget_high_u8(vec_ref_00));
-    vec_accum_lo = vabal_u8(vec_accum_lo, vget_low_u8(vec_src_16),
-                            vget_low_u8(vec_ref_16));
-    vec_accum_hi = vabal_u8(vec_accum_hi, vget_high_u8(vec_src_16),
-                            vget_high_u8(vec_ref_16));
-  }
-  return horizontal_add_u16x8(vaddq_u16(vec_accum_lo, vec_accum_hi));
-}
-
-unsigned int aom_sad16x16_neon(const uint8_t *src, int src_stride,
-                               const uint8_t *ref, int ref_stride) {
-  int i;
-  uint16x8_t vec_accum_lo = vdupq_n_u16(0);
-  uint16x8_t vec_accum_hi = vdupq_n_u16(0);
-
-  for (i = 0; i < 16; ++i) {
-    const uint8x16_t vec_src = vld1q_u8(src);
-    const uint8x16_t vec_ref = vld1q_u8(ref);
-    src += src_stride;
-    ref += ref_stride;
-    vec_accum_lo =
-        vabal_u8(vec_accum_lo, vget_low_u8(vec_src), vget_low_u8(vec_ref));
-    vec_accum_hi =
-        vabal_u8(vec_accum_hi, vget_high_u8(vec_src), vget_high_u8(vec_ref));
-  }
-  return horizontal_add_u16x8(vaddq_u16(vec_accum_lo, vec_accum_hi));
-}
-
-unsigned int aom_sad8x8_neon(const uint8_t *src, int src_stride,
-                             const uint8_t *ref, int ref_stride) {
-  int i;
-  uint16x8_t vec_accum = vdupq_n_u16(0);
-
-  for (i = 0; i < 8; ++i) {
-    const uint8x8_t vec_src = vld1_u8(src);
-    const uint8x8_t vec_ref = vld1_u8(ref);
-    src += src_stride;
-    ref += ref_stride;
-    vec_accum = vabal_u8(vec_accum, vec_src, vec_ref);
-  }
-  return horizontal_add_u16x8(vec_accum);
-}
-
 static INLINE unsigned int sad128xh_neon(const uint8_t *src_ptr, int src_stride,
                                          const uint8_t *ref_ptr, int ref_stride,
                                          int h) {
-  int sum = 0;
-  for (int i = 0; i < h; i++) {
-    uint16x8_t q3 = vdupq_n_u16(0);
+  // We use 8 accumulators to prevent overflow for large values of 'h', as well
+  // as enabling optimal UADALP instruction throughput on CPUs that have either
+  // 2 or 4 Neon pipes.
+  uint16x8_t sum[8] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0), vdupq_n_u16(0) };
 
-    uint8x16_t q0 = vld1q_u8(src_ptr);
-    uint8x16_t q1 = vld1q_u8(ref_ptr);
-    uint8x16_t q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+  int i = 0;
+  do {
+    uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7;
+    uint8x16_t r0, r1, r2, r3, r4, r5, r6, r7;
+    uint8x16_t diff0, diff1, diff2, diff3, diff4, diff5, diff6, diff7;
 
-    q0 = vld1q_u8(src_ptr + 16);
-    q1 = vld1q_u8(ref_ptr + 16);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s0 = vld1q_u8(src_ptr);
+    r0 = vld1q_u8(ref_ptr);
+    diff0 = vabdq_u8(s0, r0);
+    sum[0] = vpadalq_u8(sum[0], diff0);
 
-    q0 = vld1q_u8(src_ptr + 32);
-    q1 = vld1q_u8(ref_ptr + 32);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s1 = vld1q_u8(src_ptr + 16);
+    r1 = vld1q_u8(ref_ptr + 16);
+    diff1 = vabdq_u8(s1, r1);
+    sum[1] = vpadalq_u8(sum[1], diff1);
 
-    q0 = vld1q_u8(src_ptr + 48);
-    q1 = vld1q_u8(ref_ptr + 48);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s2 = vld1q_u8(src_ptr + 32);
+    r2 = vld1q_u8(ref_ptr + 32);
+    diff2 = vabdq_u8(s2, r2);
+    sum[2] = vpadalq_u8(sum[2], diff2);
 
-    q0 = vld1q_u8(src_ptr + 64);
-    q1 = vld1q_u8(ref_ptr + 64);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s3 = vld1q_u8(src_ptr + 48);
+    r3 = vld1q_u8(ref_ptr + 48);
+    diff3 = vabdq_u8(s3, r3);
+    sum[3] = vpadalq_u8(sum[3], diff3);
 
-    q0 = vld1q_u8(src_ptr + 80);
-    q1 = vld1q_u8(ref_ptr + 80);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s4 = vld1q_u8(src_ptr + 64);
+    r4 = vld1q_u8(ref_ptr + 64);
+    diff4 = vabdq_u8(s4, r4);
+    sum[4] = vpadalq_u8(sum[4], diff4);
 
-    q0 = vld1q_u8(src_ptr + 96);
-    q1 = vld1q_u8(ref_ptr + 96);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s5 = vld1q_u8(src_ptr + 80);
+    r5 = vld1q_u8(ref_ptr + 80);
+    diff5 = vabdq_u8(s5, r5);
+    sum[5] = vpadalq_u8(sum[5], diff5);
 
-    q0 = vld1q_u8(src_ptr + 112);
-    q1 = vld1q_u8(ref_ptr + 112);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s6 = vld1q_u8(src_ptr + 96);
+    r6 = vld1q_u8(ref_ptr + 96);
+    diff6 = vabdq_u8(s6, r6);
+    sum[6] = vpadalq_u8(sum[6], diff6);
+
+    s7 = vld1q_u8(src_ptr + 112);
+    r7 = vld1q_u8(ref_ptr + 112);
+    diff7 = vabdq_u8(s7, r7);
+    sum[7] = vpadalq_u8(sum[7], diff7);
 
     src_ptr += src_stride;
     ref_ptr += ref_stride;
+    i++;
+  } while (i < h);
 
-    sum += horizontal_add_u16x8(q3);
-  }
+  uint32x4_t sum_u32 = vpaddlq_u16(sum[0]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[1]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[2]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[3]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[4]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[5]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[6]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[7]);
 
-  return sum;
+  return horizontal_add_u32x4(sum_u32);
 }
 
 static INLINE unsigned int sad64xh_neon(const uint8_t *src_ptr, int src_stride,
                                         const uint8_t *ref_ptr, int ref_stride,
                                         int h) {
-  int sum = 0;
-  for (int i = 0; i < h; i++) {
-    uint16x8_t q3 = vdupq_n_u16(0);
+  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0) };
 
-    uint8x16_t q0 = vld1q_u8(src_ptr);
-    uint8x16_t q1 = vld1q_u8(ref_ptr);
-    uint8x16_t q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+  int i = 0;
+  do {
+    uint8x16_t s0, s1, s2, s3, r0, r1, r2, r3;
+    uint8x16_t diff0, diff1, diff2, diff3;
 
-    q0 = vld1q_u8(src_ptr + 16);
-    q1 = vld1q_u8(ref_ptr + 16);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s0 = vld1q_u8(src_ptr);
+    r0 = vld1q_u8(ref_ptr);
+    diff0 = vabdq_u8(s0, r0);
+    sum[0] = vpadalq_u8(sum[0], diff0);
 
-    q0 = vld1q_u8(src_ptr + 32);
-    q1 = vld1q_u8(ref_ptr + 32);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s1 = vld1q_u8(src_ptr + 16);
+    r1 = vld1q_u8(ref_ptr + 16);
+    diff1 = vabdq_u8(s1, r1);
+    sum[1] = vpadalq_u8(sum[1], diff1);
 
-    q0 = vld1q_u8(src_ptr + 48);
-    q1 = vld1q_u8(ref_ptr + 48);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    s2 = vld1q_u8(src_ptr + 32);
+    r2 = vld1q_u8(ref_ptr + 32);
+    diff2 = vabdq_u8(s2, r2);
+    sum[2] = vpadalq_u8(sum[2], diff2);
+
+    s3 = vld1q_u8(src_ptr + 48);
+    r3 = vld1q_u8(ref_ptr + 48);
+    diff3 = vabdq_u8(s3, r3);
+    sum[3] = vpadalq_u8(sum[3], diff3);
 
     src_ptr += src_stride;
     ref_ptr += ref_stride;
+    i++;
+  } while (i < h);
 
-    sum += horizontal_add_u16x8(q3);
-  }
+  uint32x4_t sum_u32 = vpaddlq_u16(sum[0]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[1]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[2]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[3]);
 
-  return sum;
+  return horizontal_add_u32x4(sum_u32);
 }
 
 static INLINE unsigned int sad32xh_neon(const uint8_t *src_ptr, int src_stride,
                                         const uint8_t *ref_ptr, int ref_stride,
                                         int h) {
-  int sum = 0;
-  for (int i = 0; i < h; i++) {
-    uint16x8_t q3 = vdupq_n_u16(0);
+  uint32x4_t sum = vdupq_n_u32(0);
 
-    uint8x16_t q0 = vld1q_u8(src_ptr);
-    uint8x16_t q1 = vld1q_u8(ref_ptr);
-    uint8x16_t q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+  int i = 0;
+  do {
+    uint8x16_t s0 = vld1q_u8(src_ptr);
+    uint8x16_t r0 = vld1q_u8(ref_ptr);
+    uint8x16_t diff0 = vabdq_u8(s0, r0);
+    uint16x8_t sum0 = vpaddlq_u8(diff0);
 
-    q0 = vld1q_u8(src_ptr + 16);
-    q1 = vld1q_u8(ref_ptr + 16);
-    q2 = vabdq_u8(q0, q1);
-    q3 = vpadalq_u8(q3, q2);
+    uint8x16_t s1 = vld1q_u8(src_ptr + 16);
+    uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
+    uint8x16_t diff1 = vabdq_u8(s1, r1);
+    uint16x8_t sum1 = vpaddlq_u8(diff1);
 
-    sum += horizontal_add_u16x8(q3);
+    sum = vpadalq_u16(sum, sum0);
+    sum = vpadalq_u16(sum, sum1);
 
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-  }
+    i++;
+  } while (i < h);
 
-  return sum;
+  return horizontal_add_u32x4(sum);
 }
 
 static INLINE unsigned int sad16xh_neon(const uint8_t *src_ptr, int src_stride,
                                         const uint8_t *ref_ptr, int ref_stride,
                                         int h) {
-  int sum = 0;
-  for (int i = 0; i < h; i++) {
-    uint8x8_t q0 = vld1_u8(src_ptr);
-    uint8x8_t q1 = vld1_u8(ref_ptr);
-    sum += vget_lane_u16(vpaddl_u8(vabd_u8(q0, q1)), 0);
-    sum += vget_lane_u16(vpaddl_u8(vabd_u8(q0, q1)), 1);
-    sum += vget_lane_u16(vpaddl_u8(vabd_u8(q0, q1)), 2);
-    sum += vget_lane_u16(vpaddl_u8(vabd_u8(q0, q1)), 3);
-    q0 = vld1_u8(src_ptr + 8);
-    q1 = vld1_u8(ref_ptr + 8);
-    sum += vget_lane_u16(vpaddl_u8(vabd_u8(q0, q1)), 0);
-    sum += vget_lane_u16(vpaddl_u8(vabd_u8(q0, q1)), 1);
-    sum += vget_lane_u16(vpaddl_u8(vabd_u8(q0, q1)), 2);
-    sum += vget_lane_u16(vpaddl_u8(vabd_u8(q0, q1)), 3);
+  uint16x8_t sum = vdupq_n_u16(0);
+
+  int i = 0;
+  do {
+    uint8x16_t s = vld1q_u8(src_ptr);
+    uint8x16_t r = vld1q_u8(ref_ptr);
+
+    uint8x16_t diff = vabdq_u8(s, r);
+    sum = vpadalq_u8(sum, diff);
 
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-  }
+    i++;
+  } while (i < h);
 
-  return sum;
+  return horizontal_add_u16x8(sum);
 }
 
 static INLINE unsigned int sad8xh_neon(const uint8_t *src_ptr, int src_stride,
                                        const uint8_t *ref_ptr, int ref_stride,
                                        int h) {
-  uint16x8_t q3 = vdupq_n_u16(0);
-  for (int y = 0; y < h; y++) {
-    uint8x8_t q0 = vld1_u8(src_ptr);
-    uint8x8_t q1 = vld1_u8(ref_ptr);
+  uint16x8_t sum = vdupq_n_u16(0);
+
+  int i = 0;
+  do {
+    uint8x8_t s = vld1_u8(src_ptr);
+    uint8x8_t r = vld1_u8(ref_ptr);
+
+    sum = vabal_u8(sum, s, r);
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    q3 = vabal_u8(q3, q0, q1);
-  }
-  return horizontal_add_u16x8(q3);
+    i++;
+  } while (i < h);
+
+  return horizontal_add_u16x8(sum);
 }
 
 static INLINE unsigned int sad4xh_neon(const uint8_t *src_ptr, int src_stride,
                                        const uint8_t *ref_ptr, int ref_stride,
                                        int h) {
-  uint16x8_t q3 = vdupq_n_u16(0);
-  uint32x2_t q0 = vdup_n_u32(0);
-  uint32x2_t q1 = vdup_n_u32(0);
-  uint32_t src4, ref4;
-  for (int y = 0; y < h / 2; y++) {
-    memcpy(&src4, src_ptr, 4);
-    memcpy(&ref4, ref_ptr, 4);
+  uint16x8_t sum = vdupq_n_u16(0);
+
+  int i = 0;
+  do {
+    uint32x2_t s, r;
+    uint32_t s0, s1, r0, r1;
+
+    memcpy(&s0, src_ptr, 4);
+    memcpy(&r0, ref_ptr, 4);
+    s = vdup_n_u32(s0);
+    r = vdup_n_u32(r0);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    q0 = vset_lane_u32(src4, q0, 0);
-    q1 = vset_lane_u32(ref4, q1, 0);
 
-    memcpy(&src4, src_ptr, 4);
-    memcpy(&ref4, ref_ptr, 4);
+    memcpy(&s1, src_ptr, 4);
+    memcpy(&r1, ref_ptr, 4);
+    s = vset_lane_u32(s1, s, 1);
+    r = vset_lane_u32(r1, r, 1);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    q0 = vset_lane_u32(src4, q0, 1);
-    q1 = vset_lane_u32(ref4, q1, 1);
 
-    q3 = vabal_u8(q3, vreinterpret_u8_u32(q0), vreinterpret_u8_u32(q1));
-  }
-  return horizontal_add_u16x8(q3);
+    sum = vabal_u8(sum, vreinterpret_u8_u32(s), vreinterpret_u8_u32(r));
+    i++;
+  } while (i < h / 2);
+
+  return horizontal_add_u16x8(sum);
+}
+
+unsigned int aom_sad128x128_neon(const uint8_t *src, int src_stride,
+                                 const uint8_t *ref, int ref_stride) {
+  return sad128xh_neon(src, src_stride, ref, ref_stride, 128);
+}
+
+unsigned int aom_sad64x64_neon(const uint8_t *src, int src_stride,
+                               const uint8_t *ref, int ref_stride) {
+  return sad64xh_neon(src, src_stride, ref, ref_stride, 64);
+}
+
+unsigned int aom_sad32x32_neon(const uint8_t *src, int src_stride,
+                               const uint8_t *ref, int ref_stride) {
+  return sad32xh_neon(src, src_stride, ref, ref_stride, 32);
+}
+
+unsigned int aom_sad16x16_neon(const uint8_t *src, int src_stride,
+                               const uint8_t *ref, int ref_stride) {
+  return sad16xh_neon(src, src_stride, ref, ref_stride, 16);
+}
+
+unsigned int aom_sad16x8_neon(const uint8_t *src, int src_stride,
+                              const uint8_t *ref, int ref_stride) {
+  return sad16xh_neon(src, src_stride, ref, ref_stride, 8);
+}
+
+unsigned int aom_sad8x16_neon(const uint8_t *src, int src_stride,
+                              const uint8_t *ref, int ref_stride) {
+  return sad8xh_neon(src, src_stride, ref, ref_stride, 16);
+}
+
+unsigned int aom_sad8x8_neon(const uint8_t *src, int src_stride,
+                             const uint8_t *ref, int ref_stride) {
+  return sad8xh_neon(src, src_stride, ref, ref_stride, 8);
+}
+
+unsigned int aom_sad4x4_neon(const uint8_t *src, int src_stride,
+                             const uint8_t *ref, int ref_stride) {
+  return sad4xh_neon(src, src_stride, ref, ref_stride, 4);
 }
 
 #define FSADS128_H(h)                                                    \
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index a118f3c..4116509 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -37,6 +37,17 @@
 #endif
 }
 
+static INLINE unsigned int horizontal_add_u32x4(const uint32x4_t a) {
+#if defined(__aarch64__)
+  return vaddvq_u32(a);
+#else
+  const uint64x2_t b = vpaddlq_u32(a);
+  const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)),
+                                vreinterpret_u32_u64(vget_high_u64(b)));
+  return vget_lane_u32(c, 0);
+#endif
+}
+
 static INLINE uint32_t horizontal_long_add_u16x8(const uint16x8_t vec_lo,
                                                  const uint16x8_t vec_hi) {
 #if defined(__aarch64__)