Add Neon implementations of SAD3D functions

Add Armv8.0 and Armv8.4 (dot-product) Neon implementations of
aom_sad<w>x<h>x3d functions, as well as the corresponding tests.

Change-Id: Ia62c4d6cb43d620a48a98c6c2e8c817ebad26f56
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 2feed0b..19925d5 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1046,21 +1046,29 @@
   specialize qw/aom_sad_skip_4x8x4d          sse2 neon/;
   specialize qw/aom_sad_skip_4x4x4d               neon/;
 
-  specialize qw/aom_sad128x128x3d avx2/;
-  specialize qw/aom_sad128x64x3d  avx2/;
-  specialize qw/aom_sad64x128x3d  avx2/;
-  specialize qw/aom_sad64x64x3d   avx2/;
-  specialize qw/aom_sad64x32x3d   avx2/;
-  specialize qw/aom_sad32x64x3d   avx2/;
-  specialize qw/aom_sad32x32x3d   avx2/;
-  specialize qw/aom_sad32x16x3d   avx2/;
-  specialize qw/aom_sad16x32x3d   avx2/;
-  specialize qw/aom_sad16x16x3d   avx2/;
-  specialize qw/aom_sad16x8x3d    avx2/;
+  specialize qw/aom_sad128x128x3d neon avx2/;
+  specialize qw/aom_sad128x64x3d  neon avx2/;
+  specialize qw/aom_sad64x128x3d  neon avx2/;
+  specialize qw/aom_sad64x64x3d   neon avx2/;
+  specialize qw/aom_sad64x32x3d   neon avx2/;
+  specialize qw/aom_sad32x64x3d   neon avx2/;
+  specialize qw/aom_sad32x32x3d   neon avx2/;
+  specialize qw/aom_sad32x16x3d   neon avx2/;
+  specialize qw/aom_sad16x32x3d   neon avx2/;
+  specialize qw/aom_sad16x16x3d   neon avx2/;
+  specialize qw/aom_sad16x8x3d    neon avx2/;
+  specialize qw/aom_sad8x16x3d    neon/;
+  specialize qw/aom_sad8x8x3d     neon/;
+  specialize qw/aom_sad8x4x3d     neon/;
+  specialize qw/aom_sad4x8x3d     neon/;
+  specialize qw/aom_sad4x4x3d     neon/;
 
-  specialize qw/aom_sad64x16x3d   avx2/;
-  specialize qw/aom_sad32x8x3d    avx2/;
-  specialize qw/aom_sad16x64x3d   avx2/;
+  specialize qw/aom_sad64x16x3d   neon avx2/;
+  specialize qw/aom_sad32x8x3d    neon avx2/;
+  specialize qw/aom_sad16x64x3d   neon avx2/;
+  specialize qw/aom_sad16x4x3d    neon/;
+  specialize qw/aom_sad8x32x3d    neon/;
+  specialize qw/aom_sad4x16x3d    neon/;
 
   specialize qw/aom_masked_sad128x128x4d  ssse3/;
   specialize qw/aom_masked_sad128x64x4d   ssse3/;
diff --git a/aom_dsp/arm/sadxd_neon.c b/aom_dsp/arm/sadxd_neon.c
index a3c02c5..81803b1 100644
--- a/aom_dsp/arm/sadxd_neon.c
+++ b/aom_dsp/arm/sadxd_neon.c
@@ -26,6 +26,307 @@
   *sad_sum = vdotq_u32(*sad_sum, abs_diff, vdupq_n_u8(1));
 }
 
+static INLINE void sadwxhx3d_large_neon(const uint8_t *src, int src_stride,
+                                        const uint8_t *const ref[4],
+                                        int ref_stride, uint32_t res[4], int w,
+                                        int h) {
+  uint32x4_t sum_lo[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
+  uint32x4_t sum_hi[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int ref_offset = 0;
+  int i = h;
+  do {
+    int j = 0;
+    do {
+      const uint8x16_t s0 = vld1q_u8(src + j);
+      sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]);
+      sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]);
+      sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]);
+
+      const uint8x16_t s1 = vld1q_u8(src + j + 16);
+      sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]);
+      sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]);
+      sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]);
+
+      j += 32;
+    } while (j < w);
+
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
+
+  res[0] = horizontal_add_u32x4(vaddq_u32(sum_lo[0], sum_hi[0]));
+  res[1] = horizontal_add_u32x4(vaddq_u32(sum_lo[1], sum_hi[1]));
+  res[2] = horizontal_add_u32x4(vaddq_u32(sum_lo[2], sum_hi[2]));
+}
+
+static INLINE void sad128xhx3d_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *const ref[4], int ref_stride,
+                                    uint32_t res[4], int h) {
+  sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 128, h);
+}
+
+static INLINE void sad64xhx3d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 64, h);
+}
+
+static INLINE void sad32xhx3d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 32, h);
+}
+
+static INLINE void sad16xhx3d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int ref_offset = 0;
+  int i = h;
+  do {
+    const uint8x16_t s = vld1q_u8(src);
+    sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]);
+    sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]);
+    sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]);
+
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
+
+  res[0] = horizontal_add_u32x4(sum[0]);
+  res[1] = horizontal_add_u32x4(sum[1]);
+  res[2] = horizontal_add_u32x4(sum[2]);
+}
+
+#else  // !(defined(__ARM_FEATURE_DOTPROD))
+
+static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
+                              uint16x8_t *const sad_sum) {
+  uint8x16_t abs_diff = vabdq_u8(src, ref);
+  *sad_sum = vpadalq_u8(*sad_sum, abs_diff);
+}
+
+static INLINE void sadwxhx3d_large_neon(const uint8_t *src, int src_stride,
+                                        const uint8_t *const ref[3],
+                                        int ref_stride, uint32_t res[3], int w,
+                                        int h, int h_overflow) {
+  uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
+  int h_limit = h > h_overflow ? h_overflow : h;
+
+  int ref_offset = 0;
+  int i = 0;
+  do {
+    uint16x8_t sum_lo[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+    uint16x8_t sum_hi[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+
+    do {
+      int j = 0;
+      do {
+        const uint8x16_t s0 = vld1q_u8(src + j);
+        sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]);
+        sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]);
+        sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]);
+
+        const uint8x16_t s1 = vld1q_u8(src + j + 16);
+        sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]);
+        sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]);
+        sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]);
+
+        j += 32;
+      } while (j < w);
+
+      src += src_stride;
+      ref_offset += ref_stride;
+    } while (++i < h_limit);
+
+    sum[0] = vpadalq_u16(sum[0], sum_lo[0]);
+    sum[0] = vpadalq_u16(sum[0], sum_hi[0]);
+    sum[1] = vpadalq_u16(sum[1], sum_lo[1]);
+    sum[1] = vpadalq_u16(sum[1], sum_hi[1]);
+    sum[2] = vpadalq_u16(sum[2], sum_lo[2]);
+    sum[2] = vpadalq_u16(sum[2], sum_hi[2]);
+
+    h_limit += h_overflow;
+  } while (i < h);
+
+  res[0] = horizontal_add_u32x4(sum[0]);
+  res[1] = horizontal_add_u32x4(sum[1]);
+  res[2] = horizontal_add_u32x4(sum[2]);
+}
+
+static INLINE void sad128xhx3d_neon(const uint8_t *src, int src_stride,
+                                    const uint8_t *const ref[3], int ref_stride,
+                                    uint32_t res[3], int h) {
+  sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 128, h, 32);
+}
+
+static INLINE void sad64xhx3d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[3], int ref_stride,
+                                   uint32_t res[3], int h) {
+  sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 64, h, 64);
+}
+
+static INLINE void sad32xhx3d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[3], int ref_stride,
+                                   uint32_t res[3], int h) {
+  uint16x8_t sum_lo[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+  uint16x8_t sum_hi[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+
+  int ref_offset = 0;
+  int i = h;
+  do {
+    const uint8x16_t s0 = vld1q_u8(src);
+    sad16_neon(s0, vld1q_u8(ref[0] + ref_offset), &sum_lo[0]);
+    sad16_neon(s0, vld1q_u8(ref[1] + ref_offset), &sum_lo[1]);
+    sad16_neon(s0, vld1q_u8(ref[2] + ref_offset), &sum_lo[2]);
+
+    const uint8x16_t s1 = vld1q_u8(src + 16);
+    sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + 16), &sum_hi[0]);
+    sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + 16), &sum_hi[1]);
+    sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + 16), &sum_hi[2]);
+
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
+
+  res[0] = horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
+  res[1] = horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
+  res[2] = horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
+}
+
+static INLINE void sad16xhx3d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[3], int ref_stride,
+                                   uint32_t res[3], int h) {
+  uint16x8_t sum[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+
+  int ref_offset = 0;
+  int i = h;
+  do {
+    const uint8x16_t s = vld1q_u8(src);
+    sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]);
+    sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]);
+    sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]);
+
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
+
+  res[0] = horizontal_add_u16x8(sum[0]);
+  res[1] = horizontal_add_u16x8(sum[1]);
+  res[2] = horizontal_add_u16x8(sum[2]);
+}
+
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE void sad8xhx3d_neon(const uint8_t *src, int src_stride,
+                                  const uint8_t *const ref[3], int ref_stride,
+                                  uint32_t res[3], int h) {
+  uint16x8_t sum[3];
+
+  uint8x8_t s = vld1_u8(src);
+  sum[0] = vabdl_u8(s, vld1_u8(ref[0]));
+  sum[1] = vabdl_u8(s, vld1_u8(ref[1]));
+  sum[2] = vabdl_u8(s, vld1_u8(ref[2]));
+
+  src += src_stride;
+  int ref_offset = ref_stride;
+  int i = h - 1;
+  do {
+    s = vld1_u8(src);
+    sum[0] = vabal_u8(sum[0], s, vld1_u8(ref[0] + ref_offset));
+    sum[1] = vabal_u8(sum[1], s, vld1_u8(ref[1] + ref_offset));
+    sum[2] = vabal_u8(sum[2], s, vld1_u8(ref[2] + ref_offset));
+
+    src += src_stride;
+    ref_offset += ref_stride;
+  } while (--i != 0);
+
+  res[0] = horizontal_add_u16x8(sum[0]);
+  res[1] = horizontal_add_u16x8(sum[1]);
+  res[2] = horizontal_add_u16x8(sum[2]);
+}
+
+static INLINE void sad4xhx3d_neon(const uint8_t *src, int src_stride,
+                                  const uint8_t *const ref[3], int ref_stride,
+                                  uint32_t res[3], int h) {
+  assert(h % 2 == 0);
+  uint16x8_t sum[3];
+
+  uint8x8_t s = load_unaligned_u8(src, src_stride);
+  uint8x8_t r0 = load_unaligned_u8(ref[0], ref_stride);
+  uint8x8_t r1 = load_unaligned_u8(ref[1], ref_stride);
+  uint8x8_t r2 = load_unaligned_u8(ref[2], ref_stride);
+
+  sum[0] = vabdl_u8(s, r0);
+  sum[1] = vabdl_u8(s, r1);
+  sum[2] = vabdl_u8(s, r2);
+
+  src += 2 * src_stride;
+  int ref_offset = 2 * ref_stride;
+  int i = (h / 2) - 1;
+  do {
+    s = load_unaligned_u8(src, src_stride);
+    r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride);
+    r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride);
+    r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride);
+
+    sum[0] = vabal_u8(sum[0], s, r0);
+    sum[1] = vabal_u8(sum[1], s, r1);
+    sum[2] = vabal_u8(sum[2], s, r2);
+
+    src += 2 * src_stride;
+    ref_offset += 2 * ref_stride;
+  } while (--i != 0);
+
+  res[0] = horizontal_add_u16x8(sum[0]);
+  res[1] = horizontal_add_u16x8(sum[1]);
+  res[2] = horizontal_add_u16x8(sum[2]);
+}
+
+#define SAD_WXH_3D_NEON(w, h)                                                  \
+  void aom_sad##w##x##h##x3d_neon(const uint8_t *src, int src_stride,          \
+                                  const uint8_t *const ref[4], int ref_stride, \
+                                  uint32_t res[4]) {                           \
+    sad##w##xhx3d_neon(src, src_stride, ref, ref_stride, res, (h));            \
+  }
+
+SAD_WXH_3D_NEON(4, 4)
+SAD_WXH_3D_NEON(4, 8)
+
+SAD_WXH_3D_NEON(8, 4)
+SAD_WXH_3D_NEON(8, 8)
+SAD_WXH_3D_NEON(8, 16)
+
+SAD_WXH_3D_NEON(16, 8)
+SAD_WXH_3D_NEON(16, 16)
+SAD_WXH_3D_NEON(16, 32)
+
+SAD_WXH_3D_NEON(32, 16)
+SAD_WXH_3D_NEON(32, 32)
+SAD_WXH_3D_NEON(32, 64)
+
+SAD_WXH_3D_NEON(64, 32)
+SAD_WXH_3D_NEON(64, 64)
+SAD_WXH_3D_NEON(64, 128)
+
+SAD_WXH_3D_NEON(128, 64)
+SAD_WXH_3D_NEON(128, 128)
+
+#if !CONFIG_REALTIME_ONLY
+SAD_WXH_3D_NEON(4, 16)
+SAD_WXH_3D_NEON(8, 32)
+SAD_WXH_3D_NEON(16, 4)
+SAD_WXH_3D_NEON(16, 64)
+SAD_WXH_3D_NEON(32, 8)
+SAD_WXH_3D_NEON(64, 16)
+#endif  // !CONFIG_REALTIME_ONLY
+
+#undef SAD_WXH_3D_NEON
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
 static INLINE void sadwxhx4d_large_neon(const uint8_t *src, int src_stride,
                                         const uint8_t *const ref[4],
                                         int ref_stride, uint32_t res[4], int w,
@@ -110,12 +411,6 @@
 
 #else  // !(defined(__ARM_FEATURE_DOTPROD))
 
-static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
-                              uint16x8_t *const sad_sum) {
-  uint8x16_t abs_diff = vabdq_u8(src, ref);
-  *sad_sum = vpadalq_u8(*sad_sum, abs_diff);
-}
-
 static INLINE void sadwxhx4d_large_neon(const uint8_t *src, int src_stride,
                                         const uint8_t *const ref[4],
                                         int ref_stride, uint32_t res[4], int w,
diff --git a/test/sad_test.cc b/test/sad_test.cc
index c10e929..2740305 100644
--- a/test/sad_test.cc
+++ b/test/sad_test.cc
@@ -1958,6 +1958,34 @@
 };
 INSTANTIATE_TEST_SUITE_P(NEON, SADavgTest, ::testing::ValuesIn(avg_neon_tests));
 
+const SadMxNx4Param x3d_neon_tests[] = {
+  make_tuple(128, 128, &aom_sad128x128x3d_neon, -1),
+  make_tuple(128, 64, &aom_sad128x64x3d_neon, -1),
+  make_tuple(64, 128, &aom_sad64x128x3d_neon, -1),
+  make_tuple(64, 64, &aom_sad64x64x3d_neon, -1),
+  make_tuple(64, 32, &aom_sad64x32x3d_neon, -1),
+  make_tuple(32, 64, &aom_sad32x64x3d_neon, -1),
+  make_tuple(32, 32, &aom_sad32x32x3d_neon, -1),
+  make_tuple(32, 16, &aom_sad32x16x3d_neon, -1),
+  make_tuple(16, 32, &aom_sad16x32x3d_neon, -1),
+  make_tuple(16, 16, &aom_sad16x16x3d_neon, -1),
+  make_tuple(16, 8, &aom_sad16x8x3d_neon, -1),
+  make_tuple(8, 16, &aom_sad8x16x3d_neon, -1),
+  make_tuple(8, 8, &aom_sad8x8x3d_neon, -1),
+  make_tuple(8, 4, &aom_sad8x4x3d_neon, -1),
+  make_tuple(4, 8, &aom_sad4x8x3d_neon, -1),
+  make_tuple(4, 4, &aom_sad4x4x3d_neon, -1),
+#if !CONFIG_REALTIME_ONLY
+  make_tuple(64, 16, &aom_sad64x16x3d_neon, -1),
+  make_tuple(32, 8, &aom_sad32x8x3d_neon, -1),
+  make_tuple(16, 64, &aom_sad16x64x3d_neon, -1),
+  make_tuple(16, 4, &aom_sad16x4x3d_neon, -1),
+  make_tuple(8, 32, &aom_sad8x32x3d_neon, -1),
+  make_tuple(4, 16, &aom_sad4x16x3d_neon, -1),
+#endif  // !CONFIG_REALTIME_ONLY
+};
+INSTANTIATE_TEST_SUITE_P(NEON, SADx3Test, ::testing::ValuesIn(x3d_neon_tests));
+
 #endif  // HAVE_NEON
 
 //------------------------------------------------------------------------------