Optimize highbd fwd_txfm modules

Enabled sse4_1 optimizations for tx_sizes 16x64,
64x16,32x64 and 64x32

Module level gains:
Tx_size    Gain w.r.t. C
16x64        4.4x
64x16        4.0x
32x64        3.3x
64x32        3.5x

When tested for 10 frames of crowd_run_1080p_10 at 6 mbps
for speed=1 preset, observed ~1.00% reduction in encoder time.

Change-Id: I5a1a504a9915bf6f81f3df95c8bdd21689636e38
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 481d6b8..aabb001 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -243,9 +243,13 @@
   add_proto qw/void av1_fwd_txfm2d_64x64/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
   specialize qw/av1_fwd_txfm2d_64x64 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_32x64/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
+  specialize qw/av1_fwd_txfm2d_32x64 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_64x32/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
+  specialize qw/av1_fwd_txfm2d_64x32 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_16x64/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
+  specialize qw/av1_fwd_txfm2d_16x64 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_64x16/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
+  specialize qw/av1_fwd_txfm2d_64x16 sse4_1/;
 
   #
   # Motion search
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index d3ff3eb..1b0e90f 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -367,7 +367,8 @@
   assert(txfm_param->tx_type == DCT_DCT);
   int32_t *dst_coeff = (int32_t *)coeff;
   const int bd = txfm_param->bd;
-  av1_fwd_txfm2d_32x64_c(src_diff, dst_coeff, diff_stride, DCT_DCT, bd);
+  av1_fwd_txfm2d_32x64(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
+                       bd);
 }
 
 static void highbd_fwd_txfm_64x32(const int16_t *src_diff, tran_low_t *coeff,
@@ -375,7 +376,8 @@
   assert(txfm_param->tx_type == DCT_DCT);
   int32_t *dst_coeff = (int32_t *)coeff;
   const int bd = txfm_param->bd;
-  av1_fwd_txfm2d_64x32_c(src_diff, dst_coeff, diff_stride, DCT_DCT, bd);
+  av1_fwd_txfm2d_64x32(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
+                       bd);
 }
 
 static void highbd_fwd_txfm_16x64(const int16_t *src_diff, tran_low_t *coeff,
@@ -383,7 +385,7 @@
   assert(txfm_param->tx_type == DCT_DCT);
   int32_t *dst_coeff = (int32_t *)coeff;
   const int bd = txfm_param->bd;
-  av1_fwd_txfm2d_16x64_c(src_diff, dst_coeff, diff_stride, DCT_DCT, bd);
+  av1_fwd_txfm2d_16x64(src_diff, dst_coeff, diff_stride, DCT_DCT, bd);
 }
 
 static void highbd_fwd_txfm_64x16(const int16_t *src_diff, tran_low_t *coeff,
@@ -391,7 +393,7 @@
   assert(txfm_param->tx_type == DCT_DCT);
   int32_t *dst_coeff = (int32_t *)coeff;
   const int bd = txfm_param->bd;
-  av1_fwd_txfm2d_64x16_c(src_diff, dst_coeff, diff_stride, DCT_DCT, bd);
+  av1_fwd_txfm2d_64x16(src_diff, dst_coeff, diff_stride, DCT_DCT, bd);
 }
 
 static void highbd_fwd_txfm_64x64(const int16_t *src_diff, tran_low_t *coeff,
diff --git a/av1/encoder/x86/highbd_fwd_txfm_sse4.c b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
index a9516ca..9f28c4a 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_sse4.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
@@ -2023,6 +2023,86 @@
   (void)bd;
 }
 
+void av1_fwd_txfm2d_32x64_sse4_1(const int16_t *input, int32_t *coeff,
+                                 int stride, TX_TYPE tx_type, int bd) {
+  (void)tx_type;
+  __m128i in[512];
+  __m128i *outcoef128 = (__m128i *)coeff;
+  const int8_t *shift = fwd_txfm_shift_ls[TX_32X64];
+  const int txw_idx = get_txw_idx(TX_32X64);
+  const int txh_idx = get_txh_idx(TX_32X64);
+  const int txfm_size_col = tx_size_wide[TX_32X64];
+  const int txfm_size_row = tx_size_high[TX_32X64];
+  int bitcol = fwd_cos_bit_col[txw_idx][txh_idx];
+  int bitrow = fwd_cos_bit_row[txw_idx][txh_idx];
+  const int num_row = txfm_size_row >> 2;
+  const int num_col = txfm_size_col >> 2;
+
+  // column transform
+  load_buffer_32x8n(input, in, stride, 0, 0, shift[0], txfm_size_row);
+  for (int i = 0; i < num_col; i++) {
+    av1_fdct64_new_sse4_1((in + i), (in + i), bitcol, num_col, num_col);
+  }
+  for (int i = 0; i < num_col; i++) {
+    col_txfm_16x16_rounding((in + i * txfm_size_row), -shift[1]);
+  }
+  transpose_8nx8n(in, outcoef128, txfm_size_col, txfm_size_row);
+
+  // row transform
+  for (int i = 0; i < num_row; i++) {
+    av1_fdct32_new_sse4_1((outcoef128 + i), (in + i), bitrow, num_row);
+  }
+  transpose_8nx8n(in, outcoef128, txfm_size_row, txfm_size_col);
+  av1_round_shift_rect_array_32_sse4_1(outcoef128, outcoef128, 512, -shift[2],
+                                       NewSqrt2);
+  (void)bd;
+}
+
+void av1_fwd_txfm2d_64x32_sse4_1(const int16_t *input, int32_t *coeff,
+                                 int stride, TX_TYPE tx_type, int bd) {
+  (void)tx_type;
+  __m128i in[512];
+  __m128i *outcoef128 = (__m128i *)coeff;
+  const int8_t *shift = fwd_txfm_shift_ls[TX_64X32];
+  const int txw_idx = get_txw_idx(TX_64X32);
+  const int txh_idx = get_txh_idx(TX_64X32);
+  const int txfm_size_col = tx_size_wide[TX_64X32];
+  const int txfm_size_row = tx_size_high[TX_64X32];
+  int bitcol = fwd_cos_bit_col[txw_idx][txh_idx];
+  int bitrow = fwd_cos_bit_row[txw_idx][txh_idx];
+  const int num_row = txfm_size_row >> 2;
+  const int num_col = txfm_size_col >> 2;
+
+  // column transform
+  for (int i = 0; i < 32; i++) {
+    load_buffer_4x4(input + 0 + i * stride, in + 0 + i * 16, 4, 0, 0, shift[0]);
+    load_buffer_4x4(input + 16 + i * stride, in + 4 + i * 16, 4, 0, 0,
+                    shift[0]);
+    load_buffer_4x4(input + 32 + i * stride, in + 8 + i * 16, 4, 0, 0,
+                    shift[0]);
+    load_buffer_4x4(input + 48 + i * stride, in + 12 + i * 16, 4, 0, 0,
+                    shift[0]);
+  }
+
+  for (int i = 0; i < num_col; i++) {
+    av1_fdct32_new_sse4_1((in + i), (in + i), bitcol, num_col);
+  }
+
+  for (int i = 0; i < num_row; i++) {
+    col_txfm_16x16_rounding((in + i * txfm_size_col), -shift[1]);
+  }
+  transpose_8nx8n(in, outcoef128, txfm_size_col, txfm_size_row);
+
+  // row transform
+  for (int i = 0; i < num_row; i++) {
+    av1_fdct64_new_sse4_1((outcoef128 + i), (in + i), bitrow, num_row, num_row);
+  }
+  transpose_8nx8n(in, outcoef128, txfm_size_row, txfm_size_col >> 1);
+  av1_round_shift_rect_array_32_sse4_1(outcoef128, outcoef128, 512 >> 1,
+                                       -shift[2], NewSqrt2);
+  (void)bd;
+}
+
 void av1_fwd_txfm2d_32x16_sse4_1(const int16_t *input, int32_t *coeff,
                                  int stride, TX_TYPE tx_type, int bd) {
   assert(DCT_DCT == tx_type);
@@ -2182,3 +2262,84 @@
   transpose_8nx8n(in, outcoeff128, txfm_size_row, txfm_size_col);
   (void)bd;
 }
+
+void av1_fwd_txfm2d_16x64_sse4_1(const int16_t *input, int32_t *coeff,
+                                 int stride, TX_TYPE tx_type, int bd) {
+  __m128i in[256];
+  __m128i *outcoeff128 = (__m128i *)coeff;
+  const int8_t *shift = fwd_txfm_shift_ls[TX_16X64];
+  const int txw_idx = get_txw_idx(TX_16X64);
+  const int txh_idx = get_txh_idx(TX_16X64);
+  const int txfm_size_col = tx_size_wide[TX_16X64];
+  const int txfm_size_row = tx_size_high[TX_16X64];
+  int bitcol = fwd_cos_bit_col[txw_idx][txh_idx];
+  int bitrow = fwd_cos_bit_row[txw_idx][txh_idx];
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  const int num_col = txfm_size_col >> 2;
+  // col tranform
+  for (int i = 0; i < txfm_size_row; i += num_col) {
+    load_buffer_4x4(input + (i + 0) * stride, in + (i + 0) * num_col, num_col,
+                    ud_flip, lr_flip, shift[0]);
+    load_buffer_4x4(input + (i + 1) * stride, in + (i + 1) * num_col, num_col,
+                    ud_flip, lr_flip, shift[0]);
+    load_buffer_4x4(input + (i + 2) * stride, in + (i + 2) * num_col, num_col,
+                    ud_flip, lr_flip, shift[0]);
+    load_buffer_4x4(input + (i + 3) * stride, in + (i + 3) * num_col, num_col,
+                    ud_flip, lr_flip, shift[0]);
+  }
+
+  for (int i = 0; i < num_col; i++) {
+    av1_fdct64_new_sse4_1(in + i, outcoeff128 + i, bitcol, num_col, num_col);
+  }
+
+  col_txfm_16x16_rounding(outcoeff128, -shift[1]);
+  col_txfm_16x16_rounding(outcoeff128 + 64, -shift[1]);
+  col_txfm_16x16_rounding(outcoeff128 + 128, -shift[1]);
+  col_txfm_16x16_rounding(outcoeff128 + 192, -shift[1]);
+
+  transpose_8nx8n(outcoeff128, in, txfm_size_col, 32);
+  fdct16x16_sse4_1(in, in, bitrow, 8);
+  transpose_8nx8n(in, outcoeff128, 32, txfm_size_col);
+  memset(coeff + txfm_size_col * 32, 0, txfm_size_col * 32 * sizeof(*coeff));
+  (void)bd;
+}
+
+void av1_fwd_txfm2d_64x16_sse4_1(const int16_t *input, int32_t *coeff,
+                                 int stride, TX_TYPE tx_type, int bd) {
+  __m128i in[256];
+  __m128i *outcoeff128 = (__m128i *)coeff;
+  const int8_t *shift = fwd_txfm_shift_ls[TX_64X16];
+  const int txw_idx = get_txw_idx(TX_64X16);
+  const int txh_idx = get_txh_idx(TX_64X16);
+  const int txfm_size_col = tx_size_wide[TX_64X16];
+  const int txfm_size_row = tx_size_high[TX_64X16];
+  int bitcol = fwd_cos_bit_col[txw_idx][txh_idx];
+  int bitrow = fwd_cos_bit_row[txw_idx][txh_idx];
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  // col tranform
+  for (int i = 0; i < txfm_size_row; i++) {
+    load_buffer_4x4(input + 0 + i * stride, in + 0 + i * txfm_size_row, 4,
+                    ud_flip, lr_flip, shift[0]);
+    load_buffer_4x4(input + 16 + i * stride, in + 4 + i * txfm_size_row, 4,
+                    ud_flip, lr_flip, shift[0]);
+    load_buffer_4x4(input + 32 + i * stride, in + 8 + i * txfm_size_row, 4,
+                    ud_flip, lr_flip, shift[0]);
+    load_buffer_4x4(input + 48 + i * stride, in + 12 + i * txfm_size_row, 4,
+                    ud_flip, lr_flip, shift[0]);
+  }
+
+  fdct16x16_sse4_1(in, outcoeff128, bitcol, txfm_size_row);
+  col_txfm_16x16_rounding(outcoeff128, -shift[1]);
+  col_txfm_16x16_rounding(outcoeff128 + 64, -shift[1]);
+  col_txfm_16x16_rounding(outcoeff128 + 128, -shift[1]);
+  col_txfm_16x16_rounding(outcoeff128 + 192, -shift[1]);
+
+  transpose_8nx8n(outcoeff128, in, txfm_size_col, txfm_size_row);
+  for (int i = 0; i < 4; i++) {
+    av1_fdct64_new_sse4_1(in + i, in + i, bitrow, 4, 4);
+  }
+  transpose_8nx8n(in, outcoeff128, txfm_size_row, 32);
+  (void)bd;
+}