Refactor Neon implementations of SAD4D functions

Factor out common code into sad<w>xhx4d_neon() helper functions. Each
helper function is optimized to only perform a reduction when
necessary - i.e. just before overflow, or when computing the final
result.

Change-Id: I32a0d78c0aedfd418661fe5054850d4c10826eb9
diff --git a/aom_dsp/arm/sad4d_neon.c b/aom_dsp/arm/sad4d_neon.c
index b62628e..bbc0507 100644
--- a/aom_dsp/arm/sad4d_neon.c
+++ b/aom_dsp/arm/sad4d_neon.c
@@ -17,550 +17,320 @@
 #include "aom/aom_integer.h"
 #include "aom_dsp/arm/sum_neon.h"
 
-// Calculate the absolute difference of 64 bytes from vec_src_00, vec_src_16,
-// vec_src_32, vec_src_48 and ref. Accumulate partial sums in vec_sum_ref_lo
-// and vec_sum_ref_hi.
-static void sad_neon_64(const uint8x16_t vec_src_00,
-                        const uint8x16_t vec_src_16,
-                        const uint8x16_t vec_src_32,
-                        const uint8x16_t vec_src_48, const uint8_t *ref,
-                        uint16x8_t *vec_sum_ref_lo,
-                        uint16x8_t *vec_sum_ref_hi) {
-  const uint8x16_t vec_ref_00 = vld1q_u8(ref);
-  const uint8x16_t vec_ref_16 = vld1q_u8(ref + 16);
-  const uint8x16_t vec_ref_32 = vld1q_u8(ref + 32);
-  const uint8x16_t vec_ref_48 = vld1q_u8(ref + 48);
-
-  *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_00),
-                             vget_low_u8(vec_ref_00));
-  *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_00),
-                             vget_high_u8(vec_ref_00));
-  *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_16),
-                             vget_low_u8(vec_ref_16));
-  *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_16),
-                             vget_high_u8(vec_ref_16));
-  *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_32),
-                             vget_low_u8(vec_ref_32));
-  *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_32),
-                             vget_high_u8(vec_ref_32));
-  *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_48),
-                             vget_low_u8(vec_ref_48));
-  *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_48),
-                             vget_high_u8(vec_ref_48));
+static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
+                              uint16x8_t *const sad_sum) {
+  uint8x16_t abs_diff = vabdq_u8(src, ref);
+  *sad_sum = vpadalq_u8(*sad_sum, abs_diff);
 }
 
-// Calculate the absolute difference of 32 bytes from vec_src_00, vec_src_16,
-// and ref. Accumulate partial sums in vec_sum_ref_lo and vec_sum_ref_hi.
-static void sad_neon_32(const uint8x16_t vec_src_00,
-                        const uint8x16_t vec_src_16, const uint8_t *ref,
-                        uint16x8_t *vec_sum_ref_lo,
-                        uint16x8_t *vec_sum_ref_hi) {
-  const uint8x16_t vec_ref_00 = vld1q_u8(ref);
-  const uint8x16_t vec_ref_16 = vld1q_u8(ref + 16);
+static INLINE void sad128xhx4d_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *const ref[4], int ref_stride,
+                                    uint32_t res[4], int h) {
+  vst1q_u32(res, vdupq_n_u32(0));
+  int h_tmp = h > 32 ? 32 : h;
 
-  *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_00),
-                             vget_low_u8(vec_ref_00));
-  *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_00),
-                             vget_high_u8(vec_ref_00));
-  *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_16),
-                             vget_low_u8(vec_ref_16));
-  *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_16),
-                             vget_high_u8(vec_ref_16));
+  int i = 0;
+  do {
+    uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                             vdupq_n_u16(0) };
+    uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                             vdupq_n_u16(0) };
+
+    do {
+      const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
+      sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
+      sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
+      sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
+      sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+
+      const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
+      sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
+      sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
+      sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
+      sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+
+      const uint8x16_t s2 = vld1q_u8(src + i * src_stride + 32);
+      sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
+      sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
+      sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
+      sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
+
+      const uint8x16_t s3 = vld1q_u8(src + i * src_stride + 48);
+      sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
+      sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
+      sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
+      sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
+
+      const uint8x16_t s4 = vld1q_u8(src + i * src_stride + 64);
+      sad16_neon(s4, vld1q_u8(ref[0] + i * ref_stride + 64), &sum_lo[0]);
+      sad16_neon(s4, vld1q_u8(ref[1] + i * ref_stride + 64), &sum_lo[1]);
+      sad16_neon(s4, vld1q_u8(ref[2] + i * ref_stride + 64), &sum_lo[2]);
+      sad16_neon(s4, vld1q_u8(ref[3] + i * ref_stride + 64), &sum_lo[3]);
+
+      const uint8x16_t s5 = vld1q_u8(src + i * src_stride + 80);
+      sad16_neon(s5, vld1q_u8(ref[0] + i * ref_stride + 80), &sum_hi[0]);
+      sad16_neon(s5, vld1q_u8(ref[1] + i * ref_stride + 80), &sum_hi[1]);
+      sad16_neon(s5, vld1q_u8(ref[2] + i * ref_stride + 80), &sum_hi[2]);
+      sad16_neon(s5, vld1q_u8(ref[3] + i * ref_stride + 80), &sum_hi[3]);
+
+      const uint8x16_t s6 = vld1q_u8(src + i * src_stride + 96);
+      sad16_neon(s6, vld1q_u8(ref[0] + i * ref_stride + 96), &sum_lo[0]);
+      sad16_neon(s6, vld1q_u8(ref[1] + i * ref_stride + 96), &sum_lo[1]);
+      sad16_neon(s6, vld1q_u8(ref[2] + i * ref_stride + 96), &sum_lo[2]);
+      sad16_neon(s6, vld1q_u8(ref[3] + i * ref_stride + 96), &sum_lo[3]);
+
+      const uint8x16_t s7 = vld1q_u8(src + i * src_stride + 112);
+      sad16_neon(s7, vld1q_u8(ref[0] + i * ref_stride + 112), &sum_hi[0]);
+      sad16_neon(s7, vld1q_u8(ref[1] + i * ref_stride + 112), &sum_hi[1]);
+      sad16_neon(s7, vld1q_u8(ref[2] + i * ref_stride + 112), &sum_hi[2]);
+      sad16_neon(s7, vld1q_u8(ref[3] + i * ref_stride + 112), &sum_hi[3]);
+
+      i++;
+    } while (i < h_tmp);
+
+    res[0] += horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
+    res[1] += horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
+    res[2] += horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
+    res[3] += horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
+
+    h_tmp += 32;
+  } while (i < h);
+}
+
+static INLINE void sad64xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  vst1q_u32(res, vdupq_n_u32(0));
+  int h_tmp = h > 64 ? 64 : h;
+
+  int i = 0;
+  do {
+    uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                             vdupq_n_u16(0) };
+    uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                             vdupq_n_u16(0) };
+
+    do {
+      const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
+      sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
+      sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
+      sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
+      sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+
+      const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
+      sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
+      sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
+      sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
+      sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+
+      const uint8x16_t s2 = vld1q_u8(src + i * src_stride + 32);
+      sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
+      sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
+      sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
+      sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
+
+      const uint8x16_t s3 = vld1q_u8(src + i * src_stride + 48);
+      sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
+      sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
+      sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
+      sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
+
+      i++;
+    } while (i < h_tmp);
+
+    res[0] += horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
+    res[1] += horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
+    res[2] += horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
+    res[3] += horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
+
+    h_tmp += 64;
+  } while (i < h);
+}
+
+static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                           vdupq_n_u16(0) };
+  uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                           vdupq_n_u16(0) };
+
+  int i = 0;
+  do {
+    const uint8x16_t s0 = vld1q_u8(src + i * src_stride);
+    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
+    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
+    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
+    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+
+    const uint8x16_t s1 = vld1q_u8(src + i * src_stride + 16);
+    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
+    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
+    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
+    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+
+    i++;
+  } while (i < h);
+
+  res[0] = horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
+  res[1] = horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
+  res[2] = horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
+  res[3] = horizontal_long_add_u16x8(sum_lo[3], sum_hi[3]);
+}
+
+static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0) };
+
+  int i = 0;
+  do {
+    const uint8x16_t s = vld1q_u8(src + i * src_stride);
+    sad16_neon(s, vld1q_u8(ref[0] + i * ref_stride), &sum[0]);
+    sad16_neon(s, vld1q_u8(ref[1] + i * ref_stride), &sum[1]);
+    sad16_neon(s, vld1q_u8(ref[2] + i * ref_stride), &sum[2]);
+    sad16_neon(s, vld1q_u8(ref[3] + i * ref_stride), &sum[3]);
+
+    i++;
+  } while (i < h);
+
+  res[0] = horizontal_add_u16x8(sum[0]);
+  res[1] = horizontal_add_u16x8(sum[1]);
+  res[2] = horizontal_add_u16x8(sum[2]);
+  res[3] = horizontal_add_u16x8(sum[3]);
+}
+
+static INLINE void sad8_neon(uint8x8_t src, uint8x8_t ref,
+                             uint16x8_t *const sad_sum) {
+  uint8x8_t abs_diff = vabd_u8(src, ref);
+  *sad_sum = vaddw_u8(*sad_sum, abs_diff);
+}
+
+static INLINE void sad8xhx4d_neon(const uint8_t *src, int src_stride,
+                                  const uint8_t *const ref[4], int ref_stride,
+                                  uint32_t res[4], int h) {
+  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0) };
+
+  int i = 0;
+  do {
+    const uint8x8_t s = vld1_u8(src + i * src_stride);
+    sad8_neon(s, vld1_u8(ref[0] + i * ref_stride), &sum[0]);
+    sad8_neon(s, vld1_u8(ref[1] + i * ref_stride), &sum[1]);
+    sad8_neon(s, vld1_u8(ref[2] + i * ref_stride), &sum[2]);
+    sad8_neon(s, vld1_u8(ref[3] + i * ref_stride), &sum[3]);
+
+    i++;
+  } while (i < h);
+
+  res[0] = horizontal_add_u16x8(sum[0]);
+  res[1] = horizontal_add_u16x8(sum[1]);
+  res[2] = horizontal_add_u16x8(sum[2]);
+  res[3] = horizontal_add_u16x8(sum[3]);
+}
+
+static INLINE void sad4xhx4d_neon(const uint8_t *src, int src_stride,
+                                  const uint8_t *const ref[4], int ref_stride,
+                                  uint32_t res[4], int h) {
+  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0) };
+
+  int i = 0;
+  do {
+    uint32x2_t s, r0, r1, r2, r3;
+    uint32_t s_lo, s_hi, r0_lo, r0_hi, r1_lo, r1_hi, r2_lo, r2_hi, r3_lo, r3_hi;
+
+    memcpy(&s_lo, src + i * src_stride, 4);
+    memcpy(&r0_lo, ref[0] + i * ref_stride, 4);
+    memcpy(&r1_lo, ref[1] + i * ref_stride, 4);
+    memcpy(&r2_lo, ref[2] + i * ref_stride, 4);
+    memcpy(&r3_lo, ref[3] + i * ref_stride, 4);
+    s = vdup_n_u32(s_lo);
+    r0 = vdup_n_u32(r0_lo);
+    r1 = vdup_n_u32(r1_lo);
+    r2 = vdup_n_u32(r2_lo);
+    r3 = vdup_n_u32(r3_lo);
+
+    memcpy(&s_hi, src + (i + 1) * src_stride, 4);
+    memcpy(&r0_hi, ref[0] + (i + 1) * ref_stride, 4);
+    memcpy(&r1_hi, ref[1] + (i + 1) * ref_stride, 4);
+    memcpy(&r2_hi, ref[2] + (i + 1) * ref_stride, 4);
+    memcpy(&r3_hi, ref[3] + (i + 1) * ref_stride, 4);
+    s = vset_lane_u32(s_hi, s, 1);
+    r0 = vset_lane_u32(r0_hi, r0, 1);
+    r1 = vset_lane_u32(r1_hi, r1, 1);
+    r2 = vset_lane_u32(r2_hi, r2, 1);
+    r3 = vset_lane_u32(r3_hi, r3, 1);
+
+    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r0), &sum[0]);
+    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r1), &sum[1]);
+    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r2), &sum[2]);
+    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r3), &sum[3]);
+
+    i += 2;
+  } while (i < h);
+
+  res[0] = horizontal_add_u16x8(sum[0]);
+  res[1] = horizontal_add_u16x8(sum[1]);
+  res[2] = horizontal_add_u16x8(sum[2]);
+  res[3] = horizontal_add_u16x8(sum[3]);
 }
 
 void aom_sad64x64x4d_neon(const uint8_t *src, int src_stride,
                           const uint8_t *const ref[4], int ref_stride,
                           uint32_t res[4]) {
-  int i;
-  uint16x8_t vec_sum_ref0_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref0_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref1_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref1_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref2_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref2_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref3_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref3_hi = vdupq_n_u16(0);
-  const uint8_t *ref0, *ref1, *ref2, *ref3;
-  ref0 = ref[0];
-  ref1 = ref[1];
-  ref2 = ref[2];
-  ref3 = ref[3];
-
-  for (i = 0; i < 64; ++i) {
-    const uint8x16_t vec_src_00 = vld1q_u8(src);
-    const uint8x16_t vec_src_16 = vld1q_u8(src + 16);
-    const uint8x16_t vec_src_32 = vld1q_u8(src + 32);
-    const uint8x16_t vec_src_48 = vld1q_u8(src + 48);
-
-    sad_neon_64(vec_src_00, vec_src_16, vec_src_32, vec_src_48, ref0,
-                &vec_sum_ref0_lo, &vec_sum_ref0_hi);
-    sad_neon_64(vec_src_00, vec_src_16, vec_src_32, vec_src_48, ref1,
-                &vec_sum_ref1_lo, &vec_sum_ref1_hi);
-    sad_neon_64(vec_src_00, vec_src_16, vec_src_32, vec_src_48, ref2,
-                &vec_sum_ref2_lo, &vec_sum_ref2_hi);
-    sad_neon_64(vec_src_00, vec_src_16, vec_src_32, vec_src_48, ref3,
-                &vec_sum_ref3_lo, &vec_sum_ref3_hi);
-
-    src += src_stride;
-    ref0 += ref_stride;
-    ref1 += ref_stride;
-    ref2 += ref_stride;
-    ref3 += ref_stride;
-  }
-
-  res[0] = horizontal_long_add_u16x8(vec_sum_ref0_lo, vec_sum_ref0_hi);
-  res[1] = horizontal_long_add_u16x8(vec_sum_ref1_lo, vec_sum_ref1_hi);
-  res[2] = horizontal_long_add_u16x8(vec_sum_ref2_lo, vec_sum_ref2_hi);
-  res[3] = horizontal_long_add_u16x8(vec_sum_ref3_lo, vec_sum_ref3_hi);
+  sad64xhx4d_neon(src, src_stride, ref, ref_stride, res, 64);
 }
 
 void aom_sad32x32x4d_neon(const uint8_t *src, int src_stride,
                           const uint8_t *const ref[4], int ref_stride,
                           uint32_t res[4]) {
-  int i;
-  uint16x8_t vec_sum_ref0_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref0_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref1_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref1_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref2_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref2_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref3_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref3_hi = vdupq_n_u16(0);
-  const uint8_t *ref0, *ref1, *ref2, *ref3;
-  ref0 = ref[0];
-  ref1 = ref[1];
-  ref2 = ref[2];
-  ref3 = ref[3];
-
-  for (i = 0; i < 32; ++i) {
-    const uint8x16_t vec_src_00 = vld1q_u8(src);
-    const uint8x16_t vec_src_16 = vld1q_u8(src + 16);
-
-    sad_neon_32(vec_src_00, vec_src_16, ref0, &vec_sum_ref0_lo,
-                &vec_sum_ref0_hi);
-    sad_neon_32(vec_src_00, vec_src_16, ref1, &vec_sum_ref1_lo,
-                &vec_sum_ref1_hi);
-    sad_neon_32(vec_src_00, vec_src_16, ref2, &vec_sum_ref2_lo,
-                &vec_sum_ref2_hi);
-    sad_neon_32(vec_src_00, vec_src_16, ref3, &vec_sum_ref3_lo,
-                &vec_sum_ref3_hi);
-
-    src += src_stride;
-    ref0 += ref_stride;
-    ref1 += ref_stride;
-    ref2 += ref_stride;
-    ref3 += ref_stride;
-  }
-
-  res[0] = horizontal_long_add_u16x8(vec_sum_ref0_lo, vec_sum_ref0_hi);
-  res[1] = horizontal_long_add_u16x8(vec_sum_ref1_lo, vec_sum_ref1_hi);
-  res[2] = horizontal_long_add_u16x8(vec_sum_ref2_lo, vec_sum_ref2_hi);
-  res[3] = horizontal_long_add_u16x8(vec_sum_ref3_lo, vec_sum_ref3_hi);
+  sad32xhx4d_neon(src, src_stride, ref, ref_stride, res, 32);
 }
 
 void aom_sad16x16x4d_neon(const uint8_t *src, int src_stride,
                           const uint8_t *const ref[4], int ref_stride,
                           uint32_t res[4]) {
-  int i;
-  uint16x8_t vec_sum_ref0_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref0_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref1_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref1_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref2_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref2_hi = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref3_lo = vdupq_n_u16(0);
-  uint16x8_t vec_sum_ref3_hi = vdupq_n_u16(0);
-  const uint8_t *ref0, *ref1, *ref2, *ref3;
-  ref0 = ref[0];
-  ref1 = ref[1];
-  ref2 = ref[2];
-  ref3 = ref[3];
-
-  for (i = 0; i < 16; ++i) {
-    const uint8x16_t vec_src = vld1q_u8(src);
-    const uint8x16_t vec_ref0 = vld1q_u8(ref0);
-    const uint8x16_t vec_ref1 = vld1q_u8(ref1);
-    const uint8x16_t vec_ref2 = vld1q_u8(ref2);
-    const uint8x16_t vec_ref3 = vld1q_u8(ref3);
-
-    vec_sum_ref0_lo =
-        vabal_u8(vec_sum_ref0_lo, vget_low_u8(vec_src), vget_low_u8(vec_ref0));
-    vec_sum_ref0_hi = vabal_u8(vec_sum_ref0_hi, vget_high_u8(vec_src),
-                               vget_high_u8(vec_ref0));
-    vec_sum_ref1_lo =
-        vabal_u8(vec_sum_ref1_lo, vget_low_u8(vec_src), vget_low_u8(vec_ref1));
-    vec_sum_ref1_hi = vabal_u8(vec_sum_ref1_hi, vget_high_u8(vec_src),
-                               vget_high_u8(vec_ref1));
-    vec_sum_ref2_lo =
-        vabal_u8(vec_sum_ref2_lo, vget_low_u8(vec_src), vget_low_u8(vec_ref2));
-    vec_sum_ref2_hi = vabal_u8(vec_sum_ref2_hi, vget_high_u8(vec_src),
-                               vget_high_u8(vec_ref2));
-    vec_sum_ref3_lo =
-        vabal_u8(vec_sum_ref3_lo, vget_low_u8(vec_src), vget_low_u8(vec_ref3));
-    vec_sum_ref3_hi = vabal_u8(vec_sum_ref3_hi, vget_high_u8(vec_src),
-                               vget_high_u8(vec_ref3));
-
-    src += src_stride;
-    ref0 += ref_stride;
-    ref1 += ref_stride;
-    ref2 += ref_stride;
-    ref3 += ref_stride;
-  }
-
-  res[0] = horizontal_long_add_u16x8(vec_sum_ref0_lo, vec_sum_ref0_hi);
-  res[1] = horizontal_long_add_u16x8(vec_sum_ref1_lo, vec_sum_ref1_hi);
-  res[2] = horizontal_long_add_u16x8(vec_sum_ref2_lo, vec_sum_ref2_hi);
-  res[3] = horizontal_long_add_u16x8(vec_sum_ref3_lo, vec_sum_ref3_hi);
+  sad16xhx4d_neon(src, src_stride, ref, ref_stride, res, 16);
 }
 
-static void sad_row4_neon(uint16x4_t *vec_src, const uint8x8_t q0,
-                          const uint8x8_t ref) {
-  uint8x8_t q2 = vabd_u8(q0, ref);
-  *vec_src = vpadal_u8(*vec_src, q2);
-}
-
-static void sad_row8_neon(uint16x4_t *vec_src, const uint8x8_t *q0,
-                          const uint8_t *ref_ptr) {
-  uint8x8_t q1 = vld1_u8(ref_ptr);
-  uint8x8_t q2 = vabd_u8(*q0, q1);
-  *vec_src = vpadal_u8(*vec_src, q2);
-}
-
-static void sad_row16_neon(uint16x8_t *vec_src, const uint8x16_t *q0,
-                           const uint8_t *ref_ptr) {
-  uint8x16_t q1 = vld1q_u8(ref_ptr);
-  uint8x16_t q2 = vabdq_u8(*q0, q1);
-  *vec_src = vpadalq_u8(*vec_src, q2);
-}
-
-void aom_sadMxNx4d_neon(int width, int height, const uint8_t *src,
-                        int src_stride, const uint8_t *const ref[4],
-                        int ref_stride, uint32_t res[4]) {
-  const uint8_t *ref0, *ref1, *ref2, *ref3;
-
-  ref0 = ref[0];
-  ref1 = ref[1];
-  ref2 = ref[2];
-  ref3 = ref[3];
-
-  res[0] = 0;
-  res[1] = 0;
-  res[2] = 0;
-  res[3] = 0;
-
-  switch (width) {
-    case 4: {
-      uint32_t src4, ref40, ref41, ref42, ref43;
-      uint32x2_t q8 = vdup_n_u32(0);
-      uint32x2_t q4 = vdup_n_u32(0);
-      uint32x2_t q5 = vdup_n_u32(0);
-      uint32x2_t q6 = vdup_n_u32(0);
-      uint32x2_t q7 = vdup_n_u32(0);
-
-      for (int i = 0; i < height / 2; i++) {
-        uint16x4_t q0 = vdup_n_u16(0);
-        uint16x4_t q1 = vdup_n_u16(0);
-        uint16x4_t q2 = vdup_n_u16(0);
-        uint16x4_t q3 = vdup_n_u16(0);
-
-        memcpy(&src4, src, 4);
-        memcpy(&ref40, ref0, 4);
-        memcpy(&ref41, ref1, 4);
-        memcpy(&ref42, ref2, 4);
-        memcpy(&ref43, ref3, 4);
-
-        src += src_stride;
-        ref0 += ref_stride;
-        ref1 += ref_stride;
-        ref2 += ref_stride;
-        ref3 += ref_stride;
-
-        q8 = vset_lane_u32(src4, q8, 0);
-        q4 = vset_lane_u32(ref40, q4, 0);
-        q5 = vset_lane_u32(ref41, q5, 0);
-        q6 = vset_lane_u32(ref42, q6, 0);
-        q7 = vset_lane_u32(ref43, q7, 0);
-
-        memcpy(&src4, src, 4);
-        memcpy(&ref40, ref0, 4);
-        memcpy(&ref41, ref1, 4);
-        memcpy(&ref42, ref2, 4);
-        memcpy(&ref43, ref3, 4);
-
-        src += src_stride;
-        ref0 += ref_stride;
-        ref1 += ref_stride;
-        ref2 += ref_stride;
-        ref3 += ref_stride;
-
-        q8 = vset_lane_u32(src4, q8, 1);
-        q4 = vset_lane_u32(ref40, q4, 1);
-        q5 = vset_lane_u32(ref41, q5, 1);
-        q6 = vset_lane_u32(ref42, q6, 1);
-        q7 = vset_lane_u32(ref43, q7, 1);
-
-        sad_row4_neon(&q0, vreinterpret_u8_u32(q8), vreinterpret_u8_u32(q4));
-        sad_row4_neon(&q1, vreinterpret_u8_u32(q8), vreinterpret_u8_u32(q5));
-        sad_row4_neon(&q2, vreinterpret_u8_u32(q8), vreinterpret_u8_u32(q6));
-        sad_row4_neon(&q3, vreinterpret_u8_u32(q8), vreinterpret_u8_u32(q7));
-
-        res[0] += horizontal_add_u16x4(q0);
-        res[1] += horizontal_add_u16x4(q1);
-        res[2] += horizontal_add_u16x4(q2);
-        res[3] += horizontal_add_u16x4(q3);
-      }
-      break;
-    }
-    case 8: {
-      for (int i = 0; i < height; i++) {
-        uint16x4_t q0 = vdup_n_u16(0);
-        uint16x4_t q1 = vdup_n_u16(0);
-        uint16x4_t q2 = vdup_n_u16(0);
-        uint16x4_t q3 = vdup_n_u16(0);
-
-        uint8x8_t q5 = vld1_u8(src);
-
-        sad_row8_neon(&q0, &q5, ref0);
-        sad_row8_neon(&q1, &q5, ref1);
-        sad_row8_neon(&q2, &q5, ref2);
-        sad_row8_neon(&q3, &q5, ref3);
-
-        src += src_stride;
-        ref0 += ref_stride;
-        ref1 += ref_stride;
-        ref2 += ref_stride;
-        ref3 += ref_stride;
-
-        res[0] += horizontal_add_u16x4(q0);
-        res[1] += horizontal_add_u16x4(q1);
-        res[2] += horizontal_add_u16x4(q2);
-        res[3] += horizontal_add_u16x4(q3);
-      }
-      break;
-    }
-    case 16: {
-      for (int i = 0; i < height; i++) {
-        uint16x8_t q0 = vdupq_n_u16(0);
-        uint16x8_t q1 = vdupq_n_u16(0);
-        uint16x8_t q2 = vdupq_n_u16(0);
-        uint16x8_t q3 = vdupq_n_u16(0);
-
-        uint8x16_t q4 = vld1q_u8(src);
-
-        sad_row16_neon(&q0, &q4, ref0);
-        sad_row16_neon(&q1, &q4, ref1);
-        sad_row16_neon(&q2, &q4, ref2);
-        sad_row16_neon(&q3, &q4, ref3);
-
-        src += src_stride;
-        ref0 += ref_stride;
-        ref1 += ref_stride;
-        ref2 += ref_stride;
-        ref3 += ref_stride;
-
-        res[0] += horizontal_add_u16x8(q0);
-        res[1] += horizontal_add_u16x8(q1);
-        res[2] += horizontal_add_u16x8(q2);
-        res[3] += horizontal_add_u16x8(q3);
-      }
-      break;
-    }
-    case 32: {
-      for (int i = 0; i < height; i++) {
-        uint16x8_t q0 = vdupq_n_u16(0);
-        uint16x8_t q1 = vdupq_n_u16(0);
-        uint16x8_t q2 = vdupq_n_u16(0);
-        uint16x8_t q3 = vdupq_n_u16(0);
-
-        uint8x16_t q4 = vld1q_u8(src);
-
-        sad_row16_neon(&q0, &q4, ref0);
-        sad_row16_neon(&q1, &q4, ref1);
-        sad_row16_neon(&q2, &q4, ref2);
-        sad_row16_neon(&q3, &q4, ref3);
-
-        q4 = vld1q_u8(src + 16);
-
-        sad_row16_neon(&q0, &q4, ref0 + 16);
-        sad_row16_neon(&q1, &q4, ref1 + 16);
-        sad_row16_neon(&q2, &q4, ref2 + 16);
-        sad_row16_neon(&q3, &q4, ref3 + 16);
-
-        src += src_stride;
-        ref0 += ref_stride;
-        ref1 += ref_stride;
-        ref2 += ref_stride;
-        ref3 += ref_stride;
-
-        res[0] += horizontal_add_u16x8(q0);
-        res[1] += horizontal_add_u16x8(q1);
-        res[2] += horizontal_add_u16x8(q2);
-        res[3] += horizontal_add_u16x8(q3);
-      }
-      break;
-    }
-    case 64: {
-      for (int i = 0; i < height; i++) {
-        uint16x8_t q0 = vdupq_n_u16(0);
-        uint16x8_t q1 = vdupq_n_u16(0);
-        uint16x8_t q2 = vdupq_n_u16(0);
-        uint16x8_t q3 = vdupq_n_u16(0);
-
-        uint8x16_t q4 = vld1q_u8(src);
-
-        sad_row16_neon(&q0, &q4, ref0);
-        sad_row16_neon(&q1, &q4, ref1);
-        sad_row16_neon(&q2, &q4, ref2);
-        sad_row16_neon(&q3, &q4, ref3);
-
-        q4 = vld1q_u8(src + 16);
-
-        sad_row16_neon(&q0, &q4, ref0 + 16);
-        sad_row16_neon(&q1, &q4, ref1 + 16);
-        sad_row16_neon(&q2, &q4, ref2 + 16);
-        sad_row16_neon(&q3, &q4, ref3 + 16);
-
-        q4 = vld1q_u8(src + 32);
-
-        sad_row16_neon(&q0, &q4, ref0 + 32);
-        sad_row16_neon(&q1, &q4, ref1 + 32);
-        sad_row16_neon(&q2, &q4, ref2 + 32);
-        sad_row16_neon(&q3, &q4, ref3 + 32);
-
-        q4 = vld1q_u8(src + 48);
-
-        sad_row16_neon(&q0, &q4, ref0 + 48);
-        sad_row16_neon(&q1, &q4, ref1 + 48);
-        sad_row16_neon(&q2, &q4, ref2 + 48);
-        sad_row16_neon(&q3, &q4, ref3 + 48);
-
-        src += src_stride;
-        ref0 += ref_stride;
-        ref1 += ref_stride;
-        ref2 += ref_stride;
-        ref3 += ref_stride;
-
-        res[0] += horizontal_add_u16x8(q0);
-        res[1] += horizontal_add_u16x8(q1);
-        res[2] += horizontal_add_u16x8(q2);
-        res[3] += horizontal_add_u16x8(q3);
-      }
-      break;
-    }
-    case 128: {
-      for (int i = 0; i < height; i++) {
-        uint16x8_t q0 = vdupq_n_u16(0);
-        uint16x8_t q1 = vdupq_n_u16(0);
-        uint16x8_t q2 = vdupq_n_u16(0);
-        uint16x8_t q3 = vdupq_n_u16(0);
-
-        uint8x16_t q4 = vld1q_u8(src);
-
-        sad_row16_neon(&q0, &q4, ref0);
-        sad_row16_neon(&q1, &q4, ref1);
-        sad_row16_neon(&q2, &q4, ref2);
-        sad_row16_neon(&q3, &q4, ref3);
-
-        q4 = vld1q_u8(src + 16);
-
-        sad_row16_neon(&q0, &q4, ref0 + 16);
-        sad_row16_neon(&q1, &q4, ref1 + 16);
-        sad_row16_neon(&q2, &q4, ref2 + 16);
-        sad_row16_neon(&q3, &q4, ref3 + 16);
-
-        q4 = vld1q_u8(src + 32);
-
-        sad_row16_neon(&q0, &q4, ref0 + 32);
-        sad_row16_neon(&q1, &q4, ref1 + 32);
-        sad_row16_neon(&q2, &q4, ref2 + 32);
-        sad_row16_neon(&q3, &q4, ref3 + 32);
-
-        q4 = vld1q_u8(src + 48);
-
-        sad_row16_neon(&q0, &q4, ref0 + 48);
-        sad_row16_neon(&q1, &q4, ref1 + 48);
-        sad_row16_neon(&q2, &q4, ref2 + 48);
-        sad_row16_neon(&q3, &q4, ref3 + 48);
-
-        q4 = vld1q_u8(src + 64);
-
-        sad_row16_neon(&q0, &q4, ref0 + 64);
-        sad_row16_neon(&q1, &q4, ref1 + 64);
-        sad_row16_neon(&q2, &q4, ref2 + 64);
-        sad_row16_neon(&q3, &q4, ref3 + 64);
-
-        q4 = vld1q_u8(src + 80);
-
-        sad_row16_neon(&q0, &q4, ref0 + 80);
-        sad_row16_neon(&q1, &q4, ref1 + 80);
-        sad_row16_neon(&q2, &q4, ref2 + 80);
-        sad_row16_neon(&q3, &q4, ref3 + 80);
-
-        q4 = vld1q_u8(src + 96);
-
-        sad_row16_neon(&q0, &q4, ref0 + 96);
-        sad_row16_neon(&q1, &q4, ref1 + 96);
-        sad_row16_neon(&q2, &q4, ref2 + 96);
-        sad_row16_neon(&q3, &q4, ref3 + 96);
-
-        q4 = vld1q_u8(src + 112);
-
-        sad_row16_neon(&q0, &q4, ref0 + 112);
-        sad_row16_neon(&q1, &q4, ref1 + 112);
-        sad_row16_neon(&q2, &q4, ref2 + 112);
-        sad_row16_neon(&q3, &q4, ref3 + 112);
-
-        src += src_stride;
-        ref0 += ref_stride;
-        ref1 += ref_stride;
-        ref2 += ref_stride;
-        ref3 += ref_stride;
-
-        res[0] += horizontal_add_u16x8(q0);
-        res[1] += horizontal_add_u16x8(q1);
-        res[2] += horizontal_add_u16x8(q2);
-        res[3] += horizontal_add_u16x8(q3);
-      }
-    }
-  }
-}
-
-#define SAD_SKIP_MXN_NEON(m, n)                                             \
-  void aom_sad_skip_##m##x##n##x4d_neon(const uint8_t *src, int src_stride, \
+#define SAD_SKIP_WXH_4D_NEON(w, h)                                          \
+  void aom_sad_skip_##w##x##h##x4d_neon(const uint8_t *src, int src_stride, \
                                         const uint8_t *const ref[4],        \
                                         int ref_stride, uint32_t res[4]) {  \
-    aom_sadMxNx4d_neon(m, ((n) >> 1), src, 2 * src_stride, ref,             \
-                       2 * ref_stride, res);                                \
+    sad##w##xhx4d_neon(src, 2 * src_stride, ref, 2 * ref_stride, res,       \
+                       ((h) >> 1));                                         \
     res[0] <<= 1;                                                           \
     res[1] <<= 1;                                                           \
     res[2] <<= 1;                                                           \
     res[3] <<= 1;                                                           \
   }
 
-SAD_SKIP_MXN_NEON(4, 8)
-SAD_SKIP_MXN_NEON(4, 16)
-SAD_SKIP_MXN_NEON(4, 32)
+SAD_SKIP_WXH_4D_NEON(4, 8)
+SAD_SKIP_WXH_4D_NEON(4, 16)
+SAD_SKIP_WXH_4D_NEON(4, 32)
 
-SAD_SKIP_MXN_NEON(8, 8)
-SAD_SKIP_MXN_NEON(8, 16)
-SAD_SKIP_MXN_NEON(8, 32)
+SAD_SKIP_WXH_4D_NEON(8, 8)
+SAD_SKIP_WXH_4D_NEON(8, 16)
+SAD_SKIP_WXH_4D_NEON(8, 32)
 
-SAD_SKIP_MXN_NEON(16, 8)
-SAD_SKIP_MXN_NEON(16, 16)
-SAD_SKIP_MXN_NEON(16, 32)
-SAD_SKIP_MXN_NEON(16, 64)
+SAD_SKIP_WXH_4D_NEON(16, 8)
+SAD_SKIP_WXH_4D_NEON(16, 16)
+SAD_SKIP_WXH_4D_NEON(16, 32)
+SAD_SKIP_WXH_4D_NEON(16, 64)
 
-SAD_SKIP_MXN_NEON(32, 8)
-SAD_SKIP_MXN_NEON(32, 16)
-SAD_SKIP_MXN_NEON(32, 32)
-SAD_SKIP_MXN_NEON(32, 64)
+SAD_SKIP_WXH_4D_NEON(32, 8)
+SAD_SKIP_WXH_4D_NEON(32, 16)
+SAD_SKIP_WXH_4D_NEON(32, 32)
+SAD_SKIP_WXH_4D_NEON(32, 64)
 
-SAD_SKIP_MXN_NEON(64, 16)
-SAD_SKIP_MXN_NEON(64, 32)
-SAD_SKIP_MXN_NEON(64, 64)
-SAD_SKIP_MXN_NEON(64, 128)
+SAD_SKIP_WXH_4D_NEON(64, 16)
+SAD_SKIP_WXH_4D_NEON(64, 32)
+SAD_SKIP_WXH_4D_NEON(64, 64)
+SAD_SKIP_WXH_4D_NEON(64, 128)
 
-SAD_SKIP_MXN_NEON(128, 64)
-SAD_SKIP_MXN_NEON(128, 128)
+SAD_SKIP_WXH_4D_NEON(128, 64)
+SAD_SKIP_WXH_4D_NEON(128, 128)
 
-#undef SAD_SKIP_MXN_NEON
+#undef SAD_SKIP_WXH_4D_NEON