Speed improvement in convolve 2d_sr and x_sr neon for 32-bit
Observed gains w.r.t. existing neon code in unit test:
av1_convolve_2d_sr_neon - 20%
av1_convolve_x_sr_neon - 30%
Change-Id: Ieeb5c952fa33c1a15fb4f366b1d64f8c8f5e7ffa
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index b2e45c0..93dccc8 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -68,6 +68,31 @@
return vqmovun_s16(sum);
}
+static INLINE uint8x8_t convolve8_horiz_4x1(
+ const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
+ const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
+ const int16x4_t s6, const int16x4_t s7, const int16_t *filter,
+ const int16x4_t shift_round_0, const int16x4_t shift_by_bits) {
+ int16x4_t sum;
+
+ sum = vmul_n_s16(s0, filter[0]);
+ sum = vmla_n_s16(sum, s1, filter[1]);
+ sum = vmla_n_s16(sum, s2, filter[2]);
+ sum = vmla_n_s16(sum, s5, filter[5]);
+ sum = vmla_n_s16(sum, s6, filter[6]);
+ sum = vmla_n_s16(sum, s7, filter[7]);
+ /* filter[3] can take a max value of 128. So the max value of the result :
+ * 128*255 + sum > 16 bits
+ */
+ sum = vqadd_s16(sum, vmul_n_s16(s3, filter[3]));
+ sum = vqadd_s16(sum, vmul_n_s16(s4, filter[4]));
+
+ sum = vqrshl_s16(sum, shift_round_0);
+ sum = vqrshl_s16(sum, shift_by_bits);
+
+ return vqmovun_s16(vcombine_s16(sum, sum));
+}
+
static INLINE uint8x8_t convolve8_vert_8x4(
const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
@@ -175,7 +200,10 @@
(void)conv_params;
(void)filter_params_y;
- uint8x8_t t0, t1, t2, t3;
+ uint8x8_t t0;
+#if defined(__aarch64__)
+ uint8x8_t t1, t2, t3;
+#endif
assert(bits >= 0);
assert((FILTER_BITS - conv_params->round_1) >= 0 ||
@@ -188,7 +216,7 @@
const int16x8_t shift_by_bits = vdupq_n_s16(-bits);
src -= horiz_offset;
-
+#if defined(__aarch64__)
if (h == 4) {
uint8x8_t d01, d23;
int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3;
@@ -275,12 +303,18 @@
w -= 4;
} while (w > 0);
} else {
+#endif
int width;
const uint8_t *s;
+ int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
+
+#if defined(__aarch64__)
+ int16x8_t s8, s9, s10;
uint8x8_t t4, t5, t6, t7;
- int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10;
+#endif
if (w <= 4) {
+#if defined(__aarch64__)
do {
load_u8_8x8(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
transpose_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
@@ -387,10 +421,49 @@
}
h -= 8;
} while (h > 0);
+#else
+ int16x8_t tt0;
+ int16x4_t x0, x1, x2, x3, x4, x5, x6, x7;
+ const int16x4_t shift_round_0_low = vget_low_s16(shift_round_0);
+ const int16x4_t shift_by_bits_low = vget_low_s16(shift_by_bits);
+ do {
+ t0 = vld1_u8(src); // a0 a1 a2 a3 a4 a5 a6 a7
+ tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ x0 = vget_low_s16(tt0); // a0 a1 a2 a3
+ x4 = vget_high_s16(tt0); // a4 a5 a6 a7
+
+ t0 = vld1_u8(src + 8); // a8 a9 a10 a11 a12 a13 a14 a15
+ tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ x7 = vget_low_s16(tt0); // a8 a9 a10 a11
+
+ x1 = vext_s16(x0, x4, 1); // a1 a2 a3 a4
+ x2 = vext_s16(x0, x4, 2); // a2 a3 a4 a5
+ x3 = vext_s16(x0, x4, 3); // a3 a4 a5 a6
+ x5 = vext_s16(x4, x7, 1); // a5 a6 a7 a8
+ x6 = vext_s16(x4, x7, 2); // a6 a7 a8 a9
+ x7 = vext_s16(x4, x7, 3); // a7 a8 a9 a10
+
+ src += src_stride;
+
+ t0 = convolve8_horiz_4x1(x0, x1, x2, x3, x4, x5, x6, x7, x_filter,
+ shift_round_0_low, shift_by_bits_low);
+
+ if (w == 4) {
+ vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(t0),
+ 0); // 00 01 02 03
+ dst += dst_stride;
+ } else if (w == 2) {
+ vst1_lane_u16((uint16_t *)dst, vreinterpret_u16_u8(t0), 0); // 00 01
+ dst += dst_stride;
+ }
+ h -= 1;
+ } while (h > 0);
+#endif
} else {
uint8_t *d;
- int16x8_t s11, s12, s13, s14;
-
+ int16x8_t s11;
+#if defined(__aarch64__)
+ int16x8_t s12, s13, s14;
do {
__builtin_prefetch(src + 0 * src_stride);
__builtin_prefetch(src + 1 * src_stride);
@@ -479,8 +552,47 @@
dst += 8 * dst_stride;
h -= 8;
} while (h > 0);
+#else
+ do {
+ t0 = vld1_u8(src); // a0 a1 a2 a3 a4 a5 a6 a7
+ s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+
+ width = w;
+ s = src + 8;
+ d = dst;
+ __builtin_prefetch(dst);
+
+ do {
+ t0 = vld1_u8(s); // a8 a9 a10 a11 a12 a13 a14 a15
+ s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ s11 = s0;
+ s0 = s7;
+
+ s1 = vextq_s16(s11, s7, 1); // a1 a2 a3 a4 a5 a6 a7 a8
+ s2 = vextq_s16(s11, s7, 2); // a2 a3 a4 a5 a6 a7 a8 a9
+ s3 = vextq_s16(s11, s7, 3); // a3 a4 a5 a6 a7 a8 a9 a10
+ s4 = vextq_s16(s11, s7, 4); // a4 a5 a6 a7 a8 a9 a10 a11
+ s5 = vextq_s16(s11, s7, 5); // a5 a6 a7 a8 a9 a10 a11 a12
+ s6 = vextq_s16(s11, s7, 6); // a6 a7 a8 a9 a10 a11 a12 a13
+ s7 = vextq_s16(s11, s7, 7); // a7 a8 a9 a10 a11 a12 a13 a14
+
+ t0 = convolve8_horiz_8x8(s11, s1, s2, s3, s4, s5, s6, s7, x_filter,
+ shift_round_0, shift_by_bits);
+ vst1_u8(d, t0);
+
+ s += 8;
+ d += 8;
+ width -= 8;
+ } while (width > 0);
+ src += src_stride;
+ dst += dst_stride;
+ h -= 1;
+ } while (h > 0);
+#endif
}
+#if defined(__aarch64__)
}
+#endif
}
void av1_convolve_y_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
@@ -736,7 +848,10 @@
ConvolveParams *conv_params) {
int im_dst_stride;
int width, height;
- uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
+ uint8x8_t t0;
+#if defined(__aarch64__)
+ uint8x8_t t1, t2, t3, t4, t5, t6, t7;
+#endif
DECLARE_ALIGNED(16, int16_t,
im_block[(MAX_SB_SIZE + HORIZ_EXTRA_ROWS) * MAX_SB_SIZE]);
@@ -774,13 +889,18 @@
assert(conv_params->round_0 > 0);
if (w <= 4) {
- int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3;
+ int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, d0;
+#if defined(__aarch64__)
+ int16x4_t s8, s9, s10, d1, d2, d3;
+#endif
const int16x4_t horiz_const = vdup_n_s16((1 << (bd + FILTER_BITS - 2)));
const int16x4_t shift_round_0 = vdup_n_s16(-(conv_params->round_0 - 1));
do {
s = src_ptr;
+
+#if defined(__aarch64__)
__builtin_prefetch(s + 0 * src_stride);
__builtin_prefetch(s + 1 * src_stride);
__builtin_prefetch(s + 2 * src_stride);
@@ -839,16 +959,56 @@
src_ptr += 4 * src_stride;
dst_ptr += 4 * im_dst_stride;
height -= 4;
+#else
+ int16x8_t tt0;
+
+ __builtin_prefetch(s);
+
+ t0 = vld1_u8(s); // a0 a1 a2 a3 a4 a5 a6 a7
+ tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ s0 = vget_low_s16(tt0);
+ s4 = vget_high_s16(tt0);
+
+ __builtin_prefetch(dst_ptr);
+ s += 8;
+
+ t0 = vld1_u8(s); // a8 a9 a10 a11 a12 a13 a14 a15
+ s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+
+ s1 = vext_s16(s0, s4, 1); // a1 a2 a3 a4
+ s2 = vext_s16(s0, s4, 2); // a2 a3 a4 a5
+ s3 = vext_s16(s0, s4, 3); // a3 a4 a5 a6
+ s5 = vext_s16(s4, s7, 1); // a5 a6 a7 a8
+ s6 = vext_s16(s4, s7, 2); // a6 a7 a8 a9
+ s7 = vext_s16(s4, s7, 3); // a7 a8 a9 a10
+
+ d0 = convolve8_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter_tmp,
+ horiz_const, shift_round_0);
+
+ if (w == 4) {
+ vst1_s16(dst_ptr, d0);
+ dst_ptr += im_dst_stride;
+ } else if (w == 2) {
+ vst1_lane_u32((uint32_t *)dst_ptr, vreinterpret_u32_s16(d0), 0);
+ dst_ptr += im_dst_stride;
+ }
+
+ src_ptr += src_stride;
+ height -= 1;
+#endif
} while (height > 0);
} else {
int16_t *d_tmp;
+ int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, res0;
+#if defined(__aarch64__)
+ int16x8_t s8, s9, s10, res1, res2, res3, res4, res5, res6, res7;
int16x8_t s11, s12, s13, s14;
- int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10;
- int16x8_t res0, res1, res2, res3, res4, res5, res6, res7;
+#endif
const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)));
const int16x8_t shift_round_0 = vdupq_n_s16(-(conv_params->round_0 - 1));
+#if defined(__aarch64__)
do {
__builtin_prefetch(src_ptr + 0 * src_stride);
__builtin_prefetch(src_ptr + 1 * src_stride);
@@ -936,6 +1096,45 @@
dst_ptr += 8 * im_dst_stride;
height -= 8;
} while (height > 0);
+#else
+ do {
+ t0 = vld1_u8(src_ptr);
+ s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); // a0 a1 a2 a3 a4 a5 a6 a7
+
+ width = w;
+ s = src_ptr + 8;
+ d_tmp = dst_ptr;
+
+ __builtin_prefetch(dst_ptr);
+
+ do {
+ t0 = vld1_u8(s); // a8 a9 a10 a11 a12 a13 a14 a15
+ s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ int16x8_t sum = s0;
+ s0 = s7;
+
+ s1 = vextq_s16(sum, s7, 1); // a1 a2 a3 a4 a5 a6 a7 a8
+ s2 = vextq_s16(sum, s7, 2); // a2 a3 a4 a5 a6 a7 a8 a9
+ s3 = vextq_s16(sum, s7, 3); // a3 a4 a5 a6 a7 a8 a9 a10
+ s4 = vextq_s16(sum, s7, 4); // a4 a5 a6 a7 a8 a9 a10 a11
+ s5 = vextq_s16(sum, s7, 5); // a5 a6 a7 a8 a9 a10 a11 a12
+ s6 = vextq_s16(sum, s7, 6); // a6 a7 a8 a9 a10 a11 a12 a13
+ s7 = vextq_s16(sum, s7, 7); // a7 a8 a9 a10 a11 a12 a13 a14
+
+ res0 = convolve8_8x8_s16(sum, s1, s2, s3, s4, s5, s6, s7, x_filter_tmp,
+ horiz_const, shift_round_0);
+
+ vst1q_s16(d_tmp, res0);
+
+ s += 8;
+ d_tmp += 8;
+ width -= 8;
+ } while (width > 0);
+ src_ptr += src_stride;
+ dst_ptr += im_dst_stride;
+ height -= 1;
+ } while (height > 0);
+#endif
}
// vertical