Add Neon implementation of aom_var_2d_u8
Add Neon implementation of aom_var_2d_u8 as well as the corresponding
tests.
Change-Id: I3326011a6105843f3d77c5e22127768728399eac
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index de6573d..02122c6 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -750,7 +750,7 @@
specialize qw/aom_sum_squares_i16 sse2 neon/;
add_proto qw/uint64_t aom_var_2d_u8/, "uint8_t *src, int src_stride, int width, int height";
- specialize qw/aom_var_2d_u8 sse2 avx2/;
+ specialize qw/aom_var_2d_u8 sse2 avx2 neon/;
add_proto qw/uint64_t aom_var_2d_u16/, "uint8_t *src, int src_stride, int width, int height";
specialize qw/aom_var_2d_u16 sse2 avx2/;
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 9101979..7d2f18b 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -168,6 +168,15 @@
#endif
}
+static INLINE uint64_t horizontal_long_add_u32x2(const uint32x2_t a) {
+#if defined(__aarch64__)
+ return vaddlv_u32(a);
+#else
+ const uint64x1_t b = vpaddl_u32(a);
+ return vget_lane_u64(b, 0);
+#endif
+}
+
static INLINE uint32_t horizontal_add_u16x4(const uint16x4_t a) {
#if defined(__aarch64__)
return vaddlv_u16(a);
diff --git a/aom_dsp/arm/sum_squares_neon.c b/aom_dsp/arm/sum_squares_neon.c
index 20c9c0d..c0d009d 100644
--- a/aom_dsp/arm/sum_squares_neon.c
+++ b/aom_dsp/arm/sum_squares_neon.c
@@ -286,3 +286,316 @@
}
return aom_sum_squares_i16_c(src, n);
}
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint64_t aom_var_2d_u8_4xh_neon(uint8_t *src, int src_stride,
+ int width, int height) {
+ uint64_t sum = 0;
+ uint64_t sse = 0;
+ uint32x2_t sum_u32 = vdup_n_u32(0);
+ uint32x2_t sse_u32 = vdup_n_u32(0);
+
+ int h = height / 2;
+ do {
+ int w = width;
+ uint8_t *src_ptr = src;
+ do {
+ uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride);
+
+ sum_u32 = vdot_u32(sum_u32, s0, vdup_n_u8(1));
+
+ sse_u32 = vdot_u32(sse_u32, s0, s0);
+
+ src_ptr += 8;
+ w -= 8;
+ } while (w >= 8);
+
+ // Process remaining columns in the row using C.
+ while (w > 0) {
+ int idx = width - w;
+ const uint8_t v = src[idx];
+ sum += v;
+ sse += v * v;
+ w--;
+ }
+
+ src += 2 * src_stride;
+ } while (--h != 0);
+
+ sum += horizontal_long_add_u32x2(sum_u32);
+ sse += horizontal_long_add_u32x2(sse_u32);
+
+ return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u8_8xh_neon(uint8_t *src, int src_stride,
+ int width, int height) {
+ uint64_t sum = 0;
+ uint64_t sse = 0;
+ uint32x2_t sum_u32 = vdup_n_u32(0);
+ uint32x2_t sse_u32 = vdup_n_u32(0);
+
+ int h = height;
+ do {
+ int w = width;
+ uint8_t *src_ptr = src;
+ do {
+ uint8x8_t s0 = vld1_u8(src_ptr);
+
+ sum_u32 = vdot_u32(sum_u32, s0, vdup_n_u8(1));
+
+ sse_u32 = vdot_u32(sse_u32, s0, s0);
+
+ src_ptr += 8;
+ w -= 8;
+ } while (w >= 8);
+
+ // Process remaining columns in the row using C.
+ while (w > 0) {
+ int idx = width - w;
+ const uint8_t v = src[idx];
+ sum += v;
+ sse += v * v;
+ w--;
+ }
+
+ src += src_stride;
+ } while (--h != 0);
+
+ sum += horizontal_long_add_u32x2(sum_u32);
+ sse += horizontal_long_add_u32x2(sse_u32);
+
+ return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u8_16xh_neon(uint8_t *src, int src_stride,
+ int width, int height) {
+ uint64_t sum = 0;
+ uint64_t sse = 0;
+ uint32x4_t sum_u32 = vdupq_n_u32(0);
+ uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+ int h = height;
+ do {
+ int w = width;
+ uint8_t *src_ptr = src;
+ do {
+ uint8x16_t s0 = vld1q_u8(src_ptr);
+
+ sum_u32 = vdotq_u32(sum_u32, s0, vdupq_n_u8(1));
+
+ sse_u32 = vdotq_u32(sse_u32, s0, s0);
+
+ src_ptr += 16;
+ w -= 16;
+ } while (w >= 16);
+
+ // Process remaining columns in the row using C.
+ while (w > 0) {
+ int idx = width - w;
+ const uint8_t v = src[idx];
+ sum += v;
+ sse += v * v;
+ w--;
+ }
+
+ src += src_stride;
+ } while (--h != 0);
+
+ sum += horizontal_long_add_u32x4(sum_u32);
+ sse += horizontal_long_add_u32x4(sse_u32);
+
+ return sse - sum * sum / (width * height);
+}
+
+#else // !defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint64_t aom_var_2d_u8_4xh_neon(uint8_t *src, int src_stride,
+ int width, int height) {
+ uint64_t sum = 0;
+ uint64_t sse = 0;
+ uint32x2_t sum_u32 = vdup_n_u32(0);
+ uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+ // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
+ // element before we need to accumulate to 32-bit elements. Since we're
+ // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
+ // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
+ // * 256) / width.
+ int h_limit = (4 * 256) / width;
+ int h_tmp = height > h_limit ? h_limit : height;
+
+ int h = 0;
+ do {
+ uint16x4_t sum_u16 = vdup_n_u16(0);
+ do {
+ uint8_t *src_ptr = src;
+ int w = width;
+ do {
+ uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride);
+
+ sum_u16 = vpadal_u8(sum_u16, s0);
+
+ uint16x8_t sse_u16 = vmull_u8(s0, s0);
+
+ sse_u32 = vpadalq_u16(sse_u32, sse_u16);
+
+ src_ptr += 8;
+ w -= 8;
+ } while (w >= 8);
+
+ // Process remaining columns in the row using C.
+ while (w > 0) {
+ int idx = width - w;
+ const uint8_t v = src[idx];
+ sum += v;
+ sse += v * v;
+ w--;
+ }
+
+ src += 2 * src_stride;
+ h += 2;
+ } while (h < h_tmp && h < height);
+
+ sum_u32 = vpadal_u16(sum_u32, sum_u16);
+ h_tmp += h_limit;
+ } while (h < height);
+
+ sum += horizontal_long_add_u32x2(sum_u32);
+ sse += horizontal_long_add_u32x4(sse_u32);
+
+ return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u8_8xh_neon(uint8_t *src, int src_stride,
+ int width, int height) {
+ uint64_t sum = 0;
+ uint64_t sse = 0;
+ uint32x2_t sum_u32 = vdup_n_u32(0);
+ uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+ // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
+ // element before we need to accumulate to 32-bit elements. Since we're
+ // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
+ // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
+ // * 256) / width.
+ int h_limit = (4 * 256) / width;
+ int h_tmp = height > h_limit ? h_limit : height;
+
+ int h = 0;
+ do {
+ uint16x4_t sum_u16 = vdup_n_u16(0);
+ do {
+ uint8_t *src_ptr = src;
+ int w = width;
+ do {
+ uint8x8_t s0 = vld1_u8(src_ptr);
+
+ sum_u16 = vpadal_u8(sum_u16, s0);
+
+ uint16x8_t sse_u16 = vmull_u8(s0, s0);
+
+ sse_u32 = vpadalq_u16(sse_u32, sse_u16);
+
+ src_ptr += 8;
+ w -= 8;
+ } while (w >= 8);
+
+ // Process remaining columns in the row using C.
+ while (w > 0) {
+ int idx = width - w;
+ const uint8_t v = src[idx];
+ sum += v;
+ sse += v * v;
+ w--;
+ }
+
+ src += src_stride;
+ ++h;
+ } while (h < h_tmp && h < height);
+
+ sum_u32 = vpadal_u16(sum_u32, sum_u16);
+ h_tmp += h_limit;
+ } while (h < height);
+
+ sum += horizontal_long_add_u32x2(sum_u32);
+ sse += horizontal_long_add_u32x4(sse_u32);
+
+ return sse - sum * sum / (width * height);
+}
+
+static INLINE uint64_t aom_var_2d_u8_16xh_neon(uint8_t *src, int src_stride,
+ int width, int height) {
+ uint64_t sum = 0;
+ uint64_t sse = 0;
+ uint32x4_t sum_u32 = vdupq_n_u32(0);
+ uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+ // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
+ // element before we need to accumulate to 32-bit elements. Since we're
+ // accumulating in uint16x8_t vectors, this means we can accumulate up to 8
+ // rows of 256 elements. Therefore the limit can be computed as: h_limit = (8
+ // * 256) / width.
+ int h_limit = (8 * 256) / width;
+ int h_tmp = height > h_limit ? h_limit : height;
+
+ int h = 0;
+ do {
+ uint16x8_t sum_u16 = vdupq_n_u16(0);
+ do {
+ int w = width;
+ uint8_t *src_ptr = src;
+ do {
+ uint8x16_t s0 = vld1q_u8(src_ptr);
+
+ sum_u16 = vpadalq_u8(sum_u16, s0);
+
+ uint16x8_t sse_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(s0));
+ uint16x8_t sse_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(s0));
+
+ sse_u32[0] = vpadalq_u16(sse_u32[0], sse_u16_lo);
+ sse_u32[1] = vpadalq_u16(sse_u32[1], sse_u16_hi);
+
+ src_ptr += 16;
+ w -= 16;
+ } while (w >= 16);
+
+ // Process remaining columns in the row using C.
+ while (w > 0) {
+ int idx = width - w;
+ const uint8_t v = src[idx];
+ sum += v;
+ sse += v * v;
+ w--;
+ }
+
+ src += src_stride;
+ ++h;
+ } while (h < h_tmp && h < height);
+
+ sum_u32 = vpadalq_u16(sum_u32, sum_u16);
+ h_tmp += h_limit;
+ } while (h < height);
+
+ sum += horizontal_long_add_u32x4(sum_u32);
+ sse += horizontal_long_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+
+ return sse - sum * sum / (width * height);
+}
+
+#endif // defined(__ARM_FEATURE_DOTPROD)
+
+uint64_t aom_var_2d_u8_neon(uint8_t *src, int src_stride, int width,
+ int height) {
+ if (width >= 16) {
+ return aom_var_2d_u8_16xh_neon(src, src_stride, width, height);
+ }
+ if (width >= 8) {
+ return aom_var_2d_u8_8xh_neon(src, src_stride, width, height);
+ }
+ if (width >= 4 && height % 2 == 0) {
+ return aom_var_2d_u8_4xh_neon(src, src_stride, width, height);
+ }
+ return aom_var_2d_u8_c(src, src_stride, width, height);
+}
diff --git a/test/sum_squares_test.cc b/test/sum_squares_test.cc
index a89e58c..b38c308 100644
--- a/test/sum_squares_test.cc
+++ b/test/sum_squares_test.cc
@@ -715,6 +715,14 @@
#endif // HAVE_SSE2
+#if HAVE_NEON
+
+INSTANTIATE_TEST_SUITE_P(NEON, Lowbd2dVarTest,
+ ::testing::Values(TestFuncVar2D(&aom_var_2d_u8_c,
+ &aom_var_2d_u8_neon)));
+
+#endif // HAVE_NEON
+
class Highbd2dVarTest : public ::testing::TestWithParam<TestFuncVar2D> {
public:
virtual ~Highbd2dVarTest() {}