Add Neon implementation of aom_var_2d_u16

Add Neon implementation of aom_var_2d_u16 as well as the corresponding
tests.

Change-Id: I8a5f0e4b5c73c60d846520f265d507728efebe72
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 02122c6..e8aff91 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -753,7 +753,7 @@
     specialize qw/aom_var_2d_u8 sse2 avx2 neon/;
 
     add_proto qw/uint64_t aom_var_2d_u16/, "uint8_t *src, int src_stride, int width, int height";
-    specialize qw/aom_var_2d_u16 sse2 avx2/;
+    specialize qw/aom_var_2d_u16 sse2 avx2 neon/;
   }
 
   #
diff --git a/aom_dsp/arm/sum_squares_neon.c b/aom_dsp/arm/sum_squares_neon.c
index c0d009d..626cf21 100644
--- a/aom_dsp/arm/sum_squares_neon.c
+++ b/aom_dsp/arm/sum_squares_neon.c
@@ -599,3 +599,102 @@
   }
   return aom_var_2d_u8_c(src, src_stride, width, height);
 }
+
+static INLINE uint64_t aom_var_2d_u16_4xh_neon(uint8_t *src, int src_stride,
+                                               int width, int height) {
+  uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
+  uint64_t sum = 0;
+  uint64_t sse = 0;
+  uint32x2_t sum_u32 = vdup_n_u32(0);
+  uint64x2_t sse_u64 = vdupq_n_u64(0);
+
+  int h = height;
+  do {
+    int w = width;
+    uint16_t *src_ptr = src_u16;
+    do {
+      uint16x4_t s0 = vld1_u16(src_ptr);
+
+      sum_u32 = vpadal_u16(sum_u32, s0);
+
+      uint32x4_t sse_u32 = vmull_u16(s0, s0);
+
+      sse_u64 = vpadalq_u32(sse_u64, sse_u32);
+
+      src_ptr += 4;
+      w -= 4;
+    } while (w >= 4);
+
+    // Process remaining columns in the row using C.
+    while (w > 0) {
+      int idx = width - w;
+      const uint16_t v = src_u16[idx];
+      sum += v;
+      sse += v * v;
+      w--;
+    }
+
+    src_u16 += src_stride;
+  } while (--h != 0);
+
+  sum += horizontal_long_add_u32x2(sum_u32);
+  sse += horizontal_add_u64x2(sse_u64);
+
+  return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u16_8xh_neon(uint8_t *src, int src_stride,
+                                               int width, int height) {
+  uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
+  uint64_t sum = 0;
+  uint64_t sse = 0;
+  uint32x4_t sum_u32 = vdupq_n_u32(0);
+  uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
+
+  int h = height;
+  do {
+    int w = width;
+    uint16_t *src_ptr = src_u16;
+    do {
+      uint16x8_t s0 = vld1q_u16(src_ptr);
+
+      sum_u32 = vpadalq_u16(sum_u32, s0);
+
+      uint32x4_t sse_u32_lo = vmull_u16(vget_low_u16(s0), vget_low_u16(s0));
+      uint32x4_t sse_u32_hi = vmull_u16(vget_high_u16(s0), vget_high_u16(s0));
+
+      sse_u64[0] = vpadalq_u32(sse_u64[0], sse_u32_lo);
+      sse_u64[1] = vpadalq_u32(sse_u64[1], sse_u32_hi);
+
+      src_ptr += 8;
+      w -= 8;
+    } while (w >= 8);
+
+    // Process remaining columns in the row using C.
+    while (w > 0) {
+      int idx = width - w;
+      const uint16_t v = src_u16[idx];
+      sum += v;
+      sse += v * v;
+      w--;
+    }
+
+    src_u16 += src_stride;
+  } while (--h != 0);
+
+  sum += horizontal_long_add_u32x4(sum_u32);
+  sse += horizontal_add_u64x2(vaddq_u64(sse_u64[0], sse_u64[1]));
+
+  return sse - sum * sum / (width * height);
+}
+
+uint64_t aom_var_2d_u16_neon(uint8_t *src, int src_stride, int width,
+                             int height) {
+  if (width >= 8) {
+    return aom_var_2d_u16_8xh_neon(src, src_stride, width, height);
+  }
+  if (width >= 4) {
+    return aom_var_2d_u16_4xh_neon(src, src_stride, width, height);
+  }
+  return aom_var_2d_u16_c(src, src_stride, width, height);
+}
diff --git a/test/sum_squares_test.cc b/test/sum_squares_test.cc
index b38c308..91f172d 100644
--- a/test/sum_squares_test.cc
+++ b/test/sum_squares_test.cc
@@ -852,4 +852,12 @@
     ::testing::Values(TestFuncVar2D(&aom_var_2d_u16_c, &aom_var_2d_u16_avx2)));
 
 #endif  // HAVE_SSE2
+
+#if HAVE_NEON
+
+INSTANTIATE_TEST_SUITE_P(
+    NEON, Highbd2dVarTest,
+    ::testing::Values(TestFuncVar2D(&aom_var_2d_u16_c, &aom_var_2d_u16_neon)));
+
+#endif  // HAVE_NEON
 }  // namespace