Fix bug in highbd inv txfm modules
In iidentity4_sse4_1(), iidentity16_sse4_1() and iadst4x4_sse4_1() functions
intermediate precision going beyond 32 bit was causing mismatch.
Made changes to handle it appropriately.
BUG=aomedia:2350
Change-Id: Ib6c8c8ee818ad61a5c671de3b163a456188737c1
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index 70be018..c6bf917 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -439,7 +439,10 @@
int bd, int out_shift) {
(void)out_shift;
const int32_t *sinpi = sinpi_arr(bit);
- const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
+ const __m128i zero = _mm_set1_epi32(0);
+ __m128i rnding = _mm_set1_epi32(1 << (bit + 4 - 1));
+ rnding = _mm_unpacklo_epi32(rnding, zero);
+ const __m128i mul = _mm_set1_epi32(1 << 4);
const __m128i sinpi1 = _mm_set1_epi32((int)sinpi[1]);
const __m128i sinpi2 = _mm_set1_epi32((int)sinpi[2]);
const __m128i sinpi3 = _mm_set1_epi32((int)sinpi[3]);
@@ -449,6 +452,8 @@
__m128i x0, x1, x2, x3;
__m128i u0, u1, u2, u3;
__m128i v0, v1, v2, v3;
+ __m128i u0_low, u1_low, u2_low, u3_low;
+ __m128i u0_high, u1_high, u2_high, u3_high;
v0 = _mm_unpacklo_epi32(in[0], in[1]);
v1 = _mm_unpackhi_epi32(in[0], in[1]);
@@ -483,17 +488,65 @@
t = _mm_add_epi32(s0, s1);
u3 = _mm_sub_epi32(t, s3);
- u0 = _mm_add_epi32(u0, rnding);
- u0 = _mm_srai_epi32(u0, bit);
+ // u0
+ u0_low = _mm_mul_epi32(u0, mul);
+ u0_low = _mm_add_epi64(u0_low, rnding);
- u1 = _mm_add_epi32(u1, rnding);
- u1 = _mm_srai_epi32(u1, bit);
+ u0 = _mm_srli_si128(u0, 4);
+ u0_high = _mm_mul_epi32(u0, mul);
+ u0_high = _mm_add_epi64(u0_high, rnding);
- u2 = _mm_add_epi32(u2, rnding);
- u2 = _mm_srai_epi32(u2, bit);
+ u0_low = _mm_srli_si128(u0_low, 2);
+ u0_high = _mm_srli_si128(u0_high, 2);
- u3 = _mm_add_epi32(u3, rnding);
- u3 = _mm_srai_epi32(u3, bit);
+ u0 = _mm_unpacklo_epi32(u0_low, u0_high);
+ u0_high = _mm_unpackhi_epi32(u0_low, u0_high);
+ u0 = _mm_unpacklo_epi64(u0, u0_high);
+
+ // u1
+ u1_low = _mm_mul_epi32(u1, mul);
+ u1_low = _mm_add_epi64(u1_low, rnding);
+
+ u1 = _mm_srli_si128(u1, 4);
+ u1_high = _mm_mul_epi32(u1, mul);
+ u1_high = _mm_add_epi64(u1_high, rnding);
+
+ u1_low = _mm_srli_si128(u1_low, 2);
+ u1_high = _mm_srli_si128(u1_high, 2);
+
+ u1 = _mm_unpacklo_epi32(u1_low, u1_high);
+ u1_high = _mm_unpackhi_epi32(u1_low, u1_high);
+ u1 = _mm_unpacklo_epi64(u1, u1_high);
+
+ // u2
+ u2_low = _mm_mul_epi32(u2, mul);
+ u2_low = _mm_add_epi64(u2_low, rnding);
+
+ u2 = _mm_srli_si128(u2, 4);
+ u2_high = _mm_mul_epi32(u2, mul);
+ u2_high = _mm_add_epi64(u2_high, rnding);
+
+ u2_low = _mm_srli_si128(u2_low, 2);
+ u2_high = _mm_srli_si128(u2_high, 2);
+
+ u2 = _mm_unpacklo_epi32(u2_low, u2_high);
+ u2_high = _mm_unpackhi_epi32(u2_low, u2_high);
+ u2 = _mm_unpacklo_epi64(u2, u2_high);
+
+ // u3
+ u3_low = _mm_mul_epi32(u3, mul);
+ u3_low = _mm_add_epi64(u3_low, rnding);
+
+ u3 = _mm_srli_si128(u3, 4);
+ u3_high = _mm_mul_epi32(u3, mul);
+ u3_high = _mm_add_epi64(u3_high, rnding);
+
+ u3_low = _mm_srli_si128(u3_low, 2);
+ u3_high = _mm_srli_si128(u3_high, 2);
+
+ u3 = _mm_unpacklo_epi32(u3_low, u3_high);
+ u3_high = _mm_unpackhi_epi32(u3_low, u3_high);
+ u3 = _mm_unpacklo_epi64(u3, u3_high);
if (!do_cols) {
const int log_range = AOMMAX(16, bd + 6);
@@ -606,23 +659,28 @@
(void)bit;
(void)out_shift;
__m128i v[4];
+ __m128i zero = _mm_set1_epi32(0);
__m128i fact = _mm_set1_epi32(NewSqrt2);
__m128i offset = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
- __m128i a0, a1;
+ __m128i a0_low, a1_low;
+ __m128i a0_high, a1_high;
- a0 = _mm_mullo_epi32(in[0], fact);
- a1 = _mm_mullo_epi32(in[1], fact);
- a0 = _mm_add_epi32(a0, offset);
- a1 = _mm_add_epi32(a1, offset);
- out[0] = _mm_srai_epi32(a0, NewSqrt2Bits);
- out[1] = _mm_srai_epi32(a1, NewSqrt2Bits);
+ offset = _mm_unpacklo_epi32(offset, zero);
- a0 = _mm_mullo_epi32(in[2], fact);
- a1 = _mm_mullo_epi32(in[3], fact);
- a0 = _mm_add_epi32(a0, offset);
- a1 = _mm_add_epi32(a1, offset);
- out[2] = _mm_srai_epi32(a0, NewSqrt2Bits);
- out[3] = _mm_srai_epi32(a1, NewSqrt2Bits);
+ for (int i = 0; i < 4; i++) {
+ a0_low = _mm_mul_epi32(in[i], fact);
+ a0_low = _mm_add_epi32(a0_low, offset);
+ a0_low = _mm_srli_epi64(a0_low, NewSqrt2Bits);
+
+ a0_high = _mm_srli_si128(in[i], 4);
+ a0_high = _mm_mul_epi32(a0_high, fact);
+ a0_high = _mm_add_epi32(a0_high, offset);
+ a0_high = _mm_srli_epi64(a0_high, NewSqrt2Bits);
+
+ a1_low = _mm_unpacklo_epi32(a0_low, a0_high);
+ a1_high = _mm_unpackhi_epi32(a0_low, a0_high);
+ out[i] = _mm_unpacklo_epi64(a1_low, a1_high);
+ }
if (!do_cols) {
const int log_range = AOMMAX(16, bd + 6);
@@ -3173,36 +3231,23 @@
__m128i v[16];
__m128i fact = _mm_set1_epi32(2 * NewSqrt2);
__m128i offset = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
- __m128i a0, a1, a2, a3;
+ __m128i a0_low, a0_high, a1_low, a1_high;
+ __m128i zero = _mm_set1_epi32(0);
+ offset = _mm_unpacklo_epi32(offset, zero);
- for (int i = 0; i < 16; i += 8) {
- a0 = _mm_mullo_epi32(in[i], fact);
- a1 = _mm_mullo_epi32(in[i + 1], fact);
- a0 = _mm_add_epi32(a0, offset);
- a1 = _mm_add_epi32(a1, offset);
- v[i] = _mm_srai_epi32(a0, NewSqrt2Bits);
- v[i + 1] = _mm_srai_epi32(a1, NewSqrt2Bits);
+ for (int i = 0; i < 16; i++) {
+ a0_low = _mm_mul_epi32(in[i], fact);
+ a0_low = _mm_add_epi32(a0_low, offset);
+ a0_low = _mm_srli_epi64(a0_low, NewSqrt2Bits);
- a2 = _mm_mullo_epi32(in[i + 2], fact);
- a3 = _mm_mullo_epi32(in[i + 3], fact);
- a2 = _mm_add_epi32(a2, offset);
- a3 = _mm_add_epi32(a3, offset);
- v[i + 2] = _mm_srai_epi32(a2, NewSqrt2Bits);
- v[i + 3] = _mm_srai_epi32(a3, NewSqrt2Bits);
+ a0_high = _mm_srli_si128(in[i], 4);
+ a0_high = _mm_mul_epi32(a0_high, fact);
+ a0_high = _mm_add_epi32(a0_high, offset);
+ a0_high = _mm_srli_epi64(a0_high, NewSqrt2Bits);
- a0 = _mm_mullo_epi32(in[i + 4], fact);
- a1 = _mm_mullo_epi32(in[i + 5], fact);
- a0 = _mm_add_epi32(a0, offset);
- a1 = _mm_add_epi32(a1, offset);
- v[i + 4] = _mm_srai_epi32(a0, NewSqrt2Bits);
- v[i + 5] = _mm_srai_epi32(a1, NewSqrt2Bits);
-
- a2 = _mm_mullo_epi32(in[i + 6], fact);
- a3 = _mm_mullo_epi32(in[i + 7], fact);
- a2 = _mm_add_epi32(a2, offset);
- a3 = _mm_add_epi32(a3, offset);
- v[i + 6] = _mm_srai_epi32(a2, NewSqrt2Bits);
- v[i + 7] = _mm_srai_epi32(a3, NewSqrt2Bits);
+ a1_low = _mm_unpacklo_epi32(a0_low, a0_high);
+ a1_high = _mm_unpackhi_epi32(a0_low, a0_high);
+ v[i] = _mm_unpacklo_epi64(a1_low, a1_high);
}
if (!do_cols) {