Optimize av1_convolve_2d_sr_horiz_12tap_neon using SDOT

Add an alternative AArch64 implementation of
av1_convolve_2d_sr_horiz_12tap_neon for targets that implement the
Armv8.4-A SDOT (signed dot-product) instruction.

The existing MLA-based implementation of
av1_convolve_2d_sr_horiz_12tap_neon is retained for use on target
CPUs that do not implement the SDOT instruction (or CPUs executing in
AArch32 mode.) The availability of the SDOT instruction is indicated
by the feature macro __ARM_FEATURE_DOTPROD.

Change-Id: I6db861f41c7763c59f7176dbcd0404f146cde395
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index eceab99..a01a04a 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -158,36 +158,6 @@
   return vqmovun_s16(res);
 }
 
-static INLINE int16x4_t convolve12_horiz_4x4_s16(
-    const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
-    const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
-    const int16x4_t s6, const int16x4_t s7, const int16x4_t s8,
-    const int16x4_t s9, const int16x4_t s10, const int16x4_t s11,
-    const int16x8_t x_filter_0_7, const int16x4_t x_filter_8_11,
-    const int32x4_t horiz_const, const int32x4_t shift_round_0) {
-  const int16x4_t x_filter_0_3 = vget_low_s16(x_filter_0_7);
-  const int16x4_t x_filter_4_7 = vget_high_s16(x_filter_0_7);
-  int32x4_t sum;
-
-  sum = horiz_const;
-  sum = vmlal_lane_s16(sum, s0, x_filter_0_3, 0);
-  sum = vmlal_lane_s16(sum, s1, x_filter_0_3, 1);
-  sum = vmlal_lane_s16(sum, s2, x_filter_0_3, 2);
-  sum = vmlal_lane_s16(sum, s3, x_filter_0_3, 3);
-  sum = vmlal_lane_s16(sum, s4, x_filter_4_7, 0);
-  sum = vmlal_lane_s16(sum, s5, x_filter_4_7, 1);
-  sum = vmlal_lane_s16(sum, s6, x_filter_4_7, 2);
-  sum = vmlal_lane_s16(sum, s7, x_filter_4_7, 3);
-  sum = vmlal_lane_s16(sum, s8, x_filter_8_11, 0);
-  sum = vmlal_lane_s16(sum, s9, x_filter_8_11, 1);
-  sum = vmlal_lane_s16(sum, s10, x_filter_8_11, 2);
-  sum = vmlal_lane_s16(sum, s11, x_filter_8_11, 3);
-
-  sum = vqrshlq_s32(sum, shift_round_0);
-
-  return vmovn_s32(sum);
-}
-
 static INLINE int16x4_t convolve12_vert_4x4_s32(
     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
@@ -1158,6 +1128,310 @@
   }
 }
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE int16x4_t convolve12_4_dot(uint8x16_t samples,
+                                         const int8x16_t filters,
+                                         const int32x4_t correction,
+                                         const uint8x16_t range_limit,
+                                         const uint8x16x3_t permute_tbl,
+                                         const int32x4_t shift_round_0) {
+  int8x16_t clamped_samples, permuted_samples[3];
+  int32x4_t sum;
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  /* First 4 output values. */
+  sum = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
+  sum = vdotq_laneq_s32(sum, permuted_samples[1], filters, 1);
+  sum = vdotq_laneq_s32(sum, permuted_samples[2], filters, 2);
+
+  /* Narrow and re-pack. */
+  sum = vqrshlq_s32(sum, shift_round_0);
+
+  return vmovn_s32(sum);
+}
+
+static INLINE int16x8_t convolve12_8_dot(
+    uint8x16_t samples0, uint8x16_t samples1, const int8x16_t filters,
+    const int32x4_t correction, const uint8x16_t range_limit,
+    const uint8x16x3_t permute_tbl, const int32x4_t shift_round_0) {
+  int8x16_t clamped_samples[2], permuted_samples[4];
+  int32x4_t sum[2];
+
+  /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
+  clamped_samples[0] = vreinterpretq_s8_u8(vsubq_u8(samples0, range_limit));
+  clamped_samples[1] = vreinterpretq_s8_u8(vsubq_u8(samples1, range_limit));
+
+  /* Permute samples ready for dot product. */
+  /* { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 } */
+  permuted_samples[0] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[0]);
+  /* { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 } */
+  permuted_samples[1] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[1]);
+  /* { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
+  permuted_samples[2] = vqtbl1q_s8(clamped_samples[0], permute_tbl.val[2]);
+  /* {12, 13, 14, 15, 13, 14, 15, 16, 14, 15, 16, 17, 15, 16, 17, 18 } */
+  permuted_samples[3] = vqtbl1q_s8(clamped_samples[1], permute_tbl.val[2]);
+
+  /* Accumulate dot product into 'correction' to account for range clamp. */
+  /* First 4 output values. */
+  sum[0] = vdotq_laneq_s32(correction, permuted_samples[0], filters, 0);
+  sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[1], filters, 1);
+  sum[0] = vdotq_laneq_s32(sum[0], permuted_samples[2], filters, 2);
+  /* First 4 output values. */
+  sum[1] = vdotq_laneq_s32(correction, permuted_samples[1], filters, 0);
+  sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[2], filters, 1);
+  sum[1] = vdotq_laneq_s32(sum[1], permuted_samples[3], filters, 2);
+
+  /* Narrow and re-pack. */
+  sum[0] = vqrshlq_s32(sum[0], shift_round_0);
+  sum[1] = vqrshlq_s32(sum[1], shift_round_0);
+
+  return vcombine_s16(vmovn_s32(sum[0]), vmovn_s32(sum[1]));
+}
+
+static INLINE void av1_convolve_2d_sr_horiz_12tap_neon(
+    const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
+    const int dst_stride, int w, int h, const int16x8_t x_filter_0_7,
+    const int16x4_t x_filter_8_11, const int round_0) {
+  const int bd = 8;
+
+  // Special case the following no-op filter as 128 won't fit into the
+  // 8-bit signed dot-product instruction:
+  // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
+  if (vgetq_lane_s16(x_filter_0_7, 5) == 128) {
+    const int16x8_t horiz_const = vdupq_n_s16((1 << (bd - 1)));
+    const int16x8_t shift_round_0 = vdupq_n_s16(FILTER_BITS - round_0);
+    // Undo the horizontal offset in the calling function.
+    src_ptr += 5;
+
+    for (int i = 0; i < h; i++) {
+      for (int j = 0; j < w; j += 8) {
+        uint8x8_t s0 = vld1_u8(src_ptr + i * src_stride + j);
+        uint16x8_t t0 = vaddw_u8(vreinterpretq_u16_s16(horiz_const), s0);
+        int16x8_t d0 = vqrshlq_s16(vreinterpretq_s16_u16(t0), shift_round_0);
+        if (w == 2) {
+          vst1q_lane_s32((int32_t *)(dst_ptr + i * dst_stride),
+                         vreinterpretq_s32_s16(d0), 0);
+        } else if (w == 4) {
+          vst1_s16(dst_ptr + i * dst_stride, vget_low_s16(d0));
+        } else {
+          vst1q_s16(dst_ptr + i * dst_stride + j, d0);
+        }
+      }
+    }
+  } else {
+    const int32x4_t shift_round_0 = vdupq_n_s32(-round_0);
+
+    // Narrow filter values to 8-bit.
+    const int16x8x2_t x_filter_s16 = {
+      { x_filter_0_7, vcombine_s16(x_filter_8_11, vdup_n_s16(0)) }
+    };
+    const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]),
+                                           vmovn_s16(x_filter_s16.val[1]));
+
+    // Dot product constants.
+    const int32_t horiz_const = (1 << (bd + FILTER_BITS - 1));
+    const int32x4_t correct_tmp =
+        vaddq_s32(vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[0], 7)),
+                  vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[1], 7)));
+    const int32x4_t correction =
+        vdupq_n_s32(vaddvq_s32(correct_tmp) + horiz_const);
+    const uint8x16_t range_limit = vdupq_n_u8(128);
+    const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+
+    if (w <= 4) {
+      do {
+        const uint8_t *s = src_ptr;
+        int16_t *d = dst_ptr;
+        int width = w;
+
+        do {
+          uint8x16_t s0, s1, s2, s3;
+          int16x4_t d0, d1, d2, d3;
+
+          s0 = vld1q_u8(s + 0 * src_stride);
+          s1 = vld1q_u8(s + 1 * src_stride);
+          s2 = vld1q_u8(s + 2 * src_stride);
+          s3 = vld1q_u8(s + 3 * src_stride);
+
+          d0 = convolve12_4_dot(s0, x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+          d1 = convolve12_4_dot(s1, x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+          d2 = convolve12_4_dot(s2, x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+          d3 = convolve12_4_dot(s3, x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+
+          if (w == 2) {
+            vst1_lane_s32((int32_t *)(d + 0 * dst_stride),
+                          vreinterpret_s32_s16(d0), 0);
+            vst1_lane_s32((int32_t *)(d + 1 * dst_stride),
+                          vreinterpret_s32_s16(d1), 0);
+            vst1_lane_s32((int32_t *)(d + 2 * dst_stride),
+                          vreinterpret_s32_s16(d2), 0);
+            vst1_lane_s32((int32_t *)(d + 3 * dst_stride),
+                          vreinterpret_s32_s16(d3), 0);
+          } else {
+            vst1_s16(d + 0 * dst_stride, d0);
+            vst1_s16(d + 1 * dst_stride, d1);
+            vst1_s16(d + 2 * dst_stride, d2);
+            vst1_s16(d + 3 * dst_stride, d3);
+          }
+
+          s += 4;
+          d += 4;
+          width -= 4;
+        } while (width > 0);
+
+        src_ptr += 4 * src_stride;
+        dst_ptr += 4 * dst_stride;
+        h -= 4;
+      } while (h >= 4);
+
+      for (; h > 0; h--) {
+        const uint8_t *s = src_ptr;
+        int16_t *d = dst_ptr;
+        int width = w;
+
+        do {
+          uint8x16_t s0;
+          int16x4_t d0;
+
+          s0 = vld1q_u8(s);
+
+          d0 = convolve12_4_dot(s0, x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+
+          if (w == 2) {
+            vst1_lane_s32((int32_t *)d, vreinterpret_s32_s16(d0), 0);
+          } else {
+            vst1_s16(d, d0);
+          }
+
+          s += 4;
+          d += 4;
+          width -= 4;
+        } while (width > 0);
+
+        src_ptr += src_stride;
+        dst_ptr += dst_stride;
+      }
+    } else {
+      do {
+        const uint8_t *s = src_ptr;
+        int16_t *d = dst_ptr;
+        int width = w;
+
+        do {
+          uint8x16_t s0[2], s1[2], s2[2], s3[2];
+          int16x8_t d0, d1, d2, d3;
+
+          s0[0] = vld1q_u8(s + 0 * src_stride);
+          s1[0] = vld1q_u8(s + 1 * src_stride);
+          s2[0] = vld1q_u8(s + 2 * src_stride);
+          s3[0] = vld1q_u8(s + 3 * src_stride);
+          s0[1] = vld1q_u8(s + 0 * src_stride + 4);
+          s1[1] = vld1q_u8(s + 1 * src_stride + 4);
+          s2[1] = vld1q_u8(s + 2 * src_stride + 4);
+          s3[1] = vld1q_u8(s + 3 * src_stride + 4);
+
+          d0 = convolve12_8_dot(s0[0], s0[1], x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+          d1 = convolve12_8_dot(s1[0], s1[1], x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+          d2 = convolve12_8_dot(s2[0], s2[1], x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+          d3 = convolve12_8_dot(s3[0], s3[1], x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+
+          vst1q_s16(d + 0 * dst_stride, d0);
+          vst1q_s16(d + 1 * dst_stride, d1);
+          vst1q_s16(d + 2 * dst_stride, d2);
+          vst1q_s16(d + 3 * dst_stride, d3);
+
+          s += 8;
+          d += 8;
+          width -= 8;
+        } while (width > 0);
+
+        src_ptr += 4 * src_stride;
+        dst_ptr += 4 * dst_stride;
+        h -= 4;
+      } while (h >= 4);
+
+      for (; h > 0; h--) {
+        const uint8_t *s = src_ptr;
+        int16_t *d = dst_ptr;
+        int width = w;
+
+        do {
+          uint8x16_t s0[2];
+          int16x8_t d0;
+
+          s0[0] = vld1q_u8(s);
+          s0[1] = vld1q_u8(s + 4);
+
+          d0 = convolve12_8_dot(s0[0], s0[1], x_filter, correction, range_limit,
+                                permute_tbl, shift_round_0);
+
+          vst1q_s16(d, d0);
+
+          s += 8;
+          d += 8;
+          width -= 8;
+        } while (width > 0);
+
+        src_ptr += src_stride;
+        dst_ptr += dst_stride;
+      }
+    }
+  }
+}
+
+#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+
+static INLINE int16x4_t convolve12_horiz_4x4_s16(
+    const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
+    const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
+    const int16x4_t s6, const int16x4_t s7, const int16x4_t s8,
+    const int16x4_t s9, const int16x4_t s10, const int16x4_t s11,
+    const int16x8_t x_filter_0_7, const int16x4_t x_filter_8_11,
+    const int32x4_t horiz_const, const int32x4_t shift_round_0) {
+  const int16x4_t x_filter_0_3 = vget_low_s16(x_filter_0_7);
+  const int16x4_t x_filter_4_7 = vget_high_s16(x_filter_0_7);
+  int32x4_t sum;
+
+  sum = horiz_const;
+  sum = vmlal_lane_s16(sum, s0, x_filter_0_3, 0);
+  sum = vmlal_lane_s16(sum, s1, x_filter_0_3, 1);
+  sum = vmlal_lane_s16(sum, s2, x_filter_0_3, 2);
+  sum = vmlal_lane_s16(sum, s3, x_filter_0_3, 3);
+  sum = vmlal_lane_s16(sum, s4, x_filter_4_7, 0);
+  sum = vmlal_lane_s16(sum, s5, x_filter_4_7, 1);
+  sum = vmlal_lane_s16(sum, s6, x_filter_4_7, 2);
+  sum = vmlal_lane_s16(sum, s7, x_filter_4_7, 3);
+  sum = vmlal_lane_s16(sum, s8, x_filter_8_11, 0);
+  sum = vmlal_lane_s16(sum, s9, x_filter_8_11, 1);
+  sum = vmlal_lane_s16(sum, s10, x_filter_8_11, 2);
+  sum = vmlal_lane_s16(sum, s11, x_filter_8_11, 3);
+
+  sum = vqrshlq_s32(sum, shift_round_0);
+
+  return vmovn_s32(sum);
+}
+
 // 4 column per iteration horizontal filtering for 12-tap convolve_2d_sr.
 // Processes one row at a time.
 static INLINE void horiz_filter_12tap_w4_single_row(
@@ -1329,6 +1603,8 @@
 #endif  // defined(__aarch64__)
 }
 
+#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+
 static INLINE void av1_convolve_2d_sr_vert_12tap_neon(
     int16_t *src_ptr, int src_stride, uint8_t *dst_ptr, int dst_stride, int w,
     int h, const int16x8_t y_filter_0_7, const int16x4_t y_filter_8_11,