Propagate constants in Neon dist_wtd_convolve_2d_horiz functions

Rounding modes and other convolution parameters are known ahead of
time. This patch propagates the values into the Neon code paths for
dist_wtd_convolve_2d_horiz_neon - enabling us to make some useful
simplifications and optimizations.

Co-authored by: Jonathan Wright <jonathan.wright@arm.com>

Change-Id: Ic5b6a80def2dbe0fc254bfec1bd8fc0b73c70d46
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index f12bf07..5b11112 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -136,37 +136,6 @@
 
 #elif defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
-static INLINE int16x8_t convolve8_horiz_8_sdot(uint8x16_t samples,
-                                               const int8x8_t filters,
-                                               const int32x4_t correction,
-                                               const uint8x16_t range_limit,
-                                               const uint8x16x3_t permute_tbl) {
-  int8x16_t clamped_samples, permuted_samples[3];
-  int32x4_t sum[2];
-
-  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
-  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
-
-  /* Permute samples ready for dot product. */
-  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
-  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
-  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
-  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
-  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
-  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
-
-  /* Accumulate dot product into 'correction' to account for range clamp. */
-  /* First 4 output values. */
-  sum[0] = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
-  sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
-  /* Second 4 output values. */
-  sum[1] = vdotq_lane_s32(correction, permuted_samples[1], filters, 0);
-  sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
-
-  /* Narrow and re-pack. */
-  return vcombine_s16(vmovn_s32(sum[0]), vmovn_s32(sum[1]));
-}
-
 static INLINE int16x4_t convolve12_horiz_4_sdot(
     uint8x16_t samples, const int8x16_t filters, const int32x4_t correction,
     const uint8x16_t range_limit, const uint8x16x3_t permute_tbl) {
@@ -2786,34 +2755,6 @@
 
 #if defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
 
-static INLINE int16x8_t convolve8_horiz_8_usdot(uint8x16_t samples,
-                                                const int8x8_t filters,
-                                                const uint8x16x3_t permute_tbl,
-                                                const int32x4_t horiz_const) {
-  uint8x16_t permuted_samples[3];
-  int32x4_t sum[2];
-
-  /* Permute samples ready for dot product. */
-  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
-  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
-  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
-  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
-  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
-  permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
-
-  /* First 4 output values. */
-  sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], filters, 0);
-  sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
-  /* Second 4 output values. */
-  sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], filters, 0);
-  sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
-
-  /* Narrow and re-pack. */
-  // We halved the convolution filter values so -1 from the right shift.
-  return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
-                      vshrn_n_s32(sum[1], ROUND0_BITS - 1));
-}
-
 static INLINE void convolve_2d_sr_horiz_8tap_neon(
     const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
     int im_h, const int16x8_t x_filter_s16) {
@@ -3104,52 +3045,6 @@
 
 #else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
 
-static INLINE int16x4_t convolve8_horiz_4x4_s16(
-    const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
-    const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
-    const int16x4_t s6, const int16x4_t s7, const int16x8_t filter,
-    const int16x4_t horiz_const) {
-  const int16x4_t filter_lo = vget_low_s16(filter);
-  const int16x4_t filter_hi = vget_high_s16(filter);
-  int16x4_t sum;
-
-  sum = horiz_const;
-  sum = vmla_lane_s16(sum, s0, filter_lo, 0);
-  sum = vmla_lane_s16(sum, s1, filter_lo, 1);
-  sum = vmla_lane_s16(sum, s2, filter_lo, 2);
-  sum = vmla_lane_s16(sum, s3, filter_lo, 3);
-  sum = vmla_lane_s16(sum, s4, filter_hi, 0);
-  sum = vmla_lane_s16(sum, s5, filter_hi, 1);
-  sum = vmla_lane_s16(sum, s6, filter_hi, 2);
-  sum = vmla_lane_s16(sum, s7, filter_hi, 3);
-
-  // We halved the convolution filter values so -1 from the right shift.
-  return vshr_n_s16(sum, ROUND0_BITS - 1);
-}
-
-static INLINE int16x8_t convolve8_horiz_8x8_s16(
-    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
-    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
-    const int16x8_t s6, const int16x8_t s7, const int16x8_t filter,
-    const int16x8_t horiz_const) {
-  const int16x4_t filter_lo = vget_low_s16(filter);
-  const int16x4_t filter_hi = vget_high_s16(filter);
-  int16x8_t sum;
-
-  sum = horiz_const;
-  sum = vmlaq_lane_s16(sum, s0, filter_lo, 0);
-  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
-  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
-  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
-  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
-  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
-  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
-  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
-
-  // We halved the convolution filter values so -1 from the right shift.
-  return vshrq_n_s16(sum, ROUND0_BITS - 1);
-}
-
 // Horizontal filtering for convolve_2d_sr for width multiple of 8
 // Processes one row at a time
 static INLINE void horiz_filter_w8_single_row(const uint8_t *src_ptr,
diff --git a/av1/common/arm/convolve_neon.h b/av1/common/arm/convolve_neon.h
index b8eac71..59c77b0 100644
--- a/av1/common/arm/convolve_neon.h
+++ b/av1/common/arm/convolve_neon.h
@@ -245,6 +245,34 @@
 
 #if defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
 
+static INLINE int16x8_t convolve8_horiz_8_usdot(uint8x16_t samples,
+                                                const int8x8_t filters,
+                                                const uint8x16x3_t permute_tbl,
+                                                const int32x4_t horiz_const) {
+  uint8x16_t permuted_samples[3];
+  int32x4_t sum[2];
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
+
+  /* First 4 output values. */
+  sum[0] = vusdotq_lane_s32(horiz_const, permuted_samples[0], filters, 0);
+  sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
+  /* Second 4 output values. */
+  sum[1] = vusdotq_lane_s32(horiz_const, permuted_samples[1], filters, 0);
+  sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
+
+  /* Narrow and re-pack. */
+  // We halved the convolution filter values so -1 from the right shift.
+  return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
+                      vshrn_n_s32(sum[1], ROUND0_BITS - 1));
+}
+
 static INLINE int32x4_t convolve8_4_usdot(uint8x16_t samples,
                                           const int8x8_t filters,
                                           const uint8x16x2_t permute_tbl,
@@ -297,6 +325,37 @@
 
 #elif defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
+static INLINE int16x8_t convolve8_horiz_8_sdot(uint8x16_t samples,
+                                               const int8x8_t filters,
+                                               const int32x4_t correction,
+                                               const uint8x16_t range_limit,
+                                               const uint8x16x3_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[3];
+  int32x4_t sum[2];
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  /* First 4 output values. */
+  sum[0] = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
+  sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], filters, 1);
+  /* Second 4 output values. */
+  sum[1] = vdotq_lane_s32(correction, permuted_samples[1], filters, 0);
+  sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], filters, 1);
+
+  /* Narrow and re-pack. */
+  return vcombine_s16(vmovn_s32(sum[0]), vmovn_s32(sum[1]));
+}
+
 static INLINE int32x4_t convolve8_4_sdot(uint8x16_t samples,
                                          const int8x8_t filters,
                                          const int32x4_t correction,
@@ -530,4 +589,54 @@
   return vcombine_u16(vqmovun_s32(sum0), vqmovun_s32(sum1));
 }
 
+#if !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+
+static INLINE int16x4_t convolve8_horiz_4x4_s16(
+    const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
+    const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
+    const int16x4_t s6, const int16x4_t s7, const int16x8_t filter,
+    const int16x4_t horiz_const) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+  int16x4_t sum;
+
+  sum = horiz_const;
+  sum = vmla_lane_s16(sum, s0, filter_lo, 0);
+  sum = vmla_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmla_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmla_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmla_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmla_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmla_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmla_lane_s16(sum, s7, filter_hi, 3);
+
+  // We halved the convolution filter values so -1 from the right shift.
+  return vshr_n_s16(sum, ROUND0_BITS - 1);
+}
+
+static INLINE int16x8_t convolve8_horiz_8x8_s16(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
+    const int16x8_t s6, const int16x8_t s7, const int16x8_t filter,
+    const int16x8_t horiz_const) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+  int16x8_t sum;
+
+  sum = horiz_const;
+  sum = vmlaq_lane_s16(sum, s0, filter_lo, 0);
+  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
+
+  // We halved the convolution filter values so -1 from the right shift.
+  return vshrq_n_s16(sum, ROUND0_BITS - 1);
+}
+
+#endif  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+
 #endif  // AOM_AV1_COMMON_ARM_CONVOLVE_NEON_H_
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index ae437da..bd868ea 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -285,7 +285,7 @@
 
 static INLINE void dist_wtd_convolve_2d_horiz_neon(
     const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
-    const int16x8_t x_filter_s16, const int im_h, int w, const int round_0) {
+    const int16x8_t x_filter_s16, const int im_h, int w) {
   const int bd = 8;
   int16_t *dst_ptr = im_block;
   int dst_stride = im_stride;
@@ -293,11 +293,14 @@
   int height = im_h;
 
   const int8x8_t x_filter = vmovn_s16(x_filter_s16);
-  const int32x4_t horiz_const = vdupq_n_s32(1 << (bd + FILTER_BITS - 2));
+  // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // The outermost -1 is needed because we halved the filter values.
+  const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
+                                            (1 << ((ROUND0_BITS - 1) - 1)));
 
   if (w == 4) {
     const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
-    const int16x4_t shift_round_0 = vdup_n_s16(-(round_0));
     uint8x16_t s0, s1, s2, s3;
     int32x4_t t0, t1, t2, t3;
     int16x4_t d0, d1, d2, d3;
@@ -310,10 +313,11 @@
       t2 = convolve8_4_usdot(s2, x_filter, permute_tbl, horiz_const);
       t3 = convolve8_4_usdot(s3, x_filter, permute_tbl, horiz_const);
 
-      d0 = vqrshl_s16(vmovn_s32(t0), shift_round_0);
-      d1 = vqrshl_s16(vmovn_s32(t1), shift_round_0);
-      d2 = vqrshl_s16(vmovn_s32(t2), shift_round_0);
-      d3 = vqrshl_s16(vmovn_s32(t3), shift_round_0);
+      // We halved the convolution filter values so -1 from the right shift.
+      d0 = vshrn_n_s32(t0, ROUND0_BITS - 1);
+      d1 = vshrn_n_s32(t1, ROUND0_BITS - 1);
+      d2 = vshrn_n_s32(t2, ROUND0_BITS - 1);
+      d3 = vshrn_n_s32(t3, ROUND0_BITS - 1);
 
       store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
 
@@ -323,7 +327,6 @@
     } while (height > 0);
   } else {
     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
-    const int16x8_t shift_round_0 = vdupq_n_s16(-(round_0));
     const uint8_t *s;
     int16_t *d;
     uint8x16_t s0, s1, s2, s3;
@@ -337,14 +340,10 @@
       do {
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = convolve8_8_usdot(s0, x_filter, permute_tbl, horiz_const,
-                               shift_round_0);
-        d1 = convolve8_8_usdot(s1, x_filter, permute_tbl, horiz_const,
-                               shift_round_0);
-        d2 = convolve8_8_usdot(s2, x_filter, permute_tbl, horiz_const,
-                               shift_round_0);
-        d3 = convolve8_8_usdot(s3, x_filter, permute_tbl, horiz_const,
-                               shift_round_0);
+        d0 = convolve8_horiz_8_usdot(s0, x_filter, permute_tbl, horiz_const);
+        d1 = convolve8_horiz_8_usdot(s1, x_filter, permute_tbl, horiz_const);
+        d2 = convolve8_horiz_8_usdot(s2, x_filter, permute_tbl, horiz_const);
+        d3 = convolve8_horiz_8_usdot(s3, x_filter, permute_tbl, horiz_const);
 
         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -364,7 +363,7 @@
 
 static INLINE void dist_wtd_convolve_2d_horiz_neon(
     const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
-    const int16x8_t x_filter_s16, const int im_h, int w, const int round_0) {
+    const int16x8_t x_filter_s16, const int im_h, int w) {
   const int bd = 8;
   int16_t *dst_ptr = im_block;
   int dst_stride = im_stride;
@@ -375,13 +374,15 @@
   const int32_t horiz_const = (1 << (bd + FILTER_BITS - 2));
   // Dot product constants.
   const int16x8_t correct_tmp = vshlq_n_s16(x_filter_s16, 7);
-  const int32x4_t correction =
-      vdupq_n_s32(vaddlvq_s16(correct_tmp) + horiz_const);
+  // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // The outermost -1 is needed because we halved the filter values.
+  const int32x4_t correction = vdupq_n_s32(
+      vaddlvq_s16(correct_tmp) + horiz_const + (1 << ((ROUND0_BITS - 1) - 1)));
   const uint8x16_t range_limit = vdupq_n_u8(128);
 
   if (w == 4) {
     const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
-    const int16x4_t shift_round_0 = vdup_n_s16(-(round_0));
     uint8x16_t s0, s1, s2, s3;
     int32x4_t t0, t1, t2, t3;
     int16x4_t d0, d1, d2, d3;
@@ -394,10 +395,11 @@
       t2 = convolve8_4_sdot(s2, x_filter, correction, range_limit, permute_tbl);
       t3 = convolve8_4_sdot(s3, x_filter, correction, range_limit, permute_tbl);
 
-      d0 = vqrshl_s16(vmovn_s32(t0), shift_round_0);
-      d1 = vqrshl_s16(vmovn_s32(t1), shift_round_0);
-      d2 = vqrshl_s16(vmovn_s32(t2), shift_round_0);
-      d3 = vqrshl_s16(vmovn_s32(t3), shift_round_0);
+      // We halved the convolution filter values so -1 from the right shift.
+      d0 = vshrn_n_s32(t0, ROUND0_BITS - 1);
+      d1 = vshrn_n_s32(t1, ROUND0_BITS - 1);
+      d2 = vshrn_n_s32(t2, ROUND0_BITS - 1);
+      d3 = vshrn_n_s32(t3, ROUND0_BITS - 1);
 
       store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
 
@@ -407,7 +409,6 @@
     } while (height > 0);
   } else {
     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
-    const int16x8_t shift_round_0 = vdupq_n_s16(-(round_0));
     const uint8_t *s;
     int16_t *d;
     uint8x16_t s0, s1, s2, s3;
@@ -421,14 +422,20 @@
       do {
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = convolve8_8_sdot(s0, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d1 = convolve8_8_sdot(s1, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d2 = convolve8_8_sdot(s2, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d3 = convolve8_8_sdot(s3, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
+        d0 = convolve8_horiz_8_sdot(s0, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d1 = convolve8_horiz_8_sdot(s1, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d2 = convolve8_horiz_8_sdot(s2, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d3 = convolve8_horiz_8_sdot(s3, x_filter, correction, range_limit,
+                                    permute_tbl);
+
+        // We halved the convolution filter values so -1 from the right shift.
+        d0 = vshrq_n_s16(d0, ROUND0_BITS - 1);
+        d1 = vshrq_n_s16(d1, ROUND0_BITS - 1);
+        d2 = vshrq_n_s16(d2, ROUND0_BITS - 1);
+        d3 = vshrq_n_s16(d3, ROUND0_BITS - 1);
 
         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -448,7 +455,7 @@
 
 static INLINE void dist_wtd_convolve_2d_horiz_neon(
     const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
-    const int16x8_t x_filter, const int im_h, int w, const int round_0) {
+    const int16x8_t x_filter, const int im_h, int w) {
   const int bd = 8;
   const uint8_t *s;
   int16_t *dst_ptr;
@@ -465,8 +472,11 @@
     int16x8_t tt0;
     uint8x8_t t0;
 
-    const int16x4_t horiz_const = vdup_n_s16((1 << (bd + FILTER_BITS - 2)));
-    const int16x4_t shift_round_0 = vdup_n_s16(-(round_0));
+    // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+    // shifts - which are generally faster than rounding shifts on modern CPUs.
+    // The outermost -1 is needed because we halved the filter values.
+    const int16x4_t horiz_const = vdup_n_s16((1 << (bd + FILTER_BITS - 2)) +
+                                             (1 << ((ROUND0_BITS - 1) - 1)));
 
 #if defined(__aarch64__)
     int16x4_t s8, s9, s10, d1, d2, d3;
@@ -511,14 +521,14 @@
       s9 = vget_low_s16(tt2);
       s10 = vget_low_s16(tt3);
 
-      d0 = convolve8_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
-                             horiz_const, shift_round_0);
-      d1 = convolve8_4x4_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
-                             horiz_const, shift_round_0);
-      d2 = convolve8_4x4_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
-                             horiz_const, shift_round_0);
-      d3 = convolve8_4x4_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
-                             horiz_const, shift_round_0);
+      d0 = convolve8_horiz_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                   horiz_const);
+      d1 = convolve8_horiz_4x4_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
+                                   horiz_const);
+      d2 = convolve8_horiz_4x4_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
+                                   horiz_const);
+      d3 = convolve8_horiz_4x4_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
+                                   horiz_const);
 
       transpose_s16_4x4d(&d0, &d1, &d2, &d3);
 
@@ -546,8 +556,8 @@
       s6 = vext_s16(s4, s7, 2);  // a6 a7 a8 a9
       s7 = vext_s16(s4, s7, 3);  // a7 a8 a9 a10
 
-      d0 = convolve8_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
-                             horiz_const, shift_round_0);
+      d0 = convolve8_horiz_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                   horiz_const);
 
       vst1_s16(dst_ptr, d0);
 
@@ -562,8 +572,11 @@
     int16x8_t res0;
     uint8x8_t t0;
 
-    const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)));
-    const int16x8_t shift_round_0 = vdupq_n_s16(-(round_0));
+    // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+    // shifts - which are generally faster than rounding shifts on modern CPUs.
+    // The outermost -1 is needed because we halved the filter values.
+    const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) +
+                                              (1 << ((ROUND0_BITS - 1) - 1)));
     do {
 #if defined(__aarch64__)
       uint8x8_t t1, t2, t3, t4, t5, t6, t7;
@@ -611,22 +624,22 @@
         s13 = vreinterpretq_s16_u16(vmovl_u8(t6));
         s14 = vreinterpretq_s16_u16(vmovl_u8(t7));
 
-        res0 = convolve8_8x8_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
-                                 horiz_const, shift_round_0);
-        res1 = convolve8_8x8_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
-                                 horiz_const, shift_round_0);
-        res2 = convolve8_8x8_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
-                                 horiz_const, shift_round_0);
-        res3 = convolve8_8x8_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
-                                 horiz_const, shift_round_0);
-        res4 = convolve8_8x8_s16(s4, s5, s6, s7, s8, s9, s10, s11, x_filter,
-                                 horiz_const, shift_round_0);
-        res5 = convolve8_8x8_s16(s5, s6, s7, s8, s9, s10, s11, s12, x_filter,
-                                 horiz_const, shift_round_0);
-        res6 = convolve8_8x8_s16(s6, s7, s8, s9, s10, s11, s12, s13, x_filter,
-                                 horiz_const, shift_round_0);
-        res7 = convolve8_8x8_s16(s7, s8, s9, s10, s11, s12, s13, s14, x_filter,
-                                 horiz_const, shift_round_0);
+        res0 = convolve8_horiz_8x8_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                       horiz_const);
+        res1 = convolve8_horiz_8x8_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
+                                       horiz_const);
+        res2 = convolve8_horiz_8x8_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
+                                       horiz_const);
+        res3 = convolve8_horiz_8x8_s16(s3, s4, s5, s6, s7, s8, s9, s10,
+                                       x_filter, horiz_const);
+        res4 = convolve8_horiz_8x8_s16(s4, s5, s6, s7, s8, s9, s10, s11,
+                                       x_filter, horiz_const);
+        res5 = convolve8_horiz_8x8_s16(s5, s6, s7, s8, s9, s10, s11, s12,
+                                       x_filter, horiz_const);
+        res6 = convolve8_horiz_8x8_s16(s6, s7, s8, s9, s10, s11, s12, s13,
+                                       x_filter, horiz_const);
+        res7 = convolve8_horiz_8x8_s16(s7, s8, s9, s10, s11, s12, s13, s14,
+                                       x_filter, horiz_const);
 
         transpose_s16_8x8(&res0, &res1, &res2, &res3, &res4, &res5, &res6,
                           &res7);
@@ -671,8 +684,8 @@
         s6 = vextq_s16(temp_0, s7, 6);  // a6 a7 a8 a9 a10 a11 a12 a13
         s7 = vextq_s16(temp_0, s7, 7);  // a7 a8 a9 a10 a11 a12 a13 a14
 
-        res0 = convolve8_8x8_s16(temp_0, s1, s2, s3, s4, s5, s6, s7, x_filter,
-                                 horiz_const, shift_round_0);
+        res0 = convolve8_horiz_8x8_s16(temp_0, s1, s2, s3, s4, s5, s6, s7,
+                                       x_filter, horiz_const);
         vst1q_s16(d_tmp, res0);
 
         s += 8;
@@ -1105,7 +1118,6 @@
   const int im_stride = MAX_SB_SIZE;
   const int vert_offset = filter_params_y->taps / 2 - 1;
   const int horiz_offset = filter_params_x->taps / 2 - 1;
-  const int round_0 = conv_params->round_0 - 1;
   const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
       filter_params_x, subpel_x_qn & SUBPEL_MASK);
@@ -1118,7 +1130,7 @@
   const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
 
   dist_wtd_convolve_2d_horiz_neon(src_ptr, src_stride, im_block, im_stride,
-                                  x_filter, im_h, w, round_0);
+                                  x_filter, im_h, w);
 
   if (clamped_y_taps == 6) {
     dist_wtd_convolve_2d_vert_6tap_neon(im_block + im_stride, im_stride, dst8,