Optimize Neon implementation of aom_int_pro_col
Unroll to operate on 4 rows per iteration. This unrolling also allows
us to use a 4D reduction that uses pairwise-add instructions, rather
than (slower) full-vector reduction instructions.
Change-Id: Ibd31c8c62a5b47c0fbe9d0ec1c241e236846967f
diff --git a/aom_dsp/arm/avg_neon.c b/aom_dsp/arm/avg_neon.c
index fa9a141..dadf373 100644
--- a/aom_dsp/arm/avg_neon.c
+++ b/aom_dsp/arm/avg_neon.c
@@ -9,6 +9,7 @@
*/
#include <arm_neon.h>
+#include <assert.h>
#include "config/aom_dsp_rtcd.h"
#include "aom/aom_integer.h"
@@ -120,24 +121,33 @@
void aom_int_pro_col_neon(int16_t *vbuf, const uint8_t *ref,
const int ref_stride, const int width,
const int height, int norm_factor) {
- for (int ht = 0; ht < height; ++ht) {
- uint16x8_t sum = vdupq_n_u16(0);
- for (int wd = 0; wd < width; wd += 16) {
- const uint8x16_t vec = vld1q_u8(ref + wd);
- sum = vaddq_u16(sum, vpaddlq_u8(vec));
+ assert(width % 16 == 0);
+ assert(height % 4 == 0);
+
+ const int16x4_t neg_norm_factor = vdup_n_s16(-norm_factor);
+ uint16x8_t sum[4];
+
+ int h = 0;
+ do {
+ sum[0] = vpaddlq_u8(vld1q_u8(ref + 0 * ref_stride));
+ sum[1] = vpaddlq_u8(vld1q_u8(ref + 1 * ref_stride));
+ sum[2] = vpaddlq_u8(vld1q_u8(ref + 2 * ref_stride));
+ sum[3] = vpaddlq_u8(vld1q_u8(ref + 3 * ref_stride));
+
+ for (int w = 16; w < width; w += 16) {
+ sum[0] = vpadalq_u8(sum[0], vld1q_u8(ref + 0 * ref_stride + w));
+ sum[1] = vpadalq_u8(sum[1], vld1q_u8(ref + 1 * ref_stride + w));
+ sum[2] = vpadalq_u8(sum[2], vld1q_u8(ref + 2 * ref_stride + w));
+ sum[3] = vpadalq_u8(sum[3], vld1q_u8(ref + 3 * ref_stride + w));
}
-#if defined(__aarch64__)
- vbuf[ht] = ((int16_t)vaddvq_u16(sum)) >> norm_factor;
-#else
- const uint32x4_t a = vpaddlq_u16(sum);
- const uint64x2_t b = vpaddlq_u32(a);
- const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)),
- vreinterpret_u32_u64(vget_high_u64(b)));
- vbuf[ht] = ((int16_t)vget_lane_u32(c, 0)) >> norm_factor;
-#endif
- ref += ref_stride;
- }
+ uint16x4_t sum_4d = vmovn_u32(horizontal_add_4d_u16x8(sum));
+ int16x4_t avg = vshl_s16(vreinterpret_s16_u16(sum_4d), neg_norm_factor);
+ vst1_s16(vbuf + h, avg);
+
+ ref += 4 * ref_stride;
+ h += 4;
+ } while (h < height);
}
// coeff: 16 bits, dynamic range [-32640, 32640].