Propagate constants in Neon av1_dist_wtd_convolve_x functions

Rounding modes and other convolution parameters are known ahead of
time. This patch propagates the values into the Neon code paths for
av1_dist_wtd_convolve_x - enabling us to make some useful
simplifications and optimizations.

Co-authored by: Jonathan Wright <jonathan.wright@arm.com>

Change-Id: I8e9ac921cc6685f17d45bd6e9742eca64cec2122
diff --git a/av1/common/arm/convolve_neon.h b/av1/common/arm/convolve_neon.h
index cc679fb..557a08a 100644
--- a/av1/common/arm/convolve_neon.h
+++ b/av1/common/arm/convolve_neon.h
@@ -319,35 +319,6 @@
   return sum;
 }
 
-static INLINE int16x8_t convolve8_8_usdot(uint8x16_t samples,
-                                          const int8x8_t filters,
-                                          const uint8x16x3_t permute_tbl,
-                                          const int32x4_t horiz_const,
-                                          const int16x8_t shift_round_0) {
-  uint8x16_t permuted_samples[3];
-  int32x4_t sum0, sum1;
-  int16x8_t sum;
-
-  /* Permute samples ready for dot product. */
-  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
-  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
-  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
-  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
-  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
-  permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
-
-  /* First 4 output values. */
-  sum0 = vusdotq_lane_s32(horiz_const, permuted_samples[0], filters, 0);
-  sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filters, 1);
-  /* Second 4 output values. */
-  sum1 = vusdotq_lane_s32(horiz_const, permuted_samples[1], filters, 0);
-  sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filters, 1);
-
-  /* Narrow and re-pack. */
-  sum = vcombine_s16(vmovn_s32(sum0), vmovn_s32(sum1));
-  return vqrshlq_s16(sum, shift_round_0);
-}
-
 #elif defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
 static INLINE int16x8_t convolve8_horiz_8_sdot(uint8x16_t samples,
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index 8e23f19..7cf8095 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -1243,20 +1243,16 @@
   assert(!(h % 4));
 
   const int horiz_offset = filter_params_x->taps / 2 - 1;
-  const int bits = FILTER_BITS - conv_params->round_1;
   const int bd = 8;
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int round_offset = (1 << (offset_bits - conv_params->round_1)) +
-                           (1 << (offset_bits - conv_params->round_1 - 1));
-  const int round_bits =
-      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int round_bits = FILTER_BITS - ROUND0_BITS;
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
   const int16x4_t round_offset64 = vdup_n_s16(round_offset);
   const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
-  const int16x8_t shift_round_0 = vdupq_n_s16(-conv_params->round_0 + 1);
-  const int16x8_t horiz_const = vdupq_n_s16(bits);
 
   // Horizontal filter.
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
@@ -1264,6 +1260,10 @@
   // Filter values are even, so downshift by 1 to reduce intermediate precision
   // requirements.
   const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+  // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // The outermost -1 is needed because we halved the filter values.
+  const int32x4_t horiz_const = vdupq_n_s32(1 << ((ROUND0_BITS - 1) - 1));
 
   const uint8_t *src_ptr = src - horiz_offset;
   CONV_BUF_TYPE *dst = conv_params->dst;
@@ -1285,19 +1285,17 @@
 
       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
 
-      d0 = convolve8_4_usdot(s0, x_filter, permute_tbl, vdupq_n_s32(0));
-      d1 = convolve8_4_usdot(s1, x_filter, permute_tbl, vdupq_n_s32(0));
-      d2 = convolve8_4_usdot(s2, x_filter, permute_tbl, vdupq_n_s32(0));
-      d3 = convolve8_4_usdot(s3, x_filter, permute_tbl, vdupq_n_s32(0));
+      d0 = convolve8_4_usdot(s0, x_filter, permute_tbl, horiz_const);
+      d1 = convolve8_4_usdot(s1, x_filter, permute_tbl, horiz_const);
+      d2 = convolve8_4_usdot(s2, x_filter, permute_tbl, horiz_const);
+      d3 = convolve8_4_usdot(s3, x_filter, permute_tbl, horiz_const);
 
       d01 = vcombine_s16(vmovn_s32(d0), vmovn_s32(d1));
       d23 = vcombine_s16(vmovn_s32(d2), vmovn_s32(d3));
 
-      d01 = vqrshlq_s16(d01, shift_round_0);
-      d23 = vqrshlq_s16(d23, shift_round_0);
-
-      d01 = vrshlq_s16(d01, horiz_const);
-      d23 = vrshlq_s16(d23, horiz_const);
+      // We halved the convolution filter values so -1 from the right shift.
+      d01 = vshrq_n_s16(d01, ROUND0_BITS - 1);
+      d23 = vshrq_n_s16(d23, ROUND0_BITS - 1);
 
       d01 = vaddq_s16(d01, round_offset128);
       d23 = vaddq_s16(d23, round_offset128);
@@ -1347,19 +1345,10 @@
 
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = convolve8_8_usdot(s0, x_filter, permute_tbl, vdupq_n_s32(0),
-                               shift_round_0);
-        d1 = convolve8_8_usdot(s1, x_filter, permute_tbl, vdupq_n_s32(0),
-                               shift_round_0);
-        d2 = convolve8_8_usdot(s2, x_filter, permute_tbl, vdupq_n_s32(0),
-                               shift_round_0);
-        d3 = convolve8_8_usdot(s3, x_filter, permute_tbl, vdupq_n_s32(0),
-                               shift_round_0);
-
-        d0 = vrshlq_s16(d0, horiz_const);
-        d1 = vrshlq_s16(d1, horiz_const);
-        d2 = vrshlq_s16(d2, horiz_const);
-        d3 = vrshlq_s16(d3, horiz_const);
+        d0 = convolve8_horiz_8_usdot(s0, x_filter, permute_tbl, horiz_const);
+        d1 = convolve8_horiz_8_usdot(s1, x_filter, permute_tbl, horiz_const);
+        d2 = convolve8_horiz_8_usdot(s2, x_filter, permute_tbl, horiz_const);
+        d3 = convolve8_horiz_8_usdot(s3, x_filter, permute_tbl, horiz_const);
 
         d0 = vaddq_s16(d0, round_offset128);
         d1 = vaddq_s16(d1, round_offset128);
@@ -1407,20 +1396,16 @@
   assert(!(h % 4));
 
   const int horiz_offset = filter_params_x->taps / 2 - 1;
-  const int bits = FILTER_BITS - conv_params->round_1;
   const int bd = 8;
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int round_offset = (1 << (offset_bits - conv_params->round_1)) +
-                           (1 << (offset_bits - conv_params->round_1 - 1));
-  const int round_bits =
-      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int round_bits = FILTER_BITS - ROUND0_BITS;
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
   const int16x4_t round_offset64 = vdup_n_s16(round_offset);
   const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
-  const int16x8_t shift_round_0 = vdupq_n_s16(-conv_params->round_0 + 1);
-  const int16x8_t horiz_const = vdupq_n_s16(bits);
 
   // Horizontal filter.
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
@@ -1431,7 +1416,11 @@
   // Dot-product constants.
   const uint8x16_t range_limit = vdupq_n_u8(128);
   const int32_t correction_s32 = vaddlvq_s16(vshll_n_s8(x_filter, 7));
-  const int32x4_t correction = vdupq_n_s32(correction_s32);
+  // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // The outermost -1 is needed because we halved the filter values.
+  int32x4_t correction =
+      vdupq_n_s32(correction_s32 + (1 << ((ROUND0_BITS - 1) - 1)));
 
   const uint8_t *src_ptr = src - horiz_offset;
   CONV_BUF_TYPE *dst = conv_params->dst;
@@ -1461,11 +1450,9 @@
       d01 = vcombine_s16(vmovn_s32(d0), vmovn_s32(d1));
       d23 = vcombine_s16(vmovn_s32(d2), vmovn_s32(d3));
 
-      d01 = vqrshlq_s16(d01, shift_round_0);
-      d23 = vqrshlq_s16(d23, shift_round_0);
-
-      d01 = vrshlq_s16(d01, horiz_const);
-      d23 = vrshlq_s16(d23, horiz_const);
+      // We halved the convolution filter values so -1 from the right shift.
+      d01 = vshrq_n_s16(d01, ROUND0_BITS - 1);
+      d23 = vshrq_n_s16(d23, ROUND0_BITS - 1);
 
       d01 = vaddq_s16(d01, round_offset128);
       d23 = vaddq_s16(d23, round_offset128);
@@ -1515,19 +1502,20 @@
 
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = convolve8_8_sdot(s0, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d1 = convolve8_8_sdot(s1, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d2 = convolve8_8_sdot(s2, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
-        d3 = convolve8_8_sdot(s3, x_filter, correction, range_limit,
-                              permute_tbl, shift_round_0);
+        d0 = convolve8_horiz_8_sdot(s0, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d1 = convolve8_horiz_8_sdot(s1, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d2 = convolve8_horiz_8_sdot(s2, x_filter, correction, range_limit,
+                                    permute_tbl);
+        d3 = convolve8_horiz_8_sdot(s3, x_filter, correction, range_limit,
+                                    permute_tbl);
 
-        d0 = vrshlq_s16(d0, horiz_const);
-        d1 = vrshlq_s16(d1, horiz_const);
-        d2 = vrshlq_s16(d2, horiz_const);
-        d3 = vrshlq_s16(d3, horiz_const);
+        // We halved the convolution filter values so -1 from the right shift.
+        d0 = vshrq_n_s16(d0, ROUND0_BITS - 1);
+        d1 = vshrq_n_s16(d1, ROUND0_BITS - 1);
+        d2 = vshrq_n_s16(d2, ROUND0_BITS - 1);
+        d3 = vshrq_n_s16(d3, ROUND0_BITS - 1);
 
         d0 = vaddq_s16(d0, round_offset128);
         d1 = vaddq_s16(d1, round_offset128);
@@ -1566,6 +1554,48 @@
 
 #else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
 
+static INLINE int16x4_t
+convolve8_x_4x4_s16(const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
+                    const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
+                    const int16x4_t s6, const int16x4_t s7,
+                    const int16x8_t filter, const int16x4_t horiz_const) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+  int16x4_t sum = horiz_const;
+
+  sum = vmla_lane_s16(sum, s0, filter_lo, 0);
+  sum = vmla_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmla_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmla_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmla_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmla_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmla_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmla_lane_s16(sum, s7, filter_hi, 3);
+
+  return sum;
+}
+
+static INLINE int16x8_t
+convolve8_x_8x8_s16(const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+                    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
+                    const int16x8_t s6, const int16x8_t s7,
+                    const int16x8_t filter, const int16x8_t horiz_const) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+  int16x8_t sum = horiz_const;
+
+  sum = vmlaq_lane_s16(sum, s0, filter_lo, 0);
+  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
+
+  return sum;
+}
+
 void av1_dist_wtd_convolve_x_neon(const uint8_t *src, int src_stride,
                                   uint8_t *dst8, int dst8_stride, int w, int h,
                                   const InterpFilterParams *filter_params_x,
@@ -1577,13 +1607,11 @@
   CONV_BUF_TYPE *dst = conv_params->dst;
   int dst_stride = conv_params->dst_stride;
   const int horiz_offset = filter_params_x->taps / 2 - 1;
-  const int bits = FILTER_BITS - conv_params->round_1;
   const int bd = 8;
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int round_offset = (1 << (offset_bits - conv_params->round_1)) +
-                           (1 << (offset_bits - conv_params->round_1 - 1));
-  const int round_bits =
-      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int round_bits = FILTER_BITS - ROUND0_BITS;
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
@@ -1618,17 +1646,20 @@
     int16x8_t tt0;
     uint16x4_t res4;
 #if defined(__aarch64__)
+    const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
     int16x4_t s8, s9, s10, d1, d2, d3;
-    int16x8_t tt1, tt2, tt3;
+    int16x8_t tt1, tt2, tt3, t01, t23;
     uint16x4_t res5, res6, res7;
     int16x8_t u0, u1;
 #else
+    const int16x4_t round_offset_vec = vdup_n_s16(round_offset);
     int16x4_t temp_0;
 #endif
-    const int16x4_t zero = vdup_n_s16(0);
-    const int16x4_t round_offset_vec = vdup_n_s16(round_offset);
-    const int16x4_t shift_round_0 = vdup_n_s16(-conv_params->round_0 + 1);
-    const int16x4_t horiz_const = vdup_n_s16(bits);
+    // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+    // shifts - which are generally faster than rounding shifts on modern CPUs.
+    // The outermost -1 is needed because we halved the filter values.
+    const int16x4_t horiz_const = vdup_n_s16(1 << ((ROUND0_BITS - 1) - 1));
+
     do {
       s = src_ptr;
       d = dst_ptr;
@@ -1670,22 +1701,29 @@
         s9 = vget_high_s16(u0);
         s10 = vget_high_s16(u1);
 
-        d0 = convolve8_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter, zero,
-                               shift_round_0);
-        d0 = vrshl_s16(d0, horiz_const);
-        d0 = vadd_s16(d0, round_offset_vec);
-        d1 = convolve8_4x4_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter, zero,
-                               shift_round_0);
-        d1 = vrshl_s16(d1, horiz_const);
-        d1 = vadd_s16(d1, round_offset_vec);
-        d2 = convolve8_4x4_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter, zero,
-                               shift_round_0);
-        d2 = vrshl_s16(d2, horiz_const);
-        d2 = vadd_s16(d2, round_offset_vec);
-        d3 = convolve8_4x4_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter, zero,
-                               shift_round_0);
-        d3 = vrshl_s16(d3, horiz_const);
-        d3 = vadd_s16(d3, round_offset_vec);
+        d0 = convolve8_x_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                 horiz_const);
+        d1 = convolve8_x_4x4_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
+                                 horiz_const);
+        d2 = convolve8_x_4x4_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
+                                 horiz_const);
+        d3 = convolve8_x_4x4_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
+                                 horiz_const);
+
+        t01 = vcombine_s16(d0, d1);
+        t23 = vcombine_s16(d2, d3);
+
+        // We halved the convolution filter values so -1 from the right shift.
+        t01 = vshrq_n_s16(t01, ROUND0_BITS - 1);
+        t23 = vshrq_n_s16(t23, ROUND0_BITS - 1);
+
+        t01 = vaddq_s16(t01, round_offset_vec);
+        t23 = vaddq_s16(t23, round_offset_vec);
+
+        d0 = vget_low_s16(t01);
+        d1 = vget_high_s16(t01);
+        d2 = vget_low_s16(t23);
+        d3 = vget_high_s16(t23);
 
         transpose_s16_4x4d(&d0, &d1, &d2, &d3);
 
@@ -1705,8 +1743,8 @@
           compute_avg_4x4(res4, res5, res6, res7, vreinterpret_u16_s16(d0),
                           vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
                           vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
-                          round_offset_vec, round_bits, use_dist_wtd_comp_avg,
-                          &t0, &t1);
+                          vget_low_s16(round_offset_vec), round_bits,
+                          use_dist_wtd_comp_avg, &t0, &t1);
 
           store_u8_4x1(d_u8 + 0 * dst8_stride, t0, 0);
           store_u8_4x1(d_u8 + 1 * dst8_stride, t0, 1);
@@ -1755,9 +1793,10 @@
         s6 = vext_s16(s4, s7, 2);  // a6 a7 a8 a9
         s7 = vext_s16(s4, s7, 3);  // a7 a8 a9 a10
 
-        d0 = convolve8_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter, zero,
-                               shift_round_0);
-        d0 = vrshl_s16(d0, horiz_const);
+        d0 = convolve8_x_4x4_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                 horiz_const);
+        // We halved the convolution filter values so -1 from the right shift.
+        d0 = vshr_n_s16(d0, ROUND0_BITS - 1);
         d0 = vadd_s16(d0, round_offset_vec);
         s0 = s4;
         s4 = temp_0;
@@ -1795,9 +1834,10 @@
     uint16x8_t res8;
     const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
     const int16x4_t round_offset64 = vdup_n_s16(round_offset);
-    const int16x8_t shift_round_0 = vdupq_n_s16(-conv_params->round_0 + 1);
-    const int16x8_t horiz_const = vdupq_n_s16(bits);
-    const int16x8_t zero = vdupq_n_s16(0);
+    // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+    // shifts - which are generally faster than rounding shifts on modern CPUs.
+    // The outermost -1 is needed because we halved the filter values.
+    const int16x8_t horiz_const = vdupq_n_s16(1 << ((ROUND0_BITS - 1) - 1));
 
     d = dst_ptr = dst;
     d_u8 = dst_u8_ptr = dst8;
@@ -1854,39 +1894,40 @@
         s13 = vreinterpretq_s16_u16(vmovl_u8(t6));
         s14 = vreinterpretq_s16_u16(vmovl_u8(t7));
 
-        res0 = convolve8_8x8_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter, zero,
-                                 shift_round_0);
+        res0 = convolve8_x_8x8_s16(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                   horiz_const);
+        res1 = convolve8_x_8x8_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
+                                   horiz_const);
+        res2 = convolve8_x_8x8_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
+                                   horiz_const);
+        res3 = convolve8_x_8x8_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
+                                   horiz_const);
+        res4 = convolve8_x_8x8_s16(s4, s5, s6, s7, s8, s9, s10, s11, x_filter,
+                                   horiz_const);
+        res5 = convolve8_x_8x8_s16(s5, s6, s7, s8, s9, s10, s11, s12, x_filter,
+                                   horiz_const);
+        res6 = convolve8_x_8x8_s16(s6, s7, s8, s9, s10, s11, s12, s13, x_filter,
+                                   horiz_const);
+        res7 = convolve8_x_8x8_s16(s7, s8, s9, s10, s11, s12, s13, s14,
+                                   x_filter, horiz_const);
 
-        res0 = vrshlq_s16(res0, horiz_const);
+        // We halved the convolution filter values so -1 from the right shift.
+        res0 = vshrq_n_s16(res0, ROUND0_BITS - 1);
+        res1 = vshrq_n_s16(res1, ROUND0_BITS - 1);
+        res2 = vshrq_n_s16(res2, ROUND0_BITS - 1);
+        res3 = vshrq_n_s16(res3, ROUND0_BITS - 1);
+        res4 = vshrq_n_s16(res4, ROUND0_BITS - 1);
+        res5 = vshrq_n_s16(res5, ROUND0_BITS - 1);
+        res6 = vshrq_n_s16(res6, ROUND0_BITS - 1);
+        res7 = vshrq_n_s16(res7, ROUND0_BITS - 1);
+
         res0 = vaddq_s16(res0, round_offset128);
-
-        res1 = convolve8_8x8_s16(s1, s2, s3, s4, s5, s6, s7, s8, x_filter, zero,
-                                 shift_round_0);
-        res1 = vrshlq_s16(res1, horiz_const);
         res1 = vaddq_s16(res1, round_offset128);
-        res2 = convolve8_8x8_s16(s2, s3, s4, s5, s6, s7, s8, s9, x_filter, zero,
-                                 shift_round_0);
-        res2 = vrshlq_s16(res2, horiz_const);
         res2 = vaddq_s16(res2, round_offset128);
-        res3 = convolve8_8x8_s16(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
-                                 zero, shift_round_0);
-        res3 = vrshlq_s16(res3, horiz_const);
         res3 = vaddq_s16(res3, round_offset128);
-        res4 = convolve8_8x8_s16(s4, s5, s6, s7, s8, s9, s10, s11, x_filter,
-                                 zero, shift_round_0);
-        res4 = vrshlq_s16(res4, horiz_const);
         res4 = vaddq_s16(res4, round_offset128);
-        res5 = convolve8_8x8_s16(s5, s6, s7, s8, s9, s10, s11, s12, x_filter,
-                                 zero, shift_round_0);
-        res5 = vrshlq_s16(res5, horiz_const);
         res5 = vaddq_s16(res5, round_offset128);
-        res6 = convolve8_8x8_s16(s6, s7, s8, s9, s10, s11, s12, s13, x_filter,
-                                 zero, shift_round_0);
-        res6 = vrshlq_s16(res6, horiz_const);
         res6 = vaddq_s16(res6, round_offset128);
-        res7 = convolve8_8x8_s16(s7, s8, s9, s10, s11, s12, s13, s14, x_filter,
-                                 zero, shift_round_0);
-        res7 = vrshlq_s16(res7, horiz_const);
         res7 = vaddq_s16(res7, round_offset128);
 
         transpose_s16_8x8(&res0, &res1, &res2, &res3, &res4, &res5, &res6,
@@ -1974,10 +2015,10 @@
         s6 = vextq_s16(temp_0, s7, 6);  // a6 a7 a8 a9 a10 a11 a12 a13
         s7 = vextq_s16(temp_0, s7, 7);  // a7 a8 a9 a10 a11 a12 a13 a14
 
-        res0 = convolve8_8x8_s16(temp_0, s1, s2, s3, s4, s5, s6, s7, x_filter,
-                                 zero, shift_round_0);
-
-        res0 = vrshlq_s16(res0, horiz_const);
+        res0 = convolve8_x_8x8_s16(temp_0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                   horiz_const);
+        // We halved the convolution filter values so -1 from the right shift.
+        res0 = vshrq_n_s16(res0, ROUND0_BITS - 1);
         res0 = vaddq_s16(res0, round_offset128);
 
         if (conv_params->do_average) {