Refactor aom_sse_neon to use width-specific helper functions

Refactor Neon SSE function to use width-specific helper functions of
the form sse_<w>xh_neon. A faster Armv8.4-A dot-product
implementation of each helper function will be added in a subsequent
patch.

Change-Id: I5cca24099d709b9dddf39772e94aff83562f8f9a
diff --git a/aom_dsp/arm/sse_neon.c b/aom_dsp/arm/sse_neon.c
index a69dfb5..767a7e9 100644
--- a/aom_dsp/arm/sse_neon.c
+++ b/aom_dsp/arm/sse_neon.c
@@ -16,141 +16,202 @@
 #include "aom_dsp/arm/sum_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
 
-static INLINE void sse_w16_neon(uint32x4_t *sum, const uint8_t *a,
-                                const uint8_t *b) {
-  const uint8x16_t v_a0 = vld1q_u8(a);
-  const uint8x16_t v_b0 = vld1q_u8(b);
-  const uint8x16_t diff = vabdq_u8(v_a0, v_b0);
-  const uint8x8_t diff_lo = vget_low_u8(diff);
-  const uint8x8_t diff_hi = vget_high_u8(diff);
-  *sum = vpadalq_u16(*sum, vmull_u8(diff_lo, diff_lo));
-  *sum = vpadalq_u16(*sum, vmull_u8(diff_hi, diff_hi));
+static INLINE void sse_16x1_neon(const uint8_t *src, const uint8_t *ref,
+                                 uint32x4_t *sse) {
+  uint8x16_t s = vld1q_u8(src);
+  uint8x16_t r = vld1q_u8(ref);
+
+  uint8x16_t abs_diff = vabdq_u8(s, r);
+  uint8x8_t abs_diff_lo = vget_low_u8(abs_diff);
+  uint8x8_t abs_diff_hi = vget_high_u8(abs_diff);
+
+  *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_lo, abs_diff_lo));
+  *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_hi, abs_diff_hi));
 }
-static INLINE void aom_sse4x2_neon(const uint8_t *a, int a_stride,
-                                   const uint8_t *b, int b_stride,
-                                   uint32x4_t *sum) {
-  uint8x8_t v_a0, v_b0;
-  v_a0 = v_b0 = vcreate_u8(0);
-  // above line is only to shadow [-Werror=uninitialized]
-  v_a0 = vreinterpret_u8_u32(
-      vld1_lane_u32((uint32_t *)a, vreinterpret_u32_u8(v_a0), 0));
-  v_a0 = vreinterpret_u8_u32(
-      vld1_lane_u32((uint32_t *)(a + a_stride), vreinterpret_u32_u8(v_a0), 1));
-  v_b0 = vreinterpret_u8_u32(
-      vld1_lane_u32((uint32_t *)b, vreinterpret_u32_u8(v_b0), 0));
-  v_b0 = vreinterpret_u8_u32(
-      vld1_lane_u32((uint32_t *)(b + b_stride), vreinterpret_u32_u8(v_b0), 1));
-  const uint8x8_t v_a_w = vabd_u8(v_a0, v_b0);
-  *sum = vpadalq_u16(*sum, vmull_u8(v_a_w, v_a_w));
+
+static INLINE void sse_8x1_neon(const uint8_t *src, const uint8_t *ref,
+                                uint32x4_t *sse) {
+  uint8x8_t s = vld1_u8(src);
+  uint8x8_t r = vld1_u8(ref);
+
+  uint8x8_t abs_diff = vabd_u8(s, r);
+
+  *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
 }
-static INLINE void aom_sse8_neon(const uint8_t *a, const uint8_t *b,
-                                 uint32x4_t *sum) {
-  const uint8x8_t v_a_w = vld1_u8(a);
-  const uint8x8_t v_b_w = vld1_u8(b);
-  const uint8x8_t v_d_w = vabd_u8(v_a_w, v_b_w);
-  *sum = vpadalq_u16(*sum, vmull_u8(v_d_w, v_d_w));
+
+static INLINE void sse_4x2_neon(const uint8_t *src, int src_stride,
+                                const uint8_t *ref, int ref_stride,
+                                uint32x4_t *sse) {
+  uint8x8_t s = load_unaligned_u8(src, src_stride);
+  uint8x8_t r = load_unaligned_u8(ref, ref_stride);
+
+  uint8x8_t abs_diff = vabd_u8(s, r);
+
+  *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
 }
-int64_t aom_sse_neon(const uint8_t *a, int a_stride, const uint8_t *b,
-                     int b_stride, int width, int height) {
-  int y = 0;
-  int64_t sse = 0;
-  uint32x4_t sum = vdupq_n_u32(0);
-  switch (width) {
-    case 4:
+
+static INLINE uint32_t sse_128xh_neon(const uint8_t *src, int src_stride,
+                                      const uint8_t *ref, int ref_stride,
+                                      int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    sse_16x1_neon(src, ref, &sse);
+    sse_16x1_neon(src + 16, ref + 16, &sse);
+    sse_16x1_neon(src + 32, ref + 32, &sse);
+    sse_16x1_neon(src + 48, ref + 48, &sse);
+    sse_16x1_neon(src + 64, ref + 64, &sse);
+    sse_16x1_neon(src + 80, ref + 80, &sse);
+    sse_16x1_neon(src + 96, ref + 96, &sse);
+    sse_16x1_neon(src + 112, ref + 112, &sse);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < height);
+
+  return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_64xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    sse_16x1_neon(src, ref, &sse);
+    sse_16x1_neon(src + 16, ref + 16, &sse);
+    sse_16x1_neon(src + 32, ref + 32, &sse);
+    sse_16x1_neon(src + 48, ref + 48, &sse);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < height);
+
+  return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_32xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    sse_16x1_neon(src, ref, &sse);
+    sse_16x1_neon(src + 16, ref + 16, &sse);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < height);
+
+  return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_16xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    sse_16x1_neon(src, ref, &sse);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < height);
+
+  return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_8xh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    sse_8x1_neon(src, ref, &sse);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < height);
+
+  return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_4xh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    sse_4x2_neon(src, src_stride, ref, ref_stride, &sse);
+
+    src += 2 * src_stride;
+    ref += 2 * ref_stride;
+    i += 2;
+  } while (i < height);
+
+  return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_wxh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int width, int height) {
+  uint32x4_t sse = vdupq_n_u32(0);
+
+  if ((width & 0x07) && ((width & 0x07) < 5)) {
+    int i = 0;
+    do {
+      int j = 0;
       do {
-        aom_sse4x2_neon(a, a_stride, b, b_stride, &sum);
-        a += a_stride << 1;
-        b += b_stride << 1;
-        y += 2;
-      } while (y < height);
-      sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
-      break;
-    case 8:
+        sse_8x1_neon(src + j, ref + j, &sse);
+        sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse);
+        j += 8;
+      } while (j + 4 < width);
+
+      sse_4x2_neon(src + j, src_stride, ref + j, ref_stride, &sse);
+      src += 2 * src_stride;
+      ref += 2 * ref_stride;
+      i += 2;
+    } while (i < height);
+  } else {
+    int i = 0;
+    do {
+      int j = 0;
       do {
-        aom_sse8_neon(a, b, &sum);
-        a += a_stride;
-        b += b_stride;
-        y += 1;
-      } while (y < height);
-      sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
-      break;
-    case 16:
-      do {
-        sse_w16_neon(&sum, a, b);
-        a += a_stride;
-        b += b_stride;
-        y += 1;
-      } while (y < height);
-      sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
-      break;
-    case 32:
-      do {
-        sse_w16_neon(&sum, a, b);
-        sse_w16_neon(&sum, a + 16, b + 16);
-        a += a_stride;
-        b += b_stride;
-        y += 1;
-      } while (y < height);
-      sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
-      break;
-    case 64:
-      do {
-        sse_w16_neon(&sum, a, b);
-        sse_w16_neon(&sum, a + 16 * 1, b + 16 * 1);
-        sse_w16_neon(&sum, a + 16 * 2, b + 16 * 2);
-        sse_w16_neon(&sum, a + 16 * 3, b + 16 * 3);
-        a += a_stride;
-        b += b_stride;
-        y += 1;
-      } while (y < height);
-      sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
-      break;
-    case 128:
-      do {
-        sse_w16_neon(&sum, a, b);
-        sse_w16_neon(&sum, a + 16 * 1, b + 16 * 1);
-        sse_w16_neon(&sum, a + 16 * 2, b + 16 * 2);
-        sse_w16_neon(&sum, a + 16 * 3, b + 16 * 3);
-        sse_w16_neon(&sum, a + 16 * 4, b + 16 * 4);
-        sse_w16_neon(&sum, a + 16 * 5, b + 16 * 5);
-        sse_w16_neon(&sum, a + 16 * 6, b + 16 * 6);
-        sse_w16_neon(&sum, a + 16 * 7, b + 16 * 7);
-        a += a_stride;
-        b += b_stride;
-        y += 1;
-      } while (y < height);
-      sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
-      break;
-    default:
-      if (width & 0x07) {
-        do {
-          int i = 0;
-          do {
-            aom_sse8_neon(a + i, b + i, &sum);
-            aom_sse8_neon(a + i + a_stride, b + i + b_stride, &sum);
-            i += 8;
-          } while (i + 4 < width);
-          aom_sse4x2_neon(a + i, a_stride, b + i, b_stride, &sum);
-          a += (a_stride << 1);
-          b += (b_stride << 1);
-          y += 2;
-        } while (y < height);
-      } else {
-        do {
-          int i = 0;
-          do {
-            aom_sse8_neon(a + i, b + i, &sum);
-            i += 8;
-          } while (i < width);
-          a += a_stride;
-          b += b_stride;
-          y += 1;
-        } while (y < height);
-      }
-      sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
-      break;
+        sse_8x1_neon(src + j, ref + j, &sse);
+        j += 8;
+      } while (j < width);
+
+      src += src_stride;
+      ref += ref_stride;
+      i++;
+    } while (i < height);
   }
-  return sse;
+  return horizontal_add_u32x4(sse);
+}
+
+int64_t aom_sse_neon(const uint8_t *src, int src_stride, const uint8_t *ref,
+                     int ref_stride, int width, int height) {
+  switch (width) {
+    case 4: return sse_4xh_neon(src, src_stride, ref, ref_stride, height);
+    case 8: return sse_8xh_neon(src, src_stride, ref, ref_stride, height);
+    case 16: return sse_16xh_neon(src, src_stride, ref, ref_stride, height);
+    case 32: return sse_32xh_neon(src, src_stride, ref, ref_stride, height);
+    case 64: return sse_64xh_neon(src, src_stride, ref, ref_stride, height);
+    case 128: return sse_128xh_neon(src, src_stride, ref, ref_stride, height);
+    default:
+      return sse_wxh_neon(src, src_stride, ref, ref_stride, width, height);
+  }
 }
 
 #if CONFIG_AV1_HIGHBITDEPTH