Optimize highbd fwd_txfm module

Added AVX2 variant for 8x8 txfm blk_size.

When tested for multiple test cases observed 0.4%
average reduction in encoder time for speed = 1 preset.

Module level gains improved by a factor of ~1.5
on average w.r.t to SSE4_1 module.

Change-Id: Iee170dc60916c34a841679154f1dd585cf9d7190
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 163a0a7..4ab76a1 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -230,7 +230,7 @@
   add_proto qw/void av1_fwd_txfm2d_4x4/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
   specialize qw/av1_fwd_txfm2d_4x4 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_8x8/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
-  specialize qw/av1_fwd_txfm2d_8x8 sse4_1/;
+  specialize qw/av1_fwd_txfm2d_8x8 sse4_1 avx2/;
   add_proto qw/void av1_fwd_txfm2d_16x16/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
   specialize qw/av1_fwd_txfm2d_16x16 sse4_1 avx2/;
   add_proto qw/void av1_fwd_txfm2d_32x32/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
diff --git a/av1/encoder/x86/highbd_fwd_txfm_avx2.c b/av1/encoder/x86/highbd_fwd_txfm_avx2.c
index 0cdbebd..019ab66 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_avx2.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_avx2.c
@@ -19,6 +19,80 @@
 #include "aom_ports/mem.h"
 #include "aom_dsp/x86/txfm_common_sse2.h"
 
+static INLINE void av1_load_buffer_8x8_avx2(const int16_t *input, __m256i *out,
+                                            int stride, int flipud, int fliplr,
+                                            int shift) {
+  __m128i out1[8];
+  if (!flipud) {
+    out1[0] = _mm_load_si128((const __m128i *)(input + 0 * stride));
+    out1[1] = _mm_load_si128((const __m128i *)(input + 1 * stride));
+    out1[2] = _mm_load_si128((const __m128i *)(input + 2 * stride));
+    out1[3] = _mm_load_si128((const __m128i *)(input + 3 * stride));
+    out1[4] = _mm_load_si128((const __m128i *)(input + 4 * stride));
+    out1[5] = _mm_load_si128((const __m128i *)(input + 5 * stride));
+    out1[6] = _mm_load_si128((const __m128i *)(input + 6 * stride));
+    out1[7] = _mm_load_si128((const __m128i *)(input + 7 * stride));
+
+  } else {
+    out1[7] = _mm_load_si128((const __m128i *)(input + 0 * stride));
+    out1[6] = _mm_load_si128((const __m128i *)(input + 1 * stride));
+    out1[5] = _mm_load_si128((const __m128i *)(input + 2 * stride));
+    out1[4] = _mm_load_si128((const __m128i *)(input + 3 * stride));
+    out1[3] = _mm_load_si128((const __m128i *)(input + 4 * stride));
+    out1[2] = _mm_load_si128((const __m128i *)(input + 5 * stride));
+    out1[1] = _mm_load_si128((const __m128i *)(input + 6 * stride));
+    out1[0] = _mm_load_si128((const __m128i *)(input + 7 * stride));
+  }
+  if (!fliplr) {
+    out[0] = _mm256_cvtepi16_epi32(out1[0]);
+    out[1] = _mm256_cvtepi16_epi32(out1[1]);
+    out[2] = _mm256_cvtepi16_epi32(out1[2]);
+    out[3] = _mm256_cvtepi16_epi32(out1[3]);
+    out[4] = _mm256_cvtepi16_epi32(out1[4]);
+    out[5] = _mm256_cvtepi16_epi32(out1[5]);
+    out[6] = _mm256_cvtepi16_epi32(out1[6]);
+    out[7] = _mm256_cvtepi16_epi32(out1[7]);
+
+  } else {
+    out[0] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[0]));
+    out[1] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[1]));
+    out[2] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[2]));
+    out[3] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[3]));
+    out[4] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[4]));
+    out[5] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[5]));
+    out[6] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[6]));
+    out[7] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[7]));
+  }
+  out[0] = _mm256_slli_epi32(out[0], shift);
+  out[1] = _mm256_slli_epi32(out[1], shift);
+  out[2] = _mm256_slli_epi32(out[2], shift);
+  out[3] = _mm256_slli_epi32(out[3], shift);
+  out[4] = _mm256_slli_epi32(out[4], shift);
+  out[5] = _mm256_slli_epi32(out[5], shift);
+  out[6] = _mm256_slli_epi32(out[6], shift);
+  out[7] = _mm256_slli_epi32(out[7], shift);
+}
+static INLINE void col_txfm_8x8_rounding(__m256i *in, int shift) {
+  const __m256i rounding = _mm256_set1_epi32(1 << (shift - 1));
+
+  in[0] = _mm256_add_epi32(in[0], rounding);
+  in[1] = _mm256_add_epi32(in[1], rounding);
+  in[2] = _mm256_add_epi32(in[2], rounding);
+  in[3] = _mm256_add_epi32(in[3], rounding);
+  in[4] = _mm256_add_epi32(in[4], rounding);
+  in[5] = _mm256_add_epi32(in[5], rounding);
+  in[6] = _mm256_add_epi32(in[6], rounding);
+  in[7] = _mm256_add_epi32(in[7], rounding);
+
+  in[0] = _mm256_srai_epi32(in[0], shift);
+  in[1] = _mm256_srai_epi32(in[1], shift);
+  in[2] = _mm256_srai_epi32(in[2], shift);
+  in[3] = _mm256_srai_epi32(in[3], shift);
+  in[4] = _mm256_srai_epi32(in[4], shift);
+  in[5] = _mm256_srai_epi32(in[5], shift);
+  in[6] = _mm256_srai_epi32(in[6], shift);
+  in[7] = _mm256_srai_epi32(in[7], shift);
+}
 static INLINE void av1_load_buffer_16xn_avx2(const int16_t *input, __m256i *out,
                                              int stride, int height,
                                              int outstride, int flipud,
@@ -163,6 +237,457 @@
                                   const int8_t cos_bit,
                                   const int8_t *stage_range, int instride,
                                   int outstride);
+static void av1_fdct8_avx2(__m256i *in, __m256i *out, int bit,
+                           const int col_num) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i cospim32 = _mm256_set1_epi32(-cospi[32]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  __m256i u[8], v[8];
+
+  int startidx = 0;
+  int endidx = 7;
+
+  // stage 0
+  // stage 1
+  u[0] = _mm256_add_epi32(in[startidx], in[endidx]);
+  v[7] = _mm256_sub_epi32(in[startidx], in[endidx]);
+  startidx += col_num;
+  endidx -= col_num;
+  u[1] = _mm256_add_epi32(in[startidx], in[endidx]);
+  u[6] = _mm256_sub_epi32(in[startidx], in[endidx]);
+  startidx += col_num;
+  endidx -= col_num;
+  u[2] = _mm256_add_epi32(in[startidx], in[endidx]);
+  u[5] = _mm256_sub_epi32(in[startidx], in[endidx]);
+  startidx += col_num;
+  endidx -= col_num;
+  u[3] = _mm256_add_epi32(in[startidx], in[endidx]);
+  v[4] = _mm256_sub_epi32(in[startidx], in[endidx]);
+
+  // stage 2
+  v[0] = _mm256_add_epi32(u[0], u[3]);
+  v[3] = _mm256_sub_epi32(u[0], u[3]);
+  v[1] = _mm256_add_epi32(u[1], u[2]);
+  v[2] = _mm256_sub_epi32(u[1], u[2]);
+
+  v[5] = _mm256_mullo_epi32(u[5], cospim32);
+  v[6] = _mm256_mullo_epi32(u[6], cospi32);
+  v[5] = _mm256_add_epi32(v[5], v[6]);
+  v[5] = _mm256_add_epi32(v[5], rnding);
+  v[5] = _mm256_srai_epi32(v[5], bit);
+
+  u[0] = _mm256_mullo_epi32(u[5], cospi32);
+  v[6] = _mm256_mullo_epi32(u[6], cospim32);
+  v[6] = _mm256_sub_epi32(u[0], v[6]);
+  v[6] = _mm256_add_epi32(v[6], rnding);
+  v[6] = _mm256_srai_epi32(v[6], bit);
+
+  // stage 3
+  // type 0
+  v[0] = _mm256_mullo_epi32(v[0], cospi32);
+  v[1] = _mm256_mullo_epi32(v[1], cospi32);
+  u[0] = _mm256_add_epi32(v[0], v[1]);
+  u[0] = _mm256_add_epi32(u[0], rnding);
+  u[0] = _mm256_srai_epi32(u[0], bit);
+
+  u[1] = _mm256_sub_epi32(v[0], v[1]);
+  u[1] = _mm256_add_epi32(u[1], rnding);
+  u[1] = _mm256_srai_epi32(u[1], bit);
+
+  // type 1
+  v[0] = _mm256_mullo_epi32(v[2], cospi48);
+  v[1] = _mm256_mullo_epi32(v[3], cospi16);
+  u[2] = _mm256_add_epi32(v[0], v[1]);
+  u[2] = _mm256_add_epi32(u[2], rnding);
+  u[2] = _mm256_srai_epi32(u[2], bit);
+
+  v[0] = _mm256_mullo_epi32(v[2], cospi16);
+  v[1] = _mm256_mullo_epi32(v[3], cospi48);
+  u[3] = _mm256_sub_epi32(v[1], v[0]);
+  u[3] = _mm256_add_epi32(u[3], rnding);
+  u[3] = _mm256_srai_epi32(u[3], bit);
+
+  u[4] = _mm256_add_epi32(v[4], v[5]);
+  u[5] = _mm256_sub_epi32(v[4], v[5]);
+  u[6] = _mm256_sub_epi32(v[7], v[6]);
+  u[7] = _mm256_add_epi32(v[7], v[6]);
+
+  // stage 4
+  // stage 5
+  v[0] = _mm256_mullo_epi32(u[4], cospi56);
+  v[1] = _mm256_mullo_epi32(u[7], cospi8);
+  v[0] = _mm256_add_epi32(v[0], v[1]);
+  v[0] = _mm256_add_epi32(v[0], rnding);
+  out[1] = _mm256_srai_epi32(v[0], bit);
+
+  v[0] = _mm256_mullo_epi32(u[4], cospi8);
+  v[1] = _mm256_mullo_epi32(u[7], cospi56);
+  v[0] = _mm256_sub_epi32(v[1], v[0]);
+  v[0] = _mm256_add_epi32(v[0], rnding);
+  out[7] = _mm256_srai_epi32(v[0], bit);
+
+  v[0] = _mm256_mullo_epi32(u[5], cospi24);
+  v[1] = _mm256_mullo_epi32(u[6], cospi40);
+  v[0] = _mm256_add_epi32(v[0], v[1]);
+  v[0] = _mm256_add_epi32(v[0], rnding);
+  out[5] = _mm256_srai_epi32(v[0], bit);
+
+  v[0] = _mm256_mullo_epi32(u[5], cospi40);
+  v[1] = _mm256_mullo_epi32(u[6], cospi24);
+  v[0] = _mm256_sub_epi32(v[1], v[0]);
+  v[0] = _mm256_add_epi32(v[0], rnding);
+  out[3] = _mm256_srai_epi32(v[0], bit);
+
+  out[0] = u[0];
+  out[4] = u[1];
+  out[2] = u[2];
+  out[6] = u[3];
+}
+static void av1_fadst8_avx2(__m256i *in, __m256i *out, int bit,
+                            const int col_num) {
+  (void)col_num;
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+  const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+  const __m256i cospim4 = _mm256_set1_epi32(-cospi[4]);
+  const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+  const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
+  const __m256i cospim20 = _mm256_set1_epi32(-cospi[20]);
+  const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
+  const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
+  const __m256i cospi36 = _mm256_set1_epi32(cospi[36]);
+  const __m256i cospim36 = _mm256_set1_epi32(-cospi[36]);
+  const __m256i cospi52 = _mm256_set1_epi32(cospi[52]);
+  const __m256i cospim52 = _mm256_set1_epi32(-cospi[52]);
+  const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const __m256i zero = _mm256_setzero_si256();
+  __m256i u0, u1, u2, u3, u4, u5, u6, u7;
+  __m256i v0, v1, v2, v3, v4, v5, v6, v7;
+  __m256i x, y;
+
+  // stage 0
+  // stage 1
+  u0 = in[0];
+  u1 = _mm256_sub_epi32(zero, in[7]);
+  u2 = _mm256_sub_epi32(zero, in[3]);
+  u3 = in[4];
+  u4 = _mm256_sub_epi32(zero, in[1]);
+  u5 = in[6];
+  u6 = in[2];
+  u7 = _mm256_sub_epi32(zero, in[5]);
+
+  // stage 2
+  v0 = u0;
+  v1 = u1;
+
+  x = _mm256_mullo_epi32(u2, cospi32);
+  y = _mm256_mullo_epi32(u3, cospi32);
+  v2 = _mm256_add_epi32(x, y);
+  v2 = _mm256_add_epi32(v2, rnding);
+  v2 = _mm256_srai_epi32(v2, bit);
+
+  v3 = _mm256_sub_epi32(x, y);
+  v3 = _mm256_add_epi32(v3, rnding);
+  v3 = _mm256_srai_epi32(v3, bit);
+
+  v4 = u4;
+  v5 = u5;
+
+  x = _mm256_mullo_epi32(u6, cospi32);
+  y = _mm256_mullo_epi32(u7, cospi32);
+  v6 = _mm256_add_epi32(x, y);
+  v6 = _mm256_add_epi32(v6, rnding);
+  v6 = _mm256_srai_epi32(v6, bit);
+
+  v7 = _mm256_sub_epi32(x, y);
+  v7 = _mm256_add_epi32(v7, rnding);
+  v7 = _mm256_srai_epi32(v7, bit);
+
+  // stage 3
+  u0 = _mm256_add_epi32(v0, v2);
+  u1 = _mm256_add_epi32(v1, v3);
+  u2 = _mm256_sub_epi32(v0, v2);
+  u3 = _mm256_sub_epi32(v1, v3);
+  u4 = _mm256_add_epi32(v4, v6);
+  u5 = _mm256_add_epi32(v5, v7);
+  u6 = _mm256_sub_epi32(v4, v6);
+  u7 = _mm256_sub_epi32(v5, v7);
+
+  // stage 4
+  v0 = u0;
+  v1 = u1;
+  v2 = u2;
+  v3 = u3;
+
+  x = _mm256_mullo_epi32(u4, cospi16);
+  y = _mm256_mullo_epi32(u5, cospi48);
+  v4 = _mm256_add_epi32(x, y);
+  v4 = _mm256_add_epi32(v4, rnding);
+  v4 = _mm256_srai_epi32(v4, bit);
+
+  x = _mm256_mullo_epi32(u4, cospi48);
+  y = _mm256_mullo_epi32(u5, cospim16);
+  v5 = _mm256_add_epi32(x, y);
+  v5 = _mm256_add_epi32(v5, rnding);
+  v5 = _mm256_srai_epi32(v5, bit);
+
+  x = _mm256_mullo_epi32(u6, cospim48);
+  y = _mm256_mullo_epi32(u7, cospi16);
+  v6 = _mm256_add_epi32(x, y);
+  v6 = _mm256_add_epi32(v6, rnding);
+  v6 = _mm256_srai_epi32(v6, bit);
+
+  x = _mm256_mullo_epi32(u6, cospi16);
+  y = _mm256_mullo_epi32(u7, cospi48);
+  v7 = _mm256_add_epi32(x, y);
+  v7 = _mm256_add_epi32(v7, rnding);
+  v7 = _mm256_srai_epi32(v7, bit);
+
+  // stage 5
+  u0 = _mm256_add_epi32(v0, v4);
+  u1 = _mm256_add_epi32(v1, v5);
+  u2 = _mm256_add_epi32(v2, v6);
+  u3 = _mm256_add_epi32(v3, v7);
+  u4 = _mm256_sub_epi32(v0, v4);
+  u5 = _mm256_sub_epi32(v1, v5);
+  u6 = _mm256_sub_epi32(v2, v6);
+  u7 = _mm256_sub_epi32(v3, v7);
+
+  // stage 6
+  x = _mm256_mullo_epi32(u0, cospi4);
+  y = _mm256_mullo_epi32(u1, cospi60);
+  v0 = _mm256_add_epi32(x, y);
+  v0 = _mm256_add_epi32(v0, rnding);
+  v0 = _mm256_srai_epi32(v0, bit);
+
+  x = _mm256_mullo_epi32(u0, cospi60);
+  y = _mm256_mullo_epi32(u1, cospim4);
+  v1 = _mm256_add_epi32(x, y);
+  v1 = _mm256_add_epi32(v1, rnding);
+  v1 = _mm256_srai_epi32(v1, bit);
+
+  x = _mm256_mullo_epi32(u2, cospi20);
+  y = _mm256_mullo_epi32(u3, cospi44);
+  v2 = _mm256_add_epi32(x, y);
+  v2 = _mm256_add_epi32(v2, rnding);
+  v2 = _mm256_srai_epi32(v2, bit);
+
+  x = _mm256_mullo_epi32(u2, cospi44);
+  y = _mm256_mullo_epi32(u3, cospim20);
+  v3 = _mm256_add_epi32(x, y);
+  v3 = _mm256_add_epi32(v3, rnding);
+  v3 = _mm256_srai_epi32(v3, bit);
+
+  x = _mm256_mullo_epi32(u4, cospi36);
+  y = _mm256_mullo_epi32(u5, cospi28);
+  v4 = _mm256_add_epi32(x, y);
+  v4 = _mm256_add_epi32(v4, rnding);
+  v4 = _mm256_srai_epi32(v4, bit);
+
+  x = _mm256_mullo_epi32(u4, cospi28);
+  y = _mm256_mullo_epi32(u5, cospim36);
+  v5 = _mm256_add_epi32(x, y);
+  v5 = _mm256_add_epi32(v5, rnding);
+  v5 = _mm256_srai_epi32(v5, bit);
+
+  x = _mm256_mullo_epi32(u6, cospi52);
+  y = _mm256_mullo_epi32(u7, cospi12);
+  v6 = _mm256_add_epi32(x, y);
+  v6 = _mm256_add_epi32(v6, rnding);
+  v6 = _mm256_srai_epi32(v6, bit);
+
+  x = _mm256_mullo_epi32(u6, cospi12);
+  y = _mm256_mullo_epi32(u7, cospim52);
+  v7 = _mm256_add_epi32(x, y);
+  v7 = _mm256_add_epi32(v7, rnding);
+  v7 = _mm256_srai_epi32(v7, bit);
+
+  // stage 7
+  out[0] = v1;
+  out[1] = v6;
+  out[2] = v3;
+  out[3] = v4;
+  out[4] = v5;
+  out[5] = v2;
+  out[6] = v7;
+  out[7] = v0;
+}
+static void av1_idtx8_avx2(__m256i *in, __m256i *out, int bit, int col_num) {
+  (void)bit;
+  (void)col_num;
+  out[0] = _mm256_add_epi32(in[0], in[0]);
+  out[1] = _mm256_add_epi32(in[1], in[1]);
+  out[2] = _mm256_add_epi32(in[2], in[2]);
+  out[3] = _mm256_add_epi32(in[3], in[3]);
+  out[4] = _mm256_add_epi32(in[4], in[4]);
+  out[5] = _mm256_add_epi32(in[5], in[5]);
+  out[6] = _mm256_add_epi32(in[6], in[6]);
+  out[7] = _mm256_add_epi32(in[7], in[7]);
+}
+void av1_fwd_txfm2d_8x8_avx2(const int16_t *input, int32_t *coeff, int stride,
+                             TX_TYPE tx_type, int bd) {
+  __m256i in[8], out[8];
+  const TX_SIZE tx_size = TX_8X8;
+  const int8_t *shift = fwd_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int width = tx_size_wide[tx_size];
+  const int height = tx_size_high[tx_size];
+  const int width_div8 = (width >> 3);
+
+  switch (tx_type) {
+    case DCT_DCT:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fdct8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case ADST_DCT:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fdct8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case DCT_ADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case ADST_ADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case FLIPADST_DCT:
+      av1_load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fdct8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case DCT_FLIPADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
+      av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case FLIPADST_FLIPADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 1, 1, shift[0]);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case ADST_FLIPADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case FLIPADST_ADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], 1);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case IDTX:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_idtx8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], height);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_idtx8_avx2(out, in, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case V_DCT:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_idtx8_avx2(out, in, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case H_DCT:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_idtx8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case V_ADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_idtx8_avx2(out, in, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case H_ADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+      av1_idtx8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case V_FLIPADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_idtx8_avx2(out, in, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    case H_FLIPADST:
+      av1_load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
+      av1_idtx8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      col_txfm_8x8_rounding(out, -shift[1]);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+      av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+      av1_store_buffer_avx2(in, coeff, 8, 8);
+      break;
+    default: assert(0);
+  }
+  (void)bd;
+}
 static void av1_fdct16_avx2(__m256i *in, __m256i *out, int bit,
                             const int col_num) {
   const int32_t *cospi = cospi_arr(bit);
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index eb3455b..04c6ddb 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -571,7 +571,8 @@
                                 Values(av1_highbd_fwd_txfm)));
 #endif  // HAVE_SSE4_1
 #if HAVE_AVX2
-static TX_SIZE Highbd_fwd_txfm_for_avx2[] = { TX_16X16, TX_32X32, TX_64X64 };
+static TX_SIZE Highbd_fwd_txfm_for_avx2[] = { TX_8X8, TX_16X16, TX_32X32,
+                                              TX_64X64 };
 
 INSTANTIATE_TEST_CASE_P(AVX2, AV1HighbdFwdTxfm2dTest,
                         Combine(ValuesIn(Highbd_fwd_txfm_for_avx2),