[NEON] Optimize av1_highbd_convolve_2d_scale_neon().

CPU usage drops from 64% to 48% and in superres-mode=1 encoding
time is reduced by 38% (total).

Change-Id: Ia8dc93a5f31a446beb164a59bf4a883b3b764a0e
diff --git a/av1/common/arm/highbd_convolve_neon.c b/av1/common/arm/highbd_convolve_neon.c
index 5a4b9c8..e0dc9b4 100644
--- a/av1/common/arm/highbd_convolve_neon.c
+++ b/av1/common/arm/highbd_convolve_neon.c
@@ -663,16 +663,16 @@
     do {
       load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10);
 
-      d0 = highbd_convolve8_4_rs_s32_s16(s0, s1, s2, s3, s4, s5, s6, s7,
+      d0 = highbd_convolve8_4_sr_s32_s16(s0, s1, s2, s3, s4, s5, s6, s7,
                                          y_filter, round1_shift_s32, offset_s32,
                                          correction_s32);
-      d1 = highbd_convolve8_4_rs_s32_s16(s1, s2, s3, s4, s5, s6, s7, s8,
+      d1 = highbd_convolve8_4_sr_s32_s16(s1, s2, s3, s4, s5, s6, s7, s8,
                                          y_filter, round1_shift_s32, offset_s32,
                                          correction_s32);
-      d2 = highbd_convolve8_4_rs_s32_s16(s2, s3, s4, s5, s6, s7, s8, s9,
+      d2 = highbd_convolve8_4_sr_s32_s16(s2, s3, s4, s5, s6, s7, s8, s9,
                                          y_filter, round1_shift_s32, offset_s32,
                                          correction_s32);
-      d3 = highbd_convolve8_4_rs_s32_s16(s3, s4, s5, s6, s7, s8, s9, s10,
+      d3 = highbd_convolve8_4_sr_s32_s16(s3, s4, s5, s6, s7, s8, s9, s10,
                                          y_filter, round1_shift_s32, offset_s32,
                                          correction_s32);
 
@@ -723,16 +723,16 @@
       do {
         load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
 
-        d0 = highbd_convolve8_8_rs_s32_s16(s0, s1, s2, s3, s4, s5, s6, s7,
+        d0 = highbd_convolve8_8_sr_s32_s16(s0, s1, s2, s3, s4, s5, s6, s7,
                                            y_filter, round1_shift_s32,
                                            offset_s32, correction_s32);
-        d1 = highbd_convolve8_8_rs_s32_s16(s1, s2, s3, s4, s5, s6, s7, s8,
+        d1 = highbd_convolve8_8_sr_s32_s16(s1, s2, s3, s4, s5, s6, s7, s8,
                                            y_filter, round1_shift_s32,
                                            offset_s32, correction_s32);
-        d2 = highbd_convolve8_8_rs_s32_s16(s2, s3, s4, s5, s6, s7, s8, s9,
+        d2 = highbd_convolve8_8_sr_s32_s16(s2, s3, s4, s5, s6, s7, s8, s9,
                                            y_filter, round1_shift_s32,
                                            offset_s32, correction_s32);
-        d3 = highbd_convolve8_8_rs_s32_s16(s3, s4, s5, s6, s7, s8, s9, s10,
+        d3 = highbd_convolve8_8_sr_s32_s16(s3, s4, s5, s6, s7, s8, s9, s10,
                                            y_filter, round1_shift_s32,
                                            offset_s32, correction_s32);
 
@@ -792,16 +792,16 @@
     do {
       load_s16_4x4(s, src_stride, &s11, &s12, &s13, &s14);
 
-      d0 = highbd_convolve12_y_4_rs_s32_s16(
+      d0 = highbd_convolve12_y_4_sr_s32_s16(
           s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, y_filter_0_7,
           y_filter_8_11, round1_shift_s32, offset_s32, correction_s32);
-      d1 = highbd_convolve12_y_4_rs_s32_s16(
+      d1 = highbd_convolve12_y_4_sr_s32_s16(
           s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, y_filter_0_7,
           y_filter_8_11, round1_shift_s32, offset_s32, correction_s32);
-      d2 = highbd_convolve12_y_4_rs_s32_s16(
+      d2 = highbd_convolve12_y_4_sr_s32_s16(
           s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, y_filter_0_7,
           y_filter_8_11, round1_shift_s32, offset_s32, correction_s32);
-      d3 = highbd_convolve12_y_4_rs_s32_s16(
+      d3 = highbd_convolve12_y_4_sr_s32_s16(
           s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, y_filter_0_7,
           y_filter_8_11, round1_shift_s32, offset_s32, correction_s32);
 
@@ -858,16 +858,16 @@
       do {
         load_s16_8x4(s, src_stride, &s11, &s12, &s13, &s14);
 
-        d0 = highbd_convolve12_y_8_rs_s32_s16(
+        d0 = highbd_convolve12_y_8_sr_s32_s16(
             s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, y_filter_0_7,
             y_filter_8_11, round1_shift_s32, offset_s32, correction_s32);
-        d1 = highbd_convolve12_y_8_rs_s32_s16(
+        d1 = highbd_convolve12_y_8_sr_s32_s16(
             s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, y_filter_0_7,
             y_filter_8_11, round1_shift_s32, offset_s32, correction_s32);
-        d2 = highbd_convolve12_y_8_rs_s32_s16(
+        d2 = highbd_convolve12_y_8_sr_s32_s16(
             s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, y_filter_0_7,
             y_filter_8_11, round1_shift_s32, offset_s32, correction_s32);
-        d3 = highbd_convolve12_y_8_rs_s32_s16(
+        d3 = highbd_convolve12_y_8_sr_s32_s16(
             s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, y_filter_0_7,
             y_filter_8_11, round1_shift_s32, offset_s32, correction_s32);
 
@@ -1124,3 +1124,484 @@
         bd, y_offset_initial, y_offset_correction_s32);
   }
 }
+
+static INLINE void highbd_convolve_2d_x_scale_8tap_neon(
+    const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
+    int w, int h, const int subpel_x_qn, const int x_step_qn,
+    const InterpFilterParams *filter_params, ConvolveParams *conv_params,
+    const int offset) {
+  const uint32x4_t idx = { 0, 1, 2, 3 };
+  const uint32x4_t subpel_mask = vdupq_n_u32(SCALE_SUBPEL_MASK);
+  const int32x4_t shift_s32 = vdupq_n_s32(-conv_params->round_0);
+  const int32x4_t offset_s32 = vdupq_n_s32(offset);
+
+  if (w <= 4) {
+    int height = h;
+    int16x8_t s0, s1, s2, s3;
+    uint16x4_t d0;
+
+    uint16_t *d = dst_ptr;
+
+    do {
+      int x_qn = subpel_x_qn;
+
+      // Load 4 src vectors at a time, they might be the same, but we have to
+      // calculate the indices anyway. Doing it in SIMD and then storing the
+      // indices is faster than having to calculate the expression
+      // &src_ptr[((x_qn + 0*x_step_qn) >> SCALE_SUBPEL_BITS)] 4 times
+      // Ideally this should be a gather using the indices, but NEON does not
+      // have that, so have to emulate
+      const uint32x4_t xqn_idx = vmlaq_n_u32(vdupq_n_u32(x_qn), idx, x_step_qn);
+      // We have to multiply x2 to get the actual pointer as sizeof(uint16_t) =
+      // 2
+      const uint32x4_t src_idx_u32 =
+          vshlq_n_u32(vshrq_n_u32(xqn_idx, SCALE_SUBPEL_BITS), 1);
+#if defined(__aarch64__)
+      uint64x2_t src4[2];
+      src4[0] = vaddw_u32(vdupq_n_u64((const uint64_t)src_ptr),
+                          vget_low_u32(src_idx_u32));
+      src4[1] = vaddw_u32(vdupq_n_u64((const uint64_t)src_ptr),
+                          vget_high_u32(src_idx_u32));
+      int16_t *src4_ptr[4];
+      uint64_t *tmp_ptr = (uint64_t *)&src4_ptr;
+      vst1q_u64(tmp_ptr, src4[0]);
+      vst1q_u64(tmp_ptr + 2, src4[1]);
+#else
+      uint32x4_t src4;
+      src4 = vaddq_u32(vdupq_n_u32((const uint32_t)src_ptr), src_idx_u32);
+      int16_t *src4_ptr[4];
+      uint32_t *tmp_ptr = (uint32_t *)&src4_ptr;
+      vst1q_u32(tmp_ptr, src4);
+#endif  // defined(__aarch64__)
+      // Same for the filter vectors
+      const int32x4_t filter_idx_s32 = vreinterpretq_s32_u32(
+          vshrq_n_u32(vandq_u32(xqn_idx, subpel_mask), SCALE_EXTRA_BITS));
+      int32_t x_filter4_idx[4];
+      vst1q_s32(x_filter4_idx, filter_idx_s32);
+      const int16_t *x_filter4_ptr[4];
+
+      // Load source
+      s0 = vld1q_s16(src4_ptr[0]);
+      s1 = vld1q_s16(src4_ptr[1]);
+      s2 = vld1q_s16(src4_ptr[2]);
+      s3 = vld1q_s16(src4_ptr[3]);
+
+      // We could easily do this using SIMD as well instead of calling the
+      // inline function 4 times.
+      x_filter4_ptr[0] =
+          av1_get_interp_filter_subpel_kernel(filter_params, x_filter4_idx[0]);
+      x_filter4_ptr[1] =
+          av1_get_interp_filter_subpel_kernel(filter_params, x_filter4_idx[1]);
+      x_filter4_ptr[2] =
+          av1_get_interp_filter_subpel_kernel(filter_params, x_filter4_idx[2]);
+      x_filter4_ptr[3] =
+          av1_get_interp_filter_subpel_kernel(filter_params, x_filter4_idx[3]);
+
+      // Actually load the filters
+      const int16x8_t x_filter0 = vld1q_s16(x_filter4_ptr[0]);
+      const int16x8_t x_filter1 = vld1q_s16(x_filter4_ptr[1]);
+      const int16x8_t x_filter2 = vld1q_s16(x_filter4_ptr[2]);
+      const int16x8_t x_filter3 = vld1q_s16(x_filter4_ptr[3]);
+
+      // Group low and high parts and transpose
+      int16x4_t filters_lo[] = { vget_low_s16(x_filter0),
+                                 vget_low_s16(x_filter1),
+                                 vget_low_s16(x_filter2),
+                                 vget_low_s16(x_filter3) };
+      int16x4_t filters_hi[] = { vget_high_s16(x_filter0),
+                                 vget_high_s16(x_filter1),
+                                 vget_high_s16(x_filter2),
+                                 vget_high_s16(x_filter3) };
+      transpose_u16_4x4((uint16x4_t *)filters_lo);
+      transpose_u16_4x4((uint16x4_t *)filters_hi);
+
+      // Run the 2D Scale convolution
+      d0 = highbd_convolve8_2d_scale_horiz4x8_s32_s16(
+          s0, s1, s2, s3, filters_lo, filters_hi, shift_s32, offset_s32);
+
+      if (w == 2) {
+        store_u16_2x1(d + 0 * dst_stride, d0, 0);
+      } else {
+        vst1_u16(d + 0 * dst_stride, d0);
+      }
+
+      src_ptr += src_stride;
+      d += dst_stride;
+      height--;
+    } while (height > 0);
+  } else {
+    int height = h;
+    int16x8_t s0, s1, s2, s3;
+    uint16x4_t d0;
+
+    do {
+      int width = w;
+      int x_qn = subpel_x_qn;
+      uint16_t *d = dst_ptr;
+      const uint16_t *s = src_ptr;
+
+      do {
+        // Load 4 src vectors at a time, they might be the same, but we have to
+        // calculate the indices anyway. Doing it in SIMD and then storing the
+        // indices is faster than having to calculate the expression
+        // &src_ptr[((x_qn + 0*x_step_qn) >> SCALE_SUBPEL_BITS)] 4 times
+        // Ideally this should be a gather using the indices, but NEON does not
+        // have that, so have to emulate
+        const uint32x4_t xqn_idx =
+            vmlaq_n_u32(vdupq_n_u32(x_qn), idx, x_step_qn);
+        // We have to multiply x2 to get the actual pointer as sizeof(uint16_t)
+        // = 2
+        const uint32x4_t src_idx_u32 =
+            vshlq_n_u32(vshrq_n_u32(xqn_idx, SCALE_SUBPEL_BITS), 1);
+#if defined(__aarch64__)
+        uint64x2_t src4[2];
+        src4[0] = vaddw_u32(vdupq_n_u64((const uint64_t)s),
+                            vget_low_u32(src_idx_u32));
+        src4[1] = vaddw_u32(vdupq_n_u64((const uint64_t)s),
+                            vget_high_u32(src_idx_u32));
+        int16_t *src4_ptr[4];
+        uint64_t *tmp_ptr = (uint64_t *)&src4_ptr;
+        vst1q_u64(tmp_ptr, src4[0]);
+        vst1q_u64(tmp_ptr + 2, src4[1]);
+#else
+        uint32x4_t src4;
+        src4 = vaddq_u32(vdupq_n_u32((const uint32_t)s), src_idx_u32);
+        int16_t *src4_ptr[4];
+        uint32_t *tmp_ptr = (uint32_t *)&src4_ptr;
+        vst1q_u32(tmp_ptr, src4);
+#endif  // defined(__aarch64__)
+        // Same for the filter vectors
+        const int32x4_t filter_idx_s32 = vreinterpretq_s32_u32(
+            vshrq_n_u32(vandq_u32(xqn_idx, subpel_mask), SCALE_EXTRA_BITS));
+        int32_t x_filter4_idx[4];
+        vst1q_s32(x_filter4_idx, filter_idx_s32);
+        const int16_t *x_filter4_ptr[4];
+
+        // Load source
+        s0 = vld1q_s16(src4_ptr[0]);
+        s1 = vld1q_s16(src4_ptr[1]);
+        s2 = vld1q_s16(src4_ptr[2]);
+        s3 = vld1q_s16(src4_ptr[3]);
+
+        // We could easily do this using SIMD as well instead of calling the
+        // inline function 4 times.
+        x_filter4_ptr[0] = av1_get_interp_filter_subpel_kernel(
+            filter_params, x_filter4_idx[0]);
+        x_filter4_ptr[1] = av1_get_interp_filter_subpel_kernel(
+            filter_params, x_filter4_idx[1]);
+        x_filter4_ptr[2] = av1_get_interp_filter_subpel_kernel(
+            filter_params, x_filter4_idx[2]);
+        x_filter4_ptr[3] = av1_get_interp_filter_subpel_kernel(
+            filter_params, x_filter4_idx[3]);
+
+        // Actually load the filters
+        const int16x8_t x_filter0 = vld1q_s16(x_filter4_ptr[0]);
+        const int16x8_t x_filter1 = vld1q_s16(x_filter4_ptr[1]);
+        const int16x8_t x_filter2 = vld1q_s16(x_filter4_ptr[2]);
+        const int16x8_t x_filter3 = vld1q_s16(x_filter4_ptr[3]);
+
+        // Group low and high parts and transpose
+        int16x4_t filters_lo[] = { vget_low_s16(x_filter0),
+                                   vget_low_s16(x_filter1),
+                                   vget_low_s16(x_filter2),
+                                   vget_low_s16(x_filter3) };
+        int16x4_t filters_hi[] = { vget_high_s16(x_filter0),
+                                   vget_high_s16(x_filter1),
+                                   vget_high_s16(x_filter2),
+                                   vget_high_s16(x_filter3) };
+        transpose_u16_4x4((uint16x4_t *)filters_lo);
+        transpose_u16_4x4((uint16x4_t *)filters_hi);
+
+        // Run the 2D Scale X convolution
+        d0 = highbd_convolve8_2d_scale_horiz4x8_s32_s16(
+            s0, s1, s2, s3, filters_lo, filters_hi, shift_s32, offset_s32);
+
+        vst1_u16(d, d0);
+
+        x_qn += 4 * x_step_qn;
+        d += 4;
+        width -= 4;
+      } while (width > 0);
+
+      src_ptr += src_stride;
+      dst_ptr += dst_stride;
+      height--;
+    } while (height > 0);
+  }
+}
+
+static INLINE void highbd_convolve_2d_y_scale_8tap_neon(
+    const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
+    int w, int h, const int subpel_y_qn, const int y_step_qn,
+    const InterpFilterParams *filter_params, const int round1_bits,
+    const int offset) {
+  const int32x4_t offset_s32 = vdupq_n_s32(1 << offset);
+
+  const int32x4_t round1_shift_s32 = vdupq_n_s32(-round1_bits);
+  if (w <= 4) {
+    int height = h;
+    int16x4_t s0, s1, s2, s3, s4, s5, s6, s7;
+    uint16x4_t d0;
+
+    uint16_t *d = dst_ptr;
+
+    int y_qn = subpel_y_qn;
+    do {
+      const int16_t *s =
+          (const int16_t *)&src_ptr[(y_qn >> SCALE_SUBPEL_BITS) * src_stride];
+
+      load_s16_4x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
+
+      const int y_filter_idx = (y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
+      const int16_t *y_filter_ptr =
+          av1_get_interp_filter_subpel_kernel(filter_params, y_filter_idx);
+      const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
+
+      d0 = highbd_convolve8_4_sr_s32_s16(s0, s1, s2, s3, s4, s5, s6, s7,
+                                         y_filter, round1_shift_s32, offset_s32,
+                                         vdupq_n_s32(0));
+
+      if (w == 2) {
+        store_u16_2x1(d, d0, 0);
+      } else {
+        vst1_u16(d, d0);
+      }
+
+      y_qn += y_step_qn;
+      d += dst_stride;
+      height--;
+    } while (height > 0);
+  } else {
+    int width = w;
+    int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
+    uint16x8_t d0;
+
+    do {
+      int height = h;
+      int y_qn = subpel_y_qn;
+
+      uint16_t *d = dst_ptr;
+
+      do {
+        const int16_t *s =
+            (const int16_t *)&src_ptr[(y_qn >> SCALE_SUBPEL_BITS) * src_stride];
+        load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
+
+        const int y_filter_idx = (y_qn & SCALE_SUBPEL_MASK) >> SCALE_EXTRA_BITS;
+        const int16_t *y_filter_ptr =
+            av1_get_interp_filter_subpel_kernel(filter_params, y_filter_idx);
+        const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
+
+        d0 = highbd_convolve8_8_sr_s32_s16(s0, s1, s2, s3, s4, s5, s6, s7,
+                                           y_filter, round1_shift_s32,
+                                           offset_s32, vdupq_n_s32(0));
+        vst1q_u16(d, d0);
+
+        y_qn += y_step_qn;
+        d += dst_stride;
+        height--;
+      } while (height > 0);
+      src_ptr += 8;
+      dst_ptr += 8;
+      width -= 8;
+    } while (width > 0);
+  }
+}
+
+void av1_highbd_convolve_2d_scale_neon(
+    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
+    int h, const InterpFilterParams *filter_params_x,
+    const InterpFilterParams *filter_params_y, const int subpel_x_qn,
+    const int x_step_qn, const int subpel_y_qn, const int y_step_qn,
+    ConvolveParams *conv_params, int bd) {
+  uint16_t *im_block = (uint16_t *)aom_memalign(
+      16, 2 * sizeof(uint16_t) * MAX_SB_SIZE * (MAX_SB_SIZE + MAX_FILTER_TAP));
+  if (!im_block) return;
+  uint16_t *im_block2 = (uint16_t *)aom_memalign(
+      16, 2 * sizeof(uint16_t) * MAX_SB_SIZE * (MAX_SB_SIZE + MAX_FILTER_TAP));
+  if (!im_block2) {
+    aom_free(im_block);  // free the first block and return.
+    return;
+  }
+
+  int im_h = (((h - 1) * y_step_qn + subpel_y_qn) >> SCALE_SUBPEL_BITS) +
+             filter_params_y->taps;
+  const int im_stride = MAX_SB_SIZE;
+  CONV_BUF_TYPE *dst16 = conv_params->dst;
+  const int dst16_stride = conv_params->dst_stride;
+  const int bits =
+      FILTER_BITS * 2 - conv_params->round_0 - conv_params->round_1;
+  assert(bits >= 0);
+
+  const int vert_offset = filter_params_y->taps / 2 - 1;
+  const int horiz_offset = filter_params_x->taps / 2 - 1;
+  const int x_offset_bits = (1 << (bd + FILTER_BITS - 1));
+  const int y_offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
+  const int y_offset_correction =
+      ((1 << (y_offset_bits - conv_params->round_1)) +
+       (1 << (y_offset_bits - conv_params->round_1 - 1)));
+
+  const int32x4_t final_shift_s32 = vdupq_n_s32(-bits);
+  const int16x4_t y_offset_correction_s16 = vdup_n_s16(y_offset_correction);
+  const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+  uint16x4_t fwd_offset_u16 = vdup_n_u16(conv_params->fwd_offset);
+  uint16x4_t bck_offset_u16 = vdup_n_u16(conv_params->bck_offset);
+
+  const uint16_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
+
+  highbd_convolve_2d_x_scale_8tap_neon(
+      src_ptr, src_stride, im_block, im_stride, w, im_h, subpel_x_qn, x_step_qn,
+      filter_params_x, conv_params, x_offset_bits);
+  if (conv_params->is_compound && !conv_params->do_average) {
+    highbd_convolve_2d_y_scale_8tap_neon(
+        im_block, im_stride, dst16, dst16_stride, w, h, subpel_y_qn, y_step_qn,
+        filter_params_y, conv_params->round_1, y_offset_bits);
+  } else {
+    highbd_convolve_2d_y_scale_8tap_neon(
+        im_block, im_stride, im_block2, im_stride, w, h, subpel_y_qn, y_step_qn,
+        filter_params_y, conv_params->round_1, y_offset_bits);
+  }
+
+  // Do the compound averaging outside the loop, avoids branching within the
+  // main loop
+  if (conv_params->is_compound) {
+    if (conv_params->do_average) {
+      if (conv_params->use_dist_wtd_comp_avg) {
+        // Weighted averaging
+        if (w <= 4) {
+          for (int y = 0; y < h; ++y) {
+            const uint16x4_t s = vld1_u16(im_block2 + y * im_stride);
+            const uint16x4_t d16 = vld1_u16(dst16 + y * dst16_stride);
+            // We use vmull_u16/vmlal_u16 instead of of vmull_s16/vmlal_s16
+            // because the latter sign-extend and the values are non-negative.
+            // However, d0/d1 are signed-integers and we use vqmovun
+            // to do saturated narrowing to unsigned.
+            int32x4_t d0 =
+                vreinterpretq_s32_u32(vmull_u16(d16, fwd_offset_u16));
+            d0 = vreinterpretq_s32_u32(
+                vmlal_u16(vreinterpretq_u32_s32(d0), s, bck_offset_u16));
+            d0 = vshrq_n_s32(d0, DIST_PRECISION_BITS);
+            // Subtract round offset and convolve round
+            d0 = vqrshlq_s32(vsubw_s16(d0, y_offset_correction_s16),
+                             final_shift_s32);
+            uint16x4_t d = vqmovun_s32(d0);
+            d = vmin_u16(d, vget_low_u16(max));
+            if (w == 2) {
+              store_u16_2x1(dst + y * dst_stride, d, 0);
+            } else {
+              vst1_u16(dst + y * dst_stride, d);
+            }
+          }
+        } else {
+          for (int y = 0; y < h; ++y) {
+            for (int x = 0; x < w; x += 8) {
+              const uint16x8_t s = vld1q_u16(im_block2 + y * im_stride + x);
+              const uint16x8_t d16 = vld1q_u16(dst16 + y * dst16_stride + x);
+              // We use vmull_u16/vmlal_u16 instead of of vmull_s16/vmlal_s16
+              // because the latter sign-extend and the values are non-negative.
+              // However, d0/d1 are signed-integers and we use vqmovun
+              // to do saturated narrowing to unsigned.
+              int32x4_t d0 = vreinterpretq_s32_u32(
+                  vmull_u16(vget_low_u16(d16), fwd_offset_u16));
+              int32x4_t d1 = vreinterpretq_s32_u32(
+                  vmull_u16(vget_high_u16(d16), fwd_offset_u16));
+              d0 = vreinterpretq_s32_u32(vmlal_u16(
+                  vreinterpretq_u32_s32(d0), vget_low_u16(s), bck_offset_u16));
+              d1 = vreinterpretq_s32_u32(vmlal_u16(
+                  vreinterpretq_u32_s32(d1), vget_high_u16(s), bck_offset_u16));
+              d0 = vshrq_n_s32(d0, DIST_PRECISION_BITS);
+              d1 = vshrq_n_s32(d1, DIST_PRECISION_BITS);
+              d0 = vqrshlq_s32(vsubw_s16(d0, y_offset_correction_s16),
+                               final_shift_s32);
+              d1 = vqrshlq_s32(vsubw_s16(d1, y_offset_correction_s16),
+                               final_shift_s32);
+              uint16x8_t d01 = vcombine_u16(vqmovun_s32(d0), vqmovun_s32(d1));
+              d01 = vminq_u16(d01, max);
+              vst1q_u16(dst + y * dst_stride + x, d01);
+            }
+          }
+        }
+      } else {
+        if (w <= 4) {
+          for (int y = 0; y < h; ++y) {
+            const uint16x4_t s = vld1_u16(im_block2 + y * im_stride);
+            const uint16x4_t d16 = vld1_u16(dst16 + y * dst16_stride);
+            int32x4_t s_s32 = vreinterpretq_s32_u32(vmovl_u16(s));
+            int32x4_t d16_s32 = vreinterpretq_s32_u32(vmovl_u16(d16));
+            int32x4_t d0 = vhaddq_s32(s_s32, d16_s32);
+            d0 = vsubw_s16(d0, y_offset_correction_s16);
+            d0 = vqrshlq_s32(d0, final_shift_s32);
+            uint16x4_t d = vqmovun_s32(d0);
+            d = vmin_u16(d, vget_low_u16(max));
+            if (w == 2) {
+              store_u16_2x1(dst + y * dst_stride, d, 0);
+            } else {
+              vst1_u16(dst + y * dst_stride, d);
+            }
+          }
+        } else {
+          for (int y = 0; y < h; ++y) {
+            for (int x = 0; x < w; x += 8) {
+              const uint16x8_t s = vld1q_u16(im_block2 + y * im_stride + x);
+              const uint16x8_t d16 = vld1q_u16(dst16 + y * dst16_stride + x);
+              int32x4_t s_lo =
+                  vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(s)));
+              int32x4_t s_hi =
+                  vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(s)));
+              int32x4_t d16_lo =
+                  vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(d16)));
+              int32x4_t d16_hi =
+                  vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(d16)));
+              int32x4_t d0 = vhaddq_s32(s_lo, d16_lo);
+              int32x4_t d1 = vhaddq_s32(s_hi, d16_hi);
+              d0 = vsubw_s16(d0, y_offset_correction_s16);
+              d1 = vsubw_s16(d1, y_offset_correction_s16);
+              d0 = vqrshlq_s32(d0, final_shift_s32);
+              d1 = vqrshlq_s32(d1, final_shift_s32);
+              uint16x8_t d01 = vcombine_u16(vqmovun_s32(d0), vqmovun_s32(d1));
+              d01 = vminq_u16(d01, max);
+              vst1q_u16(dst + y * dst_stride + x, d01);
+            }
+          }
+        }
+      }
+    }
+  } else {
+    // No compound averaging
+    if (w <= 4) {
+      for (int y = 0; y < h; ++y) {
+        // Subtract round offset and convolve round
+        const int16x4_t s =
+            vld1_s16((const int16_t *)(im_block2) + y * im_stride);
+        const int32x4_t d0 =
+            vqrshlq_s32(vsubl_s16(s, y_offset_correction_s16), final_shift_s32);
+        uint16x4_t d = vqmovun_s32(d0);
+        d = vmin_u16(d, vget_low_u16(max));
+        if (w == 2) {
+          store_u16_2x1(dst + y * dst_stride, d, 0);
+        } else {
+          vst1_u16(dst + y * dst_stride, d);
+        }
+      }
+    } else {
+      for (int y = 0; y < h; ++y) {
+        for (int x = 0; x < w; x += 8) {
+          // Subtract round offset and convolve round
+          const int16x8_t s =
+              vld1q_s16((const int16_t *)(im_block2) + y * im_stride + x);
+          const int32x4_t d0 =
+              vqrshlq_s32(vsubl_s16(vget_low_s16(s), y_offset_correction_s16),
+                          final_shift_s32);
+          const int32x4_t d1 =
+              vqrshlq_s32(vsubl_s16(vget_high_s16(s), y_offset_correction_s16),
+                          final_shift_s32);
+          uint16x8_t d01 = vcombine_u16(vqmovun_s32(d0), vqmovun_s32(d1));
+          d01 = vminq_u16(d01, max);
+          vst1q_u16(dst + y * dst_stride + x, d01);
+        }
+      }
+    }
+  }
+  aom_free(im_block);
+  aom_free(im_block2);
+}
diff --git a/av1/common/arm/highbd_convolve_neon.h b/av1/common/arm/highbd_convolve_neon.h
index 7934463..ed33be1 100644
--- a/av1/common/arm/highbd_convolve_neon.h
+++ b/av1/common/arm/highbd_convolve_neon.h
@@ -109,7 +109,7 @@
 }
 
 // Like above but also perform round shifting and subtract correction term
-static INLINE uint16x4_t highbd_convolve8_4_rs_s32_s16(
+static INLINE uint16x4_t highbd_convolve8_4_sr_s32_s16(
     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
     const int16x4_t s6, const int16x4_t s7, const int16x8_t y_filter,
@@ -164,7 +164,7 @@
 }
 
 // Like above but also perform round shifting and subtract correction term
-static INLINE uint16x8_t highbd_convolve8_8_rs_s32_s16(
+static INLINE uint16x8_t highbd_convolve8_8_sr_s32_s16(
     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
     const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
     const int16x8_t s6, const int16x8_t s7, const int16x8_t y_filter,
@@ -222,7 +222,7 @@
 }
 
 // Like above but also perform round shifting and subtract correction term
-static INLINE uint16x4_t highbd_convolve12_y_4_rs_s32_s16(
+static INLINE uint16x4_t highbd_convolve12_y_4_sr_s32_s16(
     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
     const int16x4_t s6, const int16x4_t s7, const int16x4_t s8,
@@ -292,7 +292,7 @@
 }
 
 // Like above but also perform round shifting and subtract correction term
-static INLINE uint16x8_t highbd_convolve12_y_8_rs_s32_s16(
+static INLINE uint16x8_t highbd_convolve12_y_8_sr_s32_s16(
     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
     const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
     const int16x8_t s6, const int16x8_t s7, const int16x8_t s8,
@@ -439,4 +439,40 @@
   return vcombine_u16(vqmovun_s32(sum0), vqmovun_s32(sum1));
 }
 
+static INLINE int32x4_t highbd_convolve8_2d_scale_horiz4x8_s32(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x4_t *filters_lo,
+    const int16x4_t *filters_hi, const int32x4_t offset) {
+  int16x4_t s_lo[] = { vget_low_s16(s0), vget_low_s16(s1), vget_low_s16(s2),
+                       vget_low_s16(s3) };
+  int16x4_t s_hi[] = { vget_high_s16(s0), vget_high_s16(s1), vget_high_s16(s2),
+                       vget_high_s16(s3) };
+
+  transpose_u16_4x4((uint16x4_t *)s_lo);
+  transpose_u16_4x4((uint16x4_t *)s_hi);
+
+  int32x4_t sum = vmlal_s16(offset, s_lo[0], filters_lo[0]);
+  sum = vmlal_s16(sum, s_lo[1], filters_lo[1]);
+  sum = vmlal_s16(sum, s_lo[2], filters_lo[2]);
+  sum = vmlal_s16(sum, s_lo[3], filters_lo[3]);
+  sum = vmlal_s16(sum, s_hi[0], filters_hi[0]);
+  sum = vmlal_s16(sum, s_hi[1], filters_hi[1]);
+  sum = vmlal_s16(sum, s_hi[2], filters_hi[2]);
+  sum = vmlal_s16(sum, s_hi[3], filters_hi[3]);
+
+  return sum;
+}
+
+static INLINE uint16x4_t highbd_convolve8_2d_scale_horiz4x8_s32_s16(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x4_t *filters_lo,
+    const int16x4_t *filters_hi, const int32x4_t shift_s32,
+    const int32x4_t offset) {
+  int32x4_t sum = highbd_convolve8_2d_scale_horiz4x8_s32(
+      s0, s1, s2, s3, filters_lo, filters_hi, offset);
+
+  sum = vqrshlq_s32(sum, shift_s32);
+  return vqmovun_s32(sum);
+}
+
 #endif  // AOM_AV1_COMMON_ARM_HIGHBD_CONVOLVE_NEON_H_
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 9f9961a..31eb440 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -603,7 +603,7 @@
     specialize qw/av1_highbd_convolve_2d_sr ssse3 avx2 neon/;
     specialize qw/av1_highbd_convolve_x_sr ssse3 avx2 neon/;
     specialize qw/av1_highbd_convolve_y_sr ssse3 avx2 neon/;
-    specialize qw/av1_highbd_convolve_2d_scale sse4_1/;
+    specialize qw/av1_highbd_convolve_2d_scale sse4_1 neon/;
   }
 
 # INTRA_EDGE functions
diff --git a/test/av1_convolve_scale_test.cc b/test/av1_convolve_scale_test.cc
index 3f35025..00f0a09 100644
--- a/test/av1_convolve_scale_test.cc
+++ b/test/av1_convolve_scale_test.cc
@@ -455,11 +455,20 @@
 TEST_P(LowBDConvolveScaleTest, DISABLED_Speed) { SpeedTest(); }
 
 INSTANTIATE_TEST_SUITE_P(
+    C, LowBDConvolveScaleTest,
+    ::testing::Combine(::testing::Values(av1_convolve_2d_scale_c),
+                       ::testing::ValuesIn(kBlockDim),
+                       ::testing::ValuesIn(kNTaps), ::testing::ValuesIn(kNTaps),
+                       ::testing::Bool()));
+
+#if HAVE_SSE4_1
+INSTANTIATE_TEST_SUITE_P(
     SSE4_1, LowBDConvolveScaleTest,
     ::testing::Combine(::testing::Values(av1_convolve_2d_scale_sse4_1),
                        ::testing::ValuesIn(kBlockDim),
                        ::testing::ValuesIn(kNTaps), ::testing::ValuesIn(kNTaps),
                        ::testing::Bool()));
+#endif  // HAVE_SSE4_1
 
 #if CONFIG_AV1_HIGHBITDEPTH
 typedef void (*HighbdConvolveFunc)(const uint16_t *src, int src_stride,
@@ -522,10 +531,30 @@
 TEST_P(HighBDConvolveScaleTest, DISABLED_Speed) { SpeedTest(); }
 
 INSTANTIATE_TEST_SUITE_P(
+    C, HighBDConvolveScaleTest,
+    ::testing::Combine(::testing::Values(av1_highbd_convolve_2d_scale_c),
+                       ::testing::ValuesIn(kBlockDim),
+                       ::testing::ValuesIn(kNTaps), ::testing::ValuesIn(kNTaps),
+                       ::testing::Bool(), ::testing::ValuesIn(kBDs)));
+
+#if HAVE_SSE4_1
+INSTANTIATE_TEST_SUITE_P(
     SSE4_1, HighBDConvolveScaleTest,
     ::testing::Combine(::testing::Values(av1_highbd_convolve_2d_scale_sse4_1),
                        ::testing::ValuesIn(kBlockDim),
                        ::testing::ValuesIn(kNTaps), ::testing::ValuesIn(kNTaps),
                        ::testing::Bool(), ::testing::ValuesIn(kBDs)));
+#endif  // HAVE_SSE4_1
+
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(
+    NEON, HighBDConvolveScaleTest,
+    ::testing::Combine(::testing::Values(av1_highbd_convolve_2d_scale_neon),
+                       ::testing::ValuesIn(kBlockDim),
+                       ::testing::ValuesIn(kNTaps), ::testing::ValuesIn(kNTaps),
+                       ::testing::Bool(), ::testing::ValuesIn(kBDs)));
+
+#endif  // HAVE_NEON
+
 #endif  // CONFIG_AV1_HIGHBITDEPTH
 }  // namespace
diff --git a/test/test.cmake b/test/test.cmake
index 8e5cb87..51b4e47 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -354,6 +354,11 @@
 
   endif()
 
+  if(HAVE_NEON)
+    list(APPEND AOM_UNIT_TEST_ENCODER_SOURCES
+                "${AOM_ROOT}/test/av1_convolve_scale_test.cc")
+  endif()
+
   if(HAVE_SSE4_2 OR HAVE_ARM_CRC32)
     list(APPEND AOM_UNIT_TEST_ENCODER_SOURCES "${AOM_ROOT}/test/hash_test.cc")
   endif()