Add clamping to inverse 64-point Tx (SSE4)

BUG=aomedia:1751

Change-Id: I9b2cb51bdf8b69c572e6b9dd8576b731a205920f
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index de40739..c57b542 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -1761,10 +1761,29 @@
   write_buffer_32x32(in32x32, rightDown, stride, fliplr, flipud, shift, bd);
 }
 
-static void idct64x64_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols) {
+static void addsub_sse4_1(const __m128i in0, const __m128i in1, __m128i *out0,
+                          __m128i *out1, const __m128i clamp_lo,
+                          const __m128i clamp_hi) {
+  __m128i a0 = _mm_add_epi32(in0, in1);
+  __m128i a1 = _mm_sub_epi32(in0, in1);
+
+  a0 = _mm_max_epi32(a0, clamp_lo);
+  a0 = _mm_min_epi32(a0, clamp_hi);
+  a1 = _mm_max_epi32(a1, clamp_lo);
+  a1 = _mm_min_epi32(a1, clamp_hi);
+
+  *out0 = a0;
+  *out1 = a1;
+}
+
+static void idct64x64_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols,
+                             int bd) {
   int i, j;
   const int32_t *cospi = cospi_arr(bit);
   const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+  const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
   int col;
 
   const __m128i cospi1 = _mm_set1_epi32(cospi[1]);
@@ -1941,10 +1960,10 @@
     u[31] = half_btf_0_sse4_1(&cospi2, &v[16], &rnding, bit);
 
     for (i = 32; i < 64; i += 4) {
-      u[i + 0] = _mm_add_epi32(v[i + 0], v[i + 1]);
-      u[i + 1] = _mm_sub_epi32(v[i + 0], v[i + 1]);
-      u[i + 2] = _mm_sub_epi32(v[i + 3], v[i + 2]);
-      u[i + 3] = _mm_add_epi32(v[i + 3], v[i + 2]);
+      addsub_sse4_1(v[i + 0], v[i + 1], &u[i + 0], &u[i + 1], clamp_lo,
+                    clamp_hi);
+      addsub_sse4_1(v[i + 3], v[i + 2], &u[i + 3], &u[i + 2], clamp_lo,
+                    clamp_hi);
     }
 
     // stage 4
@@ -1958,10 +1977,10 @@
     v[15] = half_btf_0_sse4_1(&cospi4, &u[8], &rnding, bit);
 
     for (i = 16; i < 32; i += 4) {
-      v[i + 0] = _mm_add_epi32(u[i + 0], u[i + 1]);
-      v[i + 1] = _mm_sub_epi32(u[i + 0], u[i + 1]);
-      v[i + 2] = _mm_sub_epi32(u[i + 3], u[i + 2]);
-      v[i + 3] = _mm_add_epi32(u[i + 3], u[i + 2]);
+      addsub_sse4_1(u[i + 0], u[i + 1], &v[i + 0], &v[i + 1], clamp_lo,
+                    clamp_hi);
+      addsub_sse4_1(u[i + 3], u[i + 2], &v[i + 3], &v[i + 2], clamp_lo,
+                    clamp_hi);
     }
 
     for (i = 32; i < 64; i += 4) {
@@ -1993,10 +2012,10 @@
     u[7] = half_btf_0_sse4_1(&cospi8, &v[4], &rnding, bit);
 
     for (i = 8; i < 16; i += 4) {
-      u[i + 0] = _mm_add_epi32(v[i + 0], v[i + 1]);
-      u[i + 1] = _mm_sub_epi32(v[i + 0], v[i + 1]);
-      u[i + 2] = _mm_sub_epi32(v[i + 3], v[i + 2]);
-      u[i + 3] = _mm_add_epi32(v[i + 3], v[i + 2]);
+      addsub_sse4_1(v[i + 0], v[i + 1], &u[i + 0], &u[i + 1], clamp_lo,
+                    clamp_hi);
+      addsub_sse4_1(v[i + 3], v[i + 2], &u[i + 3], &u[i + 2], clamp_lo,
+                    clamp_hi);
     }
 
     for (i = 16; i < 32; i += 4) {
@@ -2014,15 +2033,15 @@
     u[30] = half_btf_sse4_1(&cospi56, &v[17], &cospi8, &v[30], &rnding, bit);
 
     for (i = 32; i < 64; i += 8) {
-      u[i + 0] = _mm_add_epi32(v[i + 0], v[i + 3]);
-      u[i + 1] = _mm_add_epi32(v[i + 1], v[i + 2]);
-      u[i + 2] = _mm_sub_epi32(v[i + 1], v[i + 2]);
-      u[i + 3] = _mm_sub_epi32(v[i + 0], v[i + 3]);
+      addsub_sse4_1(v[i + 0], v[i + 3], &u[i + 0], &u[i + 3], clamp_lo,
+                    clamp_hi);
+      addsub_sse4_1(v[i + 1], v[i + 2], &u[i + 1], &u[i + 2], clamp_lo,
+                    clamp_hi);
 
-      u[i + 4] = _mm_sub_epi32(v[i + 7], v[i + 4]);
-      u[i + 5] = _mm_sub_epi32(v[i + 6], v[i + 5]);
-      u[i + 6] = _mm_add_epi32(v[i + 6], v[i + 5]);
-      u[i + 7] = _mm_add_epi32(v[i + 7], v[i + 4]);
+      addsub_sse4_1(v[i + 7], v[i + 4], &u[i + 7], &u[i + 4], clamp_lo,
+                    clamp_hi);
+      addsub_sse4_1(v[i + 6], v[i + 5], &u[i + 6], &u[i + 5], clamp_lo,
+                    clamp_hi);
     }
 
     // stage 6
@@ -2031,10 +2050,8 @@
     v[2] = half_btf_0_sse4_1(&cospi48, &u[2], &rnding, bit);
     v[3] = half_btf_0_sse4_1(&cospi16, &u[2], &rnding, bit);
 
-    v[4] = _mm_add_epi32(u[4], u[5]);
-    v[5] = _mm_sub_epi32(u[4], u[5]);
-    v[6] = _mm_sub_epi32(u[7], u[6]);
-    v[7] = _mm_add_epi32(u[7], u[6]);
+    addsub_sse4_1(u[4], u[5], &v[4], &v[5], clamp_lo, clamp_hi);
+    addsub_sse4_1(u[7], u[6], &v[7], &v[6], clamp_lo, clamp_hi);
 
     for (i = 8; i < 16; i += 4) {
       v[i + 0] = u[i + 0];
@@ -2047,15 +2064,15 @@
     v[14] = half_btf_sse4_1(&cospi48, &u[9], &cospi16, &u[14], &rnding, bit);
 
     for (i = 16; i < 32; i += 8) {
-      v[i + 0] = _mm_add_epi32(u[i + 0], u[i + 3]);
-      v[i + 1] = _mm_add_epi32(u[i + 1], u[i + 2]);
-      v[i + 2] = _mm_sub_epi32(u[i + 1], u[i + 2]);
-      v[i + 3] = _mm_sub_epi32(u[i + 0], u[i + 3]);
+      addsub_sse4_1(u[i + 0], u[i + 3], &v[i + 0], &v[i + 3], clamp_lo,
+                    clamp_hi);
+      addsub_sse4_1(u[i + 1], u[i + 2], &v[i + 1], &v[i + 2], clamp_lo,
+                    clamp_hi);
 
-      v[i + 4] = _mm_sub_epi32(u[i + 7], u[i + 4]);
-      v[i + 5] = _mm_sub_epi32(u[i + 6], u[i + 5]);
-      v[i + 6] = _mm_add_epi32(u[i + 6], u[i + 5]);
-      v[i + 7] = _mm_add_epi32(u[i + 7], u[i + 4]);
+      addsub_sse4_1(u[i + 7], u[i + 4], &v[i + 7], &v[i + 4], clamp_lo,
+                    clamp_hi);
+      addsub_sse4_1(u[i + 6], u[i + 5], &v[i + 6], &v[i + 5], clamp_lo,
+                    clamp_hi);
     }
 
     for (i = 32; i < 64; i += 8) {
@@ -2083,24 +2100,18 @@
     v[61] = half_btf_sse4_1(&cospi56, &u[34], &cospi8, &u[61], &rnding, bit);
 
     // stage 7
-    u[0] = _mm_add_epi32(v[0], v[3]);
-    u[1] = _mm_add_epi32(v[1], v[2]);
-    u[2] = _mm_sub_epi32(v[1], v[2]);
-    u[3] = _mm_sub_epi32(v[0], v[3]);
+    addsub_sse4_1(v[0], v[3], &u[0], &u[3], clamp_lo, clamp_hi);
+    addsub_sse4_1(v[1], v[2], &u[1], &u[2], clamp_lo, clamp_hi);
 
     u[4] = v[4];
     u[7] = v[7];
     u[5] = half_btf_sse4_1(&cospim32, &v[5], &cospi32, &v[6], &rnding, bit);
     u[6] = half_btf_sse4_1(&cospi32, &v[5], &cospi32, &v[6], &rnding, bit);
 
-    u[8] = _mm_add_epi32(v[8], v[11]);
-    u[9] = _mm_add_epi32(v[9], v[10]);
-    u[10] = _mm_sub_epi32(v[9], v[10]);
-    u[11] = _mm_sub_epi32(v[8], v[11]);
-    u[12] = _mm_sub_epi32(v[15], v[12]);
-    u[13] = _mm_sub_epi32(v[14], v[13]);
-    u[14] = _mm_add_epi32(v[14], v[13]);
-    u[15] = _mm_add_epi32(v[15], v[12]);
+    addsub_sse4_1(v[8], v[11], &u[8], &u[11], clamp_lo, clamp_hi);
+    addsub_sse4_1(v[9], v[10], &u[9], &u[10], clamp_lo, clamp_hi);
+    addsub_sse4_1(v[15], v[12], &u[15], &u[12], clamp_lo, clamp_hi);
+    addsub_sse4_1(v[14], v[13], &u[14], &u[13], clamp_lo, clamp_hi);
 
     for (i = 16; i < 32; i += 8) {
       u[i + 0] = v[i + 0];
@@ -2120,17 +2131,15 @@
 
     for (i = 32; i < 64; i += 16) {
       for (j = i; j < i + 4; j++) {
-        u[j] = _mm_add_epi32(v[j], v[j ^ 7]);
-        u[j ^ 7] = _mm_sub_epi32(v[j], v[j ^ 7]);
-        u[j ^ 8] = _mm_sub_epi32(v[j ^ 15], v[j ^ 8]);
-        u[j ^ 15] = _mm_add_epi32(v[j ^ 15], v[j ^ 8]);
+        addsub_sse4_1(v[j], v[j ^ 7], &u[j], &u[j ^ 7], clamp_lo, clamp_hi);
+        addsub_sse4_1(v[j ^ 15], v[j ^ 8], &u[j ^ 15], &u[j ^ 8], clamp_lo,
+                      clamp_hi);
       }
     }
 
     // stage 8
     for (i = 0; i < 4; ++i) {
-      v[i] = _mm_add_epi32(u[i], u[7 - i]);
-      v[7 - i] = _mm_sub_epi32(u[i], u[7 - i]);
+      addsub_sse4_1(u[i], u[7 - i], &v[i], &v[7 - i], clamp_lo, clamp_hi);
     }
 
     v[8] = u[8];
@@ -2144,10 +2153,9 @@
     v[13] = half_btf_sse4_1(&cospi32, &u[10], &cospi32, &u[13], &rnding, bit);
 
     for (i = 16; i < 20; ++i) {
-      v[i] = _mm_add_epi32(u[i], u[i ^ 7]);
-      v[i ^ 7] = _mm_sub_epi32(u[i], u[i ^ 7]);
-      v[i ^ 8] = _mm_sub_epi32(u[i ^ 15], u[i ^ 8]);
-      v[i ^ 15] = _mm_add_epi32(u[i ^ 15], u[i ^ 8]);
+      addsub_sse4_1(u[i], u[i ^ 7], &v[i], &v[i ^ 7], clamp_lo, clamp_hi);
+      addsub_sse4_1(u[i ^ 15], u[i ^ 8], &v[i ^ 15], &v[i ^ 8], clamp_lo,
+                    clamp_hi);
     }
 
     for (i = 32; i < 36; ++i) {
@@ -2176,8 +2184,7 @@
 
     // stage 9
     for (i = 0; i < 8; ++i) {
-      u[i] = _mm_add_epi32(v[i], v[15 - i]);
-      u[15 - i] = _mm_sub_epi32(v[i], v[15 - i]);
+      addsub_sse4_1(v[i], v[15 - i], &u[i], &u[15 - i], clamp_lo, clamp_hi);
     }
 
     for (i = 16; i < 20; ++i) {
@@ -2195,19 +2202,16 @@
     u[27] = half_btf_sse4_1(&cospi32, &v[20], &cospi32, &v[27], &rnding, bit);
 
     for (i = 32; i < 40; i++) {
-      u[i] = _mm_add_epi32(v[i], v[i ^ 15]);
-      u[i ^ 15] = _mm_sub_epi32(v[i], v[i ^ 15]);
+      addsub_sse4_1(v[i], v[i ^ 15], &u[i], &u[i ^ 15], clamp_lo, clamp_hi);
     }
 
     for (i = 48; i < 56; i++) {
-      u[i] = _mm_sub_epi32(v[i ^ 15], v[i]);
-      u[i ^ 15] = _mm_add_epi32(v[i ^ 15], v[i]);
+      addsub_sse4_1(v[i ^ 15], v[i], &u[i ^ 15], &u[i], clamp_lo, clamp_hi);
     }
 
     // stage 10
     for (i = 0; i < 16; i++) {
-      v[i] = _mm_add_epi32(u[i], u[31 - i]);
-      v[31 - i] = _mm_sub_epi32(u[i], u[31 - i]);
+      addsub_sse4_1(u[i], u[31 - i], &v[i], &v[31 - i], clamp_lo, clamp_hi);
     }
 
     for (i = 32; i < 40; i++) v[i] = u[i];
@@ -2233,8 +2237,8 @@
 
     // stage 11
     for (i = 0; i < 32; i++) {
-      out[16 * (i) + col] = _mm_add_epi32(v[i], v[63 - i]);
-      out[16 * (63 - i) + col] = _mm_sub_epi32(v[i], v[63 - i]);
+      addsub_sse4_1(v[i], v[63 - i], &out[16 * (i) + col],
+                    &out[16 * (63 - i) + col], clamp_lo, clamp_hi);
     }
   }
 }
@@ -2250,11 +2254,11 @@
     case DCT_DCT:
       load_buffer_64x64_lower_32x32(coeff, in);
       transpose_64x64(in, out, 0);
-      idct64x64_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0);
+      idct64x64_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd);
       // transpose before shift, so shift can apply to 512 contiguous values
       transpose_64x64(in, out, 1);
       round_shift_64x64(out, -shift[0]);
-      idct64x64_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1);
+      idct64x64_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd);
       write_buffer_64x64(in, output, stride, 0, 0, -shift[1], bd);
       break;