Add NEON version of av1_fwht4x4 function
AVG gain = 1.4 via NEON/Trans4x4WHT.DISABLED_Speed
Change-Id: I778bc07eb9d5a74a697d77e56c8e95092ad50012
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 4bef55e..f280ccf 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -383,6 +383,7 @@
"${AOM_ROOT}/av1/encoder/arm/neon/rdopt_neon.c"
"${AOM_ROOT}/av1/encoder/arm/neon/av1_error_neon.c"
"${AOM_ROOT}/av1/encoder/arm/neon/encodetxb_neon.c"
+ "${AOM_ROOT}/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c"
"${AOM_ROOT}/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c"
"${AOM_ROOT}/av1/encoder/arm/neon/highbd_fwd_txfm_neon.c")
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 2a94ef3..901203e 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -298,6 +298,7 @@
# fdct functions
add_proto qw/void av1_fwht4x4/, "const int16_t *input, tran_low_t *output, int stride";
+ specialize qw/av1_fwht4x4 neon/;
#fwd txfm
add_proto qw/void av1_lowbd_fwd_txfm/, "const int16_t *src_diff, tran_low_t *coeff, int diff_stride, TxfmParam *txfm_param";
@@ -364,6 +365,7 @@
}
add_proto qw/void av1_highbd_fwht4x4/, "const int16_t *input, tran_low_t *output, int stride";
+ specialize qw/av1_highbd_fwht4x4 neon/;
# End av1_high encoder functions
diff --git a/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c b/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c
new file mode 100644
index 0000000..0ad1131
--- /dev/null
+++ b/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <arm_neon.h>
+
+#include "aom_dsp/txfm_common.h"
+
+static void transpose4x4(int16x8_t in[2], int16x4_t out[4]) {
+ int32x4x2_t b0 =
+ vtrnq_s32(vreinterpretq_s32_s16(in[0]), vreinterpretq_s32_s16(in[1]));
+ int16x4x2_t c0 = vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[0])),
+ vreinterpret_s16_s32(vget_high_s32(b0.val[0])));
+ int16x4x2_t c1 = vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[1])),
+ vreinterpret_s16_s32(vget_high_s32(b0.val[1])));
+ out[0] = c0.val[0];
+ out[1] = c0.val[1];
+ out[2] = c1.val[0];
+ out[3] = c1.val[1];
+}
+
+void av1_fwht4x4_neon(const int16_t *input, tran_low_t *output, int stride) {
+ // Load the 4x4 source in transposed form.
+ int16x4_t a1, b1, c1, d1, e;
+ a1 = vld1_s16(&input[0]);
+ b1 = vld1_s16(&input[1 * stride]);
+ c1 = vld1_s16(&input[2 * stride]);
+ d1 = vld1_s16(&input[3 * stride]);
+
+ // WHT.
+
+ // Row transforms.
+ a1 = vadd_s16(a1, b1);
+ d1 = vsub_s16(d1, c1);
+ e = vhsub_s16(a1, d1);
+ b1 = vsub_s16(e, b1);
+ c1 = vsub_s16(e, c1);
+ a1 = vsub_s16(a1, c1);
+ d1 = vadd_s16(d1, b1);
+
+ int16x8_t x[2];
+ x[0] = vcombine_s16(a1, c1);
+ x[1] = vcombine_s16(d1, b1);
+
+ int16x4_t s[4];
+ transpose4x4(x, s);
+
+ a1 = s[0];
+ b1 = s[1];
+ c1 = s[2];
+ d1 = s[3];
+
+ // Row transforms.
+ a1 = vadd_s16(a1, b1);
+ d1 = vsub_s16(d1, c1);
+ e = vhsub_s16(a1, d1);
+ b1 = vsub_s16(e, b1);
+ c1 = vsub_s16(e, c1);
+ a1 = vsub_s16(a1, c1);
+ d1 = vadd_s16(d1, b1);
+
+ x[0] = vcombine_s16(a1, c1);
+ x[1] = vcombine_s16(d1, b1);
+
+ transpose4x4(x, s);
+
+ vst1q_s32(&output[0], vshll_n_s16(s[0], UNIT_QUANT_SHIFT));
+ vst1q_s32(&output[4], vshll_n_s16(s[1], UNIT_QUANT_SHIFT));
+ vst1q_s32(&output[8], vshll_n_s16(s[2], UNIT_QUANT_SHIFT));
+ vst1q_s32(&output[12], vshll_n_s16(s[3], UNIT_QUANT_SHIFT));
+}
+
+void av1_highbd_fwht4x4_neon(const int16_t *input, tran_low_t *output,
+ int stride) {
+ av1_fwht4x4_neon(input, output, stride);
+}
diff --git a/test/fwht4x4_test.cc b/test/fwht4x4_test.cc
index d2f77b8..b600d26 100644
--- a/test/fwht4x4_test.cc
+++ b/test/fwht4x4_test.cc
@@ -37,7 +37,7 @@
using libaom_test::FhtFunc;
-typedef std::tuple<FdctFunc, IdctFunc, TX_TYPE, aom_bit_depth_t, int>
+typedef std::tuple<FdctFunc, IdctFunc, TX_TYPE, aom_bit_depth_t, int, FdctFunc>
Dct4x4Param;
void fwht4x4_ref(const int16_t *in, tran_low_t *out, int stride,
@@ -67,6 +67,7 @@
bit_depth_ = GET_PARAM(3);
mask_ = (1 << bit_depth_) - 1;
num_coeffs_ = GET_PARAM(4);
+ fwd_txfm_c_ = GET_PARAM(5);
}
virtual void TearDown() { libaom_test::ClearSystemState(); }
@@ -77,9 +78,89 @@
void RunInvTxfm(const tran_low_t *out, uint8_t *dst, int stride) {
inv_txfm_(out, dst, stride);
}
+ void RunSpeedTest() {
+ if (!fwd_txfm_c_) {
+ GTEST_SKIP();
+ } else {
+ ACMRandom rnd(ACMRandom::DeterministicSeed());
+ const int count_test_block = 10;
+ const int numIter = 5000;
+
+ int c_sum_time = 0;
+ int simd_sum_time = 0;
+
+ int stride = 96;
+
+ int16_t *input_block = reinterpret_cast<int16_t *>(
+ aom_memalign(16, sizeof(int16_t) * stride * height_));
+ tran_low_t *output_ref_block = reinterpret_cast<tran_low_t *>(
+ aom_memalign(16, sizeof(output_ref_block[0]) * num_coeffs_));
+ tran_low_t *output_block = reinterpret_cast<tran_low_t *>(
+ aom_memalign(16, sizeof(output_block[0]) * num_coeffs_));
+
+ for (int i = 0; i < count_test_block; ++i) {
+ int j, k;
+ for (j = 0; j < height_; ++j) {
+ for (k = 0; k < pitch_; ++k) {
+ int in_idx = j * stride + k;
+ int out_idx = j * pitch_ + k;
+ input_block[in_idx] =
+ (rnd.Rand16() & mask_) - (rnd.Rand16() & mask_);
+ if (bit_depth_ == AOM_BITS_8) {
+ output_block[out_idx] = output_ref_block[out_idx] = rnd.Rand8();
+ } else {
+ output_block[out_idx] = output_ref_block[out_idx] =
+ rnd.Rand16() & mask_;
+ }
+ }
+ }
+
+ aom_usec_timer c_timer_;
+ aom_usec_timer_start(&c_timer_);
+ for (int i = 0; i < numIter; i++) {
+ ASM_REGISTER_STATE_CHECK(
+ fwd_txfm_c_(input_block, output_ref_block, stride));
+ }
+ aom_usec_timer_mark(&c_timer_);
+
+ aom_usec_timer simd_timer_;
+ aom_usec_timer_start(&simd_timer_);
+
+ for (int i = 0; i < numIter; i++) {
+ ASM_REGISTER_STATE_CHECK(
+ fwd_txfm_(input_block, output_block, stride));
+ }
+ aom_usec_timer_mark(&simd_timer_);
+
+ c_sum_time += static_cast<int>(aom_usec_timer_elapsed(&c_timer_));
+ simd_sum_time += static_cast<int>(aom_usec_timer_elapsed(&simd_timer_));
+
+ // The minimum quant value is 4.
+ for (j = 0; j < height_; ++j) {
+ for (k = 0; k < pitch_; ++k) {
+ int out_idx = j * pitch_ + k;
+ ASSERT_EQ(output_block[out_idx], output_ref_block[out_idx])
+ << "Error: not bit-exact result at index: " << out_idx
+ << " at test block: " << i;
+ }
+ }
+ }
+
+ printf(
+ "c_time = %d \t simd_time = %d \t Gain = %4.2f \n", c_sum_time,
+ simd_sum_time,
+ (static_cast<float>(c_sum_time) / static_cast<float>(simd_sum_time)));
+
+ aom_free(input_block);
+ aom_free(output_ref_block);
+ aom_free(output_block);
+ }
+ }
FdctFunc fwd_txfm_;
IdctFunc inv_txfm_;
+
+ FdctFunc fwd_txfm_c_; // C version of forward transform for speed test.
};
TEST_P(Trans4x4WHT, AccuracyCheck) { RunAccuracyCheck(0, 0.00001); }
@@ -89,12 +170,27 @@
TEST_P(Trans4x4WHT, MemCheck) { RunMemCheck(); }
TEST_P(Trans4x4WHT, InvAccuracyCheck) { RunInvAccuracyCheck(0); }
+
+TEST_P(Trans4x4WHT, DISABLED_Speed) { RunSpeedTest(); }
+
using std::make_tuple;
INSTANTIATE_TEST_SUITE_P(
C, Trans4x4WHT,
::testing::Values(make_tuple(&av1_highbd_fwht4x4_c, &iwht4x4_10, DCT_DCT,
- AOM_BITS_10, 16),
+ AOM_BITS_10, 16, static_cast<FdctFunc>(NULL)),
make_tuple(&av1_highbd_fwht4x4_c, &iwht4x4_12, DCT_DCT,
- AOM_BITS_12, 16)));
+ AOM_BITS_12, 16,
+ static_cast<FdctFunc>(NULL))));
+#if HAVE_NEON
+
+INSTANTIATE_TEST_SUITE_P(
+ NEON, Trans4x4WHT,
+ ::testing::Values(make_tuple(&av1_highbd_fwht4x4_neon, &iwht4x4_10, DCT_DCT,
+ AOM_BITS_10, 16, &av1_highbd_fwht4x4_c),
+ make_tuple(&av1_highbd_fwht4x4_neon, &iwht4x4_12, DCT_DCT,
+ AOM_BITS_12, 16, &av1_highbd_fwht4x4_c)));
+
+#endif // HAVE_NEON
+
} // namespace