Optimize Neon implementation of averaging convolution functions

This patch propagates constants in the compute_avg* functions used by
dist_wtd_convolve_* functions and simplifies the computation of both
basic and compound average.

Change-Id: Id278eb673ea3c9417bffe01ca02f5f5d736c280b
diff --git a/av1/common/arm/jnt_convolve_neon.c b/av1/common/arm/jnt_convolve_neon.c
index 8cb17fc..4eefb71 100644
--- a/av1/common/arm/jnt_convolve_neon.c
+++ b/av1/common/arm/jnt_convolve_neon.c
@@ -23,262 +23,158 @@
 #include "av1/common/arm/convolve_neon.h"
 
 #if !defined(__aarch64__)
-static INLINE void compute_avg_4x1(
-    uint16x4_t res0, uint16x4_t d0, const uint16_t fwd_offset,
-    const uint16_t bck_offset, const int16x4_t sub_const_vec,
-    const int16_t round_bits, const int use_dist_wtd_comp_avg, uint8x8_t *t0) {
-  int16x4_t tmp0;
-  uint16x4_t tmp_u0;
-  uint32x4_t sum0;
-  int32x4_t dst0;
-  int16x8_t tmp4;
+static INLINE void compute_avg_4x1(uint16x4_t res0, uint16x4_t d0,
+                                   const uint16_t fwd_offset,
+                                   const uint16_t bck_offset,
+                                   const int16x4_t sub_const,
+                                   const int use_dist_wtd_comp_avg,
+                                   uint8x8_t *t0) {
+  uint16x4_t avg0;
 
   if (use_dist_wtd_comp_avg) {
-    const int32x4_t round_bits_vec = vdupq_n_s32((int32_t)(-round_bits));
+    uint32x4_t blend0;
+    blend0 = vmull_n_u16(res0, fwd_offset);
+    blend0 = vmlal_n_u16(blend0, d0, bck_offset);
 
-    sum0 = vmull_n_u16(res0, fwd_offset);
-    sum0 = vmlal_n_u16(sum0, d0, bck_offset);
-
-    sum0 = vshrq_n_u32(sum0, DIST_PRECISION_BITS);
-
-    dst0 = vsubq_s32(vreinterpretq_s32_u32(sum0), vmovl_s16(sub_const_vec));
-
-    dst0 = vqrshlq_s32(dst0, round_bits_vec);
-
-    tmp0 = vmovn_s32(dst0);
-    tmp4 = vcombine_s16(tmp0, tmp0);
-
-    *t0 = vqmovun_s16(tmp4);
+    avg0 = vshrn_n_u32(blend0, DIST_PRECISION_BITS);
   } else {
-    const int16x4_t round_bits_vec = vdup_n_s16(-round_bits);
-    tmp_u0 = vhadd_u16(res0, d0);
-
-    tmp0 = vsub_s16(vreinterpret_s16_u16(tmp_u0), sub_const_vec);
-
-    tmp0 = vqrshl_s16(tmp0, round_bits_vec);
-
-    tmp4 = vcombine_s16(tmp0, vdup_n_s16(0));
-
-    *t0 = vqmovun_s16(tmp4);
+    avg0 = vhadd_u16(res0, d0);
   }
+
+  int16x4_t dst0 = vsub_s16(vreinterpret_s16_u16(avg0), sub_const);
+
+  int16x8_t dst0q = vcombine_s16(dst0, vdup_n_s16(0));
+
+  *t0 = vqshrun_n_s16(dst0q, FILTER_BITS - ROUND0_BITS);
 }
 
-static INLINE void compute_avg_8x1(
-    uint16x8_t res0, uint16x8_t d0, const uint16_t fwd_offset,
-    const uint16_t bck_offset, const int16x4_t sub_const,
-    const int16_t round_bits, const int use_dist_wtd_comp_avg, uint8x8_t *t0) {
-  int16x8_t f0;
-  uint32x4_t sum0, sum2;
-  int32x4_t dst0, dst2;
-
-  uint16x8_t tmp_u0;
+static INLINE void compute_avg_8x1(uint16x8_t res0, uint16x8_t d0,
+                                   const uint16_t fwd_offset,
+                                   const uint16_t bck_offset,
+                                   const int16x8_t sub_const,
+                                   const int use_dist_wtd_comp_avg,
+                                   uint8x8_t *t0) {
+  uint16x8_t avg0;
 
   if (use_dist_wtd_comp_avg) {
-    const int32x4_t sub_const_vec = vmovl_s16(sub_const);
-    const int32x4_t round_bits_vec = vdupq_n_s32(-(int32_t)round_bits);
+    uint32x4_t blend0_lo, blend0_hi;
 
-    sum0 = vmull_n_u16(vget_low_u16(res0), fwd_offset);
-    sum0 = vmlal_n_u16(sum0, vget_low_u16(d0), bck_offset);
-    sum0 = vshrq_n_u32(sum0, DIST_PRECISION_BITS);
+    blend0_lo = vmull_n_u16(vget_low_u16(res0), fwd_offset);
+    blend0_lo = vmlal_n_u16(blend0_lo, vget_low_u16(d0), bck_offset);
+    blend0_hi = vmull_n_u16(vget_high_u16(res0), fwd_offset);
+    blend0_hi = vmlal_n_u16(blend0_hi, vget_high_u16(d0), bck_offset);
 
-    sum2 = vmull_n_u16(vget_high_u16(res0), fwd_offset);
-    sum2 = vmlal_n_u16(sum2, vget_high_u16(d0), bck_offset);
-    sum2 = vshrq_n_u32(sum2, DIST_PRECISION_BITS);
-
-    dst0 = vsubq_s32(vreinterpretq_s32_u32(sum0), sub_const_vec);
-    dst2 = vsubq_s32(vreinterpretq_s32_u32(sum2), sub_const_vec);
-
-    dst0 = vqrshlq_s32(dst0, round_bits_vec);
-    dst2 = vqrshlq_s32(dst2, round_bits_vec);
-
-    f0 = vcombine_s16(vmovn_s32(dst0), vmovn_s32(dst2));
-
-    *t0 = vqmovun_s16(f0);
-
+    avg0 = vcombine_u16(vshrn_n_u32(blend0_lo, DIST_PRECISION_BITS),
+                        vshrn_n_u32(blend0_hi, DIST_PRECISION_BITS));
   } else {
-    const int16x8_t sub_const_vec = vcombine_s16(sub_const, sub_const);
-    const int16x8_t round_bits_vec = vdupq_n_s16(-round_bits);
-
-    tmp_u0 = vhaddq_u16(res0, d0);
-
-    f0 = vsubq_s16(vreinterpretq_s16_u16(tmp_u0), sub_const_vec);
-
-    f0 = vqrshlq_s16(f0, round_bits_vec);
-
-    *t0 = vqmovun_s16(f0);
+    avg0 = vhaddq_u16(res0, d0);
   }
+
+  int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), sub_const);
+
+  *t0 = vqshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
 }
 #endif  // !defined(__arch64__)
 
-static INLINE void compute_avg_4x4(
-    uint16x4_t res0, uint16x4_t res1, uint16x4_t res2, uint16x4_t res3,
-    uint16x4_t d0, uint16x4_t d1, uint16x4_t d2, uint16x4_t d3,
-    const uint16_t fwd_offset, const uint16_t bck_offset,
-    const int16x4_t sub_const_vec, const int16_t round_bits,
-    const int use_dist_wtd_comp_avg, uint8x8_t *t0, uint8x8_t *t1) {
-  int16x4_t tmp0, tmp1, tmp2, tmp3;
-  uint16x4_t tmp_u0, tmp_u1, tmp_u2, tmp_u3;
-  uint32x4_t sum0, sum1, sum2, sum3;
-
-  int32x4_t dst0, dst1, dst2, dst3;
-  int16x8_t tmp4, tmp5;
+static INLINE void compute_avg_4x4(uint16x4_t res0, uint16x4_t res1,
+                                   uint16x4_t res2, uint16x4_t res3,
+                                   uint16x4_t d0, uint16x4_t d1, uint16x4_t d2,
+                                   uint16x4_t d3, const uint16_t fwd_offset,
+                                   const uint16_t bck_offset,
+                                   const int16x8_t sub_const,
+                                   const int use_dist_wtd_comp_avg,
+                                   uint8x8_t *t0, uint8x8_t *t1) {
+  uint16x4_t avg0, avg1, avg2, avg3;
 
   if (use_dist_wtd_comp_avg) {
-    const int32x4_t round_bits_vec = vdupq_n_s32((int32_t)(-round_bits));
-    const int32x4_t const_vec = vmovl_s16(sub_const_vec);
+    uint32x4_t blend0, blend1, blend2, blend3;
 
-    sum0 = vmull_n_u16(res0, fwd_offset);
-    sum0 = vmlal_n_u16(sum0, d0, bck_offset);
-    sum1 = vmull_n_u16(res1, fwd_offset);
-    sum1 = vmlal_n_u16(sum1, d1, bck_offset);
-    sum2 = vmull_n_u16(res2, fwd_offset);
-    sum2 = vmlal_n_u16(sum2, d2, bck_offset);
-    sum3 = vmull_n_u16(res3, fwd_offset);
-    sum3 = vmlal_n_u16(sum3, d3, bck_offset);
+    blend0 = vmull_n_u16(res0, fwd_offset);
+    blend0 = vmlal_n_u16(blend0, d0, bck_offset);
+    blend1 = vmull_n_u16(res1, fwd_offset);
+    blend1 = vmlal_n_u16(blend1, d1, bck_offset);
+    blend2 = vmull_n_u16(res2, fwd_offset);
+    blend2 = vmlal_n_u16(blend2, d2, bck_offset);
+    blend3 = vmull_n_u16(res3, fwd_offset);
+    blend3 = vmlal_n_u16(blend3, d3, bck_offset);
 
-    sum0 = vshrq_n_u32(sum0, DIST_PRECISION_BITS);
-    sum1 = vshrq_n_u32(sum1, DIST_PRECISION_BITS);
-    sum2 = vshrq_n_u32(sum2, DIST_PRECISION_BITS);
-    sum3 = vshrq_n_u32(sum3, DIST_PRECISION_BITS);
-
-    dst0 = vsubq_s32(vreinterpretq_s32_u32(sum0), const_vec);
-    dst1 = vsubq_s32(vreinterpretq_s32_u32(sum1), const_vec);
-    dst2 = vsubq_s32(vreinterpretq_s32_u32(sum2), const_vec);
-    dst3 = vsubq_s32(vreinterpretq_s32_u32(sum3), const_vec);
-
-    dst0 = vqrshlq_s32(dst0, round_bits_vec);
-    dst1 = vqrshlq_s32(dst1, round_bits_vec);
-    dst2 = vqrshlq_s32(dst2, round_bits_vec);
-    dst3 = vqrshlq_s32(dst3, round_bits_vec);
-
-    tmp4 = vcombine_s16(vmovn_s32(dst0), vmovn_s32(dst1));
-    tmp5 = vcombine_s16(vmovn_s32(dst2), vmovn_s32(dst3));
-
-    *t0 = vqmovun_s16(tmp4);
-    *t1 = vqmovun_s16(tmp5);
+    avg0 = vshrn_n_u32(blend0, DIST_PRECISION_BITS);
+    avg1 = vshrn_n_u32(blend1, DIST_PRECISION_BITS);
+    avg2 = vshrn_n_u32(blend2, DIST_PRECISION_BITS);
+    avg3 = vshrn_n_u32(blend3, DIST_PRECISION_BITS);
   } else {
-    const int16x4_t round_bits_vec = vdup_n_s16(-round_bits);
-    tmp_u0 = vhadd_u16(res0, d0);
-    tmp_u1 = vhadd_u16(res1, d1);
-    tmp_u2 = vhadd_u16(res2, d2);
-    tmp_u3 = vhadd_u16(res3, d3);
-
-    tmp0 = vsub_s16(vreinterpret_s16_u16(tmp_u0), sub_const_vec);
-    tmp1 = vsub_s16(vreinterpret_s16_u16(tmp_u1), sub_const_vec);
-    tmp2 = vsub_s16(vreinterpret_s16_u16(tmp_u2), sub_const_vec);
-    tmp3 = vsub_s16(vreinterpret_s16_u16(tmp_u3), sub_const_vec);
-
-    tmp0 = vqrshl_s16(tmp0, round_bits_vec);
-    tmp1 = vqrshl_s16(tmp1, round_bits_vec);
-    tmp2 = vqrshl_s16(tmp2, round_bits_vec);
-    tmp3 = vqrshl_s16(tmp3, round_bits_vec);
-
-    tmp4 = vcombine_s16(tmp0, tmp1);
-    tmp5 = vcombine_s16(tmp2, tmp3);
-
-    *t0 = vqmovun_s16(tmp4);
-    *t1 = vqmovun_s16(tmp5);
+    avg0 = vhadd_u16(res0, d0);
+    avg1 = vhadd_u16(res1, d1);
+    avg2 = vhadd_u16(res2, d2);
+    avg3 = vhadd_u16(res3, d3);
   }
+
+  int16x8_t dst_01 = vreinterpretq_s16_u16(vcombine_u16(avg0, avg1));
+  int16x8_t dst_23 = vreinterpretq_s16_u16(vcombine_u16(avg2, avg3));
+
+  dst_01 = vsubq_s16(dst_01, sub_const);
+  dst_23 = vsubq_s16(dst_23, sub_const);
+
+  *t0 = vqshrun_n_s16(dst_01, FILTER_BITS - ROUND0_BITS);
+  *t1 = vqshrun_n_s16(dst_23, FILTER_BITS - ROUND0_BITS);
 }
 
 static INLINE void compute_avg_8x4(
     uint16x8_t res0, uint16x8_t res1, uint16x8_t res2, uint16x8_t res3,
     uint16x8_t d0, uint16x8_t d1, uint16x8_t d2, uint16x8_t d3,
     const uint16_t fwd_offset, const uint16_t bck_offset,
-    const int16x4_t sub_const, const int16_t round_bits,
-    const int use_dist_wtd_comp_avg, uint8x8_t *t0, uint8x8_t *t1,
-    uint8x8_t *t2, uint8x8_t *t3) {
-  int16x8_t f0, f1, f2, f3;
-  uint32x4_t sum0, sum1, sum2, sum3;
-  uint32x4_t sum4, sum5, sum6, sum7;
-  int32x4_t dst0, dst1, dst2, dst3;
-  int32x4_t dst4, dst5, dst6, dst7;
-  uint16x8_t tmp_u0, tmp_u1, tmp_u2, tmp_u3;
+    const int16x8_t sub_const, const int use_dist_wtd_comp_avg, uint8x8_t *t0,
+    uint8x8_t *t1, uint8x8_t *t2, uint8x8_t *t3) {
+  uint16x8_t avg0, avg1, avg2, avg3;
 
   if (use_dist_wtd_comp_avg) {
-    const int32x4_t sub_const_vec = vmovl_s16(sub_const);
-    const int32x4_t round_bits_vec = vdupq_n_s32(-(int32_t)round_bits);
+    uint32x4_t blend0_lo, blend1_lo, blend2_lo, blend3_lo;
+    uint32x4_t blend0_hi, blend1_hi, blend2_hi, blend3_hi;
 
-    sum0 = vmull_n_u16(vget_low_u16(res0), fwd_offset);
-    sum0 = vmlal_n_u16(sum0, vget_low_u16(d0), bck_offset);
-    sum1 = vmull_n_u16(vget_low_u16(res1), fwd_offset);
-    sum1 = vmlal_n_u16(sum1, vget_low_u16(d1), bck_offset);
-    sum0 = vshrq_n_u32(sum0, DIST_PRECISION_BITS);
-    sum1 = vshrq_n_u32(sum1, DIST_PRECISION_BITS);
+    blend0_lo = vmull_n_u16(vget_low_u16(res0), fwd_offset);
+    blend0_lo = vmlal_n_u16(blend0_lo, vget_low_u16(d0), bck_offset);
+    blend0_hi = vmull_n_u16(vget_high_u16(res0), fwd_offset);
+    blend0_hi = vmlal_n_u16(blend0_hi, vget_high_u16(d0), bck_offset);
 
-    sum2 = vmull_n_u16(vget_high_u16(res0), fwd_offset);
-    sum2 = vmlal_n_u16(sum2, vget_high_u16(d0), bck_offset);
-    sum3 = vmull_n_u16(vget_high_u16(res1), fwd_offset);
-    sum3 = vmlal_n_u16(sum3, vget_high_u16(d1), bck_offset);
-    sum2 = vshrq_n_u32(sum2, DIST_PRECISION_BITS);
-    sum3 = vshrq_n_u32(sum3, DIST_PRECISION_BITS);
+    blend1_lo = vmull_n_u16(vget_low_u16(res1), fwd_offset);
+    blend1_lo = vmlal_n_u16(blend1_lo, vget_low_u16(d1), bck_offset);
+    blend1_hi = vmull_n_u16(vget_high_u16(res1), fwd_offset);
+    blend1_hi = vmlal_n_u16(blend1_hi, vget_high_u16(d1), bck_offset);
 
-    sum4 = vmull_n_u16(vget_low_u16(res2), fwd_offset);
-    sum4 = vmlal_n_u16(sum4, vget_low_u16(d2), bck_offset);
-    sum5 = vmull_n_u16(vget_low_u16(res3), fwd_offset);
-    sum5 = vmlal_n_u16(sum5, vget_low_u16(d3), bck_offset);
-    sum4 = vshrq_n_u32(sum4, DIST_PRECISION_BITS);
-    sum5 = vshrq_n_u32(sum5, DIST_PRECISION_BITS);
+    blend2_lo = vmull_n_u16(vget_low_u16(res2), fwd_offset);
+    blend2_lo = vmlal_n_u16(blend2_lo, vget_low_u16(d2), bck_offset);
+    blend2_hi = vmull_n_u16(vget_high_u16(res2), fwd_offset);
+    blend2_hi = vmlal_n_u16(blend2_hi, vget_high_u16(d2), bck_offset);
 
-    sum6 = vmull_n_u16(vget_high_u16(res2), fwd_offset);
-    sum6 = vmlal_n_u16(sum6, vget_high_u16(d2), bck_offset);
-    sum7 = vmull_n_u16(vget_high_u16(res3), fwd_offset);
-    sum7 = vmlal_n_u16(sum7, vget_high_u16(d3), bck_offset);
-    sum6 = vshrq_n_u32(sum6, DIST_PRECISION_BITS);
-    sum7 = vshrq_n_u32(sum7, DIST_PRECISION_BITS);
+    blend3_lo = vmull_n_u16(vget_low_u16(res3), fwd_offset);
+    blend3_lo = vmlal_n_u16(blend3_lo, vget_low_u16(d3), bck_offset);
+    blend3_hi = vmull_n_u16(vget_high_u16(res3), fwd_offset);
+    blend3_hi = vmlal_n_u16(blend3_hi, vget_high_u16(d3), bck_offset);
 
-    dst0 = vsubq_s32(vreinterpretq_s32_u32(sum0), sub_const_vec);
-    dst1 = vsubq_s32(vreinterpretq_s32_u32(sum1), sub_const_vec);
-    dst2 = vsubq_s32(vreinterpretq_s32_u32(sum2), sub_const_vec);
-    dst3 = vsubq_s32(vreinterpretq_s32_u32(sum3), sub_const_vec);
-    dst4 = vsubq_s32(vreinterpretq_s32_u32(sum4), sub_const_vec);
-    dst5 = vsubq_s32(vreinterpretq_s32_u32(sum5), sub_const_vec);
-    dst6 = vsubq_s32(vreinterpretq_s32_u32(sum6), sub_const_vec);
-    dst7 = vsubq_s32(vreinterpretq_s32_u32(sum7), sub_const_vec);
-
-    dst0 = vqrshlq_s32(dst0, round_bits_vec);
-    dst1 = vqrshlq_s32(dst1, round_bits_vec);
-    dst2 = vqrshlq_s32(dst2, round_bits_vec);
-    dst3 = vqrshlq_s32(dst3, round_bits_vec);
-    dst4 = vqrshlq_s32(dst4, round_bits_vec);
-    dst5 = vqrshlq_s32(dst5, round_bits_vec);
-    dst6 = vqrshlq_s32(dst6, round_bits_vec);
-    dst7 = vqrshlq_s32(dst7, round_bits_vec);
-
-    f0 = vcombine_s16(vmovn_s32(dst0), vmovn_s32(dst2));
-    f1 = vcombine_s16(vmovn_s32(dst1), vmovn_s32(dst3));
-    f2 = vcombine_s16(vmovn_s32(dst4), vmovn_s32(dst6));
-    f3 = vcombine_s16(vmovn_s32(dst5), vmovn_s32(dst7));
-
-    *t0 = vqmovun_s16(f0);
-    *t1 = vqmovun_s16(f1);
-    *t2 = vqmovun_s16(f2);
-    *t3 = vqmovun_s16(f3);
-
+    avg0 = vcombine_u16(vshrn_n_u32(blend0_lo, DIST_PRECISION_BITS),
+                        vshrn_n_u32(blend0_hi, DIST_PRECISION_BITS));
+    avg1 = vcombine_u16(vshrn_n_u32(blend1_lo, DIST_PRECISION_BITS),
+                        vshrn_n_u32(blend1_hi, DIST_PRECISION_BITS));
+    avg2 = vcombine_u16(vshrn_n_u32(blend2_lo, DIST_PRECISION_BITS),
+                        vshrn_n_u32(blend2_hi, DIST_PRECISION_BITS));
+    avg3 = vcombine_u16(vshrn_n_u32(blend3_lo, DIST_PRECISION_BITS),
+                        vshrn_n_u32(blend3_hi, DIST_PRECISION_BITS));
   } else {
-    const int16x8_t sub_const_vec = vcombine_s16(sub_const, sub_const);
-    const int16x8_t round_bits_vec = vdupq_n_s16(-round_bits);
-
-    tmp_u0 = vhaddq_u16(res0, d0);
-    tmp_u1 = vhaddq_u16(res1, d1);
-    tmp_u2 = vhaddq_u16(res2, d2);
-    tmp_u3 = vhaddq_u16(res3, d3);
-
-    f0 = vsubq_s16(vreinterpretq_s16_u16(tmp_u0), sub_const_vec);
-    f1 = vsubq_s16(vreinterpretq_s16_u16(tmp_u1), sub_const_vec);
-    f2 = vsubq_s16(vreinterpretq_s16_u16(tmp_u2), sub_const_vec);
-    f3 = vsubq_s16(vreinterpretq_s16_u16(tmp_u3), sub_const_vec);
-
-    f0 = vqrshlq_s16(f0, round_bits_vec);
-    f1 = vqrshlq_s16(f1, round_bits_vec);
-    f2 = vqrshlq_s16(f2, round_bits_vec);
-    f3 = vqrshlq_s16(f3, round_bits_vec);
-
-    *t0 = vqmovun_s16(f0);
-    *t1 = vqmovun_s16(f1);
-    *t2 = vqmovun_s16(f2);
-    *t3 = vqmovun_s16(f3);
+    avg0 = vhaddq_u16(res0, d0);
+    avg1 = vhaddq_u16(res1, d1);
+    avg2 = vhaddq_u16(res2, d2);
+    avg3 = vhaddq_u16(res3, d3);
   }
+
+  int16x8_t dst0 = vsubq_s16(vreinterpretq_s16_u16(avg0), sub_const);
+  int16x8_t dst1 = vsubq_s16(vreinterpretq_s16_u16(avg1), sub_const);
+  int16x8_t dst2 = vsubq_s16(vreinterpretq_s16_u16(avg2), sub_const);
+  int16x8_t dst3 = vsubq_s16(vreinterpretq_s16_u16(avg3), sub_const);
+
+  *t0 = vqshrun_n_s16(dst0, FILTER_BITS - ROUND0_BITS);
+  *t1 = vqshrun_n_s16(dst1, FILTER_BITS - ROUND0_BITS);
+  *t2 = vqshrun_n_s16(dst2, FILTER_BITS - ROUND0_BITS);
+  *t3 = vqshrun_n_s16(dst3, FILTER_BITS - ROUND0_BITS);
 }
 
 #if defined(__aarch64__) && defined(__ARM_FEATURE_MATMUL_INT8)
@@ -710,14 +606,15 @@
 
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
+  // non-rounding shifts when computing the average.
   const int16_t sub_const = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)) -
+                            (1 << (FILTER_BITS - ROUND0_BITS - 1));
 
-  const int16_t round_bits =
-      2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS;
   const int offset = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int32x4_t offset_const = vdupq_n_s32(1 << offset);
-  const int16x4_t sub_const_vec = vdup_n_s16(sub_const);
+  const int16x8_t sub_const_vec = vdupq_n_s16(sub_const);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int do_average = conv_params->do_average;
@@ -750,8 +647,8 @@
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
         compute_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
-                        bck_offset, sub_const_vec, round_bits,
-                        use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+                        bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                        &d01_u8, &d23_u8);
 
         store_u8_4x1(dst8_ptr + 0 * dst8_stride, d01_u8, 0);
         store_u8_4x1(dst8_ptr + 1 * dst8_stride, d01_u8, 1);
@@ -778,8 +675,9 @@
       if (do_average) {
         dd0 = vld1_u16(dst_ptr);
 
-        compute_avg_4x1(dd0, d0, fwd_offset, bck_offset, sub_const_vec,
-                        round_bits, use_dist_wtd_comp_avg, &d01_u8);
+        compute_avg_4x1(dd0, d0, fwd_offset, bck_offset,
+                        vget_low_s16(sub_const_vec), use_dist_wtd_comp_avg,
+                        &d01_u8);
 
         store_u8_4x1(dst8_ptr, d01_u8, 0);
         dst8_ptr += dst8_stride;
@@ -831,9 +729,8 @@
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
           compute_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
-                          bck_offset, sub_const_vec, round_bits,
-                          use_dist_wtd_comp_avg, &d0_u8, &d1_u8, &d2_u8,
-                          &d3_u8);
+                          bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                          &d0_u8, &d1_u8, &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
           d_u8 += 4 * dst8_stride;
@@ -858,7 +755,7 @@
           dd0 = vld1q_u16(d);
 
           compute_avg_8x1(dd0, d0, fwd_offset, bck_offset, sub_const_vec,
-                          round_bits, use_dist_wtd_comp_avg, &d0_u8);
+                          use_dist_wtd_comp_avg, &d0_u8);
 
           vst1_u8(d_u8, d0_u8);
           d_u8 += dst8_stride;
@@ -894,14 +791,15 @@
 
   const int bd = 8;
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
+  // non-rounding shifts when computing the average.
   const int16_t sub_const = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
-                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)) -
+                            (1 << (FILTER_BITS - ROUND0_BITS - 1));
 
-  const int16_t round_bits =
-      2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS;
   const int offset = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int32x4_t offset_const = vdupq_n_s32(1 << offset);
-  const int16x4_t sub_const_vec = vdup_n_s16(sub_const);
+  const int16x8_t sub_const_vec = vdupq_n_s16(sub_const);
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int do_average = conv_params->do_average;
@@ -938,8 +836,8 @@
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
         compute_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
-                        bck_offset, sub_const_vec, round_bits,
-                        use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+                        bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                        &d01_u8, &d23_u8);
 
         store_u8_4x1(dst8_ptr + 0 * dst8_stride, d01_u8, 0);
         store_u8_4x1(dst8_ptr + 1 * dst8_stride, d01_u8, 1);
@@ -969,8 +867,9 @@
       if (do_average) {
         dd0 = vld1_u16(dst_ptr);
 
-        compute_avg_4x1(dd0, d0, fwd_offset, bck_offset, sub_const_vec,
-                        round_bits, use_dist_wtd_comp_avg, &d01_u8);
+        compute_avg_4x1(dd0, d0, fwd_offset, bck_offset,
+                        vget_low_s16(sub_const_vec), use_dist_wtd_comp_avg,
+                        &d01_u8);
 
         store_u8_4x1(dst8_ptr, d01_u8, 0);
         dst8_ptr += dst8_stride;
@@ -1028,9 +927,8 @@
           load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
           compute_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
-                          bck_offset, sub_const_vec, round_bits,
-                          use_dist_wtd_comp_avg, &d0_u8, &d1_u8, &d2_u8,
-                          &d3_u8);
+                          bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                          &d0_u8, &d1_u8, &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
           d_u8 += 4 * dst8_stride;
@@ -1058,7 +956,7 @@
           dd0 = vld1q_u16(d);
 
           compute_avg_8x1(dd0, d0, fwd_offset, bck_offset, sub_const_vec,
-                          round_bits, use_dist_wtd_comp_avg, &d0_u8);
+                          use_dist_wtd_comp_avg, &d0_u8);
 
           vst1_u8(d_u8, d0_u8);
           d_u8 += dst8_stride;
@@ -1139,7 +1037,10 @@
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  const int16x4_t sub_const_vec = vdup_n_s16((int16_t)round_offset);
+  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
+  // non-rounding shifts when computing the average.
+  const int16x8_t sub_const_vec = vdupq_n_s16(
+      (int16_t)round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
   const uint8x8_t shift_by_bits = vdup_n_u8(1 << (FILTER_BITS - ROUND0_BITS));
 
   if (w >= 8) {
@@ -1165,10 +1066,10 @@
         if (conv_params->do_average) {
           load_u16_8x4(d_conv_buf, dst_stride, &t0, &t1, &t2, &t3);
 
-          compute_avg_8x4(
-              t0, t1, t2, t3, d0, d1, d2, d3, conv_params->fwd_offset,
-              conv_params->bck_offset, sub_const_vec, FILTER_BITS - ROUND0_BITS,
-              conv_params->use_dist_wtd_comp_avg, &dd0, &dd1, &dd2, &dd3);
+          compute_avg_8x4(t0, t1, t2, t3, d0, d1, d2, d3,
+                          conv_params->fwd_offset, conv_params->bck_offset,
+                          sub_const_vec, conv_params->use_dist_wtd_comp_avg,
+                          &dd0, &dd1, &dd2, &dd3);
 
           store_u8_8x4(d_u8, dst8_stride, dd0, dd1, dd2, dd3);
         } else {
@@ -1202,7 +1103,6 @@
 
         compute_avg_4x4(t0, t1, t2, t3, d0, d1, d2, d3, conv_params->fwd_offset,
                         conv_params->bck_offset, sub_const_vec,
-                        FILTER_BITS - ROUND0_BITS,
                         conv_params->use_dist_wtd_comp_avg, &d01, &d23);
 
         store_u8_4x1(dst8 + 0 * dst8_stride, d01, 0);
@@ -1235,11 +1135,13 @@
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  const int round_bits = FILTER_BITS - ROUND0_BITS;
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
-  const int16x4_t round_offset64 = vdup_n_s16(round_offset);
+  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
+  // non-rounding shifts when computing the average.
+  const int16x8_t sub_const_vec =
+      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
   const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
 
   // Horizontal filter.
@@ -1291,13 +1193,12 @@
       if (conv_params->do_average) {
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
-        compute_avg_4x4(dd0, dd1, dd2, dd3,
-                        vreinterpret_u16_s16(vget_low_s16(d01)),
-                        vreinterpret_u16_s16(vget_high_s16(d01)),
-                        vreinterpret_u16_s16(vget_low_s16(d23)),
-                        vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset,
-                        bck_offset, round_offset64, round_bits,
-                        use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+        compute_avg_4x4(
+            dd0, dd1, dd2, dd3, vreinterpret_u16_s16(vget_low_s16(d01)),
+            vreinterpret_u16_s16(vget_high_s16(d01)),
+            vreinterpret_u16_s16(vget_low_s16(d23)),
+            vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset, bck_offset,
+            sub_const_vec, use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
 
         store_u8_4x1(dst_u8_ptr + 0 * dst8_stride, d01_u8, 0);
         store_u8_4x1(dst_u8_ptr + 1 * dst8_stride, d01_u8, 1);
@@ -1349,8 +1250,8 @@
           compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
                           vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
                           vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &d0_u8, &d1_u8, &d2_u8, &d3_u8);
+                          sub_const_vec, use_dist_wtd_comp_avg, &d0_u8, &d1_u8,
+                          &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
         } else {
@@ -1388,11 +1289,13 @@
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  const int round_bits = FILTER_BITS - ROUND0_BITS;
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
-  const int16x4_t round_offset64 = vdup_n_s16(round_offset);
+  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
+  // non-rounding shifts when computing the average.
+  const int16x8_t sub_const_vec =
+      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
   const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
 
   // Horizontal filter.
@@ -1448,13 +1351,12 @@
       if (conv_params->do_average) {
         load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
 
-        compute_avg_4x4(dd0, dd1, dd2, dd3,
-                        vreinterpret_u16_s16(vget_low_s16(d01)),
-                        vreinterpret_u16_s16(vget_high_s16(d01)),
-                        vreinterpret_u16_s16(vget_low_s16(d23)),
-                        vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset,
-                        bck_offset, round_offset64, round_bits,
-                        use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
+        compute_avg_4x4(
+            dd0, dd1, dd2, dd3, vreinterpret_u16_s16(vget_low_s16(d01)),
+            vreinterpret_u16_s16(vget_high_s16(d01)),
+            vreinterpret_u16_s16(vget_low_s16(d23)),
+            vreinterpret_u16_s16(vget_high_s16(d23)), fwd_offset, bck_offset,
+            sub_const_vec, use_dist_wtd_comp_avg, &d01_u8, &d23_u8);
 
         store_u8_4x1(dst_u8_ptr + 0 * dst8_stride, d01_u8, 0);
         store_u8_4x1(dst_u8_ptr + 1 * dst8_stride, d01_u8, 1);
@@ -1516,8 +1418,8 @@
           compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
                           vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
                           vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &d0_u8, &d1_u8, &d2_u8, &d3_u8);
+                          sub_const_vec, use_dist_wtd_comp_avg, &d0_u8, &d1_u8,
+                          &d2_u8, &d3_u8);
 
           store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
         } else {
@@ -1599,10 +1501,13 @@
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  const int round_bits = FILTER_BITS - ROUND0_BITS;
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
+  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
+  // non-rounding shifts when computing the average.
+  const int16x8_t sub_const_vec =
+      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
 
   // horizontal filter
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
@@ -1731,8 +1636,7 @@
           compute_avg_4x4(res4, res5, res6, res7, vreinterpret_u16_s16(d0),
                           vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
                           vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
-                          vget_low_s16(round_offset_vec), round_bits,
-                          use_dist_wtd_comp_avg, &t0, &t1);
+                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1);
 
           store_u8_4x1(d_u8 + 0 * dst8_stride, t0, 0);
           store_u8_4x1(d_u8 + 1 * dst8_stride, t0, 1);
@@ -1795,7 +1699,7 @@
           res4 = vld1_u16(d);
 
           compute_avg_4x1(res4, vreinterpret_u16_s16(d0), fwd_offset,
-                          bck_offset, round_offset_vec, round_bits,
+                          bck_offset, vget_low_s16(sub_const_vec),
                           use_dist_wtd_comp_avg, &t0);
 
           store_u8_4x1(d_u8, t0, 0);
@@ -1821,7 +1725,6 @@
     int16x8_t res0;
     uint16x8_t res8;
     const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
-    const int16x4_t round_offset64 = vdup_n_s16(round_offset);
     // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
     // shifts - which are generally faster than rounding shifts on modern CPUs.
     // The outermost -1 is needed because we halved the filter values.
@@ -1925,12 +1828,11 @@
           load_u16_8x4(d_tmp, dst_stride, &res8, &res9, &res10, &res11);
           d_tmp += 4 * dst_stride;
 
-          compute_avg_8x4(res8, res9, res10, res11, vreinterpretq_u16_s16(res0),
-                          vreinterpretq_u16_s16(res1),
-                          vreinterpretq_u16_s16(res2),
-                          vreinterpretq_u16_s16(res3), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &t0, &t1, &t2, &t3);
+          compute_avg_8x4(
+              res8, res9, res10, res11, vreinterpretq_u16_s16(res0),
+              vreinterpretq_u16_s16(res1), vreinterpretq_u16_s16(res2),
+              vreinterpretq_u16_s16(res3), fwd_offset, bck_offset,
+              sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2, &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -1938,12 +1840,11 @@
           load_u16_8x4(d_tmp, dst_stride, &res8, &res9, &res10, &res11);
           d_tmp += 4 * dst_stride;
 
-          compute_avg_8x4(res8, res9, res10, res11, vreinterpretq_u16_s16(res4),
-                          vreinterpretq_u16_s16(res5),
-                          vreinterpretq_u16_s16(res6),
-                          vreinterpretq_u16_s16(res7), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &t0, &t1, &t2, &t3);
+          compute_avg_8x4(
+              res8, res9, res10, res11, vreinterpretq_u16_s16(res4),
+              vreinterpretq_u16_s16(res5), vreinterpretq_u16_s16(res6),
+              vreinterpretq_u16_s16(res7), fwd_offset, bck_offset,
+              sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2, &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2014,8 +1915,8 @@
           d_tmp += dst_stride;
 
           compute_avg_8x1(res8, vreinterpretq_u16_s16(res0), fwd_offset,
-                          bck_offset, round_offset64, round_bits,
-                          use_dist_wtd_comp_avg, &t0);
+                          bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                          &t0);
 
           vst1_u8(d_u8, t0);
           d_u8 += dst8_stride;
@@ -2133,7 +2034,10 @@
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  const int round_bits = FILTER_BITS - ROUND0_BITS;
+  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
+  // non-rounding shifts when computing the average.
+  const int16x8_t sub_const_vec =
+      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
@@ -2206,8 +2110,7 @@
           compute_avg_4x4(dd0, dd1, dd2, dd3, vreinterpret_u16_s16(d0),
                           vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
                           vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &d01, &d23);
+                          sub_const_vec, use_dist_wtd_comp_avg, &d01, &d23);
 
           store_u8_4x1(d_u8 + 0 * dst8_stride, d01, 0);
           store_u8_4x1(d_u8 + 1 * dst8_stride, d01, 1);
@@ -2240,7 +2143,7 @@
           dd0 = vld1_u16(d);
 
           compute_avg_4x1(dd0, vreinterpret_u16_s16(d0), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
+                          vget_low_s16(sub_const_vec), use_dist_wtd_comp_avg,
                           &d01);
 
           store_u8_4x1(d_u8, d01, 0);
@@ -2274,7 +2177,6 @@
     // The outermost -1 is needed because we halved the filter values.
     const int16x8_t vert_const = vdupq_n_s16(1 << ((ROUND0_BITS - 1) - 1));
     const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
-    const int16x4_t round_offset64 = vdup_n_s16(round_offset);
 #if defined(__aarch64__)
     int16x8_t s6, s7, s8, s9, s10, s11, s12, d1, d2, d3, d4, d5, d6, d7;
     uint16x8_t d9, d10, d11;
@@ -2336,8 +2238,8 @@
           compute_avg_8x4(d8, d9, d10, d11, vreinterpretq_u16_s16(d0),
                           vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
                           vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &t0, &t1, &t2, &t3);
+                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2,
+                          &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2348,8 +2250,8 @@
           compute_avg_8x4(d8, d9, d10, d11, vreinterpretq_u16_s16(d4),
                           vreinterpretq_u16_s16(d5), vreinterpretq_u16_s16(d6),
                           vreinterpretq_u16_s16(d7), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &t0, &t1, &t2, &t3);
+                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2,
+                          &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2386,8 +2288,7 @@
           d += dst_stride;
 
           compute_avg_8x1(d8, vreinterpretq_u16_s16(d0), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &t0);
+                          sub_const_vec, use_dist_wtd_comp_avg, &t0);
 
           vst1_u8(d_u8, t0);
           d_u8 += dst8_stride;
@@ -2420,7 +2321,10 @@
   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
   const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
-  const int round_bits = FILTER_BITS - ROUND0_BITS;
+  // Subtracting (1 << (FILTER_BITS - ROUND0_BITS - 1)) allows us to use
+  // non-rounding shifts when computing the average.
+  const int16x8_t sub_const_vec =
+      vdupq_n_s16(round_offset - (1 << (FILTER_BITS - ROUND0_BITS - 1)));
   const uint16_t fwd_offset = conv_params->fwd_offset;
   const uint16_t bck_offset = conv_params->bck_offset;
   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
@@ -2441,9 +2345,9 @@
     uint16x4_t dd1, dd2, dd3;
     int16x8_t t01, t23;
     uint8x8_t d23;
-#else   // !defined(__aarch64__)
+#else  // !defined(__aarch64__)
     const int16x4_t round_offset64 = vdup_n_s16(round_offset);
-#endif  // defined(__aarch64__)
+#endif
     // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
     // shifts - which are generally faster than rounding shifts on modern CPUs.
     // The outermost -1 is needed because we halved the filter values.
@@ -2534,8 +2438,7 @@
           compute_avg_4x4(dd0, dd1, dd2, dd3, vreinterpret_u16_s16(d0),
                           vreinterpret_u16_s16(d1), vreinterpret_u16_s16(d2),
                           vreinterpret_u16_s16(d3), fwd_offset, bck_offset,
-                          vget_low_s16(round_offset64), round_bits,
-                          use_dist_wtd_comp_avg, &d01, &d23);
+                          sub_const_vec, use_dist_wtd_comp_avg, &d01, &d23);
 
           store_u8_4x1(d_u8 + 0 * dst8_stride, d01, 0);
           store_u8_4x1(d_u8 + 1 * dst8_stride, d01, 1);
@@ -2575,7 +2478,7 @@
           dd0 = vld1_u16(d);
 
           compute_avg_4x1(dd0, vreinterpret_u16_s16(d0), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
+                          vget_low_s16(sub_const_vec), use_dist_wtd_comp_avg,
                           &d01);
 
           store_u8_4x1(d_u8, d01, 0);
@@ -2603,7 +2506,6 @@
     } while (width > 0);
   } else {
     const int16x8_t round_offset128 = vdupq_n_s16(round_offset);
-    const int16x4_t round_offset64 = vdup_n_s16(round_offset);
     // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
     // shifts - which are generally faster than rounding shifts on modern CPUs.
     // The outermost -1 is needed because we halved the filter values.
@@ -2701,8 +2603,8 @@
           compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d0),
                           vreinterpretq_u16_s16(d1), vreinterpretq_u16_s16(d2),
                           vreinterpretq_u16_s16(d3), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &t0, &t1, &t2, &t3);
+                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2,
+                          &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2713,8 +2615,8 @@
           compute_avg_8x4(dd0, dd1, dd2, dd3, vreinterpretq_u16_s16(d4),
                           vreinterpretq_u16_s16(d5), vreinterpretq_u16_s16(d6),
                           vreinterpretq_u16_s16(d7), fwd_offset, bck_offset,
-                          round_offset64, round_bits, use_dist_wtd_comp_avg,
-                          &t0, &t1, &t2, &t3);
+                          sub_const_vec, use_dist_wtd_comp_avg, &t0, &t1, &t2,
+                          &t3);
 
           store_u8_8x4(d_u8, dst8_stride, t0, t1, t2, t3);
           d_u8 += 4 * dst8_stride;
@@ -2760,8 +2662,8 @@
           d += dst_stride;
 
           compute_avg_8x1(dd0, vreinterpretq_u16_s16(d0), fwd_offset,
-                          bck_offset, round_offset64, round_bits,
-                          use_dist_wtd_comp_avg, &t0);
+                          bck_offset, sub_const_vec, use_dist_wtd_comp_avg,
+                          &t0);
 
           vst1_u8(d_u8, t0);
           d_u8 += dst8_stride;