Specialize av1_dist_wtd_convolve_y_neon for 6-tap filter

av1_dist_wtd_convolve_y_neon can be called with 4-, 6- and 8-tap
filters, with 4- and 6-tap filters being padded out with 0s to enable
re-use of the same 8-tap code path. This is inefficient as we end up
spending a lot of time multiplying by, and adding, 0 - especially
since the most common filter size used with this function is 6-taps.

This patch adds a av1_dist_wtd_convolve_y_neon code path specialized
for 6-tap filters. This new path is used for both 4-tap and 6-tap
filters to reduce the amount of redundant work.

Change-Id: I730d6d1ff2e6316aedbc123c24bc26dc5a203dd1
diff --git a/aom_dsp/arm/mem_neon.h b/aom_dsp/arm/mem_neon.h
index 994a636..af11b1f 100644
--- a/aom_dsp/arm/mem_neon.h
+++ b/aom_dsp/arm/mem_neon.h
@@ -654,6 +654,14 @@
   *tu1 = load_unaligned_u8_4x2(buf, stride);
 }
 
+static INLINE void load_unaligned_u8_3x8(const uint8_t *buf, int stride,
+                                         uint8x8_t *tu0, uint8x8_t *tu1,
+                                         uint8x8_t *tu2) {
+  load_unaligned_u8_4x4(buf, stride, tu0, tu1);
+  buf += 4 * stride;
+  *tu2 = load_unaligned_u8_4x2(buf, stride);
+}
+
 static INLINE void load_unaligned_u8_4x8(const uint8_t *buf, int stride,
                                          uint8x8_t *tu0, uint8x8_t *tu1,
                                          uint8x8_t *tu2, uint8x8_t *tu3) {
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index 8ee3203..f12bf07 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -1536,44 +1536,6 @@
 
 #endif  // defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
 
-static INLINE int16x4_t convolve6_4x4(const int16x4_t s0, const int16x4_t s1,
-                                      const int16x4_t s2, const int16x4_t s3,
-                                      const int16x4_t s4, const int16x4_t s5,
-                                      const int16x8_t y_filter_0_7) {
-  const int16x4_t y_filter_0_3 = vget_low_s16(y_filter_0_7);
-  const int16x4_t y_filter_4_7 = vget_high_s16(y_filter_0_7);
-  int16x4_t sum;
-
-  // Filter values at indices 0 and 7 are 0.
-  sum = vmul_lane_s16(s0, y_filter_0_3, 1);
-  sum = vmla_lane_s16(sum, s1, y_filter_0_3, 2);
-  sum = vmla_lane_s16(sum, s2, y_filter_0_3, 3);
-  sum = vmla_lane_s16(sum, s3, y_filter_4_7, 0);
-  sum = vmla_lane_s16(sum, s4, y_filter_4_7, 1);
-  sum = vmla_lane_s16(sum, s5, y_filter_4_7, 2);
-
-  return sum;
-}
-
-static INLINE int16x8_t convolve6_8x4(const int16x8_t s0, const int16x8_t s1,
-                                      const int16x8_t s2, const int16x8_t s3,
-                                      const int16x8_t s4, const int16x8_t s5,
-                                      const int16x8_t y_filters) {
-  const int16x4_t y_filter_lo = vget_low_s16(y_filters);
-  const int16x4_t y_filter_hi = vget_high_s16(y_filters);
-  int16x8_t sum;
-
-  // Filter values at indices 0 and 7 are 0.
-  sum = vmulq_lane_s16(s0, y_filter_lo, 1);
-  sum = vmlaq_lane_s16(sum, s1, y_filter_lo, 2);
-  sum = vmlaq_lane_s16(sum, s2, y_filter_lo, 3);
-  sum = vmlaq_lane_s16(sum, s3, y_filter_hi, 0);
-  sum = vmlaq_lane_s16(sum, s4, y_filter_hi, 1);
-  sum = vmlaq_lane_s16(sum, s5, y_filter_hi, 2);
-
-  return sum;
-}
-
 static INLINE void convolve_y_sr_6tap_neon(const uint8_t *src_ptr,
                                            int src_stride, uint8_t *dst_ptr,
                                            const int dst_stride, int w, int h,
diff --git a/av1/common/arm/convolve_neon.h b/av1/common/arm/convolve_neon.h
index 3f10661..b8eac71 100644
--- a/av1/common/arm/convolve_neon.h
+++ b/av1/common/arm/convolve_neon.h
@@ -382,6 +382,44 @@
   return sum;
 }
 
+static INLINE int16x4_t convolve6_4x4(const int16x4_t s0, const int16x4_t s1,
+                                      const int16x4_t s2, const int16x4_t s3,
+                                      const int16x4_t s4, const int16x4_t s5,
+                                      const int16x8_t y_filter_0_7) {
+  const int16x4_t y_filter_0_3 = vget_low_s16(y_filter_0_7);
+  const int16x4_t y_filter_4_7 = vget_high_s16(y_filter_0_7);
+  int16x4_t sum;
+
+  // Filter values at indices 0 and 7 are 0.
+  sum = vmul_lane_s16(s0, y_filter_0_3, 1);
+  sum = vmla_lane_s16(sum, s1, y_filter_0_3, 2);
+  sum = vmla_lane_s16(sum, s2, y_filter_0_3, 3);
+  sum = vmla_lane_s16(sum, s3, y_filter_4_7, 0);
+  sum = vmla_lane_s16(sum, s4, y_filter_4_7, 1);
+  sum = vmla_lane_s16(sum, s5, y_filter_4_7, 2);
+
+  return sum;
+}
+
+static INLINE int16x8_t convolve6_8x4(const int16x8_t s0, const int16x8_t s1,
+                                      const int16x8_t s2, const int16x8_t s3,
+                                      const int16x8_t s4, const int16x8_t s5,
+                                      const int16x8_t y_filters) {
+  const int16x4_t y_filter_lo = vget_low_s16(y_filters);
+  const int16x4_t y_filter_hi = vget_high_s16(y_filters);
+  int16x8_t sum;
+
+  // Filter values at indices 0 and 7 are 0.
+  sum = vmulq_lane_s16(s0, y_filter_lo, 1);
+  sum = vmlaq_lane_s16(sum, s1, y_filter_lo, 2);
+  sum = vmlaq_lane_s16(sum, s2, y_filter_lo, 3);
+  sum = vmlaq_lane_s16(sum, s3, y_filter_hi, 0);
+  sum = vmlaq_lane_s16(sum, s4, y_filter_hi, 1);
+  sum = vmlaq_lane_s16(sum, s5, y_filter_hi, 2);
+
+  return sum;
+}
+
 static INLINE uint16x4_t convolve6_4_s32(const int16x4_t s0, const int16x4_t s1,
                                          const int16x4_t s2, const int16x4_t s3,
                                          const int16x4_t s4, const int16x4_t s5,
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index 6c1e38e..ae437da 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -2053,6 +2053,306 @@
   return vqrshlq_s16(sum, shift_round_0);
 }
 
+void dist_wtd_convolve_y_6tap_neon(const uint8_t *src_ptr, int src_stride,
+                                   uint8_t *dst8_ptr, const int dst8_stride,
+                                   int w, int h, const int16x8_t y_filter,
+                                   ConvolveParams *conv_params) {
+  CONV_BUF_TYPE *dst_ptr = conv_params->dst;
+  const int dst_stride = conv_params->dst_stride;
+  const int bits = FILTER_BITS - conv_params->round_0;
+  const int bd = 8;
+  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
+  const int round_offset = (1 << (offset_bits - conv_params->round_1)) +
+                           (1 << (offset_bits - conv_params->round_1 - 1));
+  const int round_bits =
+      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
+  const uint16_t fwd_offset = conv_params->fwd_offset;
+  const uint16_t bck_offset = conv_params->bck_offset;
+  const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
+  const int shift_value = (conv_params->round_1 - 1 - bits);
+
+  // used to get rid of multiplication = (vertical filter output sum) *
+  // (1<<bits).
+  assert((conv_params->round_1 - 2) >= bits);
+
+  if (w <= 4 || h == 4) {
+    int16x4_t s0, s1, s2, s3, s4, s5, d0;
+    uint16x4_t dd0;
+    uint8x8_t t0 = vdup_n_u8(0);
+    uint8x8_t t1 = vdup_n_u8(0);
+    uint8x8_t t2 = vdup_n_u8(0);
+    int16x8_t tt0, tt1, tt2;
+    uint8x8_t d01;
+#if defined(__aarch64__)
+    int16x4_t s6, s7, s8, d1, d2, d3;
+    uint16x4_t dd1, dd2, dd3;
+    uint8x8_t d23;
+#endif
+
+    const int16x4_t round_offset64 = vdup_n_s16(round_offset);
+    const int16x4_t shift_vec = vdup_n_s16(-shift_value);
+    int width = w;
+
+    do {
+      const uint8_t *s = src_ptr;
+      CONV_BUF_TYPE *d = dst_ptr;
+      uint8_t *d_u8 = dst8_ptr;
+      int height = h;
+
+      load_unaligned_u8_3x8(s, src_stride, &t0, &t1, &t2);
+
+      tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+      tt1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+      tt2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+
+      s0 = vget_low_s16(tt0);
+      s1 = vget_high_s16(tt0);
+      s2 = vget_low_s16(tt1);
+      s3 = vget_high_s16(tt1);
+      s4 = vget_low_s16(tt2);
+
+      s += 5 * src_stride;
+      do {
+#if defined(__aarch64__)
+        load_unaligned_u8_4x4(s, src_stride, &t0, &t1);
+
+        tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+        tt1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+
+        s5 = vget_low_s16(tt0);
+        s6 = vget_high_s16(tt0);
+        s7 = vget_low_s16(tt1);
+        s8 = vget_high_s16(tt1);
+
+        d0 = convolve6_4x4(s0, s1, s2, s3, s4, s5, y_filter);
+        d1 = convolve6_4x4(s1, s2, s3, s4, s5, s6, y_filter);
+        d2 = convolve6_4x4(s2, s3, s4, s5, s6, s7, y_filter);
+        d3 = convolve6_4x4(s3, s4, s5, s6, s7, s8, y_filter);
+
+        d0 = vqrshl_s16(d0, shift_vec);
+        d1 = vqrshl_s16(d1, shift_vec);
+        d2 = vqrshl_s16(d2, shift_vec);
+        d3 = vqrshl_s16(d3, shift_vec);
+
+        d0 = vadd_s16(d0, round_offset64);
+        d1 = vadd_s16(d1, round_offset64);
+        d2 = vadd_s16(d2, round_offset64);
+        d3 = vadd_s16(d3, round_offset64);
+
+        if (conv_params->do_average) {
+          load_u16_4x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
+
+          compute_avg_4x4(dd0, dd1, dd2, dd3, vreinterpret_u16_s16(d0),
+                          vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
+                          vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
+                          round_offset64, round_bits, use_dist_wtd_comp_avg,
+                          &d01, &d23);
+
+          store_u8_4x1(d_u8 + 0 * dst8_stride, d01, 0);
+          store_u8_4x1(d_u8 + 1 * dst8_stride, d01, 1);
+          store_u8_4x1(d_u8 + 2 * dst8_stride, d23, 0);
+          store_u8_4x1(d_u8 + 3 * dst8_stride, d23, 1);
+        } else {
+          store_u16_4x4(d, dst_stride, vreinterpret_u16_s16(d0),
+                        vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
+                        vreinterpret_u16_s16(d3));
+        }
+
+        s0 = s4;
+        s1 = s5;
+        s2 = s6;
+        s3 = s7;
+        s4 = s8;
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        d_u8 += 4 * dst8_stride;
+        height -= 4;
+#else   // !defined(__aarch64__)
+        t0 = load_unaligned_u8_4x1(s);
+        tt0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+        s5 = vget_low_s16(tt0);
+
+        d0 = convolve6_4x4(s0, s1, s2, s3, s4, s5, y_filter);
+        d0 = vqrshl_s16(d0, shift_vec);
+        d0 = vadd_s16(d0, round_offset64);
+
+        if (conv_params->do_average) {
+          dd0 = vld1_u16(d);
+
+          compute_avg_4x1(dd0, vreinterpret_u16_s16(d0), fwd_offset, bck_offset,
+                          round_offset64, round_bits, use_dist_wtd_comp_avg,
+                          &d01);
+
+          store_u8_4x1(d_u8, d01, 0);
+        } else {
+          vst1_u16(d, vreinterpret_u16_s16(d0));
+        }
+
+        s0 = s1;
+        s1 = s2;
+        s2 = s3;
+        s3 = s4;
+        s4 = s5;
+        s += src_stride;
+        d += dst_stride;
+        d_u8 += dst8_stride;
+        height--;
+#endif  // defined(__aarch64__)
+      } while (height > 0);
+      src_ptr += 4;
+      dst_ptr += 4;
+      dst8_ptr += 4;
+      width -= 4;
+    } while (width > 0);
+  } else {
+    int16x8_t s0, s1, s2, s3, s4, s5, d0;
+    uint16x8_t d8;
+    uint8x8_t t0, t1, t2, t3, t4;
+
+    const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
+    const int16x8_t shift_vec = vdupq_n_s16(-shift_value);
+    const int16x4_t round_offset64 = vdup_n_s16(round_offset);
+#if defined(__aarch64__)
+    int16x8_t s6, s7, s8, s9, s10, s11, s12, d1, d2, d3, d4, d5, d6, d7;
+    uint16x8_t d9, d10, d11;
+    uint8x8_t t5, t6, t7;
+#endif
+    int width = w;
+
+    do {
+      const uint8_t *s = src_ptr + (5 * src_stride);
+      CONV_BUF_TYPE *d = dst_ptr;
+      uint8_t *d_u8 = dst8_ptr;
+      int height = h;
+
+      load_u8_8x5(src_ptr, src_stride, &t0, &t1, &t2, &t3, &t4);
+
+      s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+      s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+      s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+      s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
+      s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
+
+      do {
+#if defined(__aarch64__)
+        load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+
+        s5 = vreinterpretq_s16_u16(vmovl_u8(t0));
+        s6 = vreinterpretq_s16_u16(vmovl_u8(t1));
+        s7 = vreinterpretq_s16_u16(vmovl_u8(t2));
+        s8 = vreinterpretq_s16_u16(vmovl_u8(t3));
+        s9 = vreinterpretq_s16_u16(vmovl_u8(t4));
+        s10 = vreinterpretq_s16_u16(vmovl_u8(t5));
+        s11 = vreinterpretq_s16_u16(vmovl_u8(t6));
+        s12 = vreinterpretq_s16_u16(vmovl_u8(t7));
+
+        d0 = convolve6_8x4(s0, s1, s2, s3, s4, s5, y_filter);
+        d1 = convolve6_8x4(s1, s2, s3, s4, s5, s6, y_filter);
+        d2 = convolve6_8x4(s2, s3, s4, s5, s6, s7, y_filter);
+        d3 = convolve6_8x4(s3, s4, s5, s6, s7, s8, y_filter);
+        d4 = convolve6_8x4(s4, s5, s6, s7, s8, s9, y_filter);
+        d5 = convolve6_8x4(s5, s6, s7, s8, s9, s10, y_filter);
+        d6 = convolve6_8x4(s6, s7, s8, s9, s10, s11, y_filter);
+        d7 = convolve6_8x4(s7, s8, s9, s10, s11, s12, y_filter);
+
+        d0 = vqrshlq_s16(d0, shift_vec);
+        d1 = vqrshlq_s16(d1, shift_vec);
+        d2 = vqrshlq_s16(d2, shift_vec);
+        d3 = vqrshlq_s16(d3, shift_vec);
+        d4 = vqrshlq_s16(d4, shift_vec);
+        d5 = vqrshlq_s16(d5, shift_vec);
+        d6 = vqrshlq_s16(d6, shift_vec);
+        d7 = vqrshlq_s16(d7, shift_vec);
+
+        d0 = vaddq_s16(d0, round_offset128);
+        d1 = vaddq_s16(d1, round_offset128);
+        d2 = vaddq_s16(d2, round_offset128);
+        d3 = vaddq_s16(d3, round_offset128);
+        d4 = vaddq_s16(d4, round_offset128);
+        d5 = vaddq_s16(d5, round_offset128);
+        d6 = vaddq_s16(d6, round_offset128);
+        d7 = vaddq_s16(d7, round_offset128);
+
+        if (conv_params->do_average) {
+          load_u16_8x4(d, dst_stride, &d8, &d9, &d10, &d11);
+          d += 4 * dst_stride;
+
+          compute_avg_8x4(d8, d9, d10, d11, vreinterpretq_u16_s16(d0),
+                          vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
+                          vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
+                          round_offset64, round_bits, use_dist_wtd_comp_avg,
+                          &t0, &t1, &t2, &t3);
+
+          store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
+          d_u8 += 4 * dst8_stride;
+
+          load_u16_8x4(d, dst_stride, &d8, &d9, &d10, &d11);
+          d += 4 * dst_stride;
+
+          compute_avg_8x4(d8, d9, d10, d11, vreinterpretq_u16_s16(d4),
+                          vreinterpretq_u16_s16(d5), vreinterpretq_u16_s16(d6),
+                          vreinterpretq_u16_s16(d7), fwd_offset, bck_offset,
+                          round_offset64, round_bits, use_dist_wtd_comp_avg,
+                          &t0, &t1, &t2, &t3);
+
+          store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
+          d_u8 += 4 * dst8_stride;
+        } else {
+          store_u16_8x8(d, dst_stride, vreinterpretq_u16_s16(d0),
+                        vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
+                        vreinterpretq_u16_s16(d3), vreinterpretq_u16_s16(d4),
+                        vreinterpretq_u16_s16(d5), vreinterpretq_u16_s16(d6),
+                        vreinterpretq_u16_s16(d7));
+          d += 8 * dst_stride;
+        }
+
+        s0 = s8;
+        s1 = s9;
+        s2 = s10;
+        s3 = s11;
+        s4 = s12;
+        s += 8 * src_stride;
+        height -= 8;
+#else   // !defined(__aarch64__)
+        s5 = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s)));
+
+        d0 = convolve6_8x4(s0, s1, s2, s3, s4, s5, y_filter);
+        d0 = vqrshlq_s16(d0, shift_vec);
+        d0 = vaddq_s16(d0, round_offset128);
+
+        s0 = s1;
+        s1 = s2;
+        s2 = s3;
+        s3 = s4;
+        s4 = s5;
+
+        if (conv_params->do_average) {
+          d8 = vld1q_u16(d);
+          d += dst_stride;
+
+          compute_avg_8x1(d8, vreinterpretq_u16_s16(d0), fwd_offset, bck_offset,
+                          round_offset64, round_bits, use_dist_wtd_comp_avg,
+                          &t0);
+
+          vst1_u8(d_u8, t0);
+          d_u8 += dst8_stride;
+        } else {
+          vst1q_u16(d, vreinterpretq_u16_s16(d0));
+          d += dst_stride;
+        }
+
+        s += src_stride;
+        height--;
+#endif  // defined(__aarch64__)
+      } while (height > 0);
+      src_ptr += 8;
+      dst_ptr += 8;
+      dst8_ptr += 8;
+      width -= 8;
+    } while (width > 0);
+  }
+}
+
 void av1_dist_wtd_convolve_y_neon(const uint8_t *src, int src_stride,
                                   uint8_t *dst8, int dst8_stride, int w, int h,
                                   const InterpFilterParams *filter_params_y,
@@ -2071,6 +2371,12 @@
   const int vert_offset = filter_params_y->taps / 2 - 1;
   const uint8_t *src_ptr = src - (vert_offset * src_stride);
 
+  if (get_filter_tap(filter_params_y, subpel_y_qn) <= 6) {
+    dist_wtd_convolve_y_6tap_neon(src_ptr + src_stride, src_stride, dst8,
+                                  dst8_stride, w, h, y_filter, conv_params);
+    return;
+  }
+
   CONV_BUF_TYPE *dst_ptr = conv_params->dst;
   const int dst_stride = conv_params->dst_stride;
   const int bits = FILTER_BITS - conv_params->round_0;