Fix bug in highbd inv txfm modules

In iidentity4_sse4_1(), iidentity16_sse4_1() and iadst4x4_sse4_1() functions
intermediate precision going beyond 32 bit was causing mismatch.
Made changes to handle it appropriately.

BUG=aomedia:2350

Change-Id: Ib6c8c8ee818ad61a5c671de3b163a456188737c1
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index 70be018..c6bf917 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -439,7 +439,10 @@
                             int bd, int out_shift) {
   (void)out_shift;
   const int32_t *sinpi = sinpi_arr(bit);
-  const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
+  const __m128i zero = _mm_set1_epi32(0);
+  __m128i rnding = _mm_set1_epi32(1 << (bit + 4 - 1));
+  rnding = _mm_unpacklo_epi32(rnding, zero);
+  const __m128i mul = _mm_set1_epi32(1 << 4);
   const __m128i sinpi1 = _mm_set1_epi32((int)sinpi[1]);
   const __m128i sinpi2 = _mm_set1_epi32((int)sinpi[2]);
   const __m128i sinpi3 = _mm_set1_epi32((int)sinpi[3]);
@@ -449,6 +452,8 @@
   __m128i x0, x1, x2, x3;
   __m128i u0, u1, u2, u3;
   __m128i v0, v1, v2, v3;
+  __m128i u0_low, u1_low, u2_low, u3_low;
+  __m128i u0_high, u1_high, u2_high, u3_high;
 
   v0 = _mm_unpacklo_epi32(in[0], in[1]);
   v1 = _mm_unpackhi_epi32(in[0], in[1]);
@@ -483,17 +488,65 @@
   t = _mm_add_epi32(s0, s1);
   u3 = _mm_sub_epi32(t, s3);
 
-  u0 = _mm_add_epi32(u0, rnding);
-  u0 = _mm_srai_epi32(u0, bit);
+  // u0
+  u0_low = _mm_mul_epi32(u0, mul);
+  u0_low = _mm_add_epi64(u0_low, rnding);
 
-  u1 = _mm_add_epi32(u1, rnding);
-  u1 = _mm_srai_epi32(u1, bit);
+  u0 = _mm_srli_si128(u0, 4);
+  u0_high = _mm_mul_epi32(u0, mul);
+  u0_high = _mm_add_epi64(u0_high, rnding);
 
-  u2 = _mm_add_epi32(u2, rnding);
-  u2 = _mm_srai_epi32(u2, bit);
+  u0_low = _mm_srli_si128(u0_low, 2);
+  u0_high = _mm_srli_si128(u0_high, 2);
 
-  u3 = _mm_add_epi32(u3, rnding);
-  u3 = _mm_srai_epi32(u3, bit);
+  u0 = _mm_unpacklo_epi32(u0_low, u0_high);
+  u0_high = _mm_unpackhi_epi32(u0_low, u0_high);
+  u0 = _mm_unpacklo_epi64(u0, u0_high);
+
+  // u1
+  u1_low = _mm_mul_epi32(u1, mul);
+  u1_low = _mm_add_epi64(u1_low, rnding);
+
+  u1 = _mm_srli_si128(u1, 4);
+  u1_high = _mm_mul_epi32(u1, mul);
+  u1_high = _mm_add_epi64(u1_high, rnding);
+
+  u1_low = _mm_srli_si128(u1_low, 2);
+  u1_high = _mm_srli_si128(u1_high, 2);
+
+  u1 = _mm_unpacklo_epi32(u1_low, u1_high);
+  u1_high = _mm_unpackhi_epi32(u1_low, u1_high);
+  u1 = _mm_unpacklo_epi64(u1, u1_high);
+
+  // u2
+  u2_low = _mm_mul_epi32(u2, mul);
+  u2_low = _mm_add_epi64(u2_low, rnding);
+
+  u2 = _mm_srli_si128(u2, 4);
+  u2_high = _mm_mul_epi32(u2, mul);
+  u2_high = _mm_add_epi64(u2_high, rnding);
+
+  u2_low = _mm_srli_si128(u2_low, 2);
+  u2_high = _mm_srli_si128(u2_high, 2);
+
+  u2 = _mm_unpacklo_epi32(u2_low, u2_high);
+  u2_high = _mm_unpackhi_epi32(u2_low, u2_high);
+  u2 = _mm_unpacklo_epi64(u2, u2_high);
+
+  // u3
+  u3_low = _mm_mul_epi32(u3, mul);
+  u3_low = _mm_add_epi64(u3_low, rnding);
+
+  u3 = _mm_srli_si128(u3, 4);
+  u3_high = _mm_mul_epi32(u3, mul);
+  u3_high = _mm_add_epi64(u3_high, rnding);
+
+  u3_low = _mm_srli_si128(u3_low, 2);
+  u3_high = _mm_srli_si128(u3_high, 2);
+
+  u3 = _mm_unpacklo_epi32(u3_low, u3_high);
+  u3_high = _mm_unpackhi_epi32(u3_low, u3_high);
+  u3 = _mm_unpacklo_epi64(u3, u3_high);
 
   if (!do_cols) {
     const int log_range = AOMMAX(16, bd + 6);
@@ -606,23 +659,28 @@
   (void)bit;
   (void)out_shift;
   __m128i v[4];
+  __m128i zero = _mm_set1_epi32(0);
   __m128i fact = _mm_set1_epi32(NewSqrt2);
   __m128i offset = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
-  __m128i a0, a1;
+  __m128i a0_low, a1_low;
+  __m128i a0_high, a1_high;
 
-  a0 = _mm_mullo_epi32(in[0], fact);
-  a1 = _mm_mullo_epi32(in[1], fact);
-  a0 = _mm_add_epi32(a0, offset);
-  a1 = _mm_add_epi32(a1, offset);
-  out[0] = _mm_srai_epi32(a0, NewSqrt2Bits);
-  out[1] = _mm_srai_epi32(a1, NewSqrt2Bits);
+  offset = _mm_unpacklo_epi32(offset, zero);
 
-  a0 = _mm_mullo_epi32(in[2], fact);
-  a1 = _mm_mullo_epi32(in[3], fact);
-  a0 = _mm_add_epi32(a0, offset);
-  a1 = _mm_add_epi32(a1, offset);
-  out[2] = _mm_srai_epi32(a0, NewSqrt2Bits);
-  out[3] = _mm_srai_epi32(a1, NewSqrt2Bits);
+  for (int i = 0; i < 4; i++) {
+    a0_low = _mm_mul_epi32(in[i], fact);
+    a0_low = _mm_add_epi32(a0_low, offset);
+    a0_low = _mm_srli_epi64(a0_low, NewSqrt2Bits);
+
+    a0_high = _mm_srli_si128(in[i], 4);
+    a0_high = _mm_mul_epi32(a0_high, fact);
+    a0_high = _mm_add_epi32(a0_high, offset);
+    a0_high = _mm_srli_epi64(a0_high, NewSqrt2Bits);
+
+    a1_low = _mm_unpacklo_epi32(a0_low, a0_high);
+    a1_high = _mm_unpackhi_epi32(a0_low, a0_high);
+    out[i] = _mm_unpacklo_epi64(a1_low, a1_high);
+  }
 
   if (!do_cols) {
     const int log_range = AOMMAX(16, bd + 6);
@@ -3173,36 +3231,23 @@
   __m128i v[16];
   __m128i fact = _mm_set1_epi32(2 * NewSqrt2);
   __m128i offset = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
-  __m128i a0, a1, a2, a3;
+  __m128i a0_low, a0_high, a1_low, a1_high;
+  __m128i zero = _mm_set1_epi32(0);
+  offset = _mm_unpacklo_epi32(offset, zero);
 
-  for (int i = 0; i < 16; i += 8) {
-    a0 = _mm_mullo_epi32(in[i], fact);
-    a1 = _mm_mullo_epi32(in[i + 1], fact);
-    a0 = _mm_add_epi32(a0, offset);
-    a1 = _mm_add_epi32(a1, offset);
-    v[i] = _mm_srai_epi32(a0, NewSqrt2Bits);
-    v[i + 1] = _mm_srai_epi32(a1, NewSqrt2Bits);
+  for (int i = 0; i < 16; i++) {
+    a0_low = _mm_mul_epi32(in[i], fact);
+    a0_low = _mm_add_epi32(a0_low, offset);
+    a0_low = _mm_srli_epi64(a0_low, NewSqrt2Bits);
 
-    a2 = _mm_mullo_epi32(in[i + 2], fact);
-    a3 = _mm_mullo_epi32(in[i + 3], fact);
-    a2 = _mm_add_epi32(a2, offset);
-    a3 = _mm_add_epi32(a3, offset);
-    v[i + 2] = _mm_srai_epi32(a2, NewSqrt2Bits);
-    v[i + 3] = _mm_srai_epi32(a3, NewSqrt2Bits);
+    a0_high = _mm_srli_si128(in[i], 4);
+    a0_high = _mm_mul_epi32(a0_high, fact);
+    a0_high = _mm_add_epi32(a0_high, offset);
+    a0_high = _mm_srli_epi64(a0_high, NewSqrt2Bits);
 
-    a0 = _mm_mullo_epi32(in[i + 4], fact);
-    a1 = _mm_mullo_epi32(in[i + 5], fact);
-    a0 = _mm_add_epi32(a0, offset);
-    a1 = _mm_add_epi32(a1, offset);
-    v[i + 4] = _mm_srai_epi32(a0, NewSqrt2Bits);
-    v[i + 5] = _mm_srai_epi32(a1, NewSqrt2Bits);
-
-    a2 = _mm_mullo_epi32(in[i + 6], fact);
-    a3 = _mm_mullo_epi32(in[i + 7], fact);
-    a2 = _mm_add_epi32(a2, offset);
-    a3 = _mm_add_epi32(a3, offset);
-    v[i + 6] = _mm_srai_epi32(a2, NewSqrt2Bits);
-    v[i + 7] = _mm_srai_epi32(a3, NewSqrt2Bits);
+    a1_low = _mm_unpacklo_epi32(a0_low, a0_high);
+    a1_high = _mm_unpackhi_epi32(a0_low, a0_high);
+    v[i] = _mm_unpacklo_epi64(a1_low, a1_high);
   }
 
   if (!do_cols) {