Refactor rounding shims in averaging Neon convolution functions

In many cases when computing a sum using Neon, adding a constant shim
can avoid a 'complex' rounding right shift (which is usually slower
than a 'simple' truncating right shift.) Such a shim was added when
averaging in the Neon compound convolution functions. However, since
we are now using a 'complex' saturating-convert-to-unsigned right
shift (vqshrun_n_s16) in this sequence anyway, there's no extra cost
to using the rounding-saturating-convert-to-unsigned right shift
instruction (vqrshrun_n_s16.)

This patch removes the rounding constant shims in favour of using the
rounding-saturating-convert-to-unsigned right shift instruction to
compute the average in Neon compound convolution functions. It also
makes a number of cosmetic changes to make variable names consistent
the C reference and other Neon implementations.

Change-Id: I10efbcc687cbbaa1ed1aed09e4d57b3de515cf70
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index 4eefb71..4129c43 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -23,79 +23,79 @@
 #include "av1/common/arm/convolve_neon.h"
 
 #if !defined(__aarch64__)
-static INLINE void compute_avg_4x1(uint16x4_t res0, uint16x4_t d0,
+static INLINE void compute_avg_4x1(uint16x4_t dd0, uint16x4_t d0,
                                    const uint16_t fwd_offset,
                                    const uint16_t bck_offset,
-                                   const int16x4_t sub_const,
+                                   const int16x4_t round_offset,
                                    const int use_dist_wtd_comp_avg,
-                                   uint8x8_t *t0) {
+                                   uint8x8_t *d0_u8) {
   uint16x4_t avg0;
 
   if (use_dist_wtd_comp_avg) {
     uint32x4_t blend0;
-    blend0 = vmull_n_u16(res0, fwd_offset);
+    blend0 = vmull_n_u16(dd0, fwd_offset);
     blend0 = vmlal_n_u16(blend0, d0, bck_offset);
 
     avg0 = vshrn_n_u32(blend0, DIST_PRECISION_BITS);
   } else {
-    avg0 = vhadd_u16(res0, d0);
+    avg0 = vhadd_u16(dd0, d0);
   }
 
-  int16x4_t dst0 = vsub_s16(vreinterpret_s16_u16(avg0), sub_const);
+  int16x4_t dst0 = vsub_s16(vreinterpret_s16_u16(avg0), round_offset);
 
   int16x8_t dst0q = vcombine_s16(dst0, vdup_n_s16(0));
 
-  *t0 = vqshrun_n_s16(dst0q, FILTER_BITS - ROUND0_BITS);
+  *d0_u8 = vqrshrun_n_s16(dst0q, FILTER_BITS - ROUND0_BITS);
 }
 
-static INLINE void compute_avg_8x1(uint16x8_t res0, uint16x8_t d0,
+static INLINE void compute_avg_8x1(uint16x8_t dd0, uint16x8_t d0,
                                    const uint16_t fwd_offset,
                                    const uint16_t bck_offset,
-                                   const int16x8_t sub_const,
+                                   const int16x8_t round_offset,
                                    const int use_dist_wtd_comp_avg,
-                                   uint8x8_t *t0) {
+                                   uint8x8_t *d0_u8) {
   uint16x8_t avg0;
 
   if (use_dist_wtd_comp_avg) {
     uint32x4_t blend0_lo, blend0_hi;
 
-    blend0_lo = vmull_n_u16(vget_low_u16(res0), fwd_offset);
+    blend0_lo = vmull_n_u16(vget_low_u16(dd0), fwd_offset);
     blend0_lo = vmlal_n_u16(blend0_lo, vget_low_u16(d0), bck_offset);
-    blend0_hi = vmull_n_u16(vget_high_u16(res0), fwd_offset);
+    blend0_hi = vmull_n_u16(vget_high_u16(dd0), fwd_offset);
     blend0_hi = vmlal_n_u16(blend0_hi, vget_high_u16(d0), bck_offset);
 
     avg0 = vcombine_u16(vshrn_n_u32(blend0_lo, DIST_PRECISION_BITS),
                         vshrn_n_u32(blend0_hi, DIST_PRECISION_BITS));
   } else {
-    avg0 = vhaddq_u16(res0, d0);
+    avg0 = vhaddq_u16(dd0, d0);
   }
 
-  int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), sub_const);
+  int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), round_offset);
 
-  *t0 = vqshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
+  *d0_u8 = vqrshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
 }
 #endif  // !defined(__arch64__)
 
-static INLINE void compute_avg_4x4(uint16x4_t res0, uint16x4_t res1,
-                                   uint16x4_t res2, uint16x4_t res3,
+static INLINE void compute_avg_4x4(uint16x4_t dd0, uint16x4_t dd1,
+                                   uint16x4_t dd2, uint16x4_t dd3,
                                    uint16x4_t d0, uint16x4_t d1, uint16x4_t d2,
                                    uint16x4_t d3, const uint16_t fwd_offset,
                                    const uint16_t bck_offset,
-                                   const int16x8_t sub_const,
+                                   const int16x8_t round_offset,
                                    const int use_dist_wtd_comp_avg,
-                                   uint8x8_t *t0, uint8x8_t *t1) {
+                                   uint8x8_t *d01_u8, uint8x8_t *d23_u8) {
   uint16x4_t avg0, avg1, avg2, avg3;
 
   if (use_dist_wtd_comp_avg) {
     uint32x4_t blend0, blend1, blend2, blend3;
 
-    blend0 = vmull_n_u16(res0, fwd_offset);
+    blend0 = vmull_n_u16(dd0, fwd_offset);
     blend0 = vmlal_n_u16(blend0, d0, bck_offset);
-    blend1 = vmull_n_u16(res1, fwd_offset);
+    blend1 = vmull_n_u16(dd1, fwd_offset);
     blend1 = vmlal_n_u16(blend1, d1, bck_offset);
-    blend2 = vmull_n_u16(res2, fwd_offset);
+    blend2 = vmull_n_u16(dd2, fwd_offset);
     blend2 = vmlal_n_u16(blend2, d2, bck_offset);
-    blend3 = vmull_n_u16(res3, fwd_offset);
+    blend3 = vmull_n_u16(dd3, fwd_offset);
     blend3 = vmlal_n_u16(blend3, d3, bck_offset);
 
     avg0 = vshrn_n_u32(blend0, DIST_PRECISION_BITS);
@@ -103,52 +103,52 @@
     avg2 = vshrn_n_u32(blend2, DIST_PRECISION_BITS);
     avg3 = vshrn_n_u32(blend3, DIST_PRECISION_BITS);
   } else {
-    avg0 = vhadd_u16(res0, d0);
-    avg1 = vhadd_u16(res1, d1);
-    avg2 = vhadd_u16(res2, d2);
-    avg3 = vhadd_u16(res3, d3);
+    avg0 = vhadd_u16(dd0, d0);
+    avg1 = vhadd_u16(dd1, d1);
+    avg2 = vhadd_u16(dd2, d2);
+    avg3 = vhadd_u16(dd3, d3);
   }
 
   int16x8_t dst_01 = vreinterpretq_s16_u16(vcombine_u16(avg0, avg1));
   int16x8_t dst_23 = vreinterpretq_s16_u16(vcombine_u16(avg2, avg3));
 
-  dst_01 = vsubq_s16(dst_01, sub_const);
-  dst_23 = vsubq_s16(dst_23, sub_const);
+  dst_01 = vsubq_s16(dst_01, round_offset);
+  dst_23 = vsubq_s16(dst_23, round_offset);
 
-  *t0 = vqshrun_n_s16(dst_01, FILTER_BITS - ROUND0_BITS);
-  *t1 = vqshrun_n_s16(dst_23, FILTER_BITS - ROUND0_BITS);
+  *d01_u8 = vqrshrun_n_s16(dst_01, FILTER_BITS - ROUND0_BITS);
+  *d23_u8 = vqrshrun_n_s16(dst_23, FILTER_BITS - ROUND0_BITS);
 }
 
 static INLINE void compute_avg_8x4(
-    uint16x8_t res0, uint16x8_t res1, uint16x8_t res2, uint16x8_t res3,
+    uint16x8_t dd0, uint16x8_t dd1, uint16x8_t dd2, uint16x8_t dd3,
     uint16x8_t d0, uint16x8_t d1, uint16x8_t d2, uint16x8_t d3,
     const uint16_t fwd_offset, const uint16_t bck_offset,
-    const int16x8_t sub_const, const int use_dist_wtd_comp_avg, uint8x8_t *t0,
-    uint8x8_t *t1, uint8x8_t *t2, uint8x8_t *t3) {
+    const int16x8_t round_offset, const int use_dist_wtd_comp_avg,
+    uint8x8_t *d0_u8, uint8x8_t *d1_u8, uint8x8_t *d2_u8, uint8x8_t *d3_u8) {
   uint16x8_t avg0, avg1, avg2, avg3;
 
   if (use_dist_wtd_comp_avg) {
     uint32x4_t blend0_lo, blend1_lo, blend2_lo, blend3_lo;
     uint32x4_t blend0_hi, blend1_hi, blend2_hi, blend3_hi;
 
-    blend0_lo = vmull_n_u16(vget_low_u16(res0), fwd_offset);
+    blend0_lo = vmull_n_u16(vget_low_u16(dd0), fwd_offset);
     blend0_lo = vmlal_n_u16(blend0_lo, vget_low_u16(d0), bck_offset);
-    blend0_hi = vmull_n_u16(vget_high_u16(res0), fwd_offset);
+    blend0_hi = vmull_n_u16(vget_high_u16(dd0), fwd_offset);
     blend0_hi = vmlal_n_u16(blend0_hi, vget_high_u16(d0), bck_offset);
 
-    blend1_lo = vmull_n_u16(vget_low_u16(res1), fwd_offset);
+    blend1_lo = vmull_n_u16(vget_low_u16(dd1), fwd_offset);
     blend1_lo = vmlal_n_u16(blend1_lo, vget_low_u16(d1), bck_offset);
-    blend1_hi = vmull_n_u16(vget_high_u16(res1), fwd_offset);
+    blend1_hi = vmull_n_u16(vget_high_u16(dd1), fwd_offset);
     blend1_hi = vmlal_n_u16(blend1_hi, vget_high_u16(d1), bck_offset);
 
-    blend2_lo = vmull_n_u16(vget_low_u16(res2), fwd_offset);
+    blend2_lo = vmull_n_u16(vget_low_u16(dd2), fwd_offset);
     blend2_lo = vmlal_n_u16(blend2_lo, vget_low_u16(d2), bck_offset);
-    blend2_hi = vmull_n_u16(vget_high_u16(res2), fwd_offset);
+    blend2_hi = vmull_n_u16(vget_high_u16(dd2), fwd_offset);
     blend2_hi = vmlal_n_u16(blend2_hi, vget_high_u16(d2), bck_offset);
 
-    blend3_lo = vmull_n_u16(vget_low_u16(res3), fwd_offset);
+    blend3_lo = vmull_n_u16(vget_low_u16(dd3), fwd_offset);
     blend3_lo = vmlal_n_u16(blend3_lo, vget_low_u16(d3), bck_offset);
-    blend3_hi = vmull_n_u16(vget_high_u16(res3), fwd_offset);
+    blend3_hi = vmull_n_u16(vget_high_u16(dd3), fwd_offset);
     blend3_hi = vmlal_n_u16(blend3_hi, vget_high_u16(d3), bck_offset);
 
     avg0 = vcombine_u16(vshrn_n_u32(blend0_lo, DIST_PRECISION_BITS),
@@ -160,21 +160,21 @@
     avg3 = vcombine_u16(vshrn_n_u32(blend3_lo, DIST_PRECISION_BITS),
                         vshrn_n_u32(blend3_hi, DIST_PRECISION_BITS));
   } else {
-    avg0 = vhaddq_u16(res0, d0);
-    avg1 = vhaddq_u16(res1, d1);
-    avg2 = vhaddq_u16(res2, d2);
-    avg3 = vhaddq_u16(res3, d3);
+    avg0 = vhaddq_u16(dd0, d0);
+    avg1 = vhaddq_u16(dd1, d1);
+    avg2 = vhaddq_u16(dd2, d2);
+    avg3 = vhaddq_u16(dd3, d3);
   }
 
-  int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), sub_const);
-  int16x8_t dst1 = vsubq_s16(vreinterpretq_s16_u16(avg1), sub_const);
-  int16x8_t dst2 = vsubq_s16(vreinterpretq_s16_u16(avg2), sub_const);
-  int16x8_t dst3 = vsubq_s16(vreinterpretq_s16_u16(avg3), sub_const);
+  int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), round_offset);
+  int16x8_t dst1 = vsubq_s16(vreinterpretq_s16_u16(avg1), round_offset);
+  int16x8_t dst2 = vsubq_s16(vreinterpretq_s16_u16(avg2), round_offset);
+  int16x8_t dst3 = vsubq_s16(vreinterpretq_s16_u16(avg3), round_offset);
 
-  *t0 = vqshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
-  *t1 = vqshrun_n_s16(dst1, FILTER_BITS - ROUND0_BITS);
-  *t2 = vqshrun_n_s16(dst2, FILTER_BITS - ROUND0_BITS);
-  *t3 = vqshrun_n_s16(dst3, FILTER_BITS - ROUND0_BITS);
+  *d0_u8 = vqrshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
+  *d1_u8 = vqrshrun_n_s16(dst1, FILTER_BITS - ROUND0_BITS);
+  *d2_u8 = vqrshrun_n_s16(dst2, FILTER_BITS - ROUND0_BITS);
+  *d3_u8 = vqrshrun_n_s16(dst3, FILTER_BITS - ROUND0_BITS);
 }
 
 #if defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
@@ -606,15 +606,10 @@
 
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
-  // non-rounding shifts when computing the average.
-  const int16_t sub_const = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)) -
-                            (1 << (FILTER_BITS - ROUND0_BITS - 1));
-
-  const int offset = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  const int32x4_t offset_const = vdupq_n_s32(1 << offset);
-  const int16x8_t sub_const_vec = vdupq_n_s16(sub_const);
+  const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
+  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int do_average = conv_params->do_average;
@@ -647,7 +642,7 @@
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
         compute_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
-                        bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                        bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
                         &d01_u8, &d23_u8);
 
         store_u8_4x1(dst8_ptr + 0 * dst8_stride, d01_u8, 0);
@@ -676,7 +671,7 @@
         dd0 = vld1_u16(dst_ptr);
 
         compute_avg_4x1(dd0, d0, fwd_offset, bck_offset,
-                        vget_low_s16(sub_const_vec), use_dist_wtd_comp_avg,
+                        vget_low_s16(round_offset_vec), use_dist_wtd_comp_avg,
                         &d01_u8);
 
         store_u8_4x1(dst8_ptr, d01_u8, 0);
@@ -729,7 +724,7 @@
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
           compute_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
-                          bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                          bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
                           &d0_u8, &d1_u8, &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
@@ -754,7 +749,7 @@
         if (do_average) {
           dd0 = vld1q_u16(d);
 
-          compute_avg_8x1(dd0, d0, fwd_offset, bck_offset, sub_const_vec,
+          compute_avg_8x1(dd0, d0, fwd_offset, bck_offset, round_offset_vec,
                           use_dist_wtd_comp_avg, &d0_u8);
 
           vst1_u8(d_u8, d0_u8);
@@ -791,15 +786,10 @@
 
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
-  // non-rounding shifts when computing the average.
-  const int16_t sub_const = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)) -
-                            (1 << (FILTER_BITS - ROUND0_BITS - 1));
-
-  const int offset = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  const int32x4_t offset_const = vdupq_n_s32(1 << offset);
-  const int16x8_t sub_const_vec = vdupq_n_s16(sub_const);
+  const int32x4_t offset_const = vdupq_n_s32(1 << offset_bits);
+  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int do_average = conv_params->do_average;
@@ -836,7 +826,7 @@
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
         compute_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
-                        bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                        bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
                         &d01_u8, &d23_u8);
 
         store_u8_4x1(dst8_ptr + 0 * dst8_stride, d01_u8, 0);
@@ -868,7 +858,7 @@
         dd0 = vld1_u16(dst_ptr);
 
         compute_avg_4x1(dd0, d0, fwd_offset, bck_offset,
-                        vget_low_s16(sub_const_vec), use_dist_wtd_comp_avg,
+                        vget_low_s16(round_offset_vec), use_dist_wtd_comp_avg,
                         &d01_u8);
 
         store_u8_4x1(dst8_ptr, d01_u8, 0);
@@ -927,7 +917,7 @@
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
           compute_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
-                          bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                          bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
                           &d0_u8, &d1_u8, &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
@@ -955,7 +945,7 @@
         if (do_average) {
           dd0 = vld1q_u16(d);
 
-          compute_avg_8x1(dd0, d0, fwd_offset, bck_offset, sub_const_vec,
+          compute_avg_8x1(dd0, d0, fwd_offset, bck_offset, round_offset_vec,
                           use_dist_wtd_comp_avg, &d0_u8);
 
           vst1_u8(d_u8, d0_u8);
@@ -1035,16 +1025,12 @@
   const int dst_stride = conv_params->dst_stride;
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
-  // non-rounding shifts when computing the average.
-  const int16x8_t sub_const_vec = vdupq_n_s16(
-      (int16_t)round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
+  const uint16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                                (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const uint16x8_t round_offset_vec = vdupq_n_u16(round_offset);
   const uint8x8_t shift_by_bits = vdup_n_u8(1 << (FILTER_BITS - ROUND0_BITS));
 
   if (w >= 8) {
-    const uint16x8_t round_offset_vec = vdupq_n_u16((uint16_t)round_offset);
     uint8x8_t s0, s1, s2, s3, dd0, dd1, dd2, dd3;
     uint16x8_t d0, d1, d2, d3, t0, t1, t2, t3;
     int height = h / 4;
@@ -1066,10 +1052,10 @@
         if (conv_params->do_average) {
           load_u16_8x4(d_conv_buf, dst_stride, &t0, &t1, &t2, &t3);
 
-          compute_avg_8x4(t0, t1, t2, t3, d0, d1, d2, d3,
-                          conv_params->fwd_offset, conv_params->bck_offset,
-                          sub_const_vec, conv_params->use_dist_wtd_comp_avg,
-                          &dd0, &dd1, &dd2, &dd3);
+          compute_avg_8x4(
+              t0, t1, t2, t3, d0, d1, d2, d3, conv_params->fwd_offset,
+              conv_params->bck_offset, vreinterpretq_s16_u16(round_offset_vec),
+              conv_params->use_dist_wtd_comp_avg, &dd0, &dd1, &dd2, &dd3);
 
           store_u8_8x4(d_u8, dst8_stride, dd0, dd1, dd2, dd3);
         } else {
@@ -1085,7 +1071,6 @@
       dst8 += 4 * dst8_stride;
     } while (--height != 0);
   } else {
-    const uint16x8_t round_offset_vec = vdupq_n_u16((uint16_t)round_offset);
     uint8x8_t s0, s1, s2, s3, d01, d23;
     uint16x4_t d0, d1, d2, d3, t0, t1, t2, t3;
     int height = h / 4;
@@ -1102,7 +1087,8 @@
         load_u16_4x4(dst, dst_stride, &t0, &t1, &t2, &t3);
 
         compute_avg_4x4(t0, t1, t2, t3, d0, d1, d2, d3, conv_params->fwd_offset,
-                        conv_params->bck_offset, sub_const_vec,
+                        conv_params->bck_offset,
+                        vreinterpretq_s16_u16(round_offset_vec),
                         conv_params->use_dist_wtd_comp_avg, &d01, &d23);
 
         store_u8_4x1(dst8 + 0 * dst8_stride, d01, 0);
@@ -1133,16 +1119,12 @@
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
-  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
-  // non-rounding shifts when computing the average.
-  const int16x8_t sub_const_vec =
-      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
-  const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
 
   // Horizontal filter.
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
@@ -1187,8 +1169,8 @@
       d01 = vshrq_n_s16(d01, ROUND0_BITS - 1);
       d23 = vshrq_n_s16(d23, ROUND0_BITS - 1);
 
-      d01 = vaddq_s16(d01, round_offset128);
-      d23 = vaddq_s16(d23, round_offset128);
+      d01 = vaddq_s16(d01, round_offset_vec);
+      d23 = vaddq_s16(d23, round_offset_vec);
 
       if (conv_params->do_average) {
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -1198,7 +1180,7 @@
             vreinterpret_u16_s16(vget_high_s16(d01)),
             vreinterpret_u16_s16(vget_low_s16(d23)),
             vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset, bck_offset,
-            sub_const_vec, use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+            round_offset_vec, use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
 
         store_u8_4x1(dst_u8_ptr + 0 * dst8_stride, d01_u8, 0);
         store_u8_4x1(dst_u8_ptr + 1 * dst8_stride, d01_u8, 1);
@@ -1239,10 +1221,10 @@
         d2 = convolve8_horiz_8_usdot(s2, x_filter, permute_tbl, horiz_const);
         d3 = convolve8_horiz_8_usdot(s3, x_filter, permute_tbl, horiz_const);
 
-        d0 = vaddq_s16(d0, round_offset128);
-        d1 = vaddq_s16(d1, round_offset128);
-        d2 = vaddq_s16(d2, round_offset128);
-        d3 = vaddq_s16(d3, round_offset128);
+        d0 = vaddq_s16(d0, round_offset_vec);
+        d1 = vaddq_s16(d1, round_offset_vec);
+        d2 = vaddq_s16(d2, round_offset_vec);
+        d3 = vaddq_s16(d3, round_offset_vec);
 
         if (conv_params->do_average) {
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -1250,8 +1232,8 @@
           compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
                           vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
                           vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &d0_u8, &d1_u8,
-                          &d2_u8, &d3_u8);
+                          round_offset_vec, use_dist_wtd_comp_avg, &d0_u8,
+                          &d1_u8, &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
         } else {
@@ -1287,16 +1269,12 @@
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
-  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
-  // non-rounding shifts when computing the average.
-  const int16x8_t sub_const_vec =
-      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
-  const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
 
   // Horizontal filter.
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
@@ -1345,8 +1323,8 @@
       d01 = vshrq_n_s16(d01, ROUND0_BITS - 1);
       d23 = vshrq_n_s16(d23, ROUND0_BITS - 1);
 
-      d01 = vaddq_s16(d01, round_offset128);
-      d23 = vaddq_s16(d23, round_offset128);
+      d01 = vaddq_s16(d01, round_offset_vec);
+      d23 = vaddq_s16(d23, round_offset_vec);
 
       if (conv_params->do_average) {
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -1356,7 +1334,7 @@
             vreinterpret_u16_s16(vget_high_s16(d01)),
             vreinterpret_u16_s16(vget_low_s16(d23)),
             vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset, bck_offset,
-            sub_const_vec, use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+            round_offset_vec, use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
 
         store_u8_4x1(dst_u8_ptr + 0 * dst8_stride, d01_u8, 0);
         store_u8_4x1(dst_u8_ptr + 1 * dst8_stride, d01_u8, 1);
@@ -1407,10 +1385,10 @@
         d2 = vshrq_n_s16(d2, ROUND0_BITS - 1);
         d3 = vshrq_n_s16(d3, ROUND0_BITS - 1);
 
-        d0 = vaddq_s16(d0, round_offset128);
-        d1 = vaddq_s16(d1, round_offset128);
-        d2 = vaddq_s16(d2, round_offset128);
-        d3 = vaddq_s16(d3, round_offset128);
+        d0 = vaddq_s16(d0, round_offset_vec);
+        d1 = vaddq_s16(d1, round_offset_vec);
+        d2 = vaddq_s16(d2, round_offset_vec);
+        d3 = vaddq_s16(d3, round_offset_vec);
 
         if (conv_params->do_average) {
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -1418,8 +1396,8 @@
           compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
                           vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
                           vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &d0_u8, &d1_u8,
-                          &d2_u8, &d3_u8);
+                          round_offset_vec, use_dist_wtd_comp_avg, &d0_u8,
+                          &d1_u8, &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
         } else {
@@ -1499,15 +1477,12 @@
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
-  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
-  // non-rounding shifts when computing the average.
-  const int16x8_t sub_const_vec =
-      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
 
   // horizontal filter
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
@@ -1539,13 +1514,11 @@
     int16x8_t tt0;
     uint16x4_t res4;
 #if defined(__aarch64__)
-    const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
     int16x4_t s8, s9, s10, d1, d2, d3;
     int16x8_t tt1, tt2, tt3, t01, t23;
     uint16x4_t res5, res6, res7;
     int16x8_t u0, u1;
 #else   // !defined(__aarch64__)
-    const int16x4_t round_offset_vec = vdup_n_s16(round_offset);
     int16x4_t temp_0;
 #endif  // defined(__aarch64__)
     // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
@@ -1636,7 +1609,7 @@
           compute_avg_4x4(res4, res5, res6, res7, vreinterpret_u16_s16(d0),
                           vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
                           vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1);
+                          round_offset_vec, use_dist_wtd_comp_avg, &t0, &t1);
 
           store_u8_4x1(d_u8 + 0 * dst8_stride, t0, 0);
           store_u8_4x1(d_u8 + 1 * dst8_stride, t0, 1);
@@ -1689,7 +1662,7 @@
                                  horiz_const);
         // We halved the convolution filter values so -1 from the right shift.
         d0 = vshr_n_s16(d0, ROUND0_BITS - 1);
-        d0 = vadd_s16(d0, round_offset_vec);
+        d0 = vadd_s16(d0, vget_low_s16(round_offset_vec));
         s0 = s4;
         s4 = temp_0;
         if (conv_params->do_average) {
@@ -1699,7 +1672,7 @@
           res4 = vld1_u16(d);
 
           compute_avg_4x1(res4, vreinterpret_u16_s16(d0), fwd_offset,
-                          bck_offset, vget_low_s16(sub_const_vec),
+                          bck_offset, vget_low_s16(round_offset_vec),
                           use_dist_wtd_comp_avg, &t0);
 
           store_u8_4x1(d_u8, t0, 0);
@@ -1724,7 +1697,6 @@
     int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
     int16x8_t res0;
     uint16x8_t res8;
-    const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
     // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
     // shifts - which are generally faster than rounding shifts on modern CPUs.
     // The outermost -1 is needed because we halved the filter values.
@@ -1812,14 +1784,14 @@
         res6 = vshrq_n_s16(res6, ROUND0_BITS - 1);
         res7 = vshrq_n_s16(res7, ROUND0_BITS - 1);
 
-        res0 = vaddq_s16(res0, round_offset128);
-        res1 = vaddq_s16(res1, round_offset128);
-        res2 = vaddq_s16(res2, round_offset128);
-        res3 = vaddq_s16(res3, round_offset128);
-        res4 = vaddq_s16(res4, round_offset128);
-        res5 = vaddq_s16(res5, round_offset128);
-        res6 = vaddq_s16(res6, round_offset128);
-        res7 = vaddq_s16(res7, round_offset128);
+        res0 = vaddq_s16(res0, round_offset_vec);
+        res1 = vaddq_s16(res1, round_offset_vec);
+        res2 = vaddq_s16(res2, round_offset_vec);
+        res3 = vaddq_s16(res3, round_offset_vec);
+        res4 = vaddq_s16(res4, round_offset_vec);
+        res5 = vaddq_s16(res5, round_offset_vec);
+        res6 = vaddq_s16(res6, round_offset_vec);
+        res7 = vaddq_s16(res7, round_offset_vec);
 
         transpose_s16_8x8(&res0, &res1, &res2, &res3, &res4, &res5, &res6,
                           &res7);
@@ -1832,7 +1804,7 @@
               res8, res9, res10, res11, vreinterpretq_u16_s16(res0),
               vreinterpretq_u16_s16(res1), vreinterpretq_u16_s16(res2),
               vreinterpretq_u16_s16(res3), fwd_offset, bck_offset,
-              sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2, &t3);
+              round_offset_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2, &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -1844,7 +1816,7 @@
               res8, res9, res10, res11, vreinterpretq_u16_s16(res4),
               vreinterpretq_u16_s16(res5), vreinterpretq_u16_s16(res6),
               vreinterpretq_u16_s16(res7), fwd_offset, bck_offset,
-              sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2, &t3);
+              round_offset_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2, &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -1908,14 +1880,14 @@
                                    horiz_const);
         // We halved the convolution filter values so -1 from the right shift.
         res0 = vshrq_n_s16(res0, ROUND0_BITS - 1);
-        res0 = vaddq_s16(res0, round_offset128);
+        res0 = vaddq_s16(res0, round_offset_vec);
 
         if (conv_params->do_average) {
           res8 = vld1q_u16(d_tmp);
           d_tmp += dst_stride;
 
           compute_avg_8x1(res8, vreinterpretq_u16_s16(res0), fwd_offset,
-                          bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                          bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
                           &t0);
 
           vst1_u8(d_u8, t0);
@@ -2032,12 +2004,9 @@
   const int dst_stride = conv_params->dst_stride;
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
-  // non-rounding shifts when computing the average.
-  const int16x8_t sub_const_vec =
-      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
+  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
@@ -2060,7 +2029,6 @@
     // shifts - which are generally faster than rounding shifts on modern CPUs.
     // The outermost -1 is needed because we halved the filter values.
     const int16x4_t vert_const = vdup_n_s16(1 << ((ROUND0_BITS - 1) - 1));
-    const int16x4_t round_offset64 = vdup_n_s16(round_offset);
     int width = w;
 
     do {
@@ -2099,10 +2067,10 @@
         d2 = convolve6_y_4x4_s16(s2, s3, s4, s5, s6, s7, y_filter, vert_const);
         d3 = convolve6_y_4x4_s16(s3, s4, s5, s6, s7, s8, y_filter, vert_const);
 
-        d0 = vadd_s16(d0, round_offset64);
-        d1 = vadd_s16(d1, round_offset64);
-        d2 = vadd_s16(d2, round_offset64);
-        d3 = vadd_s16(d3, round_offset64);
+        d0 = vadd_s16(d0, vget_low_s16(round_offset_vec));
+        d1 = vadd_s16(d1, vget_low_s16(round_offset_vec));
+        d2 = vadd_s16(d2, vget_low_s16(round_offset_vec));
+        d3 = vadd_s16(d3, vget_low_s16(round_offset_vec));
 
         if (conv_params->do_average) {
           load_u16_4x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
@@ -2110,7 +2078,7 @@
           compute_avg_4x4(dd0, dd1, dd2, dd3, vreinterpret_u16_s16(d0),
                           vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
                           vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &d01, &d23);
+                          round_offset_vec, use_dist_wtd_comp_avg, &d01, &d23);
 
           store_u8_4x1(d_u8 + 0 * dst8_stride, d01, 0);
           store_u8_4x1(d_u8 + 1 * dst8_stride, d01, 1);
@@ -2137,13 +2105,13 @@
         s5 = vget_low_s16(tt0);
 
         d0 = convolve6_y_4x4_s16(s0, s1, s2, s3, s4, s5, y_filter, vert_const);
-        d0 = vadd_s16(d0, round_offset64);
+        d0 = vadd_s16(d0, vget_low_s16(round_offset_vec));
 
         if (conv_params->do_average) {
           dd0 = vld1_u16(d);
 
           compute_avg_4x1(dd0, vreinterpret_u16_s16(d0), fwd_offset, bck_offset,
-                          vget_low_s16(sub_const_vec), use_dist_wtd_comp_avg,
+                          vget_low_s16(round_offset_vec), use_dist_wtd_comp_avg,
                           &d01);
 
           store_u8_4x1(d_u8, d01, 0);
@@ -2176,7 +2144,6 @@
     // shifts - which are generally faster than rounding shifts on modern CPUs.
     // The outermost -1 is needed because we halved the filter values.
     const int16x8_t vert_const = vdupq_n_s16(1 << ((ROUND0_BITS - 1) - 1));
-    const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
 #if defined(__aarch64__)
     int16x8_t s6, s7, s8, s9, s10, s11, s12, d1, d2, d3, d4, d5, d6, d7;
     uint16x8_t d9, d10, d11;
@@ -2222,14 +2189,14 @@
         d7 = convolve6_y_8x4_s16(s7, s8, s9, s10, s11, s12, y_filter,
                                  vert_const);
 
-        d0 = vaddq_s16(d0, round_offset128);
-        d1 = vaddq_s16(d1, round_offset128);
-        d2 = vaddq_s16(d2, round_offset128);
-        d3 = vaddq_s16(d3, round_offset128);
-        d4 = vaddq_s16(d4, round_offset128);
-        d5 = vaddq_s16(d5, round_offset128);
-        d6 = vaddq_s16(d6, round_offset128);
-        d7 = vaddq_s16(d7, round_offset128);
+        d0 = vaddq_s16(d0, round_offset_vec);
+        d1 = vaddq_s16(d1, round_offset_vec);
+        d2 = vaddq_s16(d2, round_offset_vec);
+        d3 = vaddq_s16(d3, round_offset_vec);
+        d4 = vaddq_s16(d4, round_offset_vec);
+        d5 = vaddq_s16(d5, round_offset_vec);
+        d6 = vaddq_s16(d6, round_offset_vec);
+        d7 = vaddq_s16(d7, round_offset_vec);
 
         if (conv_params->do_average) {
           load_u16_8x4(d, dst_stride, &d8, &d9, &d10, &d11);
@@ -2238,8 +2205,8 @@
           compute_avg_8x4(d8, d9, d10, d11, vreinterpretq_u16_s16(d0),
                           vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
                           vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2,
-                          &t3);
+                          round_offset_vec, use_dist_wtd_comp_avg, &t0, &t1,
+                          &t2, &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2250,8 +2217,8 @@
           compute_avg_8x4(d8, d9, d10, d11, vreinterpretq_u16_s16(d4),
                           vreinterpretq_u16_s16(d5), vreinterpretq_u16_s16(d6),
                           vreinterpretq_u16_s16(d7), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2,
-                          &t3);
+                          round_offset_vec, use_dist_wtd_comp_avg, &t0, &t1,
+                          &t2, &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2275,7 +2242,7 @@
         s5 = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s)));
 
         d0 = convolve6_y_8x4_s16(s0, s1, s2, s3, s4, s5, y_filter, vert_const);
-        d0 = vaddq_s16(d0, round_offset128);
+        d0 = vaddq_s16(d0, round_offset_vec);
 
         s0 = s1;
         s1 = s2;
@@ -2288,7 +2255,7 @@
           d += dst_stride;
 
           compute_avg_8x1(d8, vreinterpretq_u16_s16(d0), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &t0);
+                          round_offset_vec, use_dist_wtd_comp_avg, &t0);
 
           vst1_u8(d_u8, t0);
           d_u8 += dst8_stride;
@@ -2319,12 +2286,9 @@
   const int dst_stride = conv_params->dst_stride;
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
-  const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
-  // non-rounding shifts when computing the average.
-  const int16x8_t sub_const_vec =
-      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
+  const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                               (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+  const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
@@ -2340,13 +2304,10 @@
     uint8x8_t d01;
 
 #if defined(__aarch64__)
-    const int16x8_t round_offset64 = vdupq_n_s16(round_offset);
     int16x4_t s8, s9, s10, d1, d2, d3;
     uint16x4_t dd1, dd2, dd3;
     int16x8_t t01, t23;
     uint8x8_t d23;
-#else  // !defined(__aarch64__)
-    const int16x4_t round_offset64 = vdup_n_s16(round_offset);
 #endif
     // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
     // shifts - which are generally faster than rounding shifts on modern CPUs.
@@ -2414,8 +2375,8 @@
         t01 = vshrq_n_s16(t01, ROUND0_BITS - 1);
         t23 = vshrq_n_s16(t23, ROUND0_BITS - 1);
 
-        t01 = vaddq_s16(t01, round_offset64);
-        t23 = vaddq_s16(t23, round_offset64);
+        t01 = vaddq_s16(t01, round_offset_vec);
+        t23 = vaddq_s16(t23, round_offset_vec);
 
         d0 = vget_low_s16(t01);
         d1 = vget_high_s16(t01);
@@ -2438,7 +2399,7 @@
           compute_avg_4x4(dd0, dd1, dd2, dd3, vreinterpret_u16_s16(d0),
                           vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
                           vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &d01, &d23);
+                          round_offset_vec, use_dist_wtd_comp_avg, &d01, &d23);
 
           store_u8_4x1(d_u8 + 0 * dst8_stride, d01, 0);
           store_u8_4x1(d_u8 + 1 * dst8_stride, d01, 1);
@@ -2470,7 +2431,7 @@
                                  vert_const);
         // We halved the convolution filter values so -1 from the right shift.
         d0 = vshr_n_s16(d0, ROUND0_BITS - 1);
-        d0 = vadd_s16(d0, round_offset64);
+        d0 = vadd_s16(d0, vget_low_s16(round_offset_vec));
 
         if (conv_params->do_average) {
           __builtin_prefetch(d);
@@ -2478,7 +2439,7 @@
           dd0 = vld1_u16(d);
 
           compute_avg_4x1(dd0, vreinterpret_u16_s16(d0), fwd_offset, bck_offset,
-                          vget_low_s16(sub_const_vec), use_dist_wtd_comp_avg,
+                          vget_low_s16(round_offset_vec), use_dist_wtd_comp_avg,
                           &d01);
 
           store_u8_4x1(d_u8, d01, 0);
@@ -2505,7 +2466,6 @@
       width -= 4;
     } while (width > 0);
   } else {
-    const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
     // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
     // shifts - which are generally faster than rounding shifts on modern CPUs.
     // The outermost -1 is needed because we halved the filter values.
@@ -2582,14 +2542,14 @@
         d7 = convolve8_y_8x8_s16(s7, s8, s9, s10, s11, s12, s13, s14, y_filter,
                                  vert_const);
 
-        d0 = vaddq_s16(d0, round_offset128);
-        d1 = vaddq_s16(d1, round_offset128);
-        d2 = vaddq_s16(d2, round_offset128);
-        d3 = vaddq_s16(d3, round_offset128);
-        d4 = vaddq_s16(d4, round_offset128);
-        d5 = vaddq_s16(d5, round_offset128);
-        d6 = vaddq_s16(d6, round_offset128);
-        d7 = vaddq_s16(d7, round_offset128);
+        d0 = vaddq_s16(d0, round_offset_vec);
+        d1 = vaddq_s16(d1, round_offset_vec);
+        d2 = vaddq_s16(d2, round_offset_vec);
+        d3 = vaddq_s16(d3, round_offset_vec);
+        d4 = vaddq_s16(d4, round_offset_vec);
+        d5 = vaddq_s16(d5, round_offset_vec);
+        d6 = vaddq_s16(d6, round_offset_vec);
+        d7 = vaddq_s16(d7, round_offset_vec);
 
         if (conv_params->do_average) {
           __builtin_prefetch(d + 0 * dst8_stride);
@@ -2603,8 +2563,8 @@
           compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
                           vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
                           vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2,
-                          &t3);
+                          round_offset_vec, use_dist_wtd_comp_avg, &t0, &t1,
+                          &t2, &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2615,8 +2575,8 @@
           compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d4),
                           vreinterpretq_u16_s16(d5), vreinterpretq_u16_s16(d6),
                           vreinterpretq_u16_s16(d7), fwd_offset, bck_offset,
-                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2,
-                          &t3);
+                          round_offset_vec, use_dist_wtd_comp_avg, &t0, &t1,
+                          &t2, &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2645,7 +2605,7 @@
 
         d0 = convolve8_y_8x8_s16(s0, s1, s2, s3, s4, s5, s6, s7, y_filter,
                                  vert_const);
-        d0 = vaddq_s16(d0, round_offset128);
+        d0 = vaddq_s16(d0, round_offset_vec);
 
         s0 = s1;
         s1 = s2;
@@ -2662,7 +2622,7 @@
           d += dst_stride;
 
           compute_avg_8x1(dd0, vreinterpretq_u16_s16(d0), fwd_offset,
-                          bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                          bck_offset, round_offset_vec, use_dist_wtd_comp_avg,
                           &t0);
 
           vst1_u8(d_u8, t0);