Optimize HBD Neon vertical Wiener convolve for 5-tap filters

Wiener convolution filters can be either 5 or 7 taps. The current Neon
implementation pads 5-tap filters with zeros to treat them as 7-tap
filters. This patch adds a Neon path specialised for 5-tap vertical
Wiener filters, avoiding the redundant work of using the 7-tap path
with a 5-tap filter.

Change-Id: Ic081be65d7a9e6b8c8711cf9778dd6ce4654336d
diff --git a/av1/common/arm/highbd_wiener_convolve_neon.c b/av1/common/arm/highbd_wiener_convolve_neon.c
index 588b4f8..a6bd6d3 100644
--- a/av1/common/arm/highbd_wiener_convolve_neon.c
+++ b/av1/common/arm/highbd_wiener_convolve_neon.c
@@ -136,6 +136,90 @@
 
 #undef HBD_WIENER_7TAP_HORIZ
 
+#define HBD_WIENER_5TAP_VERT(name, shift)                                     \
+  static INLINE uint16x8_t name##_wiener_convolve5_8_2d_v(                    \
+      const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,             \
+      const int16x8_t s3, const int16x8_t s4, const int16x4_t y_filter,       \
+      const int32x4_t round_vec, const uint16x8_t res_max_val) {              \
+    const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter));          \
+    const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter));         \
+    /* Wiener filter is symmetric so add mirrored source elements. */         \
+    int32x4_t s04_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s4));         \
+    int32x4_t s13_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s3));         \
+                                                                              \
+    /* y_filter[0] = 0. (5-tap filters are 0-padded to 7 taps.) */            \
+    int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s04_lo, y_filter_lo, 1);     \
+    sum_lo = vmlaq_lane_s32(sum_lo, s13_lo, y_filter_hi, 0);                  \
+    sum_lo =                                                                  \
+        vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s2)), y_filter_hi, 1);  \
+                                                                              \
+    int32x4_t s04_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s4));       \
+    int32x4_t s13_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s3));       \
+                                                                              \
+    int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s04_hi, y_filter_lo, 1);     \
+    sum_hi = vmlaq_lane_s32(sum_hi, s13_hi, y_filter_hi, 0);                  \
+    sum_hi =                                                                  \
+        vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s2)), y_filter_hi, 1); \
+                                                                              \
+    uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift);                        \
+    uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift);                        \
+                                                                              \
+    return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val);              \
+  }                                                                           \
+                                                                              \
+  static INLINE void name##_convolve_add_src_5tap_vert(                       \
+      const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,       \
+      ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,           \
+      const int32x4_t round_vec, const uint16x8_t res_max_val) {              \
+    do {                                                                      \
+      const int16_t *s = (int16_t *)src_ptr;                                  \
+      uint16_t *d = dst_ptr;                                                  \
+      int height = h;                                                         \
+                                                                              \
+      while (height > 3) {                                                    \
+        int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;                             \
+        load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);  \
+                                                                              \
+        uint16x8_t d0 = name##_wiener_convolve5_8_2d_v(                       \
+            s0, s1, s2, s3, s4, y_filter, round_vec, res_max_val);            \
+        uint16x8_t d1 = name##_wiener_convolve5_8_2d_v(                       \
+            s1, s2, s3, s4, s5, y_filter, round_vec, res_max_val);            \
+        uint16x8_t d2 = name##_wiener_convolve5_8_2d_v(                       \
+            s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val);            \
+        uint16x8_t d3 = name##_wiener_convolve5_8_2d_v(                       \
+            s3, s4, s5, s6, s7, y_filter, round_vec, res_max_val);            \
+                                                                              \
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);                         \
+                                                                              \
+        s += 4 * src_stride;                                                  \
+        d += 4 * dst_stride;                                                  \
+        height -= 4;                                                          \
+      }                                                                       \
+                                                                              \
+      while (height-- != 0) {                                                 \
+        int16x8_t s0, s1, s2, s3, s4;                                         \
+        load_s16_8x5(s, src_stride, &s0, &s1, &s2, &s3, &s4);                 \
+                                                                              \
+        uint16x8_t d0 = name##_wiener_convolve5_8_2d_v(                       \
+            s0, s1, s2, s3, s4, y_filter, round_vec, res_max_val);            \
+                                                                              \
+        vst1q_u16(d, d0);                                                     \
+                                                                              \
+        s += src_stride;                                                      \
+        d += dst_stride;                                                      \
+      }                                                                       \
+                                                                              \
+      src_ptr += 8;                                                           \
+      dst_ptr += 8;                                                           \
+      w -= 8;                                                                 \
+    } while (w != 0);                                                         \
+  }
+
+HBD_WIENER_5TAP_VERT(highbd, 2 * FILTER_BITS - WIENER_ROUND0_BITS)
+HBD_WIENER_5TAP_VERT(highbd_12, 2 * FILTER_BITS - WIENER_ROUND0_BITS - 2)
+
+#undef HBD_WIENER_5TAP_VERT
+
 #define HBD_WIENER_7TAP_VERT(name, shift)                                      \
   static INLINE uint16x8_t name##_wiener_convolve7_8_2d_v(                     \
       const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,              \
@@ -171,7 +255,7 @@
     return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val);               \
   }                                                                            \
                                                                                \
-  static INLINE void name##_convolve_add_src_vert_hip(                         \
+  static INLINE void name##_convolve_add_src_7tap_vert(                        \
       const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,        \
       ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,            \
       const int32x4_t round_vec, const uint16x8_t res_max_val) {               \
@@ -250,6 +334,7 @@
                   im_block[(MAX_SB_SIZE + WIENER_WIN - 1) * MAX_SB_SIZE]);
 
   const int x_filter_taps = get_wiener_filter_taps(x_filter);
+  const int y_filter_taps = get_wiener_filter_taps(y_filter);
   int16x4_t x_filter_s16 = vld1_s16(x_filter);
   int16x4_t y_filter_s16 = vld1_s16(y_filter);
   // Add 128 to tap 3. (Needed for rounding.)
@@ -257,9 +342,9 @@
   y_filter_s16 = vadd_s16(y_filter_s16, vcreate_s16(128ULL << 48));
 
   const int im_stride = MAX_SB_SIZE;
-  const int im_h = h + WIENER_WIN - 1;
+  const int im_h = h + y_filter_taps - 1;
   const int horiz_offset = x_filter_taps / 2;
-  const int vert_offset = WIENER_HALFWIN * (int)src_stride;
+  const int vert_offset = (y_filter_taps / 2) * (int)src_stride;
 
   const int extraprec_clamp_limit =
       WIENER_CLAMP_LIMIT(conv_params->round_0, bd);
@@ -284,9 +369,16 @@
           im_h, x_filter_s16, horiz_round_vec, im_max_val);
     }
 
-    highbd_12_convolve_add_src_vert_hip(im_block, im_stride, dst, dst_stride, w,
-                                        h, y_filter_s16, vert_round_vec,
-                                        res_max_val);
+    if (y_filter_taps == WIENER_WIN_REDUCED) {
+      highbd_12_convolve_add_src_5tap_vert(im_block, im_stride, dst, dst_stride,
+                                           w, h, y_filter_s16, vert_round_vec,
+                                           res_max_val);
+    } else {
+      highbd_12_convolve_add_src_7tap_vert(im_block, im_stride, dst, dst_stride,
+                                           w, h, y_filter_s16, vert_round_vec,
+                                           res_max_val);
+    }
+
   } else {
     if (x_filter_taps == WIENER_WIN_REDUCED) {
       highbd_convolve_add_src_5tap_horiz(
@@ -298,7 +390,14 @@
           im_h, x_filter_s16, horiz_round_vec, im_max_val);
     }
 
-    highbd_convolve_add_src_vert_hip(im_block, im_stride, dst, dst_stride, w, h,
-                                     y_filter_s16, vert_round_vec, res_max_val);
+    if (y_filter_taps == WIENER_WIN_REDUCED) {
+      highbd_convolve_add_src_5tap_vert(im_block, im_stride, dst, dst_stride, w,
+                                        h, y_filter_s16, vert_round_vec,
+                                        res_max_val);
+    } else {
+      highbd_convolve_add_src_7tap_vert(im_block, im_stride, dst, dst_stride, w,
+                                        h, y_filter_s16, vert_round_vec,
+                                        res_max_val);
+    }
   }
 }