Implement 8-tap aom_convolve8_horiz* using Neon I8MM
Decomposing the 8-tap filter to a 7-tap filter followed by a 1-tap
filter enables us to use the Neon I8MM USMMLA instructions - which is
faster than the existing USDOT approach.
Change-Id: I8f79394a23571f3d5b901b4b7cdb85e9c7734abf
diff --git a/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index 5f53e69..b0bb2fc 100644
--- a/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -25,17 +25,24 @@
#include "aom_dsp/arm/transpose_neon.h"
#include "aom_ports/mem.h"
-DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
+DECLARE_ALIGNED(16, static const uint8_t, kMatMul6PermuteTbl[32]) = {
// clang-format off
0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9,
4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13
// clang-format on
};
-DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
- 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
- 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
- 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
+DECLARE_ALIGNED(16, static const uint8_t, kMatMul8PermuteTbl[32]) = {
+ // clang-format off
+ 1, 2, 3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8, 9, 10,
+ 5, 6, 7, 8, 9, 10, 11, 12, 7, 8, 9, 10, 11, 12, 13, 14
+ // clang-format on
+};
+
+DECLARE_ALIGNED(16, static const uint8_t, kMatMul8FilterPermuteTbl[16]) = {
+ // clang-format off
+ 1, 2, 3, 4, 5, 6, 7, 16, 16, 1, 2, 3, 4, 5, 6, 7
+ // clang-format on
};
DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = {
@@ -48,44 +55,41 @@
};
static inline int16x4_t convolve8_4_h(const uint8x16_t samples,
- const int8x8_t filters,
- const uint8x16x2_t permute_tbl) {
- // Permute samples ready for dot product.
- // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
- // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
- uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
- vqtbl1q_u8(samples, permute_tbl.val[1]) };
+ const int8x16_t filters,
+ const uint8x16_t permute_tbl) {
+ // Permute samples ready for matrix multiply.
+ // { 1, 2, 3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8, 9, 10 }
+ uint8x16_t perm_samples = vqtbl1q_u8(samples, permute_tbl);
- int32x4_t sum =
- vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
- sum = vusdotq_lane_s32(sum, permuted_samples[1], filters, 1);
+ // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+ // (filter), destructively accumulating into the destination register.
+ int32x4_t sum = vusmmlaq_s32(vdupq_n_s32(0), perm_samples, filters);
- // Further narrowing and packing is performed by the caller.
+ // Tap 0, as well as further narrowing and packing, is applied by the caller.
return vmovn_s32(sum);
}
static inline uint8x8_t convolve8_8_h(const uint8x16_t samples,
- const int8x8_t filters,
- const uint8x16x3_t permute_tbl) {
- // Permute samples ready for dot product.
- // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
- // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
- // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
- uint8x16_t permuted_samples[3] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
- vqtbl1q_u8(samples, permute_tbl.val[1]),
- vqtbl1q_u8(samples, permute_tbl.val[2]) };
+ const int8x16_t filters,
+ const uint8x8_t f0,
+ const uint8x16x2_t permute_tbl) {
+ // Permute samples ready for matrix multiply.
+ // { 1, 2, 3, 4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 8, 9, 10 }
+ // { 5, 6, 7, 8, 9, 10, 11, 12, 7, 8, 9, 10, 11, 12, 13, 14 }
+ uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+ vqtbl1q_u8(samples, permute_tbl.val[1]) };
- // First 4 output values.
- int32x4_t sum0 =
- vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
- sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filters, 1);
- // Second 4 output values.
- int32x4_t sum1 =
- vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filters, 0);
- sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filters, 1);
+ // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+ // (filter), destructively accumulating into the destination register.
+ int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], filters);
+ int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], filters);
// Narrow and re-pack.
- int16x8_t sum = vcombine_s16(vmovn_s32(sum0), vmovn_s32(sum1));
+ int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
+ // Apply tap 0 and accumulate.
+ sum = vreinterpretq_s16_u16(
+ vmlsl_u8(vreinterpretq_u16_s16(sum), vget_low_u8(samples), f0));
+
// We halved the filter values so -1 from right shift.
return vqrshrun_n_s16(sum, FILTER_BITS - 1);
}
@@ -94,21 +98,40 @@
const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
ptrdiff_t dst_stride, const int16_t *filter_x, int w, int h) {
// Filter values are even, so halve to reduce intermediate precision reqs.
- const int8x8_t filter = vshrn_n_s16(vld1q_s16(filter_x), 1);
+ const int8x8_t filter_s8 = vshrn_n_s16(vld1q_s16(filter_x), 1);
+ // Stagger the filter for use with the matrix multiply instructions.
+ // { f1, f2, f3, f4, f5, f6, f7, 0, 0, f1, f2, f3, f4, f5, f6, f7 }
+ const uint8x16_t filter_idx = vld1q_u8(kMatMul8FilterPermuteTbl);
+ const int8x16_t filter =
+ vqtbl1q_s8(vcombine_s8(filter_s8, vdup_n_s8(0)), filter_idx);
+
+ // Since f0 is always negative and samples are unsigned, subtract (unsigned)
+ // s0 * -f0 to avoid signed overflow.
+ const uint8x8_t f0 = vdup_n_u8(-filter_x[0] >> 1);
if (w == 4) {
- const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
+ const uint8x16_t perm_tbl = vld1q_u8(kMatMul8PermuteTbl);
+
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
+ uint8x8_t s01 = load_u8_4x2(src + 0 * src_stride, src_stride);
+ uint8x8_t s23 = load_u8_4x2(src + 2 * src_stride, src_stride);
- int16x4_t d0 = convolve8_4_h(s0, filter, perm_tbl);
- int16x4_t d1 = convolve8_4_h(s1, filter, perm_tbl);
- int16x4_t d2 = convolve8_4_h(s2, filter, perm_tbl);
- int16x4_t d3 = convolve8_4_h(s3, filter, perm_tbl);
- // We halved the filter values so -1 from right shift.
- uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS - 1);
- uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS - 1);
+ int16x4_t t0 = convolve8_4_h(s0, filter, perm_tbl);
+ int16x4_t t1 = convolve8_4_h(s1, filter, perm_tbl);
+ int16x4_t t2 = convolve8_4_h(s2, filter, perm_tbl);
+ int16x4_t t3 = convolve8_4_h(s3, filter, perm_tbl);
+ // Apply tap 0 and accumulate.
+ int16x8_t t01 = vcombine_s16(t0, t1);
+ int16x8_t t23 = vcombine_s16(t2, t3);
+ t01 =
+ vreinterpretq_s16_u16(vmlsl_u8(vreinterpretq_u16_s16(t01), s01, f0));
+ t23 =
+ vreinterpretq_s16_u16(vmlsl_u8(vreinterpretq_u16_s16(t23), s23, f0));
+ // We halved the filter values to -1 from right shift.
+ uint8x8_t d01 = vqrshrun_n_s16(t01, FILTER_BITS - 1);
+ uint8x8_t d23 = vqrshrun_n_s16(t23, FILTER_BITS - 1);
store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
@@ -118,7 +141,7 @@
h -= 4;
} while (h > 0);
} else {
- const uint8x16x3_t perm_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
+ const uint8x16x2_t perm_tbl = vld1q_u8_x2(kMatMul8PermuteTbl);
do {
int width = w;
@@ -128,10 +151,10 @@
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- uint8x8_t d0 = convolve8_8_h(s0, filter, perm_tbl);
- uint8x8_t d1 = convolve8_8_h(s1, filter, perm_tbl);
- uint8x8_t d2 = convolve8_8_h(s2, filter, perm_tbl);
- uint8x8_t d3 = convolve8_8_h(s3, filter, perm_tbl);
+ uint8x8_t d0 = convolve8_8_h(s0, filter, f0, perm_tbl);
+ uint8x8_t d1 = convolve8_8_h(s1, filter, f0, perm_tbl);
+ uint8x8_t d2 = convolve8_8_h(s2, filter, f0, perm_tbl);
+ uint8x8_t d3 = convolve8_8_h(s3, filter, f0, perm_tbl);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -192,7 +215,7 @@
vcombine_s8(vext_s8(x_filter, x_filter, 1), x_filter);
if (width == 4) {
- const uint8x16_t perm_tbl = vld1q_u8(kMatMulPermuteTbl);
+ const uint8x16_t perm_tbl = vld1q_u8(kMatMul6PermuteTbl);
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
@@ -213,7 +236,7 @@
height -= 4;
} while (height > 0);
} else {
- const uint8x16x2_t perm_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+ const uint8x16x2_t perm_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
do {
int w = width;
@@ -314,7 +337,7 @@
src += 7 * src_stride;
// This operation combines a conventional transpose and the sample permute
- // (see horizontal case) required before computing the dot product.
+ // required before computing the dot product.
uint8x16_t s0123, s1234, s2345, s3456;
transpose_concat_elems_u8_4x4(s0, s1, s2, s3, &s0123);
transpose_concat_elems_u8_4x4(s1, s2, s3, s4, &s1234);
@@ -368,7 +391,7 @@
s += 7 * src_stride;
// This operation combines a conventional transpose and the sample permute
- // (see horizontal case) required before computing the dot product.
+ // required before computing the dot product.
uint8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
s3456_lo, s3456_hi;
transpose_concat_elems_u8_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi);
@@ -471,7 +494,7 @@
src += 4 * src_stride;
// This operation combines a conventional transpose and the sample permute
- // (see horizontal case) required before computing the dot product.
+ // required before computing the dot product.
uint8x16_t s0123;
transpose_concat_elems_u8_4x4(s0, s1, s2, s3, &s0123);
@@ -519,7 +542,7 @@
s += 4 * src_stride;
// This operation combines a conventional transpose and the sample permute
- // (see horizontal case) required before computing the dot product.
+ // required before computing the dot product.
uint8x16_t s0123_lo, s0123_hi;
transpose_concat_elems_u8_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi);