Specialize HBD Neon Wiener convolution vert. pass by bitdepth

The narrowing shift values used in the vertical pass of Wiener
convolution differ depending on the bitdepth. Since we can eliminate
2 relatively expensive instructions from the inner loop of the
convolution kernel if we supply compile-time constants, specialize
the path by bitdepth. (Bitdpeths 8 and 10 use the same shift values
so we only actually need one extra path.)

Change-Id: Iad02a42821bdb1324a10e8c0e7b41af280ccdecf
diff --git a/av1/common/arm/highbd_wiener_convolve_neon.c b/av1/common/arm/highbd_wiener_convolve_neon.c
index aeb4cda..da1af97 100644
--- a/av1/common/arm/highbd_wiener_convolve_neon.c
+++ b/av1/common/arm/highbd_wiener_convolve_neon.c
@@ -78,89 +78,94 @@
 
 #undef HBD_WIENER_7TAP_HORIZ
 
-static INLINE uint16x8_t highbd_wiener_convolve7_8_2d_v(
-    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
-    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
-    const int16x8_t s6, const int16x4_t y_filter, const int32x4_t round_vec,
-    const int32x4_t shift, const uint16x8_t res_max_val) {
-  const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter));
-  const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter));
-  // Since the Wiener filter is symmetric about the middle tap (tap 3) add
-  // mirrored source elements before multiplying by filter coefficients.
-  int32x4_t s06_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s6));
-  int32x4_t s15_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s5));
-  int32x4_t s24_lo = vaddl_s16(vget_low_s16(s2), vget_low_s16(s4));
+#define HBD_WIENER_7TAP_VERT(name, shift)                                      \
+  static INLINE uint16x8_t name##_wiener_convolve7_8_2d_v(                     \
+      const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,              \
+      const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,              \
+      const int16x8_t s6, const int16x4_t y_filter, const int32x4_t round_vec, \
+      const uint16x8_t res_max_val) {                                          \
+    const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter));           \
+    const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter));          \
+    /* Wiener filter is symmetric so add mirrored source elements. */          \
+    int32x4_t s06_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s6));          \
+    int32x4_t s15_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s5));          \
+    int32x4_t s24_lo = vaddl_s16(vget_low_s16(s2), vget_low_s16(s4));          \
+                                                                               \
+    int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s06_lo, y_filter_lo, 0);      \
+    sum_lo = vmlaq_lane_s32(sum_lo, s15_lo, y_filter_lo, 1);                   \
+    sum_lo = vmlaq_lane_s32(sum_lo, s24_lo, y_filter_hi, 0);                   \
+    sum_lo =                                                                   \
+        vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s3)), y_filter_hi, 1);   \
+                                                                               \
+    int32x4_t s06_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s6));        \
+    int32x4_t s15_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s5));        \
+    int32x4_t s24_hi = vaddl_s16(vget_high_s16(s2), vget_high_s16(s4));        \
+                                                                               \
+    int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s06_hi, y_filter_lo, 0);      \
+    sum_hi = vmlaq_lane_s32(sum_hi, s15_hi, y_filter_lo, 1);                   \
+    sum_hi = vmlaq_lane_s32(sum_hi, s24_hi, y_filter_hi, 0);                   \
+    sum_hi =                                                                   \
+        vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s3)), y_filter_hi, 1);  \
+                                                                               \
+    uint16x4_t res_lo = vqrshrun_n_s32(sum_lo, shift);                         \
+    uint16x4_t res_hi = vqrshrun_n_s32(sum_hi, shift);                         \
+                                                                               \
+    return vminq_u16(vcombine_u16(res_lo, res_hi), res_max_val);               \
+  }                                                                            \
+                                                                               \
+  static INLINE void name##_convolve_add_src_vert_hip(                         \
+      const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,        \
+      ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,            \
+      const int32x4_t round_vec, const uint16x8_t res_max_val) {               \
+    do {                                                                       \
+      const int16_t *s = (int16_t *)src_ptr;                                   \
+      uint16_t *d = dst_ptr;                                                   \
+      int height = h;                                                          \
+                                                                               \
+      while (height > 3) {                                                     \
+        int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9;                      \
+        load_s16_8x10(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7,   \
+                      &s8, &s9);                                               \
+                                                                               \
+        uint16x8_t d0 = name##_wiener_convolve7_8_2d_v(                        \
+            s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val);     \
+        uint16x8_t d1 = name##_wiener_convolve7_8_2d_v(                        \
+            s1, s2, s3, s4, s5, s6, s7, y_filter, round_vec, res_max_val);     \
+        uint16x8_t d2 = name##_wiener_convolve7_8_2d_v(                        \
+            s2, s3, s4, s5, s6, s7, s8, y_filter, round_vec, res_max_val);     \
+        uint16x8_t d3 = name##_wiener_convolve7_8_2d_v(                        \
+            s3, s4, s5, s6, s7, s8, s9, y_filter, round_vec, res_max_val);     \
+                                                                               \
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);                          \
+                                                                               \
+        s += 4 * src_stride;                                                   \
+        d += 4 * dst_stride;                                                   \
+        height -= 4;                                                           \
+      }                                                                        \
+                                                                               \
+      while (height-- != 0) {                                                  \
+        int16x8_t s0, s1, s2, s3, s4, s5, s6;                                  \
+        load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);        \
+                                                                               \
+        uint16x8_t d0 = name##_wiener_convolve7_8_2d_v(                        \
+            s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, res_max_val);     \
+                                                                               \
+        vst1q_u16(d, d0);                                                      \
+                                                                               \
+        s += src_stride;                                                       \
+        d += dst_stride;                                                       \
+      }                                                                        \
+                                                                               \
+      src_ptr += 8;                                                            \
+      dst_ptr += 8;                                                            \
+      w -= 8;                                                                  \
+    } while (w != 0);                                                          \
+  }
 
-  int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s06_lo, y_filter_lo, 0);
-  sum_lo = vmlaq_lane_s32(sum_lo, s15_lo, y_filter_lo, 1);
-  sum_lo = vmlaq_lane_s32(sum_lo, s24_lo, y_filter_hi, 0);
-  sum_lo = vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s3)), y_filter_hi, 1);
+HBD_WIENER_7TAP_VERT(highbd, 2 * FILTER_BITS - WIENER_ROUND0_BITS)
+HBD_WIENER_7TAP_VERT(highbd_12, 2 * FILTER_BITS - WIENER_ROUND0_BITS - 2)
 
-  int32x4_t s06_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s6));
-  int32x4_t s15_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s5));
-  int32x4_t s24_hi = vaddl_s16(vget_high_s16(s2), vget_high_s16(s4));
-
-  int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s06_hi, y_filter_lo, 0);
-  sum_hi = vmlaq_lane_s32(sum_hi, s15_hi, y_filter_lo, 1);
-  sum_hi = vmlaq_lane_s32(sum_hi, s24_hi, y_filter_hi, 0);
-  sum_hi = vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s3)), y_filter_hi, 1);
-
-  sum_lo = vqrshlq_s32(sum_lo, shift);
-  sum_hi = vqrshlq_s32(sum_hi, shift);
-
-  uint16x8_t res = vcombine_u16(vqmovun_s32(sum_lo), vqmovun_s32(sum_hi));
-  return vminq_u16(res, res_max_val);
-}
-
-static INLINE void highbd_convolve_add_src_vert_hip(
-    const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
-    ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,
-    const int32x4_t round_vec, const int32x4_t shift,
-    const uint16x8_t res_max_val) {
-  do {
-    const int16_t *s = (int16_t *)src_ptr;
-    uint16_t *d = dst_ptr;
-    int height = h;
-
-    while (height > 3) {
-      int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9;
-      load_s16_8x10(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8,
-                    &s9);
-
-      uint16x8_t d0 = highbd_wiener_convolve7_8_2d_v(
-          s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, shift, res_max_val);
-      uint16x8_t d1 = highbd_wiener_convolve7_8_2d_v(
-          s1, s2, s3, s4, s5, s6, s7, y_filter, round_vec, shift, res_max_val);
-      uint16x8_t d2 = highbd_wiener_convolve7_8_2d_v(
-          s2, s3, s4, s5, s6, s7, s8, y_filter, round_vec, shift, res_max_val);
-      uint16x8_t d3 = highbd_wiener_convolve7_8_2d_v(
-          s3, s4, s5, s6, s7, s8, s9, y_filter, round_vec, shift, res_max_val);
-
-      store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
-
-      s += 4 * src_stride;
-      d += 4 * dst_stride;
-      height -= 4;
-    }
-
-    while (height-- != 0) {
-      int16x8_t s0, s1, s2, s3, s4, s5, s6;
-      load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
-
-      uint16x8_t d0 = highbd_wiener_convolve7_8_2d_v(
-          s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, shift, res_max_val);
-
-      vst1q_u16(d, d0);
-
-      s += src_stride;
-      d += dst_stride;
-    }
-
-    src_ptr += 8;
-    dst_ptr += 8;
-    w -= 8;
-  } while (w != 0);
-}
+#undef HBD_WIENER_7TAP_VERT
 
 void av1_highbd_wiener_convolve_add_src_neon(
     const uint8_t *src8, ptrdiff_t src_stride, uint8_t *dst8,
@@ -195,7 +200,6 @@
   const int32x4_t horiz_round_vec = vdupq_n_s32(1 << (bd + FILTER_BITS - 1));
 
   const uint16x8_t res_max_val = vdupq_n_u16((1 << bd) - 1);
-  const int32x4_t vert_shift = vdupq_n_s32(-conv_params->round_1);
   const int32x4_t vert_round_vec =
       vdupq_n_s32(-(1 << (bd + conv_params->round_1 - 1)));
 
@@ -206,13 +210,14 @@
     highbd_12_convolve_add_src_horiz_hip(
         src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
         im_h, x_filter_s16, horiz_round_vec, im_max_val);
+    highbd_12_convolve_add_src_vert_hip(im_block, im_stride, dst, dst_stride, w,
+                                        h, y_filter_s16, vert_round_vec,
+                                        res_max_val);
   } else {
     highbd_convolve_add_src_horiz_hip(
         src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
         im_h, x_filter_s16, horiz_round_vec, im_max_val);
+    highbd_convolve_add_src_vert_hip(im_block, im_stride, dst, dst_stride, w, h,
+                                     y_filter_s16, vert_round_vec, res_max_val);
   }
-
-  highbd_convolve_add_src_vert_hip(im_block, im_stride, dst, dst_stride, w, h,
-                                   y_filter_s16, vert_round_vec, vert_shift,
-                                   res_max_val);
 }