Hook in AVX2 inv txfm
1. Add av1_lowbd_inv_txfm2d_add_avx2.
1.1 For size < 16, still using ssse3 version
1.2 For size >= 16, use new AVX2 version
The unittest shows 1.25x ~ 2.0x faster than ssse3 version.
2. Hook in AVX2 inv txfm functions.
Change-Id: Ib99b20264d127eac3a5fb8eb30e0d55ea423d7ba
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index e92fe99..5246fd9 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -213,7 +213,7 @@
#inv txfm
add_proto qw/void av1_inv_txfm_add/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_inv_txfm_add ssse3/;
+specialize qw/av1_inv_txfm_add ssse3 avx2/;
add_proto qw/void av1_inv_txfm2d_add_4x8/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
add_proto qw/void av1_inv_txfm2d_add_8x4/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c
index 136cbcb..34dd41c 100644
--- a/av1/common/x86/av1_inv_txfm_avx2.c
+++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -784,6 +784,85 @@
btf_16_adds_subs_out_avx2(x1[31], x1[32], output[31], output[32]);
}
+// 1D functions process process 16 pixels at one time.
+static const transform_1d_avx2
+ lowbd_txfm_all_1d_w16_arr[TX_SIZES][ITX_TYPES_1D] = {
+ { NULL, NULL, NULL },
+ { NULL, NULL, NULL },
+ { idct16_new_avx2, iadst16_new_avx2, iidentity16_new_avx2 },
+ { idct32_new_avx2, NULL, NULL },
+ { idct64_low32_new_avx2, NULL, NULL },
+ };
+
+// only process w >= 16 h >= 16
+static INLINE void lowbd_inv_txfm2d_add_no_identity_avx2(const int32_t *input,
+ uint8_t *output,
+ int stride,
+ TX_TYPE tx_type,
+ TX_SIZE tx_size) {
+ __m256i buf1[64 * 16];
+ const int8_t *shift = inv_txfm_shift_ls[tx_size];
+ const int txw_idx = get_txw_idx(tx_size);
+ const int txh_idx = get_txh_idx(tx_size);
+ const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
+ const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+ const int txfm_size_col = tx_size_wide[tx_size];
+ const int txfm_size_row = tx_size_high[tx_size];
+ const int buf_size_w_div16 = txfm_size_col >> 4;
+ const int buf_size_h = AOMMIN(32, txfm_size_row);
+ const int input_stride = AOMMIN(32, txfm_size_col);
+ const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+
+ const transform_1d_avx2 row_txfm =
+ lowbd_txfm_all_1d_w16_arr[txw_idx][hitx_1d_tab[tx_type]];
+ const transform_1d_avx2 col_txfm =
+ lowbd_txfm_all_1d_w16_arr[txh_idx][vitx_1d_tab[tx_type]];
+
+ assert(col_txfm != NULL);
+ assert(row_txfm != NULL);
+ int ud_flip, lr_flip;
+ get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+ for (int i = 0; i < buf_size_h; i += 16) {
+ __m256i buf0[64];
+ const int32_t *input_row = input + i * input_stride;
+ for (int j = 0; j < AOMMIN(2, buf_size_w_div16); ++j) {
+ __m256i *buf0_cur = buf0 + j * 16;
+ const int32_t *input_cur = input_row + j * 16;
+ load_buffer_32bit_to_16bit_w16_avx2(input_cur, input_stride, buf0_cur,
+ 16);
+ transpose_16bit_16x16_avx2(buf0_cur, buf0_cur);
+ }
+ if (rect_type == 1 || rect_type == -1) {
+ round_shift_avx2(buf0, buf0, input_stride); // rect special code
+ }
+ row_txfm(buf0, buf0, cos_bit_row);
+ round_shift_16bit_w16_avx2(buf0, txfm_size_col, shift[0]);
+
+ __m256i *buf1_cur = buf1 + i;
+ if (lr_flip) {
+ for (int j = 0; j < buf_size_w_div16; ++j) {
+ __m256i temp[16];
+ flip_buf_av2(buf0 + 16 * j, temp, 16);
+ int offset = txfm_size_row * (buf_size_w_div16 - 1 - j);
+ transpose_16bit_16x16_avx2(temp, buf1_cur + offset);
+ }
+ } else {
+ for (int j = 0; j < buf_size_w_div16; ++j) {
+ transpose_16bit_16x16_avx2(buf0 + 16 * j, buf1_cur + txfm_size_row * j);
+ }
+ }
+ }
+ for (int i = 0; i < buf_size_w_div16; i++) {
+ __m256i *buf1_cur = buf1 + i * txfm_size_row;
+ col_txfm(buf1_cur, buf1_cur, cos_bit_col);
+ round_shift_16bit_w16_avx2(buf1_cur, txfm_size_row, shift[1]);
+ }
+ for (int i = 0; i < buf_size_w_div16; i++) {
+ lowbd_write_buffer_16xn_avx2(buf1 + i * txfm_size_row, output + 16 * i,
+ stride, ud_flip, txfm_size_row);
+ }
+}
+
static INLINE void iidentity16_row_16xn_avx2(__m256i *out, const int32_t *input,
int stride, int shift, int height,
int rect) {
@@ -961,3 +1040,201 @@
default: break;
}
}
+
+static INLINE void lowbd_inv_txfm2d_add_idtx_avx2(const int32_t *input,
+ uint8_t *output, int stride,
+ TX_SIZE tx_size) {
+ const int8_t *shift = inv_txfm_shift_ls[tx_size];
+ const int txw_idx = get_txw_idx(tx_size);
+ const int txh_idx = get_txh_idx(tx_size);
+ const int txfm_size_col = tx_size_wide[tx_size];
+ const int txfm_size_row = tx_size_high[tx_size];
+ const int input_stride = AOMMIN(32, txfm_size_col);
+ const int row_max = AOMMIN(32, txfm_size_row);
+ const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+ __m256i buf[32];
+ for (int i = 0; i < input_stride; i += 16) {
+ identity_row_16xn_avx2(buf, input + i, input_stride, shift[0], row_max,
+ txw_idx, rect_type);
+ identity_col_16xn_avx2(output + i, stride, buf, shift[1], row_max, txh_idx);
+ }
+}
+
+static INLINE void lowbd_inv_txfm2d_add_h_identity_avx2(const int32_t *input,
+ uint8_t *output,
+ int stride,
+ TX_TYPE tx_type,
+ TX_SIZE tx_size) {
+ const int8_t *shift = inv_txfm_shift_ls[tx_size];
+ const int txw_idx = get_txw_idx(tx_size);
+ const int txh_idx = get_txh_idx(tx_size);
+ const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
+ const int txfm_size_col = tx_size_wide[tx_size];
+ const int txfm_size_row = tx_size_high[tx_size];
+ const int txfm_size_col_notzero = AOMMIN(32, txfm_size_col);
+ const int txfm_size_row_notzero = AOMMIN(32, txfm_size_row);
+ const int input_stride = txfm_size_col_notzero;
+ const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+
+ const transform_1d_avx2 col_txfm =
+ lowbd_txfm_all_1d_w16_arr[txh_idx][vitx_1d_tab[tx_type]];
+
+ assert(col_txfm != NULL);
+
+ int ud_flip, lr_flip;
+ get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+ for (int i = 0; i < txfm_size_col_notzero; i += 16) {
+ __m256i buf0[64];
+ identity_row_16xn_avx2(buf0, input + i, input_stride, shift[0],
+ txfm_size_row_notzero, txw_idx, rect_type);
+ col_txfm(buf0, buf0, cos_bit_col);
+ __m256i mshift = _mm256_set1_epi16(1 << (15 + shift[1]));
+ int k = ud_flip ? (txfm_size_row - 1) : 0;
+ const int step = ud_flip ? -1 : 1;
+ for (int j = 0; j < txfm_size_row; ++j, k += step) {
+ __m256i res = _mm256_mulhrs_epi16(buf0[k], mshift);
+ write_recon_w16_avx2(res, output + i + j * stride);
+ }
+ }
+}
+
+static INLINE void lowbd_inv_txfm2d_add_v_identity_avx2(const int32_t *input,
+ uint8_t *output,
+ int stride,
+ TX_TYPE tx_type,
+ TX_SIZE tx_size) {
+ __m256i buf1[64];
+ const int8_t *shift = inv_txfm_shift_ls[tx_size];
+ const int txw_idx = get_txw_idx(tx_size);
+ const int txh_idx = get_txh_idx(tx_size);
+ const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+ const int txfm_size_col = tx_size_wide[tx_size];
+ const int txfm_size_row = tx_size_high[tx_size];
+ const int buf_size_w_div16 = txfm_size_col >> 4;
+ const int buf_size_h_div16 = AOMMIN(32, txfm_size_row) >> 4;
+ const int input_stride = AOMMIN(32, txfm_size_col);
+ const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+
+ const transform_1d_avx2 row_txfm =
+ lowbd_txfm_all_1d_w16_arr[txw_idx][hitx_1d_tab[tx_type]];
+
+ assert(row_txfm != NULL);
+ int ud_flip, lr_flip;
+ get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+ for (int i = 0; i < buf_size_h_div16; i++) {
+ __m256i buf0[64];
+ const int32_t *input_row = input + i * input_stride * 16;
+ for (int j = 0; j < AOMMIN(4, buf_size_w_div16); ++j) {
+ __m256i *buf0_cur = buf0 + j * 16;
+ load_buffer_32bit_to_16bit_w16_avx2(input_row + j * 16, input_stride,
+ buf0_cur, 16);
+ transpose_16bit_16x16_avx2(buf0_cur, buf0_cur);
+ }
+ if (rect_type == 1 || rect_type == -1) {
+ round_shift_avx2(buf0, buf0, input_stride); // rect special code
+ }
+ row_txfm(buf0, buf0, cos_bit_row);
+ round_shift_16bit_w16_avx2(buf0, txfm_size_col, shift[0]);
+ __m256i *_buf1 = buf1;
+ if (lr_flip) {
+ for (int j = 0; j < buf_size_w_div16; ++j) {
+ __m256i temp[16];
+ flip_buf_av2(buf0 + 16 * j, temp, 16);
+ transpose_16bit_16x16_avx2(temp,
+ _buf1 + 16 * (buf_size_w_div16 - 1 - j));
+ }
+ } else {
+ for (int j = 0; j < buf_size_w_div16; ++j) {
+ transpose_16bit_16x16_avx2(buf0 + 16 * j, _buf1 + 16 * j);
+ }
+ }
+ for (int j = 0; j < buf_size_w_div16; ++j) {
+ identity_col_16xn_avx2(output + i * 16 * stride + j * 16, stride,
+ buf1 + j * 16, shift[1], 16, txh_idx);
+ }
+ }
+}
+
+// for 32x32,32x64,64x32,64x64,16x32,32x16,64x16,16x64
+static INLINE void lowbd_inv_txfm2d_add_universe_avx2(
+ const int32_t *input, uint8_t *output, int stride, TX_TYPE tx_type,
+ TX_SIZE tx_size, int eob) {
+ (void)eob;
+ switch (tx_type) {
+ case DCT_DCT:
+ case ADST_DCT: // ADST in vertical, DCT in horizontal
+ case DCT_ADST: // DCT in vertical, ADST in horizontal
+ case ADST_ADST: // ADST in both directions
+ case FLIPADST_DCT:
+ case DCT_FLIPADST:
+ case FLIPADST_FLIPADST:
+ case ADST_FLIPADST:
+ case FLIPADST_ADST:
+ lowbd_inv_txfm2d_add_no_identity_avx2(input, output, stride, tx_type,
+ tx_size);
+ break;
+ case IDTX:
+ lowbd_inv_txfm2d_add_idtx_avx2(input, output, stride, tx_size);
+ break;
+ case V_DCT:
+ case V_ADST:
+ case V_FLIPADST:
+ lowbd_inv_txfm2d_add_h_identity_avx2(input, output, stride, tx_type,
+ tx_size);
+ break;
+ case H_DCT:
+ case H_ADST:
+ case H_FLIPADST:
+ lowbd_inv_txfm2d_add_v_identity_avx2(input, output, stride, tx_type,
+ tx_size);
+ break;
+ default:
+ av1_lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, tx_size,
+ eob);
+ break;
+ }
+}
+
+void av1_lowbd_inv_txfm2d_add_avx2(const int32_t *input, uint8_t *output,
+ int stride, TX_TYPE tx_type, TX_SIZE tx_size,
+ int eob) {
+ switch (tx_size) {
+ case TX_4X4:
+ case TX_8X8: // 8x8 transform
+ case TX_4X8:
+ case TX_8X4:
+ case TX_8X16: // 8x16 transform
+ case TX_16X8: // 16x8 transform
+ case TX_4X16:
+ case TX_16X4:
+ case TX_8X32: // 8x32 transform
+ case TX_32X8: // 32x8 transform
+ av1_lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, tx_size,
+ eob);
+ break;
+ case TX_16X16: // 16x16 transform
+ case TX_32X32: // 32x32 transform
+ case TX_64X64: // 64x64 transform
+ case TX_16X32: // 16x32 transform
+ case TX_32X16: // 32x16 transform
+ case TX_32X64: // 32x64 transform
+ case TX_64X32: // 64x32 transform
+ case TX_16X64: // 16x64 transform
+ case TX_64X16: // 64x16 transform
+ default:
+ lowbd_inv_txfm2d_add_universe_avx2(input, output, stride, tx_type,
+ tx_size, eob);
+ break;
+ }
+}
+
+void av1_inv_txfm_add_avx2(const tran_low_t *dqcoeff, uint8_t *dst, int stride,
+ const TxfmParam *txfm_param) {
+ const TX_TYPE tx_type = txfm_param->tx_type;
+ if (!txfm_param->lossless) {
+ av1_lowbd_inv_txfm2d_add_avx2(dqcoeff, dst, stride, tx_type,
+ txfm_param->tx_size, txfm_param->eob);
+ } else {
+ av1_inv_txfm_add_c(dqcoeff, dst, stride, txfm_param);
+ }
+}
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c
index b07d28e..8a03e6d 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.c
+++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -2172,29 +2172,6 @@
}
}
-// 1D itx types
-typedef enum ATTRIBUTE_PACKED {
- IDCT_1D,
- IADST_1D,
- IFLIPADST_1D = IADST_1D,
- IIDENTITY_1D,
- ITX_TYPES_1D,
-} ITX_TYPE_1D;
-
-static const ITX_TYPE_1D vitx_1d_tab[TX_TYPES] = {
- IDCT_1D, IADST_1D, IDCT_1D, IADST_1D,
- IFLIPADST_1D, IDCT_1D, IFLIPADST_1D, IADST_1D,
- IFLIPADST_1D, IIDENTITY_1D, IDCT_1D, IIDENTITY_1D,
- IADST_1D, IIDENTITY_1D, IFLIPADST_1D, IIDENTITY_1D,
-};
-
-static const ITX_TYPE_1D hitx_1d_tab[TX_TYPES] = {
- IDCT_1D, IDCT_1D, IADST_1D, IADST_1D,
- IDCT_1D, IFLIPADST_1D, IFLIPADST_1D, IFLIPADST_1D,
- IADST_1D, IIDENTITY_1D, IIDENTITY_1D, IDCT_1D,
- IIDENTITY_1D, IADST_1D, IIDENTITY_1D, IFLIPADST_1D,
-};
-
// 1D functions process process 8 pixels at one time.
static const transform_1d_ssse3
lowbd_txfm_all_1d_w8_arr[TX_SIZES][ITX_TYPES_1D] = {
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.h b/av1/common/x86/av1_inv_txfm_ssse3.h
index 8ef480c..e0be404 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.h
+++ b/av1/common/x86/av1_inv_txfm_ssse3.h
@@ -31,6 +31,30 @@
#ifdef __cplusplus
extern "C" {
#endif
+
+// 1D itx types
+typedef enum ATTRIBUTE_PACKED {
+ IDCT_1D,
+ IADST_1D,
+ IFLIPADST_1D = IADST_1D,
+ IIDENTITY_1D,
+ ITX_TYPES_1D,
+} ITX_TYPE_1D;
+
+static const ITX_TYPE_1D vitx_1d_tab[TX_TYPES] = {
+ IDCT_1D, IADST_1D, IDCT_1D, IADST_1D,
+ IFLIPADST_1D, IDCT_1D, IFLIPADST_1D, IADST_1D,
+ IFLIPADST_1D, IIDENTITY_1D, IDCT_1D, IIDENTITY_1D,
+ IADST_1D, IIDENTITY_1D, IFLIPADST_1D, IIDENTITY_1D,
+};
+
+static const ITX_TYPE_1D hitx_1d_tab[TX_TYPES] = {
+ IDCT_1D, IDCT_1D, IADST_1D, IADST_1D,
+ IDCT_1D, IFLIPADST_1D, IFLIPADST_1D, IFLIPADST_1D,
+ IADST_1D, IIDENTITY_1D, IIDENTITY_1D, IDCT_1D,
+ IIDENTITY_1D, IADST_1D, IIDENTITY_1D, IFLIPADST_1D,
+};
+
typedef void (*transform_1d_ssse3)(const __m128i *input, __m128i *output,
int8_t cos_bit);
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index ec2ca4f..73feb16 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -319,10 +319,17 @@
#if HAVE_SSSE3
#if defined(_MSC_VER) || defined(__SSSE3__)
#include "av1/common/x86/av1_inv_txfm_ssse3.h"
-
INSTANTIATE_TEST_CASE_P(SSSE3, AV1LbdInvTxfm2d,
::testing::Values(av1_lowbd_inv_txfm2d_add_ssse3));
#endif // _MSC_VER || __SSSE3__
-#endif // HAVE_SSE2
+#endif // HAVE_SSSE3
+
+#if HAVE_AVX2
+#if defined(_MSC_VER) || defined(__AVX2__)
+#include "av1/common/x86/av1_inv_txfm_avx2.h"
+INSTANTIATE_TEST_CASE_P(AVX2, AV1LbdInvTxfm2d,
+ ::testing::Values(av1_lowbd_inv_txfm2d_add_avx2));
+#endif // (_MSC_VER) || (__AVX2__)
+#endif // HAVE_AVX2
} // namespace