Refactor Neon compound convolution functions 3/7

Refactor the Armv8.4 and Armv8.6 dot-product implementations of
av1_dist_wtd_convolve_x_neon:

Move the final right shift into the convolution inline functions and
return unsigned types - removing the need for a lot of bulky result
vector type casting.

Change-Id: Ic22c88815c9b31af4f3dcabda9e59530e67c1bbd
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index 177a837..243df36 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -1102,23 +1102,79 @@
 
 #if defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
 
+static INLINE uint16x4_t convolve8_4_x(uint8x16_t samples,
+                                       const int8x8_t x_filter,
+                                       const uint8x16x2_t permute_tbl,
+                                       const int32x4_t round_offset) {
+  uint8x16_t permuted_samples[2];
+  int32x4_t sum;
+
+  // Permute samples ready for dot product.
+  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
+  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
+  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
+  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
+
+  // First 4 output values.
+  sum = vusdotq_lane_s32(round_offset, permuted_samples[0], x_filter, 0);
+  sum = vusdotq_lane_s32(sum, permuted_samples[1], x_filter, 1);
+
+  // We halved the convolution filter values so -1 from the right shift.
+  return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
+}
+
+static INLINE uint16x8_t convolve8_8_x(uint8x16_t samples,
+                                       const int8x8_t x_filter,
+                                       const uint8x16x3_t permute_tbl,
+                                       const int32x4_t round_offset) {
+  uint8x16_t permuted_samples[3];
+  int32x4_t sum[2];
+
+  // Permute samples ready for dot product.
+  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
+  permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
+  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
+  permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
+  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
+  permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
+
+  // First 4 output values.
+  sum[0] = vusdotq_lane_s32(round_offset, permuted_samples[0], x_filter, 0);
+  sum[0] = vusdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
+  // Second 4 output values.
+  sum[1] = vusdotq_lane_s32(round_offset, permuted_samples[1], x_filter, 0);
+  sum[1] = vusdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
+
+  // Narrow and re-pack.
+  // We halved the convolution filter values so -1 from the right shift.
+  int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
+                               vshrn_n_s32(sum[1], ROUND0_BITS - 1));
+  return vreinterpretq_u16_s16(res);
+}
+
 void av1_dist_wtd_convolve_x_neon(const uint8_t *src, int src_stride,
                                   uint8_t *dst8, int dst8_stride, int w, int h,
                                   const InterpFilterParams *filter_params_x,
                                   const int subpel_x_qn,
                                   ConvolveParams *conv_params) {
-  assert(!(w % 4));
-  assert(!(h % 4));
+  assert(w % 4 == 0);
+  assert(h % 4 == 0);
 
-  const int horiz_offset = filter_params_x->taps / 2 - 1;
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
+  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // (The extra -1 is needed because we halved the filter values.)
+  const int32x4_t round_offset_shim = vdupq_n_s32(
+      (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
+
+  const int do_average = conv_params->do_average;
+  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
-  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
 
   // Horizontal filter.
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
@@ -1126,18 +1182,12 @@
   // Filter values are even, so downshift by 1 to reduce intermediate precision
   // requirements.
   const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
-  // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
-  // shifts - which are generally faster than rounding shifts on modern CPUs.
-  // The outermost -1 is needed because we halved the filter values.
-  const int32x4_t horiz_const = vdupq_n_s32(
-      (round_offset << (ROUND0_BITS - 1)) + (1 << ((ROUND0_BITS - 1) - 1)));
 
+  const int horiz_offset = filter_params_x->taps / 2 - 1;
   const uint8_t *src_ptr = src - horiz_offset;
-  CONV_BUF_TYPE *dst = conv_params->dst;
-  CONV_BUF_TYPE *dst_ptr = dst;
-  uint8_t *dst_u8_ptr = dst8;
+  CONV_BUF_TYPE *dst_ptr = conv_params->dst;
+  uint8_t *dst8_ptr = dst8;
   int dst_stride = conv_params->dst_stride;
-  int width = w;
   int height = h;
 
   if (w == 4) {
@@ -1145,122 +1195,161 @@
 
     do {
       uint8x16_t s0, s1, s2, s3;
-      int32x4_t d0, d1, d2, d3;
-      int16x8_t d01, d23;
-      uint16x4_t dd0, dd1, dd2, dd3;
+      uint16x4_t d0, d1, d2, d3, dd0, dd1, dd2, dd3;
       uint8x8_t d01_u8, d23_u8;
 
       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
 
-      d0 = convolve8_4_usdot(s0, x_filter, permute_tbl, horiz_const);
-      d1 = convolve8_4_usdot(s1, x_filter, permute_tbl, horiz_const);
-      d2 = convolve8_4_usdot(s2, x_filter, permute_tbl, horiz_const);
-      d3 = convolve8_4_usdot(s3, x_filter, permute_tbl, horiz_const);
+      d0 = convolve8_4_x(s0, x_filter, permute_tbl, round_offset_shim);
+      d1 = convolve8_4_x(s1, x_filter, permute_tbl, round_offset_shim);
+      d2 = convolve8_4_x(s2, x_filter, permute_tbl, round_offset_shim);
+      d3 = convolve8_4_x(s3, x_filter, permute_tbl, round_offset_shim);
 
-      // We halved the convolution filter values so -1 from the right shift.
-      d01 = vcombine_s16(vshrn_n_s32(d0, ROUND0_BITS - 1),
-                         vshrn_n_s32(d1, ROUND0_BITS - 1));
-      d23 = vcombine_s16(vshrn_n_s32(d2, ROUND0_BITS - 1),
-                         vshrn_n_s32(d3, ROUND0_BITS - 1));
-
-      if (conv_params->do_average) {
+      if (do_average) {
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
-        compute_avg_4x4(
-            dd0, dd1, dd2, dd3, vreinterpret_u16_s16(vget_low_s16(d01)),
-            vreinterpret_u16_s16(vget_high_s16(d01)),
-            vreinterpret_u16_s16(vget_low_s16(d23)),
-            vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset, bck_offset,
-            round_offset_vec, use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+        compute_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
+                        bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
+                        &d01_u8, &d23_u8);
 
-        store_u8_4x1(dst_u8_ptr + 0 * dst8_stride, d01_u8, 0);
-        store_u8_4x1(dst_u8_ptr + 1 * dst8_stride, d01_u8, 1);
-        store_u8_4x1(dst_u8_ptr + 2 * dst8_stride, d23_u8, 0);
-        store_u8_4x1(dst_u8_ptr + 3 * dst8_stride, d23_u8, 1);
+        store_u8_4x1(dst8_ptr + 0 * dst8_stride, d01_u8, 0);
+        store_u8_4x1(dst8_ptr + 1 * dst8_stride, d01_u8, 1);
+        store_u8_4x1(dst8_ptr + 2 * dst8_stride, d23_u8, 0);
+        store_u8_4x1(dst8_ptr + 3 * dst8_stride, d23_u8, 1);
       } else {
-        store_u16_4x4(dst_ptr, dst_stride,
-                      vreinterpret_u16_s16(vget_low_s16(d01)),
-                      vreinterpret_u16_s16(vget_high_s16(d01)),
-                      vreinterpret_u16_s16(vget_low_s16(d23)),
-                      vreinterpret_u16_s16(vget_high_s16(d23)));
+        store_u16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
       }
 
       src_ptr += 4 * src_stride;
       dst_ptr += 4 * dst_stride;
-      dst_u8_ptr += 4 * dst8_stride;
+      dst8_ptr += 4 * dst8_stride;
       height -= 4;
-    } while (height > 0);
+    } while (height != 0);
   } else {
     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
 
     do {
       const uint8_t *s = src_ptr;
       CONV_BUF_TYPE *d = dst_ptr;
-      uint8_t *d_u8 = dst_u8_ptr;
-      width = w;
+      uint8_t *d_u8 = dst8_ptr;
+      int width = w;
 
       do {
         uint8x16_t s0, s1, s2, s3;
-        int16x8_t d0, d1, d2, d3;
-        uint16x8_t dd0, dd1, dd2, dd3;
+        uint16x8_t d0, d1, d2, d3, dd0, dd1, dd2, dd3;
         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
 
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = convolve8_horiz_8_usdot(s0, x_filter, permute_tbl, horiz_const);
-        d1 = convolve8_horiz_8_usdot(s1, x_filter, permute_tbl, horiz_const);
-        d2 = convolve8_horiz_8_usdot(s2, x_filter, permute_tbl, horiz_const);
-        d3 = convolve8_horiz_8_usdot(s3, x_filter, permute_tbl, horiz_const);
+        d0 = convolve8_8_x(s0, x_filter, permute_tbl, round_offset_shim);
+        d1 = convolve8_8_x(s1, x_filter, permute_tbl, round_offset_shim);
+        d2 = convolve8_8_x(s2, x_filter, permute_tbl, round_offset_shim);
+        d3 = convolve8_8_x(s3, x_filter, permute_tbl, round_offset_shim);
 
-        if (conv_params->do_average) {
+        if (do_average) {
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
-          compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
-                          vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
-                          vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          round_offset_vec, use_dist_wtd_comp_avg, &d0_u8,
-                          &d1_u8, &d2_u8, &d3_u8);
+          compute_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
+                          bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
+                          &d0_u8, &d1_u8, &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
         } else {
-          store_u16_8x4(d, dst_stride, vreinterpretq_u16_s16(d0),
-                        vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
-                        vreinterpretq_u16_s16(d3));
+          store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
         }
 
         s += 8;
         d += 8;
         d_u8 += 8;
         width -= 8;
-      } while (width > 0);
-
+      } while (width != 0);
       src_ptr += 4 * src_stride;
       dst_ptr += 4 * dst_stride;
-      dst_u8_ptr += 4 * dst8_stride;
+      dst8_ptr += 4 * dst8_stride;
       height -= 4;
-    } while (height > 0);
+    } while (height != 0);
   }
 }
 
 #elif defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
+static INLINE uint16x4_t convolve8_4_x(uint8x16_t samples,
+                                       const int8x8_t x_filter,
+                                       const int32x4_t correction,
+                                       const uint8x16_t range_limit,
+                                       const uint8x16x2_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[2];
+  int32x4_t sum;
+
+  // Clamp sample range to [-128, 127] for 8-bit signed dot product.
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  // Permute samples ready for dot product.
+  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+
+  // Accumulate dot product into 'correction' to account for range clamp.
+  sum = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
+  sum = vdotq_lane_s32(sum, permuted_samples[1], x_filter, 1);
+
+  // We halved the convolution filter values so -1 from the right shift.
+  return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
+}
+
+static INLINE uint16x8_t convolve8_8_x(uint8x16_t samples,
+                                       const int8x8_t x_filter,
+                                       const int32x4_t correction,
+                                       const uint8x16_t range_limit,
+                                       const uint8x16x3_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[3];
+  int32x4_t sum[2];
+
+  // Clamp sample range to [-128, 127] for 8-bit signed dot product.
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  // Permute samples ready for dot product. */
+  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+  // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
+
+  // Accumulate dot product into 'correction' to account for range clamp.
+  // First 4 output values.
+  sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
+  sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
+  // Second 4 output values.
+  sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
+  sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
+
+  // Narrow and re-pack.
+  // We halved the convolution filter values so -1 from the right shift.
+  int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
+                               vshrn_n_s32(sum[1], ROUND0_BITS - 1));
+  return vreinterpretq_u16_s16(res);
+}
+
 void av1_dist_wtd_convolve_x_neon(const uint8_t *src, int src_stride,
                                   uint8_t *dst8, int dst8_stride, int w, int h,
                                   const InterpFilterParams *filter_params_x,
                                   const int subpel_x_qn,
                                   ConvolveParams *conv_params) {
-  assert(!(w % 4));
-  assert(!(h % 4));
+  assert(w % 4 == 0);
+  assert(h % 4 == 0);
 
-  const int horiz_offset = filter_params_x->taps / 2 - 1;
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
   const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
+
+  const int do_average = conv_params->do_average;
+  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
-  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
 
   // Horizontal filter.
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
@@ -1268,22 +1357,23 @@
   // Filter values are even, so downshift by 1 to reduce intermediate precision
   // requirements.
   const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
-  // Dot-product constants.
+
+  // Dot-product constants and other shims.
   const uint8x16_t range_limit = vdupq_n_u8(128);
   const int32_t correction_s32 = vaddlvq_s16(vshll_n_s8(x_filter, 7));
-  // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
-  // shifts - which are generally faster than rounding shifts on modern CPUs.
-  // The outermost -1 is needed because we halved the filter values.
+  // Fold round_offset into the dot-product filter correction constant. The
+  // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
+  // rounding shifts - which are generally faster than rounding shifts on
+  // modern CPUs. (The extra -1 is needed because we halved the filter values.)
   int32x4_t correction =
       vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
                   (1 << ((ROUND0_BITS - 1) - 1)));
 
+  const int horiz_offset = filter_params_x->taps / 2 - 1;
   const uint8_t *src_ptr = src - horiz_offset;
-  CONV_BUF_TYPE *dst = conv_params->dst;
-  CONV_BUF_TYPE *dst_ptr = dst;
-  uint8_t *dst_u8_ptr = dst8;
+  CONV_BUF_TYPE *dst_ptr = conv_params->dst;
+  uint8_t *dst8_ptr = dst8;
   int dst_stride = conv_params->dst_stride;
-  int width = w;
   int height = h;
 
   if (w == 4) {
@@ -1291,104 +1381,79 @@
 
     do {
       uint8x16_t s0, s1, s2, s3;
-      int32x4_t d0, d1, d2, d3;
-      int16x8_t d01, d23;
-      uint16x4_t dd0, dd1, dd2, dd3;
+      uint16x4_t d0, d1, d2, d3, dd0, dd1, dd2, dd3;
       uint8x8_t d01_u8, d23_u8;
 
       load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
 
-      d0 = convolve8_4_sdot(s0, x_filter, correction, range_limit, permute_tbl);
-      d1 = convolve8_4_sdot(s1, x_filter, correction, range_limit, permute_tbl);
-      d2 = convolve8_4_sdot(s2, x_filter, correction, range_limit, permute_tbl);
-      d3 = convolve8_4_sdot(s3, x_filter, correction, range_limit, permute_tbl);
+      d0 = convolve8_4_x(s0, x_filter, correction, range_limit, permute_tbl);
+      d1 = convolve8_4_x(s1, x_filter, correction, range_limit, permute_tbl);
+      d2 = convolve8_4_x(s2, x_filter, correction, range_limit, permute_tbl);
+      d3 = convolve8_4_x(s3, x_filter, correction, range_limit, permute_tbl);
 
-      // We halved the convolution filter values so -1 from the right shift.
-      d01 = vcombine_s16(vshrn_n_s32(d0, ROUND0_BITS - 1),
-                         vshrn_n_s32(d1, ROUND0_BITS - 1));
-      d23 = vcombine_s16(vshrn_n_s32(d2, ROUND0_BITS - 1),
-                         vshrn_n_s32(d3, ROUND0_BITS - 1));
-
-      if (conv_params->do_average) {
+      if (do_average) {
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
-        compute_avg_4x4(
-            dd0, dd1, dd2, dd3, vreinterpret_u16_s16(vget_low_s16(d01)),
-            vreinterpret_u16_s16(vget_high_s16(d01)),
-            vreinterpret_u16_s16(vget_low_s16(d23)),
-            vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset, bck_offset,
-            round_offset_vec, use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+        compute_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
+                        bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
+                        &d01_u8, &d23_u8);
 
-        store_u8_4x1(dst_u8_ptr + 0 * dst8_stride, d01_u8, 0);
-        store_u8_4x1(dst_u8_ptr + 1 * dst8_stride, d01_u8, 1);
-        store_u8_4x1(dst_u8_ptr + 2 * dst8_stride, d23_u8, 0);
-        store_u8_4x1(dst_u8_ptr + 3 * dst8_stride, d23_u8, 1);
+        store_u8_4x1(dst8_ptr + 0 * dst8_stride, d01_u8, 0);
+        store_u8_4x1(dst8_ptr + 1 * dst8_stride, d01_u8, 1);
+        store_u8_4x1(dst8_ptr + 2 * dst8_stride, d23_u8, 0);
+        store_u8_4x1(dst8_ptr + 3 * dst8_stride, d23_u8, 1);
       } else {
-        store_u16_4x4(dst_ptr, dst_stride,
-                      vreinterpret_u16_s16(vget_low_s16(d01)),
-                      vreinterpret_u16_s16(vget_high_s16(d01)),
-                      vreinterpret_u16_s16(vget_low_s16(d23)),
-                      vreinterpret_u16_s16(vget_high_s16(d23)));
+        store_u16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
       }
 
       src_ptr += 4 * src_stride;
       dst_ptr += 4 * dst_stride;
-      dst_u8_ptr += 4 * dst8_stride;
+      dst8_ptr += 4 * dst8_stride;
       height -= 4;
-    } while (height > 0);
+    } while (height != 0);
   } else {
     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
 
     do {
       const uint8_t *s = src_ptr;
       CONV_BUF_TYPE *d = dst_ptr;
-      uint8_t *d_u8 = dst_u8_ptr;
-      width = w;
+      uint8_t *d_u8 = dst8_ptr;
+      int width = w;
 
       do {
         uint8x16_t s0, s1, s2, s3;
-        int16x8_t d0, d1, d2, d3;
-        uint16x8_t dd0, dd1, dd2, dd3;
+        uint16x8_t d0, d1, d2, d3, dd0, dd1, dd2, dd3;
         uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
 
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = convolve8_horiz_8_sdot(s0, x_filter, correction, range_limit,
-                                    permute_tbl);
-        d1 = convolve8_horiz_8_sdot(s1, x_filter, correction, range_limit,
-                                    permute_tbl);
-        d2 = convolve8_horiz_8_sdot(s2, x_filter, correction, range_limit,
-                                    permute_tbl);
-        d3 = convolve8_horiz_8_sdot(s3, x_filter, correction, range_limit,
-                                    permute_tbl);
+        d0 = convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
+        d1 = convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
+        d2 = convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
+        d3 = convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
 
-        if (conv_params->do_average) {
+        if (do_average) {
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
-          compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
-                          vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
-                          vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          round_offset_vec, use_dist_wtd_comp_avg, &d0_u8,
-                          &d1_u8, &d2_u8, &d3_u8);
+          compute_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
+                          bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
+                          &d0_u8, &d1_u8, &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
         } else {
-          store_u16_8x4(d, dst_stride, vreinterpretq_u16_s16(d0),
-                        vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
-                        vreinterpretq_u16_s16(d3));
+          store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
         }
 
         s += 8;
         d += 8;
         d_u8 += 8;
         width -= 8;
-      } while (width > 0);
-
+      } while (width != 0);
       src_ptr += 4 * src_stride;
       dst_ptr += 4 * dst_stride;
-      dst_u8_ptr += 4 * dst8_stride;
+      dst8_ptr += 4 * dst8_stride;
       height -= 4;
-    } while (height > 0);
+    } while (height != 0);
   }
 }