Propagate constants in Neon dist_wtd_convolve_2d_vert functions

Rounding modes and other convolution parameters are known ahead of
time. This patch propagates the values into the Neon code paths for
dist_wtd_convolve_2d_vert_[6/8]tap_neon - enabling us to make some
useful simplifications and optimizations.

Co-authored by: Jonathan Wright <jonathan.wright@arm.com>

Change-Id: Ie3753db63a1362a51f67cbdc73d2cacd8516a948
diff --git a/av1/common/arm/convolve_neon.h b/av1/common/arm/convolve_neon.h
index 59c77b0..ee12d13 100644
--- a/av1/common/arm/convolve_neon.h
+++ b/av1/common/arm/convolve_neon.h
@@ -483,7 +483,6 @@
                                          const int16x4_t s2, const int16x4_t s3,
                                          const int16x4_t s4, const int16x4_t s5,
                                          const int16x8_t y_filter,
-                                         const int32x4_t round_shift_vec,
                                          const int32x4_t offset_const) {
   const int16x4_t y_filter_lo = vget_low_s16(y_filter);
   const int16x4_t y_filter_hi = vget_high_s16(y_filter);
@@ -496,15 +495,13 @@
   sum = vmlal_lane_s16(sum, s4, y_filter_hi, 1);
   sum = vmlal_lane_s16(sum, s5, y_filter_hi, 2);
 
-  sum = vqrshlq_s32(sum, round_shift_vec);
-  return vqmovun_s32(sum);
+  return vqrshrun_n_s32(sum, COMPOUND_ROUND1_BITS);
 }
 
 static INLINE uint16x8_t convolve6_8_s32(const int16x8_t s0, const int16x8_t s1,
                                          const int16x8_t s2, const int16x8_t s3,
                                          const int16x8_t s4, const int16x8_t s5,
                                          const int16x8_t y_filter,
-                                         const int32x4_t round_shift_vec,
                                          const int32x4_t offset_const) {
   const int16x4_t y_filter_lo = vget_low_s16(y_filter);
   const int16x4_t y_filter_hi = vget_high_s16(y_filter);
@@ -525,9 +522,8 @@
   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s4), y_filter_hi, 1);
   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s5), y_filter_hi, 2);
 
-  sum0 = vqrshlq_s32(sum0, round_shift_vec);
-  sum1 = vqrshlq_s32(sum1, round_shift_vec);
-  return vcombine_u16(vqmovun_s32(sum0), vqmovun_s32(sum1));
+  return vcombine_u16(vqrshrun_n_s32(sum0, COMPOUND_ROUND1_BITS),
+                      vqrshrun_n_s32(sum1, COMPOUND_ROUND1_BITS));
 }
 
 static INLINE uint16x4_t convolve8_4_s32(const int16x4_t s0, const int16x4_t s1,
@@ -535,7 +531,6 @@
                                          const int16x4_t s4, const int16x4_t s5,
                                          const int16x4_t s6, const int16x4_t s7,
                                          const int16x8_t y_filter,
-                                         const int32x4_t round_shift_vec,
                                          const int32x4_t offset_const) {
   const int16x4_t y_filter_lo = vget_low_s16(y_filter);
   const int16x4_t y_filter_hi = vget_high_s16(y_filter);
@@ -550,8 +545,7 @@
   sum = vmlal_lane_s16(sum, s6, y_filter_hi, 2);
   sum = vmlal_lane_s16(sum, s7, y_filter_hi, 3);
 
-  sum = vqrshlq_s32(sum, round_shift_vec);
-  return vqmovun_s32(sum);
+  return vqrshrun_n_s32(sum, COMPOUND_ROUND1_BITS);
 }
 
 static INLINE uint16x8_t convolve8_8_s32(const int16x8_t s0, const int16x8_t s1,
@@ -559,7 +553,6 @@
                                          const int16x8_t s4, const int16x8_t s5,
                                          const int16x8_t s6, const int16x8_t s7,
                                          const int16x8_t y_filter,
-                                         const int32x4_t round_shift_vec,
                                          const int32x4_t offset_const) {
   const int16x4_t y_filter_lo = vget_low_s16(y_filter);
   const int16x4_t y_filter_hi = vget_high_s16(y_filter);
@@ -584,9 +577,8 @@
   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s6), y_filter_hi, 2);
   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s7), y_filter_hi, 3);
 
-  sum0 = vqrshlq_s32(sum0, round_shift_vec);
-  sum1 = vqrshlq_s32(sum1, round_shift_vec);
-  return vcombine_u16(vqmovun_s32(sum0), vqmovun_s32(sum1));
+  return vcombine_u16(vqrshrun_n_s32(sum0, COMPOUND_ROUND1_BITS),
+                      vqrshrun_n_s32(sum1, COMPOUND_ROUND1_BITS));
 }
 
 #if !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index bd868ea..67d5eee 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -709,14 +709,13 @@
   const int dst_stride = conv_params->dst_stride;
 
   const int bd = 8;
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int16_t sub_const = (1 << (offset_bits - conv_params->round_1)) +
-                            (1 << (offset_bits - conv_params->round_1 - 1));
+  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  const int16_t sub_const = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
 
   const int16_t round_bits =
-      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
-  const int offset = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int32x4_t round_shift_vec = vdupq_n_s32(-(conv_params->round_1));
+      2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS;
+  const int offset = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int32x4_t offset_const = vdupq_n_s32(1 << offset);
   const int16x4_t sub_const_vec = vdup_n_s16(sub_const);
   const uint16_t fwd_offset = conv_params->fwd_offset;
@@ -742,14 +741,10 @@
 #if defined(__aarch64__)
       load_s16_4x4(src_ptr, src_stride, &s5, &s6, &s7, &s8);
 
-      d0 = convolve6_4_s32(s0, s1, s2, s3, s4, s5, y_filter, round_shift_vec,
-                           offset_const);
-      d1 = convolve6_4_s32(s1, s2, s3, s4, s5, s6, y_filter, round_shift_vec,
-                           offset_const);
-      d2 = convolve6_4_s32(s2, s3, s4, s5, s6, s7, y_filter, round_shift_vec,
-                           offset_const);
-      d3 = convolve6_4_s32(s3, s4, s5, s6, s7, s8, y_filter, round_shift_vec,
-                           offset_const);
+      d0 = convolve6_4_s32(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
+      d1 = convolve6_4_s32(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
+      d2 = convolve6_4_s32(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
+      d3 = convolve6_4_s32(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
 
       if (do_average) {
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -778,8 +773,7 @@
 #else
       s5 = vld1_s16(src_ptr);
 
-      d0 = convolve6_4_s32(s0, s1, s2, s3, s4, s5, y_filter, round_shift_vec,
-                           offset_const);
+      d0 = convolve6_4_s32(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
 
       if (do_average) {
         dd0 = vld1_u16(dst_ptr);
@@ -828,14 +822,10 @@
 #if defined(__aarch64__)
         load_s16_8x4(s, src_stride, &s5, &s6, &s7, &s8);
 
-        d0 = convolve6_8_s32(s0, s1, s2, s3, s4, s5, y_filter, round_shift_vec,
-                             offset_const);
-        d1 = convolve6_8_s32(s1, s2, s3, s4, s5, s6, y_filter, round_shift_vec,
-                             offset_const);
-        d2 = convolve6_8_s32(s2, s3, s4, s5, s6, s7, y_filter, round_shift_vec,
-                             offset_const);
-        d3 = convolve6_8_s32(s3, s4, s5, s6, s7, s8, y_filter, round_shift_vec,
-                             offset_const);
+        d0 = convolve6_8_s32(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
+        d1 = convolve6_8_s32(s1, s2, s3, s4, s5, s6, y_filter, offset_const);
+        d2 = convolve6_8_s32(s2, s3, s4, s5, s6, s7, y_filter, offset_const);
+        d3 = convolve6_8_s32(s3, s4, s5, s6, s7, s8, y_filter, offset_const);
 
         if (do_average) {
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -862,8 +852,7 @@
 #else
         s5 = vld1q_s16(s);
 
-        d0 = convolve6_8_s32(s0, s1, s2, s3, s4, s5, y_filter, round_shift_vec,
-                             offset_const);
+        d0 = convolve6_8_s32(s0, s1, s2, s3, s4, s5, y_filter, offset_const);
 
         if (do_average) {
           dd0 = vld1q_u16(d);
@@ -904,14 +893,13 @@
   const int dst_stride = conv_params->dst_stride;
 
   const int bd = 8;
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int16_t sub_const = (1 << (offset_bits - conv_params->round_1)) +
-                            (1 << (offset_bits - conv_params->round_1 - 1));
+  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  const int16_t sub_const = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
 
   const int16_t round_bits =
-      2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
-  const int offset = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int32x4_t round_shift_vec = vdupq_n_s32(-(conv_params->round_1));
+      2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS;
+  const int offset = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int32x4_t offset_const = vdupq_n_s32(1 << offset);
   const int16x4_t sub_const_vec = vdup_n_s16(sub_const);
   const uint16_t fwd_offset = conv_params->fwd_offset;
@@ -938,13 +926,13 @@
       load_s16_4x4(src_ptr, src_stride, &s7, &s8, &s9, &s10);
 
       d0 = convolve8_4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
-                           round_shift_vec, offset_const);
+                           offset_const);
       d1 = convolve8_4_s32(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
-                           round_shift_vec, offset_const);
+                           offset_const);
       d2 = convolve8_4_s32(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
-                           round_shift_vec, offset_const);
+                           offset_const);
       d3 = convolve8_4_s32(s3, s4, s5, s6, s7, s8, s9, s10, y_filter,
-                           round_shift_vec, offset_const);
+                           offset_const);
 
       if (do_average) {
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -976,7 +964,7 @@
       s7 = vld1_s16(src_ptr);
 
       d0 = convolve8_4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
-                           round_shift_vec, offset_const);
+                           offset_const);
 
       if (do_average) {
         dd0 = vld1_u16(dst_ptr);
@@ -1028,13 +1016,13 @@
         load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
 
         d0 = convolve8_8_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
-                             round_shift_vec, offset_const);
+                             offset_const);
         d1 = convolve8_8_s32(s1, s2, s3, s4, s5, s6, s7, s8, y_filter,
-                             round_shift_vec, offset_const);
+                             offset_const);
         d2 = convolve8_8_s32(s2, s3, s4, s5, s6, s7, s8, s9, y_filter,
-                             round_shift_vec, offset_const);
+                             offset_const);
         d3 = convolve8_8_s32(s3, s4, s5, s6, s7, s8, s9, s10, y_filter,
-                             round_shift_vec, offset_const);
+                             offset_const);
 
         if (do_average) {
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -1064,7 +1052,7 @@
         s7 = vld1q_s16(s);
 
         d0 = convolve8_8_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
-                             round_shift_vec, offset_const);
+                             offset_const);
 
         if (do_average) {
           dd0 = vld1q_u16(d);