Refactor and optimize HBD Neon Wiener convolution

Refactor the high bitdepth Neon path for Wiener convolution. The
biggest change is removing a needless gather-load and subsequent
transpose. Additionally make use of the fact that Wiener filters are
symmetrical, adding mirrored source elements to reduce the number of
multiply-accumulate instructions.

Change-Id: Ifb48f14baca2dd31d2b8bab602980e0f2329f1c5
diff --git a/av1/common/arm/highbd_convolve_neon.h b/av1/common/arm/highbd_convolve_neon.h
index b534358..08b2bda 100644
--- a/av1/common/arm/highbd_convolve_neon.h
+++ b/av1/common/arm/highbd_convolve_neon.h
@@ -145,40 +145,4 @@
   return vqmovun_s32(sum);
 }
 
-static INLINE int32x4_t highbd_convolve8_horiz4x8_s32(
-    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
-    const int16x8_t s3, const int16x8_t x_filter_0_7, const int32x4_t offset) {
-  int16x4_t s_lo[] = { vget_low_s16(s0), vget_low_s16(s1), vget_low_s16(s2),
-                       vget_low_s16(s3) };
-  int16x4_t s_hi[] = { vget_high_s16(s0), vget_high_s16(s1), vget_high_s16(s2),
-                       vget_high_s16(s3) };
-
-  transpose_array_inplace_u16_4x4((uint16x4_t *)s_lo);
-  transpose_array_inplace_u16_4x4((uint16x4_t *)s_hi);
-  const int16x4_t x_filter_0_3 = vget_low_s16(x_filter_0_7);
-  const int16x4_t x_filter_4_7 = vget_high_s16(x_filter_0_7);
-
-  int32x4_t sum = vmlal_lane_s16(offset, s_lo[0], x_filter_0_3, 0);
-  sum = vmlal_lane_s16(sum, s_lo[1], x_filter_0_3, 1);
-  sum = vmlal_lane_s16(sum, s_lo[2], x_filter_0_3, 2);
-  sum = vmlal_lane_s16(sum, s_lo[3], x_filter_0_3, 3);
-  sum = vmlal_lane_s16(sum, s_hi[0], x_filter_4_7, 0);
-  sum = vmlal_lane_s16(sum, s_hi[1], x_filter_4_7, 1);
-  sum = vmlal_lane_s16(sum, s_hi[2], x_filter_4_7, 2);
-  sum = vmlal_lane_s16(sum, s_hi[3], x_filter_4_7, 3);
-
-  return sum;
-}
-
-static INLINE uint16x4_t highbd_convolve8_horiz4x8_s32_s16(
-    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
-    const int16x8_t s3, const int16x8_t x_filters_0_7,
-    const int32x4_t shift_s32, const int32x4_t offset) {
-  int32x4_t sum =
-      highbd_convolve8_horiz4x8_s32(s0, s1, s2, s3, x_filters_0_7, offset);
-
-  sum = vqrshlq_s32(sum, shift_s32);
-  return vqmovun_s32(sum);
-}
-
 #endif  // AOM_AV1_COMMON_ARM_HIGHBD_CONVOLVE_NEON_H_
diff --git a/av1/common/arm/highbd_wiener_convolve_neon.c b/av1/common/arm/highbd_wiener_convolve_neon.c
index 7ceaffb..4cec1b2 100644
--- a/av1/common/arm/highbd_wiener_convolve_neon.c
+++ b/av1/common/arm/highbd_wiener_convolve_neon.c
@@ -10,198 +10,198 @@
  */
 
 #include <arm_neon.h>
+#include <assert.h>
 
+#include "aom_dsp/arm/mem_neon.h"
+#include "av1/common/convolve.h"
 #include "config/aom_config.h"
 #include "config/av1_rtcd.h"
 
-#include "aom_dsp/arm/mem_neon.h"
-#include "aom_dsp/arm/transpose_neon.h"
-#include "av1/common/convolve.h"
-#include "av1/common/arm/highbd_convolve_neon.h"
+static INLINE uint16x8_t highbd_wiener_convolve7_8_2d_h(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
+    const int16x8_t s6, const int16x4_t x_filter, const int32x4_t round_vec,
+    const int32x4_t shift, const uint16x8_t im_max_val) {
+  // Since the Wiener filter is symmetric about the middle tap (tap 3) add
+  // mirrored source elements before multiplying by filter coefficients.
+  int16x8_t s06 = vaddq_s16(s0, s6);
+  int16x8_t s15 = vaddq_s16(s1, s5);
+  int16x8_t s24 = vaddq_s16(s2, s4);
 
-static void highbd_convolve_add_src_horiz_hip(
-    const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
-    ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int x_step_q4, int w,
-    int h, int round0_bits, int bd) {
-  const int extraprec_clamp_limit = WIENER_CLAMP_LIMIT(round0_bits, bd);
+  int32x4_t sum_lo = vmlal_lane_s16(round_vec, vget_low_s16(s06), x_filter, 0);
+  sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s15), x_filter, 1);
+  sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s24), x_filter, 2);
+  sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s3), x_filter, 3);
 
-  static const int32_t kIdx[4] = { 0, 1, 2, 3 };
-  const int32x4_t idx = vld1q_s32(kIdx);
-  const int32x4_t shift_s32 = vdupq_n_s32(-round0_bits);
-  const uint16x4_t max = vdup_n_u16(extraprec_clamp_limit - 1);
-  const int32x4_t rounding0 = vdupq_n_s32(1 << (bd + FILTER_BITS - 1));
+  int32x4_t sum_hi = vmlal_lane_s16(round_vec, vget_high_s16(s06), x_filter, 0);
+  sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s15), x_filter, 1);
+  sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s24), x_filter, 2);
+  sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s3), x_filter, 3);
 
-  int height = h;
-  do {
-    int width = w;
-    int x_q4 = 0;
-    uint16_t *d = dst_ptr;
-    const uint16_t *s = src_ptr;
+  sum_lo = vqrshlq_s32(sum_lo, shift);
+  sum_hi = vqrshlq_s32(sum_hi, shift);
 
-    do {
-      // Load 4 src vectors at a time, they might be the same, but we have to
-      // calculate the indices anyway. Doing it in SIMD and then storing the
-      // indices is faster than having to calculate the expression
-      // &src_ptr[((x_q4 + i*x_step_q4) >> SUBPEL_BITS)] 4 times
-      // Ideally this should be a gather using the indices, but NEON does not
-      // have that, so have to emulate
-      const int32x4_t xq4_idx = vmlaq_n_s32(vdupq_n_s32(x_q4), idx, x_step_q4);
-      // We have to multiply x2 to get the actual pointer as sizeof(uint16_t)
-      // = 2
-      const int32x4_t src_idx =
-          vshlq_n_s32(vshrq_n_s32(xq4_idx, SUBPEL_BITS), 1);
-
-#if AOM_ARCH_AARCH64
-      uint64x2_t tmp4[2];
-      tmp4[0] = vreinterpretq_u64_s64(
-          vaddw_s32(vdupq_n_s64((const int64_t)s), vget_low_s32(src_idx)));
-      tmp4[1] = vreinterpretq_u64_s64(
-          vaddw_s32(vdupq_n_s64((const int64_t)s), vget_high_s32(src_idx)));
-      int16_t *src4_ptr[4];
-      uint64_t *tmp_ptr = (uint64_t *)&src4_ptr;
-      vst1q_u64(tmp_ptr, tmp4[0]);
-      vst1q_u64(tmp_ptr + 2, tmp4[1]);
-#else
-      uint32x4_t tmp4;
-      tmp4 = vreinterpretq_u32_s32(
-          vaddq_s32(vdupq_n_s32((const int32_t)s), src_idx));
-      int16_t *src4_ptr[4];
-      uint32_t *tmp_ptr = (uint32_t *)&src4_ptr;
-      vst1q_u32(tmp_ptr, tmp4);
-#endif  // AOM_ARCH_AARCH64
-      // Load source
-      int16x8_t s0 = vld1q_s16(src4_ptr[0]);
-      int16x8_t s1 = vld1q_s16(src4_ptr[1]);
-      int16x8_t s2 = vld1q_s16(src4_ptr[2]);
-      int16x8_t s3 = vld1q_s16(src4_ptr[3]);
-
-      // Actually load the filters
-      const int16x8_t x_filter = vld1q_s16(x_filter_ptr);
-
-      const int16_t *rounding_ptr = (const int16_t *)src4_ptr[0];
-      int16x4_t rounding_s16 = vld1_s16(&rounding_ptr[SUBPEL_TAPS / 2 - 1]);
-      int32x4_t rounding = vshlq_n_s32(vmovl_s16(rounding_s16), FILTER_BITS);
-      rounding = vaddq_s32(rounding, rounding0);
-
-      uint16x4_t d0 = highbd_convolve8_horiz4x8_s32_s16(
-          s0, s1, s2, s3, x_filter, shift_s32, rounding);
-      d0 = vmin_u16(d0, max);
-      vst1_u16(d, d0);
-
-      x_q4 += 4 * x_step_q4;
-      d += 4;
-      width -= 4;
-    } while (width > 0);
-
-    src_ptr += src_stride;
-    dst_ptr += dst_stride;
-    height--;
-  } while (height > 0);
+  uint16x8_t res = vcombine_u16(vqmovun_s32(sum_lo), vqmovun_s32(sum_hi));
+  return vminq_u16(res, im_max_val);
 }
 
-static void highbd_convolve_add_src_vert_hip(
+static INLINE void highbd_convolve_add_src_horiz_hip(
     const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
-    ptrdiff_t dst_stride, const int16_t *y_filter_ptr, int y_step_q4, int w,
-    int h, int round1_bits, int bd) {
-  static const int32_t kIdx[4] = { 0, 1, 2, 3 };
-  const int32x4_t idx = vld1q_s32(kIdx);
-  const int32x4_t shift_s32 = vdupq_n_s32(-round1_bits);
-  const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
-  const int32x4_t rounding0 = vdupq_n_s32(1 << (bd + round1_bits - 1));
-
-  int width = w;
+    ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter,
+    const int32x4_t round_vec, const int32x4_t shift,
+    const uint16x8_t im_max_val) {
   do {
-    int height = h;
-    int y_q4 = 0;
+    const int16_t *s = (int16_t *)src_ptr;
     uint16_t *d = dst_ptr;
-    const uint16_t *s = src_ptr;
+    int width = w;
 
     do {
-      // Load 4 src vectors at a time, they might be the same, but we have to
-      // calculate the indices anyway. Doing it in SIMD and then storing the
-      // indices is faster than having to calculate the expression
-      // &src_ptr[((x_q4 + i*x_step_q4) >> SUBPEL_BITS)] 4 times
-      // Ideally this should be a gather using the indices, but NEON does not
-      // have that, so have to emulate
-      const int32x4_t yq4_idx = vmlaq_n_s32(vdupq_n_s32(y_q4), idx, y_step_q4);
-      // We have to multiply x2 to get the actual pointer as sizeof(uint16_t)
-      // = 2
-      const int32x4_t src_idx =
-          vshlq_n_s32(vshrq_n_s32(yq4_idx, SUBPEL_BITS), 1);
-#if AOM_ARCH_AARCH64
-      uint64x2_t tmp4[2];
-      tmp4[0] = vreinterpretq_u64_s64(
-          vaddw_s32(vdupq_n_s64((const int64_t)s), vget_low_s32(src_idx)));
-      tmp4[1] = vreinterpretq_u64_s64(
-          vaddw_s32(vdupq_n_s64((const int64_t)s), vget_high_s32(src_idx)));
-      const int16_t *src4_ptr[4];
-      uint64_t *tmp_ptr = (uint64_t *)&src4_ptr;
-      vst1q_u64(tmp_ptr, tmp4[0]);
-      vst1q_u64(tmp_ptr + 2, tmp4[1]);
-#else
-      uint32x4_t tmp4;
-      tmp4 = vreinterpretq_u32_s32(
-          vaddq_s32(vdupq_n_s32((const int32_t)s), src_idx));
-      int16_t *src4_ptr[4];
-      uint32_t *tmp_ptr = (uint32_t *)&src4_ptr;
-      vst1q_u32(tmp_ptr, tmp4);
-#endif  // AOM_ARCH_AARCH64
+      int16x8_t s0, s1, s2, s3, s4, s5, s6;
+      load_s16_8x7(s, 1, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
 
-      // Load source
-      int16x4_t s0, s1, s2, s3, s4, s5, s6, s7;
-      load_s16_4x8(src4_ptr[0], src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6,
-                   &s7);
+      uint16x8_t d0 = highbd_wiener_convolve7_8_2d_h(
+          s0, s1, s2, s3, s4, s5, s6, x_filter, round_vec, shift, im_max_val);
 
-      // Actually load the filters
-      const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
+      vst1q_u16(d, d0);
 
-      const int16_t *rounding_ptr = (const int16_t *)src4_ptr[0];
-      int16x4_t rounding_s16 =
-          vld1_s16(&rounding_ptr[(SUBPEL_TAPS / 2 - 1) * src_stride]);
-      int32x4_t rounding = vshlq_n_s32(vmovl_s16(rounding_s16), FILTER_BITS);
-      rounding = vsubq_s32(rounding, rounding0);
+      s += 8;
+      d += 8;
+      width -= 8;
+    } while (width != 0);
+    src_ptr += src_stride;
+    dst_ptr += dst_stride;
+  } while (--h != 0);
+}
 
-      // Run the convolution
-      uint16x4_t d0 = highbd_convolve8_4_sr_s32_s16(
-          s0, s1, s2, s3, s4, s5, s6, s7, y_filter, shift_s32, rounding);
-      d0 = vmin_u16(d0, max);
-      vst1_u16(d, d0);
+static INLINE uint16x8_t highbd_wiener_convolve7_8_2d_v(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
+    const int16x8_t s6, const int16x4_t y_filter, const int32x4_t round_vec,
+    const int32x4_t shift, const uint16x8_t res_max_val) {
+  const int32x2_t y_filter_lo = vget_low_s32(vmovl_s16(y_filter));
+  const int32x2_t y_filter_hi = vget_high_s32(vmovl_s16(y_filter));
+  // Since the Wiener filter is symmetric about the middle tap (tap 3) add
+  // mirrored source elements before multiplying by filter coefficients.
+  int32x4_t s06_lo = vaddl_s16(vget_low_s16(s0), vget_low_s16(s6));
+  int32x4_t s15_lo = vaddl_s16(vget_low_s16(s1), vget_low_s16(s5));
+  int32x4_t s24_lo = vaddl_s16(vget_low_s16(s2), vget_low_s16(s4));
+
+  int32x4_t sum_lo = vmlaq_lane_s32(round_vec, s06_lo, y_filter_lo, 0);
+  sum_lo = vmlaq_lane_s32(sum_lo, s15_lo, y_filter_lo, 1);
+  sum_lo = vmlaq_lane_s32(sum_lo, s24_lo, y_filter_hi, 0);
+  sum_lo = vmlaq_lane_s32(sum_lo, vmovl_s16(vget_low_s16(s3)), y_filter_hi, 1);
+
+  int32x4_t s06_hi = vaddl_s16(vget_high_s16(s0), vget_high_s16(s6));
+  int32x4_t s15_hi = vaddl_s16(vget_high_s16(s1), vget_high_s16(s5));
+  int32x4_t s24_hi = vaddl_s16(vget_high_s16(s2), vget_high_s16(s4));
+
+  int32x4_t sum_hi = vmlaq_lane_s32(round_vec, s06_hi, y_filter_lo, 0);
+  sum_hi = vmlaq_lane_s32(sum_hi, s15_hi, y_filter_lo, 1);
+  sum_hi = vmlaq_lane_s32(sum_hi, s24_hi, y_filter_hi, 0);
+  sum_hi = vmlaq_lane_s32(sum_hi, vmovl_s16(vget_high_s16(s3)), y_filter_hi, 1);
+
+  sum_lo = vqrshlq_s32(sum_lo, shift);
+  sum_hi = vqrshlq_s32(sum_hi, shift);
+
+  uint16x8_t res = vcombine_u16(vqmovun_s32(sum_lo), vqmovun_s32(sum_hi));
+  return vminq_u16(res, res_max_val);
+}
+
+static INLINE void highbd_convolve_add_src_vert_hip(
+    const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
+    ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,
+    const int32x4_t round_vec, const int32x4_t shift,
+    const uint16x8_t res_max_val) {
+  do {
+    const int16_t *s = (int16_t *)src_ptr;
+    uint16_t *d = dst_ptr;
+    int height = h;
+
+    while (height > 3) {
+      int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9;
+      load_s16_8x10(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8,
+                    &s9);
+
+      uint16x8_t d0 = highbd_wiener_convolve7_8_2d_v(
+          s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, shift, res_max_val);
+      uint16x8_t d1 = highbd_wiener_convolve7_8_2d_v(
+          s1, s2, s3, s4, s5, s6, s7, y_filter, round_vec, shift, res_max_val);
+      uint16x8_t d2 = highbd_wiener_convolve7_8_2d_v(
+          s2, s3, s4, s5, s6, s7, s8, y_filter, round_vec, shift, res_max_val);
+      uint16x8_t d3 = highbd_wiener_convolve7_8_2d_v(
+          s3, s4, s5, s6, s7, s8, s9, y_filter, round_vec, shift, res_max_val);
+
+      store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+      s += 4 * src_stride;
+      d += 4 * dst_stride;
+      height -= 4;
+    }
+
+    while (height-- != 0) {
+      int16x8_t s0, s1, s2, s3, s4, s5, s6;
+      load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
+
+      uint16x8_t d0 = highbd_wiener_convolve7_8_2d_v(
+          s0, s1, s2, s3, s4, s5, s6, y_filter, round_vec, shift, res_max_val);
+
+      vst1q_u16(d, d0);
 
       s += src_stride;
       d += dst_stride;
-      height--;
-    } while (height > 0);
+    }
 
-    y_q4 += 4 * y_step_q4;
-    src_ptr += 4;
-    dst_ptr += 4;
-    width -= 4;
-  } while (width > 0);
+    src_ptr += 8;
+    dst_ptr += 8;
+    w -= 8;
+  } while (w != 0);
 }
 
-#define WIENER_MAX_EXT_SIZE 263
-
 void av1_highbd_wiener_convolve_add_src_neon(
     const uint8_t *src8, ptrdiff_t src_stride, uint8_t *dst8,
-    ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int x_step_q4,
-    const int16_t *y_filter_ptr, int y_step_q4, int w, int h,
+    ptrdiff_t dst_stride, const int16_t *x_filter, int x_step_q4,
+    const int16_t *y_filter, int y_step_q4, int w, int h,
     const WienerConvolveParams *conv_params, int bd) {
-  assert(x_step_q4 == 16 && y_step_q4 == 16);
+  (void)x_step_q4;
+  (void)y_step_q4;
 
-  DECLARE_ALIGNED(16, uint16_t, im_block[WIENER_MAX_EXT_SIZE * MAX_SB_SIZE]);
-  const int im_h = (((h - 1) * y_step_q4) >> SUBPEL_BITS) + SUBPEL_TAPS;
+  assert(w % 8 == 0);
+  assert(w <= MAX_SB_SIZE && h <= MAX_SB_SIZE);
+  assert(x_step_q4 == 16 && y_step_q4 == 16);
+  assert(x_filter[7] == 0 && y_filter[7] == 0);
+
+  DECLARE_ALIGNED(16, uint16_t,
+                  im_block[(MAX_SB_SIZE + WIENER_WIN - 1) * MAX_SB_SIZE]);
+
+  int16x4_t x_filter_s16 = vld1_s16(x_filter);
+  int16x4_t y_filter_s16 = vld1_s16(y_filter);
+  // Add 128 to tap 3. (Needed for rounding.)
+  x_filter_s16 = vadd_s16(x_filter_s16, vcreate_s16(128ULL << 48));
+  y_filter_s16 = vadd_s16(y_filter_s16, vcreate_s16(128ULL << 48));
+
   const int im_stride = MAX_SB_SIZE;
-  const int vert_offset = SUBPEL_TAPS / 2 - 1;
-  const int horiz_offset = SUBPEL_TAPS / 2 - 1;
+  const int im_h = h + WIENER_WIN - 1;
+  const int horiz_offset = WIENER_HALFWIN;
+  const int vert_offset = WIENER_HALFWIN * (int)src_stride;
+
+  const int extraprec_clamp_limit =
+      WIENER_CLAMP_LIMIT(conv_params->round_0, bd);
+  const uint16x8_t im_max_val = vdupq_n_u16(extraprec_clamp_limit - 1);
+  const int32x4_t horiz_shift = vdupq_n_s32(-conv_params->round_0);
+  const int32x4_t horiz_round_vec = vdupq_n_s32(1 << (bd + FILTER_BITS - 1));
+
+  const uint16x8_t res_max_val = vdupq_n_u16((1 << bd) - 1);
+  const int32x4_t vert_shift = vdupq_n_s32(-conv_params->round_1);
+  const int32x4_t vert_round_vec =
+      vdupq_n_s32(-(1 << (bd + conv_params->round_1 - 1)));
 
   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
-  const uint16_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
 
-  highbd_convolve_add_src_horiz_hip(src_ptr, src_stride, im_block, im_stride,
-                                    x_filter_ptr, x_step_q4, w, im_h,
-                                    conv_params->round_0, bd);
-  highbd_convolve_add_src_vert_hip(im_block, im_stride, dst, dst_stride,
-                                   y_filter_ptr, y_step_q4, w, h,
-                                   conv_params->round_1, bd);
+  highbd_convolve_add_src_horiz_hip(
+      src - horiz_offset - vert_offset, src_stride, im_block, im_stride, w,
+      im_h, x_filter_s16, horiz_round_vec, horiz_shift, im_max_val);
+  highbd_convolve_add_src_vert_hip(im_block, im_stride, dst, dst_stride, w, h,
+                                   y_filter_s16, vert_round_vec, vert_shift,
+                                   res_max_val);
 }