Add a few avx2 functions for sad
Change-Id: I2b656dbc5467434180abb0e4a7768d530f304b2b
diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 8c51e7f..ad16f01 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -683,8 +683,8 @@
specialize qw/aom_dist_wtd_sad16x4_avg ssse3/;
specialize qw/aom_dist_wtd_sad8x32_avg ssse3/;
specialize qw/aom_dist_wtd_sad32x8_avg ssse3/;
- specialize qw/aom_dist_wtd_sad16x64_avg ssse3/;
- specialize qw/aom_dist_wtd_sad64x16_avg ssse3/;
+ specialize qw/aom_dist_wtd_sad16x64_avg ssse3/;
+ specialize qw/aom_dist_wtd_sad64x16_avg ssse3/;
add_proto qw/unsigned int/, "aom_sad4xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
add_proto qw/unsigned int/, "aom_sad8xh", "const uint8_t *a, int a_stride, const uint8_t *b, int b_stride, int width, int height";
@@ -737,17 +737,17 @@
specialize qw/aom_highbd_sad16x8_avg avx2 sse2/;
specialize qw/aom_highbd_sad8x4_avg sse2/;
- specialize qw/aom_highbd_sad16x4 sse2/;
- specialize qw/aom_highbd_sad8x32 sse2/;
- specialize qw/aom_highbd_sad32x8 sse2/;
- specialize qw/aom_highbd_sad16x64 sse2/;
- specialize qw/aom_highbd_sad64x16 sse2/;
+ specialize qw/aom_highbd_sad16x4 avx2 sse2/;
+ specialize qw/aom_highbd_sad8x32 sse2/;
+ specialize qw/aom_highbd_sad32x8 avx2 sse2/;
+ specialize qw/aom_highbd_sad16x64 avx2 sse2/;
+ specialize qw/aom_highbd_sad64x16 avx2 sse2/;
- specialize qw/aom_highbd_sad16x4_avg sse2/;
- specialize qw/aom_highbd_sad8x32_avg sse2/;
- specialize qw/aom_highbd_sad32x8_avg sse2/;
- specialize qw/aom_highbd_sad16x64_avg sse2/;
- specialize qw/aom_highbd_sad64x16_avg sse2/;
+ specialize qw/aom_highbd_sad16x4_avg avx2 sse2/;
+ specialize qw/aom_highbd_sad8x32_avg sse2/;
+ specialize qw/aom_highbd_sad32x8_avg avx2 sse2/;
+ specialize qw/aom_highbd_sad16x64_avg avx2 sse2/;
+ specialize qw/aom_highbd_sad64x16_avg avx2 sse2/;
#
# Masked SAD
@@ -846,12 +846,12 @@
specialize qw/aom_highbd_sad4x8x4d sse2/;
specialize qw/aom_highbd_sad4x4x4d sse2/;
- specialize qw/aom_highbd_sad4x16x4d sse2/;
- specialize qw/aom_highbd_sad16x4x4d sse2/;
- specialize qw/aom_highbd_sad8x32x4d sse2/;
- specialize qw/aom_highbd_sad32x8x4d sse2/;
- specialize qw/aom_highbd_sad16x64x4d sse2/;
- specialize qw/aom_highbd_sad64x16x4d sse2/;
+ specialize qw/aom_highbd_sad4x16x4d sse2/;
+ specialize qw/aom_highbd_sad16x4x4d avx2 sse2/;
+ specialize qw/aom_highbd_sad8x32x4d sse2/;
+ specialize qw/aom_highbd_sad32x8x4d avx2 sse2/;
+ specialize qw/aom_highbd_sad16x64x4d avx2 sse2/;
+ specialize qw/aom_highbd_sad64x16x4d avx2 sse2/;
#
# Avg
diff --git a/aom_dsp/x86/sad_highbd_avx2.c b/aom_dsp/x86/sad_highbd_avx2.c
index b506d46..eba442c 100644
--- a/aom_dsp/x86/sad_highbd_avx2.c
+++ b/aom_dsp/x86/sad_highbd_avx2.c
@@ -229,6 +229,23 @@
return get_sad_from_mm256_epi32(&sad);
}
+unsigned int aom_highbd_sad32x8_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride) {
+ __m256i sad = _mm256_setzero_si256();
+ uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+ uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+ const int left_shift = 2;
+ int row_section = 0;
+
+ while (row_section < 2) {
+ sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
+ srcp += src_stride << left_shift;
+ refp += ref_stride << left_shift;
+ row_section += 1;
+ }
+ return get_sad_from_mm256_epi32(&sad);
+}
+
unsigned int aom_highbd_sad16x32_avx2(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride) {
uint32_t sum = aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride);
@@ -352,6 +369,23 @@
return get_sad_from_mm256_epi32(&sad);
}
+unsigned int aom_highbd_sad64x16_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride) {
+ __m256i sad = _mm256_setzero_si256();
+ uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+ uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+ const int left_shift = 1;
+ int row_section = 0;
+
+ while (row_section < 8) {
+ sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
+ srcp += src_stride << left_shift;
+ refp += ref_stride << left_shift;
+ row_section += 1;
+ }
+ return get_sad_from_mm256_epi32(&sad);
+}
+
unsigned int aom_highbd_sad64x64_avx2(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride) {
uint32_t sum = aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride);
@@ -520,6 +554,27 @@
*sad_acc = _mm256_add_epi32(*sad_acc, r0);
}
+unsigned int aom_highbd_sad16x4_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride) {
+ __m256i sad = _mm256_setzero_si256();
+ uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+ uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+ sad16x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
+ return get_sad_from_mm256_epi32(&sad);
+}
+
+unsigned int aom_highbd_sad16x4_avg_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ const uint8_t *second_pred) {
+ __m256i sad = _mm256_setzero_si256();
+ uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+ uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+ uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
+ sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
+
+ return get_sad_from_mm256_epi32(&sad);
+}
+
unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride,
const uint8_t *second_pred) {
@@ -566,6 +621,50 @@
return sum;
}
+unsigned int aom_highbd_sad16x64_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride) {
+ const int left_shift = 5;
+ uint32_t sum = aom_highbd_sad16x32_avx2(src, src_stride, ref, ref_stride);
+ src += src_stride << left_shift;
+ ref += ref_stride << left_shift;
+ sum += aom_highbd_sad16x32_avx2(src, src_stride, ref, ref_stride);
+ return sum;
+}
+
+unsigned int aom_highbd_sad16x64_avg_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ const uint8_t *second_pred) {
+ const int left_shift = 5;
+ uint32_t sum = aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
+ second_pred);
+ src += src_stride << left_shift;
+ ref += ref_stride << left_shift;
+ second_pred += 16 << left_shift;
+ sum += aom_highbd_sad16x32_avg_avx2(src, src_stride, ref, ref_stride,
+ second_pred);
+ return sum;
+}
+
+unsigned int aom_highbd_sad32x8_avg_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ const uint8_t *second_pred) {
+ __m256i sad = _mm256_setzero_si256();
+ uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+ uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+ uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
+ const int left_shift = 2;
+ int row_section = 0;
+
+ while (row_section < 2) {
+ sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
+ srcp += src_stride << left_shift;
+ refp += ref_stride << left_shift;
+ secp += 32 << left_shift;
+ row_section += 1;
+ }
+ return get_sad_from_mm256_epi32(&sad);
+}
+
unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride,
const uint8_t *second_pred) {
@@ -614,6 +713,26 @@
return sum;
}
+unsigned int aom_highbd_sad64x16_avg_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *ref, int ref_stride,
+ const uint8_t *second_pred) {
+ __m256i sad = _mm256_setzero_si256();
+ uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
+ uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
+ uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
+ const int left_shift = 1;
+ int row_section = 0;
+
+ while (row_section < 8) {
+ sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
+ srcp += src_stride << left_shift;
+ refp += ref_stride << left_shift;
+ secp += 64 << left_shift;
+ row_section += 1;
+ }
+ return get_sad_from_mm256_epi32(&sad);
+}
+
unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride,
const uint8_t *second_pred) {
@@ -752,6 +871,25 @@
s[3] = _mm256_setzero_si256();
}
+void aom_highbd_sad16x4x4d_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *const ref_array[],
+ int ref_stride, uint32_t *sad_array) {
+ __m256i sad_vec[4];
+ const uint16_t *refp[4];
+ const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
+ const uint16_t *srcp;
+ int i;
+
+ init_sad(sad_vec);
+ convert_pointers(ref_array, refp);
+
+ for (i = 0; i < 4; ++i) {
+ srcp = keep;
+ sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
+ }
+ get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
+}
+
void aom_highbd_sad16x8x4d_avx2(const uint8_t *src, int src_stride,
const uint8_t *const ref_array[],
int ref_stride, uint32_t *sad_array) {
@@ -827,6 +965,59 @@
sad_array[3] = first_half[3] + second_half[3];
}
+void aom_highbd_sad16x64x4d_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *const ref_array[],
+ int ref_stride, uint32_t *sad_array) {
+ uint32_t first_half[4];
+ uint32_t second_half[4];
+ const uint8_t *ref[4];
+ const int shift_for_rows = 5;
+
+ ref[0] = ref_array[0];
+ ref[1] = ref_array[1];
+ ref[2] = ref_array[2];
+ ref[3] = ref_array[3];
+
+ aom_highbd_sad16x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
+ src += src_stride << shift_for_rows;
+ ref[0] += ref_stride << shift_for_rows;
+ ref[1] += ref_stride << shift_for_rows;
+ ref[2] += ref_stride << shift_for_rows;
+ ref[3] += ref_stride << shift_for_rows;
+ aom_highbd_sad16x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
+ sad_array[0] = first_half[0] + second_half[0];
+ sad_array[1] = first_half[1] + second_half[1];
+ sad_array[2] = first_half[2] + second_half[2];
+ sad_array[3] = first_half[3] + second_half[3];
+}
+
+void aom_highbd_sad32x8x4d_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *const ref_array[],
+ int ref_stride, uint32_t *sad_array) {
+ __m256i sad_vec[4];
+ const uint16_t *refp[4];
+ const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
+ const uint16_t *srcp;
+ const int shift_for_4_rows = 2;
+ int i;
+ int rows_section;
+
+ init_sad(sad_vec);
+ convert_pointers(ref_array, refp);
+
+ for (i = 0; i < 4; ++i) {
+ srcp = keep;
+ rows_section = 0;
+ while (rows_section < 2) {
+ sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
+ srcp += src_stride << shift_for_4_rows;
+ refp[i] += ref_stride << shift_for_4_rows;
+ rows_section++;
+ }
+ }
+ get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
+}
+
void aom_highbd_sad32x16x4d_avx2(const uint8_t *src, int src_stride,
const uint8_t *const ref_array[],
int ref_stride, uint32_t *sad_array) {
@@ -906,6 +1097,33 @@
sad_array[3] = first_half[3] + second_half[3];
}
+void aom_highbd_sad64x16x4d_avx2(const uint8_t *src, int src_stride,
+ const uint8_t *const ref_array[],
+ int ref_stride, uint32_t *sad_array) {
+ __m256i sad_vec[4];
+ const uint16_t *refp[4];
+ const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
+ const uint16_t *srcp;
+ const int shift_for_rows = 1;
+ int i;
+ int rows_section;
+
+ init_sad(sad_vec);
+ convert_pointers(ref_array, refp);
+
+ for (i = 0; i < 4; ++i) {
+ srcp = keep;
+ rows_section = 0;
+ while (rows_section < 8) {
+ sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
+ srcp += src_stride << shift_for_rows;
+ refp[i] += ref_stride << shift_for_rows;
+ rows_section++;
+ }
+ }
+ get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
+}
+
void aom_highbd_sad64x32x4d_avx2(const uint8_t *src, int src_stride,
const uint8_t *const ref_array[],
int ref_stride, uint32_t *sad_array) {
diff --git a/test/sad_test.cc b/test/sad_test.cc
index 3bb6150..a4fd08d 100644
--- a/test/sad_test.cc
+++ b/test/sad_test.cc
@@ -1582,6 +1582,19 @@
make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 8),
make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 10),
make_tuple(16, 8, &aom_highbd_sad16x8_avx2, 12),
+
+ make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 8),
+ make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 10),
+ make_tuple(64, 16, &aom_highbd_sad64x16_avx2, 12),
+ make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 8),
+ make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 10),
+ make_tuple(16, 64, &aom_highbd_sad16x64_avx2, 12),
+ make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 8),
+ make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 10),
+ make_tuple(32, 8, &aom_highbd_sad32x8_avx2, 12),
+ make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 8),
+ make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 10),
+ make_tuple(16, 4, &aom_highbd_sad16x4_avx2, 12),
};
INSTANTIATE_TEST_CASE_P(AVX2, SADTest, ::testing::ValuesIn(avx2_tests));
@@ -1627,6 +1640,19 @@
make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 8),
make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 10),
make_tuple(16, 8, &aom_highbd_sad16x8_avg_avx2, 12),
+
+ make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 8),
+ make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 10),
+ make_tuple(64, 16, &aom_highbd_sad64x16_avg_avx2, 12),
+ make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 8),
+ make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 10),
+ make_tuple(16, 64, &aom_highbd_sad16x64_avg_avx2, 12),
+ make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 8),
+ make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 10),
+ make_tuple(32, 8, &aom_highbd_sad32x8_avg_avx2, 12),
+ make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 8),
+ make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 10),
+ make_tuple(16, 4, &aom_highbd_sad16x4_avg_avx2, 12),
};
INSTANTIATE_TEST_CASE_P(AVX2, SADavgTest, ::testing::ValuesIn(avg_avx2_tests));
@@ -1671,6 +1697,19 @@
make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 8),
make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 10),
make_tuple(16, 8, &aom_highbd_sad16x8x4d_avx2, 12),
+
+ make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 8),
+ make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 10),
+ make_tuple(16, 64, &aom_highbd_sad16x64x4d_avx2, 12),
+ make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 8),
+ make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 10),
+ make_tuple(64, 16, &aom_highbd_sad64x16x4d_avx2, 12),
+ make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 8),
+ make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 10),
+ make_tuple(32, 8, &aom_highbd_sad32x8x4d_avx2, 12),
+ make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 8),
+ make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 10),
+ make_tuple(16, 4, &aom_highbd_sad16x4x4d_avx2, 12),
};
INSTANTIATE_TEST_CASE_P(AVX2, SADx4Test, ::testing::ValuesIn(x4d_avx2_tests));
#endif // HAVE_AVX2