Add a few avx2 functions for sad Change-Id: I2b656dbc5467434180abb0e4a7768d530f304b2b
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl index 8c51e7f..ad16f01 100755 --- a/aom_dsp/aom_dsp_rtcd_defs.pl +++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -683,8 +683,8 @@ specialize qw/aom_dist_wtd_sad16x4_avg ssse3/; specialize qw/aom_dist_wtd_sad8x32_avg ssse3/; specialize qw/aom_dist_wtd_sad32x8_avg ssse3/; - specialize qw/aom_dist_wtd_sad16x64_avg ssse3/; - specialize qw/aom_dist_wtd_sad64x16_avg ssse3/; + specialize qw/aom_dist_wtd_sad16x64_avg ssse3/; + specialize qw/aom_dist_wtd_sad64x16_avg ssse3/; add_proto qw/unsigned int/, "aom_sad4xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height"; add_proto qw/unsigned int/, "aom_sad8xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height"; @@ -737,17 +737,17 @@ specialize qw/aom_highbd_sad16x8_avg avx2 sse2/; specialize qw/aom_highbd_sad8x4_avg sse2/; - specialize qw/aom_highbd_sad16x4 sse2/; - specialize qw/aom_highbd_sad8x32 sse2/; - specialize qw/aom_highbd_sad32x8 sse2/; - specialize qw/aom_highbd_sad16x64 sse2/; - specialize qw/aom_highbd_sad64x16 sse2/; + specialize qw/aom_highbd_sad16x4 avx2 sse2/; + specialize qw/aom_highbd_sad8x32 sse2/; + specialize qw/aom_highbd_sad32x8 avx2 sse2/; + specialize qw/aom_highbd_sad16x64 avx2 sse2/; + specialize qw/aom_highbd_sad64x16 avx2 sse2/; - specialize qw/aom_highbd_sad16x4_avg sse2/; - specialize qw/aom_highbd_sad8x32_avg sse2/; - specialize qw/aom_highbd_sad32x8_avg sse2/; - specialize qw/aom_highbd_sad16x64_avg sse2/; - specialize qw/aom_highbd_sad64x16_avg sse2/; + specialize qw/aom_highbd_sad16x4_avg avx2 sse2/; + specialize qw/aom_highbd_sad8x32_avg sse2/; + specialize qw/aom_highbd_sad32x8_avg avx2 sse2/; + specialize qw/aom_highbd_sad16x64_avg avx2 sse2/; + specialize qw/aom_highbd_sad64x16_avg avx2 sse2/; # # Masked SAD @@ -846,12 +846,12 @@ specialize qw/aom_highbd_sad4x8x4d sse2/; specialize qw/aom_highbd_sad4x4x4d sse2/; - specialize qw/aom_highbd_sad4x16x4d sse2/; - specialize qw/aom_highbd_sad16x4x4d sse2/; - specialize qw/aom_highbd_sad8x32x4d sse2/; - specialize qw/aom_highbd_sad32x8x4d sse2/; - specialize qw/aom_highbd_sad16x64x4d sse2/; - specialize qw/aom_highbd_sad64x16x4d sse2/; + specialize qw/aom_highbd_sad4x16x4d sse2/; + specialize qw/aom_highbd_sad16x4x4d avx2 sse2/; + specialize qw/aom_highbd_sad8x32x4d sse2/; + specialize qw/aom_highbd_sad32x8x4d avx2 sse2/; + specialize qw/aom_highbd_sad16x64x4d avx2 sse2/; + specialize qw/aom_highbd_sad64x16x4d avx2 sse2/; # # Avg
diff --git a/aom_dsp/x86/sad_highbd_avx2.c b/aom_dsp/x86/sad_highbd_avx2.c index b506d46..eba442c 100644 --- a/aom_dsp/x86/sad_highbd_avx2.c +++ b/aom_dsp/x86/sad_highbd_avx2.c
@@ -229,6 +229,23 @@ return get_sad_from_mm256_epi32(&sad); } +unsigned int aom_highbd_sad32x8_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride) { + __m256i sad = _mm256_setzero_si256(); + uint16_t *srcp = CONVERT_TO_SHORTPTR(src); + uint16_t *refp = CONVERT_TO_SHORTPTR(ref); + const int left_shift = 2; + int row_section = 0; + + while (row_section < 2) { + sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad); + srcp += src_stride << left_shift; + refp += ref_stride << left_shift; + row_section += 1; + } + return get_sad_from_mm256_epi32(&sad); +} + unsigned int aom_highbd_sad16x32_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride) { uint32_t sum = aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride); @@ -352,6 +369,23 @@ return get_sad_from_mm256_epi32(&sad); } +unsigned int aom_highbd_sad64x16_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride) { + __m256i sad = _mm256_setzero_si256(); + uint16_t *srcp = CONVERT_TO_SHORTPTR(src); + uint16_t *refp = CONVERT_TO_SHORTPTR(ref); + const int left_shift = 1; + int row_section = 0; + + while (row_section < 8) { + sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad); + srcp += src_stride << left_shift; + refp += ref_stride << left_shift; + row_section += 1; + } + return get_sad_from_mm256_epi32(&sad); +} + unsigned int aom_highbd_sad64x64_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride) { uint32_t sum = aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride); @@ -520,6 +554,27 @@ *sad_acc = _mm256_add_epi32(*sad_acc, r0); } +unsigned int aom_highbd_sad16x4_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride) { + __m256i sad = _mm256_setzero_si256(); + uint16_t *srcp = CONVERT_TO_SHORTPTR(src); + uint16_t *refp = CONVERT_TO_SHORTPTR(ref); + sad16x4(srcp, src_stride, refp, ref_stride, NULL, &sad); + return get_sad_from_mm256_epi32(&sad); +} + +unsigned int aom_highbd_sad16x4_avg_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + const uint8_t *second_pred) { + __m256i sad = _mm256_setzero_si256(); + uint16_t *srcp = CONVERT_TO_SHORTPTR(src); + uint16_t *refp = CONVERT_TO_SHORTPTR(ref); + uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred); + sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad); + + return get_sad_from_mm256_epi32(&sad); +} + unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, const uint8_t *second_pred) { @@ -566,6 +621,50 @@ return sum; } +unsigned int aom_highbd_sad16x64_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride) { + const int left_shift = 5; + uint32_t sum = aom_highbd_sad16x32_avx2(src, src_stride, ref, ref_stride); + src += src_stride << left_shift; + ref += ref_stride << left_shift; + sum += aom_highbd_sad16x32_avx2(src, src_stride, ref, ref_stride); + return sum; +} + +unsigned int aom_highbd_sad16x64_avg_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + const uint8_t *second_pred) { + const int left_shift = 5; + uint32_t sum = aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride, + second_pred); + src += src_stride << left_shift; + ref += ref_stride << left_shift; + second_pred += 16 << left_shift; + sum += aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride, + second_pred); + return sum; +} + +unsigned int aom_highbd_sad32x8_avg_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + const uint8_t *second_pred) { + __m256i sad = _mm256_setzero_si256(); + uint16_t *srcp = CONVERT_TO_SHORTPTR(src); + uint16_t *refp = CONVERT_TO_SHORTPTR(ref); + uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred); + const int left_shift = 2; + int row_section = 0; + + while (row_section < 2) { + sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad); + srcp += src_stride << left_shift; + refp += ref_stride << left_shift; + secp += 32 << left_shift; + row_section += 1; + } + return get_sad_from_mm256_epi32(&sad); +} + unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, const uint8_t *second_pred) { @@ -614,6 +713,26 @@ return sum; } +unsigned int aom_highbd_sad64x16_avg_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + const uint8_t *second_pred) { + __m256i sad = _mm256_setzero_si256(); + uint16_t *srcp = CONVERT_TO_SHORTPTR(src); + uint16_t *refp = CONVERT_TO_SHORTPTR(ref); + uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred); + const int left_shift = 1; + int row_section = 0; + + while (row_section < 8) { + sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad); + srcp += src_stride << left_shift; + refp += ref_stride << left_shift; + secp += 64 << left_shift; + row_section += 1; + } + return get_sad_from_mm256_epi32(&sad); +} + unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, const uint8_t *second_pred) { @@ -752,6 +871,25 @@ s[3] = _mm256_setzero_si256(); } +void aom_highbd_sad16x4x4d_avx2(const uint8_t *src, int src_stride, + const uint8_t *const ref_array[], + int ref_stride, uint32_t *sad_array) { + __m256i sad_vec[4]; + const uint16_t *refp[4]; + const uint16_t *keep = CONVERT_TO_SHORTPTR(src); + const uint16_t *srcp; + int i; + + init_sad(sad_vec); + convert_pointers(ref_array, refp); + + for (i = 0; i < 4; ++i) { + srcp = keep; + sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]); + } + get_4d_sad_from_mm256_epi32(sad_vec, sad_array); +} + void aom_highbd_sad16x8x4d_avx2(const uint8_t *src, int src_stride, const uint8_t *const ref_array[], int ref_stride, uint32_t *sad_array) { @@ -827,6 +965,59 @@ sad_array[3] = first_half[3] + second_half[3]; } +void aom_highbd_sad16x64x4d_avx2(const uint8_t *src, int src_stride, + const uint8_t *const ref_array[], + int ref_stride, uint32_t *sad_array) { + uint32_t first_half[4]; + uint32_t second_half[4]; + const uint8_t *ref[4]; + const int shift_for_rows = 5; + + ref[0] = ref_array[0]; + ref[1] = ref_array[1]; + ref[2] = ref_array[2]; + ref[3] = ref_array[3]; + + aom_highbd_sad16x32x4d_avx2(src, src_stride, ref, ref_stride, first_half); + src += src_stride << shift_for_rows; + ref[0] += ref_stride << shift_for_rows; + ref[1] += ref_stride << shift_for_rows; + ref[2] += ref_stride << shift_for_rows; + ref[3] += ref_stride << shift_for_rows; + aom_highbd_sad16x32x4d_avx2(src, src_stride, ref, ref_stride, second_half); + sad_array[0] = first_half[0] + second_half[0]; + sad_array[1] = first_half[1] + second_half[1]; + sad_array[2] = first_half[2] + second_half[2]; + sad_array[3] = first_half[3] + second_half[3]; +} + +void aom_highbd_sad32x8x4d_avx2(const uint8_t *src, int src_stride, + const uint8_t *const ref_array[], + int ref_stride, uint32_t *sad_array) { + __m256i sad_vec[4]; + const uint16_t *refp[4]; + const uint16_t *keep = CONVERT_TO_SHORTPTR(src); + const uint16_t *srcp; + const int shift_for_4_rows = 2; + int i; + int rows_section; + + init_sad(sad_vec); + convert_pointers(ref_array, refp); + + for (i = 0; i < 4; ++i) { + srcp = keep; + rows_section = 0; + while (rows_section < 2) { + sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]); + srcp += src_stride << shift_for_4_rows; + refp[i] += ref_stride << shift_for_4_rows; + rows_section++; + } + } + get_4d_sad_from_mm256_epi32(sad_vec, sad_array); +} + void aom_highbd_sad32x16x4d_avx2(const uint8_t *src, int src_stride, const uint8_t *const ref_array[], int ref_stride, uint32_t *sad_array) { @@ -906,6 +1097,33 @@ sad_array[3] = first_half[3] + second_half[3]; } +void aom_highbd_sad64x16x4d_avx2(const uint8_t *src, int src_stride, + const uint8_t *const ref_array[], + int ref_stride, uint32_t *sad_array) { + __m256i sad_vec[4]; + const uint16_t *refp[4]; + const uint16_t *keep = CONVERT_TO_SHORTPTR(src); + const uint16_t *srcp; + const int shift_for_rows = 1; + int i; + int rows_section; + + init_sad(sad_vec); + convert_pointers(ref_array, refp); + + for (i = 0; i < 4; ++i) { + srcp = keep; + rows_section = 0; + while (rows_section < 8) { + sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]); + srcp += src_stride << shift_for_rows; + refp[i] += ref_stride << shift_for_rows; + rows_section++; + } + } + get_4d_sad_from_mm256_epi32(sad_vec, sad_array); +} + void aom_highbd_sad64x32x4d_avx2(const uint8_t *src, int src_stride, const uint8_t *const ref_array[], int ref_stride, uint32_t *sad_array) {
diff --git a/test/sad_test.cc b/test/sad_test.cc index 3bb6150..a4fd08d 100644 --- a/test/sad_test.cc +++ b/test/sad_test.cc
@@ -1582,6 +1582,19 @@ make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 8), make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 10), make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 12), + + make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 8), + make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 10), + make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 12), + make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 8), + make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 10), + make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 12), + make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 8), + make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 10), + make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 12), + make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 8), + make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 10), + make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 12), }; INSTANTIATE_TEST_CASE_P(AVX2, SADTest, ::testing::ValuesIn(avx2_tests)); @@ -1627,6 +1640,19 @@ make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 8), make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 10), make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 12), + + make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 8), + make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 10), + make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 12), + make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 8), + make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 10), + make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 12), + make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 8), + make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 10), + make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 12), + make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 8), + make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 10), + make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 12), }; INSTANTIATE_TEST_CASE_P(AVX2, SADavgTest, ::testing::ValuesIn(avg_avx2_tests)); @@ -1671,6 +1697,19 @@ make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 8), make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 10), make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 12), + + make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 8), + make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 10), + make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 12), + make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 8), + make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 10), + make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 12), + make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 8), + make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 10), + make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 12), + make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 8), + make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 10), + make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 12), }; INSTANTIATE_TEST_CASE_P(AVX2, SADx4Test, ::testing::ValuesIn(x4d_avx2_tests)); #endif // HAVE_AVX2