Optimize highbd inv_txfm modules

Added SSE4_1 variants for horizontal identity txfm types.

Module level gains:
Tx_size   Gain w.r.t. C
8x8         7.53x
8x16        12.58x
16x8        7.31x
16x16       14.15x

When tested for multiple test cases observed 1.0%
average reduction in encoder time for speed = 1 preset.

Change-Id: Ida4e806a6a6a2463953b440d43ae0ded2808352c
diff --git a/av1/common/x86/highbd_inv_txfm_avx2.c b/av1/common/x86/highbd_inv_txfm_avx2.c
index 5418057..2fe5777 100644
--- a/av1/common/x86/highbd_inv_txfm_avx2.c
+++ b/av1/common/x86/highbd_inv_txfm_avx2.c
@@ -4322,16 +4322,20 @@
   switch (tx_type) {
       // Assembly version doesn't support some transform types, so use C version
       // for those.
-    case V_DCT:
     case H_DCT:
-    case V_ADST:
     case H_ADST:
-    case V_FLIPADST:
     case H_FLIPADST:
     case IDTX:
       av1_inv_txfm2d_add_16x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
                                  tx_type, bd);
       break;
+    case V_DCT:
+    case V_ADST:
+    case V_FLIPADST:
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
+      break;
     default:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
                                               txfm_param->tx_size,
@@ -4409,16 +4413,20 @@
   switch (tx_type) {
       // Assembly version doesn't support some transform types, so use C version
       // for those.
-    case V_DCT:
     case H_DCT:
-    case V_ADST:
     case H_ADST:
-    case V_FLIPADST:
     case H_FLIPADST:
     case IDTX:
       av1_inv_txfm2d_add_8x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
                                bd);
       break;
+    case V_DCT:
+    case V_ADST:
+    case V_FLIPADST:
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
+      break;
     default:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
                                               txfm_param->tx_size,
@@ -4474,16 +4482,20 @@
   switch (tx_type) {
       // Assembly version doesn't support some transform types, so use C version
       // for those.
-    case V_DCT:
     case H_DCT:
-    case V_ADST:
     case H_ADST:
-    case V_FLIPADST:
     case H_FLIPADST:
     case IDTX:
       av1_inv_txfm2d_add_16x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
                                 txfm_param->tx_type, txfm_param->bd);
       break;
+    case V_DCT:
+    case V_ADST:
+    case V_FLIPADST:
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
+      break;
     default:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
                                               txfm_param->tx_size,
@@ -4501,16 +4513,20 @@
   switch (tx_type) {
       // Assembly version doesn't support some transform types, so use C version
       // for those.
-    case V_DCT:
     case H_DCT:
-    case V_ADST:
     case H_ADST:
-    case V_FLIPADST:
     case H_FLIPADST:
     case IDTX:
       av1_inv_txfm2d_add_8x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
                                 txfm_param->tx_type, txfm_param->bd);
       break;
+    case V_DCT:
+    case V_ADST:
+    case V_FLIPADST:
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
+      break;
     default:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
                                               txfm_param->tx_size,
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index 12c6350..f546adb 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -1116,6 +1116,80 @@
                      &clamp_hi_out, out_shift);
   }
 }
+static void highbd_clamp_epi32_sse4_1(const __m128i *in, __m128i *out,
+                                      const __m128i *clamp_lo,
+                                      const __m128i *clamp_hi, int size) {
+  __m128i a0, a1;
+  for (int i = 0; i < size; i += 4) {
+    a0 = _mm_max_epi32(in[i], *clamp_lo);
+    out[i] = _mm_min_epi32(a0, *clamp_hi);
+
+    a1 = _mm_max_epi32(in[i + 1], *clamp_lo);
+    out[i + 1] = _mm_min_epi32(a1, *clamp_hi);
+
+    a0 = _mm_max_epi32(in[i + 2], *clamp_lo);
+    out[i + 2] = _mm_min_epi32(a0, *clamp_hi);
+
+    a1 = _mm_max_epi32(in[i + 3], *clamp_lo);
+    out[i + 3] = _mm_min_epi32(a1, *clamp_hi);
+  }
+}
+
+static void shift_sse4_1(const __m128i *in, __m128i *out,
+                         const __m128i *clamp_lo, const __m128i *clamp_hi,
+                         int shift, int size) {
+  __m128i offset = _mm_set1_epi32((1 << shift) >> 1);
+  __m128i shift_vec = _mm_cvtsi32_si128(shift);
+  __m128i a0, a1;
+  for (int i = 0; i < size; i += 4) {
+    a0 = _mm_add_epi32(in[i], offset);
+    a1 = _mm_add_epi32(in[i + 1], offset);
+    a0 = _mm_sra_epi32(a0, shift_vec);
+    a1 = _mm_sra_epi32(a1, shift_vec);
+    a0 = _mm_max_epi32(a0, *clamp_lo);
+    a1 = _mm_max_epi32(a1, *clamp_lo);
+    out[i] = _mm_min_epi32(a0, *clamp_hi);
+    out[i + 1] = _mm_min_epi32(a1, *clamp_hi);
+
+    a0 = _mm_add_epi32(in[i + 2], offset);
+    a1 = _mm_add_epi32(in[i + 3], offset);
+    a0 = _mm_sra_epi32(a0, shift_vec);
+    a1 = _mm_sra_epi32(a1, shift_vec);
+    a0 = _mm_max_epi32(a0, *clamp_lo);
+    a1 = _mm_max_epi32(a1, *clamp_lo);
+    out[i + 2] = _mm_min_epi32(a0, *clamp_hi);
+    out[i + 3] = _mm_min_epi32(a1, *clamp_hi);
+  }
+}
+
+static void iidentity8_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols,
+                              int bd, int out_shift) {
+  (void)bit;
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+  const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
+  __m128i v[8];
+  v[0] = _mm_add_epi32(in[0], in[0]);
+  v[1] = _mm_add_epi32(in[1], in[1]);
+  v[2] = _mm_add_epi32(in[2], in[2]);
+  v[3] = _mm_add_epi32(in[3], in[3]);
+  v[4] = _mm_add_epi32(in[4], in[4]);
+  v[5] = _mm_add_epi32(in[5], in[5]);
+  v[6] = _mm_add_epi32(in[6], in[6]);
+  v[7] = _mm_add_epi32(in[7], in[7]);
+
+  if (!do_cols) {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX(
+        -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+    const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN(
+        (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+    shift_sse4_1(v, out, &clamp_lo_out, &clamp_hi_out, out_shift, 8);
+  } else {
+    highbd_clamp_epi32_sse4_1(v, out, &clamp_lo, &clamp_hi, 8);
+  }
+}
 
 static void round_shift_8x8(__m128i *in, int shift) {
   round_shift_4x4(&in[0], shift);
@@ -3000,7 +3074,59 @@
     }
   }
 }
+static void iidentity16_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols,
+                               int bd, int out_shift) {
+  (void)bit;
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+  const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
+  __m128i v[16];
+  __m128i fact = _mm_set1_epi32(2 * NewSqrt2);
+  __m128i offset = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
+  __m128i a0, a1, a2, a3;
 
+  for (int i = 0; i < 16; i += 8) {
+    a0 = _mm_mullo_epi32(in[i], fact);
+    a1 = _mm_mullo_epi32(in[i + 1], fact);
+    a0 = _mm_add_epi32(a0, offset);
+    a1 = _mm_add_epi32(a1, offset);
+    v[i] = _mm_srai_epi32(a0, NewSqrt2Bits);
+    v[i + 1] = _mm_srai_epi32(a1, NewSqrt2Bits);
+
+    a2 = _mm_mullo_epi32(in[i + 2], fact);
+    a3 = _mm_mullo_epi32(in[i + 3], fact);
+    a2 = _mm_add_epi32(a2, offset);
+    a3 = _mm_add_epi32(a3, offset);
+    v[i + 2] = _mm_srai_epi32(a2, NewSqrt2Bits);
+    v[i + 3] = _mm_srai_epi32(a3, NewSqrt2Bits);
+
+    a0 = _mm_mullo_epi32(in[i + 4], fact);
+    a1 = _mm_mullo_epi32(in[i + 5], fact);
+    a0 = _mm_add_epi32(a0, offset);
+    a1 = _mm_add_epi32(a1, offset);
+    v[i + 4] = _mm_srai_epi32(a0, NewSqrt2Bits);
+    v[i + 5] = _mm_srai_epi32(a1, NewSqrt2Bits);
+
+    a2 = _mm_mullo_epi32(in[i + 6], fact);
+    a3 = _mm_mullo_epi32(in[i + 7], fact);
+    a2 = _mm_add_epi32(a2, offset);
+    a3 = _mm_add_epi32(a3, offset);
+    v[i + 6] = _mm_srai_epi32(a2, NewSqrt2Bits);
+    v[i + 7] = _mm_srai_epi32(a3, NewSqrt2Bits);
+  }
+
+  if (!do_cols) {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX(
+        -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+    const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN(
+        (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+    shift_sse4_1(v, out, &clamp_lo_out, &clamp_hi_out, out_shift, 16);
+  } else {
+    highbd_clamp_epi32_sse4_1(v, out, &clamp_lo, &clamp_hi, 16);
+  }
+}
 static INLINE void idct64_stage8_sse4_1(
     __m128i *u, const __m128i *cospim32, const __m128i *cospi32,
     const __m128i *cospim16, const __m128i *cospi48, const __m128i *cospi16,
@@ -5022,16 +5148,20 @@
   switch (tx_type) {
       // Assembly version doesn't support some transform types, so use C version
       // for those.
-    case V_DCT:
     case H_DCT:
-    case V_ADST:
     case H_ADST:
-    case V_FLIPADST:
     case H_FLIPADST:
     case IDTX:
       av1_inv_txfm2d_add_8x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
                                bd);
       break;
+    case V_DCT:
+    case V_ADST:
+    case V_FLIPADST:
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
+      break;
     default:
       av1_inv_txfm2d_add_8x8_sse4_1(src, CONVERT_TO_SHORTPTR(dest), stride,
                                     tx_type, bd);
@@ -5048,11 +5178,8 @@
   switch (tx_type) {
       // Assembly version doesn't support some transform types, so use C version
       // for those.
-    case V_DCT:
     case H_DCT:
-    case V_ADST:
     case H_ADST:
-    case V_FLIPADST:
     case H_FLIPADST:
     case IDTX:
       av1_inv_txfm2d_add_16x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
@@ -5075,11 +5202,8 @@
   switch (tx_type) {
       // Assembly version doesn't support some transform types, so use C version
       // for those.
-    case V_DCT:
     case H_DCT:
-    case V_ADST:
     case H_ADST:
-    case V_FLIPADST:
     case H_FLIPADST:
     case IDTX:
       av1_inv_txfm2d_add_8x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
@@ -5102,11 +5226,8 @@
   switch (tx_type) {
       // Assembly version doesn't support some transform types, so use C version
       // for those.
-    case V_DCT:
     case H_DCT:
-    case V_ADST:
     case H_ADST:
-    case V_FLIPADST:
     case H_FLIPADST:
     case IDTX:
       av1_inv_txfm2d_add_16x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
@@ -5264,13 +5385,13 @@
       },
       { { idct8x8_low1_sse4_1, idct8x8_new_sse4_1, NULL, NULL },
         { iadst8x8_low1_sse4_1, iadst8x8_new_sse4_1, NULL, NULL },
-        { NULL, NULL, NULL, NULL } },
+        { iidentity8_sse4_1, NULL, NULL, NULL } },
       {
           { idct16x16_low1_sse4_1, idct16x16_low8_sse4_1, idct16x16_sse4_1,
             NULL },
           { iadst16x16_low1_sse4_1, iadst16x16_low8_sse4_1, iadst16x16_sse4_1,
             NULL },
-          { NULL, NULL, NULL, NULL },
+          { iidentity16_sse4_1, NULL, NULL, NULL },
       },
       { { idct32x32_low1_sse4_1, idct32x32_low8_sse4_1, idct32x32_low16_sse4_1,
           idct32x32_sse4_1 },
@@ -5281,7 +5402,68 @@
         { NULL, NULL, NULL, NULL },
         { NULL, NULL, NULL, NULL } }
     };
+static void highbd_inv_txfm2d_add_h_identity_ssse41(const int32_t *input,
+                                                    uint16_t *output,
+                                                    int stride, TX_TYPE tx_type,
+                                                    TX_SIZE tx_size, int eob,
+                                                    const int bd) {
+  __m128i buf1[64];
+  int eobx, eoby;
+  get_eobx_eoby_scan_v_identity(&eobx, &eoby, tx_size, eob);
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int input_stride = AOMMIN(32, txfm_size_col);
+  const int buf_size_w_div4 = input_stride >> 2;
+  const int buf_size_h_div8 = (eoby + 8) >> 3;
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+  const int fun_idx = lowbd_txfm_all_1d_zeros_idx[eoby];
+  const transform_1d_sse4_1 row_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txw_idx][hitx_1d_tab[tx_type]][0];
+  const transform_1d_sse4_1 col_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txh_idx][vitx_1d_tab[tx_type]][fun_idx];
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
 
+  for (int i = 0; i < (buf_size_h_div8 << 1); ++i) {
+    __m128i buf0[16];
+    const int32_t *input_row = input + i * input_stride * 4;
+    for (int j = 0; j < buf_size_w_div4; ++j) {
+      __m128i *buf0_cur = buf0 + j * 4;
+      load_buffer_32bit_input(input_row + j * 4, input_stride, buf0_cur, 4);
+    }
+    if (rect_type == 1 || rect_type == -1) {
+      av1_round_shift_rect_array_32_sse4_1(buf0, buf0, input_stride, 0,
+                                           NewInvSqrt2);
+    }
+    row_txfm(buf0, buf0, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]);
+
+    __m128i *_buf1 = buf1 + i * 4;
+
+    for (int j = 0; j < buf_size_w_div4; ++j) {
+      _buf1[j * txfm_size_row + 0] = buf0[j * 4 + 0];
+      _buf1[j * txfm_size_row + 1] = buf0[j * 4 + 1];
+      _buf1[j * txfm_size_row + 2] = buf0[j * 4 + 2];
+      _buf1[j * txfm_size_row + 3] = buf0[j * 4 + 3];
+    }
+  }
+  for (int i = 0; i < buf_size_w_div4; i++) {
+    col_txfm(buf1 + i * txfm_size_row, buf1 + i * txfm_size_row,
+             inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+
+    av1_round_shift_array_32_sse4_1(buf1 + i * txfm_size_row,
+                                    buf1 + i * txfm_size_row, txfm_size_row,
+                                    -shift[1]);
+  }
+
+  // write to buffer
+  for (int i = 0; i < (txfm_size_col >> 3); i++) {
+    highbd_write_buffer_8xn_sse4_1(buf1 + i * txfm_size_row * 2, output + 8 * i,
+                                   stride, ud_flip, txfm_size_row, bd);
+  }
+}
 static void highbd_inv_txfm2d_add_no_identity_sse41(const int32_t *input,
                                                     uint16_t *output,
                                                     int stride, TX_TYPE tx_type,
@@ -5613,6 +5795,13 @@
           input, CONVERT_TO_SHORTPTR(output), stride, tx_type, tx_size, eob,
           bd);
       break;
+    case V_DCT:
+    case V_ADST:
+    case V_FLIPADST:
+      highbd_inv_txfm2d_add_h_identity_ssse41(
+          input, CONVERT_TO_SHORTPTR(output), stride, tx_type, tx_size, eob,
+          bd);
+      break;
     default: assert(0); break;
   }
 }