Propagate constants in av1_dist_wtd_convolve_2d_copy_neon

Rounding modes and other convolution parameters are known ahead of
time. This patch propagates the values into the Neon code paths for
av1_dist_wtd_convolve_2d_copy_neon - enabling us to make some useful
simplifications and optimizations.

Co-authored by: Jonathan Wright <jonathan.wright@arm.com>

Change-Id: I8677d9c5dfae7eb3c9a6bc77d90bdf280db33291
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index 6dd4ecb..f970044 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -1135,16 +1135,14 @@
                                         int h, ConvolveParams *conv_params) {
   CONV_BUF_TYPE *dst = conv_params->dst;
   const int dst_stride = conv_params->dst_stride;
-  const int16_t bits =
-      FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
   const int bd = 8;
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int round_offset = (1 << (offset_bits - conv_params->round_1)) +
-                           (1 << (offset_bits - conv_params->round_1 - 1));
+  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
   const int16x4_t sub_const_vec = vdup_n_s16((int16_t)round_offset);
+  const uint8x8_t shift_by_bits = vdup_n_u8(1 << (FILTER_BITS - ROUND0_BITS));
 
   if (w >= 8) {
-    const int16x8_t shift_by_bits = vdupq_n_s16(bits);
     const uint16x8_t round_offset_vec = vdupq_n_u16((uint16_t)round_offset);
     uint8x8_t s0, s1, s2, s3, dd0, dd1, dd2, dd3;
     uint16x8_t d0, d1, d2, d3, t0, t1, t2, t3;
@@ -1159,22 +1157,17 @@
       do {
         load_u8_8x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        d0 = vshlq_u16(vmovl_u8(s0), shift_by_bits);
-        d1 = vshlq_u16(vmovl_u8(s1), shift_by_bits);
-        d2 = vshlq_u16(vmovl_u8(s2), shift_by_bits);
-        d3 = vshlq_u16(vmovl_u8(s3), shift_by_bits);
-
-        d0 = vaddq_u16(d0, round_offset_vec);
-        d1 = vaddq_u16(d1, round_offset_vec);
-        d2 = vaddq_u16(d2, round_offset_vec);
-        d3 = vaddq_u16(d3, round_offset_vec);
+        d0 = vmlal_u8(round_offset_vec, s0, shift_by_bits);
+        d1 = vmlal_u8(round_offset_vec, s1, shift_by_bits);
+        d2 = vmlal_u8(round_offset_vec, s2, shift_by_bits);
+        d3 = vmlal_u8(round_offset_vec, s3, shift_by_bits);
 
         if (conv_params->do_average) {
           load_u16_8x4(d_conv_buf, dst_stride, &t0, &t1, &t2, &t3);
 
           compute_avg_8x4(
               t0, t1, t2, t3, d0, d1, d2, d3, conv_params->fwd_offset,
-              conv_params->bck_offset, sub_const_vec, bits,
+              conv_params->bck_offset, sub_const_vec, FILTER_BITS - ROUND0_BITS,
               conv_params->use_dist_wtd_comp_avg, &dd0, &dd1, &dd2, &dd3);
 
           store_u8_8x4(d_u8, dst8_stride, dd0, dd1, dd2, dd3);
@@ -1191,8 +1184,7 @@
       dst8 += 4 * dst8_stride;
     } while (--height != 0);
   } else {
-    const int16x4_t shift_by_bits = vdup_n_s16(bits);
-    const uint16x4_t round_offset_vec = vdup_n_u16((uint16_t)round_offset);
+    const uint16x8_t round_offset_vec = vdupq_n_u16((uint16_t)round_offset);
     uint8x8_t s0, s1, s2, s3, d01, d23;
     uint16x4_t d0, d1, d2, d3, t0, t1, t2, t3;
     int height = h / 4;
@@ -1200,21 +1192,17 @@
     do {
       load_u8_8x4(src, src_stride, &s0, &s1, &s2, &s3);
 
-      d0 = vshl_u16(vget_low_u16(vmovl_u8(s0)), shift_by_bits);
-      d1 = vshl_u16(vget_low_u16(vmovl_u8(s1)), shift_by_bits);
-      d2 = vshl_u16(vget_low_u16(vmovl_u8(s2)), shift_by_bits);
-      d3 = vshl_u16(vget_low_u16(vmovl_u8(s3)), shift_by_bits);
-
-      d0 = vadd_u16(d0, round_offset_vec);
-      d1 = vadd_u16(d1, round_offset_vec);
-      d2 = vadd_u16(d2, round_offset_vec);
-      d3 = vadd_u16(d3, round_offset_vec);
+      d0 = vget_low_u16(vmlal_u8(round_offset_vec, s0, shift_by_bits));
+      d1 = vget_low_u16(vmlal_u8(round_offset_vec, s1, shift_by_bits));
+      d2 = vget_low_u16(vmlal_u8(round_offset_vec, s2, shift_by_bits));
+      d3 = vget_low_u16(vmlal_u8(round_offset_vec, s3, shift_by_bits));
 
       if (conv_params->do_average) {
         load_u16_4x4(dst, dst_stride, &t0, &t1, &t2, &t3);
 
         compute_avg_4x4(t0, t1, t2, t3, d0, d1, d2, d3, conv_params->fwd_offset,
-                        conv_params->bck_offset, sub_const_vec, bits,
+                        conv_params->bck_offset, sub_const_vec,
+                        FILTER_BITS - ROUND0_BITS,
                         conv_params->use_dist_wtd_comp_avg, &d01, &d23);
 
         store_u8_4x1(dst8 + 0 * dst8_stride, d01, 0);