Implement Neon variance functions using UDOT instruction

Accelerate the Neon variance helper functions by implementing the sum
of squares calculation using a single Armv8.4-A UDOT instruction
instead of four MLAs.

The previous implementation is retained for use on systems that do
support the Armv8.4-A dot-product instructions.

Change-Id: I21689fd084c5ed9448fb1a630dc880ed45fac1df
diff --git a/aom_dsp/arm/variance_neon.c b/aom_dsp/arm/variance_neon.c
index 8505d1b..e8e3a53 100644
--- a/aom_dsp/arm/variance_neon.c
+++ b/aom_dsp/arm/variance_neon.c
@@ -18,6 +18,149 @@
 #include "aom/aom_integer.h"
 #include "aom_ports/mem.h"
 
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE void variance_4xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride, int h,
+                                     uint32_t *sse, int *sum) {
+  uint32x4_t src_sum = vdupq_n_u32(0);
+  uint32x4_t ref_sum = vdupq_n_u32(0);
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    uint8x16_t s = load_unaligned_u8q(src, src_stride);
+    uint8x16_t r = load_unaligned_u8q(ref, ref_stride);
+
+    src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
+    ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
+
+    uint8x16_t abs_diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
+
+    src += 4 * src_stride;
+    ref += 4 * ref_stride;
+    i += 4;
+  } while (i < h);
+
+  int32x4_t sum_diff =
+      vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
+  *sum = horizontal_add_s32x4(sum_diff);
+  *sse = horizontal_add_u32x4(sse_u32);
+}
+
+static INLINE void variance_8xh_neon(const uint8_t *src, int src_stride,
+                                     const uint8_t *ref, int ref_stride, int h,
+                                     uint32_t *sse, int *sum) {
+  uint32x4_t src_sum = vdupq_n_u32(0);
+  uint32x4_t ref_sum = vdupq_n_u32(0);
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride));
+    uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride));
+
+    src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
+    ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
+
+    uint8x16_t abs_diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
+
+    src += 2 * src_stride;
+    ref += 2 * ref_stride;
+    i += 2;
+  } while (i < h);
+
+  int32x4_t sum_diff =
+      vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
+  *sum = horizontal_add_s32x4(sum_diff);
+  *sse = horizontal_add_u32x4(sse_u32);
+}
+
+static INLINE void variance_16xh_neon(const uint8_t *src, int src_stride,
+                                      const uint8_t *ref, int ref_stride, int h,
+                                      uint32_t *sse, int *sum) {
+  uint32x4_t src_sum = vdupq_n_u32(0);
+  uint32x4_t ref_sum = vdupq_n_u32(0);
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    uint8x16_t s = vld1q_u8(src);
+    uint8x16_t r = vld1q_u8(ref);
+
+    src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
+    ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
+
+    uint8x16_t abs_diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < h);
+
+  int32x4_t sum_diff =
+      vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
+  *sum = horizontal_add_s32x4(sum_diff);
+  *sse = horizontal_add_u32x4(sse_u32);
+}
+
+static INLINE void variance_large_neon(const uint8_t *src, int src_stride,
+                                       const uint8_t *ref, int ref_stride,
+                                       int w, int h, uint32_t *sse, int *sum) {
+  uint32x4_t src_sum = vdupq_n_u32(0);
+  uint32x4_t ref_sum = vdupq_n_u32(0);
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = 0;
+  do {
+    int j = 0;
+    do {
+      uint8x16_t s = vld1q_u8(src + j);
+      uint8x16_t r = vld1q_u8(ref + j);
+
+      src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
+      ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
+
+      uint8x16_t abs_diff = vabdq_u8(s, r);
+      sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
+
+      j += 16;
+    } while (j < w);
+
+    src += src_stride;
+    ref += ref_stride;
+    i++;
+  } while (i < h);
+
+  int32x4_t sum_diff =
+      vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum));
+  *sum = horizontal_add_s32x4(sum_diff);
+  *sse = horizontal_add_u32x4(sse_u32);
+}
+
+static INLINE void variance_32xh_neon(const uint8_t *src, int src_stride,
+                                      const uint8_t *ref, int ref_stride, int h,
+                                      uint32_t *sse, int *sum) {
+  variance_large_neon(src, src_stride, ref, ref_stride, 32, h, sse, sum);
+}
+
+static INLINE void variance_64xh_neon(const uint8_t *src, int src_stride,
+                                      const uint8_t *ref, int ref_stride, int h,
+                                      uint32_t *sse, int *sum) {
+  variance_large_neon(src, src_stride, ref, ref_stride, 64, h, sse, sum);
+}
+
+static INLINE void variance_128xh_neon(const uint8_t *src, int src_stride,
+                                       const uint8_t *ref, int ref_stride,
+                                       int h, uint32_t *sse, int *sum) {
+  variance_large_neon(src, src_stride, ref, ref_stride, 128, h, sse, sum);
+}
+
+#else  // !defined(__ARM_FEATURE_DOTPROD)
+
 static INLINE void variance_4xh_neon(const uint8_t *src, int src_stride,
                                      const uint8_t *ref, int ref_stride, int h,
                                      uint32_t *sse, int *sum) {
@@ -194,6 +337,8 @@
   variance_large_neon(src, src_stride, ref, ref_stride, 128, h, 16, sse, sum);
 }
 
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
 #define VARIANCE_WXH_NEON(w, h, shift)                                        \
   unsigned int aom_variance##w##x##h##_neon(                                  \
       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \