Optimize av1_dist_wtd_convolve_x_neon using SDOT instruction

Add an alternative AArch64 implementation of
av1_dist_wtd_convolve_x_horiz_neon for targets that implement the
Armv8.4-A SDOT (signed dot-product) instruction.

The existing MLA-based implementation of
av1_dist_wtd_convolve_x_horiz_neon is retained for use on target CPUs
that do not implement the SDOT instruction (or CPUs executing in
AArch32 mode.) The availability of the SDOT instruction is indicated
by the feature macro __ARM_FEATURE_DOTPROD.

Change-Id: I9091e25068164330a56997f157cc569234169b8f
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index ddf55bc..dd35fdf 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -936,6 +936,198 @@
   }
 }
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+
+void av1_dist_wtd_convolve_x_neon(const uint8_t *src, int src_stride,
+                                  uint8_t *dst8, int dst8_stride, int w, int h,
+                                  const InterpFilterParams *filter_params_x,
+                                  const int subpel_x_qn,
+                                  ConvolveParams *conv_params) {
+  assert(!(w % 4));
+  assert(!(h % 4));
+
+  const int horiz_offset = filter_params_x->taps / 2 - 1;
+  const int bits = FILTER_BITS - conv_params->round_1;
+  const int bd = 8;
+  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
+  const int round_offset = (1 << (offset_bits - conv_params->round_1)) +
+                           (1 << (offset_bits - conv_params->round_1 - 1));
+  const int round_bits =
+      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+  const uint16_t fwd_offset = conv_params->fwd_offset;
+  const uint16_t bck_offset = conv_params->bck_offset;
+  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
+  const int16x4_t round_offset64 = vdup_n_s16(round_offset);
+  const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
+  const int16x8_t shift_round_0 = vdupq_n_s16(-conv_params->round_0 + 1);
+  const int16x8_t horiz_const = vdupq_n_s16(bits);
+
+  // Horizontal filter.
+  const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
+      filter_params_x, subpel_x_qn & SUBPEL_MASK);
+  // Filter values are even, so downshift by 1 to reduce intermediate precision
+  // requirements.
+  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+  // Dot-product constants.
+  const uint8x16_t range_limit = vdupq_n_u8(128);
+  const int32_t correction_s32 = vaddlvq_s16(vshll_n_s8(x_filter, 7));
+  const int32x4_t correction = vdupq_n_s32(correction_s32);
+
+  const uint8_t *src_ptr = src - horiz_offset;
+  CONV_BUF_TYPE *dst = conv_params->dst;
+  CONV_BUF_TYPE *dst_ptr = dst;
+  uint8_t *dst_u8_ptr = dst8;
+  int dst_stride = conv_params->dst_stride;
+  int width = w;
+  int height = h;
+
+  if (w == 4) {
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+
+    do {
+      uint8x16_t s0, s1, s2, s3;
+      int32x4_t d0, d1, d2, d3;
+      int16x8_t d01, d23;
+      uint16x4_t dd0, dd1, dd2, dd3;
+      uint8x8_t d01_u8, d23_u8;
+
+      s0 = vld1q_u8(src_ptr + 0 * src_stride);
+      s1 = vld1q_u8(src_ptr + 1 * src_stride);
+      s2 = vld1q_u8(src_ptr + 2 * src_stride);
+      s3 = vld1q_u8(src_ptr + 3 * src_stride);
+
+      d0 = convolve8_4_dot_s16(s0, x_filter, correction, range_limit,
+                               permute_tbl);
+      d1 = convolve8_4_dot_s16(s1, x_filter, correction, range_limit,
+                               permute_tbl);
+      d2 = convolve8_4_dot_s16(s2, x_filter, correction, range_limit,
+                               permute_tbl);
+      d3 = convolve8_4_dot_s16(s3, x_filter, correction, range_limit,
+                               permute_tbl);
+
+      d01 = vcombine_s16(vmovn_s32(d0), vmovn_s32(d1));
+      d23 = vcombine_s16(vmovn_s32(d2), vmovn_s32(d3));
+
+      d01 = vqrshlq_s16(d01, shift_round_0);
+      d23 = vqrshlq_s16(d23, shift_round_0);
+
+      d01 = vrshlq_s16(d01, horiz_const);
+      d23 = vrshlq_s16(d23, horiz_const);
+
+      d01 = vaddq_s16(d01, round_offset128);
+      d23 = vaddq_s16(d23, round_offset128);
+
+      if (conv_params->do_average) {
+        dd0 = vld1_u16(dst_ptr);
+        dst_ptr += dst_stride;
+        dd1 = vld1_u16(dst_ptr);
+        dst_ptr += dst_stride;
+        dd2 = vld1_u16(dst_ptr);
+        dst_ptr += dst_stride;
+        dd3 = vld1_u16(dst_ptr);
+        dst_ptr += dst_stride;
+
+        compute_avg_4x4(dd0, dd1, dd2, dd3,
+                        vreinterpret_u16_s16(vget_low_s16(d01)),
+                        vreinterpret_u16_s16(vget_high_s16(d01)),
+                        vreinterpret_u16_s16(vget_low_s16(d23)),
+                        vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset,
+                        bck_offset, round_offset64, round_bits,
+                        use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+
+        vst1_lane_u32((uint32_t *)dst_u8_ptr, vreinterpret_u32_u8(d01_u8), 0);
+        dst_u8_ptr += dst8_stride;
+        vst1_lane_u32((uint32_t *)dst_u8_ptr, vreinterpret_u32_u8(d01_u8), 1);
+        dst_u8_ptr += dst8_stride;
+        vst1_lane_u32((uint32_t *)dst_u8_ptr, vreinterpret_u32_u8(d23_u8), 0);
+        dst_u8_ptr += dst8_stride;
+        vst1_lane_u32((uint32_t *)dst_u8_ptr, vreinterpret_u32_u8(d23_u8), 1);
+        dst_u8_ptr += dst8_stride;
+      } else {
+        vst1q_lane_u64((uint64_t *)dst_ptr, vreinterpretq_u64_s16(d01), 0);
+        dst_ptr += dst_stride;
+        vst1q_lane_u64((uint64_t *)dst_ptr, vreinterpretq_u64_s16(d01), 1);
+        dst_ptr += dst_stride;
+        vst1q_lane_u64((uint64_t *)dst_ptr, vreinterpretq_u64_s16(d23), 0);
+        dst_ptr += dst_stride;
+        vst1q_lane_u64((uint64_t *)dst_ptr, vreinterpretq_u64_s16(d23), 1);
+        dst_ptr += dst_stride;
+      }
+
+      src_ptr += 4 * src_stride;
+      height -= 4;
+    } while (height > 0);
+  } else {
+    const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
+
+    do {
+      const uint8_t *s = src_ptr;
+      CONV_BUF_TYPE *d = dst_ptr;
+      uint8_t *d_u8 = dst_u8_ptr;
+      width = w;
+
+      do {
+        uint8x16_t s0, s1, s2, s3;
+        int16x8_t d0, d1, d2, d3;
+        uint16x8_t dd0, dd1, dd2, dd3;
+        uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
+
+        s0 = vld1q_u8(s + 0 * src_stride);
+        s1 = vld1q_u8(s + 1 * src_stride);
+        s2 = vld1q_u8(s + 2 * src_stride);
+        s3 = vld1q_u8(s + 3 * src_stride);
+
+        d0 = convolve8_8_dot_s16(s0, x_filter, correction, range_limit,
+                                 permute_tbl, shift_round_0);
+        d1 = convolve8_8_dot_s16(s1, x_filter, correction, range_limit,
+                                 permute_tbl, shift_round_0);
+        d2 = convolve8_8_dot_s16(s2, x_filter, correction, range_limit,
+                                 permute_tbl, shift_round_0);
+        d3 = convolve8_8_dot_s16(s3, x_filter, correction, range_limit,
+                                 permute_tbl, shift_round_0);
+
+        d0 = vrshlq_s16(d0, horiz_const);
+        d1 = vrshlq_s16(d1, horiz_const);
+        d2 = vrshlq_s16(d2, horiz_const);
+        d3 = vrshlq_s16(d3, horiz_const);
+
+        d0 = vaddq_s16(d0, round_offset128);
+        d1 = vaddq_s16(d1, round_offset128);
+        d2 = vaddq_s16(d2, round_offset128);
+        d3 = vaddq_s16(d3, round_offset128);
+
+        if (conv_params->do_average) {
+          load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
+
+          compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
+                          vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
+                          vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
+                          round_offset64, round_bits, use_dist_wtd_comp_avg,
+                          &d0_u8, &d1_u8, &d2_u8, &d3_u8);
+
+          store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
+        } else {
+          store_u16_8x4(d, dst_stride, vreinterpretq_u16_s16(d0),
+                        vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
+                        vreinterpretq_u16_s16(d3));
+        }
+
+        s += 8;
+        d += 8;
+        d_u8 += 8;
+        width -= 8;
+      } while (width > 0);
+
+      src_ptr += 4 * src_stride;
+      dst_ptr += 4 * dst_stride;
+      dst_u8_ptr += 4 * dst8_stride;
+      height -= 4;
+    } while (height > 0);
+  }
+}
+
+#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+
 void av1_dist_wtd_convolve_x_neon(const uint8_t *src, int src_stride,
                                   uint8_t *dst8, int dst8_stride, int w, int h,
                                   const InterpFilterParams *filter_params_x,
@@ -1391,6 +1583,8 @@
   }
 }
 
+#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+
 void av1_dist_wtd_convolve_y_neon(const uint8_t *src, int src_stride,
                                   uint8_t *dst8, int dst8_stride, int w, int h,
                                   const InterpFilterParams *filter_params_y,