Optimize highbd fwd_txfm module
Added AVX2 variant for 8x8 txfm blk_size.
When tested for multiple test cases observed 0.4%
average reduction in encoder time for speed = 1 preset.
Module level gains improved by a factor of ~1.5
on average w.r.t to SSE4_1 module.
Change-Id: Iee170dc60916c34a841679154f1dd585cf9d7190
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 163a0a7..4ab76a1 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -230,7 +230,7 @@
add_proto qw/void av1_fwd_txfm2d_4x4/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
specialize qw/av1_fwd_txfm2d_4x4 sse4_1/;
add_proto qw/void av1_fwd_txfm2d_8x8/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
- specialize qw/av1_fwd_txfm2d_8x8 sse4_1/;
+ specialize qw/av1_fwd_txfm2d_8x8 sse4_1 avx2/;
add_proto qw/void av1_fwd_txfm2d_16x16/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
specialize qw/av1_fwd_txfm2d_16x16 sse4_1 avx2/;
add_proto qw/void av1_fwd_txfm2d_32x32/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
diff --git a/av1/encoder/x86/highbd_fwd_txfm_avx2.c b/av1/encoder/x86/highbd_fwd_txfm_avx2.c
index 0cdbebd..019ab66 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_avx2.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_avx2.c
@@ -19,6 +19,80 @@
#include "aom_ports/mem.h"
#include "aom_dsp/x86/txfm_common_sse2.h"
+static INLINE void av1_load_buffer_8x8_avx2(const int16_t *input, __m256i *out,
+ int stride, int flipud, int fliplr,
+ int shift) {
+ __m128i out1[8];
+ if (!flipud) {
+ out1[0] = _mm_load_si128((const __m128i *)(input + 0 * stride));
+ out1[1] = _mm_load_si128((const __m128i *)(input + 1 * stride));
+ out1[2] = _mm_load_si128((const __m128i *)(input + 2 * stride));
+ out1[3] = _mm_load_si128((const __m128i *)(input + 3 * stride));
+ out1[4] = _mm_load_si128((const __m128i *)(input + 4 * stride));
+ out1[5] = _mm_load_si128((const __m128i *)(input + 5 * stride));
+ out1[6] = _mm_load_si128((const __m128i *)(input + 6 * stride));
+ out1[7] = _mm_load_si128((const __m128i *)(input + 7 * stride));
+
+ } else {
+ out1[7] = _mm_load_si128((const __m128i *)(input + 0 * stride));
+ out1[6] = _mm_load_si128((const __m128i *)(input + 1 * stride));
+ out1[5] = _mm_load_si128((const __m128i *)(input + 2 * stride));
+ out1[4] = _mm_load_si128((const __m128i *)(input + 3 * stride));
+ out1[3] = _mm_load_si128((const __m128i *)(input + 4 * stride));
+ out1[2] = _mm_load_si128((const __m128i *)(input + 5 * stride));
+ out1[1] = _mm_load_si128((const __m128i *)(input + 6 * stride));
+ out1[0] = _mm_load_si128((const __m128i *)(input + 7 * stride));
+ }
+ if (!fliplr) {
+ out[0] = _mm256_cvtepi16_epi32(out1[0]);
+ out[1] = _mm256_cvtepi16_epi32(out1[1]);
+ out[2] = _mm256_cvtepi16_epi32(out1[2]);
+ out[3] = _mm256_cvtepi16_epi32(out1[3]);
+ out[4] = _mm256_cvtepi16_epi32(out1[4]);
+ out[5] = _mm256_cvtepi16_epi32(out1[5]);
+ out[6] = _mm256_cvtepi16_epi32(out1[6]);
+ out[7] = _mm256_cvtepi16_epi32(out1[7]);
+
+ } else {
+ out[0] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[0]));
+ out[1] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[1]));
+ out[2] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[2]));
+ out[3] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[3]));
+ out[4] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[4]));
+ out[5] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[5]));
+ out[6] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[6]));
+ out[7] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[7]));
+ }
+ out[0] = _mm256_slli_epi32(out[0], shift);
+ out[1] = _mm256_slli_epi32(out[1], shift);
+ out[2] = _mm256_slli_epi32(out[2], shift);
+ out[3] = _mm256_slli_epi32(out[3], shift);
+ out[4] = _mm256_slli_epi32(out[4], shift);
+ out[5] = _mm256_slli_epi32(out[5], shift);
+ out[6] = _mm256_slli_epi32(out[6], shift);
+ out[7] = _mm256_slli_epi32(out[7], shift);
+}
+static INLINE void col_txfm_8x8_rounding(__m256i *in, int shift) {
+ const __m256i rounding = _mm256_set1_epi32(1 << (shift - 1));
+
+ in[0] = _mm256_add_epi32(in[0], rounding);
+ in[1] = _mm256_add_epi32(in[1], rounding);
+ in[2] = _mm256_add_epi32(in[2], rounding);
+ in[3] = _mm256_add_epi32(in[3], rounding);
+ in[4] = _mm256_add_epi32(in[4], rounding);
+ in[5] = _mm256_add_epi32(in[5], rounding);
+ in[6] = _mm256_add_epi32(in[6], rounding);
+ in[7] = _mm256_add_epi32(in[7], rounding);
+
+ in[0] = _mm256_srai_epi32(in[0], shift);
+ in[1] = _mm256_srai_epi32(in[1], shift);
+ in[2] = _mm256_srai_epi32(in[2], shift);
+ in[3] = _mm256_srai_epi32(in[3], shift);
+ in[4] = _mm256_srai_epi32(in[4], shift);
+ in[5] = _mm256_srai_epi32(in[5], shift);
+ in[6] = _mm256_srai_epi32(in[6], shift);
+ in[7] = _mm256_srai_epi32(in[7], shift);
+}
static INLINE void av1_load_buffer_16xn_avx2(const int16_t *input, __m256i *out,
int stride, int height,
int outstride, int flipud,
@@ -163,6 +237,457 @@
const int8_t cos_bit,
const int8_t *stage_range, int instride,
int outstride);
+static void av1_fdct8_avx2(__m256i *in, __m256i *out, int bit,
+ const int col_num) {
+ const int32_t *cospi = cospi_arr(bit);
+ const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+ const __m256i cospim32 = _mm256_set1_epi32(-cospi[32]);
+ const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+ const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+ const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+ const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+ const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+ const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+ const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+ __m256i u[8], v[8];
+
+ int startidx = 0;
+ int endidx = 7;
+
+ // stage 0
+ // stage 1
+ u[0] = _mm256_add_epi32(in[startidx], in[endidx]);
+ v[7] = _mm256_sub_epi32(in[startidx], in[endidx]);
+ startidx += col_num;
+ endidx -= col_num;
+ u[1] = _mm256_add_epi32(in[startidx], in[endidx]);
+ u[6] = _mm256_sub_epi32(in[startidx], in[endidx]);
+ startidx += col_num;
+ endidx -= col_num;
+ u[2] = _mm256_add_epi32(in[startidx], in[endidx]);
+ u[5] = _mm256_sub_epi32(in[startidx], in[endidx]);
+ startidx += col_num;
+ endidx -= col_num;
+ u[3] = _mm256_add_epi32(in[startidx], in[endidx]);
+ v[4] = _mm256_sub_epi32(in[startidx], in[endidx]);
+
+ // stage 2
+ v[0] = _mm256_add_epi32(u[0], u[3]);
+ v[3] = _mm256_sub_epi32(u[0], u[3]);
+ v[1] = _mm256_add_epi32(u[1], u[2]);
+ v[2] = _mm256_sub_epi32(u[1], u[2]);
+
+ v[5] = _mm256_mullo_epi32(u[5], cospim32);
+ v[6] = _mm256_mullo_epi32(u[6], cospi32);
+ v[5] = _mm256_add_epi32(v[5], v[6]);
+ v[5] = _mm256_add_epi32(v[5], rnding);
+ v[5] = _mm256_srai_epi32(v[5], bit);
+
+ u[0] = _mm256_mullo_epi32(u[5], cospi32);
+ v[6] = _mm256_mullo_epi32(u[6], cospim32);
+ v[6] = _mm256_sub_epi32(u[0], v[6]);
+ v[6] = _mm256_add_epi32(v[6], rnding);
+ v[6] = _mm256_srai_epi32(v[6], bit);
+
+ // stage 3
+ // type 0
+ v[0] = _mm256_mullo_epi32(v[0], cospi32);
+ v[1] = _mm256_mullo_epi32(v[1], cospi32);
+ u[0] = _mm256_add_epi32(v[0], v[1]);
+ u[0] = _mm256_add_epi32(u[0], rnding);
+ u[0] = _mm256_srai_epi32(u[0], bit);
+
+ u[1] = _mm256_sub_epi32(v[0], v[1]);
+ u[1] = _mm256_add_epi32(u[1], rnding);
+ u[1] = _mm256_srai_epi32(u[1], bit);
+
+ // type 1
+ v[0] = _mm256_mullo_epi32(v[2], cospi48);
+ v[1] = _mm256_mullo_epi32(v[3], cospi16);
+ u[2] = _mm256_add_epi32(v[0], v[1]);
+ u[2] = _mm256_add_epi32(u[2], rnding);
+ u[2] = _mm256_srai_epi32(u[2], bit);
+
+ v[0] = _mm256_mullo_epi32(v[2], cospi16);
+ v[1] = _mm256_mullo_epi32(v[3], cospi48);
+ u[3] = _mm256_sub_epi32(v[1], v[0]);
+ u[3] = _mm256_add_epi32(u[3], rnding);
+ u[3] = _mm256_srai_epi32(u[3], bit);
+
+ u[4] = _mm256_add_epi32(v[4], v[5]);
+ u[5] = _mm256_sub_epi32(v[4], v[5]);
+ u[6] = _mm256_sub_epi32(v[7], v[6]);
+ u[7] = _mm256_add_epi32(v[7], v[6]);
+
+ // stage 4
+ // stage 5
+ v[0] = _mm256_mullo_epi32(u[4], cospi56);
+ v[1] = _mm256_mullo_epi32(u[7], cospi8);
+ v[0] = _mm256_add_epi32(v[0], v[1]);
+ v[0] = _mm256_add_epi32(v[0], rnding);
+ out[1] = _mm256_srai_epi32(v[0], bit);
+
+ v[0] = _mm256_mullo_epi32(u[4], cospi8);
+ v[1] = _mm256_mullo_epi32(u[7], cospi56);
+ v[0] = _mm256_sub_epi32(v[1], v[0]);
+ v[0] = _mm256_add_epi32(v[0], rnding);
+ out[7] = _mm256_srai_epi32(v[0], bit);
+
+ v[0] = _mm256_mullo_epi32(u[5], cospi24);
+ v[1] = _mm256_mullo_epi32(u[6], cospi40);
+ v[0] = _mm256_add_epi32(v[0], v[1]);
+ v[0] = _mm256_add_epi32(v[0], rnding);
+ out[5] = _mm256_srai_epi32(v[0], bit);
+
+ v[0] = _mm256_mullo_epi32(u[5], cospi40);
+ v[1] = _mm256_mullo_epi32(u[6], cospi24);
+ v[0] = _mm256_sub_epi32(v[1], v[0]);
+ v[0] = _mm256_add_epi32(v[0], rnding);
+ out[3] = _mm256_srai_epi32(v[0], bit);
+
+ out[0] = u[0];
+ out[4] = u[1];
+ out[2] = u[2];
+ out[6] = u[3];
+}
+static void av1_fadst8_avx2(__m256i *in, __m256i *out, int bit,
+ const int col_num) {
+ (void)col_num;
+ const int32_t *cospi = cospi_arr(bit);
+ const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+ const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+ const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
+ const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+ const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+ const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+ const __m256i cospim4 = _mm256_set1_epi32(-cospi[4]);
+ const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+ const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
+ const __m256i cospim20 = _mm256_set1_epi32(-cospi[20]);
+ const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
+ const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
+ const __m256i cospi36 = _mm256_set1_epi32(cospi[36]);
+ const __m256i cospim36 = _mm256_set1_epi32(-cospi[36]);
+ const __m256i cospi52 = _mm256_set1_epi32(cospi[52]);
+ const __m256i cospim52 = _mm256_set1_epi32(-cospi[52]);
+ const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
+ const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+ const __m256i zero = _mm256_setzero_si256();
+ __m256i u0, u1, u2, u3, u4, u5, u6, u7;
+ __m256i v0, v1, v2, v3, v4, v5, v6, v7;
+ __m256i x, y;
+
+ // stage 0
+ // stage 1
+ u0 = in[0];
+ u1 = _mm256_sub_epi32(zero, in[7]);
+ u2 = _mm256_sub_epi32(zero, in[3]);
+ u3 = in[4];
+ u4 = _mm256_sub_epi32(zero, in[1]);
+ u5 = in[6];
+ u6 = in[2];
+ u7 = _mm256_sub_epi32(zero, in[5]);
+
+ // stage 2
+ v0 = u0;
+ v1 = u1;
+
+ x = _mm256_mullo_epi32(u2, cospi32);
+ y = _mm256_mullo_epi32(u3, cospi32);
+ v2 = _mm256_add_epi32(x, y);
+ v2 = _mm256_add_epi32(v2, rnding);
+ v2 = _mm256_srai_epi32(v2, bit);
+
+ v3 = _mm256_sub_epi32(x, y);
+ v3 = _mm256_add_epi32(v3, rnding);
+ v3 = _mm256_srai_epi32(v3, bit);
+
+ v4 = u4;
+ v5 = u5;
+
+ x = _mm256_mullo_epi32(u6, cospi32);
+ y = _mm256_mullo_epi32(u7, cospi32);
+ v6 = _mm256_add_epi32(x, y);
+ v6 = _mm256_add_epi32(v6, rnding);
+ v6 = _mm256_srai_epi32(v6, bit);
+
+ v7 = _mm256_sub_epi32(x, y);
+ v7 = _mm256_add_epi32(v7, rnding);
+ v7 = _mm256_srai_epi32(v7, bit);
+
+ // stage 3
+ u0 = _mm256_add_epi32(v0, v2);
+ u1 = _mm256_add_epi32(v1, v3);
+ u2 = _mm256_sub_epi32(v0, v2);
+ u3 = _mm256_sub_epi32(v1, v3);
+ u4 = _mm256_add_epi32(v4, v6);
+ u5 = _mm256_add_epi32(v5, v7);
+ u6 = _mm256_sub_epi32(v4, v6);
+ u7 = _mm256_sub_epi32(v5, v7);
+
+ // stage 4
+ v0 = u0;
+ v1 = u1;
+ v2 = u2;
+ v3 = u3;
+
+ x = _mm256_mullo_epi32(u4, cospi16);
+ y = _mm256_mullo_epi32(u5, cospi48);
+ v4 = _mm256_add_epi32(x, y);
+ v4 = _mm256_add_epi32(v4, rnding);
+ v4 = _mm256_srai_epi32(v4, bit);
+
+ x = _mm256_mullo_epi32(u4, cospi48);
+ y = _mm256_mullo_epi32(u5, cospim16);
+ v5 = _mm256_add_epi32(x, y);
+ v5 = _mm256_add_epi32(v5, rnding);
+ v5 = _mm256_srai_epi32(v5, bit);
+
+ x = _mm256_mullo_epi32(u6, cospim48);
+ y = _mm256_mullo_epi32(u7, cospi16);
+ v6 = _mm256_add_epi32(x, y);
+ v6 = _mm256_add_epi32(v6, rnding);
+ v6 = _mm256_srai_epi32(v6, bit);
+
+ x = _mm256_mullo_epi32(u6, cospi16);
+ y = _mm256_mullo_epi32(u7, cospi48);
+ v7 = _mm256_add_epi32(x, y);
+ v7 = _mm256_add_epi32(v7, rnding);
+ v7 = _mm256_srai_epi32(v7, bit);
+
+ // stage 5
+ u0 = _mm256_add_epi32(v0, v4);
+ u1 = _mm256_add_epi32(v1, v5);
+ u2 = _mm256_add_epi32(v2, v6);
+ u3 = _mm256_add_epi32(v3, v7);
+ u4 = _mm256_sub_epi32(v0, v4);
+ u5 = _mm256_sub_epi32(v1, v5);
+ u6 = _mm256_sub_epi32(v2, v6);
+ u7 = _mm256_sub_epi32(v3, v7);
+
+ // stage 6
+ x = _mm256_mullo_epi32(u0, cospi4);
+ y = _mm256_mullo_epi32(u1, cospi60);
+ v0 = _mm256_add_epi32(x, y);
+ v0 = _mm256_add_epi32(v0, rnding);
+ v0 = _mm256_srai_epi32(v0, bit);
+
+ x = _mm256_mullo_epi32(u0, cospi60);
+ y = _mm256_mullo_epi32(u1, cospim4);
+ v1 = _mm256_add_epi32(x, y);
+ v1 = _mm256_add_epi32(v1, rnding);
+ v1 = _mm256_srai_epi32(v1, bit);
+
+ x = _mm256_mullo_epi32(u2, cospi20);
+ y = _mm256_mullo_epi32(u3, cospi44);
+ v2 = _mm256_add_epi32(x, y);
+ v2 = _mm256_add_epi32(v2, rnding);
+ v2 = _mm256_srai_epi32(v2, bit);
+
+ x = _mm256_mullo_epi32(u2, cospi44);
+ y = _mm256_mullo_epi32(u3, cospim20);
+ v3 = _mm256_add_epi32(x, y);
+ v3 = _mm256_add_epi32(v3, rnding);
+ v3 = _mm256_srai_epi32(v3, bit);
+
+ x = _mm256_mullo_epi32(u4, cospi36);
+ y = _mm256_mullo_epi32(u5, cospi28);
+ v4 = _mm256_add_epi32(x, y);
+ v4 = _mm256_add_epi32(v4, rnding);
+ v4 = _mm256_srai_epi32(v4, bit);
+
+ x = _mm256_mullo_epi32(u4, cospi28);
+ y = _mm256_mullo_epi32(u5, cospim36);
+ v5 = _mm256_add_epi32(x, y);
+ v5 = _mm256_add_epi32(v5, rnding);
+ v5 = _mm256_srai_epi32(v5, bit);
+
+ x = _mm256_mullo_epi32(u6, cospi52);
+ y = _mm256_mullo_epi32(u7, cospi12);
+ v6 = _mm256_add_epi32(x, y);
+ v6 = _mm256_add_epi32(v6, rnding);
+ v6 = _mm256_srai_epi32(v6, bit);
+
+ x = _mm256_mullo_epi32(u6, cospi12);
+ y = _mm256_mullo_epi32(u7, cospim52);
+ v7 = _mm256_add_epi32(x, y);
+ v7 = _mm256_add_epi32(v7, rnding);
+ v7 = _mm256_srai_epi32(v7, bit);
+
+ // stage 7
+ out[0] = v1;
+ out[1] = v6;
+ out[2] = v3;
+ out[3] = v4;
+ out[4] = v5;
+ out[5] = v2;
+ out[6] = v7;
+ out[7] = v0;
+}
+static void av1_idtx8_avx2(__m256i *in, __m256i *out, int bit, int col_num) {
+ (void)bit;
+ (void)col_num;
+ out[0] = _mm256_add_epi32(in[0], in[0]);
+ out[1] = _mm256_add_epi32(in[1], in[1]);
+ out[2] = _mm256_add_epi32(in[2], in[2]);
+ out[3] = _mm256_add_epi32(in[3], in[3]);
+ out[4] = _mm256_add_epi32(in[4], in[4]);
+ out[5] = _mm256_add_epi32(in[5], in[5]);
+ out[6] = _mm256_add_epi32(in[6], in[6]);
+ out[7] = _mm256_add_epi32(in[7], in[7]);
+}
+void av1_fwd_txfm2d_8x8_avx2(const int16_t *input, int32_t *coeff, int stride,
+ TX_TYPE tx_type, int bd) {
+ __m256i in[8], out[8];
+ const TX_SIZE tx_size = TX_8X8;
+ const int8_t *shift = fwd_txfm_shift_ls[tx_size];
+ const int txw_idx = get_txw_idx(tx_size);
+ const int txh_idx = get_txh_idx(tx_size);
+ const int width = tx_size_wide[tx_size];
+ const int height = tx_size_high[tx_size];
+ const int width_div8 = (width >> 3);
+
+ switch (tx_type) {
+ case DCT_DCT:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fdct8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case ADST_DCT:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fdct8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case DCT_ADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case ADST_ADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case FLIPADST_DCT:
+ av1_load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fdct8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case DCT_FLIPADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
+ av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case FLIPADST_FLIPADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 1, 1, shift[0]);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case ADST_FLIPADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case FLIPADST_ADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_row[txw_idx][txh_idx], 1);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case IDTX:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_idtx8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], height);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_idtx8_avx2(out, in, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case V_DCT:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_idtx8_avx2(out, in, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case H_DCT:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_idtx8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fdct8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case V_ADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_idtx8_avx2(out, in, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case H_ADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
+ av1_idtx8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case V_FLIPADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_idtx8_avx2(out, in, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ case H_FLIPADST:
+ av1_load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
+ av1_idtx8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ col_txfm_8x8_rounding(out, -shift[1]);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_fadst8_avx2(in, out, fwd_cos_bit_col[txw_idx][txh_idx], width_div8);
+ av1_fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
+ av1_store_buffer_avx2(in, coeff, 8, 8);
+ break;
+ default: assert(0);
+ }
+ (void)bd;
+}
static void av1_fdct16_avx2(__m256i *in, __m256i *out, int bit,
const int col_num) {
const int32_t *cospi = cospi_arr(bit);
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index eb3455b..04c6ddb 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -571,7 +571,8 @@
Values(av1_highbd_fwd_txfm)));
#endif // HAVE_SSE4_1
#if HAVE_AVX2
-static TX_SIZE Highbd_fwd_txfm_for_avx2[] = { TX_16X16, TX_32X32, TX_64X64 };
+static TX_SIZE Highbd_fwd_txfm_for_avx2[] = { TX_8X8, TX_16X16, TX_32X32,
+ TX_64X64 };
INSTANTIATE_TEST_CASE_P(AVX2, AV1HighbdFwdTxfm2dTest,
Combine(ValuesIn(Highbd_fwd_txfm_for_avx2),