[NEON] Optimize av1_highbd_convolve_horiz_rs()

Gives an extra 1% in superres encoding.

Change-Id: I631e824e74b4864337da56d3bc760f1d8903768e
diff --git a/av1/common/arm/highbd_convolve_neon.c b/av1/common/arm/highbd_convolve_neon.c
index a9cd9b8..d7def2b 100644
--- a/av1/common/arm/highbd_convolve_neon.c
+++ b/av1/common/arm/highbd_convolve_neon.c
@@ -2127,3 +2127,254 @@
     }
   }
 }
+
+#define UPSCALE_NORMATIVE_TAPS 8
+
+void av1_highbd_convolve_horiz_rs_neon(const uint16_t *src, int src_stride,
+                                       uint16_t *dst, int dst_stride, int w,
+                                       int h, const int16_t *x_filters,
+                                       int x0_qn, int x_step_qn, int bd) {
+  const int horiz_offset = UPSCALE_NORMATIVE_TAPS / 2 - 1;
+
+  const int32x4_t idx = { 0, 1, 2, 3 };
+  const int32x4_t subpel_mask = vdupq_n_s32(RS_SCALE_SUBPEL_MASK);
+  const int32x4_t shift_s32 = vdupq_n_s32(-FILTER_BITS);
+  const int32x4_t offset_s32 = vdupq_n_s32(0);
+  const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
+
+  const uint16_t *src_ptr = src - horiz_offset;
+  uint16_t *dst_ptr = dst;
+
+  if (w <= 4) {
+    int height = h;
+    int16x8_t s0, s1, s2, s3;
+    uint16x4_t d0;
+
+    uint16_t *d = dst_ptr;
+    do {
+      int x_qn = x0_qn;
+
+      // Load 4 src vectors at a time, they might be the same, but we have to
+      // calculate the indices anyway. Doing it in SIMD and then storing the
+      // indices is faster than having to calculate the expression
+      // &src_ptr[((x_qn + 0*x_step_qn) >> RS_SCALE_SUBPEL_BITS)] 4 times
+      // Ideally this should be a gather using the indices, but NEON does not
+      // have that, so have to emulate
+      const int32x4_t xqn_idx = vmlaq_n_s32(vdupq_n_s32(x_qn), idx, x_step_qn);
+      // We have to multiply x2 to get the actual pointer as sizeof(uint16_t) =
+      // 2
+      const int32x4_t src_idx =
+          vshlq_n_s32(vshrq_n_s32(xqn_idx, RS_SCALE_SUBPEL_BITS), 1);
+      // Similarly for the filter vector indices, we calculate the filter
+      // indices for 4 columns. First we calculate the indices:
+      // x_qn & RS_SCALE_SUBPEL_MASK) >> RS_SCALE_EXTRA_BITS
+      // Then we calculate the actual pointers, multiplying with
+      // UPSCALE_UPSCALE_NORMATIVE_TAPS
+      // again shift left by 1
+      const int32x4_t x_filter4_idx = vshlq_n_s32(
+          vshrq_n_s32(vandq_s32(xqn_idx, subpel_mask), RS_SCALE_EXTRA_BITS), 1);
+      // Even though pointers are unsigned 32/64-bit ints we do signed
+      // addition The reason for this is that x_qn can be negative, leading to
+      // negative offsets. Argon test
+      // profile0_core/streams/test10573_11003.obu was failing because of
+      // this.
+#if defined(__aarch64__)
+      uint64x2_t tmp4[2];
+      tmp4[0] = vreinterpretq_u64_s64(vaddw_s32(
+          vdupq_n_s64((const int64_t)src_ptr), vget_low_s32(src_idx)));
+      tmp4[1] = vreinterpretq_u64_s64(vaddw_s32(
+          vdupq_n_s64((const int64_t)src_ptr), vget_high_s32(src_idx)));
+      int16_t *src4_ptr[4];
+      uint64_t *tmp_ptr = (uint64_t *)&src4_ptr;
+      vst1q_u64(tmp_ptr, tmp4[0]);
+      vst1q_u64(tmp_ptr + 2, tmp4[1]);
+
+      // filter vectors
+      tmp4[0] = vreinterpretq_u64_s64(vmlal_s32(
+          vdupq_n_s64((const int64_t)x_filters), vget_low_s32(x_filter4_idx),
+          vdup_n_s32(UPSCALE_NORMATIVE_TAPS)));
+      tmp4[1] = vreinterpretq_u64_s64(vmlal_s32(
+          vdupq_n_s64((const int64_t)x_filters), vget_high_s32(x_filter4_idx),
+          vdup_n_s32(UPSCALE_NORMATIVE_TAPS)));
+
+      const int16_t *x_filter4_ptr[4];
+      tmp_ptr = (uint64_t *)&x_filter4_ptr;
+      vst1q_u64(tmp_ptr, tmp4[0]);
+      vst1q_u64(tmp_ptr + 2, tmp4[1]);
+#else
+      uint32x4_t tmp4;
+      tmp4 = vreinterpretq_u32_s32(
+          vaddq_s32(vdupq_n_s32((const int32_t)src_ptr), src_idx));
+      int16_t *src4_ptr[4];
+      uint32_t *tmp_ptr = (uint32_t *)&src4_ptr;
+      vst1q_u32(tmp_ptr, tmp4);
+      // filter vectors
+      tmp4 = vreinterpretq_u32_s32(
+          vmlaq_s32(vdupq_n_s32((const int32_t)x_filters), x_filter4_idx,
+                    vdupq_n_s32(UPSCALE_NORMATIVE_TAPS)));
+
+      const int16_t *x_filter4_ptr[4];
+      tmp_ptr = (uint32_t *)&x_filter4_ptr;
+      vst1q_u32(tmp_ptr, tmp4);
+#endif  // defined(__aarch64__)
+      // Load source
+      s0 = vld1q_s16(src4_ptr[0]);
+      s1 = vld1q_s16(src4_ptr[1]);
+      s2 = vld1q_s16(src4_ptr[2]);
+      s3 = vld1q_s16(src4_ptr[3]);
+
+      // Actually load the filters
+      const int16x8_t x_filter0 = vld1q_s16(x_filter4_ptr[0]);
+      const int16x8_t x_filter1 = vld1q_s16(x_filter4_ptr[1]);
+      const int16x8_t x_filter2 = vld1q_s16(x_filter4_ptr[2]);
+      const int16x8_t x_filter3 = vld1q_s16(x_filter4_ptr[3]);
+
+      // Group low and high parts and transpose
+      int16x4_t filters_lo[] = { vget_low_s16(x_filter0),
+                                 vget_low_s16(x_filter1),
+                                 vget_low_s16(x_filter2),
+                                 vget_low_s16(x_filter3) };
+      int16x4_t filters_hi[] = { vget_high_s16(x_filter0),
+                                 vget_high_s16(x_filter1),
+                                 vget_high_s16(x_filter2),
+                                 vget_high_s16(x_filter3) };
+      transpose_u16_4x4((uint16x4_t *)filters_lo);
+      transpose_u16_4x4((uint16x4_t *)filters_hi);
+
+      // Run the 2D Scale convolution
+      d0 = highbd_convolve8_2d_scale_horiz4x8_s32_s16(
+          s0, s1, s2, s3, filters_lo, filters_hi, shift_s32, offset_s32);
+
+      d0 = vmin_u16(d0, max);
+
+      if (w == 2) {
+        store_u16_2x1(d + 0 * dst_stride, d0, 0);
+      } else {
+        vst1_u16(d + 0 * dst_stride, d0);
+      }
+
+      src_ptr += src_stride;
+      d += dst_stride;
+      height--;
+    } while (height > 0);
+  } else {
+    int height = h;
+    int16x8_t s0, s1, s2, s3;
+    uint16x4_t d0;
+
+    do {
+      int width = w;
+      int x_qn = x0_qn;
+      uint16_t *d = dst_ptr;
+      const uint16_t *s = src_ptr;
+
+      do {
+        // Load 4 src vectors at a time, they might be the same, but we have to
+        // calculate the indices anyway. Doing it in SIMD and then storing the
+        // indices is faster than having to calculate the expression
+        // &src_ptr[((x_qn + 0*x_step_qn) >> RS_SCALE_SUBPEL_BITS)] 4 times
+        // Ideally this should be a gather using the indices, but NEON does not
+        // have that, so have to emulate
+        const int32x4_t xqn_idx =
+            vmlaq_n_s32(vdupq_n_s32(x_qn), idx, x_step_qn);
+        // We have to multiply x2 to get the actual pointer as sizeof(uint16_t)
+        // = 2
+        const int32x4_t src_idx =
+            vshlq_n_s32(vshrq_n_s32(xqn_idx, RS_SCALE_SUBPEL_BITS), 1);
+
+        // Similarly for the filter vector indices, we calculate the filter
+        // indices for 4 columns. First we calculate the indices:
+        // x_qn & RS_SCALE_SUBPEL_MASK) >> RS_SCALE_EXTRA_BITS
+        // Then we calculate the actual pointers, multiplying with
+        // UPSCALE_UPSCALE_NORMATIVE_TAPS
+        // again shift left by 1
+        const int32x4_t x_filter4_idx = vshlq_n_s32(
+            vshrq_n_s32(vandq_s32(xqn_idx, subpel_mask), RS_SCALE_EXTRA_BITS),
+            1);
+        // Even though pointers are unsigned 32/64-bit ints we do signed
+        // addition The reason for this is that x_qn can be negative, leading to
+        // negative offsets. Argon test
+        // profile0_core/streams/test10573_11003.obu was failing because of
+        // this.
+#if defined(__aarch64__)
+        uint64x2_t tmp4[2];
+        tmp4[0] = vreinterpretq_u64_s64(
+            vaddw_s32(vdupq_n_s64((const int64_t)s), vget_low_s32(src_idx)));
+        tmp4[1] = vreinterpretq_u64_s64(
+            vaddw_s32(vdupq_n_s64((const int64_t)s), vget_high_s32(src_idx)));
+        int16_t *src4_ptr[4];
+        uint64_t *tmp_ptr = (uint64_t *)&src4_ptr;
+        vst1q_u64(tmp_ptr, tmp4[0]);
+        vst1q_u64(tmp_ptr + 2, tmp4[1]);
+
+        // filter vectors
+        tmp4[0] = vreinterpretq_u64_s64(vmlal_s32(
+            vdupq_n_s64((const int64_t)x_filters), vget_low_s32(x_filter4_idx),
+            vdup_n_s32(UPSCALE_NORMATIVE_TAPS)));
+        tmp4[1] = vreinterpretq_u64_s64(vmlal_s32(
+            vdupq_n_s64((const int64_t)x_filters), vget_high_s32(x_filter4_idx),
+            vdup_n_s32(UPSCALE_NORMATIVE_TAPS)));
+
+        const int16_t *x_filter4_ptr[4];
+        tmp_ptr = (uint64_t *)&x_filter4_ptr;
+        vst1q_u64(tmp_ptr, tmp4[0]);
+        vst1q_u64(tmp_ptr + 2, tmp4[1]);
+#else
+        uint32x4_t tmp4;
+        tmp4 = vreinterpretq_u32_s32(
+            vaddq_s32(vdupq_n_s32((const int32_t)s), src_idx));
+        int16_t *src4_ptr[4];
+        uint32_t *tmp_ptr = (uint32_t *)&src4_ptr;
+        vst1q_u32(tmp_ptr, tmp4);
+        // filter vectors
+        tmp4 = vreinterpretq_u32_s32(
+            vmlaq_s32(vdupq_n_s32((const int32_t)x_filters), x_filter4_idx,
+                      vdupq_n_s32(UPSCALE_NORMATIVE_TAPS)));
+
+        const int16_t *x_filter4_ptr[4];
+        tmp_ptr = (uint32_t *)&x_filter4_ptr;
+        vst1q_u32(tmp_ptr, tmp4);
+#endif  // defined(__aarch64__)
+
+        // Load source
+        s0 = vld1q_s16(src4_ptr[0]);
+        s1 = vld1q_s16(src4_ptr[1]);
+        s2 = vld1q_s16(src4_ptr[2]);
+        s3 = vld1q_s16(src4_ptr[3]);
+
+        // Actually load the filters
+        const int16x8_t x_filter0 = vld1q_s16(x_filter4_ptr[0]);
+        const int16x8_t x_filter1 = vld1q_s16(x_filter4_ptr[1]);
+        const int16x8_t x_filter2 = vld1q_s16(x_filter4_ptr[2]);
+        const int16x8_t x_filter3 = vld1q_s16(x_filter4_ptr[3]);
+
+        // Group low and high parts and transpose
+        int16x4_t filters_lo[] = { vget_low_s16(x_filter0),
+                                   vget_low_s16(x_filter1),
+                                   vget_low_s16(x_filter2),
+                                   vget_low_s16(x_filter3) };
+        int16x4_t filters_hi[] = { vget_high_s16(x_filter0),
+                                   vget_high_s16(x_filter1),
+                                   vget_high_s16(x_filter2),
+                                   vget_high_s16(x_filter3) };
+        transpose_u16_4x4((uint16x4_t *)filters_lo);
+        transpose_u16_4x4((uint16x4_t *)filters_hi);
+
+        // Run the 2D Scale X convolution
+        d0 = highbd_convolve8_2d_scale_horiz4x8_s32_s16(
+            s0, s1, s2, s3, filters_lo, filters_hi, shift_s32, offset_s32);
+
+        d0 = vmin_u16(d0, max);
+        vst1_u16(d, d0);
+
+        x_qn += 4 * x_step_qn;
+        d += 4;
+        width -= 4;
+      } while (width > 0);
+
+      src_ptr += src_stride;
+      dst_ptr += dst_stride;
+      height--;
+    } while (height > 0);
+  }
+}
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 1eae7db..4c71827 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -92,7 +92,7 @@
 
 if(aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
   add_proto qw/void av1_highbd_convolve_horiz_rs/, "const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w, int h, const int16_t *x_filters, int x0_qn, int x_step_qn, int bd";
-  specialize qw/av1_highbd_convolve_horiz_rs sse4_1/;
+  specialize qw/av1_highbd_convolve_horiz_rs sse4_1 neon/;
 
   add_proto qw/void av1_highbd_wiener_convolve_add_src/, "const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const int16_t *filter_x, int x_step_q4, const int16_t *filter_y, int y_step_q4, int w, int h, const ConvolveParams *conv_params, int bd";
   specialize qw/av1_highbd_wiener_convolve_add_src ssse3 avx2/;
diff --git a/test/av1_horz_only_frame_superres_test.cc b/test/av1_horz_only_frame_superres_test.cc
index f503b63..28ee534 100644
--- a/test/av1_horz_only_frame_superres_test.cc
+++ b/test/av1_horz_only_frame_superres_test.cc
@@ -299,8 +299,13 @@
 TEST_P(LowBDConvolveHorizRSTest, Correctness) { CorrectnessTest(); }
 TEST_P(LowBDConvolveHorizRSTest, DISABLED_Speed) { SpeedTest(); }
 
+INSTANTIATE_TEST_SUITE_P(C, LowBDConvolveHorizRSTest,
+                         ::testing::Values(av1_convolve_horiz_rs_c));
+
+#if HAVE_SSE4_1
 INSTANTIATE_TEST_SUITE_P(SSE4_1, LowBDConvolveHorizRSTest,
                          ::testing::Values(av1_convolve_horiz_rs_sse4_1));
+#endif
 
 #if CONFIG_AV1_HIGHBITDEPTH
 typedef void (*HighBDConvolveHorizRsFunc)(const uint16_t *src, int src_stride,
@@ -358,9 +363,24 @@
 TEST_P(HighBDConvolveHorizRSTest, DISABLED_Speed) { SpeedTest(); }
 
 INSTANTIATE_TEST_SUITE_P(
+    C, HighBDConvolveHorizRSTest,
+    ::testing::Combine(::testing::Values(av1_highbd_convolve_horiz_rs_c),
+                       ::testing::ValuesIn(kBDs)));
+
+#if HAVE_SSE4_1
+INSTANTIATE_TEST_SUITE_P(
     SSE4_1, HighBDConvolveHorizRSTest,
     ::testing::Combine(::testing::Values(av1_highbd_convolve_horiz_rs_sse4_1),
                        ::testing::ValuesIn(kBDs)));
+#endif  // HAVE_SSE4_1
+
+#if HAVE_NEON
+INSTANTIATE_TEST_SUITE_P(
+    NEON, HighBDConvolveHorizRSTest,
+    ::testing::Combine(::testing::Values(av1_highbd_convolve_horiz_rs_neon),
+                       ::testing::ValuesIn(kBDs)));
+#endif  // HAVE_NEON
+
 #endif  // CONFIG_AV1_HIGHBITDEPTH
 
 }  // namespace
diff --git a/test/test.cmake b/test/test.cmake
index 51b4e47..7c836b0 100644
--- a/test/test.cmake
+++ b/test/test.cmake
@@ -356,7 +356,8 @@
 
   if(HAVE_NEON)
     list(APPEND AOM_UNIT_TEST_ENCODER_SOURCES
-                "${AOM_ROOT}/test/av1_convolve_scale_test.cc")
+                "${AOM_ROOT}/test/av1_convolve_scale_test.cc"
+                "${AOM_ROOT}/test/av1_horz_only_frame_superres_test.cc")
   endif()
 
   if(HAVE_SSE4_2 OR HAVE_ARM_CRC32)