SSE2 k_means implementation Speed test: av1_calc_indices_dim1_sse2() : 10.54x av1_calc_indices_dim2_sse2() : 8.5x Change-Id: I0beb82013ee0294d3c7380c5cea25af8936ec35b
diff --git a/av1/av1.cmake b/av1/av1.cmake index 204be68..b48c614 100644 --- a/av1/av1.cmake +++ b/av1/av1.cmake
@@ -386,6 +386,7 @@ "${AOM_ROOT}/av1/encoder/x86/encodetxb_sse2.c" "${AOM_ROOT}/av1/encoder/x86/highbd_block_error_intrin_sse2.c" "${AOM_ROOT}/av1/encoder/x86/temporal_filter_sse2.c" + "${AOM_ROOT}/av1/encoder/x86/av1_k_means_sse2.c" "${AOM_ROOT}/av1/encoder/x86/highbd_temporal_filter_sse2.c" "${AOM_ROOT}/av1/encoder/x86/wedge_utils_sse2.c")
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl index 01d035b..adf4fb1 100644 --- a/av1/common/av1_rtcd_defs.pl +++ b/av1/common/av1_rtcd_defs.pl
@@ -366,10 +366,10 @@ ##Krishna SSE2 TODO add_proto qw/void av1_calc_indices_dim1/, "const int *data, const int *centroids, uint8_t *indices, int n, int k"; - specialize qw/av1_calc_indices_dim1 avx2/; + specialize qw/av1_calc_indices_dim1 sse2 avx2/; add_proto qw/void av1_calc_indices_dim2/, "const int *data, const int *centroids, uint8_t *indices, int n, int k"; - specialize qw/av1_calc_indices_dim2 avx2/; + specialize qw/av1_calc_indices_dim2 sse2 avx2/; # ENCODEMB INVOKE if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") {
diff --git a/av1/encoder/x86/av1_k_means_sse2.c b/av1/encoder/x86/av1_k_means_sse2.c new file mode 100644 index 0000000..10efc9c --- /dev/null +++ b/av1/encoder/x86/av1_k_means_sse2.c
@@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <emmintrin.h> // SSE2 + +#include "config/aom_dsp_rtcd.h" +#include "aom_dsp/x86/synonyms.h" + +void av1_calc_indices_dim1_sse2(const int *data, const int *centroids, + uint8_t *indices, int n, int k) { + const __m128i v_zero = _mm_setzero_si128(); + int l = 1; + __m128i dist[PALETTE_MAX_SIZE]; + __m128i ind[2]; + + for (int i = 0; i < n; i += 4) { + l = (l == 0) ? 1 : 0; + ind[l] = _mm_loadu_si128((__m128i *)data); + for (int j = 0; j < k; j++) { + __m128i cent = _mm_set1_epi32((uint32_t)centroids[j]); + __m128i d1 = _mm_sub_epi32(ind[l], cent); + dist[j] = _mm_madd_epi16(d1, d1); + } + + ind[l] = _mm_setzero_si128(); + for (int j = 1; j < k; j++) { + __m128i cmp = _mm_cmpgt_epi32(dist[0], dist[j]); + __m128i dist1 = _mm_andnot_si128(cmp, dist[0]); + __m128i dist2 = _mm_and_si128(cmp, dist[j]); + dist[0] = _mm_or_si128(dist1, dist2); + __m128i ind1 = _mm_set1_epi32(j); + ind[l] = + _mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1)); + ind[l] = _mm_packus_epi16(ind[l], v_zero); + } + if (l == 1) { + __m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero); + _mm_storel_epi64((__m128i *)indices, p2); + indices += 8; + } + data += 4; + } +} + +void av1_calc_indices_dim2_sse2(const int *data, const int *centroids, + uint8_t *indices, int n, int k) { + const __m128i v_zero = _mm_setzero_si128(); + int l = 1; + __m128i dist[PALETTE_MAX_SIZE]; + __m128i ind[2]; + + for (int i = 0; i < n; i += 4) { + l = (l == 0) ? 1 : 0; + __m128i ind1 = _mm_loadu_si128((__m128i *)data); + __m128i ind2 = _mm_loadu_si128((__m128i *)(data + 4)); + __m128i indl = _mm_unpacklo_epi32(ind1, ind2); + __m128i indh = _mm_unpackhi_epi32(ind1, ind2); + ind1 = _mm_unpacklo_epi32(indl, indh); + ind2 = _mm_unpackhi_epi32(indl, indh); + for (int j = 0; j < k; j++) { + __m128i cent0 = _mm_set1_epi32(centroids[2 * j]); + __m128i cent1 = _mm_set1_epi32(centroids[2 * j + 1]); + __m128i d1 = _mm_sub_epi32(ind1, cent0); + __m128i d2 = _mm_sub_epi32(ind2, cent1); + __m128i d3 = _mm_madd_epi16(d1, d1); + __m128i d4 = _mm_madd_epi16(d2, d2); + dist[j] = _mm_add_epi32(d3, d4); + } + + ind[l] = _mm_setzero_si128(); + for (int j = 1; j < k; j++) { + __m128i cmp = _mm_cmpgt_epi32(dist[0], dist[j]); + __m128i dist1 = _mm_andnot_si128(cmp, dist[0]); + __m128i dist2 = _mm_and_si128(cmp, dist[j]); + dist[0] = _mm_or_si128(dist1, dist2); + ind1 = _mm_set1_epi32(j); + ind[l] = + _mm_or_si128(_mm_andnot_si128(cmp, ind[l]), _mm_and_si128(cmp, ind1)); + ind[l] = _mm_packus_epi16(ind[l], v_zero); + } + if (l == 1) { + __m128i p2 = _mm_packus_epi16(_mm_unpacklo_epi64(ind[0], ind[1]), v_zero); + _mm_storel_epi64((__m128i *)indices, p2); + indices += 8; + } + data += 8; + } +}
diff --git a/test/av1_k_means_test.cc b/test/av1_k_means_test.cc index 754a2da..e5ffda6 100644 --- a/test/av1_k_means_test.cc +++ b/test/av1_k_means_test.cc
@@ -254,13 +254,13 @@ RunSpeedTest(GET_PARAM(0), GET_PARAM(1), 8); } -#if HAVE_AVX2 const BLOCK_SIZE kValidBlockSize[] = { BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16, BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32, BLOCK_32X64, BLOCK_64X32, BLOCK_64X64, BLOCK_16X64, BLOCK_64X16 }; +#if HAVE_AVX2 INSTANTIATE_TEST_SUITE_P( AVX2, AV1KmeansTest1, ::testing::Combine(::testing::Values(&av1_calc_indices_dim1_avx2), @@ -271,4 +271,16 @@ ::testing::ValuesIn(kValidBlockSize))); #endif +#if HAVE_SSE2 + +INSTANTIATE_TEST_SUITE_P( + SSE2, AV1KmeansTest1, + ::testing::Combine(::testing::Values(&av1_calc_indices_dim1_avx2), + ::testing::ValuesIn(kValidBlockSize))); +INSTANTIATE_TEST_SUITE_P( + SSE2, AV1KmeansTest2, + ::testing::Combine(::testing::Values(&av1_calc_indices_dim2_avx2), + ::testing::ValuesIn(kValidBlockSize))); +#endif + } // namespace AV1Kmeans