clean up SIMD for IDIF to match c function logic
diff --git a/aom_dsp/x86/intrapred_avx2.c b/aom_dsp/x86/intrapred_avx2.c index 39dac25..049a272 100644 --- a/aom_dsp/x86/intrapred_avx2.c +++ b/aom_dsp/x86/intrapred_avx2.c
@@ -2983,8 +2983,7 @@ static AOM_FORCE_INLINE void highbd_dr_prediction_z1_4xN_internal_idif_avx2( int N, __m128i *dst, const uint16_t *above, int dx, int mrl_index, int bd) { const int frac_bits = 6; - // max base for the 4-tap filter is on the last ref sample (+1 to re-use code) - const int max_base_x = ((N + 4) + (mrl_index << 1)); + const int max_base_x = ((N + 4) - 1 + (mrl_index << 1)); assert(dx > 0); __m256i a0, a1, a2, a3; @@ -2994,8 +2993,8 @@ __m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm_set1_epi16(above[max_base_x - 1]); - max_base_x128 = _mm_set1_epi16(max_base_x); + a_mbase_x = _mm_set1_epi16(above[max_base_x]); + max_base_x128 = _mm_set1_epi16(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); @@ -3003,7 +3002,7 @@ __m128i res1; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { dst[i] = a_mbase_x; // save 4 values } @@ -3050,7 +3049,7 @@ int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((N + 4) + (mrl_index << 1)); + const int max_base_x = ((N + 4) - 1 + (mrl_index << 1)); assert(dx > 0); __m256i a0, a1, a2, a3; @@ -3060,8 +3059,8 @@ __m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm_set1_epi16(above[max_base_x - 1]); - max_base_x128 = _mm_set1_epi32(max_base_x); + a_mbase_x = _mm_set1_epi16(above[max_base_x]); + max_base_x128 = _mm_set1_epi32(max_base_x + 1); int x = dx * (1 + mrl_index); int shift_i; @@ -3069,7 +3068,7 @@ __m128i res1; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { dst[i] = a_mbase_x; // save 4 values } @@ -3145,7 +3144,7 @@ static AOM_FORCE_INLINE void highbd_dr_prediction_z1_8xN_internal_idif_avx2( int N, __m128i *dst, const uint16_t *above, int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((N + 8) + (mrl_index << 1)); + const int max_base_x = ((N + 8) - 1 + (mrl_index << 1)); assert(dx > 0); __m256i a0, a1, a2, a3; @@ -3155,8 +3154,8 @@ __m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]); - max_base_x256 = _mm256_set1_epi16(max_base_x); + a_mbase_x = _mm256_set1_epi16(above[max_base_x]); + max_base_x256 = _mm256_set1_epi16(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); @@ -3164,7 +3163,7 @@ __m256i res1; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { dst[i] = _mm256_castsi256_si128(a_mbase_x); // save 8 values } @@ -3212,7 +3211,7 @@ int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((N + 8) + (mrl_index << 1)); + const int max_base_x = ((N + 8) - 1 + (mrl_index << 1)); assert(dx > 0); __m256i a0, a1, a2, a3; @@ -3222,8 +3221,8 @@ __m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]); - max_base_x256 = _mm256_set1_epi32(max_base_x); + a_mbase_x = _mm256_set1_epi16(above[max_base_x]); + max_base_x256 = _mm256_set1_epi32(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); @@ -3231,7 +3230,7 @@ __m256i res1; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { dst[i] = _mm256_castsi256_si128(a_mbase_x); // save 8 values } @@ -3314,7 +3313,7 @@ int N, __m256i *dstvec, const uint16_t *above, int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((16 + N) + (mrl_index << 1)); + const int max_base_x = ((16 + N) - 1 + (mrl_index << 1)); __m256i a_mbase_x, max_base_x256, base_inc256, mask256; @@ -3324,14 +3323,14 @@ __m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]); - max_base_x256 = _mm256_set1_epi16(max_base_x); + a_mbase_x = _mm256_set1_epi16(above[max_base_x]); + max_base_x256 = _mm256_set1_epi16(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); for (int r = 0; r < N; r++) { int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { dstvec[i] = a_mbase_x; // save 16 values } @@ -3379,7 +3378,7 @@ int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((16 + N) + (mrl_index << 1)); + const int max_base_x = ((16 + N) - 1 + (mrl_index << 1)); __m256i a0, a1, a2, a3; __m256i val0, val1; __m256i f0, f1, f2, f3; @@ -3387,8 +3386,8 @@ __m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]); - max_base_x256 = _mm256_set1_epi16(max_base_x); + a_mbase_x = _mm256_set1_epi16(above[max_base_x]); + max_base_x256 = _mm256_set1_epi16(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); @@ -3396,7 +3395,7 @@ __m256i res[2], res1; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { dstvec[i] = a_mbase_x; // save 16 values } @@ -3430,7 +3429,7 @@ res[0] = _mm256_packus_epi32( val0, _mm256_castsi128_si256(_mm256_extracti128_si256(val0, 1))); - int mdif = max_base_x - base; + int mdif = max_base_x + 1 - base; if (mdif > 8) { a0 = _mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i *)(above + base + 7))); @@ -3496,7 +3495,7 @@ int N, __m256i *dstvec, const uint16_t *above, int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((32 + N) + (mrl_index << 1)); + const int max_base_x = ((32 + N) - 1 + (mrl_index << 1)); __m256i a_mbase_x, max_base_x256, base_inc256, mask256; @@ -3506,8 +3505,8 @@ __m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]); - max_base_x256 = _mm256_set1_epi16(max_base_x); + a_mbase_x = _mm256_set1_epi16(above[max_base_x]); + max_base_x256 = _mm256_set1_epi16(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); @@ -3515,7 +3514,7 @@ __m256i res; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { dstvec[i] = a_mbase_x; // save 32 values dstvec[i + N] = a_mbase_x; @@ -3531,7 +3530,7 @@ f3 = _mm256_set1_epi16(av1_dr_interp_filter[shift_i][3]); for (int j = 0; j < 32; j += 16) { - int mdif = max_base_x - (base + j); + int mdif = max_base_x + 1 - (base + j); if (mdif <= 0) { res = a_mbase_x; } else { @@ -3577,7 +3576,7 @@ int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((32 + N) + (mrl_index << 1)); + const int max_base_x = ((32 + N) - 1 + (mrl_index << 1)); __m256i a_mbase_x, max_base_x256, base_inc256, mask256; @@ -3587,8 +3586,8 @@ __m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]); - max_base_x256 = _mm256_set1_epi16(max_base_x); + a_mbase_x = _mm256_set1_epi16(above[max_base_x]); + max_base_x256 = _mm256_set1_epi16(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); @@ -3596,7 +3595,7 @@ __m256i res[2], res1; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { dstvec[i] = a_mbase_x; // save 32 values dstvec[i + N] = a_mbase_x; @@ -3612,7 +3611,7 @@ f3 = _mm256_set1_epi32(av1_dr_interp_filter[shift_i][3]); for (int j = 0; j < 32; j += 16) { - int mdif = max_base_x - (base + j); + int mdif = max_base_x + 1 - (base + j); if (mdif <= 0) { res1 = a_mbase_x; } else { @@ -3711,7 +3710,7 @@ int N, uint16_t *dst, ptrdiff_t stride, const uint16_t *above, int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((64 + N) + (mrl_index << 1)); + const int max_base_x = ((64 + N) - 1 + (mrl_index << 1)); __m256i a_mbase_x, max_base_x256, base_inc256, mask256; @@ -3721,8 +3720,8 @@ __m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]); - max_base_x256 = _mm256_set1_epi16(max_base_x); + a_mbase_x = _mm256_set1_epi16(above[max_base_x]); + max_base_x256 = _mm256_set1_epi16(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); @@ -3730,7 +3729,7 @@ __m256i res; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { _mm256_storeu_si256((__m256i *)dst, a_mbase_x); // save 32 values _mm256_storeu_si256((__m256i *)(dst + 16), a_mbase_x); @@ -3749,7 +3748,7 @@ f3 = _mm256_set1_epi16(av1_dr_interp_filter[shift_i][3]); for (int j = 0; j < 64; j += 16) { - int mdif = max_base_x - (base + j); + int mdif = max_base_x + 1 - (base + j); if (mdif <= 0) { _mm256_storeu_si256((__m256i *)(dst + j), a_mbase_x); } else { @@ -3789,7 +3788,7 @@ int N, uint16_t *dst, ptrdiff_t stride, const uint16_t *above, int dx, int mrl_index, int bd) { const int frac_bits = 6; - const int max_base_x = ((64 + N) + (mrl_index << 1)); + const int max_base_x = ((64 + N) - 1 + (mrl_index << 1)); __m256i a0, a1, a2, a3; @@ -3800,8 +3799,8 @@ __m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1)); - a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]); - max_base_x256 = _mm256_set1_epi16(max_base_x); + a_mbase_x = _mm256_set1_epi16(above[max_base_x]); + max_base_x256 = _mm256_set1_epi16(max_base_x + 1); int shift_i; int x = dx * (1 + mrl_index); @@ -3809,7 +3808,7 @@ __m256i res[2], res1; int base = x >> frac_bits; - if (base >= max_base_x) { + if (base > max_base_x) { for (int i = r; i < N; ++i) { _mm256_storeu_si256((__m256i *)dst, a_mbase_x); // save 32 values _mm256_storeu_si256((__m256i *)(dst + 16), a_mbase_x); @@ -3828,7 +3827,7 @@ f3 = _mm256_set1_epi32(av1_dr_interp_filter[shift_i][3]); for (int j = 0; j < 64; j += 16) { - int mdif = max_base_x - (base + j); + int mdif = max_base_x + 1 - (base + j); if (mdif <= 0) { _mm256_storeu_si256((__m256i *)(dst + j), a_mbase_x); } else {