Hook in AVX2 inv txfm

1. Add av1_lowbd_inv_txfm2d_add_avx2.
1.1 For size < 16, still using ssse3 version
1.2 For size >= 16, use new AVX2 version
The unittest shows 1.25x ~ 2.0x faster than ssse3 version.

2. Hook in AVX2 inv txfm functions.

Change-Id: Ib99b20264d127eac3a5fb8eb30e0d55ea423d7ba
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index e92fe99..5246fd9 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -213,7 +213,7 @@
 
 #inv txfm
 add_proto qw/void av1_inv_txfm_add/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_inv_txfm_add ssse3/;
+specialize qw/av1_inv_txfm_add ssse3 avx2/;
 
 add_proto qw/void av1_inv_txfm2d_add_4x8/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
 add_proto qw/void av1_inv_txfm2d_add_8x4/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c
index 136cbcb..34dd41c 100644
--- a/av1/common/x86/av1_inv_txfm_avx2.c
+++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -784,6 +784,85 @@
   btf_16_adds_subs_out_avx2(x1[31], x1[32], output[31], output[32]);
 }
 
+// 1D functions process process 16 pixels at one time.
+static const transform_1d_avx2
+    lowbd_txfm_all_1d_w16_arr[TX_SIZES][ITX_TYPES_1D] = {
+      { NULL, NULL, NULL },
+      { NULL, NULL, NULL },
+      { idct16_new_avx2, iadst16_new_avx2, iidentity16_new_avx2 },
+      { idct32_new_avx2, NULL, NULL },
+      { idct64_low32_new_avx2, NULL, NULL },
+    };
+
+// only process w >= 16 h >= 16
+static INLINE void lowbd_inv_txfm2d_add_no_identity_avx2(const int32_t *input,
+                                                         uint8_t *output,
+                                                         int stride,
+                                                         TX_TYPE tx_type,
+                                                         TX_SIZE tx_size) {
+  __m256i buf1[64 * 16];
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
+  const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int buf_size_w_div16 = txfm_size_col >> 4;
+  const int buf_size_h = AOMMIN(32, txfm_size_row);
+  const int input_stride = AOMMIN(32, txfm_size_col);
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+
+  const transform_1d_avx2 row_txfm =
+      lowbd_txfm_all_1d_w16_arr[txw_idx][hitx_1d_tab[tx_type]];
+  const transform_1d_avx2 col_txfm =
+      lowbd_txfm_all_1d_w16_arr[txh_idx][vitx_1d_tab[tx_type]];
+
+  assert(col_txfm != NULL);
+  assert(row_txfm != NULL);
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  for (int i = 0; i < buf_size_h; i += 16) {
+    __m256i buf0[64];
+    const int32_t *input_row = input + i * input_stride;
+    for (int j = 0; j < AOMMIN(2, buf_size_w_div16); ++j) {
+      __m256i *buf0_cur = buf0 + j * 16;
+      const int32_t *input_cur = input_row + j * 16;
+      load_buffer_32bit_to_16bit_w16_avx2(input_cur, input_stride, buf0_cur,
+                                          16);
+      transpose_16bit_16x16_avx2(buf0_cur, buf0_cur);
+    }
+    if (rect_type == 1 || rect_type == -1) {
+      round_shift_avx2(buf0, buf0, input_stride);  // rect special code
+    }
+    row_txfm(buf0, buf0, cos_bit_row);
+    round_shift_16bit_w16_avx2(buf0, txfm_size_col, shift[0]);
+
+    __m256i *buf1_cur = buf1 + i;
+    if (lr_flip) {
+      for (int j = 0; j < buf_size_w_div16; ++j) {
+        __m256i temp[16];
+        flip_buf_av2(buf0 + 16 * j, temp, 16);
+        int offset = txfm_size_row * (buf_size_w_div16 - 1 - j);
+        transpose_16bit_16x16_avx2(temp, buf1_cur + offset);
+      }
+    } else {
+      for (int j = 0; j < buf_size_w_div16; ++j) {
+        transpose_16bit_16x16_avx2(buf0 + 16 * j, buf1_cur + txfm_size_row * j);
+      }
+    }
+  }
+  for (int i = 0; i < buf_size_w_div16; i++) {
+    __m256i *buf1_cur = buf1 + i * txfm_size_row;
+    col_txfm(buf1_cur, buf1_cur, cos_bit_col);
+    round_shift_16bit_w16_avx2(buf1_cur, txfm_size_row, shift[1]);
+  }
+  for (int i = 0; i < buf_size_w_div16; i++) {
+    lowbd_write_buffer_16xn_avx2(buf1 + i * txfm_size_row, output + 16 * i,
+                                 stride, ud_flip, txfm_size_row);
+  }
+}
+
 static INLINE void iidentity16_row_16xn_avx2(__m256i *out, const int32_t *input,
                                              int stride, int shift, int height,
                                              int rect) {
@@ -961,3 +1040,201 @@
     default: break;
   }
 }
+
+static INLINE void lowbd_inv_txfm2d_add_idtx_avx2(const int32_t *input,
+                                                  uint8_t *output, int stride,
+                                                  TX_SIZE tx_size) {
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int input_stride = AOMMIN(32, txfm_size_col);
+  const int row_max = AOMMIN(32, txfm_size_row);
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+  __m256i buf[32];
+  for (int i = 0; i < input_stride; i += 16) {
+    identity_row_16xn_avx2(buf, input + i, input_stride, shift[0], row_max,
+                           txw_idx, rect_type);
+    identity_col_16xn_avx2(output + i, stride, buf, shift[1], row_max, txh_idx);
+  }
+}
+
+static INLINE void lowbd_inv_txfm2d_add_h_identity_avx2(const int32_t *input,
+                                                        uint8_t *output,
+                                                        int stride,
+                                                        TX_TYPE tx_type,
+                                                        TX_SIZE tx_size) {
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int txfm_size_col_notzero = AOMMIN(32, txfm_size_col);
+  const int txfm_size_row_notzero = AOMMIN(32, txfm_size_row);
+  const int input_stride = txfm_size_col_notzero;
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+
+  const transform_1d_avx2 col_txfm =
+      lowbd_txfm_all_1d_w16_arr[txh_idx][vitx_1d_tab[tx_type]];
+
+  assert(col_txfm != NULL);
+
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  for (int i = 0; i < txfm_size_col_notzero; i += 16) {
+    __m256i buf0[64];
+    identity_row_16xn_avx2(buf0, input + i, input_stride, shift[0],
+                           txfm_size_row_notzero, txw_idx, rect_type);
+    col_txfm(buf0, buf0, cos_bit_col);
+    __m256i mshift = _mm256_set1_epi16(1 << (15 + shift[1]));
+    int k = ud_flip ? (txfm_size_row - 1) : 0;
+    const int step = ud_flip ? -1 : 1;
+    for (int j = 0; j < txfm_size_row; ++j, k += step) {
+      __m256i res = _mm256_mulhrs_epi16(buf0[k], mshift);
+      write_recon_w16_avx2(res, output + i + j * stride);
+    }
+  }
+}
+
+static INLINE void lowbd_inv_txfm2d_add_v_identity_avx2(const int32_t *input,
+                                                        uint8_t *output,
+                                                        int stride,
+                                                        TX_TYPE tx_type,
+                                                        TX_SIZE tx_size) {
+  __m256i buf1[64];
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int buf_size_w_div16 = txfm_size_col >> 4;
+  const int buf_size_h_div16 = AOMMIN(32, txfm_size_row) >> 4;
+  const int input_stride = AOMMIN(32, txfm_size_col);
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+
+  const transform_1d_avx2 row_txfm =
+      lowbd_txfm_all_1d_w16_arr[txw_idx][hitx_1d_tab[tx_type]];
+
+  assert(row_txfm != NULL);
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  for (int i = 0; i < buf_size_h_div16; i++) {
+    __m256i buf0[64];
+    const int32_t *input_row = input + i * input_stride * 16;
+    for (int j = 0; j < AOMMIN(4, buf_size_w_div16); ++j) {
+      __m256i *buf0_cur = buf0 + j * 16;
+      load_buffer_32bit_to_16bit_w16_avx2(input_row + j * 16, input_stride,
+                                          buf0_cur, 16);
+      transpose_16bit_16x16_avx2(buf0_cur, buf0_cur);
+    }
+    if (rect_type == 1 || rect_type == -1) {
+      round_shift_avx2(buf0, buf0, input_stride);  // rect special code
+    }
+    row_txfm(buf0, buf0, cos_bit_row);
+    round_shift_16bit_w16_avx2(buf0, txfm_size_col, shift[0]);
+    __m256i *_buf1 = buf1;
+    if (lr_flip) {
+      for (int j = 0; j < buf_size_w_div16; ++j) {
+        __m256i temp[16];
+        flip_buf_av2(buf0 + 16 * j, temp, 16);
+        transpose_16bit_16x16_avx2(temp,
+                                   _buf1 + 16 * (buf_size_w_div16 - 1 - j));
+      }
+    } else {
+      for (int j = 0; j < buf_size_w_div16; ++j) {
+        transpose_16bit_16x16_avx2(buf0 + 16 * j, _buf1 + 16 * j);
+      }
+    }
+    for (int j = 0; j < buf_size_w_div16; ++j) {
+      identity_col_16xn_avx2(output + i * 16 * stride + j * 16, stride,
+                             buf1 + j * 16, shift[1], 16, txh_idx);
+    }
+  }
+}
+
+// for 32x32,32x64,64x32,64x64,16x32,32x16,64x16,16x64
+static INLINE void lowbd_inv_txfm2d_add_universe_avx2(
+    const int32_t *input, uint8_t *output, int stride, TX_TYPE tx_type,
+    TX_SIZE tx_size, int eob) {
+  (void)eob;
+  switch (tx_type) {
+    case DCT_DCT:
+    case ADST_DCT:   // ADST in vertical, DCT in horizontal
+    case DCT_ADST:   // DCT  in vertical, ADST in horizontal
+    case ADST_ADST:  // ADST in both directions
+    case FLIPADST_DCT:
+    case DCT_FLIPADST:
+    case FLIPADST_FLIPADST:
+    case ADST_FLIPADST:
+    case FLIPADST_ADST:
+      lowbd_inv_txfm2d_add_no_identity_avx2(input, output, stride, tx_type,
+                                            tx_size);
+      break;
+    case IDTX:
+      lowbd_inv_txfm2d_add_idtx_avx2(input, output, stride, tx_size);
+      break;
+    case V_DCT:
+    case V_ADST:
+    case V_FLIPADST:
+      lowbd_inv_txfm2d_add_h_identity_avx2(input, output, stride, tx_type,
+                                           tx_size);
+      break;
+    case H_DCT:
+    case H_ADST:
+    case H_FLIPADST:
+      lowbd_inv_txfm2d_add_v_identity_avx2(input, output, stride, tx_type,
+                                           tx_size);
+      break;
+    default:
+      av1_lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, tx_size,
+                                     eob);
+      break;
+  }
+}
+
+void av1_lowbd_inv_txfm2d_add_avx2(const int32_t *input, uint8_t *output,
+                                   int stride, TX_TYPE tx_type, TX_SIZE tx_size,
+                                   int eob) {
+  switch (tx_size) {
+    case TX_4X4:
+    case TX_8X8:  // 8x8 transform
+    case TX_4X8:
+    case TX_8X4:
+    case TX_8X16:  // 8x16 transform
+    case TX_16X8:  // 16x8 transform
+    case TX_4X16:
+    case TX_16X4:
+    case TX_8X32:  // 8x32 transform
+    case TX_32X8:  // 32x8 transform
+      av1_lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, tx_size,
+                                     eob);
+      break;
+    case TX_16X16:  // 16x16 transform
+    case TX_32X32:  // 32x32 transform
+    case TX_64X64:  // 64x64 transform
+    case TX_16X32:  // 16x32 transform
+    case TX_32X16:  // 32x16 transform
+    case TX_32X64:  // 32x64 transform
+    case TX_64X32:  // 64x32 transform
+    case TX_16X64:  // 16x64 transform
+    case TX_64X16:  // 64x16 transform
+    default:
+      lowbd_inv_txfm2d_add_universe_avx2(input, output, stride, tx_type,
+                                         tx_size, eob);
+      break;
+  }
+}
+
+void av1_inv_txfm_add_avx2(const tran_low_t *dqcoeff, uint8_t *dst, int stride,
+                           const TxfmParam *txfm_param) {
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  if (!txfm_param->lossless) {
+    av1_lowbd_inv_txfm2d_add_avx2(dqcoeff, dst, stride, tx_type,
+                                  txfm_param->tx_size, txfm_param->eob);
+  } else {
+    av1_inv_txfm_add_c(dqcoeff, dst, stride, txfm_param);
+  }
+}
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c
index b07d28e..8a03e6d 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.c
+++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -2172,29 +2172,6 @@
   }
 }
 
-// 1D itx types
-typedef enum ATTRIBUTE_PACKED {
-  IDCT_1D,
-  IADST_1D,
-  IFLIPADST_1D = IADST_1D,
-  IIDENTITY_1D,
-  ITX_TYPES_1D,
-} ITX_TYPE_1D;
-
-static const ITX_TYPE_1D vitx_1d_tab[TX_TYPES] = {
-  IDCT_1D,      IADST_1D,     IDCT_1D,      IADST_1D,
-  IFLIPADST_1D, IDCT_1D,      IFLIPADST_1D, IADST_1D,
-  IFLIPADST_1D, IIDENTITY_1D, IDCT_1D,      IIDENTITY_1D,
-  IADST_1D,     IIDENTITY_1D, IFLIPADST_1D, IIDENTITY_1D,
-};
-
-static const ITX_TYPE_1D hitx_1d_tab[TX_TYPES] = {
-  IDCT_1D,      IDCT_1D,      IADST_1D,     IADST_1D,
-  IDCT_1D,      IFLIPADST_1D, IFLIPADST_1D, IFLIPADST_1D,
-  IADST_1D,     IIDENTITY_1D, IIDENTITY_1D, IDCT_1D,
-  IIDENTITY_1D, IADST_1D,     IIDENTITY_1D, IFLIPADST_1D,
-};
-
 // 1D functions process process 8 pixels at one time.
 static const transform_1d_ssse3
     lowbd_txfm_all_1d_w8_arr[TX_SIZES][ITX_TYPES_1D] = {
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.h b/av1/common/x86/av1_inv_txfm_ssse3.h
index 8ef480c..e0be404 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.h
+++ b/av1/common/x86/av1_inv_txfm_ssse3.h
@@ -31,6 +31,30 @@
 #ifdef __cplusplus
 extern "C" {
 #endif
+
+// 1D itx types
+typedef enum ATTRIBUTE_PACKED {
+  IDCT_1D,
+  IADST_1D,
+  IFLIPADST_1D = IADST_1D,
+  IIDENTITY_1D,
+  ITX_TYPES_1D,
+} ITX_TYPE_1D;
+
+static const ITX_TYPE_1D vitx_1d_tab[TX_TYPES] = {
+  IDCT_1D,      IADST_1D,     IDCT_1D,      IADST_1D,
+  IFLIPADST_1D, IDCT_1D,      IFLIPADST_1D, IADST_1D,
+  IFLIPADST_1D, IIDENTITY_1D, IDCT_1D,      IIDENTITY_1D,
+  IADST_1D,     IIDENTITY_1D, IFLIPADST_1D, IIDENTITY_1D,
+};
+
+static const ITX_TYPE_1D hitx_1d_tab[TX_TYPES] = {
+  IDCT_1D,      IDCT_1D,      IADST_1D,     IADST_1D,
+  IDCT_1D,      IFLIPADST_1D, IFLIPADST_1D, IFLIPADST_1D,
+  IADST_1D,     IIDENTITY_1D, IIDENTITY_1D, IDCT_1D,
+  IIDENTITY_1D, IADST_1D,     IIDENTITY_1D, IFLIPADST_1D,
+};
+
 typedef void (*transform_1d_ssse3)(const __m128i *input, __m128i *output,
                                    int8_t cos_bit);
 
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index ec2ca4f..73feb16 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -319,10 +319,17 @@
 #if HAVE_SSSE3
 #if defined(_MSC_VER) || defined(__SSSE3__)
 #include "av1/common/x86/av1_inv_txfm_ssse3.h"
-
 INSTANTIATE_TEST_CASE_P(SSSE3, AV1LbdInvTxfm2d,
                         ::testing::Values(av1_lowbd_inv_txfm2d_add_ssse3));
 #endif  // _MSC_VER || __SSSE3__
-#endif  // HAVE_SSE2
+#endif  // HAVE_SSSE3
+
+#if HAVE_AVX2
+#if defined(_MSC_VER) || defined(__AVX2__)
+#include "av1/common/x86/av1_inv_txfm_avx2.h"
+INSTANTIATE_TEST_CASE_P(AVX2, AV1LbdInvTxfm2d,
+                        ::testing::Values(av1_lowbd_inv_txfm2d_add_avx2));
+#endif  // (_MSC_VER) || (__AVX2__)
+#endif  // HAVE_AVX2
 
 }  // namespace