Implement aom_sse_neon helpers using UDOT instruction

Accelerate the aom_sse_neon helper functions using the Armv8.4-A UDOT
instruction instead of multiple UMULL and UADALP instructions.

The previous implementation is retained for use on systems that do
not support the Armv8.4-A dot-product instructions.

Change-Id: Ia0c14a4bb5d95f68988d9d5d848f93b81bf7fa87
diff --git a/aom_dsp/arm/sse_neon.c b/aom_dsp/arm/sse_neon.c
index 767a7e9..4370146 100644
--- a/aom_dsp/arm/sse_neon.c
+++ b/aom_dsp/arm/sse_neon.c
@@ -16,6 +16,115 @@
 #include "aom_dsp/arm/sum_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
 
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE void sse_16x1_neon(const uint8_t *src, const uint8_t *ref,
+                                 uint32x4_t *sse) {
+  uint8x16_t s = vld1q_u8(src);
+  uint8x16_t r = vld1q_u8(ref);
+
+  uint8x16_t abs_diff = vabdq_u8(s, r);
+
+  *sse = vdotq_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE void sse_8x1_neon(const uint8_t *src, const uint8_t *ref,
+                                uint32x2_t *sse) {
+  uint8x8_t s = vld1_u8(src);
+  uint8x8_t r = vld1_u8(ref);
+
+  uint8x8_t abs_diff = vabd_u8(s, r);
+
+  *sse = vdot_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE void sse_4x2_neon(const uint8_t *src, int src_stride,
+                                const uint8_t *ref, int ref_stride,
+                                uint32x2_t *sse) {
+  uint8x8_t s = load_unaligned_u8(src, src_stride);
+  uint8x8_t r = load_unaligned_u8(ref, ref_stride);
+
+  uint8x8_t abs_diff = vabd_u8(s, r);
+
+  *sse = vdot_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE uint32_t sse_8xh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int height) {
+  uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
+
+  int i = 0;
+  do {
+    sse_8x1_neon(src, ref, &sse[0]);
+    src += src_stride;
+    ref += ref_stride;
+    sse_8x1_neon(src, ref, &sse[1]);
+    src += src_stride;
+    ref += ref_stride;
+    i += 2;
+  } while (i < height);
+
+  return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_4xh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int height) {
+  uint32x2_t sse = vdup_n_u32(0);
+
+  int i = 0;
+  do {
+    sse_4x2_neon(src, src_stride, ref, ref_stride, &sse);
+
+    src += 2 * src_stride;
+    ref += 2 * ref_stride;
+    i += 2;
+  } while (i < height);
+
+  return horizontal_add_u32x2(sse);
+}
+
+static INLINE uint32_t sse_wxh_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *ref, int ref_stride,
+                                    int width, int height) {
+  uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
+
+  if ((width & 0x07) && ((width & 0x07) < 5)) {
+    int i = 0;
+    do {
+      int j = 0;
+      do {
+        sse_8x1_neon(src + j, ref + j, &sse[0]);
+        sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse[1]);
+        j += 8;
+      } while (j + 4 < width);
+
+      sse_4x2_neon(src + j, src_stride, ref + j, ref_stride, &sse[0]);
+      src += 2 * src_stride;
+      ref += 2 * ref_stride;
+      i += 2;
+    } while (i < height);
+  } else {
+    int i = 0;
+    do {
+      int j = 0;
+      do {
+        sse_8x1_neon(src + j, ref + j, &sse[0]);
+        sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse[1]);
+        j += 8;
+      } while (j < width);
+
+      src += 2 * src_stride;
+      ref += 2 * ref_stride;
+      i += 2;
+    } while (i < height);
+  }
+  return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
+}
+
+#else  // !defined(__ARM_FEATURE_DOTPROD)
+
 static INLINE void sse_16x1_neon(const uint8_t *src, const uint8_t *ref,
                                  uint32x4_t *sse) {
   uint8x16_t s = vld1q_u8(src);
@@ -50,85 +159,6 @@
   *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
 }
 
-static INLINE uint32_t sse_128xh_neon(const uint8_t *src, int src_stride,
-                                      const uint8_t *ref, int ref_stride,
-                                      int height) {
-  uint32x4_t sse = vdupq_n_u32(0);
-
-  int i = 0;
-  do {
-    sse_16x1_neon(src, ref, &sse);
-    sse_16x1_neon(src + 16, ref + 16, &sse);
-    sse_16x1_neon(src + 32, ref + 32, &sse);
-    sse_16x1_neon(src + 48, ref + 48, &sse);
-    sse_16x1_neon(src + 64, ref + 64, &sse);
-    sse_16x1_neon(src + 80, ref + 80, &sse);
-    sse_16x1_neon(src + 96, ref + 96, &sse);
-    sse_16x1_neon(src + 112, ref + 112, &sse);
-
-    src += src_stride;
-    ref += ref_stride;
-    i++;
-  } while (i < height);
-
-  return horizontal_add_u32x4(sse);
-}
-
-static INLINE uint32_t sse_64xh_neon(const uint8_t *src, int src_stride,
-                                     const uint8_t *ref, int ref_stride,
-                                     int height) {
-  uint32x4_t sse = vdupq_n_u32(0);
-
-  int i = 0;
-  do {
-    sse_16x1_neon(src, ref, &sse);
-    sse_16x1_neon(src + 16, ref + 16, &sse);
-    sse_16x1_neon(src + 32, ref + 32, &sse);
-    sse_16x1_neon(src + 48, ref + 48, &sse);
-
-    src += src_stride;
-    ref += ref_stride;
-    i++;
-  } while (i < height);
-
-  return horizontal_add_u32x4(sse);
-}
-
-static INLINE uint32_t sse_32xh_neon(const uint8_t *src, int src_stride,
-                                     const uint8_t *ref, int ref_stride,
-                                     int height) {
-  uint32x4_t sse = vdupq_n_u32(0);
-
-  int i = 0;
-  do {
-    sse_16x1_neon(src, ref, &sse);
-    sse_16x1_neon(src + 16, ref + 16, &sse);
-
-    src += src_stride;
-    ref += ref_stride;
-    i++;
-  } while (i < height);
-
-  return horizontal_add_u32x4(sse);
-}
-
-static INLINE uint32_t sse_16xh_neon(const uint8_t *src, int src_stride,
-                                     const uint8_t *ref, int ref_stride,
-                                     int height) {
-  uint32x4_t sse = vdupq_n_u32(0);
-
-  int i = 0;
-  do {
-    sse_16x1_neon(src, ref, &sse);
-
-    src += src_stride;
-    ref += ref_stride;
-    i++;
-  } while (i < height);
-
-  return horizontal_add_u32x4(sse);
-}
-
 static INLINE uint32_t sse_8xh_neon(const uint8_t *src, int src_stride,
                                     const uint8_t *ref, int ref_stride,
                                     int height) {
@@ -200,6 +230,89 @@
   return horizontal_add_u32x4(sse);
 }
 
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t sse_128xh_neon(const uint8_t *src, int src_stride,
+                                      const uint8_t *ref, int ref_stride,
+                                      int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = 0;
+  do {
+    sse_16x1_neon(src, ref, &sse[0]);
+    sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+    sse_16x1_neon(src + 32, ref + 32, &sse[0]);
+    sse_16x1_neon(src + 48, ref + 48, &sse[1]);
+    sse_16x1_neon(src + 64, ref + 64, &sse[0]);
+    sse_16x1_neon(src + 80, ref + 80, &sse[1]);
+    sse_16x1_neon(src + 96, ref + 96, &sse[0]);
+    sse_16x1_neon(src + 112, ref + 112, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < height);
+
+  return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_64xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = 0;
+  do {
+    sse_16x1_neon(src, ref, &sse[0]);
+    sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+    sse_16x1_neon(src + 32, ref + 32, &sse[0]);
+    sse_16x1_neon(src + 48, ref + 48, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < height);
+
+  return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_32xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = 0;
+  do {
+    sse_16x1_neon(src, ref, &sse[0]);
+    sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < height);
+
+  return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_16xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride,
+                                     int height) {
+  uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = 0;
+  do {
+    sse_16x1_neon(src, ref, &sse[0]);
+    src += src_stride;
+    ref += ref_stride;
+    sse_16x1_neon(src, ref, &sse[1]);
+    src += src_stride;
+    ref += ref_stride;
+    i += 2;
+  } while (i < height);
+
+  return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
+}
+
 int64_t aom_sse_neon(const uint8_t *src, int src_stride, const uint8_t *ref,
                      int ref_stride, int width, int height) {
   switch (width) {
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 0cf110a..855edf6 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -94,6 +94,15 @@
 #endif
 }
 
+static INLINE uint32_t horizontal_add_u32x2(const uint32x2_t a) {
+#if defined(__aarch64__)
+  return vaddv_u32(a);
+#else
+  const uint64x1_t b = vpaddl_u32(a);
+  return vget_lane_u32(vreinterpret_u32_u64(b), 0);
+#endif
+}
+
 static INLINE uint32_t horizontal_add_u16x4(const uint16x4_t a) {
 #if defined(__aarch64__)
   return vaddlv_u16(a);