Add Neon implementation of high bitdepth MSE functions

Add an Armv8.0 and Armv8.4 Neon implementation of high bitdepth MSE
functions for each 8-, 10- and 12-bit data and use it instead of
falling back to the scalar C function.

The implementation of the functions is a backport of this libvpx
change[1].

Add the corresponding tests as well.

[1]https://chromium-review.googlesource.com/c/webm/libvpx/+/4295398

Change-Id: I444e6285bba9e25e2e328b6824978018cc0572f0
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index e8aff91..e02d244 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1684,28 +1684,34 @@
     add_proto qw/unsigned int aom_highbd_8_variance4x4/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
 
     add_proto qw/unsigned int aom_highbd_8_mse16x16/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
-    specialize qw/aom_highbd_8_mse16x16 sse2/;
+    specialize qw/aom_highbd_8_mse16x16 sse2 neon/;
 
     add_proto qw/unsigned int aom_highbd_8_mse16x8/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
+    specialize qw/aom_highbd_8_mse16x8 neon/;
     add_proto qw/unsigned int aom_highbd_8_mse8x16/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
+    specialize qw/aom_highbd_8_mse8x16 neon/;
     add_proto qw/unsigned int aom_highbd_8_mse8x8/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
-    specialize qw/aom_highbd_8_mse8x8 sse2/;
+    specialize qw/aom_highbd_8_mse8x8 sse2 neon/;
 
     add_proto qw/unsigned int aom_highbd_10_mse16x16/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
-    specialize qw/aom_highbd_10_mse16x16 sse2/;
+    specialize qw/aom_highbd_10_mse16x16 sse2 neon/;
 
     add_proto qw/unsigned int aom_highbd_10_mse16x8/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
+    specialize qw/aom_highbd_10_mse16x8 neon/;
     add_proto qw/unsigned int aom_highbd_10_mse8x16/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
+    specialize qw/aom_highbd_10_mse8x16 neon/;
     add_proto qw/unsigned int aom_highbd_10_mse8x8/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
-    specialize qw/aom_highbd_10_mse8x8 sse2/;
+    specialize qw/aom_highbd_10_mse8x8 sse2 neon/;
 
     add_proto qw/unsigned int aom_highbd_12_mse16x16/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
-    specialize qw/aom_highbd_12_mse16x16 sse2/;
+    specialize qw/aom_highbd_12_mse16x16 sse2 neon/;
 
     add_proto qw/unsigned int aom_highbd_12_mse16x8/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
+    specialize qw/aom_highbd_12_mse16x8 neon/;
     add_proto qw/unsigned int aom_highbd_12_mse8x16/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
+    specialize qw/aom_highbd_12_mse8x16 neon/;
     add_proto qw/unsigned int aom_highbd_12_mse8x8/, "const uint8_t *src_ptr, int  source_stride, const uint8_t *ref_ptr, int  recon_stride, unsigned int *sse";
-    specialize qw/aom_highbd_12_mse8x8 sse2/;
+    specialize qw/aom_highbd_12_mse8x8 sse2 neon/;
 
     add_proto qw/void aom_highbd_comp_avg_pred/, "uint8_t *comp_pred8, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride";
 
diff --git a/aom_dsp/arm/highbd_variance_neon.c b/aom_dsp/arm/highbd_variance_neon.c
index 3c3877a..3b88430 100644
--- a/aom_dsp/arm/highbd_variance_neon.c
+++ b/aom_dsp/arm/highbd_variance_neon.c
@@ -1,4 +1,5 @@
 /*
+ * Copyright (c) 2023 The WebM project authors. All Rights Reserved.
  * Copyright (c) 2022, Alliance for Open Media. All rights reserved
  *
  * This source code is subject to the terms of the BSD 2 Clause License and
@@ -169,3 +170,153 @@
 #endif  // !CONFIG_REALTIME_ONLY
 
 #undef VAR_FN
+
+static INLINE uint32_t highbd_mse_wxh_neon(const uint16_t *src_ptr,
+                                           int src_stride,
+                                           const uint16_t *ref_ptr,
+                                           int ref_stride, int w, int h,
+                                           unsigned int *sse) {
+  uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = h;
+  do {
+    int j = 0;
+    do {
+      uint16x8_t s = vld1q_u16(src_ptr + j);
+      uint16x8_t r = vld1q_u16(ref_ptr + j);
+
+      uint16x8_t diff = vabdq_u16(s, r);
+
+      sse_u32[0] =
+          vmlal_u16(sse_u32[0], vget_low_u16(diff), vget_low_u16(diff));
+      sse_u32[1] =
+          vmlal_u16(sse_u32[1], vget_high_u16(diff), vget_high_u16(diff));
+
+      j += 8;
+    } while (j < w);
+
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+  } while (--i != 0);
+
+  *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+  return *sse;
+}
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t highbd_mse8_8xh_neon(const uint16_t *src_ptr,
+                                            int src_stride,
+                                            const uint16_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            unsigned int *sse) {
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = h / 2;
+  do {
+    uint16x8_t s0 = vld1q_u16(src_ptr);
+    src_ptr += src_stride;
+    uint16x8_t s1 = vld1q_u16(src_ptr);
+    src_ptr += src_stride;
+    uint16x8_t r0 = vld1q_u16(ref_ptr);
+    ref_ptr += ref_stride;
+    uint16x8_t r1 = vld1q_u16(ref_ptr);
+    ref_ptr += ref_stride;
+
+    uint8x16_t s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
+    uint8x16_t r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
+
+    uint8x16_t diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, diff, diff);
+  } while (--i != 0);
+
+  *sse = horizontal_add_u32x4(sse_u32);
+  return *sse;
+}
+
+static INLINE uint32_t highbd_mse8_16xh_neon(const uint16_t *src_ptr,
+                                             int src_stride,
+                                             const uint16_t *ref_ptr,
+                                             int ref_stride, int h,
+                                             unsigned int *sse) {
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = h;
+  do {
+    uint16x8_t s0 = vld1q_u16(src_ptr);
+    uint16x8_t s1 = vld1q_u16(src_ptr + 8);
+    uint16x8_t r0 = vld1q_u16(ref_ptr);
+    uint16x8_t r1 = vld1q_u16(ref_ptr + 8);
+
+    uint8x16_t s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
+    uint8x16_t r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
+
+    uint8x16_t diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, diff, diff);
+
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+  } while (--i != 0);
+
+  *sse = horizontal_add_u32x4(sse_u32);
+  return *sse;
+}
+
+#else  // !defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t highbd_mse8_8xh_neon(const uint16_t *src_ptr,
+                                            int src_stride,
+                                            const uint16_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            unsigned int *sse) {
+  return highbd_mse_wxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 8, h,
+                             sse);
+}
+
+static INLINE uint32_t highbd_mse8_16xh_neon(const uint16_t *src_ptr,
+                                             int src_stride,
+                                             const uint16_t *ref_ptr,
+                                             int ref_stride, int h,
+                                             unsigned int *sse) {
+  return highbd_mse_wxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 16, h,
+                             sse);
+}
+
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
+#define HIGHBD_MSE_WXH_NEON(w, h)                                       \
+  uint32_t aom_highbd_8_mse##w##x##h##_neon(                            \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse8_##w##xh_neon(src, src_stride, ref, ref_stride, h, sse); \
+    return *sse;                                                        \
+  }                                                                     \
+                                                                        \
+  uint32_t aom_highbd_10_mse##w##x##h##_neon(                           \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse);   \
+    *sse = ROUND_POWER_OF_TWO(*sse, 4);                                 \
+    return *sse;                                                        \
+  }                                                                     \
+                                                                        \
+  uint32_t aom_highbd_12_mse##w##x##h##_neon(                           \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse);   \
+    *sse = ROUND_POWER_OF_TWO(*sse, 8);                                 \
+    return *sse;                                                        \
+  }
+
+HIGHBD_MSE_WXH_NEON(16, 16)
+HIGHBD_MSE_WXH_NEON(16, 8)
+HIGHBD_MSE_WXH_NEON(8, 16)
+HIGHBD_MSE_WXH_NEON(8, 8)
+
+#undef HIGHBD_MSE_WXH_NEON
diff --git a/test/variance_test.cc b/test/variance_test.cc
index c2af7d5..8db54fc 100644
--- a/test/variance_test.cc
+++ b/test/variance_test.cc
@@ -2103,6 +2103,23 @@
                       MseParams(3, 4, &aom_highbd_8_mse8x16_c, 8),
                       MseParams(3, 3, &aom_highbd_8_mse8x8_c, 8)));
 
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(
+    NEON, AvxHBDMseTest,
+    ::testing::Values(MseParams(4, 4, &aom_highbd_12_mse16x16_neon, 12),
+                      MseParams(4, 3, &aom_highbd_12_mse16x8_neon, 12),
+                      MseParams(3, 4, &aom_highbd_12_mse8x16_neon, 12),
+                      MseParams(3, 3, &aom_highbd_12_mse8x8_neon, 12),
+                      MseParams(4, 4, &aom_highbd_10_mse16x16_neon, 10),
+                      MseParams(4, 3, &aom_highbd_10_mse16x8_neon, 10),
+                      MseParams(3, 4, &aom_highbd_10_mse8x16_neon, 10),
+                      MseParams(3, 3, &aom_highbd_10_mse8x8_neon, 10),
+                      MseParams(4, 4, &aom_highbd_8_mse16x16_neon, 8),
+                      MseParams(4, 3, &aom_highbd_8_mse16x8_neon, 8),
+                      MseParams(3, 4, &aom_highbd_8_mse8x16_neon, 8),
+                      MseParams(3, 3, &aom_highbd_8_mse8x8_neon, 8)));
+#endif  // HAVE_NEON
+
 const VarianceParams kArrayHBDVariance_c[] = {
   VarianceParams(7, 7, &aom_highbd_12_variance128x128_c, 12),
   VarianceParams(7, 6, &aom_highbd_12_variance128x64_c, 12),