Improve horiz 8-tap Neon I8MM convolve_2d_sr implementation

Improve the horiz 8-tap pass of Neon I8MM implementation of
convolve_2d_sr by replacing the USDOT with USMMLA and an extra UMLSL.

Change-Id: Ifaf3234f2379bf2110cdd46820983eca65f450fa
diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index ef3a8d3..20b9af4 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -890,42 +890,53 @@
 }
 
 static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples,
-                                         const int8x8_t filters,
-                                         const uint8x16x3_t permute_tbl,
-                                         const int32x4_t horiz_const) {
-  // Permute samples ready for dot product.
-  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
-  uint8x16_t perm_samples[3] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
-                                 vqtbl1q_u8(samples, permute_tbl.val[1]),
-                                 vqtbl1q_u8(samples, permute_tbl.val[2]) };
+                                         const int8x16_t x_filter,
+                                         const uint8x8_t f0,
+                                         const uint8x16x2_t permute_tbl,
+                                         const int16x8_t horiz_const) {
+  // Permute samples ready for matrix multiply.
+  // { 1,  2,  3,  4,  5,  6,  7,  8,  3,  4,  5,  6,  7,  8,  9, 10 }
+  // { 5,  6,  7,  8,  9, 10, 11, 12,  7,  8,  9, 10, 11, 12, 13, 14 }
+  uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+                                 vqtbl1q_u8(samples, permute_tbl.val[1]) };
 
-  int32x4_t sum0123 =
-      vusdotq_lane_s32(horiz_const, perm_samples[0], filters, 0);
-  sum0123 = vusdotq_lane_s32(sum0123, perm_samples[1], filters, 1);
+  // Calculate partial 7-tap convolution.
+  int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], x_filter);
+  int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], x_filter);
+  int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
 
-  int32x4_t sum4567 =
-      vusdotq_lane_s32(horiz_const, perm_samples[1], filters, 0);
-  sum4567 = vusdotq_lane_s32(sum4567, perm_samples[2], filters, 1);
+  // Apply tap 0 and accumulate.
+  sum = vreinterpretq_s16_u16(
+      vmlsl_u8(vreinterpretq_u16_s16(sum), vget_low_u8(samples), f0));
 
-  // Narrow and re-pack.
+  sum = vaddq_s16(sum, horiz_const);
+
   // We halved the convolution filter values so -1 from the right shift.
-  return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
-                      vshrn_n_s32(sum4567, ROUND0_BITS - 1));
+  return vshrq_n_s16(sum, ROUND0_BITS - 1);
 }
 
 static inline void convolve_2d_sr_horiz_8tap_neon_i8mm(
     const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
     int im_h, const int16_t *x_filter_ptr) {
   // Filter values are even, so halve to reduce intermediate precision reqs.
-  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+  const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+
+  // Stagger the filter for use with the matrix multiply instructions.
+  // { f1, f2, f3, f4, f5, f6, f7, 0, 0, f1, f2, f3, f4, f5, f6, f7 }
+  const uint8x16_t filter_idx = vld1q_u8(kFilterPermuteTbl);
+  const int8x16_t x_filter =
+      vqtbl1q_s8(vcombine_s8(x_filter_s8, vdup_n_s8(0)), filter_idx);
+
+  // Since f0 is always negative and s0 is unsigned, subtract (unsigned) s0 *
+  // -f0 to avoid signed overflow.
+  const uint8x8_t f0 = vdup_n_u8(-x_filter_ptr[0] >> 1);
+  const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMul8PermuteTbl);
 
   const int bd = 8;
   // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
   // shifts - which are generally faster than rounding shifts on modern CPUs.
   // The outermost -1 is needed because we halved the filter values.
-  const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
+  const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) +
                                             (1 << ((ROUND0_BITS - 1) - 1)));
 
   const uint8_t *src_ptr = src;
@@ -933,7 +944,6 @@
   int dst_stride = im_stride;
   int height = im_h;
 
-  const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
   do {
     const uint8_t *s = src_ptr;
     int16_t *d = dst_ptr;
@@ -943,10 +953,14 @@
       uint8x16_t s0, s1, s2, s3;
       load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-      int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
-      int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, permute_tbl, horiz_const);
-      int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, permute_tbl, horiz_const);
-      int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, permute_tbl, horiz_const);
+      int16x8_t d0 =
+          convolve8_8_2d_h(s0, x_filter, f0, permute_tbl, horiz_const);
+      int16x8_t d1 =
+          convolve8_8_2d_h(s1, x_filter, f0, permute_tbl, horiz_const);
+      int16x8_t d2 =
+          convolve8_8_2d_h(s2, x_filter, f0, permute_tbl, horiz_const);
+      int16x8_t d3 =
+          convolve8_8_2d_h(s3, x_filter, f0, permute_tbl, horiz_const);
 
       store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -966,7 +980,8 @@
 
     do {
       uint8x16_t s0 = vld1q_u8(s);
-      int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, permute_tbl, horiz_const);
+      int16x8_t d0 =
+          convolve8_8_2d_h(s0, x_filter, f0, permute_tbl, horiz_const);
       vst1q_s16(d, d0);
 
       s += 8;