Refactor aom_sse_neon to use width-specific helper functions
Refactor Neon SSE function to use width-specific helper functions of
the form sse_<w>xh_neon. A faster Armv8.4-A dot-product
implementation of each helper function will be added in a subsequent
patch.
Change-Id: I5cca24099d709b9dddf39772e94aff83562f8f9a
diff --git a/aom_dsp/arm/sse_neon.c b/aom_dsp/arm/sse_neon.c
index a69dfb5..767a7e9 100644
--- a/aom_dsp/arm/sse_neon.c
+++ b/aom_dsp/arm/sse_neon.c
@@ -16,141 +16,202 @@
#include "aom_dsp/arm/sum_neon.h"
#include "aom_dsp/arm/transpose_neon.h"
-static INLINE void sse_w16_neon(uint32x4_t *sum, const uint8_t *a,
- const uint8_t *b) {
- const uint8x16_t v_a0 = vld1q_u8(a);
- const uint8x16_t v_b0 = vld1q_u8(b);
- const uint8x16_t diff = vabdq_u8(v_a0, v_b0);
- const uint8x8_t diff_lo = vget_low_u8(diff);
- const uint8x8_t diff_hi = vget_high_u8(diff);
- *sum = vpadalq_u16(*sum, vmull_u8(diff_lo, diff_lo));
- *sum = vpadalq_u16(*sum, vmull_u8(diff_hi, diff_hi));
+static INLINE void sse_16x1_neon(const uint8_t *src, const uint8_t *ref,
+ uint32x4_t *sse) {
+ uint8x16_t s = vld1q_u8(src);
+ uint8x16_t r = vld1q_u8(ref);
+
+ uint8x16_t abs_diff = vabdq_u8(s, r);
+ uint8x8_t abs_diff_lo = vget_low_u8(abs_diff);
+ uint8x8_t abs_diff_hi = vget_high_u8(abs_diff);
+
+ *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_lo, abs_diff_lo));
+ *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_hi, abs_diff_hi));
}
-static INLINE void aom_sse4x2_neon(const uint8_t *a, int a_stride,
- const uint8_t *b, int b_stride,
- uint32x4_t *sum) {
- uint8x8_t v_a0, v_b0;
- v_a0 = v_b0 = vcreate_u8(0);
- // above line is only to shadow [-Werror=uninitialized]
- v_a0 = vreinterpret_u8_u32(
- vld1_lane_u32((uint32_t *)a, vreinterpret_u32_u8(v_a0), 0));
- v_a0 = vreinterpret_u8_u32(
- vld1_lane_u32((uint32_t *)(a + a_stride), vreinterpret_u32_u8(v_a0), 1));
- v_b0 = vreinterpret_u8_u32(
- vld1_lane_u32((uint32_t *)b, vreinterpret_u32_u8(v_b0), 0));
- v_b0 = vreinterpret_u8_u32(
- vld1_lane_u32((uint32_t *)(b + b_stride), vreinterpret_u32_u8(v_b0), 1));
- const uint8x8_t v_a_w = vabd_u8(v_a0, v_b0);
- *sum = vpadalq_u16(*sum, vmull_u8(v_a_w, v_a_w));
+
+static INLINE void sse_8x1_neon(const uint8_t *src, const uint8_t *ref,
+ uint32x4_t *sse) {
+ uint8x8_t s = vld1_u8(src);
+ uint8x8_t r = vld1_u8(ref);
+
+ uint8x8_t abs_diff = vabd_u8(s, r);
+
+ *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
}
-static INLINE void aom_sse8_neon(const uint8_t *a, const uint8_t *b,
- uint32x4_t *sum) {
- const uint8x8_t v_a_w = vld1_u8(a);
- const uint8x8_t v_b_w = vld1_u8(b);
- const uint8x8_t v_d_w = vabd_u8(v_a_w, v_b_w);
- *sum = vpadalq_u16(*sum, vmull_u8(v_d_w, v_d_w));
+
+static INLINE void sse_4x2_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ uint32x4_t *sse) {
+ uint8x8_t s = load_unaligned_u8(src, src_stride);
+ uint8x8_t r = load_unaligned_u8(ref, ref_stride);
+
+ uint8x8_t abs_diff = vabd_u8(s, r);
+
+ *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff));
}
-int64_t aom_sse_neon(const uint8_t *a, int a_stride, const uint8_t *b,
- int b_stride, int width, int height) {
- int y = 0;
- int64_t sse = 0;
- uint32x4_t sum = vdupq_n_u32(0);
- switch (width) {
- case 4:
+
+static INLINE uint32_t sse_128xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse = vdupq_n_u32(0);
+
+ int i = 0;
+ do {
+ sse_16x1_neon(src, ref, &sse);
+ sse_16x1_neon(src + 16, ref + 16, &sse);
+ sse_16x1_neon(src + 32, ref + 32, &sse);
+ sse_16x1_neon(src + 48, ref + 48, &sse);
+ sse_16x1_neon(src + 64, ref + 64, &sse);
+ sse_16x1_neon(src + 80, ref + 80, &sse);
+ sse_16x1_neon(src + 96, ref + 96, &sse);
+ sse_16x1_neon(src + 112, ref + 112, &sse);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
+
+ return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_64xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse = vdupq_n_u32(0);
+
+ int i = 0;
+ do {
+ sse_16x1_neon(src, ref, &sse);
+ sse_16x1_neon(src + 16, ref + 16, &sse);
+ sse_16x1_neon(src + 32, ref + 32, &sse);
+ sse_16x1_neon(src + 48, ref + 48, &sse);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
+
+ return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_32xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse = vdupq_n_u32(0);
+
+ int i = 0;
+ do {
+ sse_16x1_neon(src, ref, &sse);
+ sse_16x1_neon(src + 16, ref + 16, &sse);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
+
+ return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_16xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse = vdupq_n_u32(0);
+
+ int i = 0;
+ do {
+ sse_16x1_neon(src, ref, &sse);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
+
+ return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_8xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse = vdupq_n_u32(0);
+
+ int i = 0;
+ do {
+ sse_8x1_neon(src, ref, &sse);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
+
+ return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_4xh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int height) {
+ uint32x4_t sse = vdupq_n_u32(0);
+
+ int i = 0;
+ do {
+ sse_4x2_neon(src, src_stride, ref, ref_stride, &sse);
+
+ src += 2 * src_stride;
+ ref += 2 * ref_stride;
+ i += 2;
+ } while (i < height);
+
+ return horizontal_add_u32x4(sse);
+}
+
+static INLINE uint32_t sse_wxh_neon(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ int width, int height) {
+ uint32x4_t sse = vdupq_n_u32(0);
+
+ if ((width & 0x07) && ((width & 0x07) < 5)) {
+ int i = 0;
+ do {
+ int j = 0;
do {
- aom_sse4x2_neon(a, a_stride, b, b_stride, &sum);
- a += a_stride << 1;
- b += b_stride << 1;
- y += 2;
- } while (y < height);
- sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
- break;
- case 8:
+ sse_8x1_neon(src + j, ref + j, &sse);
+ sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse);
+ j += 8;
+ } while (j + 4 < width);
+
+ sse_4x2_neon(src + j, src_stride, ref + j, ref_stride, &sse);
+ src += 2 * src_stride;
+ ref += 2 * ref_stride;
+ i += 2;
+ } while (i < height);
+ } else {
+ int i = 0;
+ do {
+ int j = 0;
do {
- aom_sse8_neon(a, b, &sum);
- a += a_stride;
- b += b_stride;
- y += 1;
- } while (y < height);
- sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
- break;
- case 16:
- do {
- sse_w16_neon(&sum, a, b);
- a += a_stride;
- b += b_stride;
- y += 1;
- } while (y < height);
- sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
- break;
- case 32:
- do {
- sse_w16_neon(&sum, a, b);
- sse_w16_neon(&sum, a + 16, b + 16);
- a += a_stride;
- b += b_stride;
- y += 1;
- } while (y < height);
- sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
- break;
- case 64:
- do {
- sse_w16_neon(&sum, a, b);
- sse_w16_neon(&sum, a + 16 * 1, b + 16 * 1);
- sse_w16_neon(&sum, a + 16 * 2, b + 16 * 2);
- sse_w16_neon(&sum, a + 16 * 3, b + 16 * 3);
- a += a_stride;
- b += b_stride;
- y += 1;
- } while (y < height);
- sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
- break;
- case 128:
- do {
- sse_w16_neon(&sum, a, b);
- sse_w16_neon(&sum, a + 16 * 1, b + 16 * 1);
- sse_w16_neon(&sum, a + 16 * 2, b + 16 * 2);
- sse_w16_neon(&sum, a + 16 * 3, b + 16 * 3);
- sse_w16_neon(&sum, a + 16 * 4, b + 16 * 4);
- sse_w16_neon(&sum, a + 16 * 5, b + 16 * 5);
- sse_w16_neon(&sum, a + 16 * 6, b + 16 * 6);
- sse_w16_neon(&sum, a + 16 * 7, b + 16 * 7);
- a += a_stride;
- b += b_stride;
- y += 1;
- } while (y < height);
- sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
- break;
- default:
- if (width & 0x07) {
- do {
- int i = 0;
- do {
- aom_sse8_neon(a + i, b + i, &sum);
- aom_sse8_neon(a + i + a_stride, b + i + b_stride, &sum);
- i += 8;
- } while (i + 4 < width);
- aom_sse4x2_neon(a + i, a_stride, b + i, b_stride, &sum);
- a += (a_stride << 1);
- b += (b_stride << 1);
- y += 2;
- } while (y < height);
- } else {
- do {
- int i = 0;
- do {
- aom_sse8_neon(a + i, b + i, &sum);
- i += 8;
- } while (i < width);
- a += a_stride;
- b += b_stride;
- y += 1;
- } while (y < height);
- }
- sse = horizontal_add_s32x4(vreinterpretq_s32_u32(sum));
- break;
+ sse_8x1_neon(src + j, ref + j, &sse);
+ j += 8;
+ } while (j < width);
+
+ src += src_stride;
+ ref += ref_stride;
+ i++;
+ } while (i < height);
}
- return sse;
+ return horizontal_add_u32x4(sse);
+}
+
+int64_t aom_sse_neon(const uint8_t *src, int src_stride, const uint8_t *ref,
+ int ref_stride, int width, int height) {
+ switch (width) {
+ case 4: return sse_4xh_neon(src, src_stride, ref, ref_stride, height);
+ case 8: return sse_8xh_neon(src, src_stride, ref, ref_stride, height);
+ case 16: return sse_16xh_neon(src, src_stride, ref, ref_stride, height);
+ case 32: return sse_32xh_neon(src, src_stride, ref, ref_stride, height);
+ case 64: return sse_64xh_neon(src, src_stride, ref, ref_stride, height);
+ case 128: return sse_128xh_neon(src, src_stride, ref, ref_stride, height);
+ default:
+ return sse_wxh_neon(src, src_stride, ref, ref_stride, width, height);
+ }
}
#if CONFIG_AV1_HIGHBITDEPTH