Add Neon implementation of aom_var_2d_u16
Add Neon implementation of aom_var_2d_u16 as well as the corresponding
tests.
Change-Id: I8a5f0e4b5c73c60d846520f265d507728efebe72
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 02122c6..e8aff91 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -753,7 +753,7 @@
specialize qw/aom_var_2d_u8 sse2 avx2 neon/;
add_proto qw/uint64_t aom_var_2d_u16/, "uint8_t *src, int src_stride, int width, int height";
- specialize qw/aom_var_2d_u16 sse2 avx2/;
+ specialize qw/aom_var_2d_u16 sse2 avx2 neon/;
}
#
diff --git a/aom_dsp/arm/sum_squares_neon.c b/aom_dsp/arm/sum_squares_neon.c
index c0d009d..626cf21 100644
--- a/aom_dsp/arm/sum_squares_neon.c
+++ b/aom_dsp/arm/sum_squares_neon.c
@@ -599,3 +599,102 @@
}
return aom_var_2d_u8_c(src, src_stride, width, height);
}
+
+static INLINE uint64_t aom_var_2d_u16_4xh_neon(uint8_t *src, int src_stride,
+ int width, int height) {
+ uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
+ uint64_t sum = 0;
+ uint64_t sse = 0;
+ uint32x2_t sum_u32 = vdup_n_u32(0);
+ uint64x2_t sse_u64 = vdupq_n_u64(0);
+
+ int h = height;
+ do {
+ int w = width;
+ uint16_t *src_ptr = src_u16;
+ do {
+ uint16x4_t s0 = vld1_u16(src_ptr);
+
+ sum_u32 = vpadal_u16(sum_u32, s0);
+
+ uint32x4_t sse_u32 = vmull_u16(s0, s0);
+
+ sse_u64 = vpadalq_u32(sse_u64, sse_u32);
+
+ src_ptr += 4;
+ w -= 4;
+ } while (w >= 4);
+
+ // Process remaining columns in the row using C.
+ while (w > 0) {
+ int idx = width - w;
+ const uint16_t v = src_u16[idx];
+ sum += v;
+ sse += v * v;
+ w--;
+ }
+
+ src_u16 += src_stride;
+ } while (--h != 0);
+
+ sum += horizontal_long_add_u32x2(sum_u32);
+ sse += horizontal_add_u64x2(sse_u64);
+
+ return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u16_8xh_neon(uint8_t *src, int src_stride,
+ int width, int height) {
+ uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
+ uint64_t sum = 0;
+ uint64_t sse = 0;
+ uint32x4_t sum_u32 = vdupq_n_u32(0);
+ uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
+
+ int h = height;
+ do {
+ int w = width;
+ uint16_t *src_ptr = src_u16;
+ do {
+ uint16x8_t s0 = vld1q_u16(src_ptr);
+
+ sum_u32 = vpadalq_u16(sum_u32, s0);
+
+ uint32x4_t sse_u32_lo = vmull_u16(vget_low_u16(s0), vget_low_u16(s0));
+ uint32x4_t sse_u32_hi = vmull_u16(vget_high_u16(s0), vget_high_u16(s0));
+
+ sse_u64[0] = vpadalq_u32(sse_u64[0], sse_u32_lo);
+ sse_u64[1] = vpadalq_u32(sse_u64[1], sse_u32_hi);
+
+ src_ptr += 8;
+ w -= 8;
+ } while (w >= 8);
+
+ // Process remaining columns in the row using C.
+ while (w > 0) {
+ int idx = width - w;
+ const uint16_t v = src_u16[idx];
+ sum += v;
+ sse += v * v;
+ w--;
+ }
+
+ src_u16 += src_stride;
+ } while (--h != 0);
+
+ sum += horizontal_long_add_u32x4(sum_u32);
+ sse += horizontal_add_u64x2(vaddq_u64(sse_u64[0], sse_u64[1]));
+
+ return sse - sum * sum / (width * height);
+}
+
+uint64_t aom_var_2d_u16_neon(uint8_t *src, int src_stride, int width,
+ int height) {
+ if (width >= 8) {
+ return aom_var_2d_u16_8xh_neon(src, src_stride, width, height);
+ }
+ if (width >= 4) {
+ return aom_var_2d_u16_4xh_neon(src, src_stride, width, height);
+ }
+ return aom_var_2d_u16_c(src, src_stride, width, height);
+}
diff --git a/test/sum_squares_test.cc b/test/sum_squares_test.cc
index b38c308..91f172d 100644
--- a/test/sum_squares_test.cc
+++ b/test/sum_squares_test.cc
@@ -852,4 +852,12 @@
::testing::Values(TestFuncVar2D(&aom_var_2d_u16_c, &aom_var_2d_u16_avx2)));
#endif // HAVE_SSE2
+
+#if HAVE_NEON
+
+INSTANTIATE_TEST_SUITE_P(
+ NEON, Highbd2dVarTest,
+ ::testing::Values(TestFuncVar2D(&aom_var_2d_u16_c, &aom_var_2d_u16_neon)));
+
+#endif // HAVE_NEON
} // namespace