Propagate constants in Neon av1_convolve_2d_sr_horiz functions

Rounding modes and other convolution parameters are known ahead of
time. This patch propagates the values into the Neon code paths for
av1_convolve_2d_sr_horiz - enabling us to make some useful
simplifications and optimizations.

Co-authored by: Jonathan Wright <jonathan.wright@arm.com>

Change-Id: Ifa16317fca8d805fc5f7332ea5274a53eca0e3f9
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index 10e60cb..7377531 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -136,6 +136,100 @@
 
 #elif defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
+static INLINE int16x8_t convolve8_horiz_8_sdot(uint8x16_t samples,
+                                               const int8x8_t filters,
+                                               const int32x4_t correction,
+                                               const uint8x16_t range_limit,
+                                               const uint8x16x3_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[3];
+  int32x4_t sum[2];
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  /* First 4 output values. */
+  sum[0] = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
+  sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
+  /* Second 4 output values. */
+  sum[1] = vdotq_lane_s32(correction, permuted_samples[1], filters, 0);
+  sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
+
+  /* Narrow and re-pack. */
+  return vcombine_s16(vmovn_s32(sum[0]), vmovn_s32(sum[1]));
+}
+
+static INLINE int16x4_t convolve12_horiz_4_sdot(
+    uint8x16_t samples, const int8x16_t filters, const int32x4_t correction,
+    const uint8x16_t range_limit, const uint8x16x3_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[3];
+  int32x4_t sum;
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  /* First 4 output values. */
+  sum = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
+  sum = vdotq_laneq_s32(sum, permuted_samples[1], filters, 1);
+  sum = vdotq_laneq_s32(sum, permuted_samples[2], filters, 2);
+
+  /* Narrow and re-pack. */
+  return vshrn_n_s32(sum, ROUND0_BITS);
+}
+
+static INLINE int16x8_t convolve12_horiz_8_sdot(
+    uint8x16_t samples0, uint8x16_t samples1, const int8x16_t filters,
+    const int32x4_t correction, const uint8x16_t range_limit,
+    const uint8x16x3_t permute_tbl) {
+  int8x16_t clamped_samples[2], permuted_samples[4];
+  int32x4_t sum[2];
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples[0] = vreinterpretq_s8_u8(vsubq_u8(samples0, range_limit));
+  clamped_samples[1] = vreinterpretq_s8_u8(vsubq_u8(samples1, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[2]);
+  /* {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 } */
+  permuted_samples[3] = vqtbl1q_s8(clamped_samples[1], permute_tbl.val[2]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  /* First 4 output values. */
+  sum[0] = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
+  sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[1], filters, 1);
+  sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[2], filters, 2);
+  /* Second 4 output values. */
+  sum[1] = vdotq_laneq_s32(correction, permuted_samples[1], filters, 0);
+  sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[2], filters, 1);
+  sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[3], filters, 2);
+
+  /* Narrow and re-pack. */
+  return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS),
+                      vshrn_n_s32(sum[1], ROUND0_BITS));
+}
+
 static INLINE int16x4_t convolve12_4_sdot(uint8x16_t samples,
                                           const int8x16_t filters,
                                           const int32x4_t correction,
@@ -2125,10 +2219,66 @@
 
 #if defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
 
-static INLINE void av1_convolve_2d_sr_horiz_12tap_neon(
+static INLINE int16x4_t convolve12_horiz_4_usdot(uint8x16_t samples,
+                                                 const int8x16_t filters,
+                                                 const uint8x16x3_t permute_tbl,
+                                                 int32x4_t horiz_const) {
+  uint8x16_t permuted_samples[3];
+  int32x4_t sum;
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
+
+  /* First 4 output values. */
+  sum = vusdotq_laneq_s32(horiz_const, permuted_samples[0], filters, 0);
+  sum = vusdotq_laneq_s32(sum, permuted_samples[1], filters, 1);
+  sum = vusdotq_laneq_s32(sum, permuted_samples[2], filters, 2);
+
+  /* Narrow and re-pack. */
+  return vshrn_n_s32(sum, ROUND0_BITS);
+}
+
+static INLINE int16x8_t convolve12_horiz_8_usdot(uint8x16_t samples0,
+                                                 uint8x16_t samples1,
+                                                 const int8x16_t filters,
+                                                 const uint8x16x3_t permute_tbl,
+                                                 const int32x4_t horiz_const) {
+  uint8x16_t permuted_samples[4];
+  int32x4_t sum[2];
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_u8(samples0, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_u8(samples0, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_u8(samples0, permute_tbl.val[2]);
+  /* {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 } */
+  permuted_samples[3] = vqtbl1q_u8(samples1, permute_tbl.val[2]);
+
+  /* First 4 output values. */
+  sum[0] = vusdotq_laneq_s32(horiz_const, permuted_samples[0], filters, 0);
+  sum[0] = vusdotq_laneq_s32(sum[0], permuted_samples[1], filters, 1);
+  sum[0] = vusdotq_laneq_s32(sum[0], permuted_samples[2], filters, 2);
+  /* Second 4 output values. */
+  sum[1] = vusdotq_laneq_s32(horiz_const, permuted_samples[1], filters, 0);
+  sum[1] = vusdotq_laneq_s32(sum[1], permuted_samples[2], filters, 1);
+  sum[1] = vusdotq_laneq_s32(sum[1], permuted_samples[3], filters, 2);
+
+  /* Narrow and re-pack. */
+  return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS),
+                      vshrn_n_s32(sum[1], ROUND0_BITS));
+}
+
+static INLINE void convolve_2d_sr_horiz_12tap_neon(
     const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
     const int dst_stride, int w, int h, const int16x8_t x_filter_0_7,
-    const int16x4_t x_filter_8_11, const int round_0) {
+    const int16x4_t x_filter_8_11) {
   const int bd = 8;
 
   // Special case the following no-op filter as 128 won't fit into the
@@ -2136,7 +2286,6 @@
   // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
   if (vgetq_lane_s16(x_filter_0_7, 5) == 128) {
     const int16x8_t horiz_const = vdupq_n_s16((1 << (bd - 1)));
-    const int16x8_t shift_round_0 = vdupq_n_s16(FILTER_BITS - round_0);
     // Undo the horizontal offset in the calling function.
     src_ptr += 5;
 
@@ -2144,7 +2293,8 @@
       for (int j = 0; j < w; j += 8) {
         uint8x8_t s0 = vld1_u8(src_ptr + i * src_stride + j);
         uint16x8_t t0 = vaddw_u8(vreinterpretq_u16_s16(horiz_const), s0);
-        int16x8_t d0 = vqrshlq_s16(vreinterpretq_s16_u16(t0), shift_round_0);
+        int16x8_t d0 =
+            vshlq_n_s16(vreinterpretq_s16_u16(t0), FILTER_BITS - ROUND0_BITS);
         if (w == 2) {
           store_s16_2x1(dst_ptr + i * dst_stride, vget_low_s16(d0), 0);
         } else if (w == 4) {
@@ -2161,9 +2311,10 @@
     };
     const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]),
                                            vmovn_s16(x_filter_s16.val[1]));
-
-    const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 1)));
-    const int32x4_t shift_round_0 = vdupq_n_s32(-round_0);
+    // This shim of +4 enables us to use non-rounding shifts - which are
+    // generally faster than rounding shifts on modern CPUs.
+    const int32x4_t horiz_const =
+        vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + 4);
     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
 
     if (w <= 4) {
@@ -2178,14 +2329,10 @@
 
           load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-          d0 = convolve12_4_usdot(s0, x_filter, permute_tbl, horiz_const,
-                                  shift_round_0);
-          d1 = convolve12_4_usdot(s1, x_filter, permute_tbl, horiz_const,
-                                  shift_round_0);
-          d2 = convolve12_4_usdot(s2, x_filter, permute_tbl, horiz_const,
-                                  shift_round_0);
-          d3 = convolve12_4_usdot(s3, x_filter, permute_tbl, horiz_const,
-                                  shift_round_0);
+          d0 = convolve12_horiz_4_usdot(s0, x_filter, permute_tbl, horiz_const);
+          d1 = convolve12_horiz_4_usdot(s1, x_filter, permute_tbl, horiz_const);
+          d2 = convolve12_horiz_4_usdot(s2, x_filter, permute_tbl, horiz_const);
+          d3 = convolve12_horiz_4_usdot(s3, x_filter, permute_tbl, horiz_const);
 
           if (w == 2) {
             store_s16_2x1(d + 0 * dst_stride, d0, 0);
@@ -2217,8 +2364,7 @@
 
           s0 = vld1q_u8(s);
 
-          d0 = convolve12_4_usdot(s0, x_filter, permute_tbl, horiz_const,
-                                  shift_round_0);
+          d0 = convolve12_horiz_4_usdot(s0, x_filter, permute_tbl, horiz_const);
 
           if (w == 2) {
             store_s16_2x1(d, d0, 0);
@@ -2247,14 +2393,14 @@
           load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
           load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
 
-          d0 = convolve12_8_usdot(s0[0], s0[1], x_filter, permute_tbl,
-                                  horiz_const, shift_round_0);
-          d1 = convolve12_8_usdot(s1[0], s1[1], x_filter, permute_tbl,
-                                  horiz_const, shift_round_0);
-          d2 = convolve12_8_usdot(s2[0], s2[1], x_filter, permute_tbl,
-                                  horiz_const, shift_round_0);
-          d3 = convolve12_8_usdot(s3[0], s3[1], x_filter, permute_tbl,
-                                  horiz_const, shift_round_0);
+          d0 = convolve12_horiz_8_usdot(s0[0], s0[1], x_filter, permute_tbl,
+                                        horiz_const);
+          d1 = convolve12_horiz_8_usdot(s1[0], s1[1], x_filter, permute_tbl,
+                                        horiz_const);
+          d2 = convolve12_horiz_8_usdot(s2[0], s2[1], x_filter, permute_tbl,
+                                        horiz_const);
+          d3 = convolve12_horiz_8_usdot(s3[0], s3[1], x_filter, permute_tbl,
+                                        horiz_const);
 
           store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -2280,8 +2426,8 @@
           s0[0] = vld1q_u8(s);
           s0[1] = vld1q_u8(s + 4);
 
-          d0 = convolve12_8_usdot(s0[0], s0[1], x_filter, permute_tbl,
-                                  horiz_const, shift_round_0);
+          d0 = convolve12_horiz_8_usdot(s0[0], s0[1], x_filter, permute_tbl,
+                                        horiz_const);
 
           vst1q_s16(d, d0);
 
@@ -2299,10 +2445,10 @@
 
 #elif defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
-static INLINE void av1_convolve_2d_sr_horiz_12tap_neon(
+static INLINE void convolve_2d_sr_horiz_12tap_neon(
     const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
     const int dst_stride, int w, int h, const int16x8_t x_filter_0_7,
-    const int16x4_t x_filter_8_11, const int round_0) {
+    const int16x4_t x_filter_8_11) {
   const int bd = 8;
 
   // Special case the following no-op filter as 128 won't fit into the
@@ -2310,7 +2456,6 @@
   // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
   if (vgetq_lane_s16(x_filter_0_7, 5) == 128) {
     const int16x8_t horiz_const = vdupq_n_s16((1 << (bd - 1)));
-    const int16x8_t shift_round_0 = vdupq_n_s16(FILTER_BITS - round_0);
     // Undo the horizontal offset in the calling function.
     src_ptr += 5;
 
@@ -2318,7 +2463,8 @@
       for (int j = 0; j < w; j += 8) {
         uint8x8_t s0 = vld1_u8(src_ptr + i * src_stride + j);
         uint16x8_t t0 = vaddw_u8(vreinterpretq_u16_s16(horiz_const), s0);
-        int16x8_t d0 = vqrshlq_s16(vreinterpretq_s16_u16(t0), shift_round_0);
+        int16x8_t d0 =
+            vshlq_n_s16(vreinterpretq_s16_u16(t0), FILTER_BITS - ROUND0_BITS);
         if (w == 2) {
           store_s16_2x1(dst_ptr + i * dst_stride, vget_low_s16(d0), 0);
         } else if (w == 4) {
@@ -2329,8 +2475,6 @@
       }
     }
   } else {
-    const int32x4_t shift_round_0 = vdupq_n_s32(-round_0);
-
     // Narrow filter values to 8-bit.
     const int16x8x2_t x_filter_s16 = {
       { x_filter_0_7, vcombine_s16(x_filter_8_11, vdup_n_s16(0)) }
@@ -2338,8 +2482,10 @@
     const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]),
                                            vmovn_s16(x_filter_s16.val[1]));
 
+    // This shim of +4 enables us to use non-rounding shifts - which are
+    // generally faster than rounding shifts on modern CPUs.
+    const int32_t horiz_const = ((1 << (bd + FILTER_BITS - 1)) + 4);
     // Dot product constants.
-    const int32_t horiz_const = (1 << (bd + FILTER_BITS - 1));
     const int32x4_t correct_tmp =
         vaddq_s32(vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[0], 7)),
                   vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[1], 7)));
@@ -2360,14 +2506,14 @@
 
           load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-          d0 = convolve12_4_sdot(s0, x_filter, correction, range_limit,
-                                 permute_tbl, shift_round_0);
-          d1 = convolve12_4_sdot(s1, x_filter, correction, range_limit,
-                                 permute_tbl, shift_round_0);
-          d2 = convolve12_4_sdot(s2, x_filter, correction, range_limit,
-                                 permute_tbl, shift_round_0);
-          d3 = convolve12_4_sdot(s3, x_filter, correction, range_limit,
-                                 permute_tbl, shift_round_0);
+          d0 = convolve12_horiz_4_sdot(s0, x_filter, correction, range_limit,
+                                       permute_tbl);
+          d1 = convolve12_horiz_4_sdot(s1, x_filter, correction, range_limit,
+                                       permute_tbl);
+          d2 = convolve12_horiz_4_sdot(s2, x_filter, correction, range_limit,
+                                       permute_tbl);
+          d3 = convolve12_horiz_4_sdot(s3, x_filter, correction, range_limit,
+                                       permute_tbl);
 
           if (w == 2) {
             store_s16_2x1(d + 0 * dst_stride, d0, 0);
@@ -2399,8 +2545,8 @@
 
           s0 = vld1q_u8(s);
 
-          d0 = convolve12_4_sdot(s0, x_filter, correction, range_limit,
-                                 permute_tbl, shift_round_0);
+          d0 = convolve12_horiz_4_sdot(s0, x_filter, correction, range_limit,
+                                       permute_tbl);
 
           if (w == 2) {
             store_s16_2x1(d, d0, 0);
@@ -2429,14 +2575,14 @@
           load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
           load_u8_16x4(s + 4, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
 
-          d0 = convolve12_8_sdot(s0[0], s0[1], x_filter, correction,
-                                 range_limit, permute_tbl, shift_round_0);
-          d1 = convolve12_8_sdot(s1[0], s1[1], x_filter, correction,
-                                 range_limit, permute_tbl, shift_round_0);
-          d2 = convolve12_8_sdot(s2[0], s2[1], x_filter, correction,
-                                 range_limit, permute_tbl, shift_round_0);
-          d3 = convolve12_8_sdot(s3[0], s3[1], x_filter, correction,
-                                 range_limit, permute_tbl, shift_round_0);
+          d0 = convolve12_horiz_8_sdot(s0[0], s0[1], x_filter, correction,
+                                       range_limit, permute_tbl);
+          d1 = convolve12_horiz_8_sdot(s1[0], s1[1], x_filter, correction,
+                                       range_limit, permute_tbl);
+          d2 = convolve12_horiz_8_sdot(s2[0], s2[1], x_filter, correction,
+                                       range_limit, permute_tbl);
+          d3 = convolve12_horiz_8_sdot(s3[0], s3[1], x_filter, correction,
+                                       range_limit, permute_tbl);
 
           store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -2462,8 +2608,8 @@
           s0[0] = vld1q_u8(s);
           s0[1] = vld1q_u8(s + 4);
 
-          d0 = convolve12_8_sdot(s0[0], s0[1], x_filter, correction,
-                                 range_limit, permute_tbl, shift_round_0);
+          d0 = convolve12_horiz_8_sdot(s0[0], s0[1], x_filter, correction,
+                                       range_limit, permute_tbl);
 
           vst1q_s16(d, d0);
 
@@ -2487,7 +2633,7 @@
     const int16x4_t s6, const int16x4_t s7, const int16x4_t s8,
     const int16x4_t s9, const int16x4_t s10, const int16x4_t s11,
     const int16x8_t x_filter_0_7, const int16x4_t x_filter_8_11,
-    const int32x4_t horiz_const, const int32x4_t shift_round_0) {
+    const int32x4_t horiz_const) {
   const int16x4_t x_filter_0_3 = vget_low_s16(x_filter_0_7);
   const int16x4_t x_filter_4_7 = vget_high_s16(x_filter_0_7);
   int32x4_t sum;
@@ -2506,9 +2652,7 @@
   sum = vmlal_lane_s16(sum, s10, x_filter_8_11, 2);
   sum = vmlal_lane_s16(sum, s11, x_filter_8_11, 3);
 
-  sum = vqrshlq_s32(sum, shift_round_0);
-
-  return vmovn_s32(sum);
+  return vshrn_n_s32(sum, ROUND0_BITS);
 }
 
 // 4 column per iteration horizontal filtering for 12-tap convolve_2d_sr.
@@ -2516,8 +2660,7 @@
 static INLINE void horiz_filter_12tap_w4_single_row(
     const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
     const int dst_stride, int w, int h, const int16x8_t x_filter_0_7,
-    const int16x4_t x_filter_8_11, const int32x4_t horiz_const,
-    const int32x4_t shift_round_0) {
+    const int16x4_t x_filter_8_11, const int32x4_t horiz_const) {
   do {
     const uint8_t *s = src_ptr;
     int16_t *d = dst_ptr;
@@ -2549,7 +2692,7 @@
 
       d0 = convolve12_horiz_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10,
                                     s11, x_filter_0_7, x_filter_8_11,
-                                    horiz_const, shift_round_0);
+                                    horiz_const);
 
       if (w == 2) {
         store_s16_2x1(d, d0, 0);
@@ -2568,13 +2711,14 @@
   } while (h > 0);
 }
 
-static INLINE void av1_convolve_2d_sr_horiz_12tap_neon(
+static INLINE void convolve_2d_sr_horiz_12tap_neon(
     const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
     const int dst_stride, int w, int h, const int16x8_t x_filter_0_7,
-    const int16x4_t x_filter_8_11, const int round_0) {
+    const int16x4_t x_filter_8_11) {
   const int bd = 8;
-  const int32x4_t shift_round_0 = vdupq_n_s32(-(round_0));
-  const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 1)));
+  // This shim of +4 enables us to use non-rounding shifts - which are
+  // generally faster than rounding shifts on modern CPUs.
+  const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + 4);
 
 #if defined(__aarch64__)
   do {
@@ -2619,16 +2763,16 @@
 
       d0 = convolve12_horiz_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10,
                                     s11, x_filter_0_7, x_filter_8_11,
-                                    horiz_const, shift_round_0);
+                                    horiz_const);
       d1 = convolve12_horiz_4x4_s16(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10,
                                     s11, s12, x_filter_0_7, x_filter_8_11,
-                                    horiz_const, shift_round_0);
+                                    horiz_const);
       d2 = convolve12_horiz_4x4_s16(s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
                                     s12, s13, x_filter_0_7, x_filter_8_11,
-                                    horiz_const, shift_round_0);
+                                    horiz_const);
       d3 = convolve12_horiz_4x4_s16(s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
                                     s13, s14, x_filter_0_7, x_filter_8_11,
-                                    horiz_const, shift_round_0);
+                                    horiz_const);
 
       transpose_s16_4x4d(&d0, &d1, &d2, &d3);
 
@@ -2666,12 +2810,11 @@
   if (h) {
     horiz_filter_12tap_w4_single_row(src_ptr, src_stride, dst_ptr, dst_stride,
                                      w, h, x_filter_0_7, x_filter_8_11,
-                                     horiz_const, shift_round_0);
+                                     horiz_const);
   }
 #else   // !defined(__aarch64__)
   horiz_filter_12tap_w4_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
-                                   h, x_filter_0_7, x_filter_8_11, horiz_const,
-                                   shift_round_0);
+                                   h, x_filter_0_7, x_filter_8_11, horiz_const);
 #endif  // defined(__aarch64__)
 }
 
@@ -2679,9 +2822,36 @@
 
 #if defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
 
-static INLINE void av1_convolve_2d_sr_horiz_neon(
+static INLINE int16x8_t convolve8_horiz_8_usdot(uint8x16_t samples,
+                                                const int8x8_t filters,
+                                                const uint8x16x3_t permute_tbl,
+                                                const int32x4_t horiz_const) {
+  uint8x16_t permuted_samples[3];
+  int32x4_t sum[2];
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
+
+  /* First 4 output values. */
+  sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], filters, 0);
+  sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
+  /* Second 4 output values. */
+  sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], filters, 0);
+  sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
+
+  /* Narrow and re-pack. */
+  return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
+                      vshrn_n_s32(sum[1], ROUND0_BITS - 1));
+}
+
+static INLINE void convolve_2d_sr_horiz_8tap_neon(
     const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
-    int im_h, const int16x8_t x_filter_s16, const int round_0) {
+    int im_h, const int16x8_t x_filter_s16) {
   const int bd = 8;
 
   const uint8_t *src_ptr = src;
@@ -2693,13 +2863,12 @@
   // Filter values are even, so downshift by 1 to reduce intermediate precision
   // requirements.
   const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
-  const int32x4_t horiz_const = vdupq_n_s32(1 << (bd + FILTER_BITS - 2));
-
-  assert(round_0 > 0);
+  // This shim of +2 enables us to use non-rounding shifts - which are
+  // generally faster than rounding shifts on modern CPUs.
+  const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) + 2);
 
   if (w <= 4) {
     const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
-    const int16x4_t shift_round_0 = vdup_n_s16(-(round_0 - 1));
     uint8x16_t s0, s1, s2, s3;
     int32x4_t t0, t1, t2, t3;
     int16x4_t d0, d1, d2, d3;
@@ -2714,10 +2883,10 @@
       t2 = convolve8_4_usdot(s2, x_filter, permute_tbl, horiz_const);
       t3 = convolve8_4_usdot(s3, x_filter, permute_tbl, horiz_const);
 
-      d0 = vqrshl_s16(vmovn_s32(t0), shift_round_0);
-      d1 = vqrshl_s16(vmovn_s32(t1), shift_round_0);
-      d2 = vqrshl_s16(vmovn_s32(t2), shift_round_0);
-      d3 = vqrshl_s16(vmovn_s32(t3), shift_round_0);
+      d0 = vshrn_n_s32(t0, ROUND0_BITS - 1);
+      d1 = vshrn_n_s32(t1, ROUND0_BITS - 1);
+      d2 = vshrn_n_s32(t2, ROUND0_BITS - 1);
+      d3 = vshrn_n_s32(t3, ROUND0_BITS - 1);
 
       if (w == 2) {
         store_s16_2x1(dst_ptr + 0 * dst_stride, d0, 0);
@@ -2739,7 +2908,7 @@
       do {
         s0 = vld1q_u8(src_ptr);
         t0 = convolve8_4_usdot(s0, x_filter, permute_tbl, horiz_const);
-        d0 = vqrshl_s16(vmovn_s32(t0), shift_round_0);
+        d0 = vshrn_n_s32(t0, ROUND0_BITS - 1);
 
         if (w == 2) {
           store_s16_2x1(dst_ptr, d0, 0);
@@ -2754,7 +2923,6 @@
     }
   } else {
     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
-    const int16x8_t shift_round_0 = vdupq_n_s16(-(round_0 - 1));
     uint8x16_t s0, s1, s2, s3;
     int16x8_t d0, d1, d2, d3;
 
@@ -2768,14 +2936,10 @@
       do {
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = convolve8_8_usdot(s0, x_filter, permute_tbl, horiz_const,
-                               shift_round_0);
-        d1 = convolve8_8_usdot(s1, x_filter, permute_tbl, horiz_const,
-                               shift_round_0);
-        d2 = convolve8_8_usdot(s2, x_filter, permute_tbl, horiz_const,
-                               shift_round_0);
-        d3 = convolve8_8_usdot(s3, x_filter, permute_tbl, horiz_const,
-                               shift_round_0);
+        d0 = convolve8_horiz_8_usdot(s0, x_filter, permute_tbl, horiz_const);
+        d1 = convolve8_horiz_8_usdot(s1, x_filter, permute_tbl, horiz_const);
+        d2 = convolve8_horiz_8_usdot(s2, x_filter, permute_tbl, horiz_const);
+        d3 = convolve8_horiz_8_usdot(s3, x_filter, permute_tbl, horiz_const);
 
         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -2799,8 +2963,7 @@
 
         do {
           s0 = vld1q_u8(s);
-          d0 = convolve8_8_usdot(s0, x_filter, permute_tbl, horiz_const,
-                                 shift_round_0);
+          d0 = convolve8_horiz_8_usdot(s0, x_filter, permute_tbl, horiz_const);
           vst1q_s16(d, d0);
 
           s += 8;
@@ -2818,9 +2981,9 @@
 
 #elif defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
-static INLINE void av1_convolve_2d_sr_horiz_neon(
+static INLINE void convolve_2d_sr_horiz_8tap_neon(
     const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
-    int im_h, const int16x8_t x_filter_s16, const int round_0) {
+    int im_h, const int16x8_t x_filter_s16) {
   const int bd = 8;
 
   const uint8_t *src_ptr = src;
@@ -2832,18 +2995,16 @@
   // Filter values are even, so downshift by 1 to reduce intermediate precision
   // requirements.
   const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
-  const int32_t horiz_const = (1 << (bd + FILTER_BITS - 2));
+  // This shim of +2 enables us to use non-rounding shifts - which are
+  // generally faster than rounding shifts on modern CPUs.
+  const int32_t horiz_const = ((1 << (bd + FILTER_BITS - 2)) + 2);
   // Dot product constants.
   const int16x8_t correct_tmp = vshlq_n_s16(x_filter_s16, 6);
-  const int32x4_t correction =
-      vdupq_n_s32(vaddlvq_s16(correct_tmp) + horiz_const);
+  int32x4_t correction = vdupq_n_s32(vaddlvq_s16(correct_tmp) + horiz_const);
   const uint8x16_t range_limit = vdupq_n_u8(128);
 
-  assert(round_0 > 0);
-
   if (w <= 4) {
     const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
-    const int16x4_t shift_round_0 = vdup_n_s16(-(round_0 - 1));
     uint8x16_t s0, s1, s2, s3;
     int32x4_t t0, t1, t2, t3;
     int16x4_t d0, d1, d2, d3;
@@ -2858,10 +3019,10 @@
       t2 = convolve8_4_sdot(s2, x_filter, correction, range_limit, permute_tbl);
       t3 = convolve8_4_sdot(s3, x_filter, correction, range_limit, permute_tbl);
 
-      d0 = vqrshl_s16(vmovn_s32(t0), shift_round_0);
-      d1 = vqrshl_s16(vmovn_s32(t1), shift_round_0);
-      d2 = vqrshl_s16(vmovn_s32(t2), shift_round_0);
-      d3 = vqrshl_s16(vmovn_s32(t3), shift_round_0);
+      d0 = vshrn_n_s32(t0, ROUND0_BITS - 1);
+      d1 = vshrn_n_s32(t1, ROUND0_BITS - 1);
+      d2 = vshrn_n_s32(t2, ROUND0_BITS - 1);
+      d3 = vshrn_n_s32(t3, ROUND0_BITS - 1);
 
       if (w == 2) {
         store_s16_2x1(dst_ptr + 0 * dst_stride, d0, 0);
@@ -2884,7 +3045,7 @@
         s0 = vld1q_u8(src_ptr);
         t0 = convolve8_4_sdot(s0, x_filter, correction, range_limit,
                               permute_tbl);
-        d0 = vqrshl_s16(vmovn_s32(t0), shift_round_0);
+        d0 = vshrn_n_s32(t0, ROUND0_BITS - 1);
 
         if (w == 2) {
           store_s16_2x1(dst_ptr, d0, 0);
@@ -2899,7 +3060,6 @@
     }
   } else {
     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
-    const int16x8_t shift_round_0 = vdupq_n_s16(-(round_0 - 1));
     uint8x16_t s0, s1, s2, s3;
     int16x8_t d0, d1, d2, d3;
 
@@ -2913,14 +3073,19 @@
       do {
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = convolve8_8_sdot(s0, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d1 = convolve8_8_sdot(s1, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d2 = convolve8_8_sdot(s2, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d3 = convolve8_8_sdot(s3, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
+        d0 = convolve8_horiz_8_sdot(s0, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d1 = convolve8_horiz_8_sdot(s1, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d2 = convolve8_horiz_8_sdot(s2, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d3 = convolve8_horiz_8_sdot(s3, x_filter, correction, range_limit,
+                                    permute_tbl);
+
+        d0 = vshrq_n_s16(d0, ROUND0_BITS - 1);
+        d1 = vshrq_n_s16(d1, ROUND0_BITS - 1);
+        d2 = vshrq_n_s16(d2, ROUND0_BITS - 1);
+        d3 = vshrq_n_s16(d3, ROUND0_BITS - 1);
 
         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -2945,7 +3110,8 @@
         do {
           s0 = vld1q_u8(s);
           d0 = convolve8_8_sdot(s0, x_filter, correction, range_limit,
-                                permute_tbl, shift_round_0);
+                                permute_tbl, vdupq_n_s16(0));
+          d0 = vshrq_n_s16(d0, ROUND0_BITS - 1);
           vst1q_s16(d, d0);
 
           s += 8;
@@ -2963,12 +3129,58 @@
 
 #else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
 
+static INLINE int16x4_t convolve8_horiz_4x4_s16(
+    const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
+    const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
+    const int16x4_t s6, const int16x4_t s7, const int16x8_t filter,
+    const int16x4_t horiz_const) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+  int16x4_t sum;
+
+  sum = horiz_const;
+  sum = vmla_lane_s16(sum, s0, filter_lo, 0);
+  sum = vmla_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmla_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmla_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmla_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmla_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmla_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmla_lane_s16(sum, s7, filter_hi, 3);
+
+  return vshr_n_s16(sum, ROUND0_BITS - 1);
+}
+
+static INLINE int16x8_t convolve8_horiz_8x8_s16(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
+    const int16x8_t s6, const int16x8_t s7, const int16x8_t filter,
+    const int16x8_t horiz_const) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+  int16x8_t sum;
+
+  sum = horiz_const;
+  sum = vmlaq_lane_s16(sum, s0, filter_lo, 0);
+  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
+
+  return vshrq_n_s16(sum, ROUND0_BITS - 1);
+}
+
 // Horizontal filtering for convolve_2d_sr for width multiple of 8
 // Processes one row at a time
-static INLINE void horiz_filter_w8_single_row(
-    const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
-    const int dst_stride, int width, int height, const int16x8_t x_filter,
-    const int16x8_t horiz_const, const int16x8_t shift_round_0) {
+static INLINE void horiz_filter_w8_single_row(const uint8_t *src_ptr,
+                                              int src_stride, int16_t *dst_ptr,
+                                              const int dst_stride, int width,
+                                              int height,
+                                              const int16x8_t x_filter,
+                                              const int16x8_t horiz_const) {
   int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
   do {
     uint8x8_t t0 = vld1_u8(src_ptr);
@@ -2994,8 +3206,8 @@
       s6 = vextq_s16(sum, s7, 6);  // a6 a7 a8 a9 a10 a11 a12 a13
       s7 = vextq_s16(sum, s7, 7);  // a7 a8 a9 a10 a11 a12 a13 a14
 
-      int16x8_t res0 = convolve8_8x8_s16(sum, s1, s2, s3, s4, s5, s6, s7,
-                                         x_filter, horiz_const, shift_round_0);
+      int16x8_t res0 = convolve8_horiz_8x8_s16(sum, s1, s2, s3, s4, s5, s6, s7,
+                                               x_filter, horiz_const);
 
       vst1q_s16(dst_tmp, res0);
 
@@ -3011,10 +3223,12 @@
 
 // Horizontal filtering for convolve_2d_sr for width <= 4
 // Processes one row at a time
-static INLINE void horiz_filter_w4_single_row(
-    const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
-    const int dst_stride, int width, int height, const int16x8_t x_filter,
-    const int16x4_t horiz_const, const int16x4_t shift_round_0) {
+static INLINE void horiz_filter_w4_single_row(const uint8_t *src_ptr,
+                                              int src_stride, int16_t *dst_ptr,
+                                              const int dst_stride, int width,
+                                              int height,
+                                              const int16x8_t x_filter,
+                                              const int16x4_t horiz_const) {
   int16x4_t s0, s1, s2, s3, s4, s5, s6, s7;
   do {
     const uint8_t *s = src_ptr;
@@ -3039,8 +3253,8 @@
     s6 = vext_s16(s4, s7, 2);  // a6 a7 a8 a9
     s7 = vext_s16(s4, s7, 3);  // a7 a8 a9 a10
 
-    int16x4_t d0 = convolve8_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
-                                     horiz_const, shift_round_0);
+    int16x4_t d0 = convolve8_horiz_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7,
+                                           x_filter, horiz_const);
 
     if (width == 2) {
       store_s16_2x1(dst_ptr, d0, 0);
@@ -3054,9 +3268,9 @@
   } while (height > 0);
 }
 
-static INLINE void av1_convolve_2d_sr_horiz_neon(
+static INLINE void convolve_2d_sr_horiz_8tap_neon(
     const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
-    int im_h, const int16x8_t x_filter_s16, const int round_0) {
+    int im_h, const int16x8_t x_filter_s16) {
   const int bd = 8;
 
   const uint8_t *src_ptr = src;
@@ -3069,11 +3283,10 @@
   // requirements.
   const int16x8_t x_filter = vshrq_n_s16(x_filter_s16, 1);
 
-  assert(round_0 > 0);
-
   if (w <= 4) {
-    const int16x4_t horiz_const = vdup_n_s16((1 << (bd + FILTER_BITS - 2)));
-    const int16x4_t shift_round_0 = vdup_n_s16(-(round_0 - 1));
+    // This shim of +2 enables us to use non-rounding shifts - which are
+    // generally faster than rounding shifts on modern CPUs.
+    const int16x4_t horiz_const = vdup_n_s16((1 << (bd + FILTER_BITS - 2)) + 2);
 
 #if defined(__aarch64__)
     do {
@@ -3104,14 +3317,14 @@
       s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
       s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
 
-      d0 = convolve8_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
-                             horiz_const, shift_round_0);
-      d1 = convolve8_4x4_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
-                             horiz_const, shift_round_0);
-      d2 = convolve8_4x4_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
-                             horiz_const, shift_round_0);
-      d3 = convolve8_4x4_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
-                             horiz_const, shift_round_0);
+      d0 = convolve8_horiz_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                   horiz_const);
+      d1 = convolve8_horiz_4x4_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
+                                   horiz_const);
+      d2 = convolve8_horiz_4x4_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
+                                   horiz_const);
+      d3 = convolve8_horiz_4x4_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
+                                   horiz_const);
 
       transpose_s16_4x4d(&d0, &d1, &d2, &d3);
 
@@ -3132,17 +3345,19 @@
     if (height) {
       assert(height < 4);
       horiz_filter_w4_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
-                                 height, x_filter, horiz_const, shift_round_0);
+                                 height, x_filter, horiz_const);
     }
 
 #else   // !defined(__aarch64__)
     horiz_filter_w4_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
-                               height, x_filter, horiz_const, shift_round_0);
+                               height, x_filter, horiz_const);
 #endif  // defined(__aarch64__)
 
   } else {
-    const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)));
-    const int16x8_t shift_round_0 = vdupq_n_s16(-(round_0 - 1));
+    // This shim of +2 enables us to use non-rounding shifts - which are
+    // generally faster than rounding shifts on modern CPUs.
+    const int16x8_t horiz_const =
+        vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) + 2);
 
 #if defined(__aarch64__)
 
@@ -3183,22 +3398,22 @@
         s13 = vreinterpretq_s16_u16(vmovl_u8(t6));
         s14 = vreinterpretq_s16_u16(vmovl_u8(t7));
 
-        d0 = convolve8_8x8_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
-                               horiz_const, shift_round_0);
-        d1 = convolve8_8x8_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
-                               horiz_const, shift_round_0);
-        d2 = convolve8_8x8_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
-                               horiz_const, shift_round_0);
-        d3 = convolve8_8x8_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
-                               horiz_const, shift_round_0);
-        d4 = convolve8_8x8_s16(s4, s5, s6, s7, s8, s9, s10, s11, x_filter,
-                               horiz_const, shift_round_0);
-        d5 = convolve8_8x8_s16(s5, s6, s7, s8, s9, s10, s11, s12, x_filter,
-                               horiz_const, shift_round_0);
-        d6 = convolve8_8x8_s16(s6, s7, s8, s9, s10, s11, s12, s13, x_filter,
-                               horiz_const, shift_round_0);
-        d7 = convolve8_8x8_s16(s7, s8, s9, s10, s11, s12, s13, s14, x_filter,
-                               horiz_const, shift_round_0);
+        d0 = convolve8_horiz_8x8_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                     horiz_const);
+        d1 = convolve8_horiz_8x8_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
+                                     horiz_const);
+        d2 = convolve8_horiz_8x8_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
+                                     horiz_const);
+        d3 = convolve8_horiz_8x8_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
+                                     horiz_const);
+        d4 = convolve8_horiz_8x8_s16(s4, s5, s6, s7, s8, s9, s10, s11, x_filter,
+                                     horiz_const);
+        d5 = convolve8_horiz_8x8_s16(s5, s6, s7, s8, s9, s10, s11, s12,
+                                     x_filter, horiz_const);
+        d6 = convolve8_horiz_8x8_s16(s6, s7, s8, s9, s10, s11, s12, s13,
+                                     x_filter, horiz_const);
+        d7 = convolve8_horiz_8x8_s16(s7, s8, s9, s10, s11, s12, s13, s14,
+                                     x_filter, horiz_const);
 
         transpose_s16_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
 
@@ -3273,10 +3488,10 @@
         d2 = vaddq_s16(d2, horiz_const);
         d3 = vaddq_s16(d3, horiz_const);
 
-        d0 = vqrshlq_s16(d0, shift_round_0);
-        d1 = vqrshlq_s16(d1, shift_round_0);
-        d2 = vqrshlq_s16(d2, shift_round_0);
-        d3 = vqrshlq_s16(d3, shift_round_0);
+        d0 = vshrq_n_s16(d0, ROUND0_BITS - 1);
+        d1 = vshrq_n_s16(d1, ROUND0_BITS - 1);
+        d2 = vshrq_n_s16(d2, ROUND0_BITS - 1);
+        d3 = vshrq_n_s16(d3, ROUND0_BITS - 1);
 
         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -3299,12 +3514,12 @@
     if (height) {
       assert(height < 4);
       horiz_filter_w8_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
-                                 height, x_filter, horiz_const, shift_round_0);
+                                 height, x_filter, horiz_const);
     }
 
 #else   // !defined(__aarch64__)
     horiz_filter_w8_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
-                               height, x_filter, horiz_const, shift_round_0);
+                               height, x_filter, horiz_const);
 #endif  // defined(__aarch64__)
   }
 }
@@ -3870,6 +4085,7 @@
                              const InterpFilterParams *filter_params_y,
                              const int subpel_x_qn, const int subpel_y_qn,
                              ConvolveParams *conv_params) {
+  (void)conv_params;
   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
   const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
   const int im_h = h + clamped_y_taps - 1;
@@ -3892,9 +4108,8 @@
     const int16x8_t y_filter_0_7 = vld1q_s16(y_filter_ptr);
     const int16x4_t y_filter_8_11 = vld1_s16(y_filter_ptr + 8);
 
-    av1_convolve_2d_sr_horiz_12tap_neon(src_ptr, src_stride, im_block,
-                                        im_stride, w, im_h, x_filter_0_7,
-                                        x_filter_8_11, conv_params->round_0);
+    convolve_2d_sr_horiz_12tap_neon(src_ptr, src_stride, im_block, im_stride, w,
+                                    im_h, x_filter_0_7, x_filter_8_11);
 
     convolve_2d_sr_vert_12tap_neon(im_block, im_stride, dst, dst_stride, w, h,
                                    y_filter_0_7, y_filter_8_11);
@@ -3905,8 +4120,8 @@
     const int16x8_t x_filter = vld1q_s16(x_filter_ptr);
     const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
 
-    av1_convolve_2d_sr_horiz_neon(src_ptr, src_stride, im_block, im_stride, w,
-                                  im_h, x_filter, conv_params->round_0);
+    convolve_2d_sr_horiz_8tap_neon(src_ptr, src_stride, im_block, im_stride, w,
+                                   im_h, x_filter);
 
     if (clamped_y_taps <= 6) {
       convolve_2d_sr_vert_6tap_neon(im_block, im_stride, dst, dst_stride, w, h,