Add Neon implementation of high bitdepth MSE functions
Add an Armv8.0 and Armv8.4 Neon implementation of high bitdepth MSE
functions for each 8-, 10- and 12-bit data and use it instead of
falling back to the scalar C function.
The implementation of the functions is a backport of this libvpx
change[1].
Add the corresponding tests as well.
[1]https://chromium-review.googlesource.com/c/webm/libvpx/+/4295398
Change-Id: I444e6285bba9e25e2e328b6824978018cc0572f0
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index e8aff91..e02d244 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1684,28 +1684,34 @@
add_proto qw/unsigned int aom_highbd_8_variance4x4/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
add_proto qw/unsigned int aom_highbd_8_mse16x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
- specialize qw/aom_highbd_8_mse16x16 sse2/;
+ specialize qw/aom_highbd_8_mse16x16 sse2 neon/;
add_proto qw/unsigned int aom_highbd_8_mse16x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
+ specialize qw/aom_highbd_8_mse16x8 neon/;
add_proto qw/unsigned int aom_highbd_8_mse8x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
+ specialize qw/aom_highbd_8_mse8x16 neon/;
add_proto qw/unsigned int aom_highbd_8_mse8x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
- specialize qw/aom_highbd_8_mse8x8 sse2/;
+ specialize qw/aom_highbd_8_mse8x8 sse2 neon/;
add_proto qw/unsigned int aom_highbd_10_mse16x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
- specialize qw/aom_highbd_10_mse16x16 sse2/;
+ specialize qw/aom_highbd_10_mse16x16 sse2 neon/;
add_proto qw/unsigned int aom_highbd_10_mse16x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
+ specialize qw/aom_highbd_10_mse16x8 neon/;
add_proto qw/unsigned int aom_highbd_10_mse8x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
+ specialize qw/aom_highbd_10_mse8x16 neon/;
add_proto qw/unsigned int aom_highbd_10_mse8x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
- specialize qw/aom_highbd_10_mse8x8 sse2/;
+ specialize qw/aom_highbd_10_mse8x8 sse2 neon/;
add_proto qw/unsigned int aom_highbd_12_mse16x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
- specialize qw/aom_highbd_12_mse16x16 sse2/;
+ specialize qw/aom_highbd_12_mse16x16 sse2 neon/;
add_proto qw/unsigned int aom_highbd_12_mse16x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
+ specialize qw/aom_highbd_12_mse16x8 neon/;
add_proto qw/unsigned int aom_highbd_12_mse8x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
+ specialize qw/aom_highbd_12_mse8x16 neon/;
add_proto qw/unsigned int aom_highbd_12_mse8x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse";
- specialize qw/aom_highbd_12_mse8x8 sse2/;
+ specialize qw/aom_highbd_12_mse8x8 sse2 neon/;
add_proto qw/void aom_highbd_comp_avg_pred/, "uint8_t *comp_pred8, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride";
diff --git a/aom_dsp/arm/highbd_variance_neon.c b/aom_dsp/arm/highbd_variance_neon.c
index 3c3877a..3b88430 100644
--- a/aom_dsp/arm/highbd_variance_neon.c
+++ b/aom_dsp/arm/highbd_variance_neon.c
@@ -1,4 +1,5 @@
/*
+ * Copyright (c) 2023 The WebM project authors. All Rights Reserved.
* Copyright (c) 2022, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
@@ -169,3 +170,153 @@
#endif // !CONFIG_REALTIME_ONLY
#undef VAR_FN
+
+static INLINE uint32_t highbd_mse_wxh_neon(const uint16_t *src_ptr,
+ int src_stride,
+ const uint16_t *ref_ptr,
+ int ref_stride, int w, int h,
+ unsigned int *sse) {
+ uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+ int i = h;
+ do {
+ int j = 0;
+ do {
+ uint16x8_t s = vld1q_u16(src_ptr + j);
+ uint16x8_t r = vld1q_u16(ref_ptr + j);
+
+ uint16x8_t diff = vabdq_u16(s, r);
+
+ sse_u32[0] =
+ vmlal_u16(sse_u32[0], vget_low_u16(diff), vget_low_u16(diff));
+ sse_u32[1] =
+ vmlal_u16(sse_u32[1], vget_high_u16(diff), vget_high_u16(diff));
+
+ j += 8;
+ } while (j < w);
+
+ src_ptr += src_stride;
+ ref_ptr += ref_stride;
+ } while (--i != 0);
+
+ *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+ return *sse;
+}
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t highbd_mse8_8xh_neon(const uint16_t *src_ptr,
+ int src_stride,
+ const uint16_t *ref_ptr,
+ int ref_stride, int h,
+ unsigned int *sse) {
+ uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+ int i = h / 2;
+ do {
+ uint16x8_t s0 = vld1q_u16(src_ptr);
+ src_ptr += src_stride;
+ uint16x8_t s1 = vld1q_u16(src_ptr);
+ src_ptr += src_stride;
+ uint16x8_t r0 = vld1q_u16(ref_ptr);
+ ref_ptr += ref_stride;
+ uint16x8_t r1 = vld1q_u16(ref_ptr);
+ ref_ptr += ref_stride;
+
+ uint8x16_t s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
+ uint8x16_t r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
+
+ uint8x16_t diff = vabdq_u8(s, r);
+ sse_u32 = vdotq_u32(sse_u32, diff, diff);
+ } while (--i != 0);
+
+ *sse = horizontal_add_u32x4(sse_u32);
+ return *sse;
+}
+
+static INLINE uint32_t highbd_mse8_16xh_neon(const uint16_t *src_ptr,
+ int src_stride,
+ const uint16_t *ref_ptr,
+ int ref_stride, int h,
+ unsigned int *sse) {
+ uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+ int i = h;
+ do {
+ uint16x8_t s0 = vld1q_u16(src_ptr);
+ uint16x8_t s1 = vld1q_u16(src_ptr + 8);
+ uint16x8_t r0 = vld1q_u16(ref_ptr);
+ uint16x8_t r1 = vld1q_u16(ref_ptr + 8);
+
+ uint8x16_t s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
+ uint8x16_t r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
+
+ uint8x16_t diff = vabdq_u8(s, r);
+ sse_u32 = vdotq_u32(sse_u32, diff, diff);
+
+ src_ptr += src_stride;
+ ref_ptr += ref_stride;
+ } while (--i != 0);
+
+ *sse = horizontal_add_u32x4(sse_u32);
+ return *sse;
+}
+
+#else // !defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t highbd_mse8_8xh_neon(const uint16_t *src_ptr,
+ int src_stride,
+ const uint16_t *ref_ptr,
+ int ref_stride, int h,
+ unsigned int *sse) {
+ return highbd_mse_wxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 8, h,
+ sse);
+}
+
+static INLINE uint32_t highbd_mse8_16xh_neon(const uint16_t *src_ptr,
+ int src_stride,
+ const uint16_t *ref_ptr,
+ int ref_stride, int h,
+ unsigned int *sse) {
+ return highbd_mse_wxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 16, h,
+ sse);
+}
+
+#endif // defined(__ARM_FEATURE_DOTPROD)
+
+#define HIGHBD_MSE_WXH_NEON(w, h) \
+ uint32_t aom_highbd_8_mse##w##x##h##_neon( \
+ const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
+ int ref_stride, uint32_t *sse) { \
+ uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
+ uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
+ highbd_mse8_##w##xh_neon(src, src_stride, ref, ref_stride, h, sse); \
+ return *sse; \
+ } \
+ \
+ uint32_t aom_highbd_10_mse##w##x##h##_neon( \
+ const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
+ int ref_stride, uint32_t *sse) { \
+ uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
+ uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
+ highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \
+ *sse = ROUND_POWER_OF_TWO(*sse, 4); \
+ return *sse; \
+ } \
+ \
+ uint32_t aom_highbd_12_mse##w##x##h##_neon( \
+ const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
+ int ref_stride, uint32_t *sse) { \
+ uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
+ uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
+ highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \
+ *sse = ROUND_POWER_OF_TWO(*sse, 8); \
+ return *sse; \
+ }
+
+HIGHBD_MSE_WXH_NEON(16, 16)
+HIGHBD_MSE_WXH_NEON(16, 8)
+HIGHBD_MSE_WXH_NEON(8, 16)
+HIGHBD_MSE_WXH_NEON(8, 8)
+
+#undef HIGHBD_MSE_WXH_NEON
diff --git a/test/variance_test.cc b/test/variance_test.cc
index c2af7d5..8db54fc 100644
--- a/test/variance_test.cc
+++ b/test/variance_test.cc
@@ -2103,6 +2103,23 @@
MseParams(3, 4, &aom_highbd_8_mse8x16_c, 8),
MseParams(3, 3, &aom_highbd_8_mse8x8_c, 8)));
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(
+ NEON, AvxHBDMseTest,
+ ::testing::Values(MseParams(4, 4, &aom_highbd_12_mse16x16_neon, 12),
+ MseParams(4, 3, &aom_highbd_12_mse16x8_neon, 12),
+ MseParams(3, 4, &aom_highbd_12_mse8x16_neon, 12),
+ MseParams(3, 3, &aom_highbd_12_mse8x8_neon, 12),
+ MseParams(4, 4, &aom_highbd_10_mse16x16_neon, 10),
+ MseParams(4, 3, &aom_highbd_10_mse16x8_neon, 10),
+ MseParams(3, 4, &aom_highbd_10_mse8x16_neon, 10),
+ MseParams(3, 3, &aom_highbd_10_mse8x8_neon, 10),
+ MseParams(4, 4, &aom_highbd_8_mse16x16_neon, 8),
+ MseParams(4, 3, &aom_highbd_8_mse16x8_neon, 8),
+ MseParams(3, 4, &aom_highbd_8_mse8x16_neon, 8),
+ MseParams(3, 3, &aom_highbd_8_mse8x8_neon, 8)));
+#endif // HAVE_NEON
+
const VarianceParams kArrayHBDVariance_c[] = {
VarianceParams(7, 7, &aom_highbd_12_variance128x128_c, 12),
VarianceParams(7, 6, &aom_highbd_12_variance128x64_c, 12),