Optimize av1_convolve_x_sr_neon using SDOT instruction

Add an alternative AArch64 implementation of av1_convolve_x_sr_neon
for targets that implement the Armv8.4-A SDOT (signed dot-product)
instruction.

The existing MLA implementation of av1_convolve_x_sr_neon is retained
for use on target CPUs that do not implement the SDOT instruction (or
CPUs executing in AArch32 mode.) The availability of the SDOT
instruction is indicated by the feature macro __ARM_FEATURE_DOTPROD.

Change-Id: I5bbd56af03ce0c0a27c7ea1f221a7e61e4249e1a
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index b12c279..2d35669 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -44,30 +44,6 @@
   return sum;
 }
 
-static INLINE uint8x8_t convolve8_horiz_8x8(
-    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
-    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
-    const int16x8_t s6, const int16x8_t s7, const int16x8_t filter,
-    const int16x8_t shift_round_0, const int16x8_t shift_by_bits) {
-  const int16x4_t filter_lo = vget_low_s16(filter);
-  const int16x4_t filter_hi = vget_high_s16(filter);
-  int16x8_t sum;
-
-  sum = vmulq_lane_s16(s0, filter_lo, 0);
-  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
-  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
-  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
-  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
-  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
-  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
-  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
-
-  sum = vqrshlq_s16(sum, shift_round_0);
-  sum = vqrshlq_s16(sum, shift_by_bits);
-
-  return vqmovun_s16(sum);
-}
-
 #if !defined(__aarch64__)
 static INLINE uint8x8_t convolve8_horiz_4x1(
     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
@@ -182,6 +158,228 @@
   return vqmovun_s16(res);
 }
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE int32x4_t convolve8_4_dot(uint8x16_t samples,
+                                        const int8x8_t filters,
+                                        const int32x4_t correction,
+                                        const uint8x16_t range_limit,
+                                        const uint8x16x2_t permute_tbl) {
+  int8x16_t clamped_samples, permuted_samples[2];
+  int32x4_t sum;
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  sum = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
+  sum = vdotq_lane_s32(sum, permuted_samples[1], filters, 1);
+
+  /* Narrowing and packing is performed by the caller. */
+  return sum;
+}
+
+static INLINE uint8x8_t convolve8_8_dot(
+    uint8x16_t samples, const int8x8_t filters, const int32x4_t correction,
+    const uint8x16_t range_limit, const uint8x16x3_t permute_tbl,
+    const int16x8_t shift_round_0, const int16x8_t shift_by_bits) {
+  int8x16_t clamped_samples, permuted_samples[3];
+  int32x4_t sum0, sum1;
+  int16x8_t sum;
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  /* First 4 output values. */
+  sum0 = vdotq_lane_s32(correction, permuted_samples[0], filters, 0);
+  sum0 = vdotq_lane_s32(sum0, permuted_samples[1], filters, 1);
+  /* Second 4 output values. */
+  sum1 = vdotq_lane_s32(correction, permuted_samples[1], filters, 0);
+  sum1 = vdotq_lane_s32(sum1, permuted_samples[2], filters, 1);
+
+  /* Narrow and re-pack. */
+  sum = vcombine_s16(vmovn_s32(sum0), vmovn_s32(sum1));
+  sum = vqrshlq_s16(sum, shift_round_0);
+  sum = vqrshlq_s16(sum, shift_by_bits);
+  return vqmovun_s16(sum);
+}
+
+void av1_convolve_x_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
+                            int dst_stride, int w, int h,
+                            const InterpFilterParams *filter_params_x,
+                            const int subpel_x_qn,
+                            ConvolveParams *conv_params) {
+  if (filter_params_x->taps > 8) {
+    av1_convolve_x_sr_c(src, src_stride, dst, dst_stride, w, h, filter_params_x,
+                        subpel_x_qn, conv_params);
+    return;
+  }
+  const uint8_t horiz_offset = filter_params_x->taps / 2 - 1;
+  const int8_t bits = FILTER_BITS - conv_params->round_0;
+
+  assert(bits >= 0);
+  assert((FILTER_BITS - conv_params->round_1) >= 0 ||
+         ((conv_params->round_0 + conv_params->round_1) == 2 * FILTER_BITS));
+
+  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
+      filter_params_x, subpel_x_qn & SUBPEL_MASK);
+  // Filter values are even, so downshift by 1 to reduce intermediate precision
+  // requirements.
+  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+  // Dot product constants.
+  const int16x8_t correct_tmp = vshll_n_s8(x_filter, 7);
+  const int32x4_t correction = vdupq_n_s32(vaddlvq_s16(correct_tmp));
+  const uint8x16_t range_limit = vdupq_n_u8(128);
+
+  const int16x8_t shift_round_0 = vdupq_n_s16(-conv_params->round_0 + 1);
+  const int16x8_t shift_by_bits = vdupq_n_s16(-bits);
+
+  src -= horiz_offset;
+
+  if (w <= 4) {
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+    uint8x16_t s0, s1, s2, s3;
+    int32x4_t t0, t1, t2, t3;
+    int16x8_t t01, t23;
+    uint8x8_t d01, d23;
+
+    do {
+      s0 = vld1q_u8(src + 0 * src_stride);
+      s1 = vld1q_u8(src + 1 * src_stride);
+      s2 = vld1q_u8(src + 2 * src_stride);
+      s3 = vld1q_u8(src + 3 * src_stride);
+
+      t0 = convolve8_4_dot(s0, x_filter, correction, range_limit, permute_tbl);
+      t1 = convolve8_4_dot(s1, x_filter, correction, range_limit, permute_tbl);
+      t2 = convolve8_4_dot(s2, x_filter, correction, range_limit, permute_tbl);
+      t3 = convolve8_4_dot(s3, x_filter, correction, range_limit, permute_tbl);
+
+      t01 = vcombine_s16(vmovn_s32(t0), vmovn_s32(t1));
+      t23 = vcombine_s16(vmovn_s32(t2), vmovn_s32(t3));
+
+      t01 = vqrshlq_s16(t01, shift_round_0);
+      t23 = vqrshlq_s16(t23, shift_round_0);
+
+      t01 = vqrshlq_s16(t01, shift_by_bits);
+      t23 = vqrshlq_s16(t23, shift_by_bits);
+
+      d01 = vqmovun_s16(t01);
+      d23 = vqmovun_s16(t23);
+
+      if (w == 2) {
+        vst1_lane_u16((uint16_t *)(dst + 0 * dst_stride),
+                      vreinterpret_u16_u8(d01), 0);
+        vst1_lane_u16((uint16_t *)(dst + 1 * dst_stride),
+                      vreinterpret_u16_u8(d01), 2);
+        if (h != 2) {
+          vst1_lane_u16((uint16_t *)(dst + 2 * dst_stride),
+                        vreinterpret_u16_u8(d23), 0);
+          vst1_lane_u16((uint16_t *)(dst + 3 * dst_stride),
+                        vreinterpret_u16_u8(d23), 2);
+        }
+      } else {
+        vst1_lane_u32((uint32_t *)(dst + 0 * dst_stride),
+                      vreinterpret_u32_u8(d01), 0);
+        vst1_lane_u32((uint32_t *)(dst + 1 * dst_stride),
+                      vreinterpret_u32_u8(d01), 1);
+        if (h != 2) {
+          vst1_lane_u32((uint32_t *)(dst + 2 * dst_stride),
+                        vreinterpret_u32_u8(d23), 0);
+          vst1_lane_u32((uint32_t *)(dst + 3 * dst_stride),
+                        vreinterpret_u32_u8(d23), 1);
+        }
+      }
+
+      h -= 4;
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+    } while (h > 0);
+
+  } else {
+    const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+    uint8x16_t s0, s1, s2, s3;
+    uint8x8_t d0, d1, d2, d3;
+
+    do {
+      int width = w;
+      const uint8_t *s = src;
+      uint8_t *d = dst;
+
+      do {
+        s0 = vld1q_u8(s + 0 * src_stride);
+        s1 = vld1q_u8(s + 1 * src_stride);
+        s2 = vld1q_u8(s + 2 * src_stride);
+        s3 = vld1q_u8(s + 3 * src_stride);
+
+        d0 = convolve8_8_dot(s0, x_filter, correction, range_limit, permute_tbl,
+                             shift_round_0, shift_by_bits);
+        d1 = convolve8_8_dot(s1, x_filter, correction, range_limit, permute_tbl,
+                             shift_round_0, shift_by_bits);
+        d2 = convolve8_8_dot(s2, x_filter, correction, range_limit, permute_tbl,
+                             shift_round_0, shift_by_bits);
+        d3 = convolve8_8_dot(s3, x_filter, correction, range_limit, permute_tbl,
+                             shift_round_0, shift_by_bits);
+
+        vst1_u8(d + 0 * dst_stride, d0);
+        vst1_u8(d + 1 * dst_stride, d1);
+        if (h != 2) {
+          vst1_u8(d + 2 * dst_stride, d2);
+          vst1_u8(d + 3 * dst_stride, d3);
+        }
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width > 0);
+
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 0);
+  }
+}
+
+#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+
+static INLINE uint8x8_t convolve8_horiz_8x8(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
+    const int16x8_t s6, const int16x8_t s7, const int16x8_t filter,
+    const int16x8_t shift_round_0, const int16x8_t shift_by_bits) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+  int16x8_t sum;
+
+  sum = vmulq_lane_s16(s0, filter_lo, 0);
+  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
+
+  sum = vqrshlq_s16(sum, shift_round_0);
+  sum = vqrshlq_s16(sum, shift_by_bits);
+
+  return vqmovun_s16(sum);
+}
+
 void av1_convolve_x_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
                             int dst_stride, int w, int h,
                             const InterpFilterParams *filter_params_x,
@@ -602,6 +800,8 @@
 #endif
 }
 
+#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+
 void av1_convolve_y_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
                             int dst_stride, int w, int h,
                             const InterpFilterParams *filter_params_y,