Add Neon implementation of aom_var_2d_u8

Add Neon implementation of aom_var_2d_u8 as well as the corresponding
tests.

Change-Id: I3326011a6105843f3d77c5e22127768728399eac
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index de6573d..02122c6 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -750,7 +750,7 @@
     specialize qw/aom_sum_squares_i16 sse2 neon/;
 
     add_proto qw/uint64_t aom_var_2d_u8/, "uint8_t *src, int src_stride, int width, int height";
-    specialize qw/aom_var_2d_u8 sse2 avx2/;
+    specialize qw/aom_var_2d_u8 sse2 avx2 neon/;
 
     add_proto qw/uint64_t aom_var_2d_u16/, "uint8_t *src, int src_stride, int width, int height";
     specialize qw/aom_var_2d_u16 sse2 avx2/;
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 9101979..7d2f18b 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -168,6 +168,15 @@
 #endif
 }
 
+static INLINE uint64_t horizontal_long_add_u32x2(const uint32x2_t a) {
+#if defined(__aarch64__)
+  return vaddlv_u32(a);
+#else
+  const uint64x1_t b = vpaddl_u32(a);
+  return vget_lane_u64(b, 0);
+#endif
+}
+
 static INLINE uint32_t horizontal_add_u16x4(const uint16x4_t a) {
 #if defined(__aarch64__)
   return vaddlv_u16(a);
diff --git a/aom_dsp/arm/sum_squares_neon.c b/aom_dsp/arm/sum_squares_neon.c
index 20c9c0d..c0d009d 100644
--- a/aom_dsp/arm/sum_squares_neon.c
+++ b/aom_dsp/arm/sum_squares_neon.c
@@ -286,3 +286,316 @@
   }
   return aom_sum_squares_i16_c(src, n);
 }
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint64_t aom_var_2d_u8_4xh_neon(uint8_t *src, int src_stride,
+                                              int width, int height) {
+  uint64_t sum = 0;
+  uint64_t sse = 0;
+  uint32x2_t sum_u32 = vdup_n_u32(0);
+  uint32x2_t sse_u32 = vdup_n_u32(0);
+
+  int h = height / 2;
+  do {
+    int w = width;
+    uint8_t *src_ptr = src;
+    do {
+      uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride);
+
+      sum_u32 = vdot_u32(sum_u32, s0, vdup_n_u8(1));
+
+      sse_u32 = vdot_u32(sse_u32, s0, s0);
+
+      src_ptr += 8;
+      w -= 8;
+    } while (w >= 8);
+
+    // Process remaining columns in the row using C.
+    while (w > 0) {
+      int idx = width - w;
+      const uint8_t v = src[idx];
+      sum += v;
+      sse += v * v;
+      w--;
+    }
+
+    src += 2 * src_stride;
+  } while (--h != 0);
+
+  sum += horizontal_long_add_u32x2(sum_u32);
+  sse += horizontal_long_add_u32x2(sse_u32);
+
+  return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u8_8xh_neon(uint8_t *src, int src_stride,
+                                              int width, int height) {
+  uint64_t sum = 0;
+  uint64_t sse = 0;
+  uint32x2_t sum_u32 = vdup_n_u32(0);
+  uint32x2_t sse_u32 = vdup_n_u32(0);
+
+  int h = height;
+  do {
+    int w = width;
+    uint8_t *src_ptr = src;
+    do {
+      uint8x8_t s0 = vld1_u8(src_ptr);
+
+      sum_u32 = vdot_u32(sum_u32, s0, vdup_n_u8(1));
+
+      sse_u32 = vdot_u32(sse_u32, s0, s0);
+
+      src_ptr += 8;
+      w -= 8;
+    } while (w >= 8);
+
+    // Process remaining columns in the row using C.
+    while (w > 0) {
+      int idx = width - w;
+      const uint8_t v = src[idx];
+      sum += v;
+      sse += v * v;
+      w--;
+    }
+
+    src += src_stride;
+  } while (--h != 0);
+
+  sum += horizontal_long_add_u32x2(sum_u32);
+  sse += horizontal_long_add_u32x2(sse_u32);
+
+  return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u8_16xh_neon(uint8_t *src, int src_stride,
+                                               int width, int height) {
+  uint64_t sum = 0;
+  uint64_t sse = 0;
+  uint32x4_t sum_u32 = vdupq_n_u32(0);
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int h = height;
+  do {
+    int w = width;
+    uint8_t *src_ptr = src;
+    do {
+      uint8x16_t s0 = vld1q_u8(src_ptr);
+
+      sum_u32 = vdotq_u32(sum_u32, s0, vdupq_n_u8(1));
+
+      sse_u32 = vdotq_u32(sse_u32, s0, s0);
+
+      src_ptr += 16;
+      w -= 16;
+    } while (w >= 16);
+
+    // Process remaining columns in the row using C.
+    while (w > 0) {
+      int idx = width - w;
+      const uint8_t v = src[idx];
+      sum += v;
+      sse += v * v;
+      w--;
+    }
+
+    src += src_stride;
+  } while (--h != 0);
+
+  sum += horizontal_long_add_u32x4(sum_u32);
+  sse += horizontal_long_add_u32x4(sse_u32);
+
+  return sse - sum * sum / (width * height);
+}
+
+#else  //  !defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint64_t aom_var_2d_u8_4xh_neon(uint8_t *src, int src_stride,
+                                              int width, int height) {
+  uint64_t sum = 0;
+  uint64_t sse = 0;
+  uint32x2_t sum_u32 = vdup_n_u32(0);
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
+  // element before we need to accumulate to 32-bit elements. Since we're
+  // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
+  // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
+  // * 256) / width.
+  int h_limit = (4 * 256) / width;
+  int h_tmp = height > h_limit ? h_limit : height;
+
+  int h = 0;
+  do {
+    uint16x4_t sum_u16 = vdup_n_u16(0);
+    do {
+      uint8_t *src_ptr = src;
+      int w = width;
+      do {
+        uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride);
+
+        sum_u16 = vpadal_u8(sum_u16, s0);
+
+        uint16x8_t sse_u16 = vmull_u8(s0, s0);
+
+        sse_u32 = vpadalq_u16(sse_u32, sse_u16);
+
+        src_ptr += 8;
+        w -= 8;
+      } while (w >= 8);
+
+      // Process remaining columns in the row using C.
+      while (w > 0) {
+        int idx = width - w;
+        const uint8_t v = src[idx];
+        sum += v;
+        sse += v * v;
+        w--;
+      }
+
+      src += 2 * src_stride;
+      h += 2;
+    } while (h < h_tmp && h < height);
+
+    sum_u32 = vpadal_u16(sum_u32, sum_u16);
+    h_tmp += h_limit;
+  } while (h < height);
+
+  sum += horizontal_long_add_u32x2(sum_u32);
+  sse += horizontal_long_add_u32x4(sse_u32);
+
+  return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u8_8xh_neon(uint8_t *src, int src_stride,
+                                              int width, int height) {
+  uint64_t sum = 0;
+  uint64_t sse = 0;
+  uint32x2_t sum_u32 = vdup_n_u32(0);
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
+  // element before we need to accumulate to 32-bit elements. Since we're
+  // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
+  // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
+  // * 256) / width.
+  int h_limit = (4 * 256) / width;
+  int h_tmp = height > h_limit ? h_limit : height;
+
+  int h = 0;
+  do {
+    uint16x4_t sum_u16 = vdup_n_u16(0);
+    do {
+      uint8_t *src_ptr = src;
+      int w = width;
+      do {
+        uint8x8_t s0 = vld1_u8(src_ptr);
+
+        sum_u16 = vpadal_u8(sum_u16, s0);
+
+        uint16x8_t sse_u16 = vmull_u8(s0, s0);
+
+        sse_u32 = vpadalq_u16(sse_u32, sse_u16);
+
+        src_ptr += 8;
+        w -= 8;
+      } while (w >= 8);
+
+      // Process remaining columns in the row using C.
+      while (w > 0) {
+        int idx = width - w;
+        const uint8_t v = src[idx];
+        sum += v;
+        sse += v * v;
+        w--;
+      }
+
+      src += src_stride;
+      ++h;
+    } while (h < h_tmp && h < height);
+
+    sum_u32 = vpadal_u16(sum_u32, sum_u16);
+    h_tmp += h_limit;
+  } while (h < height);
+
+  sum += horizontal_long_add_u32x2(sum_u32);
+  sse += horizontal_long_add_u32x4(sse_u32);
+
+  return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u8_16xh_neon(uint8_t *src, int src_stride,
+                                               int width, int height) {
+  uint64_t sum = 0;
+  uint64_t sse = 0;
+  uint32x4_t sum_u32 = vdupq_n_u32(0);
+  uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
+  // element before we need to accumulate to 32-bit elements. Since we're
+  // accumulating in uint16x8_t vectors, this means we can accumulate up to 8
+  // rows of 256 elements. Therefore the limit can be computed as: h_limit = (8
+  // * 256) / width.
+  int h_limit = (8 * 256) / width;
+  int h_tmp = height > h_limit ? h_limit : height;
+
+  int h = 0;
+  do {
+    uint16x8_t sum_u16 = vdupq_n_u16(0);
+    do {
+      int w = width;
+      uint8_t *src_ptr = src;
+      do {
+        uint8x16_t s0 = vld1q_u8(src_ptr);
+
+        sum_u16 = vpadalq_u8(sum_u16, s0);
+
+        uint16x8_t sse_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(s0));
+        uint16x8_t sse_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(s0));
+
+        sse_u32[0] = vpadalq_u16(sse_u32[0], sse_u16_lo);
+        sse_u32[1] = vpadalq_u16(sse_u32[1], sse_u16_hi);
+
+        src_ptr += 16;
+        w -= 16;
+      } while (w >= 16);
+
+      // Process remaining columns in the row using C.
+      while (w > 0) {
+        int idx = width - w;
+        const uint8_t v = src[idx];
+        sum += v;
+        sse += v * v;
+        w--;
+      }
+
+      src += src_stride;
+      ++h;
+    } while (h < h_tmp && h < height);
+
+    sum_u32 = vpadalq_u16(sum_u32, sum_u16);
+    h_tmp += h_limit;
+  } while (h < height);
+
+  sum += horizontal_long_add_u32x4(sum_u32);
+  sse += horizontal_long_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+
+  return sse - sum * sum / (width * height);
+}
+
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
+uint64_t aom_var_2d_u8_neon(uint8_t *src, int src_stride, int width,
+                            int height) {
+  if (width >= 16) {
+    return aom_var_2d_u8_16xh_neon(src, src_stride, width, height);
+  }
+  if (width >= 8) {
+    return aom_var_2d_u8_8xh_neon(src, src_stride, width, height);
+  }
+  if (width >= 4 && height % 2 == 0) {
+    return aom_var_2d_u8_4xh_neon(src, src_stride, width, height);
+  }
+  return aom_var_2d_u8_c(src, src_stride, width, height);
+}
diff --git a/test/sum_squares_test.cc b/test/sum_squares_test.cc
index a89e58c..b38c308 100644
--- a/test/sum_squares_test.cc
+++ b/test/sum_squares_test.cc
@@ -715,6 +715,14 @@
 
 #endif  // HAVE_SSE2
 
+#if HAVE_NEON
+
+INSTANTIATE_TEST_SUITE_P(NEON, Lowbd2dVarTest,
+                         ::testing::Values(TestFuncVar2D(&aom_var_2d_u8_c,
+                                                         &aom_var_2d_u8_neon)));
+
+#endif  // HAVE_NEON
+
 class Highbd2dVarTest : public ::testing::TestWithParam<TestFuncVar2D> {
  public:
   virtual ~Highbd2dVarTest() {}