Add Neon dist_wtd_avg functions to a new dist_wtd_avg.h file

Add Neon dist_wtd_avg_<size> helper functions to dist_wtd_avg.h and
use them to reduce code duplication in avg_pred_neon.c and
highbd_variance_neon.c.

Change-Id: If5f1104b7f7cd15f7acd409d1850f19cf88558da
diff --git a/aom_dsp/arm/avg_pred_neon.c b/aom_dsp/arm/avg_pred_neon.c
index 8fe151d..b17f7fc 100644
--- a/aom_dsp/arm/avg_pred_neon.c
+++ b/aom_dsp/arm/avg_pred_neon.c
@@ -14,8 +14,9 @@
 
 #include "config/aom_dsp_rtcd.h"
 
-#include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/blend_neon.h"
+#include "aom_dsp/arm/dist_wtd_avg_neon.h"
+#include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/blend.h"
 
 void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
@@ -94,22 +95,8 @@
         const uint8x16_t p = vld1q_u8(pred_ptr);
         const uint8x16_t r = vld1q_u8(ref_ptr);
 
-        uint16x8_t wtd_sum_lo =
-            vmull_u8(vget_low_u8(p), vget_low_u8(bck_offset));
-        uint16x8_t wtd_sum_hi =
-            vmull_u8(vget_high_u8(p), vget_high_u8(bck_offset));
-
-        wtd_sum_lo =
-            vmlal_u8(wtd_sum_lo, vget_low_u8(r), vget_low_u8(fwd_offset));
-        wtd_sum_hi =
-            vmlal_u8(wtd_sum_hi, vget_high_u8(r), vget_high_u8(fwd_offset));
-
-        const uint8x8_t wtd_avg_lo =
-            vrshrn_n_u16(wtd_sum_lo, DIST_PRECISION_BITS);
-        const uint8x8_t wtd_avg_hi =
-            vrshrn_n_u16(wtd_sum_hi, DIST_PRECISION_BITS);
-
-        const uint8x16_t wtd_avg = vcombine_u8(wtd_avg_lo, wtd_avg_hi);
+        const uint8x16_t wtd_avg =
+            dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset);
 
         vst1q_u8(comp_pred_ptr, wtd_avg);
 
@@ -130,21 +117,8 @@
       const uint8x16_t p = vld1q_u8(pred);
       const uint8x16_t r = load_u8_8x2(ref, ref_stride);
 
-      uint16x8_t wtd_sum_lo = vmull_u8(vget_low_u8(p), vget_low_u8(bck_offset));
-      uint16x8_t wtd_sum_hi =
-          vmull_u8(vget_high_u8(p), vget_high_u8(bck_offset));
-
-      wtd_sum_lo =
-          vmlal_u8(wtd_sum_lo, vget_low_u8(r), vget_low_u8(fwd_offset));
-      wtd_sum_hi =
-          vmlal_u8(wtd_sum_hi, vget_high_u8(r), vget_high_u8(fwd_offset));
-
-      const uint8x8_t wtd_avg_lo =
-          vrshrn_n_u16(wtd_sum_lo, DIST_PRECISION_BITS);
-      const uint8x8_t wtd_avg_hi =
-          vrshrn_n_u16(wtd_sum_hi, DIST_PRECISION_BITS);
-
-      const uint8x16_t wtd_avg = vcombine_u8(wtd_avg_lo, wtd_avg_hi);
+      const uint8x16_t wtd_avg =
+          dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset);
 
       vst1q_u8(comp_pred, wtd_avg);
 
@@ -160,10 +134,8 @@
       const uint8x8_t p = vld1_u8(pred);
       const uint8x8_t r = load_unaligned_u8_4x2(ref, ref_stride);
 
-      uint16x8_t wtd_sum = vmull_u8(p, vget_low_u8(bck_offset));
-      wtd_sum = vmlal_u8(wtd_sum, r, vget_low_u8(fwd_offset));
-
-      const uint8x8_t wtd_avg = vrshrn_n_u16(wtd_sum, DIST_PRECISION_BITS);
+      const uint8x8_t wtd_avg = dist_wtd_avg_u8x8(r, p, vget_low_u8(fwd_offset),
+                                                  vget_low_u8(bck_offset));
 
       vst1_u8(comp_pred, wtd_avg);
 
diff --git a/aom_dsp/arm/dist_wtd_avg_neon.h b/aom_dsp/arm/dist_wtd_avg_neon.h
new file mode 100644
index 0000000..19c9b04
--- /dev/null
+++ b/aom_dsp/arm/dist_wtd_avg_neon.h
@@ -0,0 +1,65 @@
+/*
+ *  Copyright (c) 2023, Alliance for Open Media. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_
+#define AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_
+
+#include <arm_neon.h>
+
+#include "aom_dsp/aom_dsp_common.h"
+#include "av1/common/enums.h"
+
+static INLINE uint8x8_t dist_wtd_avg_u8x8(uint8x8_t a, uint8x8_t b,
+                                          uint8x8_t wta, uint8x8_t wtb) {
+  uint16x8_t wtd_sum = vmull_u8(a, wta);
+
+  wtd_sum = vmlal_u8(wtd_sum, b, wtb);
+
+  return vrshrn_n_u16(wtd_sum, DIST_PRECISION_BITS);
+}
+
+static INLINE uint16x4_t dist_wtd_avg_u16x4(uint16x4_t a, uint16x4_t b,
+                                            uint16x4_t wta, uint16x4_t wtb) {
+  uint32x4_t wtd_sum = vmull_u16(a, wta);
+
+  wtd_sum = vmlal_u16(wtd_sum, b, wtb);
+
+  return vrshrn_n_u32(wtd_sum, DIST_PRECISION_BITS);
+}
+
+static INLINE uint8x16_t dist_wtd_avg_u8x16(uint8x16_t a, uint8x16_t b,
+                                            uint8x16_t wta, uint8x16_t wtb) {
+  uint16x8_t wtd_sum_lo = vmull_u8(vget_low_u8(a), vget_low_u8(wta));
+  uint16x8_t wtd_sum_hi = vmull_u8(vget_high_u8(a), vget_high_u8(wta));
+
+  wtd_sum_lo = vmlal_u8(wtd_sum_lo, vget_low_u8(b), vget_low_u8(wtb));
+  wtd_sum_hi = vmlal_u8(wtd_sum_hi, vget_high_u8(b), vget_high_u8(wtb));
+
+  uint8x8_t wtd_avg_lo = vrshrn_n_u16(wtd_sum_lo, DIST_PRECISION_BITS);
+  uint8x8_t wtd_avg_hi = vrshrn_n_u16(wtd_sum_hi, DIST_PRECISION_BITS);
+
+  return vcombine_u8(wtd_avg_lo, wtd_avg_hi);
+}
+
+static INLINE uint16x8_t dist_wtd_avg_u16x8(uint16x8_t a, uint16x8_t b,
+                                            uint16x8_t wta, uint16x8_t wtb) {
+  uint32x4_t wtd_sum_lo = vmull_u16(vget_low_u16(a), vget_low_u16(wta));
+  uint32x4_t wtd_sum_hi = vmull_u16(vget_high_u16(a), vget_high_u16(wta));
+
+  wtd_sum_lo = vmlal_u16(wtd_sum_lo, vget_low_u16(b), vget_low_u16(wtb));
+  wtd_sum_hi = vmlal_u16(wtd_sum_hi, vget_high_u16(b), vget_high_u16(wtb));
+
+  uint16x4_t wtd_avg_lo = vrshrn_n_u32(wtd_sum_lo, DIST_PRECISION_BITS);
+  uint16x4_t wtd_avg_hi = vrshrn_n_u32(wtd_sum_hi, DIST_PRECISION_BITS);
+
+  return vcombine_u16(wtd_avg_lo, wtd_avg_hi);
+}
+
+#endif  // AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_
diff --git a/aom_dsp/arm/highbd_variance_neon.c b/aom_dsp/arm/highbd_variance_neon.c
index 4dbca61..0aa0bb8 100644
--- a/aom_dsp/arm/highbd_variance_neon.c
+++ b/aom_dsp/arm/highbd_variance_neon.c
@@ -15,10 +15,11 @@
 #include "config/aom_config.h"
 #include "config/aom_dsp_rtcd.h"
 
-#include "aom_dsp/variance.h"
 #include "aom_dsp/aom_filter.h"
+#include "aom_dsp/arm/dist_wtd_avg_neon.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/sum_neon.h"
+#include "aom_dsp/variance.h"
 
 // Process a block of width 4 two rows at a time.
 static INLINE void highbd_variance_4xh_neon(const uint16_t *src_ptr,
@@ -474,8 +475,8 @@
   const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
   const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
-  const uint16x4_t fwd_offset_u16 = vdup_n_u16(jcp_param->fwd_offset);
-  const uint16x4_t bck_offset_u16 = vdup_n_u16(jcp_param->bck_offset);
+  const uint16x8_t fwd_offset_u16 = vdupq_n_u16(jcp_param->fwd_offset);
+  const uint16x8_t bck_offset_u16 = vdupq_n_u16(jcp_param->bck_offset);
 
   int i = height;
   if (width > 8) {
@@ -485,12 +486,9 @@
         const uint16x8_t p = vld1q_u16(pred + j);
         const uint16x8_t r = vld1q_u16(ref + j);
 
-        uint32x4_t cp0 = vmull_u16(vget_low_u16(p), bck_offset_u16);
-        uint32x4_t cp1 = vmull_u16(vget_high_u16(p), bck_offset_u16);
-        cp0 = vmlal_u16(cp0, vget_low_u16(r), fwd_offset_u16);
-        cp1 = vmlal_u16(cp1, vget_high_u16(r), fwd_offset_u16);
-        uint16x8_t avg = vcombine_u16(vrshrn_n_u32(cp0, DIST_PRECISION_BITS),
-                                      vrshrn_n_u32(cp1, DIST_PRECISION_BITS));
+        const uint16x8_t avg =
+            dist_wtd_avg_u16x8(r, p, fwd_offset_u16, bck_offset_u16);
+
         vst1q_u16(comp_pred + j, avg);
 
         j += 8;
@@ -505,12 +503,9 @@
       const uint16x8_t p = vld1q_u16(pred);
       const uint16x8_t r = vld1q_u16(ref);
 
-      uint32x4_t cp0 = vmull_u16(vget_low_u16(p), bck_offset_u16);
-      uint32x4_t cp1 = vmull_u16(vget_high_u16(p), bck_offset_u16);
-      cp0 = vmlal_u16(cp0, vget_low_u16(r), fwd_offset_u16);
-      cp1 = vmlal_u16(cp1, vget_high_u16(r), fwd_offset_u16);
-      uint16x8_t avg = vcombine_u16(vrshrn_n_u32(cp0, DIST_PRECISION_BITS),
-                                    vrshrn_n_u32(cp1, DIST_PRECISION_BITS));
+      const uint16x8_t avg =
+          dist_wtd_avg_u16x8(r, p, fwd_offset_u16, bck_offset_u16);
+
       vst1q_u16(comp_pred, avg);
 
       comp_pred += width;
@@ -523,9 +518,9 @@
       const uint16x4_t p = vld1_u16(pred);
       const uint16x4_t r = vld1_u16(ref);
 
-      uint32x4_t cp = vmull_u16(p, bck_offset_u16);
-      cp = vmlal_u16(cp, r, fwd_offset_u16);
-      uint16x4_t avg = vrshrn_n_u32(cp, DIST_PRECISION_BITS);
+      const uint16x4_t avg = dist_wtd_avg_u16x4(
+          r, p, vget_low_u16(fwd_offset_u16), vget_low_u16(bck_offset_u16));
+
       vst1_u16(comp_pred, avg);
 
       comp_pred += width;