Optimize highbd fwd_txfm module

Added AVX2 variant for 16x16 txfm blk-size.

When tested for multiple test cases observed 0.95%
average reduction in encoder time for speed = 1 preset.

Module level gains improved by a factor of ~1.9
on average w.r.t to SSE4_1 module.

Change-Id: Ia6d9467c748c8c9d9f6c0ff3b5786d3b5921b87d
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 6189cc2..4b8ea99 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -233,7 +233,7 @@
   add_proto qw/void av1_fwd_txfm2d_8x8/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
   specialize qw/av1_fwd_txfm2d_8x8 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_16x16/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
-  specialize qw/av1_fwd_txfm2d_16x16 sse4_1/;
+  specialize qw/av1_fwd_txfm2d_16x16 sse4_1 avx2/;
   add_proto qw/void av1_fwd_txfm2d_32x32/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
   specialize qw/av1_fwd_txfm2d_32x32 sse4_1 avx2/;
 
diff --git a/av1/encoder/x86/highbd_fwd_txfm_avx2.c b/av1/encoder/x86/highbd_fwd_txfm_avx2.c
index 85fb3be..0cdbebd 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_avx2.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_avx2.c
@@ -17,21 +17,40 @@
 #include "av1/encoder/av1_fwd_txfm1d_cfg.h"
 #include "aom_dsp/txfm_common.h"
 #include "aom_ports/mem.h"
+#include "aom_dsp/x86/txfm_common_sse2.h"
 
 static INLINE void av1_load_buffer_16xn_avx2(const int16_t *input, __m256i *out,
                                              int stride, int height,
-                                             int outstride) {
+                                             int outstride, int flipud,
+                                             int fliplr) {
   __m256i out1[64];
-  for (int i = 0; i < height; i++) {
-    out1[i] = _mm256_loadu_si256((const __m256i *)(input + i * stride));
+  if (!flipud) {
+    for (int i = 0; i < height; i++) {
+      out1[i] = _mm256_loadu_si256((const __m256i *)(input + i * stride));
+    }
+  } else {
+    for (int i = 0; i < height; i++) {
+      out1[(height - 1) - i] =
+          _mm256_loadu_si256((const __m256i *)(input + i * stride));
+    }
   }
-
-  for (int i = 0; i < height; i++) {
-    out[i * outstride] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(out1[i]));
-    out[i * outstride + 1] =
-        _mm256_cvtepi16_epi32(_mm256_extractf128_si256(out1[i], 1));
+  if (!fliplr) {
+    for (int i = 0; i < height; i++) {
+      out[i * outstride] =
+          _mm256_cvtepi16_epi32(_mm256_castsi256_si128(out1[i]));
+      out[i * outstride + 1] =
+          _mm256_cvtepi16_epi32(_mm256_extractf128_si256(out1[i], 1));
+    }
+  } else {
+    for (int i = 0; i < height; i++) {
+      out[i * outstride + 1] = _mm256_cvtepi16_epi32(
+          mm_reverse_epi16(_mm256_castsi256_si128(out1[i])));
+      out[i * outstride + 0] = _mm256_cvtepi16_epi32(
+          mm_reverse_epi16(_mm256_extractf128_si256(out1[i], 1)));
+    }
   }
 }
+
 static void av1_fwd_txfm_transpose_8x8_avx2(const __m256i *in, __m256i *out,
                                             const int instride,
                                             const int outstride) {
@@ -92,6 +111,26 @@
     out += stride;
   }
 }
+static INLINE void av1_fwd_txfm_transpose_16x16_avx2(const __m256i *in,
+                                                     __m256i *out) {
+  av1_fwd_txfm_transpose_8x8_avx2(&in[0], &out[0], 2, 2);
+  av1_fwd_txfm_transpose_8x8_avx2(&in[1], &out[16], 2, 2);
+  av1_fwd_txfm_transpose_8x8_avx2(&in[16], &out[1], 2, 2);
+  av1_fwd_txfm_transpose_8x8_avx2(&in[17], &out[17], 2, 2);
+}
+
+static INLINE __m256i av1_half_btf_avx2(const __m256i *w0, const __m256i *n0,
+                                        const __m256i *w1, const __m256i *n1,
+                                        const __m256i *rounding, int bit) {
+  __m256i x, y;
+
+  x = _mm256_mullo_epi32(*w0, *n0);
+  y = _mm256_mullo_epi32(*w1, *n1);
+  x = _mm256_add_epi32(x, y);
+  x = _mm256_add_epi32(x, *rounding);
+  x = _mm256_srai_epi32(x, bit);
+  return x;
+}
 #define btf_32_avx2_type0(w0, w1, in0, in1, out0, out1, bit) \
   do {                                                       \
     const __m256i ww0 = _mm256_set1_epi32(w0);               \
@@ -124,7 +163,725 @@
                                   const int8_t cos_bit,
                                   const int8_t *stage_range, int instride,
                                   int outstride);
+static void av1_fdct16_avx2(__m256i *in, __m256i *out, int bit,
+                            const int col_num) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i cospim32 = _mm256_set1_epi32(-cospi[32]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+  const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+  const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+  const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
+  const __m256i cospi36 = _mm256_set1_epi32(cospi[36]);
+  const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
+  const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
+  const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
+  const __m256i cospi52 = _mm256_set1_epi32(cospi[52]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  __m256i u[16], v[16], x;
+  int col;
 
+  // Calculate the column 0, 1, 2, 3
+  for (col = 0; col < col_num; ++col) {
+    // stage 0
+    // stage 1
+    u[0] = _mm256_add_epi32(in[0 * col_num + col], in[15 * col_num + col]);
+    u[15] = _mm256_sub_epi32(in[0 * col_num + col], in[15 * col_num + col]);
+    u[1] = _mm256_add_epi32(in[1 * col_num + col], in[14 * col_num + col]);
+    u[14] = _mm256_sub_epi32(in[1 * col_num + col], in[14 * col_num + col]);
+    u[2] = _mm256_add_epi32(in[2 * col_num + col], in[13 * col_num + col]);
+    u[13] = _mm256_sub_epi32(in[2 * col_num + col], in[13 * col_num + col]);
+    u[3] = _mm256_add_epi32(in[3 * col_num + col], in[12 * col_num + col]);
+    u[12] = _mm256_sub_epi32(in[3 * col_num + col], in[12 * col_num + col]);
+    u[4] = _mm256_add_epi32(in[4 * col_num + col], in[11 * col_num + col]);
+    u[11] = _mm256_sub_epi32(in[4 * col_num + col], in[11 * col_num + col]);
+    u[5] = _mm256_add_epi32(in[5 * col_num + col], in[10 * col_num + col]);
+    u[10] = _mm256_sub_epi32(in[5 * col_num + col], in[10 * col_num + col]);
+    u[6] = _mm256_add_epi32(in[6 * col_num + col], in[9 * col_num + col]);
+    u[9] = _mm256_sub_epi32(in[6 * col_num + col], in[9 * col_num + col]);
+    u[7] = _mm256_add_epi32(in[7 * col_num + col], in[8 * col_num + col]);
+    u[8] = _mm256_sub_epi32(in[7 * col_num + col], in[8 * col_num + col]);
+
+    // stage 2
+    v[0] = _mm256_add_epi32(u[0], u[7]);
+    v[7] = _mm256_sub_epi32(u[0], u[7]);
+    v[1] = _mm256_add_epi32(u[1], u[6]);
+    v[6] = _mm256_sub_epi32(u[1], u[6]);
+    v[2] = _mm256_add_epi32(u[2], u[5]);
+    v[5] = _mm256_sub_epi32(u[2], u[5]);
+    v[3] = _mm256_add_epi32(u[3], u[4]);
+    v[4] = _mm256_sub_epi32(u[3], u[4]);
+    v[8] = u[8];
+    v[9] = u[9];
+
+    v[10] = _mm256_mullo_epi32(u[10], cospim32);
+    x = _mm256_mullo_epi32(u[13], cospi32);
+    v[10] = _mm256_add_epi32(v[10], x);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[13] = _mm256_mullo_epi32(u[10], cospi32);
+    x = _mm256_mullo_epi32(u[13], cospim32);
+    v[13] = _mm256_sub_epi32(v[13], x);
+    v[13] = _mm256_add_epi32(v[13], rnding);
+    v[13] = _mm256_srai_epi32(v[13], bit);
+
+    v[11] = _mm256_mullo_epi32(u[11], cospim32);
+    x = _mm256_mullo_epi32(u[12], cospi32);
+    v[11] = _mm256_add_epi32(v[11], x);
+    v[11] = _mm256_add_epi32(v[11], rnding);
+    v[11] = _mm256_srai_epi32(v[11], bit);
+
+    v[12] = _mm256_mullo_epi32(u[11], cospi32);
+    x = _mm256_mullo_epi32(u[12], cospim32);
+    v[12] = _mm256_sub_epi32(v[12], x);
+    v[12] = _mm256_add_epi32(v[12], rnding);
+    v[12] = _mm256_srai_epi32(v[12], bit);
+    v[14] = u[14];
+    v[15] = u[15];
+
+    // stage 3
+    u[0] = _mm256_add_epi32(v[0], v[3]);
+    u[3] = _mm256_sub_epi32(v[0], v[3]);
+    u[1] = _mm256_add_epi32(v[1], v[2]);
+    u[2] = _mm256_sub_epi32(v[1], v[2]);
+    u[4] = v[4];
+
+    u[5] = _mm256_mullo_epi32(v[5], cospim32);
+    x = _mm256_mullo_epi32(v[6], cospi32);
+    u[5] = _mm256_add_epi32(u[5], x);
+    u[5] = _mm256_add_epi32(u[5], rnding);
+    u[5] = _mm256_srai_epi32(u[5], bit);
+
+    u[6] = _mm256_mullo_epi32(v[5], cospi32);
+    x = _mm256_mullo_epi32(v[6], cospim32);
+    u[6] = _mm256_sub_epi32(u[6], x);
+    u[6] = _mm256_add_epi32(u[6], rnding);
+    u[6] = _mm256_srai_epi32(u[6], bit);
+
+    u[7] = v[7];
+    u[8] = _mm256_add_epi32(v[8], v[11]);
+    u[11] = _mm256_sub_epi32(v[8], v[11]);
+    u[9] = _mm256_add_epi32(v[9], v[10]);
+    u[10] = _mm256_sub_epi32(v[9], v[10]);
+    u[12] = _mm256_sub_epi32(v[15], v[12]);
+    u[15] = _mm256_add_epi32(v[15], v[12]);
+    u[13] = _mm256_sub_epi32(v[14], v[13]);
+    u[14] = _mm256_add_epi32(v[14], v[13]);
+
+    // stage 4
+    u[0] = _mm256_mullo_epi32(u[0], cospi32);
+    u[1] = _mm256_mullo_epi32(u[1], cospi32);
+    v[0] = _mm256_add_epi32(u[0], u[1]);
+    v[0] = _mm256_add_epi32(v[0], rnding);
+    v[0] = _mm256_srai_epi32(v[0], bit);
+
+    v[1] = _mm256_sub_epi32(u[0], u[1]);
+    v[1] = _mm256_add_epi32(v[1], rnding);
+    v[1] = _mm256_srai_epi32(v[1], bit);
+
+    v[2] = _mm256_mullo_epi32(u[2], cospi48);
+    x = _mm256_mullo_epi32(u[3], cospi16);
+    v[2] = _mm256_add_epi32(v[2], x);
+    v[2] = _mm256_add_epi32(v[2], rnding);
+    v[2] = _mm256_srai_epi32(v[2], bit);
+
+    v[3] = _mm256_mullo_epi32(u[2], cospi16);
+    x = _mm256_mullo_epi32(u[3], cospi48);
+    v[3] = _mm256_sub_epi32(x, v[3]);
+    v[3] = _mm256_add_epi32(v[3], rnding);
+    v[3] = _mm256_srai_epi32(v[3], bit);
+
+    v[4] = _mm256_add_epi32(u[4], u[5]);
+    v[5] = _mm256_sub_epi32(u[4], u[5]);
+    v[6] = _mm256_sub_epi32(u[7], u[6]);
+    v[7] = _mm256_add_epi32(u[7], u[6]);
+    v[8] = u[8];
+
+    v[9] = _mm256_mullo_epi32(u[9], cospim16);
+    x = _mm256_mullo_epi32(u[14], cospi48);
+    v[9] = _mm256_add_epi32(v[9], x);
+    v[9] = _mm256_add_epi32(v[9], rnding);
+    v[9] = _mm256_srai_epi32(v[9], bit);
+
+    v[14] = _mm256_mullo_epi32(u[9], cospi48);
+    x = _mm256_mullo_epi32(u[14], cospim16);
+    v[14] = _mm256_sub_epi32(v[14], x);
+    v[14] = _mm256_add_epi32(v[14], rnding);
+    v[14] = _mm256_srai_epi32(v[14], bit);
+
+    v[10] = _mm256_mullo_epi32(u[10], cospim48);
+    x = _mm256_mullo_epi32(u[13], cospim16);
+    v[10] = _mm256_add_epi32(v[10], x);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[13] = _mm256_mullo_epi32(u[10], cospim16);
+    x = _mm256_mullo_epi32(u[13], cospim48);
+    v[13] = _mm256_sub_epi32(v[13], x);
+    v[13] = _mm256_add_epi32(v[13], rnding);
+    v[13] = _mm256_srai_epi32(v[13], bit);
+
+    v[11] = u[11];
+    v[12] = u[12];
+    v[15] = u[15];
+
+    // stage 5
+    u[0] = v[0];
+    u[1] = v[1];
+    u[2] = v[2];
+    u[3] = v[3];
+
+    u[4] = _mm256_mullo_epi32(v[4], cospi56);
+    x = _mm256_mullo_epi32(v[7], cospi8);
+    u[4] = _mm256_add_epi32(u[4], x);
+    u[4] = _mm256_add_epi32(u[4], rnding);
+    u[4] = _mm256_srai_epi32(u[4], bit);
+
+    u[7] = _mm256_mullo_epi32(v[4], cospi8);
+    x = _mm256_mullo_epi32(v[7], cospi56);
+    u[7] = _mm256_sub_epi32(x, u[7]);
+    u[7] = _mm256_add_epi32(u[7], rnding);
+    u[7] = _mm256_srai_epi32(u[7], bit);
+
+    u[5] = _mm256_mullo_epi32(v[5], cospi24);
+    x = _mm256_mullo_epi32(v[6], cospi40);
+    u[5] = _mm256_add_epi32(u[5], x);
+    u[5] = _mm256_add_epi32(u[5], rnding);
+    u[5] = _mm256_srai_epi32(u[5], bit);
+
+    u[6] = _mm256_mullo_epi32(v[5], cospi40);
+    x = _mm256_mullo_epi32(v[6], cospi24);
+    u[6] = _mm256_sub_epi32(x, u[6]);
+    u[6] = _mm256_add_epi32(u[6], rnding);
+    u[6] = _mm256_srai_epi32(u[6], bit);
+
+    u[8] = _mm256_add_epi32(v[8], v[9]);
+    u[9] = _mm256_sub_epi32(v[8], v[9]);
+    u[10] = _mm256_sub_epi32(v[11], v[10]);
+    u[11] = _mm256_add_epi32(v[11], v[10]);
+    u[12] = _mm256_add_epi32(v[12], v[13]);
+    u[13] = _mm256_sub_epi32(v[12], v[13]);
+    u[14] = _mm256_sub_epi32(v[15], v[14]);
+    u[15] = _mm256_add_epi32(v[15], v[14]);
+
+    // stage 6
+    v[0] = u[0];
+    v[1] = u[1];
+    v[2] = u[2];
+    v[3] = u[3];
+    v[4] = u[4];
+    v[5] = u[5];
+    v[6] = u[6];
+    v[7] = u[7];
+
+    v[8] = _mm256_mullo_epi32(u[8], cospi60);
+    x = _mm256_mullo_epi32(u[15], cospi4);
+    v[8] = _mm256_add_epi32(v[8], x);
+    v[8] = _mm256_add_epi32(v[8], rnding);
+    v[8] = _mm256_srai_epi32(v[8], bit);
+
+    v[15] = _mm256_mullo_epi32(u[8], cospi4);
+    x = _mm256_mullo_epi32(u[15], cospi60);
+    v[15] = _mm256_sub_epi32(x, v[15]);
+    v[15] = _mm256_add_epi32(v[15], rnding);
+    v[15] = _mm256_srai_epi32(v[15], bit);
+
+    v[9] = _mm256_mullo_epi32(u[9], cospi28);
+    x = _mm256_mullo_epi32(u[14], cospi36);
+    v[9] = _mm256_add_epi32(v[9], x);
+    v[9] = _mm256_add_epi32(v[9], rnding);
+    v[9] = _mm256_srai_epi32(v[9], bit);
+
+    v[14] = _mm256_mullo_epi32(u[9], cospi36);
+    x = _mm256_mullo_epi32(u[14], cospi28);
+    v[14] = _mm256_sub_epi32(x, v[14]);
+    v[14] = _mm256_add_epi32(v[14], rnding);
+    v[14] = _mm256_srai_epi32(v[14], bit);
+
+    v[10] = _mm256_mullo_epi32(u[10], cospi44);
+    x = _mm256_mullo_epi32(u[13], cospi20);
+    v[10] = _mm256_add_epi32(v[10], x);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[13] = _mm256_mullo_epi32(u[10], cospi20);
+    x = _mm256_mullo_epi32(u[13], cospi44);
+    v[13] = _mm256_sub_epi32(x, v[13]);
+    v[13] = _mm256_add_epi32(v[13], rnding);
+    v[13] = _mm256_srai_epi32(v[13], bit);
+
+    v[11] = _mm256_mullo_epi32(u[11], cospi12);
+    x = _mm256_mullo_epi32(u[12], cospi52);
+    v[11] = _mm256_add_epi32(v[11], x);
+    v[11] = _mm256_add_epi32(v[11], rnding);
+    v[11] = _mm256_srai_epi32(v[11], bit);
+
+    v[12] = _mm256_mullo_epi32(u[11], cospi52);
+    x = _mm256_mullo_epi32(u[12], cospi12);
+    v[12] = _mm256_sub_epi32(x, v[12]);
+    v[12] = _mm256_add_epi32(v[12], rnding);
+    v[12] = _mm256_srai_epi32(v[12], bit);
+
+    out[0 * col_num + col] = v[0];
+    out[1 * col_num + col] = v[8];
+    out[2 * col_num + col] = v[4];
+    out[3 * col_num + col] = v[12];
+    out[4 * col_num + col] = v[2];
+    out[5 * col_num + col] = v[10];
+    out[6 * col_num + col] = v[6];
+    out[7 * col_num + col] = v[14];
+    out[8 * col_num + col] = v[1];
+    out[9 * col_num + col] = v[9];
+    out[10 * col_num + col] = v[5];
+    out[11 * col_num + col] = v[13];
+    out[12 * col_num + col] = v[3];
+    out[13 * col_num + col] = v[11];
+    out[14 * col_num + col] = v[7];
+    out[15 * col_num + col] = v[15];
+  }
+}
+static void av1_fadst16_avx2(__m256i *in, __m256i *out, int bit,
+                             const int num_cols) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
+  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospim56 = _mm256_set1_epi32(-cospi[56]);
+  const __m256i cospim8 = _mm256_set1_epi32(-cospi[8]);
+  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+  const __m256i cospim24 = _mm256_set1_epi32(-cospi[24]);
+  const __m256i cospim40 = _mm256_set1_epi32(-cospi[40]);
+  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+  const __m256i cospi2 = _mm256_set1_epi32(cospi[2]);
+  const __m256i cospi62 = _mm256_set1_epi32(cospi[62]);
+  const __m256i cospim2 = _mm256_set1_epi32(-cospi[2]);
+  const __m256i cospi10 = _mm256_set1_epi32(cospi[10]);
+  const __m256i cospi54 = _mm256_set1_epi32(cospi[54]);
+  const __m256i cospim10 = _mm256_set1_epi32(-cospi[10]);
+  const __m256i cospi18 = _mm256_set1_epi32(cospi[18]);
+  const __m256i cospi46 = _mm256_set1_epi32(cospi[46]);
+  const __m256i cospim18 = _mm256_set1_epi32(-cospi[18]);
+  const __m256i cospi26 = _mm256_set1_epi32(cospi[26]);
+  const __m256i cospi38 = _mm256_set1_epi32(cospi[38]);
+  const __m256i cospim26 = _mm256_set1_epi32(-cospi[26]);
+  const __m256i cospi34 = _mm256_set1_epi32(cospi[34]);
+  const __m256i cospi30 = _mm256_set1_epi32(cospi[30]);
+  const __m256i cospim34 = _mm256_set1_epi32(-cospi[34]);
+  const __m256i cospi42 = _mm256_set1_epi32(cospi[42]);
+  const __m256i cospi22 = _mm256_set1_epi32(cospi[22]);
+  const __m256i cospim42 = _mm256_set1_epi32(-cospi[42]);
+  const __m256i cospi50 = _mm256_set1_epi32(cospi[50]);
+  const __m256i cospi14 = _mm256_set1_epi32(cospi[14]);
+  const __m256i cospim50 = _mm256_set1_epi32(-cospi[50]);
+  const __m256i cospi58 = _mm256_set1_epi32(cospi[58]);
+  const __m256i cospi6 = _mm256_set1_epi32(cospi[6]);
+  const __m256i cospim58 = _mm256_set1_epi32(-cospi[58]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const __m256i zero = _mm256_setzero_si256();
+
+  __m256i u[16], v[16], x, y;
+  int col;
+
+  for (col = 0; col < num_cols; ++col) {
+    // stage 0
+    // stage 1
+    u[0] = in[0 * num_cols + col];
+    u[1] = _mm256_sub_epi32(zero, in[15 * num_cols + col]);
+    u[2] = _mm256_sub_epi32(zero, in[7 * num_cols + col]);
+    u[3] = in[8 * num_cols + col];
+    u[4] = _mm256_sub_epi32(zero, in[3 * num_cols + col]);
+    u[5] = in[12 * num_cols + col];
+    u[6] = in[4 * num_cols + col];
+    u[7] = _mm256_sub_epi32(zero, in[11 * num_cols + col]);
+    u[8] = _mm256_sub_epi32(zero, in[1 * num_cols + col]);
+    u[9] = in[14 * num_cols + col];
+    u[10] = in[6 * num_cols + col];
+    u[11] = _mm256_sub_epi32(zero, in[9 * num_cols + col]);
+    u[12] = in[2 * num_cols + col];
+    u[13] = _mm256_sub_epi32(zero, in[13 * num_cols + col]);
+    u[14] = _mm256_sub_epi32(zero, in[5 * num_cols + col]);
+    u[15] = in[10 * num_cols + col];
+
+    // stage 2
+    v[0] = u[0];
+    v[1] = u[1];
+
+    x = _mm256_mullo_epi32(u[2], cospi32);
+    y = _mm256_mullo_epi32(u[3], cospi32);
+    v[2] = _mm256_add_epi32(x, y);
+    v[2] = _mm256_add_epi32(v[2], rnding);
+    v[2] = _mm256_srai_epi32(v[2], bit);
+
+    v[3] = _mm256_sub_epi32(x, y);
+    v[3] = _mm256_add_epi32(v[3], rnding);
+    v[3] = _mm256_srai_epi32(v[3], bit);
+
+    v[4] = u[4];
+    v[5] = u[5];
+
+    x = _mm256_mullo_epi32(u[6], cospi32);
+    y = _mm256_mullo_epi32(u[7], cospi32);
+    v[6] = _mm256_add_epi32(x, y);
+    v[6] = _mm256_add_epi32(v[6], rnding);
+    v[6] = _mm256_srai_epi32(v[6], bit);
+
+    v[7] = _mm256_sub_epi32(x, y);
+    v[7] = _mm256_add_epi32(v[7], rnding);
+    v[7] = _mm256_srai_epi32(v[7], bit);
+
+    v[8] = u[8];
+    v[9] = u[9];
+
+    x = _mm256_mullo_epi32(u[10], cospi32);
+    y = _mm256_mullo_epi32(u[11], cospi32);
+    v[10] = _mm256_add_epi32(x, y);
+    v[10] = _mm256_add_epi32(v[10], rnding);
+    v[10] = _mm256_srai_epi32(v[10], bit);
+
+    v[11] = _mm256_sub_epi32(x, y);
+    v[11] = _mm256_add_epi32(v[11], rnding);
+    v[11] = _mm256_srai_epi32(v[11], bit);
+
+    v[12] = u[12];
+    v[13] = u[13];
+
+    x = _mm256_mullo_epi32(u[14], cospi32);
+    y = _mm256_mullo_epi32(u[15], cospi32);
+    v[14] = _mm256_add_epi32(x, y);
+    v[14] = _mm256_add_epi32(v[14], rnding);
+    v[14] = _mm256_srai_epi32(v[14], bit);
+
+    v[15] = _mm256_sub_epi32(x, y);
+    v[15] = _mm256_add_epi32(v[15], rnding);
+    v[15] = _mm256_srai_epi32(v[15], bit);
+
+    // stage 3
+    u[0] = _mm256_add_epi32(v[0], v[2]);
+    u[1] = _mm256_add_epi32(v[1], v[3]);
+    u[2] = _mm256_sub_epi32(v[0], v[2]);
+    u[3] = _mm256_sub_epi32(v[1], v[3]);
+    u[4] = _mm256_add_epi32(v[4], v[6]);
+    u[5] = _mm256_add_epi32(v[5], v[7]);
+    u[6] = _mm256_sub_epi32(v[4], v[6]);
+    u[7] = _mm256_sub_epi32(v[5], v[7]);
+    u[8] = _mm256_add_epi32(v[8], v[10]);
+    u[9] = _mm256_add_epi32(v[9], v[11]);
+    u[10] = _mm256_sub_epi32(v[8], v[10]);
+    u[11] = _mm256_sub_epi32(v[9], v[11]);
+    u[12] = _mm256_add_epi32(v[12], v[14]);
+    u[13] = _mm256_add_epi32(v[13], v[15]);
+    u[14] = _mm256_sub_epi32(v[12], v[14]);
+    u[15] = _mm256_sub_epi32(v[13], v[15]);
+
+    // stage 4
+    v[0] = u[0];
+    v[1] = u[1];
+    v[2] = u[2];
+    v[3] = u[3];
+    v[4] = av1_half_btf_avx2(&cospi16, &u[4], &cospi48, &u[5], &rnding, bit);
+    v[5] = av1_half_btf_avx2(&cospi48, &u[4], &cospim16, &u[5], &rnding, bit);
+    v[6] = av1_half_btf_avx2(&cospim48, &u[6], &cospi16, &u[7], &rnding, bit);
+    v[7] = av1_half_btf_avx2(&cospi16, &u[6], &cospi48, &u[7], &rnding, bit);
+    v[8] = u[8];
+    v[9] = u[9];
+    v[10] = u[10];
+    v[11] = u[11];
+    v[12] = av1_half_btf_avx2(&cospi16, &u[12], &cospi48, &u[13], &rnding, bit);
+    v[13] =
+        av1_half_btf_avx2(&cospi48, &u[12], &cospim16, &u[13], &rnding, bit);
+    v[14] =
+        av1_half_btf_avx2(&cospim48, &u[14], &cospi16, &u[15], &rnding, bit);
+    v[15] = av1_half_btf_avx2(&cospi16, &u[14], &cospi48, &u[15], &rnding, bit);
+
+    // stage 5
+    u[0] = _mm256_add_epi32(v[0], v[4]);
+    u[1] = _mm256_add_epi32(v[1], v[5]);
+    u[2] = _mm256_add_epi32(v[2], v[6]);
+    u[3] = _mm256_add_epi32(v[3], v[7]);
+    u[4] = _mm256_sub_epi32(v[0], v[4]);
+    u[5] = _mm256_sub_epi32(v[1], v[5]);
+    u[6] = _mm256_sub_epi32(v[2], v[6]);
+    u[7] = _mm256_sub_epi32(v[3], v[7]);
+    u[8] = _mm256_add_epi32(v[8], v[12]);
+    u[9] = _mm256_add_epi32(v[9], v[13]);
+    u[10] = _mm256_add_epi32(v[10], v[14]);
+    u[11] = _mm256_add_epi32(v[11], v[15]);
+    u[12] = _mm256_sub_epi32(v[8], v[12]);
+    u[13] = _mm256_sub_epi32(v[9], v[13]);
+    u[14] = _mm256_sub_epi32(v[10], v[14]);
+    u[15] = _mm256_sub_epi32(v[11], v[15]);
+
+    // stage 6
+    v[0] = u[0];
+    v[1] = u[1];
+    v[2] = u[2];
+    v[3] = u[3];
+    v[4] = u[4];
+    v[5] = u[5];
+    v[6] = u[6];
+    v[7] = u[7];
+    v[8] = av1_half_btf_avx2(&cospi8, &u[8], &cospi56, &u[9], &rnding, bit);
+    v[9] = av1_half_btf_avx2(&cospi56, &u[8], &cospim8, &u[9], &rnding, bit);
+    v[10] = av1_half_btf_avx2(&cospi40, &u[10], &cospi24, &u[11], &rnding, bit);
+    v[11] =
+        av1_half_btf_avx2(&cospi24, &u[10], &cospim40, &u[11], &rnding, bit);
+    v[12] = av1_half_btf_avx2(&cospim56, &u[12], &cospi8, &u[13], &rnding, bit);
+    v[13] = av1_half_btf_avx2(&cospi8, &u[12], &cospi56, &u[13], &rnding, bit);
+    v[14] =
+        av1_half_btf_avx2(&cospim24, &u[14], &cospi40, &u[15], &rnding, bit);
+    v[15] = av1_half_btf_avx2(&cospi40, &u[14], &cospi24, &u[15], &rnding, bit);
+
+    // stage 7
+    u[0] = _mm256_add_epi32(v[0], v[8]);
+    u[1] = _mm256_add_epi32(v[1], v[9]);
+    u[2] = _mm256_add_epi32(v[2], v[10]);
+    u[3] = _mm256_add_epi32(v[3], v[11]);
+    u[4] = _mm256_add_epi32(v[4], v[12]);
+    u[5] = _mm256_add_epi32(v[5], v[13]);
+    u[6] = _mm256_add_epi32(v[6], v[14]);
+    u[7] = _mm256_add_epi32(v[7], v[15]);
+    u[8] = _mm256_sub_epi32(v[0], v[8]);
+    u[9] = _mm256_sub_epi32(v[1], v[9]);
+    u[10] = _mm256_sub_epi32(v[2], v[10]);
+    u[11] = _mm256_sub_epi32(v[3], v[11]);
+    u[12] = _mm256_sub_epi32(v[4], v[12]);
+    u[13] = _mm256_sub_epi32(v[5], v[13]);
+    u[14] = _mm256_sub_epi32(v[6], v[14]);
+    u[15] = _mm256_sub_epi32(v[7], v[15]);
+
+    // stage 8
+    v[0] = av1_half_btf_avx2(&cospi2, &u[0], &cospi62, &u[1], &rnding, bit);
+    v[1] = av1_half_btf_avx2(&cospi62, &u[0], &cospim2, &u[1], &rnding, bit);
+    v[2] = av1_half_btf_avx2(&cospi10, &u[2], &cospi54, &u[3], &rnding, bit);
+    v[3] = av1_half_btf_avx2(&cospi54, &u[2], &cospim10, &u[3], &rnding, bit);
+    v[4] = av1_half_btf_avx2(&cospi18, &u[4], &cospi46, &u[5], &rnding, bit);
+    v[5] = av1_half_btf_avx2(&cospi46, &u[4], &cospim18, &u[5], &rnding, bit);
+    v[6] = av1_half_btf_avx2(&cospi26, &u[6], &cospi38, &u[7], &rnding, bit);
+    v[7] = av1_half_btf_avx2(&cospi38, &u[6], &cospim26, &u[7], &rnding, bit);
+    v[8] = av1_half_btf_avx2(&cospi34, &u[8], &cospi30, &u[9], &rnding, bit);
+    v[9] = av1_half_btf_avx2(&cospi30, &u[8], &cospim34, &u[9], &rnding, bit);
+    v[10] = av1_half_btf_avx2(&cospi42, &u[10], &cospi22, &u[11], &rnding, bit);
+    v[11] =
+        av1_half_btf_avx2(&cospi22, &u[10], &cospim42, &u[11], &rnding, bit);
+    v[12] = av1_half_btf_avx2(&cospi50, &u[12], &cospi14, &u[13], &rnding, bit);
+    v[13] =
+        av1_half_btf_avx2(&cospi14, &u[12], &cospim50, &u[13], &rnding, bit);
+    v[14] = av1_half_btf_avx2(&cospi58, &u[14], &cospi6, &u[15], &rnding, bit);
+    v[15] = av1_half_btf_avx2(&cospi6, &u[14], &cospim58, &u[15], &rnding, bit);
+
+    // stage 9
+    out[0 * num_cols + col] = v[1];
+    out[1 * num_cols + col] = v[14];
+    out[2 * num_cols + col] = v[3];
+    out[3 * num_cols + col] = v[12];
+    out[4 * num_cols + col] = v[5];
+    out[5 * num_cols + col] = v[10];
+    out[6 * num_cols + col] = v[7];
+    out[7 * num_cols + col] = v[8];
+    out[8 * num_cols + col] = v[9];
+    out[9 * num_cols + col] = v[6];
+    out[10 * num_cols + col] = v[11];
+    out[11 * num_cols + col] = v[4];
+    out[12 * num_cols + col] = v[13];
+    out[13 * num_cols + col] = v[2];
+    out[14 * num_cols + col] = v[15];
+    out[15 * num_cols + col] = v[0];
+  }
+}
+static void av1_idtx16_avx2(__m256i *in, __m256i *out, int bit, int col_num) {
+  (void)bit;
+  __m256i fact = _mm256_set1_epi32(2 * NewSqrt2);
+  __m256i offset = _mm256_set1_epi32(1 << (NewSqrt2Bits - 1));
+  __m256i a_low;
+
+  int num_iters = 16 * col_num;
+  for (int i = 0; i < num_iters; i++) {
+    a_low = _mm256_mullo_epi32(in[i], fact);
+    a_low = _mm256_add_epi32(a_low, offset);
+    out[i] = _mm256_srai_epi32(a_low, NewSqrt2Bits);
+  }
+}
+void av1_fwd_txfm2d_16x16_avx2(const int16_t *input, int32_t *coeff, int stride,
+                               TX_TYPE tx_type, int bd) {
+  __m256i in[32], out[32];
+  const TX_SIZE tx_size = TX_16X16;
+  const int8_t *shift = fwd_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int width = tx_size_wide[tx_size];
+  const int height = tx_size_high[tx_size];
+  const int width_div8 = (width >> 3);
+  const int width_div16 = (width >> 4);
+  const int size = (height << 1);
+  switch (tx_type) {
+    case DCT_DCT:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fdct16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fdct16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case ADST_DCT:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fdct16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case DCT_ADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fdct16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case ADST_ADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case FLIPADST_DCT:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 1, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fdct16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case DCT_FLIPADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 1);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fdct16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case FLIPADST_FLIPADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 1, 1);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case ADST_FLIPADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 1);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case FLIPADST_ADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 1, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case IDTX:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_idtx16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_idtx16_avx2(out, in, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case V_DCT:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fdct16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_idtx16_avx2(out, in, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case H_DCT:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_idtx16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fdct16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case V_ADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_idtx16_avx2(out, in, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case H_ADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_idtx16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case V_FLIPADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 1, 0);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_idtx16_avx2(out, in, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    case H_FLIPADST:
+      av1_load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 1);
+      av1_round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
+      av1_idtx16_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_fadst16_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_16x16_avx2(out, in);
+      av1_store_buffer_avx2(in, coeff, 8, 32);
+      break;
+    default: assert(0);
+  }
+  (void)bd;
+}
 static INLINE void av1_fdct32_avx2(__m256i *input, __m256i *output,
                                    const int8_t cos_bit,
                                    const int8_t *stage_range,
@@ -597,7 +1354,7 @@
 
   for (int i = 0; i < width_div16; i++) {
     av1_load_buffer_16xn_avx2(input + (i << 4), &buf0[(i << 1)], stride, height,
-                              width_div8);
+                              width_div8, 0, 0);
     av1_round_shift_32_8xn_avx2(&buf0[(i << 1)], height, shift[0], width_div8);
     av1_round_shift_32_8xn_avx2(&buf0[(i << 1) + 1], height, shift[0],
                                 width_div8);
@@ -1668,7 +2425,7 @@
   int r, c;
   for (int i = 0; i < width_div16; i++) {
     av1_load_buffer_16xn_avx2(input + (i << 4), &buf0[i << 1], stride, height,
-                              width_div8);
+                              width_div8, 0, 0);
     av1_round_shift_32_8xn_avx2(&buf0[i << 1], height, shift[0], width_div8);
     av1_round_shift_32_8xn_avx2(&buf0[(i << 1) + 1], height, shift[0],
                                 width_div8);
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index 285983d..eb3455b 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -571,7 +571,7 @@
                                 Values(av1_highbd_fwd_txfm)));
 #endif  // HAVE_SSE4_1
 #if HAVE_AVX2
-static TX_SIZE Highbd_fwd_txfm_for_avx2[] = { TX_32X32, TX_64X64 };
+static TX_SIZE Highbd_fwd_txfm_for_avx2[] = { TX_16X16, TX_32X32, TX_64X64 };
 
 INSTANTIATE_TEST_CASE_P(AVX2, AV1HighbdFwdTxfm2dTest,
                         Combine(ValuesIn(Highbd_fwd_txfm_for_avx2),