Split av1_convolve_2d_sr_neon into horizontal/vertical helpers

Split av1_convolve_2d_sr_neon into separate helper functions for the
horiztonal and vertical convolution passes. A faster dot-product
implementation of the horizontal convolution will be added in a
subsequent patch.

Change-Id: Ia33ed4a4727d36c274f406c42cc052646df1cfb6
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index ef49e29..7387a13 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -935,59 +935,31 @@
   } while (height > 0);
 }
 
-void av1_convolve_2d_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
-                             int dst_stride, int w, int h,
-                             const InterpFilterParams *filter_params_x,
-                             const InterpFilterParams *filter_params_y,
-                             const int subpel_x_qn, const int subpel_y_qn,
-                             ConvolveParams *conv_params) {
-  if (filter_params_x->taps > 8) {
-    av1_convolve_2d_sr_c(src, src_stride, dst, dst_stride, w, h,
-                         filter_params_x, filter_params_y, subpel_x_qn,
-                         subpel_y_qn, conv_params);
-    return;
-  }
-
-#if defined(__aarch64__)
-  uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
-#endif
-
-  DECLARE_ALIGNED(16, int16_t,
-                  im_block[(MAX_SB_SIZE + HORIZ_EXTRA_ROWS) * MAX_SB_SIZE]);
-
+static INLINE void av1_convolve_2d_sr_horiz_neon(
+    const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
+    int im_h, const int16x8_t x_filter_s16, const int round_0) {
   const int bd = 8;
-  const int im_h = h + filter_params_y->taps - 1;
-  const int im_stride = MAX_SB_SIZE;
-  const int vert_offset = filter_params_y->taps / 2 - 1;
-  const int horiz_offset = filter_params_x->taps / 2 - 1;
 
-  const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
+  const uint8_t *src_ptr = src;
   int16_t *dst_ptr = im_block;
+  int dst_stride = im_stride;
 
-  int im_dst_stride = im_stride;
-  int width = w;
   int height = im_h;
 
-  const int16_t round_bits =
-      FILTER_BITS * 2 - conv_params->round_0 - conv_params->round_1;
-  const int16x8_t vec_round_bits = vdupq_n_s16(-round_bits);
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
-      filter_params_x, subpel_x_qn & SUBPEL_MASK);
-
   // Filter values are even, so downshift by 1 to reduce intermediate precision
   // requirements.
-  const int16x8_t x_filter = vshrq_n_s16(vld1q_s16(x_filter_ptr), 1);
+  const int16x8_t x_filter = vshrq_n_s16(x_filter_s16, 1);
 
-  assert(conv_params->round_0 > 0);
+  assert(round_0 > 0);
 
   if (w <= 4) {
     const int16x4_t horiz_const = vdup_n_s16((1 << (bd + FILTER_BITS - 2)));
-    const int16x4_t shift_round_0 = vdup_n_s16(-(conv_params->round_0 - 1));
+    const int16x4_t shift_round_0 = vdup_n_s16(-(round_0 - 1));
 
 #if defined(__aarch64__)
     do {
       int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3;
+      uint8x8_t t0, t1, t2, t3;
       const uint8_t *s = src_ptr;
 
       assert(height >= 4);
@@ -1025,50 +997,53 @@
       transpose_s16_4x4d(&d0, &d1, &d2, &d3);
 
       if (w == 2) {
-        vst1_lane_u32((uint32_t *)(dst_ptr + 0 * im_dst_stride),
+        vst1_lane_u32((uint32_t *)(dst_ptr + 0 * dst_stride),
                       vreinterpret_u32_s16(d0), 0);
-        vst1_lane_u32((uint32_t *)(dst_ptr + 1 * im_dst_stride),
+        vst1_lane_u32((uint32_t *)(dst_ptr + 1 * dst_stride),
                       vreinterpret_u32_s16(d1), 0);
-        vst1_lane_u32((uint32_t *)(dst_ptr + 2 * im_dst_stride),
+        vst1_lane_u32((uint32_t *)(dst_ptr + 2 * dst_stride),
                       vreinterpret_u32_s16(d2), 0);
-        vst1_lane_u32((uint32_t *)(dst_ptr + 3 * im_dst_stride),
+        vst1_lane_u32((uint32_t *)(dst_ptr + 3 * dst_stride),
                       vreinterpret_u32_s16(d3), 0);
       } else {
-        vst1_s16((dst_ptr + 0 * im_dst_stride), d0);
-        vst1_s16((dst_ptr + 1 * im_dst_stride), d1);
-        vst1_s16((dst_ptr + 2 * im_dst_stride), d2);
-        vst1_s16((dst_ptr + 3 * im_dst_stride), d3);
+        vst1_s16((dst_ptr + 0 * dst_stride), d0);
+        vst1_s16((dst_ptr + 1 * dst_stride), d1);
+        vst1_s16((dst_ptr + 2 * dst_stride), d2);
+        vst1_s16((dst_ptr + 3 * dst_stride), d3);
       }
 
       src_ptr += 4 * src_stride;
-      dst_ptr += 4 * im_dst_stride;
+      dst_ptr += 4 * dst_stride;
       height -= 4;
     } while (height >= 4);
 
     if (height) {
       assert(height < 4);
-      horiz_filter_w4_single_row(src_ptr, src_stride, dst_ptr, im_dst_stride, w,
+      horiz_filter_w4_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
                                  height, x_filter, horiz_const, shift_round_0);
     }
 
 #else   // !defined(__aarch64__)
-    horiz_filter_w4_single_row(src_ptr, src_stride, dst_ptr, im_dst_stride, w,
+    horiz_filter_w4_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
                                height, x_filter, horiz_const, shift_round_0);
 #endif  // defined(__aarch64__)
 
   } else {
     const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)));
-    const int16x8_t shift_round_0 = vdupq_n_s16(-(conv_params->round_0 - 1));
+    const int16x8_t shift_round_0 = vdupq_n_s16(-(round_0 - 1));
 
 #if defined(__aarch64__)
 
     for (; height >= 8; height -= 8) {
       int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14,
           d0, d1, d2, d3, d4, d5, d6, d7;
-      const uint8_t *s;
-      int16_t *d;
+      uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
 
-      load_u8_8x8(src_ptr, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+      const uint8_t *s = src_ptr;
+      int16_t *d = dst_ptr;
+      int width = w;
+
+      load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
 
       transpose_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
 
@@ -1080,9 +1055,7 @@
       s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
       s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
 
-      width = w;
-      s = src_ptr + 7;
-      d = dst_ptr;
+      s += 7;
 
       do {
         load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
@@ -1117,7 +1090,7 @@
 
         transpose_s16_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
 
-        store_s16_8x8(d, im_dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
+        store_s16_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
 
         s0 = s8;
         s1 = s9;
@@ -1132,15 +1105,18 @@
       } while (width > 0);
 
       src_ptr += 8 * src_stride;
-      dst_ptr += 8 * im_dst_stride;
+      dst_ptr += 8 * dst_stride;
     }
 
     for (; height >= 4; height -= 4) {
       int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14,
           dd0, dd1, dd2, dd3, dd4, dd5, dd6, dd7;
       int16x8_t d0, d1, d2, d3;
-      const uint8_t *s;
-      int16_t *d;
+      uint8x8_t t0, t1, t2, t3;
+
+      const uint8_t *s = src_ptr;
+      int16_t *d = dst_ptr;
+      int width = w;
 
       load_u8_8x4(src_ptr, src_stride, &t0, &t1, &t2, &t3);
       transpose_u8_8x4(&t0, &t1, &t2, &t3);
@@ -1153,9 +1129,7 @@
       s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
       s6 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
 
-      s = src_ptr + 7;
-      d = dst_ptr;
-      width = w;
+      s += 7;
 
       do {
         load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
@@ -1192,7 +1166,7 @@
         d2 = vqrshlq_s16(d2, shift_round_0);
         d3 = vqrshlq_s16(d3, shift_round_0);
 
-        store_s16_8x4(d, im_dst_stride, d0, d1, d2, d3);
+        store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
 
         s0 = s8;
         s1 = s9;
@@ -1207,103 +1181,179 @@
       } while (width > 0);
 
       src_ptr += 4 * src_stride;
-      dst_ptr += 4 * im_dst_stride;
+      dst_ptr += 4 * dst_stride;
     }
 
     if (height) {
       assert(height < 4);
-      horiz_filter_w8_single_row(src_ptr, src_stride, dst_ptr, im_dst_stride, w,
+      horiz_filter_w8_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
                                  height, x_filter, horiz_const, shift_round_0);
     }
 
 #else   // !defined(__aarch64__)
-    horiz_filter_w8_single_row(src_ptr, src_stride, dst_ptr, im_dst_stride, w,
+    horiz_filter_w8_single_row(src_ptr, src_stride, dst_ptr, dst_stride, w,
                                height, x_filter, horiz_const, shift_round_0);
 #endif  // defined(__aarch64__)
   }
+}
 
-  // vertical
-  {
-    const int32_t sub_const = (1 << (offset_bits - conv_params->round_1)) +
-                              (1 << (offset_bits - conv_params->round_1 - 1));
-    const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
-        filter_params_y, subpel_y_qn & SUBPEL_MASK);
-    const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
+static INLINE void av1_convolve_2d_sr_vert_neon(
+    int16_t *src_ptr, int src_stride, uint8_t *dst_ptr, int dst_stride, int w,
+    int h, const int16x8_t y_filter, ConvolveParams *conv_params) {
+  const int bd = 8;
+  const int16_t round_bits =
+      FILTER_BITS * 2 - conv_params->round_0 - conv_params->round_1;
+  const int16x8_t vec_round_bits = vdupq_n_s16(-round_bits);
+  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
 
-    const int32x4_t round_shift_vec = vdupq_n_s32(-(conv_params->round_1));
-    const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
-    const int32x4_t sub_const_vec = vdupq_n_s32(sub_const);
+  const int32_t sub_const = (1 << (offset_bits - conv_params->round_1)) +
+                            (1 << (offset_bits - conv_params->round_1 - 1));
 
-    src_stride = im_stride;
-    int16_t *v_src_ptr = im_block;
-    uint8_t *v_dst_ptr = dst;
+  const int32x4_t round_shift_vec = vdupq_n_s32(-(conv_params->round_1));
+  const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
+  const int32x4_t sub_const_vec = vdupq_n_s32(sub_const);
 
-    height = h;
-    width = w;
-
-    if (width <= 4) {
-      int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, d0;
-      int16x8_t dd0;
-      uint8x8_t d01;
+  if (w <= 4) {
+    int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, d0;
+    int16x8_t dd0;
+    uint8x8_t d01;
 
 #if defined(__aarch64__)
-      int16x4_t s8, s9, s10, d1, d2, d3;
-      int16x8_t dd1;
-      uint8x8_t d23;
+    int16x4_t s8, s9, s10, d1, d2, d3;
+    int16x8_t dd1;
+    uint8x8_t d23;
 #endif  // defined(__aarch64__)
 
-      int16_t *s = v_src_ptr;
-      uint8_t *d = v_dst_ptr;
+    int16_t *s = src_ptr;
+    uint8_t *d = dst_ptr;
 
-      load_s16_4x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
+    load_s16_4x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
+    s += (7 * src_stride);
+
+    do {
+#if defined(__aarch64__)
+      load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10);
+      s += (4 * src_stride);
+
+      d0 = convolve8_vert_4x4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
+                                  round_shift_vec, offset_const, sub_const_vec);
+      d1 = convolve8_vert_4x4_s32(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
+                                  round_shift_vec, offset_const, sub_const_vec);
+      d2 = convolve8_vert_4x4_s32(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
+                                  round_shift_vec, offset_const, sub_const_vec);
+      d3 = convolve8_vert_4x4_s32(s3, s4, s5, s6, s7, s8, s9, s10, y_filter,
+                                  round_shift_vec, offset_const, sub_const_vec);
+
+      dd0 = vqrshlq_s16(vcombine_s16(d0, d1), vec_round_bits);
+      dd1 = vqrshlq_s16(vcombine_s16(d2, d3), vec_round_bits);
+
+      d01 = vqmovun_s16(dd0);
+      d23 = vqmovun_s16(dd1);
+
+      if (w == 4) {
+        vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d01), 0);
+        d += dst_stride;
+        vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d01), 1);
+        d += dst_stride;
+        if (h != 2) {
+          vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d23), 0);
+          d += dst_stride;
+          vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d23), 1);
+          d += dst_stride;
+        }
+      } else {
+        vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d01), 0);
+        d += dst_stride;
+        vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d01), 2);
+        d += dst_stride;
+        if (h != 2) {
+          vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d23), 0);
+          d += dst_stride;
+          vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d23), 2);
+          d += dst_stride;
+        }
+      }
+
+      s0 = s4;
+      s1 = s5;
+      s2 = s6;
+      s3 = s7;
+      s4 = s8;
+      s5 = s9;
+      s6 = s10;
+      h -= 4;
+#else   // !defined(__aarch64__)
+      s7 = vld1_s16(s);
+      s += src_stride;
+
+      d0 = convolve8_vert_4x4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
+                                  round_shift_vec, offset_const, sub_const_vec);
+
+      dd0 = vqrshlq_s16(vcombine_s16(d0, d0), vec_round_bits);
+      d01 = vqmovun_s16(dd0);
+
+      if (w == 2) {
+        vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d01), 0);
+        d += dst_stride;
+      } else {
+        vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d01), 0);
+        d += dst_stride;
+      }
+
+      s0 = s1;
+      s1 = s2;
+      s2 = s3;
+      s3 = s4;
+      s4 = s5;
+      s5 = s6;
+      s6 = s7;
+      h--;
+#endif  // defined(__aarch64__)
+    } while (h > 0);
+  } else {
+    // if width is a multiple of 8 & height is a multiple of 4
+    int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
+    uint8x8_t d0;
+#if defined(__aarch64__)
+    int16x8_t s8, s9, s10;
+    uint8x8_t d1, d2, d3;
+#endif  // defined(__aarch64__)
+
+    do {
+      int height = h;
+      int16_t *s = src_ptr;
+      uint8_t *d = dst_ptr;
+
+      load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
       s += (7 * src_stride);
 
       do {
 #if defined(__aarch64__)
-        load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10);
+        load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
         s += (4 * src_stride);
 
-        d0 = convolve8_vert_4x4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
+        d0 = convolve8_vert_8x4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
                                     round_shift_vec, offset_const,
-                                    sub_const_vec);
-        d1 = convolve8_vert_4x4_s32(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
+                                    sub_const_vec, vec_round_bits);
+        d1 = convolve8_vert_8x4_s32(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
                                     round_shift_vec, offset_const,
-                                    sub_const_vec);
-        d2 = convolve8_vert_4x4_s32(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
+                                    sub_const_vec, vec_round_bits);
+        d2 = convolve8_vert_8x4_s32(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
                                     round_shift_vec, offset_const,
-                                    sub_const_vec);
-        d3 = convolve8_vert_4x4_s32(s3, s4, s5, s6, s7, s8, s9, s10, y_filter,
+                                    sub_const_vec, vec_round_bits);
+        d3 = convolve8_vert_8x4_s32(s3, s4, s5, s6, s7, s8, s9, s10, y_filter,
                                     round_shift_vec, offset_const,
-                                    sub_const_vec);
+                                    sub_const_vec, vec_round_bits);
 
-        dd0 = vqrshlq_s16(vcombine_s16(d0, d1), vec_round_bits);
-        dd1 = vqrshlq_s16(vcombine_s16(d2, d3), vec_round_bits);
-
-        d01 = vqmovun_s16(dd0);
-        d23 = vqmovun_s16(dd1);
-
-        if (w == 4) {
-          vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d01), 0);
+        vst1_u8(d, d0);
+        d += dst_stride;
+        vst1_u8(d, d1);
+        d += dst_stride;
+        if (h != 2) {
+          vst1_u8(d, d2);
           d += dst_stride;
-          vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d01), 1);
+          vst1_u8(d, d3);
           d += dst_stride;
-          if (h != 2) {
-            vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d23), 0);
-            d += dst_stride;
-            vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d23), 1);
-            d += dst_stride;
-          }
-        } else {
-          vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d01), 0);
-          d += dst_stride;
-          vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d01), 2);
-          d += dst_stride;
-          if (h != 2) {
-            vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d23), 0);
-            d += dst_stride;
-            vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d23), 2);
-            d += dst_stride;
-          }
         }
 
         s0 = s4;
@@ -1315,23 +1365,15 @@
         s6 = s10;
         height -= 4;
 #else   // !defined(__aarch64__)
-        s7 = vld1_s16(s);
+        s7 = vld1q_s16(s);
         s += src_stride;
 
-        d0 = convolve8_vert_4x4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
+        d0 = convolve8_vert_8x4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
                                     round_shift_vec, offset_const,
-                                    sub_const_vec);
+                                    sub_const_vec, vec_round_bits);
 
-        dd0 = vqrshlq_s16(vcombine_s16(d0, d0), vec_round_bits);
-        d01 = vqmovun_s16(dd0);
-
-        if (w == 2) {
-          vst1_lane_u16((uint16_t *)d, vreinterpret_u16_u8(d01), 0);
-          d += dst_stride;
-        } else {
-          vst1_lane_u32((uint32_t *)d, vreinterpret_u32_u8(d01), 0);
-          d += dst_stride;
-        }
+        vst1_u8(d, d0);
+        d += dst_stride;
 
         s0 = s1;
         s1 = s2;
@@ -1343,88 +1385,46 @@
         height--;
 #endif  // defined(__aarch64__)
       } while (height > 0);
-    } else {
-      // if width is a multiple of 8 & height is a multiple of 4
-      int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
-      uint8x8_t d0;
-#if defined(__aarch64__)
-      int16x8_t s8, s9, s10;
-      uint8x8_t d1, d2, d3;
-#endif  // defined(__aarch64__)
 
-      do {
-        int16_t *s = v_src_ptr;
-        uint8_t *d = v_dst_ptr;
+      src_ptr += 8;
+      dst_ptr += 8;
+      w -= 8;
+    } while (w > 0);
+  }
+}
 
-        load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
-        s += (7 * src_stride);
+void av1_convolve_2d_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
+                             int dst_stride, int w, int h,
+                             const InterpFilterParams *filter_params_x,
+                             const InterpFilterParams *filter_params_y,
+                             const int subpel_x_qn, const int subpel_y_qn,
+                             ConvolveParams *conv_params) {
+  if (filter_params_x->taps > 8) {
+    av1_convolve_2d_sr_c(src, src_stride, dst, dst_stride, w, h,
+                         filter_params_x, filter_params_y, subpel_x_qn,
+                         subpel_y_qn, conv_params);
+  } else {
+    DECLARE_ALIGNED(16, int16_t,
+                    im_block[(MAX_SB_SIZE + HORIZ_EXTRA_ROWS) * MAX_SB_SIZE]);
 
-        height = h;
+    const int im_h = h + filter_params_y->taps - 1;
+    const int im_stride = MAX_SB_SIZE;
+    const int vert_offset = filter_params_y->taps / 2 - 1;
+    const int horiz_offset = filter_params_x->taps / 2 - 1;
+    const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
 
-        do {
-#if defined(__aarch64__)
-          load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
-          s += (4 * src_stride);
+    const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
+        filter_params_x, subpel_x_qn & SUBPEL_MASK);
+    const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
+        filter_params_y, subpel_y_qn & SUBPEL_MASK);
+    const int16x8_t x_filter = vld1q_s16(x_filter_ptr);
+    const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
 
-          d0 = convolve8_vert_8x4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
-                                      round_shift_vec, offset_const,
-                                      sub_const_vec, vec_round_bits);
-          d1 = convolve8_vert_8x4_s32(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
-                                      round_shift_vec, offset_const,
-                                      sub_const_vec, vec_round_bits);
-          d2 = convolve8_vert_8x4_s32(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
-                                      round_shift_vec, offset_const,
-                                      sub_const_vec, vec_round_bits);
-          d3 = convolve8_vert_8x4_s32(s3, s4, s5, s6, s7, s8, s9, s10, y_filter,
-                                      round_shift_vec, offset_const,
-                                      sub_const_vec, vec_round_bits);
+    av1_convolve_2d_sr_horiz_neon(src_ptr, src_stride, im_block, im_stride, w,
+                                  im_h, x_filter, conv_params->round_0);
 
-          vst1_u8(d, d0);
-          d += dst_stride;
-          vst1_u8(d, d1);
-          d += dst_stride;
-          if (h != 2) {
-            vst1_u8(d, d2);
-            d += dst_stride;
-            vst1_u8(d, d3);
-            d += dst_stride;
-          }
-
-          s0 = s4;
-          s1 = s5;
-          s2 = s6;
-          s3 = s7;
-          s4 = s8;
-          s5 = s9;
-          s6 = s10;
-          height -= 4;
-#else   // !defined(__aarch64__)
-          s7 = vld1q_s16(s);
-          s += src_stride;
-
-          d0 = convolve8_vert_8x4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
-                                      round_shift_vec, offset_const,
-                                      sub_const_vec, vec_round_bits);
-
-          vst1_u8(d, d0);
-          d += dst_stride;
-
-          s0 = s1;
-          s1 = s2;
-          s2 = s3;
-          s3 = s4;
-          s4 = s5;
-          s5 = s6;
-          s6 = s7;
-          height--;
-#endif  // defined(__aarch64__)
-        } while (height > 0);
-
-        v_src_ptr += 8;
-        v_dst_ptr += 8;
-        w -= 8;
-      } while (w > 0);
-    }
+    av1_convolve_2d_sr_vert_neon(im_block, im_stride, dst, dst_stride, w, h,
+                                 y_filter, conv_params);
   }
 }