Optimize highbd fwd_txfm modules

Added sse4_1 variant for identity tx_type of 32.

Enabled sse4_1 optimizations for tx_sizes 32x32 , 32x16 ,
16x32 , 32x8 , 8x32.

Module level gains:
Tx_size    Gain w.r.t. C
32x32      1.75x
32x16      2.68x
16x32      2.38x
32x8       2.88x
8x32       2.72x

When tested for 20 frames of crowd_run_360p_10 at 1 mbps
for speed=1 preset, observed ~0.12% reduction in encoder time.

Change-Id: I833c08ae30e84432b5258e1bf2c10307f3379d72
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index 35c4dd7..0699085 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -123,33 +123,15 @@
 static void highbd_fwd_txfm_16x32(const int16_t *src_diff, tran_low_t *coeff,
                                   int diff_stride, TxfmParam *txfm_param) {
   int32_t *dst_coeff = (int32_t *)coeff;
-  switch (txfm_param->tx_type) {
-    case DCT_DCT:
-      av1_fwd_txfm2d_16x32(src_diff, dst_coeff, diff_stride,
-                           txfm_param->tx_type, txfm_param->bd);
-      break;
-    case IDTX:
-      av1_fwd_txfm2d_16x32_c(src_diff, dst_coeff, diff_stride,
-                             txfm_param->tx_type, txfm_param->bd);
-      break;
-    default: assert(0);
-  }
+  av1_fwd_txfm2d_16x32(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
+                       txfm_param->bd);
 }
 
 static void highbd_fwd_txfm_32x16(const int16_t *src_diff, tran_low_t *coeff,
                                   int diff_stride, TxfmParam *txfm_param) {
   int32_t *dst_coeff = (int32_t *)coeff;
-  switch (txfm_param->tx_type) {
-    case DCT_DCT:
-      av1_fwd_txfm2d_32x16(src_diff, dst_coeff, diff_stride,
-                           txfm_param->tx_type, txfm_param->bd);
-      break;
-    case IDTX:
-      av1_fwd_txfm2d_32x16_c(src_diff, dst_coeff, diff_stride,
-                             txfm_param->tx_type, txfm_param->bd);
-      break;
-    default: assert(0);
-  }
+  av1_fwd_txfm2d_32x16(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
+                       txfm_param->bd);
 }
 
 static void highbd_fwd_txfm_16x4(const int16_t *src_diff, tran_low_t *coeff,
@@ -169,33 +151,15 @@
 static void highbd_fwd_txfm_32x8(const int16_t *src_diff, tran_low_t *coeff,
                                  int diff_stride, TxfmParam *txfm_param) {
   int32_t *dst_coeff = (int32_t *)coeff;
-  switch (txfm_param->tx_type) {
-    case DCT_DCT:
-      av1_fwd_txfm2d_32x8(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
-                          txfm_param->bd);
-      break;
-    case IDTX:
-      av1_fwd_txfm2d_32x8_c(src_diff, dst_coeff, diff_stride,
-                            txfm_param->tx_type, txfm_param->bd);
-      break;
-    default: assert(0);
-  }
+  av1_fwd_txfm2d_32x8(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
+                      txfm_param->bd);
 }
 
 static void highbd_fwd_txfm_8x32(const int16_t *src_diff, tran_low_t *coeff,
                                  int diff_stride, TxfmParam *txfm_param) {
   int32_t *dst_coeff = (int32_t *)coeff;
-  switch (txfm_param->tx_type) {
-    case DCT_DCT:
-      av1_fwd_txfm2d_8x32(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
-                          txfm_param->bd);
-      break;
-    case IDTX:
-      av1_fwd_txfm2d_8x32_c(src_diff, dst_coeff, diff_stride,
-                            txfm_param->tx_type, txfm_param->bd);
-      break;
-    default: assert(0);
-  }
+  av1_fwd_txfm2d_8x32(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
+                      txfm_param->bd);
 }
 
 static void highbd_fwd_txfm_8x8(const int16_t *src_diff, tran_low_t *coeff,
@@ -219,21 +183,7 @@
   int32_t *dst_coeff = (int32_t *)coeff;
   const TX_TYPE tx_type = txfm_param->tx_type;
   const int bd = txfm_param->bd;
-  switch (tx_type) {
-    // use the c version for anything including identity for now
-    case V_DCT:
-    case H_DCT:
-    case V_ADST:
-    case H_ADST:
-    case V_FLIPADST:
-    case H_FLIPADST:
-    case IDTX:
-      av1_fwd_txfm2d_32x32_c(src_diff, dst_coeff, diff_stride, tx_type, bd);
-      break;
-    default:
-      av1_fwd_txfm2d_32x32(src_diff, dst_coeff, diff_stride, tx_type, bd);
-      break;
-  }
+  av1_fwd_txfm2d_32x32(src_diff, dst_coeff, diff_stride, tx_type, bd);
 }
 
 static void highbd_fwd_txfm_32x64(const int16_t *src_diff, tran_low_t *coeff,
diff --git a/av1/encoder/x86/av1_fwd_txfm1d_sse4.c b/av1/encoder/x86/av1_fwd_txfm1d_sse4.c
index 5da723f..865ac31 100644
--- a/av1/encoder/x86/av1_fwd_txfm1d_sse4.c
+++ b/av1/encoder/x86/av1_fwd_txfm1d_sse4.c
@@ -1407,3 +1407,11 @@
   output[startidx] = x10[62];
   output[endidx] = x10[1];
 }
+
+void av1_idtx32_new_sse4_1(__m128i *input, __m128i *output, int cos_bit,
+                           const int col_num) {
+  (void)cos_bit;
+  for (int i = 0; i < 32; i++) {
+    output[i * col_num] = _mm_slli_epi32(input[i * col_num], 2);
+  }
+}
diff --git a/av1/encoder/x86/av1_fwd_txfm2d_sse4.c b/av1/encoder/x86/av1_fwd_txfm2d_sse4.c
index c2903d2..193f9d1 100644
--- a/av1/encoder/x86/av1_fwd_txfm2d_sse4.c
+++ b/av1/encoder/x86/av1_fwd_txfm2d_sse4.c
@@ -55,11 +55,20 @@
                           col_num);
   }
 }
+static void idtx32x32_sse4_1(__m128i *input, __m128i *output,
+                             const int8_t cos_bit, const int8_t *stage_range) {
+  (void)stage_range;
+
+  for (int i = 0; i < 8; i++) {
+    av1_idtx32_new_sse4_1(&input[i * 32], &output[i * 32], cos_bit, 1);
+  }
+}
 
 static INLINE TxfmFuncSSE2 fwd_txfm_type_to_func(TXFM_TYPE txfm_type) {
   switch (txfm_type) {
     case TXFM_TYPE_DCT32: return fdct32_new_sse4_1; break;
     case TXFM_TYPE_DCT64: return fdct64_new_sse4_1; break;
+    case TXFM_TYPE_IDENTITY32: return idtx32x32_sse4_1; break;
     default: assert(0);
   }
   return NULL;
diff --git a/av1/encoder/x86/av1_txfm1d_sse4.h b/av1/encoder/x86/av1_txfm1d_sse4.h
index ccdc327..b3d5b22 100644
--- a/av1/encoder/x86/av1_txfm1d_sse4.h
+++ b/av1/encoder/x86/av1_txfm1d_sse4.h
@@ -30,7 +30,6 @@
                            const int stride);
 void av1_fdct64_new_sse4_1(__m128i *input, __m128i *output, int8_t cos_bit,
                            const int instride, const int outstride);
-
 void av1_fadst4_new_sse4_1(const __m128i *input, __m128i *output,
                            const int8_t cos_bit, const int8_t *stage_range);
 void av1_fadst8_new_sse4_1(const __m128i *input, __m128i *output,
@@ -55,6 +54,10 @@
                            const int8_t cos_bit, const int8_t *stage_range);
 void av1_iadst16_new_sse4_1(const __m128i *input, __m128i *output,
                             const int8_t cos_bit, const int8_t *stage_range);
+
+void av1_idtx32_new_sse4_1(__m128i *input, __m128i *output, int cos_bit,
+                           const int col_num);
+
 static INLINE void transpose_32_4x4(int stride, const __m128i *input,
                                     __m128i *output) {
   __m128i temp0 = _mm_unpacklo_epi32(input[0 * stride], input[2 * stride]);
diff --git a/av1/encoder/x86/highbd_fwd_txfm_sse4.c b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
index bef8abc..d105977 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_sse4.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
@@ -827,6 +827,20 @@
     out[7 + 8 * i] = _mm_add_epi32(in[7 + 8 * i], in[7 + 8 * i]);
   }
 }
+static void idtx32x8_sse4_1(__m128i *in, __m128i *out, int bit, int col_num) {
+  (void)bit;
+  (void)col_num;
+  for (int j = 0; j < 2; j++) {
+    out[j + 8 * 0] = _mm_add_epi32(in[j + 8 * 0], in[j + 8 * 0]);
+    out[j + 8 * 1] = _mm_add_epi32(in[j + 8 * 1], in[j + 8 * 1]);
+    out[j + 8 * 2] = _mm_add_epi32(in[j + 8 * 2], in[j + 8 * 2]);
+    out[j + 8 * 3] = _mm_add_epi32(in[j + 8 * 3], in[j + 8 * 3]);
+    out[j + 8 * 4] = _mm_add_epi32(in[j + 8 * 4], in[j + 8 * 4]);
+    out[j + 8 * 5] = _mm_add_epi32(in[j + 8 * 5], in[j + 8 * 5]);
+    out[j + 8 * 6] = _mm_add_epi32(in[j + 8 * 6], in[j + 8 * 6]);
+    out[j + 8 * 7] = _mm_add_epi32(in[j + 8 * 7], in[j + 8 * 7]);
+  }
+}
 void av1_fwd_txfm2d_8x8_sse4_1(const int16_t *input, int32_t *coeff, int stride,
                                TX_TYPE tx_type, int bd) {
   __m128i in[16], out[16];
@@ -1913,7 +1927,24 @@
   fadst8x8_sse4_1,  // V_FLIPADST
   idtx8x8_sse4_1    // H_FLIPADST
 };
-
+static const fwd_transform_1d_sse4_1 row_highbd_txfm32x8_arr[TX_TYPES] = {
+  fdct8x8_sse4_1,   // DCT_DCT
+  NULL,             // ADST_DCT
+  NULL,             // DCT_ADST
+  NULL,             // ADST_ADST
+  NULL,             // FLIPADST_DCT
+  NULL,             // DCT_FLIPADST
+  NULL,             // FLIPADST_FLIPADST
+  NULL,             // ADST_FLIPADST
+  NULL,             // FLIPADST-ADST
+  idtx32x8_sse4_1,  // IDTX
+  NULL,             // V_DCT
+  NULL,             // H_DCT
+  NULL,             // V_ADST
+  NULL,             // H_ADST
+  NULL,             // V_FLIPADST
+  NULL,             // H_FLIPADST
+};
 static const fwd_transform_1d_sse4_1 col_highbd_txfm4x8_arr[TX_TYPES] = {
   fdct4x8_sse4_1,   // DCT_DCT
   fadst8x8_sse4_1,  // ADST_DCT
@@ -2056,7 +2087,7 @@
   NULL,                   // FLIPADST_FLIPADST
   NULL,                   // ADST_FLIPADST
   NULL,                   // FLIPADST_ADST
-  NULL,                   // IDTX
+  av1_idtx32_new_sse4_1,  // IDTX
   NULL,                   // V_DCT
   NULL,                   // H_DCT
   NULL,                   // V_ADST
@@ -2075,7 +2106,7 @@
   NULL,              // FLIPADST_FLIPADST
   NULL,              // ADST_FLIPADST
   NULL,              // FLIPADST_ADST
-  NULL,              // IDTX
+  idtx16x16_sse4_1,  // IDTX
   NULL,              // V_DCT
   NULL,              // H_DCT
   NULL,              // V_ADST
@@ -2209,7 +2240,6 @@
 
 void av1_fwd_txfm2d_16x32_sse4_1(const int16_t *input, int32_t *coeff,
                                  int stride, TX_TYPE tx_type, int bd) {
-  assert(DCT_DCT == tx_type);
   __m128i in[128];
   __m128i *outcoef128 = (__m128i *)coeff;
   const int8_t *shift = fwd_txfm_shift_ls[TX_16X32];
@@ -2321,7 +2351,6 @@
 
 void av1_fwd_txfm2d_32x16_sse4_1(const int16_t *input, int32_t *coeff,
                                  int stride, TX_TYPE tx_type, int bd) {
-  assert(DCT_DCT == tx_type);
   __m128i in[128];
   __m128i *outcoef128 = (__m128i *)coeff;
   const int8_t *shift = fwd_txfm_shift_ls[TX_32X16];
@@ -2351,14 +2380,13 @@
 
 void av1_fwd_txfm2d_8x32_sse4_1(const int16_t *input, int32_t *coeff,
                                 int stride, TX_TYPE tx_type, int bd) {
-  assert(DCT_DCT == tx_type);
   __m128i in[64];
   __m128i *outcoef128 = (__m128i *)coeff;
   const int8_t *shift = fwd_txfm_shift_ls[TX_8X32];
   const int txw_idx = get_txw_idx(TX_8X32);
   const int txh_idx = get_txh_idx(TX_8X32);
   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm8x32_arr[tx_type];
-  const fwd_transform_1d_sse4_1 row_txfm = col_highbd_txfm8x8_arr[tx_type];
+  const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm32x8_arr[tx_type];
   int bitcol = fwd_cos_bit_col[txw_idx][txh_idx];
   int bitrow = fwd_cos_bit_row[txw_idx][txh_idx];
 
@@ -2387,13 +2415,12 @@
 
 void av1_fwd_txfm2d_32x8_sse4_1(const int16_t *input, int32_t *coeff,
                                 int stride, TX_TYPE tx_type, int bd) {
-  assert(DCT_DCT == tx_type);
   __m128i in[64];
   __m128i *outcoef128 = (__m128i *)coeff;
   const int8_t *shift = fwd_txfm_shift_ls[TX_32X8];
   const int txw_idx = get_txw_idx(TX_32X8);
   const int txh_idx = get_txh_idx(TX_32X8);
-  const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm8x8_arr[tx_type];
+  const fwd_transform_1d_sse4_1 col_txfm = row_highbd_txfm32x8_arr[tx_type];
   const fwd_transform_1d_sse4_1 row_txfm = col_highbd_txfm8x32_arr[tx_type];
   int bitcol = fwd_cos_bit_col[txw_idx][txh_idx];
   int bitrow = fwd_cos_bit_row[txw_idx][txh_idx];