clean up SIMD for IDIF to match c function logic
diff --git a/aom_dsp/x86/intrapred_avx2.c b/aom_dsp/x86/intrapred_avx2.c
index 39dac25..049a272 100644
--- a/aom_dsp/x86/intrapred_avx2.c
+++ b/aom_dsp/x86/intrapred_avx2.c
@@ -2983,8 +2983,7 @@
static AOM_FORCE_INLINE void highbd_dr_prediction_z1_4xN_internal_idif_avx2(
int N, __m128i *dst, const uint16_t *above, int dx, int mrl_index, int bd) {
const int frac_bits = 6;
- // max base for the 4-tap filter is on the last ref sample (+1 to re-use code)
- const int max_base_x = ((N + 4) + (mrl_index << 1));
+ const int max_base_x = ((N + 4) - 1 + (mrl_index << 1));
assert(dx > 0);
__m256i a0, a1, a2, a3;
@@ -2994,8 +2993,8 @@
__m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm_set1_epi16(above[max_base_x - 1]);
- max_base_x128 = _mm_set1_epi16(max_base_x);
+ a_mbase_x = _mm_set1_epi16(above[max_base_x]);
+ max_base_x128 = _mm_set1_epi16(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
@@ -3003,7 +3002,7 @@
__m128i res1;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
dst[i] = a_mbase_x; // save 4 values
}
@@ -3050,7 +3049,7 @@
int dx, int mrl_index,
int bd) {
const int frac_bits = 6;
- const int max_base_x = ((N + 4) + (mrl_index << 1));
+ const int max_base_x = ((N + 4) - 1 + (mrl_index << 1));
assert(dx > 0);
__m256i a0, a1, a2, a3;
@@ -3060,8 +3059,8 @@
__m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm_set1_epi16(above[max_base_x - 1]);
- max_base_x128 = _mm_set1_epi32(max_base_x);
+ a_mbase_x = _mm_set1_epi16(above[max_base_x]);
+ max_base_x128 = _mm_set1_epi32(max_base_x + 1);
int x = dx * (1 + mrl_index);
int shift_i;
@@ -3069,7 +3068,7 @@
__m128i res1;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
dst[i] = a_mbase_x; // save 4 values
}
@@ -3145,7 +3144,7 @@
static AOM_FORCE_INLINE void highbd_dr_prediction_z1_8xN_internal_idif_avx2(
int N, __m128i *dst, const uint16_t *above, int dx, int mrl_index, int bd) {
const int frac_bits = 6;
- const int max_base_x = ((N + 8) + (mrl_index << 1));
+ const int max_base_x = ((N + 8) - 1 + (mrl_index << 1));
assert(dx > 0);
__m256i a0, a1, a2, a3;
@@ -3155,8 +3154,8 @@
__m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]);
- max_base_x256 = _mm256_set1_epi16(max_base_x);
+ a_mbase_x = _mm256_set1_epi16(above[max_base_x]);
+ max_base_x256 = _mm256_set1_epi16(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
@@ -3164,7 +3163,7 @@
__m256i res1;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
dst[i] = _mm256_castsi256_si128(a_mbase_x); // save 8 values
}
@@ -3212,7 +3211,7 @@
int dx, int mrl_index,
int bd) {
const int frac_bits = 6;
- const int max_base_x = ((N + 8) + (mrl_index << 1));
+ const int max_base_x = ((N + 8) - 1 + (mrl_index << 1));
assert(dx > 0);
__m256i a0, a1, a2, a3;
@@ -3222,8 +3221,8 @@
__m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]);
- max_base_x256 = _mm256_set1_epi32(max_base_x);
+ a_mbase_x = _mm256_set1_epi16(above[max_base_x]);
+ max_base_x256 = _mm256_set1_epi32(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
@@ -3231,7 +3230,7 @@
__m256i res1;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
dst[i] = _mm256_castsi256_si128(a_mbase_x); // save 8 values
}
@@ -3314,7 +3313,7 @@
int N, __m256i *dstvec, const uint16_t *above, int dx, int mrl_index,
int bd) {
const int frac_bits = 6;
- const int max_base_x = ((16 + N) + (mrl_index << 1));
+ const int max_base_x = ((16 + N) - 1 + (mrl_index << 1));
__m256i a_mbase_x, max_base_x256, base_inc256, mask256;
@@ -3324,14 +3323,14 @@
__m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]);
- max_base_x256 = _mm256_set1_epi16(max_base_x);
+ a_mbase_x = _mm256_set1_epi16(above[max_base_x]);
+ max_base_x256 = _mm256_set1_epi16(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
for (int r = 0; r < N; r++) {
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
dstvec[i] = a_mbase_x; // save 16 values
}
@@ -3379,7 +3378,7 @@
int dx, int mrl_index,
int bd) {
const int frac_bits = 6;
- const int max_base_x = ((16 + N) + (mrl_index << 1));
+ const int max_base_x = ((16 + N) - 1 + (mrl_index << 1));
__m256i a0, a1, a2, a3;
__m256i val0, val1;
__m256i f0, f1, f2, f3;
@@ -3387,8 +3386,8 @@
__m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]);
- max_base_x256 = _mm256_set1_epi16(max_base_x);
+ a_mbase_x = _mm256_set1_epi16(above[max_base_x]);
+ max_base_x256 = _mm256_set1_epi16(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
@@ -3396,7 +3395,7 @@
__m256i res[2], res1;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
dstvec[i] = a_mbase_x; // save 16 values
}
@@ -3430,7 +3429,7 @@
res[0] = _mm256_packus_epi32(
val0, _mm256_castsi128_si256(_mm256_extracti128_si256(val0, 1)));
- int mdif = max_base_x - base;
+ int mdif = max_base_x + 1 - base;
if (mdif > 8) {
a0 =
_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i *)(above + base + 7)));
@@ -3496,7 +3495,7 @@
int N, __m256i *dstvec, const uint16_t *above, int dx, int mrl_index,
int bd) {
const int frac_bits = 6;
- const int max_base_x = ((32 + N) + (mrl_index << 1));
+ const int max_base_x = ((32 + N) - 1 + (mrl_index << 1));
__m256i a_mbase_x, max_base_x256, base_inc256, mask256;
@@ -3506,8 +3505,8 @@
__m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]);
- max_base_x256 = _mm256_set1_epi16(max_base_x);
+ a_mbase_x = _mm256_set1_epi16(above[max_base_x]);
+ max_base_x256 = _mm256_set1_epi16(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
@@ -3515,7 +3514,7 @@
__m256i res;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
dstvec[i] = a_mbase_x; // save 32 values
dstvec[i + N] = a_mbase_x;
@@ -3531,7 +3530,7 @@
f3 = _mm256_set1_epi16(av1_dr_interp_filter[shift_i][3]);
for (int j = 0; j < 32; j += 16) {
- int mdif = max_base_x - (base + j);
+ int mdif = max_base_x + 1 - (base + j);
if (mdif <= 0) {
res = a_mbase_x;
} else {
@@ -3577,7 +3576,7 @@
int dx, int mrl_index,
int bd) {
const int frac_bits = 6;
- const int max_base_x = ((32 + N) + (mrl_index << 1));
+ const int max_base_x = ((32 + N) - 1 + (mrl_index << 1));
__m256i a_mbase_x, max_base_x256, base_inc256, mask256;
@@ -3587,8 +3586,8 @@
__m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]);
- max_base_x256 = _mm256_set1_epi16(max_base_x);
+ a_mbase_x = _mm256_set1_epi16(above[max_base_x]);
+ max_base_x256 = _mm256_set1_epi16(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
@@ -3596,7 +3595,7 @@
__m256i res[2], res1;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
dstvec[i] = a_mbase_x; // save 32 values
dstvec[i + N] = a_mbase_x;
@@ -3612,7 +3611,7 @@
f3 = _mm256_set1_epi32(av1_dr_interp_filter[shift_i][3]);
for (int j = 0; j < 32; j += 16) {
- int mdif = max_base_x - (base + j);
+ int mdif = max_base_x + 1 - (base + j);
if (mdif <= 0) {
res1 = a_mbase_x;
} else {
@@ -3711,7 +3710,7 @@
int N, uint16_t *dst, ptrdiff_t stride, const uint16_t *above, int dx,
int mrl_index, int bd) {
const int frac_bits = 6;
- const int max_base_x = ((64 + N) + (mrl_index << 1));
+ const int max_base_x = ((64 + N) - 1 + (mrl_index << 1));
__m256i a_mbase_x, max_base_x256, base_inc256, mask256;
@@ -3721,8 +3720,8 @@
__m256i rnding = _mm256_set1_epi16(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]);
- max_base_x256 = _mm256_set1_epi16(max_base_x);
+ a_mbase_x = _mm256_set1_epi16(above[max_base_x]);
+ max_base_x256 = _mm256_set1_epi16(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
@@ -3730,7 +3729,7 @@
__m256i res;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
_mm256_storeu_si256((__m256i *)dst, a_mbase_x); // save 32 values
_mm256_storeu_si256((__m256i *)(dst + 16), a_mbase_x);
@@ -3749,7 +3748,7 @@
f3 = _mm256_set1_epi16(av1_dr_interp_filter[shift_i][3]);
for (int j = 0; j < 64; j += 16) {
- int mdif = max_base_x - (base + j);
+ int mdif = max_base_x + 1 - (base + j);
if (mdif <= 0) {
_mm256_storeu_si256((__m256i *)(dst + j), a_mbase_x);
} else {
@@ -3789,7 +3788,7 @@
int N, uint16_t *dst, ptrdiff_t stride, const uint16_t *above, int dx,
int mrl_index, int bd) {
const int frac_bits = 6;
- const int max_base_x = ((64 + N) + (mrl_index << 1));
+ const int max_base_x = ((64 + N) - 1 + (mrl_index << 1));
__m256i a0, a1, a2, a3;
@@ -3800,8 +3799,8 @@
__m256i rnding = _mm256_set1_epi32(1 << (POWER_DR_INTERP_FILTER - 1));
- a_mbase_x = _mm256_set1_epi16(above[max_base_x - 1]);
- max_base_x256 = _mm256_set1_epi16(max_base_x);
+ a_mbase_x = _mm256_set1_epi16(above[max_base_x]);
+ max_base_x256 = _mm256_set1_epi16(max_base_x + 1);
int shift_i;
int x = dx * (1 + mrl_index);
@@ -3809,7 +3808,7 @@
__m256i res[2], res1;
int base = x >> frac_bits;
- if (base >= max_base_x) {
+ if (base > max_base_x) {
for (int i = r; i < N; ++i) {
_mm256_storeu_si256((__m256i *)dst, a_mbase_x); // save 32 values
_mm256_storeu_si256((__m256i *)(dst + 16), a_mbase_x);
@@ -3828,7 +3827,7 @@
f3 = _mm256_set1_epi32(av1_dr_interp_filter[shift_i][3]);
for (int j = 0; j < 64; j += 16) {
- int mdif = max_base_x - (base + j);
+ int mdif = max_base_x + 1 - (base + j);
if (mdif <= 0) {
_mm256_storeu_si256((__m256i *)(dst + j), a_mbase_x);
} else {