Optimize highbd inv_txfm modules

Added AVX2 variants for 16x16,32x16,16x32 txfm blk_sizes.

When tested for multiple test cases observed 0.97%
average reduction in encoder time for speed = 1 preset.

Module level gains improved by a factor of ~1.77
on average w.r.t to SSE4_1 modules.

Code cleanup for functions which are common for both fwd_txfm and inv_txfm.

Change-Id: I8c1baf9ebb954c31b440c0371c953cfb42e60300
diff --git a/aom_dsp/x86/txfm_common_avx2.h b/aom_dsp/x86/txfm_common_avx2.h
index b1611ba..8a40508 100644
--- a/aom_dsp/x86/txfm_common_avx2.h
+++ b/aom_dsp/x86/txfm_common_avx2.h
@@ -20,9 +20,6 @@
 extern "C" {
 #endif
 
-typedef void (*transform_1d_avx2)(const __m256i *input, __m256i *output,
-                                  int8_t cos_bit);
-
 static INLINE __m256i pair_set_w16_epi16(int16_t a, int16_t b) {
   return _mm256_set1_epi32(
       (int32_t)(((uint16_t)(a)) | (((uint32_t)(b)) << 16)));
@@ -192,6 +189,53 @@
   }
 }
 
+static INLINE __m256i av1_round_shift_32_avx2(__m256i vec, int bit) {
+  __m256i tmp, round;
+  round = _mm256_set1_epi32(1 << (bit - 1));
+  tmp = _mm256_add_epi32(vec, round);
+  return _mm256_srai_epi32(tmp, bit);
+}
+
+static INLINE void av1_round_shift_array_32_avx2(__m256i *input,
+                                                 __m256i *output,
+                                                 const int size,
+                                                 const int bit) {
+  if (bit > 0) {
+    int i;
+    for (i = 0; i < size; i++) {
+      output[i] = av1_round_shift_32_avx2(input[i], bit);
+    }
+  } else {
+    int i;
+    for (i = 0; i < size; i++) {
+      output[i] = _mm256_slli_epi32(input[i], -bit);
+    }
+  }
+}
+
+static INLINE void av1_round_shift_rect_array_32_avx2(__m256i *input,
+                                                      __m256i *output,
+                                                      const int size,
+                                                      const int bit,
+                                                      const int val) {
+  const __m256i sqrt2 = _mm256_set1_epi32(val);
+  if (bit > 0) {
+    int i;
+    for (i = 0; i < size; i++) {
+      const __m256i r0 = av1_round_shift_32_avx2(input[i], bit);
+      const __m256i r1 = _mm256_mullo_epi32(sqrt2, r0);
+      output[i] = av1_round_shift_32_avx2(r1, NewSqrt2Bits);
+    }
+  } else {
+    int i;
+    for (i = 0; i < size; i++) {
+      const __m256i r0 = _mm256_slli_epi32(input[i], -bit);
+      const __m256i r1 = _mm256_mullo_epi32(sqrt2, r0);
+      output[i] = av1_round_shift_32_avx2(r1, NewSqrt2Bits);
+    }
+  }
+}
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index c167c65..7eb6cce 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -122,13 +122,13 @@
 add_proto qw/void av1_highbd_inv_txfm_add_8x16/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_8x16 sse4_1/;
 add_proto qw/void av1_highbd_inv_txfm_add_16x16/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_16x16 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_16x16 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_32x32/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_32x32 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_16x32/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_16x32 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_16x32 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_32x16/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_32x16 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_32x16 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_8x32/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_8x32 sse4_1/;
 add_proto qw/void av1_highbd_inv_txfm_add_32x8/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c
index 5db2ccf..cf1f947 100644
--- a/av1/common/x86/av1_inv_txfm_avx2.c
+++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -1577,6 +1577,9 @@
   idct64_stage11_avx2(output, x);
 }
 
+typedef void (*transform_1d_avx2)(const __m256i *input, __m256i *output,
+                                  int8_t cos_bit);
+
 // 1D functions process 16 pixels at one time.
 static const transform_1d_avx2
     lowbd_txfm_all_1d_zeros_w16_arr[TX_SIZES][ITX_TYPES_1D][4] = {
diff --git a/av1/common/x86/highbd_inv_txfm_avx2.c b/av1/common/x86/highbd_inv_txfm_avx2.c
index 9a1224c..abd892e 100644
--- a/av1/common/x86/highbd_inv_txfm_avx2.c
+++ b/av1/common/x86/highbd_inv_txfm_avx2.c
@@ -18,6 +18,7 @@
 #include "av1/common/idct.h"
 #include "av1/common/x86/av1_inv_txfm_ssse3.h"
 #include "av1/common/x86/highbd_txfm_utility_sse4.h"
+#include "aom_dsp/x86/txfm_common_avx2.h"
 
 // Note:
 //  Total 32x4 registers to represent 32x32 block coefficients.
@@ -73,28 +74,23 @@
   }
 }
 
-static INLINE __m256i av1_round_shift_32_avx2(__m256i vec, int bit) {
-  __m256i tmp, round;
-  round = _mm256_set1_epi32(1 << (bit - 1));
-  tmp = _mm256_add_epi32(vec, round);
-  return _mm256_srai_epi32(tmp, bit);
-}
+static void neg_shift_avx2(const __m256i in0, const __m256i in1, __m256i *out0,
+                           __m256i *out1, const __m256i *clamp_lo,
+                           const __m256i *clamp_hi, int shift) {
+  __m256i offset = _mm256_set1_epi32((1 << shift) >> 1);
+  __m256i a0 = _mm256_add_epi32(offset, in0);
+  __m256i a1 = _mm256_sub_epi32(offset, in1);
 
-static INLINE void av1_round_shift_array_32_avx2(__m256i *input,
-                                                 __m256i *output,
-                                                 const int size,
-                                                 const int bit) {
-  if (bit > 0) {
-    int i;
-    for (i = 0; i < size; i++) {
-      output[i] = av1_round_shift_32_avx2(input[i], bit);
-    }
-  } else {
-    int i;
-    for (i = 0; i < size; i++) {
-      output[i] = _mm256_slli_epi32(input[i], -bit);
-    }
-  }
+  a0 = _mm256_sra_epi32(a0, _mm_cvtsi32_si128(shift));
+  a1 = _mm256_sra_epi32(a1, _mm_cvtsi32_si128(shift));
+
+  a0 = _mm256_max_epi32(a0, *clamp_lo);
+  a0 = _mm256_min_epi32(a0, *clamp_hi);
+  a1 = _mm256_max_epi32(a1, *clamp_lo);
+  a1 = _mm256_min_epi32(a1, *clamp_hi);
+
+  *out0 = a0;
+  *out1 = a1;
 }
 
 static void transpose_8x8_avx2(const __m256i *in, __m256i *out) {
@@ -134,6 +130,43 @@
   out[7] = _mm256_permute2f128_si256(x0, x1, 0x31);
 }
 
+static void transpose_8x8_flip_avx2(const __m256i *in, __m256i *out) {
+  __m256i u0, u1, u2, u3, u4, u5, u6, u7;
+  __m256i x0, x1;
+
+  u0 = _mm256_unpacklo_epi32(in[7], in[6]);
+  u1 = _mm256_unpackhi_epi32(in[7], in[6]);
+
+  u2 = _mm256_unpacklo_epi32(in[5], in[4]);
+  u3 = _mm256_unpackhi_epi32(in[5], in[4]);
+
+  u4 = _mm256_unpacklo_epi32(in[3], in[2]);
+  u5 = _mm256_unpackhi_epi32(in[3], in[2]);
+
+  u6 = _mm256_unpacklo_epi32(in[1], in[0]);
+  u7 = _mm256_unpackhi_epi32(in[1], in[0]);
+
+  x0 = _mm256_unpacklo_epi64(u0, u2);
+  x1 = _mm256_unpacklo_epi64(u4, u6);
+  out[0] = _mm256_permute2f128_si256(x0, x1, 0x20);
+  out[4] = _mm256_permute2f128_si256(x0, x1, 0x31);
+
+  x0 = _mm256_unpackhi_epi64(u0, u2);
+  x1 = _mm256_unpackhi_epi64(u4, u6);
+  out[1] = _mm256_permute2f128_si256(x0, x1, 0x20);
+  out[5] = _mm256_permute2f128_si256(x0, x1, 0x31);
+
+  x0 = _mm256_unpacklo_epi64(u1, u3);
+  x1 = _mm256_unpacklo_epi64(u5, u7);
+  out[2] = _mm256_permute2f128_si256(x0, x1, 0x20);
+  out[6] = _mm256_permute2f128_si256(x0, x1, 0x31);
+
+  x0 = _mm256_unpackhi_epi64(u1, u3);
+  x1 = _mm256_unpackhi_epi64(u5, u7);
+  out[3] = _mm256_permute2f128_si256(x0, x1, 0x20);
+  out[7] = _mm256_permute2f128_si256(x0, x1, 0x31);
+}
+
 static void load_buffer_32x32(const int32_t *coeff, __m256i *in,
                               int input_stiride, int size) {
   int i;
@@ -1153,6 +1186,1301 @@
     }
   }
 }
+static void idct16_low1_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                             int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+  const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+
+  {
+    // stage 0
+    // stage 1
+    // stage 2
+    // stage 3
+    // stage 4
+    in[0] = _mm256_mullo_epi32(in[0], cospi32);
+    in[0] = _mm256_add_epi32(in[0], rnding);
+    in[0] = _mm256_srai_epi32(in[0], bit);
+
+    // stage 5
+    // stage 6
+    // stage 7
+    if (do_cols) {
+      in[0] = _mm256_max_epi32(in[0], clamp_lo);
+      in[0] = _mm256_min_epi32(in[0], clamp_hi);
+    } else {
+      const int log_range_out = AOMMAX(16, bd + 6);
+      const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX(
+          -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+      const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN(
+          (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+      __m256i offset = _mm256_set1_epi32((1 << out_shift) >> 1);
+      in[0] = _mm256_add_epi32(in[0], offset);
+      in[0] = _mm256_sra_epi32(in[0], _mm_cvtsi32_si128(out_shift));
+      in[0] = _mm256_max_epi32(in[0], clamp_lo_out);
+      in[0] = _mm256_min_epi32(in[0], clamp_hi_out);
+    }
+
+    out[0] = in[0];
+    out[1] = in[0];
+    out[2] = in[0];
+    out[3] = in[0];
+    out[4] = in[0];
+    out[5] = in[0];
+    out[6] = in[0];
+    out[7] = in[0];
+    out[8] = in[0];
+    out[9] = in[0];
+    out[10] = in[0];
+    out[11] = in[0];
+    out[12] = in[0];
+    out[13] = in[0];
+    out[14] = in[0];
+    out[15] = in[0];
+  }
+}
+
+static void idct16_low8_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                             int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+  const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
+  const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
+  const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
+  const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
+  const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+  const __m256i cospim40 = _mm256_set1_epi32(-cospi[40]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
+  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+  const __m256i cospim36 = _mm256_set1_epi32(-cospi[36]);
+  const __m256i cospim52 = _mm256_set1_epi32(-cospi[52]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+  const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+  __m256i u[16], x, y;
+
+  {
+    // stage 0
+    // stage 1
+    u[0] = in[0];
+    u[2] = in[4];
+    u[4] = in[2];
+    u[6] = in[6];
+    u[8] = in[1];
+    u[10] = in[5];
+    u[12] = in[3];
+    u[14] = in[7];
+
+    // stage 2
+    u[15] = half_btf_0_avx2(&cospi4, &u[8], &rnding, bit);
+    u[8] = half_btf_0_avx2(&cospi60, &u[8], &rnding, bit);
+
+    u[9] = half_btf_0_avx2(&cospim36, &u[14], &rnding, bit);
+    u[14] = half_btf_0_avx2(&cospi28, &u[14], &rnding, bit);
+
+    u[13] = half_btf_0_avx2(&cospi20, &u[10], &rnding, bit);
+    u[10] = half_btf_0_avx2(&cospi44, &u[10], &rnding, bit);
+
+    u[11] = half_btf_0_avx2(&cospim52, &u[12], &rnding, bit);
+    u[12] = half_btf_0_avx2(&cospi12, &u[12], &rnding, bit);
+
+    // stage 3
+    u[7] = half_btf_0_avx2(&cospi8, &u[4], &rnding, bit);
+    u[4] = half_btf_0_avx2(&cospi56, &u[4], &rnding, bit);
+    u[5] = half_btf_0_avx2(&cospim40, &u[6], &rnding, bit);
+    u[6] = half_btf_0_avx2(&cospi24, &u[6], &rnding, bit);
+
+    addsub_avx2(u[8], u[9], &u[8], &u[9], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[11], u[10], &u[11], &u[10], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[12], u[13], &u[12], &u[13], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[15], u[14], &u[15], &u[14], &clamp_lo, &clamp_hi);
+
+    // stage 4
+    x = _mm256_mullo_epi32(u[0], cospi32);
+    u[0] = _mm256_add_epi32(x, rnding);
+    u[0] = _mm256_srai_epi32(u[0], bit);
+    u[1] = u[0];
+
+    u[3] = half_btf_0_avx2(&cospi16, &u[2], &rnding, bit);
+    u[2] = half_btf_0_avx2(&cospi48, &u[2], &rnding, bit);
+
+    addsub_avx2(u[4], u[5], &u[4], &u[5], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[7], u[6], &u[7], &u[6], &clamp_lo, &clamp_hi);
+
+    x = half_btf_avx2(&cospim16, &u[9], &cospi48, &u[14], &rnding, bit);
+    u[14] = half_btf_avx2(&cospi48, &u[9], &cospi16, &u[14], &rnding, bit);
+    u[9] = x;
+    y = half_btf_avx2(&cospim48, &u[10], &cospim16, &u[13], &rnding, bit);
+    u[13] = half_btf_avx2(&cospim16, &u[10], &cospi48, &u[13], &rnding, bit);
+    u[10] = y;
+
+    // stage 5
+    addsub_avx2(u[0], u[3], &u[0], &u[3], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[1], u[2], &u[1], &u[2], &clamp_lo, &clamp_hi);
+
+    x = _mm256_mullo_epi32(u[5], cospi32);
+    y = _mm256_mullo_epi32(u[6], cospi32);
+    u[5] = _mm256_sub_epi32(y, x);
+    u[5] = _mm256_add_epi32(u[5], rnding);
+    u[5] = _mm256_srai_epi32(u[5], bit);
+
+    u[6] = _mm256_add_epi32(y, x);
+    u[6] = _mm256_add_epi32(u[6], rnding);
+    u[6] = _mm256_srai_epi32(u[6], bit);
+
+    addsub_avx2(u[8], u[11], &u[8], &u[11], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[9], u[10], &u[9], &u[10], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[15], u[12], &u[15], &u[12], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[14], u[13], &u[14], &u[13], &clamp_lo, &clamp_hi);
+
+    // stage 6
+    addsub_avx2(u[0], u[7], &u[0], &u[7], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[1], u[6], &u[1], &u[6], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[2], u[5], &u[2], &u[5], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[3], u[4], &u[3], &u[4], &clamp_lo, &clamp_hi);
+
+    x = _mm256_mullo_epi32(u[10], cospi32);
+    y = _mm256_mullo_epi32(u[13], cospi32);
+    u[10] = _mm256_sub_epi32(y, x);
+    u[10] = _mm256_add_epi32(u[10], rnding);
+    u[10] = _mm256_srai_epi32(u[10], bit);
+
+    u[13] = _mm256_add_epi32(x, y);
+    u[13] = _mm256_add_epi32(u[13], rnding);
+    u[13] = _mm256_srai_epi32(u[13], bit);
+
+    x = _mm256_mullo_epi32(u[11], cospi32);
+    y = _mm256_mullo_epi32(u[12], cospi32);
+    u[11] = _mm256_sub_epi32(y, x);
+    u[11] = _mm256_add_epi32(u[11], rnding);
+    u[11] = _mm256_srai_epi32(u[11], bit);
+
+    u[12] = _mm256_add_epi32(x, y);
+    u[12] = _mm256_add_epi32(u[12], rnding);
+    u[12] = _mm256_srai_epi32(u[12], bit);
+    // stage 7
+    if (do_cols) {
+      addsub_no_clamp_avx2(u[0], u[15], out + 0, out + 15);
+      addsub_no_clamp_avx2(u[1], u[14], out + 1, out + 14);
+      addsub_no_clamp_avx2(u[2], u[13], out + 2, out + 13);
+      addsub_no_clamp_avx2(u[3], u[12], out + 3, out + 12);
+      addsub_no_clamp_avx2(u[4], u[11], out + 4, out + 11);
+      addsub_no_clamp_avx2(u[5], u[10], out + 5, out + 10);
+      addsub_no_clamp_avx2(u[6], u[9], out + 6, out + 9);
+      addsub_no_clamp_avx2(u[7], u[8], out + 7, out + 8);
+    } else {
+      const int log_range_out = AOMMAX(16, bd + 6);
+      const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX(
+          -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+      const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN(
+          (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+      addsub_shift_avx2(u[0], u[15], out + 0, out + 15, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(u[1], u[14], out + 1, out + 14, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(u[2], u[13], out + 2, out + 13, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(u[3], u[12], out + 3, out + 12, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(u[4], u[11], out + 4, out + 11, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(u[5], u[10], out + 5, out + 10, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(u[6], u[9], out + 6, out + 9, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(u[7], u[8], out + 7, out + 8, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    }
+  }
+}
+
+static void idct16_avx2(__m256i *in, __m256i *out, int bit, int do_cols, int bd,
+                        int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+  const __m256i cospim4 = _mm256_set1_epi32(-cospi[4]);
+  const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
+  const __m256i cospim36 = _mm256_set1_epi32(-cospi[36]);
+  const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
+  const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
+  const __m256i cospim20 = _mm256_set1_epi32(-cospi[20]);
+  const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
+  const __m256i cospim52 = _mm256_set1_epi32(-cospi[52]);
+  const __m256i cospi52 = _mm256_set1_epi32(cospi[52]);
+  const __m256i cospi36 = _mm256_set1_epi32(cospi[36]);
+  const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospim8 = _mm256_set1_epi32(-cospi[8]);
+  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+  const __m256i cospim40 = _mm256_set1_epi32(-cospi[40]);
+  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
+  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+  const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+  __m256i u[16], v[16], x, y;
+
+  {
+    // stage 0
+    // stage 1
+    u[0] = in[0];
+    u[1] = in[8];
+    u[2] = in[4];
+    u[3] = in[12];
+    u[4] = in[2];
+    u[5] = in[10];
+    u[6] = in[6];
+    u[7] = in[14];
+    u[8] = in[1];
+    u[9] = in[9];
+    u[10] = in[5];
+    u[11] = in[13];
+    u[12] = in[3];
+    u[13] = in[11];
+    u[14] = in[7];
+    u[15] = in[15];
+
+    // stage 2
+    v[0] = u[0];
+    v[1] = u[1];
+    v[2] = u[2];
+    v[3] = u[3];
+    v[4] = u[4];
+    v[5] = u[5];
+    v[6] = u[6];
+    v[7] = u[7];
+
+    v[8] = half_btf_avx2(&cospi60, &u[8], &cospim4, &u[15], &rnding, bit);
+    v[9] = half_btf_avx2(&cospi28, &u[9], &cospim36, &u[14], &rnding, bit);
+    v[10] = half_btf_avx2(&cospi44, &u[10], &cospim20, &u[13], &rnding, bit);
+    v[11] = half_btf_avx2(&cospi12, &u[11], &cospim52, &u[12], &rnding, bit);
+    v[12] = half_btf_avx2(&cospi52, &u[11], &cospi12, &u[12], &rnding, bit);
+    v[13] = half_btf_avx2(&cospi20, &u[10], &cospi44, &u[13], &rnding, bit);
+    v[14] = half_btf_avx2(&cospi36, &u[9], &cospi28, &u[14], &rnding, bit);
+    v[15] = half_btf_avx2(&cospi4, &u[8], &cospi60, &u[15], &rnding, bit);
+
+    // stage 3
+    u[0] = v[0];
+    u[1] = v[1];
+    u[2] = v[2];
+    u[3] = v[3];
+    u[4] = half_btf_avx2(&cospi56, &v[4], &cospim8, &v[7], &rnding, bit);
+    u[5] = half_btf_avx2(&cospi24, &v[5], &cospim40, &v[6], &rnding, bit);
+    u[6] = half_btf_avx2(&cospi40, &v[5], &cospi24, &v[6], &rnding, bit);
+    u[7] = half_btf_avx2(&cospi8, &v[4], &cospi56, &v[7], &rnding, bit);
+    addsub_avx2(v[8], v[9], &u[8], &u[9], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[11], v[10], &u[11], &u[10], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[12], v[13], &u[12], &u[13], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[15], v[14], &u[15], &u[14], &clamp_lo, &clamp_hi);
+
+    // stage 4
+    x = _mm256_mullo_epi32(u[0], cospi32);
+    y = _mm256_mullo_epi32(u[1], cospi32);
+    v[0] = _mm256_add_epi32(x, y);
+    v[0] = _mm256_add_epi32(v[0], rnding);
+    v[0] = _mm256_srai_epi32(v[0], bit);
+
+    v[1] = _mm256_sub_epi32(x, y);
+    v[1] = _mm256_add_epi32(v[1], rnding);
+    v[1] = _mm256_srai_epi32(v[1], bit);
+
+    v[2] = half_btf_avx2(&cospi48, &u[2], &cospim16, &u[3], &rnding, bit);
+    v[3] = half_btf_avx2(&cospi16, &u[2], &cospi48, &u[3], &rnding, bit);
+    addsub_avx2(u[4], u[5], &v[4], &v[5], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[7], u[6], &v[7], &v[6], &clamp_lo, &clamp_hi);
+    v[8] = u[8];
+    v[9] = half_btf_avx2(&cospim16, &u[9], &cospi48, &u[14], &rnding, bit);
+    v[10] = half_btf_avx2(&cospim48, &u[10], &cospim16, &u[13], &rnding, bit);
+    v[11] = u[11];
+    v[12] = u[12];
+    v[13] = half_btf_avx2(&cospim16, &u[10], &cospi48, &u[13], &rnding, bit);
+    v[14] = half_btf_avx2(&cospi48, &u[9], &cospi16, &u[14], &rnding, bit);
+    v[15] = u[15];
+
+    // stage 5
+    addsub_avx2(v[0], v[3], &u[0], &u[3], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[1], v[2], &u[1], &u[2], &clamp_lo, &clamp_hi);
+    u[4] = v[4];
+
+    x = _mm256_mullo_epi32(v[5], cospi32);
+    y = _mm256_mullo_epi32(v[6], cospi32);
+    u[5] = _mm256_sub_epi32(y, x);
+    u[5] = _mm256_add_epi32(u[5], rnding);
+    u[5] = _mm256_srai_epi32(u[5], bit);
+
+    u[6] = _mm256_add_epi32(y, x);
+    u[6] = _mm256_add_epi32(u[6], rnding);
+    u[6] = _mm256_srai_epi32(u[6], bit);
+
+    u[7] = v[7];
+    addsub_avx2(v[8], v[11], &u[8], &u[11], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[9], v[10], &u[9], &u[10], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[15], v[12], &u[15], &u[12], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[14], v[13], &u[14], &u[13], &clamp_lo, &clamp_hi);
+
+    // stage 6
+    addsub_avx2(u[0], u[7], &v[0], &v[7], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[1], u[6], &v[1], &v[6], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[2], u[5], &v[2], &v[5], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[3], u[4], &v[3], &v[4], &clamp_lo, &clamp_hi);
+    v[8] = u[8];
+    v[9] = u[9];
+
+    x = _mm256_mullo_epi32(u[10], cospi32);
+    y = _mm256_mullo_epi32(u[13], cospi32);
+    v[10] = _mm256_sub_epi32(y, x);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[13] = _mm256_add_epi32(x, y);
+    v[13] = _mm256_add_epi32(v[13], rnding);
+    v[13] = _mm256_srai_epi32(v[13], bit);
+
+    x = _mm256_mullo_epi32(u[11], cospi32);
+    y = _mm256_mullo_epi32(u[12], cospi32);
+    v[11] = _mm256_sub_epi32(y, x);
+    v[11] = _mm256_add_epi32(v[11], rnding);
+    v[11] = _mm256_srai_epi32(v[11], bit);
+
+    v[12] = _mm256_add_epi32(x, y);
+    v[12] = _mm256_add_epi32(v[12], rnding);
+    v[12] = _mm256_srai_epi32(v[12], bit);
+
+    v[14] = u[14];
+    v[15] = u[15];
+
+    // stage 7
+    if (do_cols) {
+      addsub_no_clamp_avx2(v[0], v[15], out + 0, out + 15);
+      addsub_no_clamp_avx2(v[1], v[14], out + 1, out + 14);
+      addsub_no_clamp_avx2(v[2], v[13], out + 2, out + 13);
+      addsub_no_clamp_avx2(v[3], v[12], out + 3, out + 12);
+      addsub_no_clamp_avx2(v[4], v[11], out + 4, out + 11);
+      addsub_no_clamp_avx2(v[5], v[10], out + 5, out + 10);
+      addsub_no_clamp_avx2(v[6], v[9], out + 6, out + 9);
+      addsub_no_clamp_avx2(v[7], v[8], out + 7, out + 8);
+    } else {
+      const int log_range_out = AOMMAX(16, bd + 6);
+      const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX(
+          -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+      const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN(
+          (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+      addsub_shift_avx2(v[0], v[15], out + 0, out + 15, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(v[1], v[14], out + 1, out + 14, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(v[2], v[13], out + 2, out + 13, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(v[3], v[12], out + 3, out + 12, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(v[4], v[11], out + 4, out + 11, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(v[5], v[10], out + 5, out + 10, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(v[6], v[9], out + 6, out + 9, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+      addsub_shift_avx2(v[7], v[8], out + 7, out + 8, &clamp_lo_out,
+                        &clamp_hi_out, out_shift);
+    }
+  }
+}
+
+static void iadst16_low1_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                              int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi2 = _mm256_set1_epi32(cospi[2]);
+  const __m256i cospi62 = _mm256_set1_epi32(cospi[62]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const __m256i zero = _mm256_setzero_si256();
+  __m256i v[16], x, y, temp1, temp2;
+
+  // Calculate the column 0, 1, 2, 3
+  {
+    // stage 0
+    // stage 1
+    // stage 2
+    x = _mm256_mullo_epi32(in[0], cospi62);
+    v[0] = _mm256_add_epi32(x, rnding);
+    v[0] = _mm256_srai_epi32(v[0], bit);
+
+    x = _mm256_mullo_epi32(in[0], cospi2);
+    v[1] = _mm256_sub_epi32(zero, x);
+    v[1] = _mm256_add_epi32(v[1], rnding);
+    v[1] = _mm256_srai_epi32(v[1], bit);
+
+    // stage 3
+    v[8] = v[0];
+    v[9] = v[1];
+
+    // stage 4
+    temp1 = _mm256_mullo_epi32(v[8], cospi8);
+    x = _mm256_mullo_epi32(v[9], cospi56);
+    temp1 = _mm256_add_epi32(temp1, x);
+    temp1 = _mm256_add_epi32(temp1, rnding);
+    temp1 = _mm256_srai_epi32(temp1, bit);
+
+    temp2 = _mm256_mullo_epi32(v[8], cospi56);
+    x = _mm256_mullo_epi32(v[9], cospi8);
+    temp2 = _mm256_sub_epi32(temp2, x);
+    temp2 = _mm256_add_epi32(temp2, rnding);
+    temp2 = _mm256_srai_epi32(temp2, bit);
+    v[8] = temp1;
+    v[9] = temp2;
+
+    // stage 5
+    v[4] = v[0];
+    v[5] = v[1];
+    v[12] = v[8];
+    v[13] = v[9];
+
+    // stage 6
+    temp1 = _mm256_mullo_epi32(v[4], cospi16);
+    x = _mm256_mullo_epi32(v[5], cospi48);
+    temp1 = _mm256_add_epi32(temp1, x);
+    temp1 = _mm256_add_epi32(temp1, rnding);
+    temp1 = _mm256_srai_epi32(temp1, bit);
+
+    temp2 = _mm256_mullo_epi32(v[4], cospi48);
+    x = _mm256_mullo_epi32(v[5], cospi16);
+    temp2 = _mm256_sub_epi32(temp2, x);
+    temp2 = _mm256_add_epi32(temp2, rnding);
+    temp2 = _mm256_srai_epi32(temp2, bit);
+    v[4] = temp1;
+    v[5] = temp2;
+
+    temp1 = _mm256_mullo_epi32(v[12], cospi16);
+    x = _mm256_mullo_epi32(v[13], cospi48);
+    temp1 = _mm256_add_epi32(temp1, x);
+    temp1 = _mm256_add_epi32(temp1, rnding);
+    temp1 = _mm256_srai_epi32(temp1, bit);
+
+    temp2 = _mm256_mullo_epi32(v[12], cospi48);
+    x = _mm256_mullo_epi32(v[13], cospi16);
+    temp2 = _mm256_sub_epi32(temp2, x);
+    temp2 = _mm256_add_epi32(temp2, rnding);
+    temp2 = _mm256_srai_epi32(temp2, bit);
+    v[12] = temp1;
+    v[13] = temp2;
+
+    // stage 7
+    v[2] = v[0];
+    v[3] = v[1];
+    v[6] = v[4];
+    v[7] = v[5];
+    v[10] = v[8];
+    v[11] = v[9];
+    v[14] = v[12];
+    v[15] = v[13];
+
+    // stage 8
+    y = _mm256_mullo_epi32(v[2], cospi32);
+    x = _mm256_mullo_epi32(v[3], cospi32);
+    v[2] = _mm256_add_epi32(y, x);
+    v[2] = _mm256_add_epi32(v[2], rnding);
+    v[2] = _mm256_srai_epi32(v[2], bit);
+
+    v[3] = _mm256_sub_epi32(y, x);
+    v[3] = _mm256_add_epi32(v[3], rnding);
+    v[3] = _mm256_srai_epi32(v[3], bit);
+
+    y = _mm256_mullo_epi32(v[6], cospi32);
+    x = _mm256_mullo_epi32(v[7], cospi32);
+    v[6] = _mm256_add_epi32(y, x);
+    v[6] = _mm256_add_epi32(v[6], rnding);
+    v[6] = _mm256_srai_epi32(v[6], bit);
+
+    v[7] = _mm256_sub_epi32(y, x);
+    v[7] = _mm256_add_epi32(v[7], rnding);
+    v[7] = _mm256_srai_epi32(v[7], bit);
+
+    y = _mm256_mullo_epi32(v[10], cospi32);
+    x = _mm256_mullo_epi32(v[11], cospi32);
+    v[10] = _mm256_add_epi32(y, x);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[11] = _mm256_sub_epi32(y, x);
+    v[11] = _mm256_add_epi32(v[11], rnding);
+    v[11] = _mm256_srai_epi32(v[11], bit);
+
+    y = _mm256_mullo_epi32(v[14], cospi32);
+    x = _mm256_mullo_epi32(v[15], cospi32);
+    v[14] = _mm256_add_epi32(y, x);
+    v[14] = _mm256_add_epi32(v[14], rnding);
+    v[14] = _mm256_srai_epi32(v[14], bit);
+
+    v[15] = _mm256_sub_epi32(y, x);
+    v[15] = _mm256_add_epi32(v[15], rnding);
+    v[15] = _mm256_srai_epi32(v[15], bit);
+
+    // stage 9
+    if (do_cols) {
+      out[0] = v[0];
+      out[1] = _mm256_sub_epi32(_mm256_setzero_si256(), v[8]);
+      out[2] = v[12];
+      out[3] = _mm256_sub_epi32(_mm256_setzero_si256(), v[4]);
+      out[4] = v[6];
+      out[5] = _mm256_sub_epi32(_mm256_setzero_si256(), v[14]);
+      out[6] = v[10];
+      out[7] = _mm256_sub_epi32(_mm256_setzero_si256(), v[2]);
+      out[8] = v[3];
+      out[9] = _mm256_sub_epi32(_mm256_setzero_si256(), v[11]);
+      out[10] = v[15];
+      out[11] = _mm256_sub_epi32(_mm256_setzero_si256(), v[7]);
+      out[12] = v[5];
+      out[13] = _mm256_sub_epi32(_mm256_setzero_si256(), v[13]);
+      out[14] = v[9];
+      out[15] = _mm256_sub_epi32(_mm256_setzero_si256(), v[1]);
+    } else {
+      const int log_range_out = AOMMAX(16, bd + 6);
+      const __m256i clamp_lo_out =
+          _mm256_set1_epi32(-(1 << (log_range_out - 1)));
+      const __m256i clamp_hi_out =
+          _mm256_set1_epi32((1 << (log_range_out - 1)) - 1);
+
+      neg_shift_avx2(v[0], v[8], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out,
+                     out_shift);
+      neg_shift_avx2(v[12], v[4], out + 2, out + 3, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[6], v[14], out + 4, out + 5, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[10], v[2], out + 6, out + 7, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[3], v[11], out + 8, out + 9, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[15], v[7], out + 10, out + 11, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[5], v[13], out + 12, out + 13, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[9], v[1], out + 14, out + 15, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+    }
+  }
+}
+
+static void iadst16_low8_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                              int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi2 = _mm256_set1_epi32(cospi[2]);
+  const __m256i cospi62 = _mm256_set1_epi32(cospi[62]);
+  const __m256i cospi10 = _mm256_set1_epi32(cospi[10]);
+  const __m256i cospi54 = _mm256_set1_epi32(cospi[54]);
+  const __m256i cospi18 = _mm256_set1_epi32(cospi[18]);
+  const __m256i cospi46 = _mm256_set1_epi32(cospi[46]);
+  const __m256i cospi26 = _mm256_set1_epi32(cospi[26]);
+  const __m256i cospi38 = _mm256_set1_epi32(cospi[38]);
+  const __m256i cospi34 = _mm256_set1_epi32(cospi[34]);
+  const __m256i cospi30 = _mm256_set1_epi32(cospi[30]);
+  const __m256i cospi42 = _mm256_set1_epi32(cospi[42]);
+  const __m256i cospi22 = _mm256_set1_epi32(cospi[22]);
+  const __m256i cospi50 = _mm256_set1_epi32(cospi[50]);
+  const __m256i cospi14 = _mm256_set1_epi32(cospi[14]);
+  const __m256i cospi58 = _mm256_set1_epi32(cospi[58]);
+  const __m256i cospi6 = _mm256_set1_epi32(cospi[6]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+  const __m256i cospim56 = _mm256_set1_epi32(-cospi[56]);
+  const __m256i cospim24 = _mm256_set1_epi32(-cospi[24]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+  const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+  __m256i u[16], x, y;
+
+  {
+    // stage 0
+    // stage 1
+    // stage 2
+    __m256i zero = _mm256_setzero_si256();
+    x = _mm256_mullo_epi32(in[0], cospi62);
+    u[0] = _mm256_add_epi32(x, rnding);
+    u[0] = _mm256_srai_epi32(u[0], bit);
+
+    x = _mm256_mullo_epi32(in[0], cospi2);
+    u[1] = _mm256_sub_epi32(zero, x);
+    u[1] = _mm256_add_epi32(u[1], rnding);
+    u[1] = _mm256_srai_epi32(u[1], bit);
+
+    x = _mm256_mullo_epi32(in[2], cospi54);
+    u[2] = _mm256_add_epi32(x, rnding);
+    u[2] = _mm256_srai_epi32(u[2], bit);
+
+    x = _mm256_mullo_epi32(in[2], cospi10);
+    u[3] = _mm256_sub_epi32(zero, x);
+    u[3] = _mm256_add_epi32(u[3], rnding);
+    u[3] = _mm256_srai_epi32(u[3], bit);
+
+    x = _mm256_mullo_epi32(in[4], cospi46);
+    u[4] = _mm256_add_epi32(x, rnding);
+    u[4] = _mm256_srai_epi32(u[4], bit);
+
+    x = _mm256_mullo_epi32(in[4], cospi18);
+    u[5] = _mm256_sub_epi32(zero, x);
+    u[5] = _mm256_add_epi32(u[5], rnding);
+    u[5] = _mm256_srai_epi32(u[5], bit);
+
+    x = _mm256_mullo_epi32(in[6], cospi38);
+    u[6] = _mm256_add_epi32(x, rnding);
+    u[6] = _mm256_srai_epi32(u[6], bit);
+
+    x = _mm256_mullo_epi32(in[6], cospi26);
+    u[7] = _mm256_sub_epi32(zero, x);
+    u[7] = _mm256_add_epi32(u[7], rnding);
+    u[7] = _mm256_srai_epi32(u[7], bit);
+
+    u[8] = _mm256_mullo_epi32(in[7], cospi34);
+    u[8] = _mm256_add_epi32(u[8], rnding);
+    u[8] = _mm256_srai_epi32(u[8], bit);
+
+    u[9] = _mm256_mullo_epi32(in[7], cospi30);
+    u[9] = _mm256_add_epi32(u[9], rnding);
+    u[9] = _mm256_srai_epi32(u[9], bit);
+
+    u[10] = _mm256_mullo_epi32(in[5], cospi42);
+    u[10] = _mm256_add_epi32(u[10], rnding);
+    u[10] = _mm256_srai_epi32(u[10], bit);
+
+    u[11] = _mm256_mullo_epi32(in[5], cospi22);
+    u[11] = _mm256_add_epi32(u[11], rnding);
+    u[11] = _mm256_srai_epi32(u[11], bit);
+
+    u[12] = _mm256_mullo_epi32(in[3], cospi50);
+    u[12] = _mm256_add_epi32(u[12], rnding);
+    u[12] = _mm256_srai_epi32(u[12], bit);
+
+    u[13] = _mm256_mullo_epi32(in[3], cospi14);
+    u[13] = _mm256_add_epi32(u[13], rnding);
+    u[13] = _mm256_srai_epi32(u[13], bit);
+
+    u[14] = _mm256_mullo_epi32(in[1], cospi58);
+    u[14] = _mm256_add_epi32(u[14], rnding);
+    u[14] = _mm256_srai_epi32(u[14], bit);
+
+    u[15] = _mm256_mullo_epi32(in[1], cospi6);
+    u[15] = _mm256_add_epi32(u[15], rnding);
+    u[15] = _mm256_srai_epi32(u[15], bit);
+
+    // stage 3
+    addsub_avx2(u[0], u[8], &u[0], &u[8], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[1], u[9], &u[1], &u[9], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[2], u[10], &u[2], &u[10], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[3], u[11], &u[3], &u[11], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[4], u[12], &u[4], &u[12], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[5], u[13], &u[5], &u[13], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[6], u[14], &u[6], &u[14], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[7], u[15], &u[7], &u[15], &clamp_lo, &clamp_hi);
+
+    // stage 4
+    y = _mm256_mullo_epi32(u[8], cospi56);
+    x = _mm256_mullo_epi32(u[9], cospi56);
+    u[8] = _mm256_mullo_epi32(u[8], cospi8);
+    u[8] = _mm256_add_epi32(u[8], x);
+    u[8] = _mm256_add_epi32(u[8], rnding);
+    u[8] = _mm256_srai_epi32(u[8], bit);
+
+    x = _mm256_mullo_epi32(u[9], cospi8);
+    u[9] = _mm256_sub_epi32(y, x);
+    u[9] = _mm256_add_epi32(u[9], rnding);
+    u[9] = _mm256_srai_epi32(u[9], bit);
+
+    x = _mm256_mullo_epi32(u[11], cospi24);
+    y = _mm256_mullo_epi32(u[10], cospi24);
+    u[10] = _mm256_mullo_epi32(u[10], cospi40);
+    u[10] = _mm256_add_epi32(u[10], x);
+    u[10] = _mm256_add_epi32(u[10], rnding);
+    u[10] = _mm256_srai_epi32(u[10], bit);
+
+    x = _mm256_mullo_epi32(u[11], cospi40);
+    u[11] = _mm256_sub_epi32(y, x);
+    u[11] = _mm256_add_epi32(u[11], rnding);
+    u[11] = _mm256_srai_epi32(u[11], bit);
+
+    x = _mm256_mullo_epi32(u[13], cospi8);
+    y = _mm256_mullo_epi32(u[12], cospi8);
+    u[12] = _mm256_mullo_epi32(u[12], cospim56);
+    u[12] = _mm256_add_epi32(u[12], x);
+    u[12] = _mm256_add_epi32(u[12], rnding);
+    u[12] = _mm256_srai_epi32(u[12], bit);
+
+    x = _mm256_mullo_epi32(u[13], cospim56);
+    u[13] = _mm256_sub_epi32(y, x);
+    u[13] = _mm256_add_epi32(u[13], rnding);
+    u[13] = _mm256_srai_epi32(u[13], bit);
+
+    x = _mm256_mullo_epi32(u[15], cospi40);
+    y = _mm256_mullo_epi32(u[14], cospi40);
+    u[14] = _mm256_mullo_epi32(u[14], cospim24);
+    u[14] = _mm256_add_epi32(u[14], x);
+    u[14] = _mm256_add_epi32(u[14], rnding);
+    u[14] = _mm256_srai_epi32(u[14], bit);
+
+    x = _mm256_mullo_epi32(u[15], cospim24);
+    u[15] = _mm256_sub_epi32(y, x);
+    u[15] = _mm256_add_epi32(u[15], rnding);
+    u[15] = _mm256_srai_epi32(u[15], bit);
+
+    // stage 5
+    addsub_avx2(u[0], u[4], &u[0], &u[4], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[1], u[5], &u[1], &u[5], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[2], u[6], &u[2], &u[6], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[3], u[7], &u[3], &u[7], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[8], u[12], &u[8], &u[12], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[9], u[13], &u[9], &u[13], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[10], u[14], &u[10], &u[14], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[11], u[15], &u[11], &u[15], &clamp_lo, &clamp_hi);
+
+    // stage 6
+    x = _mm256_mullo_epi32(u[5], cospi48);
+    y = _mm256_mullo_epi32(u[4], cospi48);
+    u[4] = _mm256_mullo_epi32(u[4], cospi16);
+    u[4] = _mm256_add_epi32(u[4], x);
+    u[4] = _mm256_add_epi32(u[4], rnding);
+    u[4] = _mm256_srai_epi32(u[4], bit);
+
+    x = _mm256_mullo_epi32(u[5], cospi16);
+    u[5] = _mm256_sub_epi32(y, x);
+    u[5] = _mm256_add_epi32(u[5], rnding);
+    u[5] = _mm256_srai_epi32(u[5], bit);
+
+    x = _mm256_mullo_epi32(u[7], cospi16);
+    y = _mm256_mullo_epi32(u[6], cospi16);
+    u[6] = _mm256_mullo_epi32(u[6], cospim48);
+    u[6] = _mm256_add_epi32(u[6], x);
+    u[6] = _mm256_add_epi32(u[6], rnding);
+    u[6] = _mm256_srai_epi32(u[6], bit);
+
+    x = _mm256_mullo_epi32(u[7], cospim48);
+    u[7] = _mm256_sub_epi32(y, x);
+    u[7] = _mm256_add_epi32(u[7], rnding);
+    u[7] = _mm256_srai_epi32(u[7], bit);
+
+    x = _mm256_mullo_epi32(u[13], cospi48);
+    y = _mm256_mullo_epi32(u[12], cospi48);
+    u[12] = _mm256_mullo_epi32(u[12], cospi16);
+    u[12] = _mm256_add_epi32(u[12], x);
+    u[12] = _mm256_add_epi32(u[12], rnding);
+    u[12] = _mm256_srai_epi32(u[12], bit);
+
+    x = _mm256_mullo_epi32(u[13], cospi16);
+    u[13] = _mm256_sub_epi32(y, x);
+    u[13] = _mm256_add_epi32(u[13], rnding);
+    u[13] = _mm256_srai_epi32(u[13], bit);
+
+    x = _mm256_mullo_epi32(u[15], cospi16);
+    y = _mm256_mullo_epi32(u[14], cospi16);
+    u[14] = _mm256_mullo_epi32(u[14], cospim48);
+    u[14] = _mm256_add_epi32(u[14], x);
+    u[14] = _mm256_add_epi32(u[14], rnding);
+    u[14] = _mm256_srai_epi32(u[14], bit);
+
+    x = _mm256_mullo_epi32(u[15], cospim48);
+    u[15] = _mm256_sub_epi32(y, x);
+    u[15] = _mm256_add_epi32(u[15], rnding);
+    u[15] = _mm256_srai_epi32(u[15], bit);
+
+    // stage 7
+    addsub_avx2(u[0], u[2], &u[0], &u[2], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[1], u[3], &u[1], &u[3], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[4], u[6], &u[4], &u[6], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[5], u[7], &u[5], &u[7], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[8], u[10], &u[8], &u[10], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[9], u[11], &u[9], &u[11], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[12], u[14], &u[12], &u[14], &clamp_lo, &clamp_hi);
+    addsub_avx2(u[13], u[15], &u[13], &u[15], &clamp_lo, &clamp_hi);
+
+    // stage 8
+    y = _mm256_mullo_epi32(u[2], cospi32);
+    x = _mm256_mullo_epi32(u[3], cospi32);
+    u[2] = _mm256_add_epi32(y, x);
+    u[2] = _mm256_add_epi32(u[2], rnding);
+    u[2] = _mm256_srai_epi32(u[2], bit);
+
+    u[3] = _mm256_sub_epi32(y, x);
+    u[3] = _mm256_add_epi32(u[3], rnding);
+    u[3] = _mm256_srai_epi32(u[3], bit);
+    y = _mm256_mullo_epi32(u[6], cospi32);
+    x = _mm256_mullo_epi32(u[7], cospi32);
+    u[6] = _mm256_add_epi32(y, x);
+    u[6] = _mm256_add_epi32(u[6], rnding);
+    u[6] = _mm256_srai_epi32(u[6], bit);
+
+    u[7] = _mm256_sub_epi32(y, x);
+    u[7] = _mm256_add_epi32(u[7], rnding);
+    u[7] = _mm256_srai_epi32(u[7], bit);
+
+    y = _mm256_mullo_epi32(u[10], cospi32);
+    x = _mm256_mullo_epi32(u[11], cospi32);
+    u[10] = _mm256_add_epi32(y, x);
+    u[10] = _mm256_add_epi32(u[10], rnding);
+    u[10] = _mm256_srai_epi32(u[10], bit);
+
+    u[11] = _mm256_sub_epi32(y, x);
+    u[11] = _mm256_add_epi32(u[11], rnding);
+    u[11] = _mm256_srai_epi32(u[11], bit);
+
+    y = _mm256_mullo_epi32(u[14], cospi32);
+    x = _mm256_mullo_epi32(u[15], cospi32);
+    u[14] = _mm256_add_epi32(y, x);
+    u[14] = _mm256_add_epi32(u[14], rnding);
+    u[14] = _mm256_srai_epi32(u[14], bit);
+
+    u[15] = _mm256_sub_epi32(y, x);
+    u[15] = _mm256_add_epi32(u[15], rnding);
+    u[15] = _mm256_srai_epi32(u[15], bit);
+
+    // stage 9
+    if (do_cols) {
+      out[0] = u[0];
+      out[1] = _mm256_sub_epi32(_mm256_setzero_si256(), u[8]);
+      out[2] = u[12];
+      out[3] = _mm256_sub_epi32(_mm256_setzero_si256(), u[4]);
+      out[4] = u[6];
+      out[5] = _mm256_sub_epi32(_mm256_setzero_si256(), u[14]);
+      out[6] = u[10];
+      out[7] = _mm256_sub_epi32(_mm256_setzero_si256(), u[2]);
+      out[8] = u[3];
+      out[9] = _mm256_sub_epi32(_mm256_setzero_si256(), u[11]);
+      out[10] = u[15];
+      out[11] = _mm256_sub_epi32(_mm256_setzero_si256(), u[7]);
+      out[12] = u[5];
+      out[13] = _mm256_sub_epi32(_mm256_setzero_si256(), u[13]);
+      out[14] = u[9];
+      out[15] = _mm256_sub_epi32(_mm256_setzero_si256(), u[1]);
+    } else {
+      const int log_range_out = AOMMAX(16, bd + 6);
+      const __m256i clamp_lo_out =
+          _mm256_set1_epi32(-(1 << (log_range_out - 1)));
+      const __m256i clamp_hi_out =
+          _mm256_set1_epi32((1 << (log_range_out - 1)) - 1);
+
+      neg_shift_avx2(u[0], u[8], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out,
+                     out_shift);
+      neg_shift_avx2(u[12], u[4], out + 2, out + 3, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(u[6], u[14], out + 4, out + 5, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(u[10], u[2], out + 6, out + 7, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(u[3], u[11], out + 8, out + 9, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(u[15], u[7], out + 10, out + 11, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(u[5], u[13], out + 12, out + 13, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(u[9], u[1], out + 14, out + 15, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+    }
+  }
+}
+
+static void iadst16_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                         int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi2 = _mm256_set1_epi32(cospi[2]);
+  const __m256i cospi62 = _mm256_set1_epi32(cospi[62]);
+  const __m256i cospi10 = _mm256_set1_epi32(cospi[10]);
+  const __m256i cospi54 = _mm256_set1_epi32(cospi[54]);
+  const __m256i cospi18 = _mm256_set1_epi32(cospi[18]);
+  const __m256i cospi46 = _mm256_set1_epi32(cospi[46]);
+  const __m256i cospi26 = _mm256_set1_epi32(cospi[26]);
+  const __m256i cospi38 = _mm256_set1_epi32(cospi[38]);
+  const __m256i cospi34 = _mm256_set1_epi32(cospi[34]);
+  const __m256i cospi30 = _mm256_set1_epi32(cospi[30]);
+  const __m256i cospi42 = _mm256_set1_epi32(cospi[42]);
+  const __m256i cospi22 = _mm256_set1_epi32(cospi[22]);
+  const __m256i cospi50 = _mm256_set1_epi32(cospi[50]);
+  const __m256i cospi14 = _mm256_set1_epi32(cospi[14]);
+  const __m256i cospi58 = _mm256_set1_epi32(cospi[58]);
+  const __m256i cospi6 = _mm256_set1_epi32(cospi[6]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+  const __m256i cospim56 = _mm256_set1_epi32(-cospi[56]);
+  const __m256i cospim24 = _mm256_set1_epi32(-cospi[24]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+  const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+  __m256i u[16], v[16], x, y;
+
+  {
+    // stage 0
+    // stage 1
+    // stage 2
+    v[0] = _mm256_mullo_epi32(in[15], cospi2);
+    x = _mm256_mullo_epi32(in[0], cospi62);
+    v[0] = _mm256_add_epi32(v[0], x);
+    v[0] = _mm256_add_epi32(v[0], rnding);
+    v[0] = _mm256_srai_epi32(v[0], bit);
+
+    v[1] = _mm256_mullo_epi32(in[15], cospi62);
+    x = _mm256_mullo_epi32(in[0], cospi2);
+    v[1] = _mm256_sub_epi32(v[1], x);
+    v[1] = _mm256_add_epi32(v[1], rnding);
+    v[1] = _mm256_srai_epi32(v[1], bit);
+
+    v[2] = _mm256_mullo_epi32(in[13], cospi10);
+    x = _mm256_mullo_epi32(in[2], cospi54);
+    v[2] = _mm256_add_epi32(v[2], x);
+    v[2] = _mm256_add_epi32(v[2], rnding);
+    v[2] = _mm256_srai_epi32(v[2], bit);
+
+    v[3] = _mm256_mullo_epi32(in[13], cospi54);
+    x = _mm256_mullo_epi32(in[2], cospi10);
+    v[3] = _mm256_sub_epi32(v[3], x);
+    v[3] = _mm256_add_epi32(v[3], rnding);
+    v[3] = _mm256_srai_epi32(v[3], bit);
+
+    v[4] = _mm256_mullo_epi32(in[11], cospi18);
+    x = _mm256_mullo_epi32(in[4], cospi46);
+    v[4] = _mm256_add_epi32(v[4], x);
+    v[4] = _mm256_add_epi32(v[4], rnding);
+    v[4] = _mm256_srai_epi32(v[4], bit);
+
+    v[5] = _mm256_mullo_epi32(in[11], cospi46);
+    x = _mm256_mullo_epi32(in[4], cospi18);
+    v[5] = _mm256_sub_epi32(v[5], x);
+    v[5] = _mm256_add_epi32(v[5], rnding);
+    v[5] = _mm256_srai_epi32(v[5], bit);
+
+    v[6] = _mm256_mullo_epi32(in[9], cospi26);
+    x = _mm256_mullo_epi32(in[6], cospi38);
+    v[6] = _mm256_add_epi32(v[6], x);
+    v[6] = _mm256_add_epi32(v[6], rnding);
+    v[6] = _mm256_srai_epi32(v[6], bit);
+
+    v[7] = _mm256_mullo_epi32(in[9], cospi38);
+    x = _mm256_mullo_epi32(in[6], cospi26);
+    v[7] = _mm256_sub_epi32(v[7], x);
+    v[7] = _mm256_add_epi32(v[7], rnding);
+    v[7] = _mm256_srai_epi32(v[7], bit);
+
+    v[8] = _mm256_mullo_epi32(in[7], cospi34);
+    x = _mm256_mullo_epi32(in[8], cospi30);
+    v[8] = _mm256_add_epi32(v[8], x);
+    v[8] = _mm256_add_epi32(v[8], rnding);
+    v[8] = _mm256_srai_epi32(v[8], bit);
+
+    v[9] = _mm256_mullo_epi32(in[7], cospi30);
+    x = _mm256_mullo_epi32(in[8], cospi34);
+    v[9] = _mm256_sub_epi32(v[9], x);
+    v[9] = _mm256_add_epi32(v[9], rnding);
+    v[9] = _mm256_srai_epi32(v[9], bit);
+
+    v[10] = _mm256_mullo_epi32(in[5], cospi42);
+    x = _mm256_mullo_epi32(in[10], cospi22);
+    v[10] = _mm256_add_epi32(v[10], x);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[11] = _mm256_mullo_epi32(in[5], cospi22);
+    x = _mm256_mullo_epi32(in[10], cospi42);
+    v[11] = _mm256_sub_epi32(v[11], x);
+    v[11] = _mm256_add_epi32(v[11], rnding);
+    v[11] = _mm256_srai_epi32(v[11], bit);
+
+    v[12] = _mm256_mullo_epi32(in[3], cospi50);
+    x = _mm256_mullo_epi32(in[12], cospi14);
+    v[12] = _mm256_add_epi32(v[12], x);
+    v[12] = _mm256_add_epi32(v[12], rnding);
+    v[12] = _mm256_srai_epi32(v[12], bit);
+
+    v[13] = _mm256_mullo_epi32(in[3], cospi14);
+    x = _mm256_mullo_epi32(in[12], cospi50);
+    v[13] = _mm256_sub_epi32(v[13], x);
+    v[13] = _mm256_add_epi32(v[13], rnding);
+    v[13] = _mm256_srai_epi32(v[13], bit);
+
+    v[14] = _mm256_mullo_epi32(in[1], cospi58);
+    x = _mm256_mullo_epi32(in[14], cospi6);
+    v[14] = _mm256_add_epi32(v[14], x);
+    v[14] = _mm256_add_epi32(v[14], rnding);
+    v[14] = _mm256_srai_epi32(v[14], bit);
+
+    v[15] = _mm256_mullo_epi32(in[1], cospi6);
+    x = _mm256_mullo_epi32(in[14], cospi58);
+    v[15] = _mm256_sub_epi32(v[15], x);
+    v[15] = _mm256_add_epi32(v[15], rnding);
+    v[15] = _mm256_srai_epi32(v[15], bit);
+
+    // stage 3
+    addsub_avx2(v[0], v[8], &u[0], &u[8], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[1], v[9], &u[1], &u[9], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[2], v[10], &u[2], &u[10], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[3], v[11], &u[3], &u[11], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[4], v[12], &u[4], &u[12], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[5], v[13], &u[5], &u[13], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[6], v[14], &u[6], &u[14], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[7], v[15], &u[7], &u[15], &clamp_lo, &clamp_hi);
+
+    // stage 4
+    v[0] = u[0];
+    v[1] = u[1];
+    v[2] = u[2];
+    v[3] = u[3];
+    v[4] = u[4];
+    v[5] = u[5];
+    v[6] = u[6];
+    v[7] = u[7];
+
+    v[8] = _mm256_mullo_epi32(u[8], cospi8);
+    x = _mm256_mullo_epi32(u[9], cospi56);
+    v[8] = _mm256_add_epi32(v[8], x);
+    v[8] = _mm256_add_epi32(v[8], rnding);
+    v[8] = _mm256_srai_epi32(v[8], bit);
+
+    v[9] = _mm256_mullo_epi32(u[8], cospi56);
+    x = _mm256_mullo_epi32(u[9], cospi8);
+    v[9] = _mm256_sub_epi32(v[9], x);
+    v[9] = _mm256_add_epi32(v[9], rnding);
+    v[9] = _mm256_srai_epi32(v[9], bit);
+
+    v[10] = _mm256_mullo_epi32(u[10], cospi40);
+    x = _mm256_mullo_epi32(u[11], cospi24);
+    v[10] = _mm256_add_epi32(v[10], x);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[11] = _mm256_mullo_epi32(u[10], cospi24);
+    x = _mm256_mullo_epi32(u[11], cospi40);
+    v[11] = _mm256_sub_epi32(v[11], x);
+    v[11] = _mm256_add_epi32(v[11], rnding);
+    v[11] = _mm256_srai_epi32(v[11], bit);
+
+    v[12] = _mm256_mullo_epi32(u[12], cospim56);
+    x = _mm256_mullo_epi32(u[13], cospi8);
+    v[12] = _mm256_add_epi32(v[12], x);
+    v[12] = _mm256_add_epi32(v[12], rnding);
+    v[12] = _mm256_srai_epi32(v[12], bit);
+
+    v[13] = _mm256_mullo_epi32(u[12], cospi8);
+    x = _mm256_mullo_epi32(u[13], cospim56);
+    v[13] = _mm256_sub_epi32(v[13], x);
+    v[13] = _mm256_add_epi32(v[13], rnding);
+    v[13] = _mm256_srai_epi32(v[13], bit);
+
+    v[14] = _mm256_mullo_epi32(u[14], cospim24);
+    x = _mm256_mullo_epi32(u[15], cospi40);
+    v[14] = _mm256_add_epi32(v[14], x);
+    v[14] = _mm256_add_epi32(v[14], rnding);
+    v[14] = _mm256_srai_epi32(v[14], bit);
+
+    v[15] = _mm256_mullo_epi32(u[14], cospi40);
+    x = _mm256_mullo_epi32(u[15], cospim24);
+    v[15] = _mm256_sub_epi32(v[15], x);
+    v[15] = _mm256_add_epi32(v[15], rnding);
+    v[15] = _mm256_srai_epi32(v[15], bit);
+
+    // stage 5
+    addsub_avx2(v[0], v[4], &u[0], &u[4], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[1], v[5], &u[1], &u[5], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[2], v[6], &u[2], &u[6], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[3], v[7], &u[3], &u[7], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[8], v[12], &u[8], &u[12], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[9], v[13], &u[9], &u[13], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[10], v[14], &u[10], &u[14], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[11], v[15], &u[11], &u[15], &clamp_lo, &clamp_hi);
+
+    // stage 6
+    v[0] = u[0];
+    v[1] = u[1];
+    v[2] = u[2];
+    v[3] = u[3];
+
+    v[4] = _mm256_mullo_epi32(u[4], cospi16);
+    x = _mm256_mullo_epi32(u[5], cospi48);
+    v[4] = _mm256_add_epi32(v[4], x);
+    v[4] = _mm256_add_epi32(v[4], rnding);
+    v[4] = _mm256_srai_epi32(v[4], bit);
+
+    v[5] = _mm256_mullo_epi32(u[4], cospi48);
+    x = _mm256_mullo_epi32(u[5], cospi16);
+    v[5] = _mm256_sub_epi32(v[5], x);
+    v[5] = _mm256_add_epi32(v[5], rnding);
+    v[5] = _mm256_srai_epi32(v[5], bit);
+
+    v[6] = _mm256_mullo_epi32(u[6], cospim48);
+    x = _mm256_mullo_epi32(u[7], cospi16);
+    v[6] = _mm256_add_epi32(v[6], x);
+    v[6] = _mm256_add_epi32(v[6], rnding);
+    v[6] = _mm256_srai_epi32(v[6], bit);
+
+    v[7] = _mm256_mullo_epi32(u[6], cospi16);
+    x = _mm256_mullo_epi32(u[7], cospim48);
+    v[7] = _mm256_sub_epi32(v[7], x);
+    v[7] = _mm256_add_epi32(v[7], rnding);
+    v[7] = _mm256_srai_epi32(v[7], bit);
+
+    v[8] = u[8];
+    v[9] = u[9];
+    v[10] = u[10];
+    v[11] = u[11];
+
+    v[12] = _mm256_mullo_epi32(u[12], cospi16);
+    x = _mm256_mullo_epi32(u[13], cospi48);
+    v[12] = _mm256_add_epi32(v[12], x);
+    v[12] = _mm256_add_epi32(v[12], rnding);
+    v[12] = _mm256_srai_epi32(v[12], bit);
+
+    v[13] = _mm256_mullo_epi32(u[12], cospi48);
+    x = _mm256_mullo_epi32(u[13], cospi16);
+    v[13] = _mm256_sub_epi32(v[13], x);
+    v[13] = _mm256_add_epi32(v[13], rnding);
+    v[13] = _mm256_srai_epi32(v[13], bit);
+
+    v[14] = _mm256_mullo_epi32(u[14], cospim48);
+    x = _mm256_mullo_epi32(u[15], cospi16);
+    v[14] = _mm256_add_epi32(v[14], x);
+    v[14] = _mm256_add_epi32(v[14], rnding);
+    v[14] = _mm256_srai_epi32(v[14], bit);
+
+    v[15] = _mm256_mullo_epi32(u[14], cospi16);
+    x = _mm256_mullo_epi32(u[15], cospim48);
+    v[15] = _mm256_sub_epi32(v[15], x);
+    v[15] = _mm256_add_epi32(v[15], rnding);
+    v[15] = _mm256_srai_epi32(v[15], bit);
+
+    // stage 7
+    addsub_avx2(v[0], v[2], &u[0], &u[2], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[1], v[3], &u[1], &u[3], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[4], v[6], &u[4], &u[6], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[5], v[7], &u[5], &u[7], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[8], v[10], &u[8], &u[10], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[9], v[11], &u[9], &u[11], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[12], v[14], &u[12], &u[14], &clamp_lo, &clamp_hi);
+    addsub_avx2(v[13], v[15], &u[13], &u[15], &clamp_lo, &clamp_hi);
+
+    // stage 8
+    v[0] = u[0];
+    v[1] = u[1];
+
+    y = _mm256_mullo_epi32(u[2], cospi32);
+    x = _mm256_mullo_epi32(u[3], cospi32);
+    v[2] = _mm256_add_epi32(y, x);
+    v[2] = _mm256_add_epi32(v[2], rnding);
+    v[2] = _mm256_srai_epi32(v[2], bit);
+
+    v[3] = _mm256_sub_epi32(y, x);
+    v[3] = _mm256_add_epi32(v[3], rnding);
+    v[3] = _mm256_srai_epi32(v[3], bit);
+
+    v[4] = u[4];
+    v[5] = u[5];
+
+    y = _mm256_mullo_epi32(u[6], cospi32);
+    x = _mm256_mullo_epi32(u[7], cospi32);
+    v[6] = _mm256_add_epi32(y, x);
+    v[6] = _mm256_add_epi32(v[6], rnding);
+    v[6] = _mm256_srai_epi32(v[6], bit);
+
+    v[7] = _mm256_sub_epi32(y, x);
+    v[7] = _mm256_add_epi32(v[7], rnding);
+    v[7] = _mm256_srai_epi32(v[7], bit);
+
+    v[8] = u[8];
+    v[9] = u[9];
+
+    y = _mm256_mullo_epi32(u[10], cospi32);
+    x = _mm256_mullo_epi32(u[11], cospi32);
+    v[10] = _mm256_add_epi32(y, x);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[11] = _mm256_sub_epi32(y, x);
+    v[11] = _mm256_add_epi32(v[11], rnding);
+    v[11] = _mm256_srai_epi32(v[11], bit);
+
+    v[12] = u[12];
+    v[13] = u[13];
+
+    y = _mm256_mullo_epi32(u[14], cospi32);
+    x = _mm256_mullo_epi32(u[15], cospi32);
+    v[14] = _mm256_add_epi32(y, x);
+    v[14] = _mm256_add_epi32(v[14], rnding);
+    v[14] = _mm256_srai_epi32(v[14], bit);
+
+    v[15] = _mm256_sub_epi32(y, x);
+    v[15] = _mm256_add_epi32(v[15], rnding);
+    v[15] = _mm256_srai_epi32(v[15], bit);
+
+    // stage 9
+    if (do_cols) {
+      out[0] = v[0];
+      out[1] = _mm256_sub_epi32(_mm256_setzero_si256(), v[8]);
+      out[2] = v[12];
+      out[3] = _mm256_sub_epi32(_mm256_setzero_si256(), v[4]);
+      out[4] = v[6];
+      out[5] = _mm256_sub_epi32(_mm256_setzero_si256(), v[14]);
+      out[6] = v[10];
+      out[7] = _mm256_sub_epi32(_mm256_setzero_si256(), v[2]);
+      out[8] = v[3];
+      out[9] = _mm256_sub_epi32(_mm256_setzero_si256(), v[11]);
+      out[10] = v[15];
+      out[11] = _mm256_sub_epi32(_mm256_setzero_si256(), v[7]);
+      out[12] = v[5];
+      out[13] = _mm256_sub_epi32(_mm256_setzero_si256(), v[13]);
+      out[14] = v[9];
+      out[15] = _mm256_sub_epi32(_mm256_setzero_si256(), v[1]);
+    } else {
+      const int log_range_out = AOMMAX(16, bd + 6);
+      const __m256i clamp_lo_out =
+          _mm256_set1_epi32(-(1 << (log_range_out - 1)));
+      const __m256i clamp_hi_out =
+          _mm256_set1_epi32((1 << (log_range_out - 1)) - 1);
+
+      neg_shift_avx2(v[0], v[8], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out,
+                     out_shift);
+      neg_shift_avx2(v[12], v[4], out + 2, out + 3, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[6], v[14], out + 4, out + 5, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[10], v[2], out + 6, out + 7, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[3], v[11], out + 8, out + 9, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[15], v[7], out + 10, out + 11, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[5], v[13], out + 12, out + 13, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+      neg_shift_avx2(v[9], v[1], out + 14, out + 15, &clamp_lo_out,
+                     &clamp_hi_out, out_shift);
+    }
+  }
+}
 
 typedef void (*transform_1d_avx2)(__m256i *in, __m256i *out, int bit,
                                   int do_cols, int bd, int out_shift);
@@ -1164,14 +2492,16 @@
           { NULL, NULL, NULL, NULL },
           { NULL, NULL, NULL, NULL },
       },
-      { { NULL, NULL, NULL, NULL },
-        { NULL, NULL, NULL, NULL },
-        { NULL, NULL, NULL, NULL } },
       {
           { NULL, NULL, NULL, NULL },
           { NULL, NULL, NULL, NULL },
           { NULL, NULL, NULL, NULL },
       },
+      {
+          { idct16_low1_avx2, idct16_low8_avx2, idct16_avx2, NULL },
+          { iadst16_low1_avx2, iadst16_low8_avx2, iadst16_avx2, NULL },
+          { NULL, NULL, NULL, NULL },
+      },
       { { idct32_low1_avx2, idct32_low8_avx2, idct32_low16_avx2, idct32_avx2 },
         { NULL, NULL, NULL, NULL },
         { NULL, NULL, NULL, NULL } },
@@ -1198,7 +2528,7 @@
   const int buf_size_nonzero_w_div8 = (eobx + 8) >> 3;
   const int buf_size_nonzero_h_div8 = (eoby + 8) >> 3;
   const int input_stride = AOMMIN(32, txfm_size_col);
-
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
   const int fun_idx_x = lowbd_txfm_all_1d_zeros_idx[eobx];
   const int fun_idx_y = lowbd_txfm_all_1d_zeros_idx[eoby];
   const transform_1d_avx2 row_txfm =
@@ -1221,12 +2551,22 @@
 
       transpose_8x8_avx2(&buf0_cur[0], &buf0_cur[0]);
     }
-
+    if (rect_type == 1 || rect_type == -1) {
+      av1_round_shift_rect_array_32_avx2(
+          buf0, buf0, buf_size_nonzero_w_div8 << 3, 0, NewInvSqrt2);
+    }
     row_txfm(buf0, buf0, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]);
 
     __m256i *_buf1 = buf1 + i * 8;
-    for (int j = 0; j < buf_size_w_div8; ++j) {
-      transpose_8x8_avx2(&buf0[j * 8], &_buf1[j * txfm_size_row]);
+    if (lr_flip) {
+      for (int j = 0; j < buf_size_w_div8; ++j) {
+        transpose_8x8_flip_avx2(
+            &buf0[j * 8], &_buf1[(buf_size_w_div8 - 1 - j) * txfm_size_row]);
+      }
+    } else {
+      for (int j = 0; j < buf_size_w_div8; ++j) {
+        transpose_8x8_avx2(&buf0[j * 8], &_buf1[j * txfm_size_row]);
+      }
     }
   }
   // 2nd stage: column transform
@@ -1255,6 +2595,14 @@
                                              int eob, const int bd) {
   switch (tx_type) {
     case DCT_DCT:
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+    case FLIPADST_DCT:
+    case DCT_FLIPADST:
+    case FLIPADST_FLIPADST:
+    case ADST_FLIPADST:
+    case FLIPADST_ADST:
       highbd_inv_txfm2d_add_no_identity_avx2(input, CONVERT_TO_SHORTPTR(output),
                                              stride, tx_type, tx_size, eob, bd);
       break;
@@ -1262,6 +2610,33 @@
   }
 }
 
+void av1_highbd_inv_txfm_add_16x16_avx2(const tran_low_t *input, uint8_t *dest,
+                                        int stride,
+                                        const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+      // Assembly version doesn't support some transform types, so use C version
+      // for those.
+    case V_DCT:
+    case H_DCT:
+    case V_ADST:
+    case H_ADST:
+    case V_FLIPADST:
+    case H_FLIPADST:
+    case IDTX:
+      av1_inv_txfm2d_add_16x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                                 tx_type, bd);
+      break;
+    default:
+      av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+                                              txfm_param->tx_size,
+                                              txfm_param->eob, bd);
+      break;
+  }
+}
+
 void av1_highbd_inv_txfm_add_32x32_avx2(const tran_low_t *input, uint8_t *dest,
                                         int stride,
                                         const TxfmParam *txfm_param) {
@@ -1284,6 +2659,46 @@
   }
 }
 
+void av1_highbd_inv_txfm_add_16x32_avx2(const tran_low_t *input, uint8_t *dest,
+                                        int stride,
+                                        const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+    case DCT_DCT:
+      av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+                                              txfm_param->tx_size,
+                                              txfm_param->eob, bd);
+      break;
+    case IDTX:
+      av1_inv_txfm2d_add_16x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                                 txfm_param->tx_type, txfm_param->bd);
+      break;
+    default: assert(0);
+  }
+}
+
+void av1_highbd_inv_txfm_add_32x16_avx2(const tran_low_t *input, uint8_t *dest,
+                                        int stride,
+                                        const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+    case DCT_DCT:
+      av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+                                              txfm_param->tx_size,
+                                              txfm_param->eob, bd);
+      break;
+    case IDTX:
+      av1_inv_txfm2d_add_32x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                                 txfm_param->tx_type, txfm_param->bd);
+      break;
+    default: assert(0);
+  }
+}
+
 void av1_highbd_inv_txfm_add_avx2(const tran_low_t *input, uint8_t *dest,
                                   int stride, const TxfmParam *txfm_param) {
   assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
@@ -1293,7 +2708,7 @@
       av1_highbd_inv_txfm_add_32x32_avx2(input, dest, stride, txfm_param);
       break;
     case TX_16X16:
-      av1_highbd_inv_txfm_add_16x16_sse4_1(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_16x16_avx2(input, dest, stride, txfm_param);
       break;
     case TX_8X8:
       av1_highbd_inv_txfm_add_8x8_sse4_1(input, dest, stride, txfm_param);
@@ -1311,10 +2726,10 @@
       av1_highbd_inv_txfm_add_16x8_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_16X32:
-      av1_highbd_inv_txfm_add_16x32_sse4_1(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_16x32_avx2(input, dest, stride, txfm_param);
       break;
     case TX_32X16:
-      av1_highbd_inv_txfm_add_32x16_sse4_1(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_32x16_avx2(input, dest, stride, txfm_param);
       break;
     case TX_4X4:
       av1_highbd_inv_txfm_add_4x4_sse4_1(input, dest, stride, txfm_param);
diff --git a/av1/encoder/x86/av1_fwd_txfm2d_avx2.c b/av1/encoder/x86/av1_fwd_txfm2d_avx2.c
index 592462e..13982cc 100644
--- a/av1/encoder/x86/av1_fwd_txfm2d_avx2.c
+++ b/av1/encoder/x86/av1_fwd_txfm2d_avx2.c
@@ -1436,45 +1436,6 @@
   }
 }
 
-static INLINE void av1_round_shift_array_32_avx2(__m256i *input,
-                                                 __m256i *output,
-                                                 const int size,
-                                                 const int bit) {
-  if (bit > 0) {
-    int i;
-    for (i = 0; i < size; i++) {
-      output[i] = av1_round_shift_32_avx2(input[i], bit);
-    }
-  } else {
-    int i;
-    for (i = 0; i < size; i++) {
-      output[i] = _mm256_slli_epi32(input[i], -bit);
-    }
-  }
-}
-
-static INLINE void av1_round_shift_rect_array_32_avx2(__m256i *input,
-                                                      __m256i *output,
-                                                      const int size,
-                                                      const int bit) {
-  const __m256i sqrt2 = _mm256_set1_epi32(NewSqrt2);
-  if (bit > 0) {
-    int i;
-    for (i = 0; i < size; i++) {
-      const __m256i r0 = av1_round_shift_32_avx2(input[i], bit);
-      const __m256i r1 = _mm256_mullo_epi32(sqrt2, r0);
-      output[i] = av1_round_shift_32_avx2(r1, NewSqrt2Bits);
-    }
-  } else {
-    int i;
-    for (i = 0; i < size; i++) {
-      const __m256i r0 = _mm256_slli_epi32(input[i], -bit);
-      const __m256i r1 = _mm256_mullo_epi32(sqrt2, r0);
-      output[i] = av1_round_shift_32_avx2(r1, NewSqrt2Bits);
-    }
-  }
-}
-
 static INLINE void transpose_32_8x8_avx2(int stride, const __m256i *inputA,
                                          __m256i *output) {
   __m256i temp0 = _mm256_unpacklo_epi32(inputA[0], inputA[2]);
@@ -1540,6 +1501,9 @@
   }
 }
 
+typedef void (*transform_1d_avx2)(const __m256i *input, __m256i *output,
+                                  int8_t cos_bit);
+
 static const transform_1d_avx2 col_txfm16x32_arr[TX_TYPES] = {
   fdct16x32_new_avx2,       // DCT_DCT
   NULL,                     // ADST_DCT
@@ -1885,8 +1849,8 @@
     }
     av1_fdct64_new_avx2(bufA, bufA, cos_bit_row);
     av1_fdct64_new_avx2(bufB, bufB, cos_bit_row);
-    av1_round_shift_rect_array_32_avx2(bufA, bufA, 32, -shift[2]);
-    av1_round_shift_rect_array_32_avx2(bufB, bufB, 32, -shift[2]);
+    av1_round_shift_rect_array_32_avx2(bufA, bufA, 32, -shift[2], NewSqrt2);
+    av1_round_shift_rect_array_32_avx2(bufB, bufB, 32, -shift[2], NewSqrt2);
 
     int32_t *output8 = output + 16 * 32 * i;
     for (int j = 0; j < 4; ++j) {
@@ -1935,8 +1899,8 @@
     }
     av1_fdct32_new_avx2(bufA, bufA, cos_bit_row);
     av1_fdct32_new_avx2(bufB, bufB, cos_bit_row);
-    av1_round_shift_rect_array_32_avx2(bufA, bufA, 32, -shift[2]);
-    av1_round_shift_rect_array_32_avx2(bufB, bufB, 32, -shift[2]);
+    av1_round_shift_rect_array_32_avx2(bufA, bufA, 32, -shift[2], NewSqrt2);
+    av1_round_shift_rect_array_32_avx2(bufB, bufB, 32, -shift[2], NewSqrt2);
 
     int32_t *output8 = output + 16 * 32 * i;
     for (int j = 0; j < 4; ++j) {
diff --git a/av1/encoder/x86/av1_fwd_txfm_avx2.h b/av1/encoder/x86/av1_fwd_txfm_avx2.h
index 3870713..aaad76e 100644
--- a/av1/encoder/x86/av1_fwd_txfm_avx2.h
+++ b/av1/encoder/x86/av1_fwd_txfm_avx2.h
@@ -13,13 +13,6 @@
 #define AOM_AV1_ENCODER_X86_AV1_FWD_TXFM_AVX2_H_
 #include <immintrin.h>
 
-static INLINE __m256i av1_round_shift_32_avx2(__m256i vec, int bit) {
-  __m256i tmp, round;
-  round = _mm256_set1_epi32(1 << (bit - 1));
-  tmp = _mm256_add_epi32(vec, round);
-  return _mm256_srai_epi32(tmp, bit);
-}
-
 // out0 = in0*w0 + in1*w1
 // out1 = -in1*w0 + in0*w1
 static INLINE void btf_32_avx2_type0(const int32_t w0, const int32_t w1,