Optimize highbd 4x16 and 16x4 inv_txfm

Enabled sse4_1 optimizations for tx_sizes 4x16 and 16x4.

Module level gains:
Tx_size    Gain w.r.t. C
4x16         6.01x
16x4         7.36x

When tested for 20 frames of crowd_run_360p_10 at 1 mbps
for speed=1 preset, observed ~0.5% reduction in encoder time.

Change-Id: I0c5d6530c666150be0de062a7084c7a2bf61410f
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index a6d5138..033eeb7 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -135,6 +135,10 @@
 specialize qw/av1_highbd_inv_txfm_add_4x8 sse4_1/;
 add_proto qw/void av1_highbd_inv_txfm_add_8x4/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_8x4 sse4_1/;
+add_proto qw/void av1_highbd_inv_txfm_add_4x16/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
+specialize qw/av1_highbd_inv_txfm_add_4x16 sse4_1/;
+add_proto qw/void av1_highbd_inv_txfm_add_16x4/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
+specialize qw/av1_highbd_inv_txfm_add_16x4 sse4_1/;
 
 add_proto qw/void av1_highbd_iwht4x4_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
 add_proto qw/void av1_highbd_iwht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 0261f6f..55925a5 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -86,15 +86,15 @@
                              txfm_param->tx_type, txfm_param->bd);
 }
 
-void av1_highbd_inv_txfm_add_16x4(const tran_low_t *input, uint8_t *dest,
-                                  int stride, const TxfmParam *txfm_param) {
+void av1_highbd_inv_txfm_add_16x4_c(const tran_low_t *input, uint8_t *dest,
+                                    int stride, const TxfmParam *txfm_param) {
   const int32_t *src = cast_to_int32(input);
   av1_inv_txfm2d_add_16x4_c(src, CONVERT_TO_SHORTPTR(dest), stride,
                             txfm_param->tx_type, txfm_param->bd);
 }
 
-void av1_highbd_inv_txfm_add_4x16(const tran_low_t *input, uint8_t *dest,
-                                  int stride, const TxfmParam *txfm_param) {
+void av1_highbd_inv_txfm_add_4x16_c(const tran_low_t *input, uint8_t *dest,
+                                    int stride, const TxfmParam *txfm_param) {
   const int32_t *src = cast_to_int32(input);
   av1_inv_txfm2d_add_4x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
                             txfm_param->tx_type, txfm_param->bd);
@@ -263,10 +263,10 @@
       av1_highbd_inv_txfm_add_4x4_c(input, dest, stride, txfm_param);
       break;
     case TX_16X4:
-      av1_highbd_inv_txfm_add_16x4(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_16x4_c(input, dest, stride, txfm_param);
       break;
     case TX_4X16:
-      av1_highbd_inv_txfm_add_4x16(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_4x16_c(input, dest, stride, txfm_param);
       break;
     case TX_8X32:
       av1_highbd_inv_txfm_add_8x32_c(input, dest, stride, txfm_param);
diff --git a/av1/common/idct.h b/av1/common/idct.h
index 00a55f9..004d25d 100644
--- a/av1/common/idct.h
+++ b/av1/common/idct.h
@@ -44,12 +44,6 @@
   return (const int32_t *)input;
 }
 
-typedef void(highbd_inv_txfm_add)(const tran_low_t *input, uint8_t *dest,
-                                  int stride, const TxfmParam *param);
-
-highbd_inv_txfm_add av1_highbd_inv_txfm_add_16x4;
-highbd_inv_txfm_add av1_highbd_inv_txfm_add_4x16;
-
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/common/x86/highbd_inv_txfm_avx2.c b/av1/common/x86/highbd_inv_txfm_avx2.c
index 7e53525..9a1224c 100644
--- a/av1/common/x86/highbd_inv_txfm_avx2.c
+++ b/av1/common/x86/highbd_inv_txfm_avx2.c
@@ -1320,10 +1320,10 @@
       av1_highbd_inv_txfm_add_4x4_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_16X4:
-      av1_highbd_inv_txfm_add_16x4(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_16x4_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_4X16:
-      av1_highbd_inv_txfm_add_4x16(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_4x16_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_8X32:
       av1_highbd_inv_txfm_add_8x32_sse4_1(input, dest, stride, txfm_param);
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index 254abcd..41afae0 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -5480,6 +5480,121 @@
                                  txfm_size_row, bd);
 }
 
+static void highbd_inv_txfm2d_add_4x16_sse4_1(const int32_t *input,
+                                              uint16_t *output, int stride,
+                                              TX_TYPE tx_type, TX_SIZE tx_size,
+                                              int eob, const int bd) {
+  (void)eob;
+  __m128i buf1[16];
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int buf_size_h_div8 = txfm_size_row >> 2;
+  const transform_1d_sse4_1 row_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txw_idx][hitx_1d_tab[tx_type]][0];
+  const transform_1d_sse4_1 col_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txh_idx][vitx_1d_tab[tx_type]][2];
+  const int input_stride = AOMMIN(32, txfm_size_col);
+
+  assert(col_txfm != NULL);
+  assert(row_txfm != NULL);
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+
+  // 1st stage: column transform
+  __m128i buf0[16];
+  const int32_t *input_row = input;
+  __m128i *buf0_cur = buf0;
+  load_buffer_32bit_input(input_row, input_stride, buf0_cur, txfm_size_row);
+  for (int i = 0; i < (txfm_size_row >> 2); i++) {
+    row_txfm(buf0 + (i << 2), buf0 + (i << 2),
+             inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]);
+  }
+
+  av1_round_shift_array_32_sse4_1(buf0, buf0, txfm_size_row, -shift[0]);
+
+  if (lr_flip) {
+    for (int j = 0; j < buf_size_h_div8; ++j) {
+      TRANSPOSE_4X4(buf0[4 * j + 3], buf0[4 * j + 2], buf0[4 * j + 1],
+                    buf0[4 * j], buf1[4 * j], buf1[4 * j + 1], buf1[4 * j + 2],
+                    buf1[4 * j + 3]);
+    }
+  } else {
+    for (int j = 0; j < buf_size_h_div8; ++j) {
+      TRANSPOSE_4X4(buf0[4 * j], buf0[4 * j + 1], buf0[4 * j + 2],
+                    buf0[4 * j + 3], buf1[4 * j], buf1[4 * j + 1],
+                    buf1[4 * j + 2], buf1[4 * j + 3]);
+    }
+  }
+
+  // 2nd stage: column transform
+  col_txfm(buf1, buf1, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+
+  av1_round_shift_array_32_sse4_1(buf1, buf1, txfm_size_row, -shift[1]);
+
+  // write to buffer
+  highbd_write_buffer_4xn_sse4_1(buf1, output, stride, ud_flip, txfm_size_row,
+                                 bd);
+}
+
+static void highbd_inv_txfm2d_add_16x4_sse4_1(const int32_t *input,
+                                              uint16_t *output, int stride,
+                                              TX_TYPE tx_type, TX_SIZE tx_size,
+                                              int eob, const int bd) {
+  (void)eob;
+  __m128i buf1[16];
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int buf_size_w_div8 = txfm_size_col >> 2;
+  const transform_1d_sse4_1 row_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txw_idx][hitx_1d_tab[tx_type]][2];
+  const transform_1d_sse4_1 col_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txh_idx][vitx_1d_tab[tx_type]][0];
+
+  assert(col_txfm != NULL);
+  assert(row_txfm != NULL);
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+
+  // 1st stage: column transform
+  __m128i buf0[16];
+  const int32_t *input_row = input;
+  load_buffer_32bit_input(input_row, 4, buf0, txfm_size_col);
+
+  for (int j = 0; j < buf_size_w_div8; j++) {
+    TRANSPOSE_4X4(buf0[j], buf0[j + 4], buf0[j + 8], buf0[j + 12], buf1[4 * j],
+                  buf1[4 * j + 1], buf1[4 * j + 2], buf1[4 * j + 3]);
+  }
+  row_txfm(buf1, buf0, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]);
+
+  __m128i *buf1_ptr;
+  if (lr_flip) {
+    flip_buf_sse2(buf0, buf1, txfm_size_col);
+    buf1_ptr = buf1;
+  } else {
+    buf1_ptr = buf0;
+  }
+
+  // 2nd stage: column transform
+  for (int i = 0; i < buf_size_w_div8; i++) {
+    col_txfm(buf1_ptr + i * txfm_size_row, buf1_ptr + i * txfm_size_row,
+             inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+  }
+  av1_round_shift_array_32_sse4_1(buf1_ptr, buf1_ptr, txfm_size_col, -shift[1]);
+
+  // write to buffer
+  for (int i = 0; i < (txfm_size_col >> 3); i++) {
+    highbd_write_buffer_8xn_sse4_1(buf1_ptr + i * txfm_size_row * 2,
+                                   output + 8 * i, stride, ud_flip,
+                                   txfm_size_row, bd);
+  }
+}
+
 void av1_highbd_inv_txfm2d_add_universe_sse4_1(const int32_t *input,
                                                uint8_t *output, int stride,
                                                TX_TYPE tx_type, TX_SIZE tx_size,
@@ -5558,6 +5673,62 @@
   }
 }
 
+void av1_highbd_inv_txfm_add_4x16_sse4_1(const tran_low_t *input, uint8_t *dest,
+                                         int stride,
+                                         const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const TX_SIZE tx_size = txfm_param->tx_size;
+  const int32_t *src = cast_to_int32(input);
+  int eob = txfm_param->eob;
+  switch (tx_type) {
+      // Assembly version doesn't support some transform types, so use C version
+      // for those.
+    case V_DCT:
+    case H_DCT:
+    case V_ADST:
+    case H_ADST:
+    case V_FLIPADST:
+    case H_FLIPADST:
+    case IDTX:
+      av1_inv_txfm2d_add_4x16_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                                bd);
+      break;
+    default:
+      highbd_inv_txfm2d_add_4x16_sse4_1(input, CONVERT_TO_SHORTPTR(dest),
+                                        stride, tx_type, tx_size, eob, bd);
+      break;
+  }
+}
+
+void av1_highbd_inv_txfm_add_16x4_sse4_1(const tran_low_t *input, uint8_t *dest,
+                                         int stride,
+                                         const TxfmParam *txfm_param) {
+  int bd = txfm_param->bd;
+  const TX_TYPE tx_type = txfm_param->tx_type;
+  const TX_SIZE tx_size = txfm_param->tx_size;
+  const int32_t *src = cast_to_int32(input);
+  int eob = txfm_param->eob;
+  switch (tx_type) {
+      // Assembly version doesn't support some transform types, so use C version
+      // for those.
+    case V_DCT:
+    case H_DCT:
+    case V_ADST:
+    case H_ADST:
+    case V_FLIPADST:
+    case H_FLIPADST:
+    case IDTX:
+      av1_inv_txfm2d_add_16x4_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                                bd);
+      break;
+    default:
+      highbd_inv_txfm2d_add_16x4_sse4_1(input, CONVERT_TO_SHORTPTR(dest),
+                                        stride, tx_type, tx_size, eob, bd);
+      break;
+  }
+}
+
 void av1_highbd_inv_txfm_add_sse4_1(const tran_low_t *input, uint8_t *dest,
                                     int stride, const TxfmParam *txfm_param) {
   assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
@@ -5594,10 +5765,10 @@
       av1_highbd_inv_txfm_add_4x4_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_16X4:
-      av1_highbd_inv_txfm_add_16x4(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_16x4_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_4X16:
-      av1_highbd_inv_txfm_add_4x16(input, dest, stride, txfm_param);
+      av1_highbd_inv_txfm_add_4x16_sse4_1(input, dest, stride, txfm_param);
       break;
     case TX_8X32:
       av1_highbd_inv_txfm_add_8x32_sse4_1(input, dest, stride, txfm_param);