AVX2 ver of highb dr prediction - Z1 bug fixed
Extracted from https://aomedia-review.googlesource.com/c/aom/+/75642.
BUG=aomedia:2259,aomedia:2260,oss-fuzz:11517
Change-Id: Ifc4d8e4703080a70aac1f97e5dbefb2ab2b8e33b
diff --git a/aom_dsp/x86/intrapred_avx2.c b/aom_dsp/x86/intrapred_avx2.c
index 4525f95..5f3e7bb 100644
--- a/aom_dsp/x86/intrapred_avx2.c
+++ b/aom_dsp/x86/intrapred_avx2.c
@@ -1260,7 +1260,6 @@
const uint16_t *above,
int upsample_above, int dx) {
__m256i dstvec[64];
-
highbd_dr_prediction_z1_16xN_internal_avx2(N, dstvec, above, upsample_above,
dx);
for (int i = 0; i < N; i++) {
@@ -1307,7 +1306,7 @@
for (int j = 0; j < 32; j += 16) {
int mdif = max_base_x - (base + j);
- if (mdif == 0) {
+ if (mdif <= 0) {
res1 = a_mbase_x;
} else {
a0 = _mm256_cvtepu16_epi32(
@@ -1325,17 +1324,17 @@
res[0] = _mm256_packus_epi32(
res[0],
_mm256_castsi128_si256(_mm256_extracti128_si256(res[0], 1)));
-
- a0_1 = _mm256_cvtepu16_epi32(
- _mm_loadu_si128((__m128i *)(above + base + 8 + j)));
- a1_1 = _mm256_cvtepu16_epi32(
- _mm_loadu_si128((__m128i *)(above + base + 9 + j)));
-
- diff = _mm256_sub_epi32(a1_1, a0_1); // a[x+1] - a[x]
- a32 = _mm256_slli_epi32(a0_1, 5); // a[x] * 32
- a32 = _mm256_add_epi32(a32, a16); // a[x] * 32 + 16
- b = _mm256_mullo_epi32(diff, shift);
if (mdif > 8) {
+ a0_1 = _mm256_cvtepu16_epi32(
+ _mm_loadu_si128((__m128i *)(above + base + 8 + j)));
+ a1_1 = _mm256_cvtepu16_epi32(
+ _mm_loadu_si128((__m128i *)(above + base + 9 + j)));
+
+ diff = _mm256_sub_epi32(a1_1, a0_1); // a[x+1] - a[x]
+ a32 = _mm256_slli_epi32(a0_1, 5); // a[x] * 32
+ a32 = _mm256_add_epi32(a32, a16); // a[x] * 32 + 16
+ b = _mm256_mullo_epi32(diff, shift);
+
res[1] = _mm256_add_epi32(a32, b);
res[1] = _mm256_srli_epi32(res[1], 5);
res[1] = _mm256_packus_epi32(
@@ -1424,7 +1423,7 @@
__m128i a0_128, a0_1_128, a1_128, a1_1_128;
for (int j = 0; j < 64; j += 16) {
int mdif = max_base_x - (base + j);
- if (mdif == 0) {
+ if (mdif <= 0) {
_mm256_storeu_si256((__m256i *)(dst + j), a_mbase_x);
} else {
a0_128 = _mm_loadu_si128((__m128i *)(above + base + j));