Add Neon dist_wtd_avg functions to a new dist_wtd_avg.h file
Add Neon dist_wtd_avg_<size> helper functions to dist_wtd_avg.h and
use them to reduce code duplication in avg_pred_neon.c and
highbd_variance_neon.c.
Change-Id: If5f1104b7f7cd15f7acd409d1850f19cf88558da
diff --git a/aom_dsp/arm/avg_pred_neon.c b/aom_dsp/arm/avg_pred_neon.c
index 8fe151d..b17f7fc 100644
--- a/aom_dsp/arm/avg_pred_neon.c
+++ b/aom_dsp/arm/avg_pred_neon.c
@@ -14,8 +14,9 @@
#include "config/aom_dsp_rtcd.h"
-#include "aom_dsp/arm/mem_neon.h"
#include "aom_dsp/arm/blend_neon.h"
+#include "aom_dsp/arm/dist_wtd_avg_neon.h"
+#include "aom_dsp/arm/mem_neon.h"
#include "aom_dsp/blend.h"
void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width,
@@ -94,22 +95,8 @@
const uint8x16_t p = vld1q_u8(pred_ptr);
const uint8x16_t r = vld1q_u8(ref_ptr);
- uint16x8_t wtd_sum_lo =
- vmull_u8(vget_low_u8(p), vget_low_u8(bck_offset));
- uint16x8_t wtd_sum_hi =
- vmull_u8(vget_high_u8(p), vget_high_u8(bck_offset));
-
- wtd_sum_lo =
- vmlal_u8(wtd_sum_lo, vget_low_u8(r), vget_low_u8(fwd_offset));
- wtd_sum_hi =
- vmlal_u8(wtd_sum_hi, vget_high_u8(r), vget_high_u8(fwd_offset));
-
- const uint8x8_t wtd_avg_lo =
- vrshrn_n_u16(wtd_sum_lo, DIST_PRECISION_BITS);
- const uint8x8_t wtd_avg_hi =
- vrshrn_n_u16(wtd_sum_hi, DIST_PRECISION_BITS);
-
- const uint8x16_t wtd_avg = vcombine_u8(wtd_avg_lo, wtd_avg_hi);
+ const uint8x16_t wtd_avg =
+ dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset);
vst1q_u8(comp_pred_ptr, wtd_avg);
@@ -130,21 +117,8 @@
const uint8x16_t p = vld1q_u8(pred);
const uint8x16_t r = load_u8_8x2(ref, ref_stride);
- uint16x8_t wtd_sum_lo = vmull_u8(vget_low_u8(p), vget_low_u8(bck_offset));
- uint16x8_t wtd_sum_hi =
- vmull_u8(vget_high_u8(p), vget_high_u8(bck_offset));
-
- wtd_sum_lo =
- vmlal_u8(wtd_sum_lo, vget_low_u8(r), vget_low_u8(fwd_offset));
- wtd_sum_hi =
- vmlal_u8(wtd_sum_hi, vget_high_u8(r), vget_high_u8(fwd_offset));
-
- const uint8x8_t wtd_avg_lo =
- vrshrn_n_u16(wtd_sum_lo, DIST_PRECISION_BITS);
- const uint8x8_t wtd_avg_hi =
- vrshrn_n_u16(wtd_sum_hi, DIST_PRECISION_BITS);
-
- const uint8x16_t wtd_avg = vcombine_u8(wtd_avg_lo, wtd_avg_hi);
+ const uint8x16_t wtd_avg =
+ dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset);
vst1q_u8(comp_pred, wtd_avg);
@@ -160,10 +134,8 @@
const uint8x8_t p = vld1_u8(pred);
const uint8x8_t r = load_unaligned_u8_4x2(ref, ref_stride);
- uint16x8_t wtd_sum = vmull_u8(p, vget_low_u8(bck_offset));
- wtd_sum = vmlal_u8(wtd_sum, r, vget_low_u8(fwd_offset));
-
- const uint8x8_t wtd_avg = vrshrn_n_u16(wtd_sum, DIST_PRECISION_BITS);
+ const uint8x8_t wtd_avg = dist_wtd_avg_u8x8(r, p, vget_low_u8(fwd_offset),
+ vget_low_u8(bck_offset));
vst1_u8(comp_pred, wtd_avg);
diff --git a/aom_dsp/arm/dist_wtd_avg_neon.h b/aom_dsp/arm/dist_wtd_avg_neon.h
new file mode 100644
index 0000000..19c9b04
--- /dev/null
+++ b/aom_dsp/arm/dist_wtd_avg_neon.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_
+#define AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_
+
+#include <arm_neon.h>
+
+#include "aom_dsp/aom_dsp_common.h"
+#include "av1/common/enums.h"
+
+static INLINE uint8x8_t dist_wtd_avg_u8x8(uint8x8_t a, uint8x8_t b,
+ uint8x8_t wta, uint8x8_t wtb) {
+ uint16x8_t wtd_sum = vmull_u8(a, wta);
+
+ wtd_sum = vmlal_u8(wtd_sum, b, wtb);
+
+ return vrshrn_n_u16(wtd_sum, DIST_PRECISION_BITS);
+}
+
+static INLINE uint16x4_t dist_wtd_avg_u16x4(uint16x4_t a, uint16x4_t b,
+ uint16x4_t wta, uint16x4_t wtb) {
+ uint32x4_t wtd_sum = vmull_u16(a, wta);
+
+ wtd_sum = vmlal_u16(wtd_sum, b, wtb);
+
+ return vrshrn_n_u32(wtd_sum, DIST_PRECISION_BITS);
+}
+
+static INLINE uint8x16_t dist_wtd_avg_u8x16(uint8x16_t a, uint8x16_t b,
+ uint8x16_t wta, uint8x16_t wtb) {
+ uint16x8_t wtd_sum_lo = vmull_u8(vget_low_u8(a), vget_low_u8(wta));
+ uint16x8_t wtd_sum_hi = vmull_u8(vget_high_u8(a), vget_high_u8(wta));
+
+ wtd_sum_lo = vmlal_u8(wtd_sum_lo, vget_low_u8(b), vget_low_u8(wtb));
+ wtd_sum_hi = vmlal_u8(wtd_sum_hi, vget_high_u8(b), vget_high_u8(wtb));
+
+ uint8x8_t wtd_avg_lo = vrshrn_n_u16(wtd_sum_lo, DIST_PRECISION_BITS);
+ uint8x8_t wtd_avg_hi = vrshrn_n_u16(wtd_sum_hi, DIST_PRECISION_BITS);
+
+ return vcombine_u8(wtd_avg_lo, wtd_avg_hi);
+}
+
+static INLINE uint16x8_t dist_wtd_avg_u16x8(uint16x8_t a, uint16x8_t b,
+ uint16x8_t wta, uint16x8_t wtb) {
+ uint32x4_t wtd_sum_lo = vmull_u16(vget_low_u16(a), vget_low_u16(wta));
+ uint32x4_t wtd_sum_hi = vmull_u16(vget_high_u16(a), vget_high_u16(wta));
+
+ wtd_sum_lo = vmlal_u16(wtd_sum_lo, vget_low_u16(b), vget_low_u16(wtb));
+ wtd_sum_hi = vmlal_u16(wtd_sum_hi, vget_high_u16(b), vget_high_u16(wtb));
+
+ uint16x4_t wtd_avg_lo = vrshrn_n_u32(wtd_sum_lo, DIST_PRECISION_BITS);
+ uint16x4_t wtd_avg_hi = vrshrn_n_u32(wtd_sum_hi, DIST_PRECISION_BITS);
+
+ return vcombine_u16(wtd_avg_lo, wtd_avg_hi);
+}
+
+#endif // AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_
diff --git a/aom_dsp/arm/highbd_variance_neon.c b/aom_dsp/arm/highbd_variance_neon.c
index 4dbca61..0aa0bb8 100644
--- a/aom_dsp/arm/highbd_variance_neon.c
+++ b/aom_dsp/arm/highbd_variance_neon.c
@@ -15,10 +15,11 @@
#include "config/aom_config.h"
#include "config/aom_dsp_rtcd.h"
-#include "aom_dsp/variance.h"
#include "aom_dsp/aom_filter.h"
+#include "aom_dsp/arm/dist_wtd_avg_neon.h"
#include "aom_dsp/arm/mem_neon.h"
#include "aom_dsp/arm/sum_neon.h"
+#include "aom_dsp/variance.h"
// Process a block of width 4 two rows at a time.
static INLINE void highbd_variance_4xh_neon(const uint16_t *src_ptr,
@@ -474,8 +475,8 @@
const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
- const uint16x4_t fwd_offset_u16 = vdup_n_u16(jcp_param->fwd_offset);
- const uint16x4_t bck_offset_u16 = vdup_n_u16(jcp_param->bck_offset);
+ const uint16x8_t fwd_offset_u16 = vdupq_n_u16(jcp_param->fwd_offset);
+ const uint16x8_t bck_offset_u16 = vdupq_n_u16(jcp_param->bck_offset);
int i = height;
if (width > 8) {
@@ -485,12 +486,9 @@
const uint16x8_t p = vld1q_u16(pred + j);
const uint16x8_t r = vld1q_u16(ref + j);
- uint32x4_t cp0 = vmull_u16(vget_low_u16(p), bck_offset_u16);
- uint32x4_t cp1 = vmull_u16(vget_high_u16(p), bck_offset_u16);
- cp0 = vmlal_u16(cp0, vget_low_u16(r), fwd_offset_u16);
- cp1 = vmlal_u16(cp1, vget_high_u16(r), fwd_offset_u16);
- uint16x8_t avg = vcombine_u16(vrshrn_n_u32(cp0, DIST_PRECISION_BITS),
- vrshrn_n_u32(cp1, DIST_PRECISION_BITS));
+ const uint16x8_t avg =
+ dist_wtd_avg_u16x8(r, p, fwd_offset_u16, bck_offset_u16);
+
vst1q_u16(comp_pred + j, avg);
j += 8;
@@ -505,12 +503,9 @@
const uint16x8_t p = vld1q_u16(pred);
const uint16x8_t r = vld1q_u16(ref);
- uint32x4_t cp0 = vmull_u16(vget_low_u16(p), bck_offset_u16);
- uint32x4_t cp1 = vmull_u16(vget_high_u16(p), bck_offset_u16);
- cp0 = vmlal_u16(cp0, vget_low_u16(r), fwd_offset_u16);
- cp1 = vmlal_u16(cp1, vget_high_u16(r), fwd_offset_u16);
- uint16x8_t avg = vcombine_u16(vrshrn_n_u32(cp0, DIST_PRECISION_BITS),
- vrshrn_n_u32(cp1, DIST_PRECISION_BITS));
+ const uint16x8_t avg =
+ dist_wtd_avg_u16x8(r, p, fwd_offset_u16, bck_offset_u16);
+
vst1q_u16(comp_pred, avg);
comp_pred += width;
@@ -523,9 +518,9 @@
const uint16x4_t p = vld1_u16(pred);
const uint16x4_t r = vld1_u16(ref);
- uint32x4_t cp = vmull_u16(p, bck_offset_u16);
- cp = vmlal_u16(cp, r, fwd_offset_u16);
- uint16x4_t avg = vrshrn_n_u32(cp, DIST_PRECISION_BITS);
+ const uint16x4_t avg = dist_wtd_avg_u16x4(
+ r, p, vget_low_u16(fwd_offset_u16), vget_low_u16(bck_offset_u16));
+
vst1_u16(comp_pred, avg);
comp_pred += width;