Improve horiz 8-tap Neon I8MM convolve_2d_sr implementation
Improve the horiz 8-tap pass of Neon I8MM implementation of
convolve_2d_sr by replacing the USDOT with USMMLA and an extra UMLSL.
Change-Id: Ifaf3234f2379bf2110cdd46820983eca65f450fa
diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index ef3a8d3..20b9af4 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -890,42 +890,53 @@
}
static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples,
- const int8x8_t filters,
- const uint8x16x3_t permute_tbl,
- const int32x4_t horiz_const) {
- // Permute samples ready for dot product.
- // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
- // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
- // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
- uint8x16_t perm_samples[3] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
- vqtbl1q_u8(samples, permute_tbl.val[1]),
- vqtbl1q_u8(samples, permute_tbl.val[2]) };
+ const int8x16_t x_filter,
+ const uint8x8_t f0,
+ const uint8x16x2_t permute_tbl,
+ const int16x8_t horiz_const) {
+ // Permute samples ready for matrix multiply.
+ // { 1, 2, 3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8, 9, 10 }
+ // { 5, 6, 7, 8, 9, 10, 11, 12, 7, 8, 9, 10, 11, 12, 13, 14 }
+ uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+ vqtbl1q_u8(samples, permute_tbl.val[1]) };
- int32x4_t sum0123 =
- vusdotq_lane_s32(horiz_const, perm_samples[0], filters, 0);
- sum0123 = vusdotq_lane_s32(sum0123, perm_samples[1], filters, 1);
+ // Calculate partial 7-tap convolution.
+ int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], x_filter);
+ int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], x_filter);
+ int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
- int32x4_t sum4567 =
- vusdotq_lane_s32(horiz_const, perm_samples[1], filters, 0);
- sum4567 = vusdotq_lane_s32(sum4567, perm_samples[2], filters, 1);
+ // Apply tap 0 and accumulate.
+ sum = vreinterpretq_s16_u16(
+ vmlsl_u8(vreinterpretq_u16_s16(sum), vget_low_u8(samples), f0));
- // Narrow and re-pack.
+ sum = vaddq_s16(sum, horiz_const);
+
// We halved the convolution filter values so -1 from the right shift.
- return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
- vshrn_n_s32(sum4567, ROUND0_BITS - 1));
+ return vshrq_n_s16(sum, ROUND0_BITS - 1);
}
static inline void convolve_2d_sr_horiz_8tap_neon_i8mm(
const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
int im_h, const int16_t *x_filter_ptr) {
// Filter values are even, so halve to reduce intermediate precision reqs.
- const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+ const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+
+ // Stagger the filter for use with the matrix multiply instructions.
+ // { f1, f2, f3, f4, f5, f6, f7, 0, 0, f1, f2, f3, f4, f5, f6, f7 }
+ const uint8x16_t filter_idx = vld1q_u8(kFilterPermuteTbl);
+ const int8x16_t x_filter =
+ vqtbl1q_s8(vcombine_s8(x_filter_s8, vdup_n_s8(0)), filter_idx);
+
+ // Since f0 is always negative and s0 is unsigned, subtract (unsigned) s0 *
+ // -f0 to avoid signed overflow.
+ const uint8x8_t f0 = vdup_n_u8(-x_filter_ptr[0] >> 1);
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul8PermuteTbl);
const int bd = 8;
// This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
// shifts - which are generally faster than rounding shifts on modern CPUs.
// The outermost -1 is needed because we halved the filter values.
- const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
+ const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) +
(1 << ((ROUND0_BITS - 1) - 1)));
const uint8_t *src_ptr = src;
@@ -933,7 +944,6 @@
int dst_stride = im_stride;
int height = im_h;
- const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
do {
const uint8_t *s = src_ptr;
int16_t *d = dst_ptr;
@@ -943,10 +953,14 @@
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
- int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
- int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
- int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
+ int16x8_t d0 =
+ convolve8_8_2d_h(s0, x_filter, f0, permute_tbl, horiz_const);
+ int16x8_t d1 =
+ convolve8_8_2d_h(s1, x_filter, f0, permute_tbl, horiz_const);
+ int16x8_t d2 =
+ convolve8_8_2d_h(s2, x_filter, f0, permute_tbl, horiz_const);
+ int16x8_t d3 =
+ convolve8_8_2d_h(s3, x_filter, f0, permute_tbl, horiz_const);
store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -966,7 +980,8 @@
do {
uint8x16_t s0 = vld1q_u8(s);
- int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
+ int16x8_t d0 =
+ convolve8_8_2d_h(s0, x_filter, f0, permute_tbl, horiz_const);
vst1q_s16(d, d0);
s += 8;