Add sse4_1 variant for highbd inv_txfm 32x32

Coded different variants of idct32x32_sse4_1 based on eobx logic.

Achieved module level gains of 7.7x on an average over all eob values.

Change-Id: I209c6d7e0c44b5c0c8b7ebda890c7698cda61bb8
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index c605906..dee1f1c 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -122,7 +122,7 @@
 add_proto qw/void av1_highbd_inv_txfm_add_16x16/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_16x16 sse4_1/;
 add_proto qw/void av1_highbd_inv_txfm_add_32x32/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_32x32 avx2/;
+specialize qw/av1_highbd_inv_txfm_add_32x32 sse4_1 avx2/;
 
 add_proto qw/void av1_highbd_iwht4x4_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
 add_proto qw/void av1_highbd_iwht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index d30b9be..e29e0ba 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -121,6 +121,219 @@
   *out1 = a1;
 }
 
+static INLINE void idct32_stage4_sse4_1(
+    __m128i *bf1, const __m128i *cospim8, const __m128i *cospi56,
+    const __m128i *cospi8, const __m128i *cospim56, const __m128i *cospim40,
+    const __m128i *cospi24, const __m128i *cospi40, const __m128i *cospim24,
+    const __m128i *rounding, int bit) {
+  __m128i temp1, temp2;
+  temp1 = half_btf_sse4_1(cospim8, &bf1[17], cospi56, &bf1[30], rounding, bit);
+  bf1[30] = half_btf_sse4_1(cospi56, &bf1[17], cospi8, &bf1[30], rounding, bit);
+  bf1[17] = temp1;
+
+  temp2 = half_btf_sse4_1(cospim56, &bf1[18], cospim8, &bf1[29], rounding, bit);
+  bf1[29] =
+      half_btf_sse4_1(cospim8, &bf1[18], cospi56, &bf1[29], rounding, bit);
+  bf1[18] = temp2;
+
+  temp1 = half_btf_sse4_1(cospim40, &bf1[21], cospi24, &bf1[26], rounding, bit);
+  bf1[26] =
+      half_btf_sse4_1(cospi24, &bf1[21], cospi40, &bf1[26], rounding, bit);
+  bf1[21] = temp1;
+
+  temp2 =
+      half_btf_sse4_1(cospim24, &bf1[22], cospim40, &bf1[25], rounding, bit);
+  bf1[25] =
+      half_btf_sse4_1(cospim40, &bf1[22], cospi24, &bf1[25], rounding, bit);
+  bf1[22] = temp2;
+}
+
+static INLINE void idct32_stage5_sse4_1(
+    __m128i *bf1, const __m128i *cospim16, const __m128i *cospi48,
+    const __m128i *cospi16, const __m128i *cospim48, const __m128i *clamp_lo,
+    const __m128i *clamp_hi, const __m128i *rounding, int bit) {
+  __m128i temp1, temp2;
+  temp1 = half_btf_sse4_1(cospim16, &bf1[9], cospi48, &bf1[14], rounding, bit);
+  bf1[14] = half_btf_sse4_1(cospi48, &bf1[9], cospi16, &bf1[14], rounding, bit);
+  bf1[9] = temp1;
+
+  temp2 =
+      half_btf_sse4_1(cospim48, &bf1[10], cospim16, &bf1[13], rounding, bit);
+  bf1[13] =
+      half_btf_sse4_1(cospim16, &bf1[10], cospi48, &bf1[13], rounding, bit);
+  bf1[10] = temp2;
+
+  addsub_sse4_1(bf1[16], bf1[19], bf1 + 16, bf1 + 19, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[17], bf1[18], bf1 + 17, bf1 + 18, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[23], bf1[20], bf1 + 23, bf1 + 20, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[22], bf1[21], bf1 + 22, bf1 + 21, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[24], bf1[27], bf1 + 24, bf1 + 27, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[25], bf1[26], bf1 + 25, bf1 + 26, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[31], bf1[28], bf1 + 31, bf1 + 28, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[30], bf1[29], bf1 + 30, bf1 + 29, clamp_lo, clamp_hi);
+}
+
+static INLINE void idct32_stage6_sse4_1(
+    __m128i *bf1, const __m128i *cospim32, const __m128i *cospi32,
+    const __m128i *cospim16, const __m128i *cospi48, const __m128i *cospi16,
+    const __m128i *cospim48, const __m128i *clamp_lo, const __m128i *clamp_hi,
+    const __m128i *rounding, int bit) {
+  __m128i temp1, temp2;
+  temp1 = half_btf_sse4_1(cospim32, &bf1[5], cospi32, &bf1[6], rounding, bit);
+  bf1[6] = half_btf_sse4_1(cospi32, &bf1[5], cospi32, &bf1[6], rounding, bit);
+  bf1[5] = temp1;
+
+  addsub_sse4_1(bf1[8], bf1[11], bf1 + 8, bf1 + 11, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[9], bf1[10], bf1 + 9, bf1 + 10, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[15], bf1[12], bf1 + 15, bf1 + 12, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[14], bf1[13], bf1 + 14, bf1 + 13, clamp_lo, clamp_hi);
+
+  temp1 = half_btf_sse4_1(cospim16, &bf1[18], cospi48, &bf1[29], rounding, bit);
+  bf1[29] =
+      half_btf_sse4_1(cospi48, &bf1[18], cospi16, &bf1[29], rounding, bit);
+  bf1[18] = temp1;
+  temp2 = half_btf_sse4_1(cospim16, &bf1[19], cospi48, &bf1[28], rounding, bit);
+  bf1[28] =
+      half_btf_sse4_1(cospi48, &bf1[19], cospi16, &bf1[28], rounding, bit);
+  bf1[19] = temp2;
+  temp1 =
+      half_btf_sse4_1(cospim48, &bf1[20], cospim16, &bf1[27], rounding, bit);
+  bf1[27] =
+      half_btf_sse4_1(cospim16, &bf1[20], cospi48, &bf1[27], rounding, bit);
+  bf1[20] = temp1;
+  temp2 =
+      half_btf_sse4_1(cospim48, &bf1[21], cospim16, &bf1[26], rounding, bit);
+  bf1[26] =
+      half_btf_sse4_1(cospim16, &bf1[21], cospi48, &bf1[26], rounding, bit);
+  bf1[21] = temp2;
+}
+
+static INLINE void idct32_stage7_sse4_1(__m128i *bf1, const __m128i *cospim32,
+                                        const __m128i *cospi32,
+                                        const __m128i *clamp_lo,
+                                        const __m128i *clamp_hi,
+                                        const __m128i *rounding, int bit) {
+  __m128i temp1, temp2;
+  addsub_sse4_1(bf1[0], bf1[7], bf1 + 0, bf1 + 7, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[1], bf1[6], bf1 + 1, bf1 + 6, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[2], bf1[5], bf1 + 2, bf1 + 5, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[3], bf1[4], bf1 + 3, bf1 + 4, clamp_lo, clamp_hi);
+
+  temp1 = half_btf_sse4_1(cospim32, &bf1[10], cospi32, &bf1[13], rounding, bit);
+  bf1[13] =
+      half_btf_sse4_1(cospi32, &bf1[10], cospi32, &bf1[13], rounding, bit);
+  bf1[10] = temp1;
+  temp2 = half_btf_sse4_1(cospim32, &bf1[11], cospi32, &bf1[12], rounding, bit);
+  bf1[12] =
+      half_btf_sse4_1(cospi32, &bf1[11], cospi32, &bf1[12], rounding, bit);
+  bf1[11] = temp2;
+
+  addsub_sse4_1(bf1[16], bf1[23], bf1 + 16, bf1 + 23, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[17], bf1[22], bf1 + 17, bf1 + 22, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[18], bf1[21], bf1 + 18, bf1 + 21, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[19], bf1[20], bf1 + 19, bf1 + 20, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[31], bf1[24], bf1 + 31, bf1 + 24, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[30], bf1[25], bf1 + 30, bf1 + 25, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[29], bf1[26], bf1 + 29, bf1 + 26, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[28], bf1[27], bf1 + 28, bf1 + 27, clamp_lo, clamp_hi);
+}
+
+static INLINE void idct32_stage8_sse4_1(__m128i *bf1, const __m128i *cospim32,
+                                        const __m128i *cospi32,
+                                        const __m128i *clamp_lo,
+                                        const __m128i *clamp_hi,
+                                        const __m128i *rounding, int bit) {
+  __m128i temp1, temp2;
+  addsub_sse4_1(bf1[0], bf1[15], bf1 + 0, bf1 + 15, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[1], bf1[14], bf1 + 1, bf1 + 14, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[2], bf1[13], bf1 + 2, bf1 + 13, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[3], bf1[12], bf1 + 3, bf1 + 12, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[4], bf1[11], bf1 + 4, bf1 + 11, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[5], bf1[10], bf1 + 5, bf1 + 10, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[6], bf1[9], bf1 + 6, bf1 + 9, clamp_lo, clamp_hi);
+  addsub_sse4_1(bf1[7], bf1[8], bf1 + 7, bf1 + 8, clamp_lo, clamp_hi);
+
+  temp1 = half_btf_sse4_1(cospim32, &bf1[20], cospi32, &bf1[27], rounding, bit);
+  bf1[27] =
+      half_btf_sse4_1(cospi32, &bf1[20], cospi32, &bf1[27], rounding, bit);
+  bf1[20] = temp1;
+  temp2 = half_btf_sse4_1(cospim32, &bf1[21], cospi32, &bf1[26], rounding, bit);
+  bf1[26] =
+      half_btf_sse4_1(cospi32, &bf1[21], cospi32, &bf1[26], rounding, bit);
+  bf1[21] = temp2;
+  temp1 = half_btf_sse4_1(cospim32, &bf1[22], cospi32, &bf1[25], rounding, bit);
+  bf1[25] =
+      half_btf_sse4_1(cospi32, &bf1[22], cospi32, &bf1[25], rounding, bit);
+  bf1[22] = temp1;
+  temp2 = half_btf_sse4_1(cospim32, &bf1[23], cospi32, &bf1[24], rounding, bit);
+  bf1[24] =
+      half_btf_sse4_1(cospi32, &bf1[23], cospi32, &bf1[24], rounding, bit);
+  bf1[23] = temp2;
+}
+
+static INLINE void idct32_stage9_sse4_1(__m128i *bf1, __m128i *out,
+                                        const int do_cols, const int bd,
+                                        const int out_shift,
+                                        const int log_range) {
+  if (do_cols) {
+    addsub_no_clamp_sse4_1(bf1[0], bf1[31], out + 0, out + 31);
+    addsub_no_clamp_sse4_1(bf1[1], bf1[30], out + 1, out + 30);
+    addsub_no_clamp_sse4_1(bf1[2], bf1[29], out + 2, out + 29);
+    addsub_no_clamp_sse4_1(bf1[3], bf1[28], out + 3, out + 28);
+    addsub_no_clamp_sse4_1(bf1[4], bf1[27], out + 4, out + 27);
+    addsub_no_clamp_sse4_1(bf1[5], bf1[26], out + 5, out + 26);
+    addsub_no_clamp_sse4_1(bf1[6], bf1[25], out + 6, out + 25);
+    addsub_no_clamp_sse4_1(bf1[7], bf1[24], out + 7, out + 24);
+    addsub_no_clamp_sse4_1(bf1[8], bf1[23], out + 8, out + 23);
+    addsub_no_clamp_sse4_1(bf1[9], bf1[22], out + 9, out + 22);
+    addsub_no_clamp_sse4_1(bf1[10], bf1[21], out + 10, out + 21);
+    addsub_no_clamp_sse4_1(bf1[11], bf1[20], out + 11, out + 20);
+    addsub_no_clamp_sse4_1(bf1[12], bf1[19], out + 12, out + 19);
+    addsub_no_clamp_sse4_1(bf1[13], bf1[18], out + 13, out + 18);
+    addsub_no_clamp_sse4_1(bf1[14], bf1[17], out + 14, out + 17);
+    addsub_no_clamp_sse4_1(bf1[15], bf1[16], out + 15, out + 16);
+  } else {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX(
+        -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+    const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN(
+        (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+    addsub_shift_sse4_1(bf1[0], bf1[31], out + 0, out + 31, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[1], bf1[30], out + 1, out + 30, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[2], bf1[29], out + 2, out + 29, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[3], bf1[28], out + 3, out + 28, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[4], bf1[27], out + 4, out + 27, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[5], bf1[26], out + 5, out + 26, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[6], bf1[25], out + 6, out + 25, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[7], bf1[24], out + 7, out + 24, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[8], bf1[23], out + 8, out + 23, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[9], bf1[22], out + 9, out + 22, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[10], bf1[21], out + 10, out + 21, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[11], bf1[20], out + 11, out + 20, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[12], bf1[19], out + 12, out + 19, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[13], bf1[18], out + 13, out + 18, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[14], bf1[17], out + 14, out + 17, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf1[15], bf1[16], out + 15, out + 16, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+  }
+}
+
 static void neg_shift_sse4_1(const __m128i in0, const __m128i in1,
                              __m128i *out0, __m128i *out1,
                              const __m128i *clamp_lo, const __m128i *clamp_hi,
@@ -4027,6 +4240,751 @@
   }
 }
 
+static void idct32x32_low1_sse4_1(__m128i *in, __m128i *out, int bit,
+                                  int do_cols, int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
+  const __m128i rounding = _mm_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+  const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
+  __m128i bf1;
+
+  // stage 0
+  // stage 1
+  bf1 = in[0];
+
+  // stage 2
+  // stage 3
+  // stage 4
+  // stage 5
+  bf1 = half_btf_0_sse4_1(&cospi32, &bf1, &rounding, bit);
+
+  // stage 6
+  // stage 7
+  // stage 8
+  // stage 9
+  if (do_cols) {
+    bf1 = _mm_max_epi32(bf1, clamp_lo);
+    bf1 = _mm_min_epi32(bf1, clamp_hi);
+  } else {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX(
+        -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+    const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN(
+        (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+    __m128i offset = _mm_set1_epi32((1 << out_shift) >> 1);
+    bf1 = _mm_add_epi32(bf1, offset);
+    bf1 = _mm_sra_epi32(bf1, _mm_cvtsi32_si128(out_shift));
+    bf1 = _mm_max_epi32(bf1, clamp_lo_out);
+    bf1 = _mm_min_epi32(bf1, clamp_hi_out);
+  }
+  out[0] = bf1;
+  out[1] = bf1;
+  out[2] = bf1;
+  out[3] = bf1;
+  out[4] = bf1;
+  out[5] = bf1;
+  out[6] = bf1;
+  out[7] = bf1;
+  out[8] = bf1;
+  out[9] = bf1;
+  out[10] = bf1;
+  out[11] = bf1;
+  out[12] = bf1;
+  out[13] = bf1;
+  out[14] = bf1;
+  out[15] = bf1;
+  out[16] = bf1;
+  out[17] = bf1;
+  out[18] = bf1;
+  out[19] = bf1;
+  out[20] = bf1;
+  out[21] = bf1;
+  out[22] = bf1;
+  out[23] = bf1;
+  out[24] = bf1;
+  out[25] = bf1;
+  out[26] = bf1;
+  out[27] = bf1;
+  out[28] = bf1;
+  out[29] = bf1;
+  out[30] = bf1;
+  out[31] = bf1;
+}
+
+static void idct32x32_low8_sse4_1(__m128i *in, __m128i *out, int bit,
+                                  int do_cols, int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m128i cospi62 = _mm_set1_epi32(cospi[62]);
+  const __m128i cospi14 = _mm_set1_epi32(cospi[14]);
+  const __m128i cospi54 = _mm_set1_epi32(cospi[54]);
+  const __m128i cospi6 = _mm_set1_epi32(cospi[6]);
+  const __m128i cospi10 = _mm_set1_epi32(cospi[10]);
+  const __m128i cospi2 = _mm_set1_epi32(cospi[2]);
+  const __m128i cospim58 = _mm_set1_epi32(-cospi[58]);
+  const __m128i cospim50 = _mm_set1_epi32(-cospi[50]);
+  const __m128i cospi60 = _mm_set1_epi32(cospi[60]);
+  const __m128i cospi12 = _mm_set1_epi32(cospi[12]);
+  const __m128i cospi4 = _mm_set1_epi32(cospi[4]);
+  const __m128i cospim52 = _mm_set1_epi32(-cospi[52]);
+  const __m128i cospi56 = _mm_set1_epi32(cospi[56]);
+  const __m128i cospi24 = _mm_set1_epi32(cospi[24]);
+  const __m128i cospi40 = _mm_set1_epi32(cospi[40]);
+  const __m128i cospi8 = _mm_set1_epi32(cospi[8]);
+  const __m128i cospim40 = _mm_set1_epi32(-cospi[40]);
+  const __m128i cospim8 = _mm_set1_epi32(-cospi[8]);
+  const __m128i cospim56 = _mm_set1_epi32(-cospi[56]);
+  const __m128i cospim24 = _mm_set1_epi32(-cospi[24]);
+  const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
+  const __m128i cospim32 = _mm_set1_epi32(-cospi[32]);
+  const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
+  const __m128i cospim48 = _mm_set1_epi32(-cospi[48]);
+  const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
+  const __m128i cospim16 = _mm_set1_epi32(-cospi[16]);
+  const __m128i rounding = _mm_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+  const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
+  __m128i bf1[32];
+
+  // stage 0
+  // stage 1
+  bf1[0] = in[0];
+  bf1[4] = in[4];
+  bf1[8] = in[2];
+  bf1[12] = in[6];
+  bf1[16] = in[1];
+  bf1[20] = in[5];
+  bf1[24] = in[3];
+  bf1[28] = in[7];
+
+  // stage 2
+  bf1[31] = half_btf_0_sse4_1(&cospi2, &bf1[16], &rounding, bit);
+  bf1[16] = half_btf_0_sse4_1(&cospi62, &bf1[16], &rounding, bit);
+  bf1[19] = half_btf_0_sse4_1(&cospim50, &bf1[28], &rounding, bit);
+  bf1[28] = half_btf_0_sse4_1(&cospi14, &bf1[28], &rounding, bit);
+  bf1[27] = half_btf_0_sse4_1(&cospi10, &bf1[20], &rounding, bit);
+  bf1[20] = half_btf_0_sse4_1(&cospi54, &bf1[20], &rounding, bit);
+  bf1[23] = half_btf_0_sse4_1(&cospim58, &bf1[24], &rounding, bit);
+  bf1[24] = half_btf_0_sse4_1(&cospi6, &bf1[24], &rounding, bit);
+
+  // stage 3
+  bf1[15] = half_btf_0_sse4_1(&cospi4, &bf1[8], &rounding, bit);
+  bf1[8] = half_btf_0_sse4_1(&cospi60, &bf1[8], &rounding, bit);
+
+  bf1[11] = half_btf_0_sse4_1(&cospim52, &bf1[12], &rounding, bit);
+  bf1[12] = half_btf_0_sse4_1(&cospi12, &bf1[12], &rounding, bit);
+  bf1[17] = bf1[16];
+  bf1[18] = bf1[19];
+  bf1[21] = bf1[20];
+  bf1[22] = bf1[23];
+  bf1[25] = bf1[24];
+  bf1[26] = bf1[27];
+  bf1[29] = bf1[28];
+  bf1[30] = bf1[31];
+
+  // stage 4 :
+  bf1[7] = half_btf_0_sse4_1(&cospi8, &bf1[4], &rounding, bit);
+  bf1[4] = half_btf_0_sse4_1(&cospi56, &bf1[4], &rounding, bit);
+
+  bf1[9] = bf1[8];
+  bf1[10] = bf1[11];
+  bf1[13] = bf1[12];
+  bf1[14] = bf1[15];
+
+  idct32_stage4_sse4_1(bf1, &cospim8, &cospi56, &cospi8, &cospim56, &cospim40,
+                       &cospi24, &cospi40, &cospim24, &rounding, bit);
+
+  // stage 5
+  bf1[0] = half_btf_0_sse4_1(&cospi32, &bf1[0], &rounding, bit);
+  bf1[1] = bf1[0];
+  bf1[5] = bf1[4];
+  bf1[6] = bf1[7];
+
+  idct32_stage5_sse4_1(bf1, &cospim16, &cospi48, &cospi16, &cospim48, &clamp_lo,
+                       &clamp_hi, &rounding, bit);
+
+  // stage 6
+  bf1[3] = bf1[0];
+  bf1[2] = bf1[1];
+
+  idct32_stage6_sse4_1(bf1, &cospim32, &cospi32, &cospim16, &cospi48, &cospi16,
+                       &cospim48, &clamp_lo, &clamp_hi, &rounding, bit);
+
+  // stage 7
+  idct32_stage7_sse4_1(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi,
+                       &rounding, bit);
+
+  // stage 8
+  idct32_stage8_sse4_1(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi,
+                       &rounding, bit);
+
+  // stage 9
+  idct32_stage9_sse4_1(bf1, out, do_cols, bd, out_shift, log_range);
+}
+
+static void idct32x32_low16_sse4_1(__m128i *in, __m128i *out, int bit,
+                                   int do_cols, int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m128i cospi62 = _mm_set1_epi32(cospi[62]);
+  const __m128i cospi30 = _mm_set1_epi32(cospi[30]);
+  const __m128i cospi46 = _mm_set1_epi32(cospi[46]);
+  const __m128i cospi14 = _mm_set1_epi32(cospi[14]);
+  const __m128i cospi54 = _mm_set1_epi32(cospi[54]);
+  const __m128i cospi22 = _mm_set1_epi32(cospi[22]);
+  const __m128i cospi38 = _mm_set1_epi32(cospi[38]);
+  const __m128i cospi6 = _mm_set1_epi32(cospi[6]);
+  const __m128i cospi26 = _mm_set1_epi32(cospi[26]);
+  const __m128i cospi10 = _mm_set1_epi32(cospi[10]);
+  const __m128i cospi18 = _mm_set1_epi32(cospi[18]);
+  const __m128i cospi2 = _mm_set1_epi32(cospi[2]);
+  const __m128i cospim58 = _mm_set1_epi32(-cospi[58]);
+  const __m128i cospim42 = _mm_set1_epi32(-cospi[42]);
+  const __m128i cospim50 = _mm_set1_epi32(-cospi[50]);
+  const __m128i cospim34 = _mm_set1_epi32(-cospi[34]);
+  const __m128i cospi60 = _mm_set1_epi32(cospi[60]);
+  const __m128i cospi28 = _mm_set1_epi32(cospi[28]);
+  const __m128i cospi44 = _mm_set1_epi32(cospi[44]);
+  const __m128i cospi12 = _mm_set1_epi32(cospi[12]);
+  const __m128i cospi20 = _mm_set1_epi32(cospi[20]);
+  const __m128i cospi4 = _mm_set1_epi32(cospi[4]);
+  const __m128i cospim52 = _mm_set1_epi32(-cospi[52]);
+  const __m128i cospim36 = _mm_set1_epi32(-cospi[36]);
+  const __m128i cospi56 = _mm_set1_epi32(cospi[56]);
+  const __m128i cospi24 = _mm_set1_epi32(cospi[24]);
+  const __m128i cospi40 = _mm_set1_epi32(cospi[40]);
+  const __m128i cospi8 = _mm_set1_epi32(cospi[8]);
+  const __m128i cospim40 = _mm_set1_epi32(-cospi[40]);
+  const __m128i cospim8 = _mm_set1_epi32(-cospi[8]);
+  const __m128i cospim56 = _mm_set1_epi32(-cospi[56]);
+  const __m128i cospim24 = _mm_set1_epi32(-cospi[24]);
+  const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
+  const __m128i cospim32 = _mm_set1_epi32(-cospi[32]);
+  const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
+  const __m128i cospim48 = _mm_set1_epi32(-cospi[48]);
+  const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
+  const __m128i cospim16 = _mm_set1_epi32(-cospi[16]);
+  const __m128i rounding = _mm_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+  const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
+  __m128i bf1[32];
+
+  // stage 0
+  // stage 1
+
+  bf1[0] = in[0];
+  bf1[2] = in[8];
+  bf1[4] = in[4];
+  bf1[6] = in[12];
+  bf1[8] = in[2];
+  bf1[10] = in[10];
+  bf1[12] = in[6];
+  bf1[14] = in[14];
+  bf1[16] = in[1];
+  bf1[18] = in[9];
+  bf1[20] = in[5];
+  bf1[22] = in[13];
+  bf1[24] = in[3];
+  bf1[26] = in[11];
+  bf1[28] = in[7];
+  bf1[30] = in[15];
+
+  // stage 2
+  bf1[31] = half_btf_0_sse4_1(&cospi2, &bf1[16], &rounding, bit);
+  bf1[16] = half_btf_0_sse4_1(&cospi62, &bf1[16], &rounding, bit);
+  bf1[17] = half_btf_0_sse4_1(&cospim34, &bf1[30], &rounding, bit);
+  bf1[30] = half_btf_0_sse4_1(&cospi30, &bf1[30], &rounding, bit);
+  bf1[29] = half_btf_0_sse4_1(&cospi18, &bf1[18], &rounding, bit);
+  bf1[18] = half_btf_0_sse4_1(&cospi46, &bf1[18], &rounding, bit);
+  bf1[19] = half_btf_0_sse4_1(&cospim50, &bf1[28], &rounding, bit);
+  bf1[28] = half_btf_0_sse4_1(&cospi14, &bf1[28], &rounding, bit);
+  bf1[27] = half_btf_0_sse4_1(&cospi10, &bf1[20], &rounding, bit);
+  bf1[20] = half_btf_0_sse4_1(&cospi54, &bf1[20], &rounding, bit);
+  bf1[21] = half_btf_0_sse4_1(&cospim42, &bf1[26], &rounding, bit);
+  bf1[26] = half_btf_0_sse4_1(&cospi22, &bf1[26], &rounding, bit);
+  bf1[25] = half_btf_0_sse4_1(&cospi26, &bf1[22], &rounding, bit);
+  bf1[22] = half_btf_0_sse4_1(&cospi38, &bf1[22], &rounding, bit);
+  bf1[23] = half_btf_0_sse4_1(&cospim58, &bf1[24], &rounding, bit);
+  bf1[24] = half_btf_0_sse4_1(&cospi6, &bf1[24], &rounding, bit);
+
+  // stage 3
+  bf1[15] = half_btf_0_sse4_1(&cospi4, &bf1[8], &rounding, bit);
+  bf1[8] = half_btf_0_sse4_1(&cospi60, &bf1[8], &rounding, bit);
+  bf1[9] = half_btf_0_sse4_1(&cospim36, &bf1[14], &rounding, bit);
+  bf1[14] = half_btf_0_sse4_1(&cospi28, &bf1[14], &rounding, bit);
+  bf1[13] = half_btf_0_sse4_1(&cospi20, &bf1[10], &rounding, bit);
+  bf1[10] = half_btf_0_sse4_1(&cospi44, &bf1[10], &rounding, bit);
+  bf1[11] = half_btf_0_sse4_1(&cospim52, &bf1[12], &rounding, bit);
+  bf1[12] = half_btf_0_sse4_1(&cospi12, &bf1[12], &rounding, bit);
+
+  addsub_sse4_1(bf1[16], bf1[17], bf1 + 16, bf1 + 17, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[19], bf1[18], bf1 + 19, bf1 + 18, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[20], bf1[21], bf1 + 20, bf1 + 21, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[23], bf1[22], bf1 + 23, bf1 + 22, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[24], bf1[25], bf1 + 24, bf1 + 25, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[27], bf1[26], bf1 + 27, bf1 + 26, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[28], bf1[29], bf1 + 28, bf1 + 29, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[31], bf1[30], bf1 + 31, bf1 + 30, &clamp_lo, &clamp_hi);
+  // stage 4
+  bf1[7] = half_btf_0_sse4_1(&cospi8, &bf1[4], &rounding, bit);
+  bf1[4] = half_btf_0_sse4_1(&cospi56, &bf1[4], &rounding, bit);
+  bf1[5] = half_btf_0_sse4_1(&cospim40, &bf1[6], &rounding, bit);
+  bf1[6] = half_btf_0_sse4_1(&cospi24, &bf1[6], &rounding, bit);
+
+  addsub_sse4_1(bf1[8], bf1[9], bf1 + 8, bf1 + 9, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[11], bf1[10], bf1 + 11, bf1 + 10, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[12], bf1[13], bf1 + 12, bf1 + 13, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[15], bf1[14], bf1 + 15, bf1 + 14, &clamp_lo, &clamp_hi);
+
+  idct32_stage4_sse4_1(bf1, &cospim8, &cospi56, &cospi8, &cospim56, &cospim40,
+                       &cospi24, &cospi40, &cospim24, &rounding, bit);
+
+  // stage 5
+  bf1[0] = half_btf_0_sse4_1(&cospi32, &bf1[0], &rounding, bit);
+  bf1[1] = bf1[0];
+  bf1[3] = half_btf_0_sse4_1(&cospi16, &bf1[2], &rounding, bit);
+  bf1[2] = half_btf_0_sse4_1(&cospi48, &bf1[2], &rounding, bit);
+
+  addsub_sse4_1(bf1[4], bf1[5], bf1 + 4, bf1 + 5, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[7], bf1[6], bf1 + 7, bf1 + 6, &clamp_lo, &clamp_hi);
+
+  idct32_stage5_sse4_1(bf1, &cospim16, &cospi48, &cospi16, &cospim48, &clamp_lo,
+                       &clamp_hi, &rounding, bit);
+
+  // stage 6
+  addsub_sse4_1(bf1[0], bf1[3], bf1 + 0, bf1 + 3, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[1], bf1[2], bf1 + 1, bf1 + 2, &clamp_lo, &clamp_hi);
+
+  idct32_stage6_sse4_1(bf1, &cospim32, &cospi32, &cospim16, &cospi48, &cospi16,
+                       &cospim48, &clamp_lo, &clamp_hi, &rounding, bit);
+
+  // stage 7
+  idct32_stage7_sse4_1(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi,
+                       &rounding, bit);
+
+  // stage 8
+  idct32_stage8_sse4_1(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi,
+                       &rounding, bit);
+
+  // stage 9
+  idct32_stage9_sse4_1(bf1, out, do_cols, bd, out_shift, log_range);
+}
+
+static void idct32x32_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols,
+                             int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m128i cospi62 = _mm_set1_epi32(cospi[62]);
+  const __m128i cospi30 = _mm_set1_epi32(cospi[30]);
+  const __m128i cospi46 = _mm_set1_epi32(cospi[46]);
+  const __m128i cospi14 = _mm_set1_epi32(cospi[14]);
+  const __m128i cospi54 = _mm_set1_epi32(cospi[54]);
+  const __m128i cospi22 = _mm_set1_epi32(cospi[22]);
+  const __m128i cospi38 = _mm_set1_epi32(cospi[38]);
+  const __m128i cospi6 = _mm_set1_epi32(cospi[6]);
+  const __m128i cospi58 = _mm_set1_epi32(cospi[58]);
+  const __m128i cospi26 = _mm_set1_epi32(cospi[26]);
+  const __m128i cospi42 = _mm_set1_epi32(cospi[42]);
+  const __m128i cospi10 = _mm_set1_epi32(cospi[10]);
+  const __m128i cospi50 = _mm_set1_epi32(cospi[50]);
+  const __m128i cospi18 = _mm_set1_epi32(cospi[18]);
+  const __m128i cospi34 = _mm_set1_epi32(cospi[34]);
+  const __m128i cospi2 = _mm_set1_epi32(cospi[2]);
+  const __m128i cospim58 = _mm_set1_epi32(-cospi[58]);
+  const __m128i cospim26 = _mm_set1_epi32(-cospi[26]);
+  const __m128i cospim42 = _mm_set1_epi32(-cospi[42]);
+  const __m128i cospim10 = _mm_set1_epi32(-cospi[10]);
+  const __m128i cospim50 = _mm_set1_epi32(-cospi[50]);
+  const __m128i cospim18 = _mm_set1_epi32(-cospi[18]);
+  const __m128i cospim34 = _mm_set1_epi32(-cospi[34]);
+  const __m128i cospim2 = _mm_set1_epi32(-cospi[2]);
+  const __m128i cospi60 = _mm_set1_epi32(cospi[60]);
+  const __m128i cospi28 = _mm_set1_epi32(cospi[28]);
+  const __m128i cospi44 = _mm_set1_epi32(cospi[44]);
+  const __m128i cospi12 = _mm_set1_epi32(cospi[12]);
+  const __m128i cospi52 = _mm_set1_epi32(cospi[52]);
+  const __m128i cospi20 = _mm_set1_epi32(cospi[20]);
+  const __m128i cospi36 = _mm_set1_epi32(cospi[36]);
+  const __m128i cospi4 = _mm_set1_epi32(cospi[4]);
+  const __m128i cospim52 = _mm_set1_epi32(-cospi[52]);
+  const __m128i cospim20 = _mm_set1_epi32(-cospi[20]);
+  const __m128i cospim36 = _mm_set1_epi32(-cospi[36]);
+  const __m128i cospim4 = _mm_set1_epi32(-cospi[4]);
+  const __m128i cospi56 = _mm_set1_epi32(cospi[56]);
+  const __m128i cospi24 = _mm_set1_epi32(cospi[24]);
+  const __m128i cospi40 = _mm_set1_epi32(cospi[40]);
+  const __m128i cospi8 = _mm_set1_epi32(cospi[8]);
+  const __m128i cospim40 = _mm_set1_epi32(-cospi[40]);
+  const __m128i cospim8 = _mm_set1_epi32(-cospi[8]);
+  const __m128i cospim56 = _mm_set1_epi32(-cospi[56]);
+  const __m128i cospim24 = _mm_set1_epi32(-cospi[24]);
+  const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
+  const __m128i cospim32 = _mm_set1_epi32(-cospi[32]);
+  const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
+  const __m128i cospim48 = _mm_set1_epi32(-cospi[48]);
+  const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
+  const __m128i cospim16 = _mm_set1_epi32(-cospi[16]);
+  const __m128i rounding = _mm_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+  const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
+  __m128i bf1[32], bf0[32];
+
+  // stage 0
+  // stage 1
+  bf1[0] = in[0];
+  bf1[1] = in[16];
+  bf1[2] = in[8];
+  bf1[3] = in[24];
+  bf1[4] = in[4];
+  bf1[5] = in[20];
+  bf1[6] = in[12];
+  bf1[7] = in[28];
+  bf1[8] = in[2];
+  bf1[9] = in[18];
+  bf1[10] = in[10];
+  bf1[11] = in[26];
+  bf1[12] = in[6];
+  bf1[13] = in[22];
+  bf1[14] = in[14];
+  bf1[15] = in[30];
+  bf1[16] = in[1];
+  bf1[17] = in[17];
+  bf1[18] = in[9];
+  bf1[19] = in[25];
+  bf1[20] = in[5];
+  bf1[21] = in[21];
+  bf1[22] = in[13];
+  bf1[23] = in[29];
+  bf1[24] = in[3];
+  bf1[25] = in[19];
+  bf1[26] = in[11];
+  bf1[27] = in[27];
+  bf1[28] = in[7];
+  bf1[29] = in[23];
+  bf1[30] = in[15];
+  bf1[31] = in[31];
+
+  // stage 2
+  bf0[0] = bf1[0];
+  bf0[1] = bf1[1];
+  bf0[2] = bf1[2];
+  bf0[3] = bf1[3];
+  bf0[4] = bf1[4];
+  bf0[5] = bf1[5];
+  bf0[6] = bf1[6];
+  bf0[7] = bf1[7];
+  bf0[8] = bf1[8];
+  bf0[9] = bf1[9];
+  bf0[10] = bf1[10];
+  bf0[11] = bf1[11];
+  bf0[12] = bf1[12];
+  bf0[13] = bf1[13];
+  bf0[14] = bf1[14];
+  bf0[15] = bf1[15];
+  bf0[16] =
+      half_btf_sse4_1(&cospi62, &bf1[16], &cospim2, &bf1[31], &rounding, bit);
+  bf0[17] =
+      half_btf_sse4_1(&cospi30, &bf1[17], &cospim34, &bf1[30], &rounding, bit);
+  bf0[18] =
+      half_btf_sse4_1(&cospi46, &bf1[18], &cospim18, &bf1[29], &rounding, bit);
+  bf0[19] =
+      half_btf_sse4_1(&cospi14, &bf1[19], &cospim50, &bf1[28], &rounding, bit);
+  bf0[20] =
+      half_btf_sse4_1(&cospi54, &bf1[20], &cospim10, &bf1[27], &rounding, bit);
+  bf0[21] =
+      half_btf_sse4_1(&cospi22, &bf1[21], &cospim42, &bf1[26], &rounding, bit);
+  bf0[22] =
+      half_btf_sse4_1(&cospi38, &bf1[22], &cospim26, &bf1[25], &rounding, bit);
+  bf0[23] =
+      half_btf_sse4_1(&cospi6, &bf1[23], &cospim58, &bf1[24], &rounding, bit);
+  bf0[24] =
+      half_btf_sse4_1(&cospi58, &bf1[23], &cospi6, &bf1[24], &rounding, bit);
+  bf0[25] =
+      half_btf_sse4_1(&cospi26, &bf1[22], &cospi38, &bf1[25], &rounding, bit);
+  bf0[26] =
+      half_btf_sse4_1(&cospi42, &bf1[21], &cospi22, &bf1[26], &rounding, bit);
+  bf0[27] =
+      half_btf_sse4_1(&cospi10, &bf1[20], &cospi54, &bf1[27], &rounding, bit);
+  bf0[28] =
+      half_btf_sse4_1(&cospi50, &bf1[19], &cospi14, &bf1[28], &rounding, bit);
+  bf0[29] =
+      half_btf_sse4_1(&cospi18, &bf1[18], &cospi46, &bf1[29], &rounding, bit);
+  bf0[30] =
+      half_btf_sse4_1(&cospi34, &bf1[17], &cospi30, &bf1[30], &rounding, bit);
+  bf0[31] =
+      half_btf_sse4_1(&cospi2, &bf1[16], &cospi62, &bf1[31], &rounding, bit);
+
+  // stage 3
+  bf1[0] = bf0[0];
+  bf1[1] = bf0[1];
+  bf1[2] = bf0[2];
+  bf1[3] = bf0[3];
+  bf1[4] = bf0[4];
+  bf1[5] = bf0[5];
+  bf1[6] = bf0[6];
+  bf1[7] = bf0[7];
+  bf1[8] =
+      half_btf_sse4_1(&cospi60, &bf0[8], &cospim4, &bf0[15], &rounding, bit);
+  bf1[9] =
+      half_btf_sse4_1(&cospi28, &bf0[9], &cospim36, &bf0[14], &rounding, bit);
+  bf1[10] =
+      half_btf_sse4_1(&cospi44, &bf0[10], &cospim20, &bf0[13], &rounding, bit);
+  bf1[11] =
+      half_btf_sse4_1(&cospi12, &bf0[11], &cospim52, &bf0[12], &rounding, bit);
+  bf1[12] =
+      half_btf_sse4_1(&cospi52, &bf0[11], &cospi12, &bf0[12], &rounding, bit);
+  bf1[13] =
+      half_btf_sse4_1(&cospi20, &bf0[10], &cospi44, &bf0[13], &rounding, bit);
+  bf1[14] =
+      half_btf_sse4_1(&cospi36, &bf0[9], &cospi28, &bf0[14], &rounding, bit);
+  bf1[15] =
+      half_btf_sse4_1(&cospi4, &bf0[8], &cospi60, &bf0[15], &rounding, bit);
+
+  addsub_sse4_1(bf0[16], bf0[17], bf1 + 16, bf1 + 17, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[19], bf0[18], bf1 + 19, bf1 + 18, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[20], bf0[21], bf1 + 20, bf1 + 21, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[23], bf0[22], bf1 + 23, bf1 + 22, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[24], bf0[25], bf1 + 24, bf1 + 25, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[27], bf0[26], bf1 + 27, bf1 + 26, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[28], bf0[29], bf1 + 28, bf1 + 29, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[31], bf0[30], bf1 + 31, bf1 + 30, &clamp_lo, &clamp_hi);
+
+  // stage 4
+  bf0[0] = bf1[0];
+  bf0[1] = bf1[1];
+  bf0[2] = bf1[2];
+  bf0[3] = bf1[3];
+  bf0[4] =
+      half_btf_sse4_1(&cospi56, &bf1[4], &cospim8, &bf1[7], &rounding, bit);
+  bf0[5] =
+      half_btf_sse4_1(&cospi24, &bf1[5], &cospim40, &bf1[6], &rounding, bit);
+  bf0[6] =
+      half_btf_sse4_1(&cospi40, &bf1[5], &cospi24, &bf1[6], &rounding, bit);
+  bf0[7] = half_btf_sse4_1(&cospi8, &bf1[4], &cospi56, &bf1[7], &rounding, bit);
+
+  addsub_sse4_1(bf1[8], bf1[9], bf0 + 8, bf0 + 9, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[11], bf1[10], bf0 + 11, bf0 + 10, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[12], bf1[13], bf0 + 12, bf0 + 13, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[15], bf1[14], bf0 + 15, bf0 + 14, &clamp_lo, &clamp_hi);
+
+  bf0[16] = bf1[16];
+  bf0[17] =
+      half_btf_sse4_1(&cospim8, &bf1[17], &cospi56, &bf1[30], &rounding, bit);
+  bf0[18] =
+      half_btf_sse4_1(&cospim56, &bf1[18], &cospim8, &bf1[29], &rounding, bit);
+  bf0[19] = bf1[19];
+  bf0[20] = bf1[20];
+  bf0[21] =
+      half_btf_sse4_1(&cospim40, &bf1[21], &cospi24, &bf1[26], &rounding, bit);
+  bf0[22] =
+      half_btf_sse4_1(&cospim24, &bf1[22], &cospim40, &bf1[25], &rounding, bit);
+  bf0[23] = bf1[23];
+  bf0[24] = bf1[24];
+  bf0[25] =
+      half_btf_sse4_1(&cospim40, &bf1[22], &cospi24, &bf1[25], &rounding, bit);
+  bf0[26] =
+      half_btf_sse4_1(&cospi24, &bf1[21], &cospi40, &bf1[26], &rounding, bit);
+  bf0[27] = bf1[27];
+  bf0[28] = bf1[28];
+  bf0[29] =
+      half_btf_sse4_1(&cospim8, &bf1[18], &cospi56, &bf1[29], &rounding, bit);
+  bf0[30] =
+      half_btf_sse4_1(&cospi56, &bf1[17], &cospi8, &bf1[30], &rounding, bit);
+  bf0[31] = bf1[31];
+
+  // stage 5
+  bf1[0] =
+      half_btf_sse4_1(&cospi32, &bf0[0], &cospi32, &bf0[1], &rounding, bit);
+  bf1[1] =
+      half_btf_sse4_1(&cospi32, &bf0[0], &cospim32, &bf0[1], &rounding, bit);
+  bf1[2] =
+      half_btf_sse4_1(&cospi48, &bf0[2], &cospim16, &bf0[3], &rounding, bit);
+  bf1[3] =
+      half_btf_sse4_1(&cospi16, &bf0[2], &cospi48, &bf0[3], &rounding, bit);
+  addsub_sse4_1(bf0[4], bf0[5], bf1 + 4, bf1 + 5, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[7], bf0[6], bf1 + 7, bf1 + 6, &clamp_lo, &clamp_hi);
+  bf1[8] = bf0[8];
+  bf1[9] =
+      half_btf_sse4_1(&cospim16, &bf0[9], &cospi48, &bf0[14], &rounding, bit);
+  bf1[10] =
+      half_btf_sse4_1(&cospim48, &bf0[10], &cospim16, &bf0[13], &rounding, bit);
+  bf1[11] = bf0[11];
+  bf1[12] = bf0[12];
+  bf1[13] =
+      half_btf_sse4_1(&cospim16, &bf0[10], &cospi48, &bf0[13], &rounding, bit);
+  bf1[14] =
+      half_btf_sse4_1(&cospi48, &bf0[9], &cospi16, &bf0[14], &rounding, bit);
+  bf1[15] = bf0[15];
+  addsub_sse4_1(bf0[16], bf0[19], bf1 + 16, bf1 + 19, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[17], bf0[18], bf1 + 17, bf1 + 18, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[23], bf0[20], bf1 + 23, bf1 + 20, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[22], bf0[21], bf1 + 22, bf1 + 21, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[24], bf0[27], bf1 + 24, bf1 + 27, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[25], bf0[26], bf1 + 25, bf1 + 26, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[31], bf0[28], bf1 + 31, bf1 + 28, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[30], bf0[29], bf1 + 30, bf1 + 29, &clamp_lo, &clamp_hi);
+
+  // stage 6
+  addsub_sse4_1(bf1[0], bf1[3], bf0 + 0, bf0 + 3, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[1], bf1[2], bf0 + 1, bf0 + 2, &clamp_lo, &clamp_hi);
+  bf0[4] = bf1[4];
+  bf0[5] =
+      half_btf_sse4_1(&cospim32, &bf1[5], &cospi32, &bf1[6], &rounding, bit);
+  bf0[6] =
+      half_btf_sse4_1(&cospi32, &bf1[5], &cospi32, &bf1[6], &rounding, bit);
+  bf0[7] = bf1[7];
+  addsub_sse4_1(bf1[8], bf1[11], bf0 + 8, bf0 + 11, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[9], bf1[10], bf0 + 9, bf0 + 10, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[15], bf1[12], bf0 + 15, bf0 + 12, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[14], bf1[13], bf0 + 14, bf0 + 13, &clamp_lo, &clamp_hi);
+  bf0[16] = bf1[16];
+  bf0[17] = bf1[17];
+  bf0[18] =
+      half_btf_sse4_1(&cospim16, &bf1[18], &cospi48, &bf1[29], &rounding, bit);
+  bf0[19] =
+      half_btf_sse4_1(&cospim16, &bf1[19], &cospi48, &bf1[28], &rounding, bit);
+  bf0[20] =
+      half_btf_sse4_1(&cospim48, &bf1[20], &cospim16, &bf1[27], &rounding, bit);
+  bf0[21] =
+      half_btf_sse4_1(&cospim48, &bf1[21], &cospim16, &bf1[26], &rounding, bit);
+  bf0[22] = bf1[22];
+  bf0[23] = bf1[23];
+  bf0[24] = bf1[24];
+  bf0[25] = bf1[25];
+  bf0[26] =
+      half_btf_sse4_1(&cospim16, &bf1[21], &cospi48, &bf1[26], &rounding, bit);
+  bf0[27] =
+      half_btf_sse4_1(&cospim16, &bf1[20], &cospi48, &bf1[27], &rounding, bit);
+  bf0[28] =
+      half_btf_sse4_1(&cospi48, &bf1[19], &cospi16, &bf1[28], &rounding, bit);
+  bf0[29] =
+      half_btf_sse4_1(&cospi48, &bf1[18], &cospi16, &bf1[29], &rounding, bit);
+  bf0[30] = bf1[30];
+  bf0[31] = bf1[31];
+
+  // stage 7
+  addsub_sse4_1(bf0[0], bf0[7], bf1 + 0, bf1 + 7, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[1], bf0[6], bf1 + 1, bf1 + 6, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[2], bf0[5], bf1 + 2, bf1 + 5, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[3], bf0[4], bf1 + 3, bf1 + 4, &clamp_lo, &clamp_hi);
+  bf1[8] = bf0[8];
+  bf1[9] = bf0[9];
+  bf1[10] =
+      half_btf_sse4_1(&cospim32, &bf0[10], &cospi32, &bf0[13], &rounding, bit);
+  bf1[11] =
+      half_btf_sse4_1(&cospim32, &bf0[11], &cospi32, &bf0[12], &rounding, bit);
+  bf1[12] =
+      half_btf_sse4_1(&cospi32, &bf0[11], &cospi32, &bf0[12], &rounding, bit);
+  bf1[13] =
+      half_btf_sse4_1(&cospi32, &bf0[10], &cospi32, &bf0[13], &rounding, bit);
+  bf1[14] = bf0[14];
+  bf1[15] = bf0[15];
+  addsub_sse4_1(bf0[16], bf0[23], bf1 + 16, bf1 + 23, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[17], bf0[22], bf1 + 17, bf1 + 22, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[18], bf0[21], bf1 + 18, bf1 + 21, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[19], bf0[20], bf1 + 19, bf1 + 20, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[31], bf0[24], bf1 + 31, bf1 + 24, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[30], bf0[25], bf1 + 30, bf1 + 25, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[29], bf0[26], bf1 + 29, bf1 + 26, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf0[28], bf0[27], bf1 + 28, bf1 + 27, &clamp_lo, &clamp_hi);
+
+  // stage 8
+  addsub_sse4_1(bf1[0], bf1[15], bf0 + 0, bf0 + 15, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[1], bf1[14], bf0 + 1, bf0 + 14, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[2], bf1[13], bf0 + 2, bf0 + 13, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[3], bf1[12], bf0 + 3, bf0 + 12, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[4], bf1[11], bf0 + 4, bf0 + 11, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[5], bf1[10], bf0 + 5, bf0 + 10, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[6], bf1[9], bf0 + 6, bf0 + 9, &clamp_lo, &clamp_hi);
+  addsub_sse4_1(bf1[7], bf1[8], bf0 + 7, bf0 + 8, &clamp_lo, &clamp_hi);
+  bf0[16] = bf1[16];
+  bf0[17] = bf1[17];
+  bf0[18] = bf1[18];
+  bf0[19] = bf1[19];
+  bf0[20] =
+      half_btf_sse4_1(&cospim32, &bf1[20], &cospi32, &bf1[27], &rounding, bit);
+  bf0[21] =
+      half_btf_sse4_1(&cospim32, &bf1[21], &cospi32, &bf1[26], &rounding, bit);
+  bf0[22] =
+      half_btf_sse4_1(&cospim32, &bf1[22], &cospi32, &bf1[25], &rounding, bit);
+  bf0[23] =
+      half_btf_sse4_1(&cospim32, &bf1[23], &cospi32, &bf1[24], &rounding, bit);
+  bf0[24] =
+      half_btf_sse4_1(&cospi32, &bf1[23], &cospi32, &bf1[24], &rounding, bit);
+  bf0[25] =
+      half_btf_sse4_1(&cospi32, &bf1[22], &cospi32, &bf1[25], &rounding, bit);
+  bf0[26] =
+      half_btf_sse4_1(&cospi32, &bf1[21], &cospi32, &bf1[26], &rounding, bit);
+  bf0[27] =
+      half_btf_sse4_1(&cospi32, &bf1[20], &cospi32, &bf1[27], &rounding, bit);
+  bf0[28] = bf1[28];
+  bf0[29] = bf1[29];
+  bf0[30] = bf1[30];
+  bf0[31] = bf1[31];
+
+  // stage 9
+  if (do_cols) {
+    addsub_no_clamp_sse4_1(bf0[0], bf0[31], out + 0, out + 31);
+    addsub_no_clamp_sse4_1(bf0[1], bf0[30], out + 1, out + 30);
+    addsub_no_clamp_sse4_1(bf0[2], bf0[29], out + 2, out + 29);
+    addsub_no_clamp_sse4_1(bf0[3], bf0[28], out + 3, out + 28);
+    addsub_no_clamp_sse4_1(bf0[4], bf0[27], out + 4, out + 27);
+    addsub_no_clamp_sse4_1(bf0[5], bf0[26], out + 5, out + 26);
+    addsub_no_clamp_sse4_1(bf0[6], bf0[25], out + 6, out + 25);
+    addsub_no_clamp_sse4_1(bf0[7], bf0[24], out + 7, out + 24);
+    addsub_no_clamp_sse4_1(bf0[8], bf0[23], out + 8, out + 23);
+    addsub_no_clamp_sse4_1(bf0[9], bf0[22], out + 9, out + 22);
+    addsub_no_clamp_sse4_1(bf0[10], bf0[21], out + 10, out + 21);
+    addsub_no_clamp_sse4_1(bf0[11], bf0[20], out + 11, out + 20);
+    addsub_no_clamp_sse4_1(bf0[12], bf0[19], out + 12, out + 19);
+    addsub_no_clamp_sse4_1(bf0[13], bf0[18], out + 13, out + 18);
+    addsub_no_clamp_sse4_1(bf0[14], bf0[17], out + 14, out + 17);
+    addsub_no_clamp_sse4_1(bf0[15], bf0[16], out + 15, out + 16);
+  } else {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX(
+        -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+    const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN(
+        (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+    addsub_shift_sse4_1(bf0[0], bf0[31], out + 0, out + 31, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[1], bf0[30], out + 1, out + 30, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[2], bf0[29], out + 2, out + 29, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[3], bf0[28], out + 3, out + 28, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[4], bf0[27], out + 4, out + 27, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[5], bf0[26], out + 5, out + 26, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[6], bf0[25], out + 6, out + 25, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[7], bf0[24], out + 7, out + 24, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[8], bf0[23], out + 8, out + 23, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[9], bf0[22], out + 9, out + 22, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[10], bf0[21], out + 10, out + 21, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[11], bf0[20], out + 11, out + 20, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[12], bf0[19], out + 12, out + 19, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[13], bf0[18], out + 13, out + 18, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[14], bf0[17], out + 14, out + 17, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    addsub_shift_sse4_1(bf0[15], bf0[16], out + 15, out + 16, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+  }
+}
+
 void av1_highbd_inv_txfm_add_8x8_sse4_1(const tran_low_t *input, uint8_t *dest,
                                         int stride,
                                         const TxfmParam *txfm_param) {
@@ -4134,6 +5092,27 @@
   }
 }
 
+void av1_highbd_inv_txfm_add_32x32_sse4_1(const tran_low_t *input,
+                                          uint8_t *dest, int stride,
+                                          const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+    case DCT_DCT:
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
+      break;
+      // Assembly version doesn't support IDTX, so use C version for it.
+    case IDTX:
+      av1_inv_txfm2d_add_32x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                                 tx_type, bd);
+      break;
+    default: assert(0);
+  }
+}
+
 void av1_highbd_inv_txfm_add_4x4_sse4_1(const tran_low_t *input, uint8_t *dest,
                                         int stride,
                                         const TxfmParam *txfm_param) {
@@ -4185,7 +5164,8 @@
             NULL },
           { NULL, NULL, NULL, NULL },
       },
-      { { NULL, NULL, NULL, NULL },
+      { { idct32x32_low1_sse4_1, idct32x32_low8_sse4_1, idct32x32_low16_sse4_1,
+          idct32x32_sse4_1 },
         { NULL, NULL, NULL, NULL },
         { NULL, NULL, NULL, NULL } },
       { { idct64x64_low1_sse4_1, idct64x64_low8_sse4_1, idct64x64_low16_sse4_1,
@@ -4309,7 +5289,7 @@
   const TX_SIZE tx_size = txfm_param->tx_size;
   switch (tx_size) {
     case TX_32X32:
-      av1_highbd_inv_txfm_add_32x32_c(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_32x32_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_16X16:
       av1_highbd_inv_txfm_add_16x16_sse4_1(input, dest, stride, txfm_param);