Optimize implementation of aom_sum_squares_2d_i16_neon
Optimize the implementation of aom_sum_squares_2d_i16_neon by
refactoring the sum-of-squares helper functions. The main change is
to use vector accumulators that don't require some kind of reduction
on every loop iteration.
Also add new 32/64-bit agnostic reduction functions to sum_neon.h and
use them instead of #ifdef blocks in sum_squares_neon.c.
Change-Id: Icd7cab01a57621f567414b3637e5dea10a56e960
diff --git a/aom_dsp/arm/sum_neon.h b/aom_dsp/arm/sum_neon.h
index 4116509..0cf110a 100644
--- a/aom_dsp/arm/sum_neon.h
+++ b/aom_dsp/arm/sum_neon.h
@@ -37,6 +37,23 @@
#endif
}
+static INLINE uint64_t horizontal_add_u64x2(const uint64x2_t a) {
+#if defined(__aarch64__)
+ return vaddvq_u64(a);
+#else
+ return vgetq_lane_u64(a, 0) + vgetq_lane_u64(a, 1);
+#endif
+}
+
+static INLINE uint64_t horizontal_long_add_u32x4(const uint32x4_t a) {
+#if defined(__aarch64__)
+ return vaddlvq_u32(a);
+#else
+ const uint64x2_t b = vpaddlq_u32(a);
+ return vgetq_lane_u64(b, 0) + vgetq_lane_u64(b, 1);
+#endif
+}
+
static INLINE unsigned int horizontal_add_u32x4(const uint32x4_t a) {
#if defined(__aarch64__)
return vaddvq_u32(a);
diff --git a/aom_dsp/arm/sum_squares_neon.c b/aom_dsp/arm/sum_squares_neon.c
index 0b7337a..524b098 100644
--- a/aom_dsp/arm/sum_squares_neon.c
+++ b/aom_dsp/arm/sum_squares_neon.c
@@ -13,111 +13,83 @@
#include <assert.h>
#include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/arm/sum_neon.h"
#include "config/aom_dsp_rtcd.h"
-static INLINE uint32x4_t sum_squares_i16_4x4_neon(const int16_t *src,
- int stride) {
- const int16x4_t v_val_01_lo = vld1_s16(src + 0 * stride);
- const int16x4_t v_val_01_hi = vld1_s16(src + 1 * stride);
- const int16x4_t v_val_23_lo = vld1_s16(src + 2 * stride);
- const int16x4_t v_val_23_hi = vld1_s16(src + 3 * stride);
- int32x4_t v_sq_01_d = vmull_s16(v_val_01_lo, v_val_01_lo);
- v_sq_01_d = vmlal_s16(v_sq_01_d, v_val_01_hi, v_val_01_hi);
- int32x4_t v_sq_23_d = vmull_s16(v_val_23_lo, v_val_23_lo);
- v_sq_23_d = vmlal_s16(v_sq_23_d, v_val_23_hi, v_val_23_hi);
-#if defined(__aarch64__)
- return vreinterpretq_u32_s32(vpaddq_s32(v_sq_01_d, v_sq_23_d));
-#else
- return vreinterpretq_u32_s32(vcombine_s32(
- vqmovn_s64(vpaddlq_s32(v_sq_01_d)), vqmovn_s64(vpaddlq_s32(v_sq_23_d))));
-#endif
+static INLINE uint64_t aom_sum_squares_2d_i16_4x4_neon(const int16_t *src,
+ int stride) {
+ int16x4_t s0 = vld1_s16(src + 0 * stride);
+ int16x4_t s1 = vld1_s16(src + 1 * stride);
+ int16x4_t s2 = vld1_s16(src + 2 * stride);
+ int16x4_t s3 = vld1_s16(src + 3 * stride);
+
+ int32x4_t sum_squares = vmull_s16(s0, s0);
+ sum_squares = vmlal_s16(sum_squares, s1, s1);
+ sum_squares = vmlal_s16(sum_squares, s2, s2);
+ sum_squares = vmlal_s16(sum_squares, s3, s3);
+
+ return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sum_squares));
}
-uint64_t aom_sum_squares_2d_i16_4x4_neon(const int16_t *src, int stride) {
- const uint32x4_t v_sum_0123_d = sum_squares_i16_4x4_neon(src, stride);
-#if defined(__aarch64__)
- return (uint64_t)vaddvq_u32(v_sum_0123_d);
-#else
- uint64x2_t v_sum_d = vpaddlq_u32(v_sum_0123_d);
- v_sum_d = vaddq_u64(v_sum_d, vextq_u64(v_sum_d, v_sum_d, 1));
- return vgetq_lane_u64(v_sum_d, 0);
-#endif
-}
+static INLINE uint64_t aom_sum_squares_2d_i16_4xn_neon(const int16_t *src,
+ int stride, int height) {
+ int32x4_t sum_squares[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
-uint64_t aom_sum_squares_2d_i16_4xn_neon(const int16_t *src, int stride,
- int height) {
- int r = 0;
- uint32x4_t v_acc_q = vdupq_n_u32(0);
+ int h = 0;
do {
- const uint32x4_t v_acc_d = sum_squares_i16_4x4_neon(src, stride);
- v_acc_q = vaddq_u32(v_acc_q, v_acc_d);
- src += stride << 2;
- r += 4;
- } while (r < height);
+ int16x4_t s0 = vld1_s16(src + 0 * stride);
+ int16x4_t s1 = vld1_s16(src + 1 * stride);
+ int16x4_t s2 = vld1_s16(src + 2 * stride);
+ int16x4_t s3 = vld1_s16(src + 3 * stride);
- uint64x2_t v_acc_64 = vpaddlq_u32(v_acc_q);
-#if defined(__aarch64__)
- return vaddvq_u64(v_acc_64);
-#else
- v_acc_64 = vaddq_u64(v_acc_64, vextq_u64(v_acc_64, v_acc_64, 1));
- return vgetq_lane_u64(v_acc_64, 0);
-#endif
-}
-
-uint64_t aom_sum_squares_2d_i16_nxn_neon(const int16_t *src, int stride,
- int width, int height) {
- int r = 0;
- const int32x4_t zero = vdupq_n_s32(0);
- uint64x2_t v_acc_q = vreinterpretq_u64_s32(zero);
- do {
- int32x4_t v_sum = zero;
- int c = 0;
- do {
- const int16_t *b = src + c;
- const int16x8_t v_val_0 = vld1q_s16(b + 0 * stride);
- const int16x8_t v_val_1 = vld1q_s16(b + 1 * stride);
- const int16x8_t v_val_2 = vld1q_s16(b + 2 * stride);
- const int16x8_t v_val_3 = vld1q_s16(b + 3 * stride);
- const int16x4_t v_val_0_lo = vget_low_s16(v_val_0);
- const int16x4_t v_val_1_lo = vget_low_s16(v_val_1);
- const int16x4_t v_val_2_lo = vget_low_s16(v_val_2);
- const int16x4_t v_val_3_lo = vget_low_s16(v_val_3);
- int32x4_t v_sum_01 = vmull_s16(v_val_0_lo, v_val_0_lo);
- v_sum_01 = vmlal_s16(v_sum_01, v_val_1_lo, v_val_1_lo);
- int32x4_t v_sum_23 = vmull_s16(v_val_2_lo, v_val_2_lo);
- v_sum_23 = vmlal_s16(v_sum_23, v_val_3_lo, v_val_3_lo);
-#if defined(__aarch64__)
- v_sum_01 = vmlal_high_s16(v_sum_01, v_val_0, v_val_0);
- v_sum_01 = vmlal_high_s16(v_sum_01, v_val_1, v_val_1);
- v_sum_23 = vmlal_high_s16(v_sum_23, v_val_2, v_val_2);
- v_sum_23 = vmlal_high_s16(v_sum_23, v_val_3, v_val_3);
- v_sum = vaddq_s32(v_sum, vpaddq_s32(v_sum_01, v_sum_23));
-#else
- const int16x4_t v_val_0_hi = vget_high_s16(v_val_0);
- const int16x4_t v_val_1_hi = vget_high_s16(v_val_1);
- const int16x4_t v_val_2_hi = vget_high_s16(v_val_2);
- const int16x4_t v_val_3_hi = vget_high_s16(v_val_3);
- v_sum_01 = vmlal_s16(v_sum_01, v_val_0_hi, v_val_0_hi);
- v_sum_01 = vmlal_s16(v_sum_01, v_val_1_hi, v_val_1_hi);
- v_sum_23 = vmlal_s16(v_sum_23, v_val_2_hi, v_val_2_hi);
- v_sum_23 = vmlal_s16(v_sum_23, v_val_3_hi, v_val_3_hi);
- v_sum = vaddq_s32(v_sum, vcombine_s32(vqmovn_s64(vpaddlq_s32(v_sum_01)),
- vqmovn_s64(vpaddlq_s32(v_sum_23))));
-#endif
- c += 8;
- } while (c < width);
-
- v_acc_q = vpadalq_u32(v_acc_q, vreinterpretq_u32_s32(v_sum));
+ sum_squares[0] = vmlal_s16(sum_squares[0], s0, s0);
+ sum_squares[0] = vmlal_s16(sum_squares[0], s1, s1);
+ sum_squares[1] = vmlal_s16(sum_squares[1], s2, s2);
+ sum_squares[1] = vmlal_s16(sum_squares[1], s3, s3);
src += 4 * stride;
- r += 4;
- } while (r < height);
-#if defined(__aarch64__)
- return vaddvq_u64(v_acc_q);
-#else
- v_acc_q = vaddq_u64(v_acc_q, vextq_u64(v_acc_q, v_acc_q, 1));
- return vgetq_lane_u64(v_acc_q, 0);
-#endif
+ h += 4;
+ } while (h < height);
+
+ return horizontal_long_add_u32x4(
+ vreinterpretq_u32_s32(vaddq_s32(sum_squares[0], sum_squares[1])));
+}
+
+static INLINE uint64_t aom_sum_squares_2d_i16_nxn_neon(const int16_t *src,
+ int stride, int width,
+ int height) {
+ uint64x2_t sum_squares = vdupq_n_u64(0);
+
+ int h = 0;
+ do {
+ int32x4_t ss_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
+ int w = 0;
+ do {
+ const int16_t *s = src + w;
+ int16x8_t s0 = vld1q_s16(s + 0 * stride);
+ int16x8_t s1 = vld1q_s16(s + 1 * stride);
+ int16x8_t s2 = vld1q_s16(s + 2 * stride);
+ int16x8_t s3 = vld1q_s16(s + 3 * stride);
+
+ ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s0), vget_low_s16(s0));
+ ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s1), vget_low_s16(s1));
+ ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s2), vget_low_s16(s2));
+ ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s3), vget_low_s16(s3));
+ ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s0), vget_high_s16(s0));
+ ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s1), vget_high_s16(s1));
+ ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s2), vget_high_s16(s2));
+ ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s3), vget_high_s16(s3));
+ w += 8;
+ } while (w < width);
+
+ sum_squares = vpadalq_u32(
+ sum_squares, vreinterpretq_u32_s32(vaddq_s32(ss_row[0], ss_row[1])));
+
+ src += 4 * stride;
+ h += 4;
+ } while (h < height);
+
+ return horizontal_add_u64x2(sum_squares);
}
uint64_t aom_sum_squares_2d_i16_neon(const int16_t *src, int stride, int width,