Implement aom_sse_neon helpers using UDOT instruction
Accelerate the aom_sse_neon helper functions using the Armv8.4-A UDOT
instruction instead of multiple UMULL and UADALP instructions.
The previous implementation is retained for use on systems that do
not support the Armv8.4-A dot-product instructions.
Change-Id: Ia0c14a4bb5d95f68988d9d5d848f93b81bf7fa87
diff --git a/aom_dsp/arm/sse_neon.c b/aom_dsp/arm/sse_neon.c
index 767a7e9..4370146 100644
--- a/aom_dsp/arm/sse_neon.c
+++ b/aom_dsp/arm/sse_neon.c
@@ -16,6 +16,115 @@
#include "aom_dsp/arm/sum_neon.h"
#include "aom_dsp/arm/transpose_neon.h"
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE void sse_16x1_neon(const uint8_t *src, const uint8_t *ref,
+ uint32x4_t *sse) {
+ uint8x16_t s = vld1q_u8(src);
+ uint8x16_t r = vld1q_u8(ref);
+
+ uint8x16_t abs_diff = vabdq_u8(s, r);
+
+ *sse = vdotq_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE void sse_8x1_neon(const uint8_t *src, const uint8_t *ref,
+ uint32x2_t *sse) {
+ uint8x8_t s = vld1_u8(src);
+ uint8x8_t r = vld1_u8(ref);
+
+ uint8x8_t abs_diff = vabd_u8(s, r);
+
+ *sse = vdot_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE void sse_4x2_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ uint32x2_t *sse) {
+ uint8x8_t s = load_unaligned_u8(src, src_stride);
+ uint8x8_t r = load_unaligned_u8(ref, ref_stride);
+
+ uint8x8_t abs_diff = vabd_u8(s, r);
+
+ *sse = vdot_u32(*sse, abs_diff, abs_diff);
+}
+
+static INLINE uint32_t sse_8xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
+
+ int i = 0;
+ do {
+ sse_8x1_neon(src, ref, &sse[0]);
+ src += src_stride;
+ ref += ref_stride;
+ sse_8x1_neon(src, ref, &sse[1]);
+ src += src_stride;
+ ref += ref_stride;
+ i += 2;
+ } while (i < height);
+
+ return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_4xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x2_t sse = vdup_n_u32(0);
+
+ int i = 0;
+ do {
+ sse_4x2_neon(src, src_stride, ref, ref_stride, &sse);
+
+ src += 2 * src_stride;
+ ref += 2 * ref_stride;
+ i += 2;
+ } while (i < height);
+
+ return horizontal_add_u32x2(sse);
+}
+
+static INLINE uint32_t sse_wxh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int width, int height) {
+ uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
+
+ if ((width & 0x07) && ((width & 0x07) < 5)) {
+ int i = 0;
+ do {
+ int j = 0;
+ do {
+ sse_8x1_neon(src + j, ref + j, &sse[0]);
+ sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse[1]);
+ j += 8;
+ } while (j + 4 < width);
+
+ sse_4x2_neon(src + j, src_stride, ref + j, ref_stride, &sse[0]);
+ src += 2 * src_stride;
+ ref += 2 * ref_stride;
+ i += 2;
+ } while (i < height);
+ } else {
+ int i = 0;
+ do {
+ int j = 0;
+ do {
+ sse_8x1_neon(src + j, ref + j, &sse[0]);
+ sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse[1]);
+ j += 8;
+ } while (j < width);
+
+ src += 2 * src_stride;
+ ref += 2 * ref_stride;
+ i += 2;
+ } while (i < height);
+ }
+ return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1]));
+}
+
+#else // !defined(__ARM_FEATURE_DOTPROD)
+
static INLINE void sse_16x1_neon(const uint8_t *src, const uint8_t *ref,
uint32x4_t *sse) {
uint8x16_t s = vld1q_u8(src);
@@ -50,85 +159,6 @@
*sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
}
-static INLINE uint32_t sse_128xh_neon(const uint8_t *src, int src_stride,
- const uint8_t *ref, int ref_stride,
- int height) {
- uint32x4_t sse = vdupq_n_u32(0);
-
- int i = 0;
- do {
- sse_16x1_neon(src, ref, &sse);
- sse_16x1_neon(src + 16, ref + 16, &sse);
- sse_16x1_neon(src + 32, ref + 32, &sse);
- sse_16x1_neon(src + 48, ref + 48, &sse);
- sse_16x1_neon(src + 64, ref + 64, &sse);
- sse_16x1_neon(src + 80, ref + 80, &sse);
- sse_16x1_neon(src + 96, ref + 96, &sse);
- sse_16x1_neon(src + 112, ref + 112, &sse);
-
- src += src_stride;
- ref += ref_stride;
- i++;
- } while (i < height);
-
- return horizontal_add_u32x4(sse);
-}
-
-static INLINE uint32_t sse_64xh_neon(const uint8_t *src, int src_stride,
- const uint8_t *ref, int ref_stride,
- int height) {
- uint32x4_t sse = vdupq_n_u32(0);
-
- int i = 0;
- do {
- sse_16x1_neon(src, ref, &sse);
- sse_16x1_neon(src + 16, ref + 16, &sse);
- sse_16x1_neon(src + 32, ref + 32, &sse);
- sse_16x1_neon(src + 48, ref + 48, &sse);
-
- src += src_stride;
- ref += ref_stride;
- i++;
- } while (i < height);
-
- return horizontal_add_u32x4(sse);
-}
-
-static INLINE uint32_t sse_32xh_neon(const uint8_t *src, int src_stride,
- const uint8_t *ref, int ref_stride,
- int height) {
- uint32x4_t sse = vdupq_n_u32(0);
-
- int i = 0;
- do {
- sse_16x1_neon(src, ref, &sse);
- sse_16x1_neon(src + 16, ref + 16, &sse);
-
- src += src_stride;
- ref += ref_stride;
- i++;
- } while (i < height);
-
- return horizontal_add_u32x4(sse);
-}
-
-static INLINE uint32_t sse_16xh_neon(const uint8_t *src, int src_stride,
- const uint8_t *ref, int ref_stride,
- int height) {
- uint32x4_t sse = vdupq_n_u32(0);
-
- int i = 0;
- do {
- sse_16x1_neon(src, ref, &sse);
-
- src += src_stride;
- ref += ref_stride;
- i++;
- } while (i < height);
-
- return horizontal_add_u32x4(sse);
-}
-
static INLINE uint32_t sse_8xh_neon(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride,
int height) {
@@ -200,6 +230,89 @@
return horizontal_add_u32x4(sse);
}
+#endif // defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t sse_128xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+ int i = 0;
+ do {
+ sse_16x1_neon(src, ref, &sse[0]);
+ sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+ sse_16x1_neon(src + 32, ref + 32, &sse[0]);
+ sse_16x1_neon(src + 48, ref + 48, &sse[1]);
+ sse_16x1_neon(src + 64, ref + 64, &sse[0]);
+ sse_16x1_neon(src + 80, ref + 80, &sse[1]);
+ sse_16x1_neon(src + 96, ref + 96, &sse[0]);
+ sse_16x1_neon(src + 112, ref + 112, &sse[1]);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
+
+ return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_64xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+ int i = 0;
+ do {
+ sse_16x1_neon(src, ref, &sse[0]);
+ sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+ sse_16x1_neon(src + 32, ref + 32, &sse[0]);
+ sse_16x1_neon(src + 48, ref + 48, &sse[1]);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
+
+ return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_32xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+ int i = 0;
+ do {
+ sse_16x1_neon(src, ref, &sse[0]);
+ sse_16x1_neon(src + 16, ref + 16, &sse[1]);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
+
+ return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
+}
+
+static INLINE uint32_t sse_16xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+ int i = 0;
+ do {
+ sse_16x1_neon(src, ref, &sse[0]);
+ src += src_stride;
+ ref += ref_stride;
+ sse_16x1_neon(src, ref, &sse[1]);
+ src += src_stride;
+ ref += ref_stride;
+ i += 2;
+ } while (i < height);
+
+ return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1]));
+}
+
int64_t aom_sse_neon(const uint8_t *src, int src_stride, const uint8_t *ref,
int ref_stride, int width, int height) {
switch (width) {
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 0cf110a..855edf6 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -94,6 +94,15 @@
#endif
}
+static INLINE uint32_t horizontal_add_u32x2(const uint32x2_t a) {
+#if defined(__aarch64__)
+ return vaddv_u32(a);
+#else
+ const uint64x1_t b = vpaddl_u32(a);
+ return vget_lane_u32(vreinterpret_u32_u64(b), 0);
+#endif
+}
+
static INLINE uint32_t horizontal_add_u16x4(const uint16x4_t a) {
#if defined(__aarch64__)
return vaddlv_u16(a);