Add NEON version of av1_lowbd_pixel_proj_error
SpeedUp
r0 0 r1 0 gain (7.62)
r0 1 r1 0 gain (4.03)
r0 0 r1 1 gain (4.03)
r0 1 r1 1 gain (3.73)
via NEON/PixelProjErrorTest.DISABLED_Speed
Change-Id: I317ec9e8725f1a088aa33682d0099f0b052918f5
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 0099a21..94d2102 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -376,6 +376,7 @@
list(APPEND AOM_AV1_ENCODER_INTRIN_NEON
"${AOM_ROOT}/av1/encoder/arm/neon/quantize_neon.c"
+ "${AOM_ROOT}/av1/encoder/arm/neon/picksrt_neon.c"
"${AOM_ROOT}/av1/encoder/arm/neon/av1_error_neon.c"
"${AOM_ROOT}/av1/encoder/arm/neon/encodetxb_neon.c"
"${AOM_ROOT}/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c"
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index a9b2942..3869447 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -396,7 +396,7 @@
specialize qw/av1_calc_proj_params avx2/;
add_proto qw/int64_t av1_lowbd_pixel_proj_error/, " const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params";
- specialize qw/av1_lowbd_pixel_proj_error sse4_1 avx2/;
+ specialize qw/av1_lowbd_pixel_proj_error sse4_1 avx2 neon/;
if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
add_proto qw/int64_t av1_highbd_pixel_proj_error/, " const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params";
diff --git a/av1/encoder/arm/neon/picksrt_neon.c b/av1/encoder/arm/neon/picksrt_neon.c
new file mode 100644
index 0000000..34595e8
--- /dev/null
+++ b/av1/encoder/arm/neon/picksrt_neon.c
@@ -0,0 +1,151 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <arm_neon.h>
+#include <math.h>
+
+#include "aom/aom_integer.h"
+#include "aom_mem/aom_mem.h"
+#include "aom_ports/mem.h"
+#include "av1/common/restoration.h"
+#include "common/tools_common.h"
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+
+int64_t av1_lowbd_pixel_proj_error_neon(
+ const uint8_t *src8, int width, int height, int src_stride,
+ const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
+ int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
+ int i, j, k;
+ const int32_t shift = SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS;
+ const int32x4_t zero = vdupq_n_s32(0);
+ uint64x2_t sum64 = vreinterpretq_u64_s32(zero);
+ const uint8_t *src = src8;
+ const uint8_t *dat = dat8;
+
+ int64_t err = 0;
+ if (params->r[0] > 0 && params->r[1] > 0) {
+ for (i = 0; i < height; ++i) {
+ int32x4_t err0 = zero;
+ for (j = 0; j <= width - 8; j += 8) {
+ const uint8x8_t d0 = vld1_u8(&dat[j]);
+ const uint8x8_t s0 = vld1_u8(&src[j]);
+ const int16x8_t flt0_16b =
+ vcombine_s16(vqmovn_s32(vld1q_s32(&flt0[j])),
+ vqmovn_s32(vld1q_s32(&flt0[j + 4])));
+ const int16x8_t flt1_16b =
+ vcombine_s16(vqmovn_s32(vld1q_s32(&flt1[j])),
+ vqmovn_s32(vld1q_s32(&flt1[j + 4])));
+ const int16x8_t u0 =
+ vreinterpretq_s16_u16(vshll_n_u8(d0, SGRPROJ_RST_BITS));
+ const int16x8_t flt0_0_sub_u = vsubq_s16(flt0_16b, u0);
+ const int16x8_t flt1_0_sub_u = vsubq_s16(flt1_16b, u0);
+ const int16x4_t flt0_16b_sub_u_lo = vget_low_s16(flt0_0_sub_u);
+ const int16x4_t flt0_16b_sub_u_hi = vget_high_s16(flt0_0_sub_u);
+ const int16x4_t flt1_16b_sub_u_lo = vget_low_s16(flt1_0_sub_u);
+ const int16x4_t flt1_16b_sub_u_hi = vget_high_s16(flt1_0_sub_u);
+
+ int32x4_t v0 = vmull_n_s16(flt0_16b_sub_u_lo, (int16_t)xq[0]);
+ v0 = vmlal_n_s16(v0, flt1_16b_sub_u_lo, (int16_t)xq[1]);
+ int32x4_t v1 = vmull_n_s16(flt0_16b_sub_u_hi, (int16_t)xq[0]);
+ v1 = vmlal_n_s16(v1, flt1_16b_sub_u_hi, (int16_t)xq[1]);
+ const int16x4_t vr0 = vqrshrn_n_s32(v0, 11);
+ const int16x4_t vr1 = vqrshrn_n_s32(v1, 11);
+ const int16x8_t e0 = vaddq_s16(vcombine_s16(vr0, vr1),
+ vreinterpretq_s16_u16(vsubl_u8(d0, s0)));
+ const int16x4_t e0_lo = vget_low_s16(e0);
+ const int16x4_t e0_hi = vget_high_s16(e0);
+ err0 = vmlal_s16(err0, e0_lo, e0_lo);
+ err0 = vmlal_s16(err0, e0_hi, e0_hi);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = dat[k] << SGRPROJ_RST_BITS;
+ int32_t v = xq[0] * (flt0[k] - u) + xq[1] * (flt1[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, 11) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt0 += flt0_stride;
+ flt1 += flt1_stride;
+ sum64 = vpadalq_u32(sum64, vreinterpretq_u32_s32(err0));
+ }
+
+ } else if (params->r[0] > 0 || params->r[1] > 0) {
+ const int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
+ const int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
+ const int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
+ for (i = 0; i < height; ++i) {
+ int32x4_t err0 = zero;
+ for (j = 0; j <= width - 8; j += 8) {
+ const uint8x8_t d0 = vld1_u8(&dat[j]);
+ const uint8x8_t s0 = vld1_u8(&src[j]);
+ const uint16x8_t d0s0 = vsubl_u8(d0, s0);
+ const uint16x8x2_t d0w =
+ vzipq_u16(vmovl_u8(d0), vreinterpretq_u16_s32(zero));
+
+ const int32x4_t flt_16b_lo = vld1q_s32(&flt[j]);
+ const int32x4_t flt_16b_hi = vld1q_s32(&flt[j + 4]);
+
+ int32x4_t v0 = vmulq_n_s32(flt_16b_lo, xq_active);
+ v0 = vmlsq_n_s32(v0, vreinterpretq_s32_u16(d0w.val[0]),
+ xq_active << SGRPROJ_RST_BITS);
+ int32x4_t v1 = vmulq_n_s32(flt_16b_hi, xq_active);
+ v1 = vmlsq_n_s32(v1, vreinterpretq_s32_u16(d0w.val[1]),
+ xq_active << SGRPROJ_RST_BITS);
+ const int16x4_t vr0 = vqrshrn_n_s32(v0, 11);
+ const int16x4_t vr1 = vqrshrn_n_s32(v1, 11);
+ const int16x8_t e0 =
+ vaddq_s16(vcombine_s16(vr0, vr1), vreinterpretq_s16_u16(d0s0));
+ const int16x4_t e0_lo = vget_low_s16(e0);
+ const int16x4_t e0_hi = vget_high_s16(e0);
+ err0 = vmlal_s16(err0, e0_lo, e0_lo);
+ err0 = vmlal_s16(err0, e0_hi, e0_hi);
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t u = dat[k] << SGRPROJ_RST_BITS;
+ int32_t v = xq_active * (flt[k] - u);
+ const int32_t e = ROUND_POWER_OF_TWO(v, shift) + dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ flt += flt_stride;
+ sum64 = vpadalq_u32(sum64, vreinterpretq_u32_s32(err0));
+ }
+ } else {
+ uint32x4_t err0 = vreinterpretq_u32_s32(zero);
+ for (i = 0; i < height; ++i) {
+ for (j = 0; j <= width - 16; j += 16) {
+ const uint8x16_t d = vld1q_u8(&dat[j]);
+ const uint8x16_t s = vld1q_u8(&src[j]);
+ const uint8x16_t diff = vabdq_u8(d, s);
+ const uint8x8_t diff0 = vget_low_u8(diff);
+ const uint8x8_t diff1 = vget_high_u8(diff);
+ err0 = vpadalq_u16(err0, vmull_u8(diff0, diff0));
+ err0 = vpadalq_u16(err0, vmull_u8(diff1, diff1));
+ }
+ for (k = j; k < width; ++k) {
+ const int32_t e = dat[k] - src[k];
+ err += e * e;
+ }
+ dat += dat_stride;
+ src += src_stride;
+ }
+ sum64 = vpaddlq_u32(err0);
+ }
+#if defined(__aarch64__)
+ err += vaddvq_u64(sum64);
+#else
+ err += vget_lane_u64(vadd_u64(vget_low_u64(sum64), vget_high_u64(sum64)), 0);
+#endif // __aarch64__
+ return err;
+}
diff --git a/test/pickrst_test.cc b/test/pickrst_test.cc
index 9a2c5bc..77c6b47 100644
--- a/test/pickrst_test.cc
+++ b/test/pickrst_test.cc
@@ -188,6 +188,12 @@
::testing::Values(av1_lowbd_pixel_proj_error_avx2));
#endif // HAVE_AVX2
+#if HAVE_NEON
+
+INSTANTIATE_TEST_SUITE_P(NEON, PixelProjErrorTest,
+ ::testing::Values(av1_lowbd_pixel_proj_error_neon));
+#endif // HAVE_NEON
+
} // namespace pickrst_test_lowbd
#if CONFIG_AV1_HIGHBITDEPTH