Add AVX2 variant for sum_squares_2d
Achieved module level gains improved by ~71%
on average for width >= 16 and observed no drop
in gains for remaining widths.
Change-Id: Ica295d97b9854a2a22e3af9cb8671f91376e2562
diff --git a/aom_dsp/aom_dsp.cmake b/aom_dsp/aom_dsp.cmake
index a76f88f..e2bf87b 100644
--- a/aom_dsp/aom_dsp.cmake
+++ b/aom_dsp/aom_dsp.cmake
@@ -67,7 +67,8 @@
"${AOM_ROOT}/aom_dsp/x86/lpf_common_sse2.h"
"${AOM_ROOT}/aom_dsp/x86/mem_sse2.h"
"${AOM_ROOT}/aom_dsp/x86/transpose_sse2.h"
- "${AOM_ROOT}/aom_dsp/x86/txfm_common_sse2.h")
+ "${AOM_ROOT}/aom_dsp/x86/txfm_common_sse2.h"
+ "${AOM_ROOT}/aom_dsp/x86/sum_squares_sse2.h")
list(APPEND AOM_DSP_COMMON_ASM_SSSE3
"${AOM_ROOT}/aom_dsp/x86/aom_subpixel_8t_ssse3.asm"
@@ -201,7 +202,8 @@
"${AOM_ROOT}/aom_dsp/x86/sse_avx2.c"
"${AOM_ROOT}/aom_dsp/x86/variance_impl_avx2.c"
"${AOM_ROOT}/aom_dsp/x86/obmc_sad_avx2.c"
- "${AOM_ROOT}/aom_dsp/x86/obmc_variance_avx2.c")
+ "${AOM_ROOT}/aom_dsp/x86/obmc_variance_avx2.c"
+ "${AOM_ROOT}/aom_dsp/x86/sum_squares_avx2.c")
list(APPEND AOM_DSP_ENCODER_ASM_SSSE3_X86_64
"${AOM_ROOT}/aom_dsp/x86/quantize_ssse3_x86_64.asm")
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 27e95b9..50af6fd 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -580,7 +580,7 @@
# Sum of Squares
#
add_proto qw/uint64_t aom_sum_squares_2d_i16/, "const int16_t *src, int stride, int width, int height";
- specialize qw/aom_sum_squares_2d_i16 sse2/;
+ specialize qw/aom_sum_squares_2d_i16 sse2 avx2/;
add_proto qw/uint64_t aom_sum_squares_i16/, "const int16_t *src, uint32_t N";
specialize qw/aom_sum_squares_i16 sse2/;
diff --git a/aom_dsp/x86/sum_squares_avx2.c b/aom_dsp/x86/sum_squares_avx2.c
new file mode 100644
index 0000000..4c00b15
--- /dev/null
+++ b/aom_dsp/x86/sum_squares_avx2.c
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include <immintrin.h>
+#include <smmintrin.h>
+
+#include "aom_dsp/x86/synonyms.h"
+#include "aom_dsp/x86/sum_squares_sse2.h"
+#include "config/aom_dsp_rtcd.h"
+
+static uint64_t aom_sum_squares_2d_i16_nxn_avx2(const int16_t *src, int stride,
+ int width, int height) {
+ uint64_t result;
+ __m256i v_acc_q = _mm256_setzero_si256();
+ const __m256i v_zext_mask_q = _mm256_set1_epi64x(0xffffffff);
+ for (int col = 0; col < height; col += 4) {
+ __m256i v_acc_d = _mm256_setzero_si256();
+ for (int row = 0; row < width; row += 16) {
+ const int16_t *tempsrc = src + row;
+ const __m256i v_val_0_w =
+ _mm256_loadu_si256((const __m256i *)(tempsrc + 0 * stride));
+ const __m256i v_val_1_w =
+ _mm256_loadu_si256((const __m256i *)(tempsrc + 1 * stride));
+ const __m256i v_val_2_w =
+ _mm256_loadu_si256((const __m256i *)(tempsrc + 2 * stride));
+ const __m256i v_val_3_w =
+ _mm256_loadu_si256((const __m256i *)(tempsrc + 3 * stride));
+
+ const __m256i v_sq_0_d = _mm256_madd_epi16(v_val_0_w, v_val_0_w);
+ const __m256i v_sq_1_d = _mm256_madd_epi16(v_val_1_w, v_val_1_w);
+ const __m256i v_sq_2_d = _mm256_madd_epi16(v_val_2_w, v_val_2_w);
+ const __m256i v_sq_3_d = _mm256_madd_epi16(v_val_3_w, v_val_3_w);
+
+ const __m256i v_sum_01_d = _mm256_add_epi32(v_sq_0_d, v_sq_1_d);
+ const __m256i v_sum_23_d = _mm256_add_epi32(v_sq_2_d, v_sq_3_d);
+ const __m256i v_sum_0123_d = _mm256_add_epi32(v_sum_01_d, v_sum_23_d);
+
+ v_acc_d = _mm256_add_epi32(v_acc_d, v_sum_0123_d);
+ }
+ v_acc_q =
+ _mm256_add_epi64(v_acc_q, _mm256_and_si256(v_acc_d, v_zext_mask_q));
+ v_acc_q = _mm256_add_epi64(v_acc_q, _mm256_srli_epi64(v_acc_d, 32));
+ src += 4 * stride;
+ }
+ __m128i lower_64_2_Value = _mm256_castsi256_si128(v_acc_q);
+ __m128i higher_64_2_Value = _mm256_extracti128_si256(v_acc_q, 1);
+ __m128i result_64_2_int = _mm_add_epi64(lower_64_2_Value, higher_64_2_Value);
+
+ result_64_2_int = _mm_add_epi64(
+ result_64_2_int, _mm_unpackhi_epi64(result_64_2_int, result_64_2_int));
+
+ xx_storel_64(&result, result_64_2_int);
+
+ return result;
+}
+
+uint64_t aom_sum_squares_2d_i16_avx2(const int16_t *src, int stride, int width,
+ int height) {
+ if (LIKELY(width == 4 && height == 4)) {
+ return aom_sum_squares_2d_i16_4x4_sse2(src, stride);
+ } else if (LIKELY(width == 4 && (height & 3) == 0)) {
+ return aom_sum_squares_2d_i16_4xn_sse2(src, stride, height);
+ } else if (LIKELY(width == 8 && (height & 3) == 0)) {
+ return aom_sum_squares_2d_i16_nxn_sse2(src, stride, width, height);
+ } else if (LIKELY(((width & 15) == 0) && ((height & 3) == 0))) {
+ return aom_sum_squares_2d_i16_nxn_avx2(src, stride, width, height);
+ } else {
+ return aom_sum_squares_2d_i16_c(src, stride, width, height);
+ }
+}
diff --git a/aom_dsp/x86/sum_squares_sse2.c b/aom_dsp/x86/sum_squares_sse2.c
index a79f22d..22d7739 100644
--- a/aom_dsp/x86/sum_squares_sse2.c
+++ b/aom_dsp/x86/sum_squares_sse2.c
@@ -14,6 +14,7 @@
#include <stdio.h>
#include "aom_dsp/x86/synonyms.h"
+#include "aom_dsp/x86/sum_squares_sse2.h"
#include "config/aom_dsp_rtcd.h"
static INLINE __m128i xx_loadh_64(__m128i a, const void *b) {
@@ -44,8 +45,7 @@
return _mm_add_epi32(v_sq_01_d, v_sq_23_d);
}
-static uint64_t aom_sum_squares_2d_i16_4x4_sse2(const int16_t *src,
- int stride) {
+uint64_t aom_sum_squares_2d_i16_4x4_sse2(const int16_t *src, int stride) {
const __m128i v_sum_0123_d = sum_squares_i16_4x4_sse2(src, stride);
__m128i v_sum_d =
_mm_add_epi32(v_sum_0123_d, _mm_srli_epi64(v_sum_0123_d, 32));
@@ -53,8 +53,8 @@
return (uint64_t)_mm_cvtsi128_si32(v_sum_d);
}
-static uint64_t aom_sum_squares_2d_i16_4xn_sse2(const int16_t *src, int stride,
- int height) {
+uint64_t aom_sum_squares_2d_i16_4xn_sse2(const int16_t *src, int stride,
+ int height) {
int r = 0;
__m128i v_acc_q = _mm_setzero_si128();
do {
@@ -76,7 +76,7 @@
// maintenance instructions in the common case of 4x4.
__attribute__((noinline))
#endif
-static uint64_t
+uint64_t
aom_sum_squares_2d_i16_nxn_sse2(const int16_t *src, int stride, int width,
int height) {
int r = 0;
diff --git a/aom_dsp/x86/sum_squares_sse2.h b/aom_dsp/x86/sum_squares_sse2.h
new file mode 100644
index 0000000..491e31c
--- /dev/null
+++ b/aom_dsp/x86/sum_squares_sse2.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (c) 2018, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_DSP_X86_SUM_SQUARES_SSE2_H_
+#define AOM_DSP_X86_SUM_SQUARES_SSE2_H_
+
+uint64_t aom_sum_squares_2d_i16_nxn_sse2(const int16_t *src, int stride,
+ int width, int height);
+
+uint64_t aom_sum_squares_2d_i16_4xn_sse2(const int16_t *src, int stride,
+ int height);
+uint64_t aom_sum_squares_2d_i16_4x4_sse2(const int16_t *src, int stride);
+
+#endif // AOM_DSP_X86_SUM_SQUARES_SSE2_H_
diff --git a/test/sum_squares_test.cc b/test/sum_squares_test.cc
index c03ebad..f109984 100644
--- a/test/sum_squares_test.cc
+++ b/test/sum_squares_test.cc
@@ -52,6 +52,7 @@
aom_free(src_);
}
void RunTest(int isRandom);
+ void RunSpeedTest();
void GenRandomData(int width, int height, int stride) {
const int msb = 11; // Up to 12 bit input
@@ -108,6 +109,38 @@
}
}
+void SumSquaresTest::RunSpeedTest() {
+ for (int block = BLOCK_4X4; block < BLOCK_SIZES_ALL; block++) {
+ const int width = block_size_wide[block]; // Up to 128x128
+ const int height = block_size_high[block]; // Up to 128x128
+ int stride = 4 << rnd_(7); // Up to 256 stride
+ while (stride < width) { // Make sure it's valid
+ stride = 4 << rnd_(7);
+ }
+ GenExtremeData(width, height, stride);
+ const int num_loops = 1000000000 / (width + height);
+ aom_usec_timer timer;
+ aom_usec_timer_start(&timer);
+
+ for (int i = 0; i < num_loops; ++i)
+ params_.ref_func(src_, stride, width, height);
+
+ aom_usec_timer_mark(&timer);
+ const int elapsed_time = static_cast<int>(aom_usec_timer_elapsed(&timer));
+ printf("SumSquaresTest C %3dx%-3d: %7.2f ns\n", width, height,
+ 1000.0 * elapsed_time / num_loops);
+
+ aom_usec_timer timer1;
+ aom_usec_timer_start(&timer1);
+ for (int i = 0; i < num_loops; ++i)
+ params_.tst_func(src_, stride, width, height);
+ aom_usec_timer_mark(&timer1);
+ const int elapsed_time1 = static_cast<int>(aom_usec_timer_elapsed(&timer1));
+ printf("SumSquaresTest Test %3dx%-3d: %7.2f ns\n", width, height,
+ 1000.0 * elapsed_time1 / num_loops);
+ }
+}
+
TEST_P(SumSquaresTest, OperationCheck) {
RunTest(1); // GenRandomData
}
@@ -116,6 +149,8 @@
RunTest(0); // GenExtremeData
}
+TEST_P(SumSquaresTest, DISABLED_Speed) { RunSpeedTest(); }
+
#if HAVE_SSE2
INSTANTIATE_TEST_CASE_P(
@@ -125,6 +160,13 @@
#endif // HAVE_SSE2
+#if HAVE_AVX2
+INSTANTIATE_TEST_CASE_P(
+ AVX2, SumSquaresTest,
+ ::testing::Values(TestFuncs(&aom_sum_squares_2d_i16_c,
+ &aom_sum_squares_2d_i16_avx2)));
+#endif // HAVE_AVX2
+
//////////////////////////////////////////////////////////////////////////////
// 1D version
//////////////////////////////////////////////////////////////////////////////