Add Neon implementations of SAD3D functions
Add Armv8.0 and Armv8.4 (dot-product) Neon implementations of
aom_sad<w>x<h>x3d functions, as well as the corresponding tests.
Change-Id: Ia62c4d6cb43d620a48a98c6c2e8c817ebad26f56
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 2feed0b..19925d5 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1046,21 +1046,29 @@
specialize qw/aom_sad_skip_4x8x4d sse2 neon/;
specialize qw/aom_sad_skip_4x4x4d neon/;
- specialize qw/aom_sad128x128x3d avx2/;
- specialize qw/aom_sad128x64x3d avx2/;
- specialize qw/aom_sad64x128x3d avx2/;
- specialize qw/aom_sad64x64x3d avx2/;
- specialize qw/aom_sad64x32x3d avx2/;
- specialize qw/aom_sad32x64x3d avx2/;
- specialize qw/aom_sad32x32x3d avx2/;
- specialize qw/aom_sad32x16x3d avx2/;
- specialize qw/aom_sad16x32x3d avx2/;
- specialize qw/aom_sad16x16x3d avx2/;
- specialize qw/aom_sad16x8x3d avx2/;
+ specialize qw/aom_sad128x128x3d neon avx2/;
+ specialize qw/aom_sad128x64x3d neon avx2/;
+ specialize qw/aom_sad64x128x3d neon avx2/;
+ specialize qw/aom_sad64x64x3d neon avx2/;
+ specialize qw/aom_sad64x32x3d neon avx2/;
+ specialize qw/aom_sad32x64x3d neon avx2/;
+ specialize qw/aom_sad32x32x3d neon avx2/;
+ specialize qw/aom_sad32x16x3d neon avx2/;
+ specialize qw/aom_sad16x32x3d neon avx2/;
+ specialize qw/aom_sad16x16x3d neon avx2/;
+ specialize qw/aom_sad16x8x3d neon avx2/;
+ specialize qw/aom_sad8x16x3d neon/;
+ specialize qw/aom_sad8x8x3d neon/;
+ specialize qw/aom_sad8x4x3d neon/;
+ specialize qw/aom_sad4x8x3d neon/;
+ specialize qw/aom_sad4x4x3d neon/;
- specialize qw/aom_sad64x16x3d avx2/;
- specialize qw/aom_sad32x8x3d avx2/;
- specialize qw/aom_sad16x64x3d avx2/;
+ specialize qw/aom_sad64x16x3d neon avx2/;
+ specialize qw/aom_sad32x8x3d neon avx2/;
+ specialize qw/aom_sad16x64x3d neon avx2/;
+ specialize qw/aom_sad16x4x3d neon/;
+ specialize qw/aom_sad8x32x3d neon/;
+ specialize qw/aom_sad4x16x3d neon/;
specialize qw/aom_masked_sad128x128x4d ssse3/;
specialize qw/aom_masked_sad128x64x4d ssse3/;
diff --git a/aom_dsp/arm/sadxd_neon.c b/aom_dsp/arm/sadxd_neon.c
index a3c02c5..81803b1 100644
--- a/aom_dsp/arm/sadxd_neon.c
+++ b/aom_dsp/arm/sadxd_neon.c
@@ -26,6 +26,307 @@
*sad_sum = vdotq_u32(*sad_sum, abs_diff, vdupq_n_u8(1));
}
+static INLINE void sadwxhx3d_large_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[4],
+ int ref_stride, uint32_t res[4], int w,
+ int h) {
+ uint32x4_t sum_lo[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
+ uint32x4_t sum_hi[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
+
+ int ref_offset = 0;
+ int i = h;
+ do {
+ int j = 0;
+ do {
+ const uint8x16_t s0 = vld1q_u8(src + j);
+ sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]);
+ sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]);
+ sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]);
+
+ const uint8x16_t s1 = vld1q_u8(src + j + 16);
+ sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]);
+ sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]);
+ sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]);
+
+ j += 32;
+ } while (j < w);
+
+ src += src_stride;
+ ref_offset += ref_stride;
+ } while (--i != 0);
+
+ res[0] = horizontal_add_u32x4(vaddq_u32(sum_lo[0], sum_hi[0]));
+ res[1] = horizontal_add_u32x4(vaddq_u32(sum_lo[1], sum_hi[1]));
+ res[2] = horizontal_add_u32x4(vaddq_u32(sum_lo[2], sum_hi[2]));
+}
+
+static INLINE void sad128xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[4], int ref_stride,
+ uint32_t res[4], int h) {
+ sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 128, h);
+}
+
+static INLINE void sad64xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[4], int ref_stride,
+ uint32_t res[4], int h) {
+ sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 64, h);
+}
+
+static INLINE void sad32xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[4], int ref_stride,
+ uint32_t res[4], int h) {
+ sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 32, h);
+}
+
+static INLINE void sad16xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[4], int ref_stride,
+ uint32_t res[4], int h) {
+ uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
+
+ int ref_offset = 0;
+ int i = h;
+ do {
+ const uint8x16_t s = vld1q_u8(src);
+ sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]);
+ sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]);
+ sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]);
+
+ src += src_stride;
+ ref_offset += ref_stride;
+ } while (--i != 0);
+
+ res[0] = horizontal_add_u32x4(sum[0]);
+ res[1] = horizontal_add_u32x4(sum[1]);
+ res[2] = horizontal_add_u32x4(sum[2]);
+}
+
+#else // !(defined(__ARM_FEATURE_DOTPROD))
+
+static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
+ uint16x8_t *const sad_sum) {
+ uint8x16_t abs_diff = vabdq_u8(src, ref);
+ *sad_sum = vpadalq_u8(*sad_sum, abs_diff);
+}
+
+static INLINE void sadwxhx3d_large_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[3],
+ int ref_stride, uint32_t res[3], int w,
+ int h, int h_overflow) {
+ uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
+ int h_limit = h > h_overflow ? h_overflow : h;
+
+ int ref_offset = 0;
+ int i = 0;
+ do {
+ uint16x8_t sum_lo[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+ uint16x8_t sum_hi[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+
+ do {
+ int j = 0;
+ do {
+ const uint8x16_t s0 = vld1q_u8(src + j);
+ sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]);
+ sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]);
+ sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]);
+
+ const uint8x16_t s1 = vld1q_u8(src + j + 16);
+ sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]);
+ sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]);
+ sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]);
+
+ j += 32;
+ } while (j < w);
+
+ src += src_stride;
+ ref_offset += ref_stride;
+ } while (++i < h_limit);
+
+ sum[0] = vpadalq_u16(sum[0], sum_lo[0]);
+ sum[0] = vpadalq_u16(sum[0], sum_hi[0]);
+ sum[1] = vpadalq_u16(sum[1], sum_lo[1]);
+ sum[1] = vpadalq_u16(sum[1], sum_hi[1]);
+ sum[2] = vpadalq_u16(sum[2], sum_lo[2]);
+ sum[2] = vpadalq_u16(sum[2], sum_hi[2]);
+
+ h_limit += h_overflow;
+ } while (i < h);
+
+ res[0] = horizontal_add_u32x4(sum[0]);
+ res[1] = horizontal_add_u32x4(sum[1]);
+ res[2] = horizontal_add_u32x4(sum[2]);
+}
+
+static INLINE void sad128xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[3], int ref_stride,
+ uint32_t res[3], int h) {
+ sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 128, h, 32);
+}
+
+static INLINE void sad64xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[3], int ref_stride,
+ uint32_t res[3], int h) {
+ sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 64, h, 64);
+}
+
+static INLINE void sad32xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[3], int ref_stride,
+ uint32_t res[3], int h) {
+ uint16x8_t sum_lo[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+ uint16x8_t sum_hi[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+
+ int ref_offset = 0;
+ int i = h;
+ do {
+ const uint8x16_t s0 = vld1q_u8(src);
+ sad16_neon(s0, vld1q_u8(ref[0] + ref_offset), &sum_lo[0]);
+ sad16_neon(s0, vld1q_u8(ref[1] + ref_offset), &sum_lo[1]);
+ sad16_neon(s0, vld1q_u8(ref[2] + ref_offset), &sum_lo[2]);
+
+ const uint8x16_t s1 = vld1q_u8(src + 16);
+ sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + 16), &sum_hi[0]);
+ sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + 16), &sum_hi[1]);
+ sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + 16), &sum_hi[2]);
+
+ src += src_stride;
+ ref_offset += ref_stride;
+ } while (--i != 0);
+
+ res[0] = horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]);
+ res[1] = horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]);
+ res[2] = horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]);
+}
+
+static INLINE void sad16xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[3], int ref_stride,
+ uint32_t res[3], int h) {
+ uint16x8_t sum[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
+
+ int ref_offset = 0;
+ int i = h;
+ do {
+ const uint8x16_t s = vld1q_u8(src);
+ sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]);
+ sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]);
+ sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]);
+
+ src += src_stride;
+ ref_offset += ref_stride;
+ } while (--i != 0);
+
+ res[0] = horizontal_add_u16x8(sum[0]);
+ res[1] = horizontal_add_u16x8(sum[1]);
+ res[2] = horizontal_add_u16x8(sum[2]);
+}
+
+#endif // defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE void sad8xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[3], int ref_stride,
+ uint32_t res[3], int h) {
+ uint16x8_t sum[3];
+
+ uint8x8_t s = vld1_u8(src);
+ sum[0] = vabdl_u8(s, vld1_u8(ref[0]));
+ sum[1] = vabdl_u8(s, vld1_u8(ref[1]));
+ sum[2] = vabdl_u8(s, vld1_u8(ref[2]));
+
+ src += src_stride;
+ int ref_offset = ref_stride;
+ int i = h - 1;
+ do {
+ s = vld1_u8(src);
+ sum[0] = vabal_u8(sum[0], s, vld1_u8(ref[0] + ref_offset));
+ sum[1] = vabal_u8(sum[1], s, vld1_u8(ref[1] + ref_offset));
+ sum[2] = vabal_u8(sum[2], s, vld1_u8(ref[2] + ref_offset));
+
+ src += src_stride;
+ ref_offset += ref_stride;
+ } while (--i != 0);
+
+ res[0] = horizontal_add_u16x8(sum[0]);
+ res[1] = horizontal_add_u16x8(sum[1]);
+ res[2] = horizontal_add_u16x8(sum[2]);
+}
+
+static INLINE void sad4xhx3d_neon(const uint8_t *src, int src_stride,
+ const uint8_t *const ref[3], int ref_stride,
+ uint32_t res[3], int h) {
+ assert(h % 2 == 0);
+ uint16x8_t sum[3];
+
+ uint8x8_t s = load_unaligned_u8(src, src_stride);
+ uint8x8_t r0 = load_unaligned_u8(ref[0], ref_stride);
+ uint8x8_t r1 = load_unaligned_u8(ref[1], ref_stride);
+ uint8x8_t r2 = load_unaligned_u8(ref[2], ref_stride);
+
+ sum[0] = vabdl_u8(s, r0);
+ sum[1] = vabdl_u8(s, r1);
+ sum[2] = vabdl_u8(s, r2);
+
+ src += 2 * src_stride;
+ int ref_offset = 2 * ref_stride;
+ int i = (h / 2) - 1;
+ do {
+ s = load_unaligned_u8(src, src_stride);
+ r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride);
+ r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride);
+ r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride);
+
+ sum[0] = vabal_u8(sum[0], s, r0);
+ sum[1] = vabal_u8(sum[1], s, r1);
+ sum[2] = vabal_u8(sum[2], s, r2);
+
+ src += 2 * src_stride;
+ ref_offset += 2 * ref_stride;
+ } while (--i != 0);
+
+ res[0] = horizontal_add_u16x8(sum[0]);
+ res[1] = horizontal_add_u16x8(sum[1]);
+ res[2] = horizontal_add_u16x8(sum[2]);
+}
+
+#define SAD_WXH_3D_NEON(w, h) \
+ void aom_sad##w##x##h##x3d_neon(const uint8_t *src, int src_stride, \
+ const uint8_t *const ref[4], int ref_stride, \
+ uint32_t res[4]) { \
+ sad##w##xhx3d_neon(src, src_stride, ref, ref_stride, res, (h)); \
+ }
+
+SAD_WXH_3D_NEON(4, 4)
+SAD_WXH_3D_NEON(4, 8)
+
+SAD_WXH_3D_NEON(8, 4)
+SAD_WXH_3D_NEON(8, 8)
+SAD_WXH_3D_NEON(8, 16)
+
+SAD_WXH_3D_NEON(16, 8)
+SAD_WXH_3D_NEON(16, 16)
+SAD_WXH_3D_NEON(16, 32)
+
+SAD_WXH_3D_NEON(32, 16)
+SAD_WXH_3D_NEON(32, 32)
+SAD_WXH_3D_NEON(32, 64)
+
+SAD_WXH_3D_NEON(64, 32)
+SAD_WXH_3D_NEON(64, 64)
+SAD_WXH_3D_NEON(64, 128)
+
+SAD_WXH_3D_NEON(128, 64)
+SAD_WXH_3D_NEON(128, 128)
+
+#if !CONFIG_REALTIME_ONLY
+SAD_WXH_3D_NEON(4, 16)
+SAD_WXH_3D_NEON(8, 32)
+SAD_WXH_3D_NEON(16, 4)
+SAD_WXH_3D_NEON(16, 64)
+SAD_WXH_3D_NEON(32, 8)
+SAD_WXH_3D_NEON(64, 16)
+#endif // !CONFIG_REALTIME_ONLY
+
+#undef SAD_WXH_3D_NEON
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
static INLINE void sadwxhx4d_large_neon(const uint8_t *src, int src_stride,
const uint8_t *const ref[4],
int ref_stride, uint32_t res[4], int w,
@@ -110,12 +411,6 @@
#else // !(defined(__ARM_FEATURE_DOTPROD))
-static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
- uint16x8_t *const sad_sum) {
- uint8x16_t abs_diff = vabdq_u8(src, ref);
- *sad_sum = vpadalq_u8(*sad_sum, abs_diff);
-}
-
static INLINE void sadwxhx4d_large_neon(const uint8_t *src, int src_stride,
const uint8_t *const ref[4],
int ref_stride, uint32_t res[4], int w,
diff --git a/test/sad_test.cc b/test/sad_test.cc
index c10e929..2740305 100644
--- a/test/sad_test.cc
+++ b/test/sad_test.cc
@@ -1958,6 +1958,34 @@
};
INSTANTIATE_TEST_SUITE_P(NEON, SADavgTest, ::testing::ValuesIn(avg_neon_tests));
+const SadMxNx4Param x3d_neon_tests[] = {
+ make_tuple(128, 128, &aom_sad128x128x3d_neon, -1),
+ make_tuple(128, 64, &aom_sad128x64x3d_neon, -1),
+ make_tuple(64, 128, &aom_sad64x128x3d_neon, -1),
+ make_tuple(64, 64, &aom_sad64x64x3d_neon, -1),
+ make_tuple(64, 32, &aom_sad64x32x3d_neon, -1),
+ make_tuple(32, 64, &aom_sad32x64x3d_neon, -1),
+ make_tuple(32, 32, &aom_sad32x32x3d_neon, -1),
+ make_tuple(32, 16, &aom_sad32x16x3d_neon, -1),
+ make_tuple(16, 32, &aom_sad16x32x3d_neon, -1),
+ make_tuple(16, 16, &aom_sad16x16x3d_neon, -1),
+ make_tuple(16, 8, &aom_sad16x8x3d_neon, -1),
+ make_tuple(8, 16, &aom_sad8x16x3d_neon, -1),
+ make_tuple(8, 8, &aom_sad8x8x3d_neon, -1),
+ make_tuple(8, 4, &aom_sad8x4x3d_neon, -1),
+ make_tuple(4, 8, &aom_sad4x8x3d_neon, -1),
+ make_tuple(4, 4, &aom_sad4x4x3d_neon, -1),
+#if !CONFIG_REALTIME_ONLY
+ make_tuple(64, 16, &aom_sad64x16x3d_neon, -1),
+ make_tuple(32, 8, &aom_sad32x8x3d_neon, -1),
+ make_tuple(16, 64, &aom_sad16x64x3d_neon, -1),
+ make_tuple(16, 4, &aom_sad16x4x3d_neon, -1),
+ make_tuple(8, 32, &aom_sad8x32x3d_neon, -1),
+ make_tuple(4, 16, &aom_sad4x16x3d_neon, -1),
+#endif // !CONFIG_REALTIME_ONLY
+};
+INSTANTIATE_TEST_SUITE_P(NEON, SADx3Test, ::testing::ValuesIn(x3d_neon_tests));
+
#endif // HAVE_NEON
//------------------------------------------------------------------------------