Optimize highbd inv_txfm modules

Added AVX2 variants for 8x8,8x16,16x8,8x32 and 32x8 txfm blk_sizes.

When tested for multiple test cases observed 0.9%
average reduction in encoder time for speed = 1 preset.

Module level gains improved by a factor of ~1.66
on average w.r.t to SSE4_1 modules.

Change-Id: I061439c206aa4ae7cb39c2d77594e2dce4946c9b
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index d6b9fef..7049f16 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -116,11 +116,11 @@
 add_proto qw/void av1_highbd_inv_txfm_add_4x4/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_4x4 sse4_1/;
 add_proto qw/void av1_highbd_inv_txfm_add_8x8/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_8x8 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_8x8 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_16x8/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_16x8 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_16x8 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_8x16/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_8x16 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_8x16 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_16x16/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_16x16 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_32x32/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
@@ -130,9 +130,9 @@
 add_proto qw/void av1_highbd_inv_txfm_add_32x16/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_32x16 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_8x32/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_8x32 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_8x32 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_32x8/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_32x8 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_32x8 sse4_1 avx2/;
 add_proto qw/void av1_highbd_inv_txfm_add_4x8/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_4x8 sse4_1/;
 add_proto qw/void av1_highbd_inv_txfm_add_8x4/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
diff --git a/av1/common/x86/highbd_inv_txfm_avx2.c b/av1/common/x86/highbd_inv_txfm_avx2.c
index abd892e..1d1ae06 100644
--- a/av1/common/x86/highbd_inv_txfm_avx2.c
+++ b/av1/common/x86/highbd_inv_txfm_avx2.c
@@ -73,7 +73,30 @@
     _mm256_storeu_si256((__m256i *)(output + i * stride), u);
   }
 }
+static INLINE __m256i highbd_get_recon_8x8_avx2(const __m256i pred, __m256i res,
+                                                const int bd) {
+  __m256i x0 = pred;
+  x0 = _mm256_add_epi32(res, x0);
+  x0 = _mm256_packus_epi32(x0, x0);
+  x0 = _mm256_permute4x64_epi64(x0, 0xd8);
+  x0 = highbd_clamp_epi16_avx2(x0, bd);
+  return x0;
+}
 
+static INLINE void highbd_write_buffer_8xn_avx2(__m256i *in, uint16_t *output,
+                                                int stride, int flipud,
+                                                int height, const int bd) {
+  int j = flipud ? (height - 1) : 0;
+  __m128i temp;
+  const int step = flipud ? -1 : 1;
+  for (int i = 0; i < height; ++i, j += step) {
+    temp = _mm_loadu_si128((__m128i const *)(output + i * stride));
+    __m256i v = _mm256_cvtepi16_epi32(temp);
+    __m256i u = highbd_get_recon_8x8_avx2(v, in[j], bd);
+    __m128i u1 = _mm256_castsi256_si128(u);
+    _mm_storeu_si128((__m128i *)(output + i * stride), u1);
+  }
+}
 static void neg_shift_avx2(const __m256i in0, const __m256i in1, __m256i *out0,
                            __m256i *out1, const __m256i *clamp_lo,
                            const __m256i *clamp_hi, int shift) {
@@ -2481,7 +2504,422 @@
     }
   }
 }
+static void idct8x8_low1_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                              int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  __m256i x;
 
+  // stage 0
+  // stage 1
+  // stage 2
+  // stage 3
+  x = _mm256_mullo_epi32(in[0], cospi32);
+  x = _mm256_add_epi32(x, rnding);
+  x = _mm256_srai_epi32(x, bit);
+
+  // stage 4
+  // stage 5
+  if (!do_cols) {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX(
+        -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+    const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN(
+        (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+    __m256i offset = _mm256_set1_epi32((1 << out_shift) >> 1);
+    x = _mm256_add_epi32(x, offset);
+    x = _mm256_sra_epi32(x, _mm_cvtsi32_si128(out_shift));
+    x = _mm256_max_epi32(x, clamp_lo_out);
+    x = _mm256_min_epi32(x, clamp_hi_out);
+  }
+
+  out[0] = x;
+  out[1] = x;
+  out[2] = x;
+  out[3] = x;
+  out[4] = x;
+  out[5] = x;
+  out[6] = x;
+  out[7] = x;
+}
+static void idct8x8_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                         int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+  const __m256i cospim8 = _mm256_set1_epi32(-cospi[8]);
+  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+  const __m256i cospim40 = _mm256_set1_epi32(-cospi[40]);
+  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+  const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+  __m256i u0, u1, u2, u3, u4, u5, u6, u7;
+  __m256i v0, v1, v2, v3, v4, v5, v6, v7;
+  __m256i x, y;
+
+  // stage 0
+  // stage 1
+  // stage 2
+  u0 = in[0];
+  u1 = in[4];
+  u2 = in[2];
+  u3 = in[6];
+
+  x = _mm256_mullo_epi32(in[1], cospi56);
+  y = _mm256_mullo_epi32(in[7], cospim8);
+  u4 = _mm256_add_epi32(x, y);
+  u4 = _mm256_add_epi32(u4, rnding);
+  u4 = _mm256_srai_epi32(u4, bit);
+
+  x = _mm256_mullo_epi32(in[1], cospi8);
+  y = _mm256_mullo_epi32(in[7], cospi56);
+  u7 = _mm256_add_epi32(x, y);
+  u7 = _mm256_add_epi32(u7, rnding);
+  u7 = _mm256_srai_epi32(u7, bit);
+
+  x = _mm256_mullo_epi32(in[5], cospi24);
+  y = _mm256_mullo_epi32(in[3], cospim40);
+  u5 = _mm256_add_epi32(x, y);
+  u5 = _mm256_add_epi32(u5, rnding);
+  u5 = _mm256_srai_epi32(u5, bit);
+
+  x = _mm256_mullo_epi32(in[5], cospi40);
+  y = _mm256_mullo_epi32(in[3], cospi24);
+  u6 = _mm256_add_epi32(x, y);
+  u6 = _mm256_add_epi32(u6, rnding);
+  u6 = _mm256_srai_epi32(u6, bit);
+
+  // stage 3
+  x = _mm256_mullo_epi32(u0, cospi32);
+  y = _mm256_mullo_epi32(u1, cospi32);
+  v0 = _mm256_add_epi32(x, y);
+  v0 = _mm256_add_epi32(v0, rnding);
+  v0 = _mm256_srai_epi32(v0, bit);
+
+  v1 = _mm256_sub_epi32(x, y);
+  v1 = _mm256_add_epi32(v1, rnding);
+  v1 = _mm256_srai_epi32(v1, bit);
+
+  x = _mm256_mullo_epi32(u2, cospi48);
+  y = _mm256_mullo_epi32(u3, cospim16);
+  v2 = _mm256_add_epi32(x, y);
+  v2 = _mm256_add_epi32(v2, rnding);
+  v2 = _mm256_srai_epi32(v2, bit);
+
+  x = _mm256_mullo_epi32(u2, cospi16);
+  y = _mm256_mullo_epi32(u3, cospi48);
+  v3 = _mm256_add_epi32(x, y);
+  v3 = _mm256_add_epi32(v3, rnding);
+  v3 = _mm256_srai_epi32(v3, bit);
+
+  addsub_avx2(u4, u5, &v4, &v5, &clamp_lo, &clamp_hi);
+  addsub_avx2(u7, u6, &v7, &v6, &clamp_lo, &clamp_hi);
+
+  // stage 4
+  addsub_avx2(v0, v3, &u0, &u3, &clamp_lo, &clamp_hi);
+  addsub_avx2(v1, v2, &u1, &u2, &clamp_lo, &clamp_hi);
+  u4 = v4;
+  u7 = v7;
+
+  x = _mm256_mullo_epi32(v5, cospi32);
+  y = _mm256_mullo_epi32(v6, cospi32);
+  u6 = _mm256_add_epi32(y, x);
+  u6 = _mm256_add_epi32(u6, rnding);
+  u6 = _mm256_srai_epi32(u6, bit);
+
+  u5 = _mm256_sub_epi32(y, x);
+  u5 = _mm256_add_epi32(u5, rnding);
+  u5 = _mm256_srai_epi32(u5, bit);
+
+  // stage 5
+  if (do_cols) {
+    addsub_no_clamp_avx2(u0, u7, out + 0, out + 7);
+    addsub_no_clamp_avx2(u1, u6, out + 1, out + 6);
+    addsub_no_clamp_avx2(u2, u5, out + 2, out + 5);
+    addsub_no_clamp_avx2(u3, u4, out + 3, out + 4);
+  } else {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX(
+        -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+    const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN(
+        (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+    addsub_shift_avx2(u0, u7, out + 0, out + 7, &clamp_lo_out, &clamp_hi_out,
+                      out_shift);
+    addsub_shift_avx2(u1, u6, out + 1, out + 6, &clamp_lo_out, &clamp_hi_out,
+                      out_shift);
+    addsub_shift_avx2(u2, u5, out + 2, out + 5, &clamp_lo_out, &clamp_hi_out,
+                      out_shift);
+    addsub_shift_avx2(u3, u4, out + 3, out + 4, &clamp_lo_out, &clamp_hi_out,
+                      out_shift);
+  }
+}
+static void iadst8x8_low1_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                               int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+  const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const __m256i kZero = _mm256_setzero_si256();
+  __m256i u[8], x;
+
+  // stage 0
+  // stage 1
+  // stage 2
+
+  x = _mm256_mullo_epi32(in[0], cospi60);
+  u[0] = _mm256_add_epi32(x, rnding);
+  u[0] = _mm256_srai_epi32(u[0], bit);
+
+  x = _mm256_mullo_epi32(in[0], cospi4);
+  u[1] = _mm256_sub_epi32(kZero, x);
+  u[1] = _mm256_add_epi32(u[1], rnding);
+  u[1] = _mm256_srai_epi32(u[1], bit);
+
+  // stage 3
+  // stage 4
+  __m256i temp1, temp2;
+  temp1 = _mm256_mullo_epi32(u[0], cospi16);
+  x = _mm256_mullo_epi32(u[1], cospi48);
+  temp1 = _mm256_add_epi32(temp1, x);
+  temp1 = _mm256_add_epi32(temp1, rnding);
+  temp1 = _mm256_srai_epi32(temp1, bit);
+  u[4] = temp1;
+
+  temp2 = _mm256_mullo_epi32(u[0], cospi48);
+  x = _mm256_mullo_epi32(u[1], cospi16);
+  u[5] = _mm256_sub_epi32(temp2, x);
+  u[5] = _mm256_add_epi32(u[5], rnding);
+  u[5] = _mm256_srai_epi32(u[5], bit);
+
+  // stage 5
+  // stage 6
+  temp1 = _mm256_mullo_epi32(u[0], cospi32);
+  x = _mm256_mullo_epi32(u[1], cospi32);
+  u[2] = _mm256_add_epi32(temp1, x);
+  u[2] = _mm256_add_epi32(u[2], rnding);
+  u[2] = _mm256_srai_epi32(u[2], bit);
+
+  u[3] = _mm256_sub_epi32(temp1, x);
+  u[3] = _mm256_add_epi32(u[3], rnding);
+  u[3] = _mm256_srai_epi32(u[3], bit);
+
+  temp1 = _mm256_mullo_epi32(u[4], cospi32);
+  x = _mm256_mullo_epi32(u[5], cospi32);
+  u[6] = _mm256_add_epi32(temp1, x);
+  u[6] = _mm256_add_epi32(u[6], rnding);
+  u[6] = _mm256_srai_epi32(u[6], bit);
+
+  u[7] = _mm256_sub_epi32(temp1, x);
+  u[7] = _mm256_add_epi32(u[7], rnding);
+  u[7] = _mm256_srai_epi32(u[7], bit);
+
+  // stage 7
+  if (do_cols) {
+    out[0] = u[0];
+    out[1] = _mm256_sub_epi32(kZero, u[4]);
+    out[2] = u[6];
+    out[3] = _mm256_sub_epi32(kZero, u[2]);
+    out[4] = u[3];
+    out[5] = _mm256_sub_epi32(kZero, u[7]);
+    out[6] = u[5];
+    out[7] = _mm256_sub_epi32(kZero, u[1]);
+  } else {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m256i clamp_lo_out = _mm256_set1_epi32(-(1 << (log_range_out - 1)));
+    const __m256i clamp_hi_out =
+        _mm256_set1_epi32((1 << (log_range_out - 1)) - 1);
+
+    neg_shift_avx2(u[0], u[4], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out,
+                   out_shift);
+    neg_shift_avx2(u[6], u[2], out + 2, out + 3, &clamp_lo_out, &clamp_hi_out,
+                   out_shift);
+    neg_shift_avx2(u[3], u[7], out + 4, out + 5, &clamp_lo_out, &clamp_hi_out,
+                   out_shift);
+    neg_shift_avx2(u[5], u[1], out + 6, out + 7, &clamp_lo_out, &clamp_hi_out,
+                   out_shift);
+  }
+}
+
+static void iadst8x8_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+                          int bd, int out_shift) {
+  const int32_t *cospi = cospi_arr(bit);
+  const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+  const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+  const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
+  const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
+  const __m256i cospi36 = _mm256_set1_epi32(cospi[36]);
+  const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
+  const __m256i cospi52 = _mm256_set1_epi32(cospi[52]);
+  const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
+  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+  const __m256i kZero = _mm256_setzero_si256();
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+  const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+  __m256i u[8], v[8], x;
+
+  // stage 0
+  // stage 1
+  // stage 2
+
+  u[0] = _mm256_mullo_epi32(in[7], cospi4);
+  x = _mm256_mullo_epi32(in[0], cospi60);
+  u[0] = _mm256_add_epi32(u[0], x);
+  u[0] = _mm256_add_epi32(u[0], rnding);
+  u[0] = _mm256_srai_epi32(u[0], bit);
+
+  u[1] = _mm256_mullo_epi32(in[7], cospi60);
+  x = _mm256_mullo_epi32(in[0], cospi4);
+  u[1] = _mm256_sub_epi32(u[1], x);
+  u[1] = _mm256_add_epi32(u[1], rnding);
+  u[1] = _mm256_srai_epi32(u[1], bit);
+
+  u[2] = _mm256_mullo_epi32(in[5], cospi20);
+  x = _mm256_mullo_epi32(in[2], cospi44);
+  u[2] = _mm256_add_epi32(u[2], x);
+  u[2] = _mm256_add_epi32(u[2], rnding);
+  u[2] = _mm256_srai_epi32(u[2], bit);
+
+  u[3] = _mm256_mullo_epi32(in[5], cospi44);
+  x = _mm256_mullo_epi32(in[2], cospi20);
+  u[3] = _mm256_sub_epi32(u[3], x);
+  u[3] = _mm256_add_epi32(u[3], rnding);
+  u[3] = _mm256_srai_epi32(u[3], bit);
+
+  u[4] = _mm256_mullo_epi32(in[3], cospi36);
+  x = _mm256_mullo_epi32(in[4], cospi28);
+  u[4] = _mm256_add_epi32(u[4], x);
+  u[4] = _mm256_add_epi32(u[4], rnding);
+  u[4] = _mm256_srai_epi32(u[4], bit);
+
+  u[5] = _mm256_mullo_epi32(in[3], cospi28);
+  x = _mm256_mullo_epi32(in[4], cospi36);
+  u[5] = _mm256_sub_epi32(u[5], x);
+  u[5] = _mm256_add_epi32(u[5], rnding);
+  u[5] = _mm256_srai_epi32(u[5], bit);
+
+  u[6] = _mm256_mullo_epi32(in[1], cospi52);
+  x = _mm256_mullo_epi32(in[6], cospi12);
+  u[6] = _mm256_add_epi32(u[6], x);
+  u[6] = _mm256_add_epi32(u[6], rnding);
+  u[6] = _mm256_srai_epi32(u[6], bit);
+
+  u[7] = _mm256_mullo_epi32(in[1], cospi12);
+  x = _mm256_mullo_epi32(in[6], cospi52);
+  u[7] = _mm256_sub_epi32(u[7], x);
+  u[7] = _mm256_add_epi32(u[7], rnding);
+  u[7] = _mm256_srai_epi32(u[7], bit);
+
+  // stage 3
+  addsub_avx2(u[0], u[4], &v[0], &v[4], &clamp_lo, &clamp_hi);
+  addsub_avx2(u[1], u[5], &v[1], &v[5], &clamp_lo, &clamp_hi);
+  addsub_avx2(u[2], u[6], &v[2], &v[6], &clamp_lo, &clamp_hi);
+  addsub_avx2(u[3], u[7], &v[3], &v[7], &clamp_lo, &clamp_hi);
+
+  // stage 4
+  u[0] = v[0];
+  u[1] = v[1];
+  u[2] = v[2];
+  u[3] = v[3];
+
+  u[4] = _mm256_mullo_epi32(v[4], cospi16);
+  x = _mm256_mullo_epi32(v[5], cospi48);
+  u[4] = _mm256_add_epi32(u[4], x);
+  u[4] = _mm256_add_epi32(u[4], rnding);
+  u[4] = _mm256_srai_epi32(u[4], bit);
+
+  u[5] = _mm256_mullo_epi32(v[4], cospi48);
+  x = _mm256_mullo_epi32(v[5], cospi16);
+  u[5] = _mm256_sub_epi32(u[5], x);
+  u[5] = _mm256_add_epi32(u[5], rnding);
+  u[5] = _mm256_srai_epi32(u[5], bit);
+
+  u[6] = _mm256_mullo_epi32(v[6], cospim48);
+  x = _mm256_mullo_epi32(v[7], cospi16);
+  u[6] = _mm256_add_epi32(u[6], x);
+  u[6] = _mm256_add_epi32(u[6], rnding);
+  u[6] = _mm256_srai_epi32(u[6], bit);
+
+  u[7] = _mm256_mullo_epi32(v[6], cospi16);
+  x = _mm256_mullo_epi32(v[7], cospim48);
+  u[7] = _mm256_sub_epi32(u[7], x);
+  u[7] = _mm256_add_epi32(u[7], rnding);
+  u[7] = _mm256_srai_epi32(u[7], bit);
+
+  // stage 5
+  addsub_avx2(u[0], u[2], &v[0], &v[2], &clamp_lo, &clamp_hi);
+  addsub_avx2(u[1], u[3], &v[1], &v[3], &clamp_lo, &clamp_hi);
+  addsub_avx2(u[4], u[6], &v[4], &v[6], &clamp_lo, &clamp_hi);
+  addsub_avx2(u[5], u[7], &v[5], &v[7], &clamp_lo, &clamp_hi);
+
+  // stage 6
+  u[0] = v[0];
+  u[1] = v[1];
+  u[4] = v[4];
+  u[5] = v[5];
+
+  v[0] = _mm256_mullo_epi32(v[2], cospi32);
+  x = _mm256_mullo_epi32(v[3], cospi32);
+  u[2] = _mm256_add_epi32(v[0], x);
+  u[2] = _mm256_add_epi32(u[2], rnding);
+  u[2] = _mm256_srai_epi32(u[2], bit);
+
+  u[3] = _mm256_sub_epi32(v[0], x);
+  u[3] = _mm256_add_epi32(u[3], rnding);
+  u[3] = _mm256_srai_epi32(u[3], bit);
+
+  v[0] = _mm256_mullo_epi32(v[6], cospi32);
+  x = _mm256_mullo_epi32(v[7], cospi32);
+  u[6] = _mm256_add_epi32(v[0], x);
+  u[6] = _mm256_add_epi32(u[6], rnding);
+  u[6] = _mm256_srai_epi32(u[6], bit);
+
+  u[7] = _mm256_sub_epi32(v[0], x);
+  u[7] = _mm256_add_epi32(u[7], rnding);
+  u[7] = _mm256_srai_epi32(u[7], bit);
+
+  // stage 7
+  if (do_cols) {
+    out[0] = u[0];
+    out[1] = _mm256_sub_epi32(kZero, u[4]);
+    out[2] = u[6];
+    out[3] = _mm256_sub_epi32(kZero, u[2]);
+    out[4] = u[3];
+    out[5] = _mm256_sub_epi32(kZero, u[7]);
+    out[6] = u[5];
+    out[7] = _mm256_sub_epi32(kZero, u[1]);
+  } else {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m256i clamp_lo_out = _mm256_set1_epi32(-(1 << (log_range_out - 1)));
+    const __m256i clamp_hi_out =
+        _mm256_set1_epi32((1 << (log_range_out - 1)) - 1);
+
+    neg_shift_avx2(u[0], u[4], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out,
+                   out_shift);
+    neg_shift_avx2(u[6], u[2], out + 2, out + 3, &clamp_lo_out, &clamp_hi_out,
+                   out_shift);
+    neg_shift_avx2(u[3], u[7], out + 4, out + 5, &clamp_lo_out, &clamp_hi_out,
+                   out_shift);
+    neg_shift_avx2(u[5], u[1], out + 6, out + 7, &clamp_lo_out, &clamp_hi_out,
+                   out_shift);
+  }
+}
 typedef void (*transform_1d_avx2)(__m256i *in, __m256i *out, int bit,
                                   int do_cols, int bd, int out_shift);
 
@@ -2493,8 +2931,8 @@
           { NULL, NULL, NULL, NULL },
       },
       {
-          { NULL, NULL, NULL, NULL },
-          { NULL, NULL, NULL, NULL },
+          { idct8x8_low1_avx2, idct8x8_avx2, NULL, NULL },
+          { iadst8x8_low1_avx2, iadst8x8_avx2, NULL, NULL },
           { NULL, NULL, NULL, NULL },
       },
       {
@@ -2580,12 +3018,15 @@
   }
 
   // write to buffer
-  {
+  if (txfm_size_col >= 16) {
     for (int i = 0; i < (txfm_size_col >> 4); i++) {
       highbd_write_buffer_16xn_avx2(buf1 + i * txfm_size_row * 2,
                                     output + 16 * i, stride, ud_flip,
                                     txfm_size_row, bd);
     }
+  } else if (txfm_size_col == 8) {
+    highbd_write_buffer_8xn_avx2(buf1, output, stride, ud_flip, txfm_size_row,
+                                 bd);
   }
 }
 
@@ -2698,7 +3139,123 @@
     default: assert(0);
   }
 }
+void av1_highbd_inv_txfm_add_8x8_avx2(const tran_low_t *input, uint8_t *dest,
+                                      int stride, const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+      // Assembly version doesn't support some transform types, so use C version
+      // for those.
+    case V_DCT:
+    case H_DCT:
+    case V_ADST:
+    case H_ADST:
+    case V_FLIPADST:
+    case H_FLIPADST:
+    case IDTX:
+      av1_inv_txfm2d_add_8x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                               bd);
+      break;
+    default:
+      av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+                                              txfm_param->tx_size,
+                                              txfm_param->eob, bd);
+      break;
+  }
+}
+void av1_highbd_inv_txfm_add_8x32_avx2(const tran_low_t *input, uint8_t *dest,
+                                       int stride,
+                                       const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+    case DCT_DCT:
+      av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+                                              txfm_param->tx_size,
+                                              txfm_param->eob, bd);
+      break;
+    case IDTX:
+      av1_inv_txfm2d_add_8x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                                txfm_param->tx_type, txfm_param->bd);
+      break;
+    default: assert(0);
+  }
+}
 
+void av1_highbd_inv_txfm_add_32x8_avx2(const tran_low_t *input, uint8_t *dest,
+                                       int stride,
+                                       const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+    case DCT_DCT:
+      av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+                                              txfm_param->tx_size,
+                                              txfm_param->eob, bd);
+      break;
+    case IDTX:
+      av1_inv_txfm2d_add_32x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                                txfm_param->tx_type, txfm_param->bd);
+      break;
+    default: assert(0);
+  }
+}
+void av1_highbd_inv_txfm_add_16x8_avx2(const tran_low_t *input, uint8_t *dest,
+                                       int stride,
+                                       const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+      // Assembly version doesn't support some transform types, so use C version
+      // for those.
+    case V_DCT:
+    case H_DCT:
+    case V_ADST:
+    case H_ADST:
+    case V_FLIPADST:
+    case H_FLIPADST:
+    case IDTX:
+      av1_inv_txfm2d_add_16x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                                txfm_param->tx_type, txfm_param->bd);
+      break;
+    default:
+      av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+                                              txfm_param->tx_size,
+                                              txfm_param->eob, bd);
+      break;
+  }
+}
+
+void av1_highbd_inv_txfm_add_8x16_avx2(const tran_low_t *input, uint8_t *dest,
+                                       int stride,
+                                       const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const int32_t *src = cast_to_int32(input);
+  switch (tx_type) {
+      // Assembly version doesn't support some transform types, so use C version
+      // for those.
+    case V_DCT:
+    case H_DCT:
+    case V_ADST:
+    case H_ADST:
+    case V_FLIPADST:
+    case H_FLIPADST:
+    case IDTX:
+      av1_inv_txfm2d_add_8x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                                txfm_param->tx_type, txfm_param->bd);
+      break;
+    default:
+      av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+                                              txfm_param->tx_size,
+                                              txfm_param->eob, bd);
+      break;
+  }
+}
 void av1_highbd_inv_txfm_add_avx2(const tran_low_t *input, uint8_t *dest,
                                   int stride, const TxfmParam *txfm_param) {
   assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
@@ -2711,7 +3268,7 @@
       av1_highbd_inv_txfm_add_16x16_avx2(input, dest, stride, txfm_param);
       break;
     case TX_8X8:
-      av1_highbd_inv_txfm_add_8x8_sse4_1(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_8x8_avx2(input, dest, stride, txfm_param);
       break;
     case TX_4X8:
       av1_highbd_inv_txfm_add_4x8_sse4_1(input, dest, stride, txfm_param);
@@ -2720,10 +3277,10 @@
       av1_highbd_inv_txfm_add_8x4_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_8X16:
-      av1_highbd_inv_txfm_add_8x16_sse4_1(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_8x16_avx2(input, dest, stride, txfm_param);
       break;
     case TX_16X8:
-      av1_highbd_inv_txfm_add_16x8_sse4_1(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_16x8_avx2(input, dest, stride, txfm_param);
       break;
     case TX_16X32:
       av1_highbd_inv_txfm_add_16x32_avx2(input, dest, stride, txfm_param);
@@ -2741,10 +3298,10 @@
       av1_highbd_inv_txfm_add_4x16_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_8X32:
-      av1_highbd_inv_txfm_add_8x32_sse4_1(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_8x32_avx2(input, dest, stride, txfm_param);
       break;
     case TX_32X8:
-      av1_highbd_inv_txfm_add_32x8_sse4_1(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_32x8_avx2(input, dest, stride, txfm_param);
       break;
     case TX_64X64:
     case TX_32X64: