Improve 8-tap Neon I8MM convolve_x_sr implementation

Improve the 8-tap Neon I8MM implementation of convolve_x_sr by replacing
the USDOT with USMMLA and an extra UMLSL.

Change-Id: Ibe77d8d6ae8504a30c5d0fbd9ac6763728c53470
diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index 0ecf6b2..ef3a8d3 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -32,6 +32,19 @@
   3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30
 };
 
+DECLARE_ALIGNED(16, static const uint8_t, kMatMul8PermuteTbl[32]) = {
+  // clang-format off
+  1,  2,  3,  4,  5,  6,  7,  8,  3,  4,  5,  6,  7,  8,  9, 10,
+  5,  6,  7,  8,  9, 10, 11, 12,  7,  8,  9, 10, 11, 12, 13, 14
+  // clang-format on
+};
+
+DECLARE_ALIGNED(16, static const uint8_t, kFilterPermuteTbl[16]) = {
+  // clang-format off
+  1,  2,  3,  4,  5,  6,  7, 16, 16,  1,  2,  3,  4,  5,  6,  7
+  // clang-format on
+};
+
 static inline int16x4_t convolve12_4_x(uint8x16_t samples[2],
                                        const int8x16_t filter[2],
                                        const uint8x16_t permute_tbl,
@@ -105,7 +118,7 @@
   const int32x4_t horiz_const = vdupq_n_s32(1 << (ROUND0_BITS - 1));
 
   if (w <= 4) {
-    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
+    const uint8x16_t permute_tbl = vld1q_u8(kMatMul6PermuteTbl);
 
     do {
       uint8x16_t s0[2], s1[2], s2[2], s3[2];
@@ -128,7 +141,7 @@
       h -= 4;
     } while (h != 0);
   } else {
-    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
 
     do {
       const uint8_t *s = src;
@@ -158,35 +171,49 @@
   }
 }
 
-static inline uint8x8_t convolve8_8_x(uint8x16_t samples, const int8x8_t filter,
-                                      const uint8x16x3_t permute_tbl,
-                                      const int32x4_t horiz_const) {
-  // Permute samples ready for dot product.
-  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
-  uint8x16_t perm_samples[3] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
-                                 vqtbl1q_u8(samples, permute_tbl.val[1]),
-                                 vqtbl1q_u8(samples, permute_tbl.val[2]) };
+static inline uint8x8_t convolve8_8_x(uint8x16_t samples,
+                                      const int8x16_t filter,
+                                      const uint8x8_t f0,
+                                      const uint8x16x2_t permute_tbl,
+                                      const int16x8_t horiz_const) {
+  // Permute samples ready for matrix multiply.
+  // { 1,  2,  3,  4,  5,  6,  7,  8,  3,  4,  5,  6,  7,  8,  9, 10 }
+  // { 5,  6,  7,  8,  9, 10, 11, 12,  7,  8,  9, 10, 11, 12, 13, 14 }
+  uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+                                 vqtbl1q_u8(samples, permute_tbl.val[1]) };
 
-  int32x4_t sum0123 = vusdotq_lane_s32(horiz_const, perm_samples[0], filter, 0);
-  sum0123 = vusdotq_lane_s32(sum0123, perm_samples[1], filter, 1);
+  // Calculate partial 7-tap convolution.
+  int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], filter);
+  int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], filter);
+  int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
 
-  int32x4_t sum4567 = vusdotq_lane_s32(horiz_const, perm_samples[1], filter, 0);
-  sum4567 = vusdotq_lane_s32(sum4567, perm_samples[2], filter, 1);
+  // Apply tap 0 and accumulate.
+  sum = vreinterpretq_s16_u16(
+      vmlsl_u8(vreinterpretq_u16_s16(sum), vget_low_u8(samples), f0));
 
-  int16x8_t sum_s16 = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
+  sum = vaddq_s16(sum, horiz_const);
+
   // We halved the convolution filter values so - 1 from the right shift.
-  return vqrshrun_n_s16(sum_s16, FILTER_BITS - 1);
+  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
 }
 
 static inline void convolve_x_sr_8tap_neon_i8mm(
     const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
     ptrdiff_t dst_stride, int width, int height, const int16_t *filter_x,
-    const int32x4_t horiz_const) {
+    const int16x8_t horiz_const) {
   // Filter values are even, so halve to reduce intermediate precision reqs.
-  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(filter_x), 1);
-  const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
+  const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(filter_x), 1);
+
+  // Stagger the filter for use with the matrix multiply instructions.
+  // { f1, f2, f3, f4, f5, f6, f7, 0, 0, f1, f2, f3, f4, f5, f6, f7 }
+  const uint8x16_t filter_idx = vld1q_u8(kFilterPermuteTbl);
+  const int8x16_t x_filter =
+      vqtbl1q_s8(vcombine_s8(x_filter_s8, vdup_n_s8(0)), filter_idx);
+
+  // Since f0 is always negative and s0 is unsigned, subtract (unsigned) s0 *
+  // -f0 to avoid signed overflow.
+  const uint8x8_t f0 = vdup_n_u8(-filter_x[0] >> 1);
+  const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul8PermuteTbl);
 
   do {
     const uint8_t *s = src;
@@ -197,10 +224,10 @@
       uint8x16_t s0, s1, s2, s3;
       load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-      uint8x8_t d0 = convolve8_8_x(s0, x_filter, permute_tbl, horiz_const);
-      uint8x8_t d1 = convolve8_8_x(s1, x_filter, permute_tbl, horiz_const);
-      uint8x8_t d2 = convolve8_8_x(s2, x_filter, permute_tbl, horiz_const);
-      uint8x8_t d3 = convolve8_8_x(s3, x_filter, permute_tbl, horiz_const);
+      uint8x8_t d0 = convolve8_8_x(s0, x_filter, f0, permute_tbl, horiz_const);
+      uint8x8_t d1 = convolve8_8_x(s1, x_filter, f0, permute_tbl, horiz_const);
+      uint8x8_t d2 = convolve8_8_x(s2, x_filter, f0, permute_tbl, horiz_const);
+      uint8x8_t d3 = convolve8_8_x(s3, x_filter, f0, permute_tbl, horiz_const);
 
       store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -262,7 +289,7 @@
       vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
 
   if (width == 4) {
-    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
+    const uint8x16_t permute_tbl = vld1q_u8(kMatMul6PermuteTbl);
     do {
       uint8x16_t s0, s1, s2, s3;
       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
@@ -283,7 +310,7 @@
       height -= 4;
     } while (height != 0);
   } else {
-    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
     do {
       const uint8_t *s = src;
       uint8_t *d = dst;
@@ -335,11 +362,12 @@
   // right shift by FILTER_BITS instead of two rounding right shifts: first by
   // ROUND0_BITS, and then subsequently by FILTER_BITS - ROUND0_BITS.
   // Halve the total because we will halve the filter values.
-  const int32x4_t horiz_const = vdupq_n_s32((1 << ((ROUND0_BITS - 1)) / 2));
+  const int32x4_t horiz_const_s32 = vdupq_n_s32(1 << (ROUND0_BITS - 1) / 2);
+  const int16x8_t horiz_const_s16 = vdupq_n_s16(1 << (ROUND0_BITS - 1) / 2);
 
   if (filter_taps <= 6) {
     convolve_x_sr_6tap_neon_i8mm(src + 1, src_stride, dst, dst_stride, w, h,
-                                 x_filter_ptr, horiz_const);
+                                 x_filter_ptr, horiz_const_s32);
     return;
   }
 
@@ -350,7 +378,7 @@
   }
 
   convolve_x_sr_8tap_neon_i8mm(src, src_stride, dst, dst_stride, w, h,
-                               x_filter_ptr, horiz_const);
+                               x_filter_ptr, horiz_const_s16);
 }
 
 static inline int16x4_t convolve12_4_y(const uint8x16_t s0, const uint8x16_t s1,
@@ -1131,7 +1159,7 @@
   const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
                                             (1 << ((ROUND0_BITS - 1) - 1)));
   const int16x8_t vert_const = vdupq_n_s16(1 << (bd - 1));
-  const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+  const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
 
   do {
     const uint8_t *s = src;
@@ -1208,7 +1236,7 @@
   const int16x8_t vert_const = vdupq_n_s16(1 << (bd - 1));
 
   if (w == 4) {
-    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
+    const uint8x16_t permute_tbl = vld1q_u8(kMatMul6PermuteTbl);
     uint8x16_t h_s0, h_s1, h_s2;
     load_u8_16x3(src, src_stride, &h_s0, &h_s1, &h_s2);
 
@@ -1251,7 +1279,7 @@
       h -= 4;
     } while (h != 0);
   } else {
-    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
 
     do {
       int height = h;
diff --git a/av1/common/arm/convolve_neon_i8mm.h b/av1/common/arm/convolve_neon_i8mm.h
index 71b7461..78dadec 100644
--- a/av1/common/arm/convolve_neon_i8mm.h
+++ b/av1/common/arm/convolve_neon_i8mm.h
@@ -29,7 +29,7 @@
   8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
 };
 
-DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
+DECLARE_ALIGNED(16, static const uint8_t, kMatMul6PermuteTbl[32]) = {
   // clang-format off
   0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9,
   4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13
@@ -110,7 +110,7 @@
       vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
 
   if (w <= 4) {
-    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
+    const uint8x16_t permute_tbl = vld1q_u8(kMatMul6PermuteTbl);
 
     do {
       uint8x16_t s0[2], s1[2], s2[2], s3[2];
@@ -141,7 +141,7 @@
     } while (--h != 0);
 
   } else {
-    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul6PermuteTbl);
 
     do {
       const uint8_t *s = src_ptr;