Optimize highbd inv_txfm modules

Added SSE4_1 variant support of horizontal, vertical, identity txfm types for
txfm blk_sizes 4x4,4x8,8x4,4x16 and 16x4.

Module level gains:
Tx_size   Gain w.r.t. C
4x4        5.58x
4x8        4.65x
8x4        5.00x
4x16       4.63x
16x4       5.70x

When tested for multiple test cases observed 1.38%
average reduction in encoder time for speed = 1 preset.

Change-Id: Ie6cb27dc8710a73ed48dd948ba6ba06e67067e41
diff --git a/av1/common/x86/highbd_inv_txfm_avx2.c b/av1/common/x86/highbd_inv_txfm_avx2.c
index 4de31b6..cc9e90b 100644
--- a/av1/common/x86/highbd_inv_txfm_avx2.c
+++ b/av1/common/x86/highbd_inv_txfm_avx2.c
@@ -4318,14 +4318,8 @@
                                         const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
     case IDTX:
-      av1_inv_txfm2d_add_16x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                 tx_type, bd);
-      break;
     case H_DCT:
     case H_ADST:
     case H_FLIPADST:
@@ -4349,17 +4343,16 @@
                                         const TxfmParam *txfm_param) {
   const int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
     case DCT_DCT:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
                                               txfm_param->tx_size,
                                               txfm_param->eob, bd);
       break;
-      // Assembly version doesn't support IDTX, so use C version for it.
     case IDTX:
-      av1_inv_txfm2d_add_32x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                 tx_type, bd);
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
       break;
 
     default: assert(0);
@@ -4371,7 +4364,6 @@
                                         const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
     case DCT_DCT:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
@@ -4379,8 +4371,9 @@
                                               txfm_param->eob, bd);
       break;
     case IDTX:
-      av1_inv_txfm2d_add_16x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                 txfm_param->tx_type, txfm_param->bd);
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
       break;
     default: assert(0);
   }
@@ -4391,7 +4384,6 @@
                                         const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
     case DCT_DCT:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
@@ -4399,8 +4391,9 @@
                                               txfm_param->eob, bd);
       break;
     case IDTX:
-      av1_inv_txfm2d_add_32x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                 txfm_param->tx_type, txfm_param->bd);
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
       break;
     default: assert(0);
   }
@@ -4409,14 +4402,8 @@
                                       int stride, const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
     case IDTX:
-      av1_inv_txfm2d_add_8x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
-                               bd);
-      break;
     case H_DCT:
     case H_ADST:
     case H_FLIPADST:
@@ -4439,7 +4426,6 @@
                                        const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
     case DCT_DCT:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
@@ -4447,8 +4433,9 @@
                                               txfm_param->eob, bd);
       break;
     case IDTX:
-      av1_inv_txfm2d_add_8x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                txfm_param->tx_type, txfm_param->bd);
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
       break;
     default: assert(0);
   }
@@ -4459,7 +4446,6 @@
                                        const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
     case DCT_DCT:
       av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
@@ -4467,8 +4453,9 @@
                                               txfm_param->eob, bd);
       break;
     case IDTX:
-      av1_inv_txfm2d_add_32x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                txfm_param->tx_type, txfm_param->bd);
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
+                                                txfm_param->tx_size,
+                                                txfm_param->eob, bd);
       break;
     default: assert(0);
   }
@@ -4478,14 +4465,8 @@
                                        const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
     case IDTX:
-      av1_inv_txfm2d_add_16x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                txfm_param->tx_type, txfm_param->bd);
-      break;
     case H_DCT:
     case H_ADST:
     case H_FLIPADST:
@@ -4509,14 +4490,8 @@
                                        const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
     case IDTX:
-      av1_inv_txfm2d_add_8x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                txfm_param->tx_type, txfm_param->bd);
-      break;
     case H_DCT:
     case H_ADST:
     case H_FLIPADST:
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index b38bbe9..39f3548 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -583,7 +583,66 @@
   _mm_storel_epi64((__m128i *)(output + 2 * stride), v2);
   _mm_storel_epi64((__m128i *)(output + 3 * stride), v3);
 }
+static void highbd_clamp_epi32_sse4_1(const __m128i *in, __m128i *out,
+                                      const __m128i *clamp_lo,
+                                      const __m128i *clamp_hi, int size) {
+  __m128i a0, a1;
+  for (int i = 0; i < size; i += 4) {
+    a0 = _mm_max_epi32(in[i], *clamp_lo);
+    out[i] = _mm_min_epi32(a0, *clamp_hi);
 
+    a1 = _mm_max_epi32(in[i + 1], *clamp_lo);
+    out[i + 1] = _mm_min_epi32(a1, *clamp_hi);
+
+    a0 = _mm_max_epi32(in[i + 2], *clamp_lo);
+    out[i + 2] = _mm_min_epi32(a0, *clamp_hi);
+
+    a1 = _mm_max_epi32(in[i + 3], *clamp_lo);
+    out[i + 3] = _mm_min_epi32(a1, *clamp_hi);
+  }
+}
+static void iidentity4_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols,
+                              int bd, int out_shift) {
+  (void)bit;
+  (void)out_shift;
+  __m128i v[4];
+  __m128i fact = _mm_set1_epi32(NewSqrt2);
+  __m128i offset = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
+  __m128i a0, a1;
+
+  a0 = _mm_mullo_epi32(in[0], fact);
+  a1 = _mm_mullo_epi32(in[1], fact);
+  a0 = _mm_add_epi32(a0, offset);
+  a1 = _mm_add_epi32(a1, offset);
+  out[0] = _mm_srai_epi32(a0, NewSqrt2Bits);
+  out[1] = _mm_srai_epi32(a1, NewSqrt2Bits);
+
+  a0 = _mm_mullo_epi32(in[2], fact);
+  a1 = _mm_mullo_epi32(in[3], fact);
+  a0 = _mm_add_epi32(a0, offset);
+  a1 = _mm_add_epi32(a1, offset);
+  out[2] = _mm_srai_epi32(a0, NewSqrt2Bits);
+  out[3] = _mm_srai_epi32(a1, NewSqrt2Bits);
+
+  if (!do_cols) {
+    const int log_range = AOMMAX(16, bd + 6);
+    const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+    const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
+
+    highbd_clamp_epi32_sse4_1(out, out, &clamp_lo, &clamp_hi, 4);
+  }
+
+  // Transpose for 4x4
+  v[0] = _mm_unpacklo_epi32(out[0], out[1]);
+  v[1] = _mm_unpackhi_epi32(out[0], out[1]);
+  v[2] = _mm_unpacklo_epi32(out[2], out[3]);
+  v[3] = _mm_unpackhi_epi32(out[2], out[3]);
+
+  out[0] = _mm_unpacklo_epi64(v[0], v[2]);
+  out[1] = _mm_unpackhi_epi64(v[0], v[2]);
+  out[2] = _mm_unpacklo_epi64(v[1], v[3]);
+  out[3] = _mm_unpackhi_epi64(v[1], v[3]);
+}
 void av1_inv_txfm2d_add_4x4_sse4_1(const int32_t *coeff, uint16_t *output,
                                    int stride, TX_TYPE tx_type, int bd) {
   __m128i in[4];
@@ -646,6 +705,48 @@
       iadst4x4_sse4_1(in, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
       write_buffer_4x4(in, output, stride, 0, 1, -shift[1], bd);
       break;
+    case IDTX:
+      load_buffer_4x4(coeff, in);
+      iidentity4_sse4_1(in, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, 0);
+      iidentity4_sse4_1(in, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
+      break;
+    case V_DCT:
+      load_buffer_4x4(coeff, in);
+      iidentity4_sse4_1(in, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, 0);
+      idct4x4_sse4_1(in, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
+      break;
+    case H_DCT:
+      load_buffer_4x4(coeff, in);
+      idct4x4_sse4_1(in, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, 0);
+      iidentity4_sse4_1(in, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
+      break;
+    case V_ADST:
+      load_buffer_4x4(coeff, in);
+      iidentity4_sse4_1(in, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, 0);
+      iadst4x4_sse4_1(in, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
+      break;
+    case H_ADST:
+      load_buffer_4x4(coeff, in);
+      iadst4x4_sse4_1(in, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, 0);
+      iidentity4_sse4_1(in, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
+      break;
+    case V_FLIPADST:
+      load_buffer_4x4(coeff, in);
+      iidentity4_sse4_1(in, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, 0);
+      iadst4x4_sse4_1(in, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+      write_buffer_4x4(in, output, stride, 0, 1, -shift[1], bd);
+      break;
+    case H_FLIPADST:
+      load_buffer_4x4(coeff, in);
+      iadst4x4_sse4_1(in, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, 0);
+      iidentity4_sse4_1(in, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+      write_buffer_4x4(in, output, stride, 1, 0, -shift[1], bd);
+      break;
     default: assert(0);
   }
 }
@@ -1116,25 +1217,6 @@
                      &clamp_hi_out, out_shift);
   }
 }
-static void highbd_clamp_epi32_sse4_1(const __m128i *in, __m128i *out,
-                                      const __m128i *clamp_lo,
-                                      const __m128i *clamp_hi, int size) {
-  __m128i a0, a1;
-  for (int i = 0; i < size; i += 4) {
-    a0 = _mm_max_epi32(in[i], *clamp_lo);
-    out[i] = _mm_min_epi32(a0, *clamp_hi);
-
-    a1 = _mm_max_epi32(in[i + 1], *clamp_lo);
-    out[i + 1] = _mm_min_epi32(a1, *clamp_hi);
-
-    a0 = _mm_max_epi32(in[i + 2], *clamp_lo);
-    out[i + 2] = _mm_min_epi32(a0, *clamp_hi);
-
-    a1 = _mm_max_epi32(in[i + 3], *clamp_lo);
-    out[i + 3] = _mm_min_epi32(a1, *clamp_hi);
-  }
-}
-
 static void shift_sse4_1(const __m128i *in, __m128i *out,
                          const __m128i *clamp_lo, const __m128i *clamp_hi,
                          int shift, int size) {
@@ -5146,12 +5228,7 @@
   const TX_TYPE tx_type = txfm_param->tx_type;
   const int32_t *src = cast_to_int32(input);
   switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
     case IDTX:
-      av1_inv_txfm2d_add_8x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
-                               bd);
-      break;
     case H_DCT:
     case H_ADST:
     case H_FLIPADST:
@@ -5174,20 +5251,8 @@
                                          const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
-    case IDTX:
-      av1_inv_txfm2d_add_16x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                txfm_param->tx_type, txfm_param->bd);
-      break;
-    default:
-      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
-                                                txfm_param->tx_size,
-                                                txfm_param->eob, bd);
-      break;
-  }
+  av1_highbd_inv_txfm2d_add_universe_sse4_1(
+      input, dest, stride, tx_type, txfm_param->tx_size, txfm_param->eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_8x16_sse4_1(const tran_low_t *input, uint8_t *dest,
@@ -5195,20 +5260,8 @@
                                          const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
-    case IDTX:
-      av1_inv_txfm2d_add_8x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                txfm_param->tx_type, txfm_param->bd);
-      break;
-    default:
-      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
-                                                txfm_param->tx_size,
-                                                txfm_param->eob, bd);
-      break;
-  }
+  av1_highbd_inv_txfm2d_add_universe_sse4_1(
+      input, dest, stride, tx_type, txfm_param->tx_size, txfm_param->eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_16x16_sse4_1(const tran_low_t *input,
@@ -5216,20 +5269,8 @@
                                           const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
-    case IDTX:
-      av1_inv_txfm2d_add_16x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                 tx_type, bd);
-      break;
-    default:
-      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
-                                                txfm_param->tx_size,
-                                                txfm_param->eob, bd);
-      break;
-  }
+  av1_highbd_inv_txfm2d_add_universe_sse4_1(
+      input, dest, stride, tx_type, txfm_param->tx_size, txfm_param->eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_32x32_sse4_1(const tran_low_t *input,
@@ -5237,20 +5278,8 @@
                                           const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  switch (tx_type) {
-    case DCT_DCT:
-      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
-                                                txfm_param->tx_size,
-                                                txfm_param->eob, bd);
-      break;
-      // Assembly version doesn't support IDTX, so use C version for it.
-    case IDTX:
-      av1_inv_txfm2d_add_32x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                 tx_type, bd);
-      break;
-    default: assert(0);
-  }
+  av1_highbd_inv_txfm2d_add_universe_sse4_1(
+      input, dest, stride, tx_type, txfm_param->tx_size, txfm_param->eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_16x32_sse4_1(const tran_low_t *input,
@@ -5258,19 +5287,8 @@
                                           const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  switch (tx_type) {
-    case DCT_DCT:
-      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
-                                                txfm_param->tx_size,
-                                                txfm_param->eob, bd);
-      break;
-    case IDTX:
-      av1_inv_txfm2d_add_16x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                 txfm_param->tx_type, txfm_param->bd);
-      break;
-    default: assert(0);
-  }
+  av1_highbd_inv_txfm2d_add_universe_sse4_1(
+      input, dest, stride, tx_type, txfm_param->tx_size, txfm_param->eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_32x16_sse4_1(const tran_low_t *input,
@@ -5278,19 +5296,8 @@
                                           const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  switch (tx_type) {
-    case DCT_DCT:
-      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
-                                                txfm_param->tx_size,
-                                                txfm_param->eob, bd);
-      break;
-    case IDTX:
-      av1_inv_txfm2d_add_32x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                 txfm_param->tx_type, txfm_param->bd);
-      break;
-    default: assert(0);
-  }
+  av1_highbd_inv_txfm2d_add_universe_sse4_1(
+      input, dest, stride, tx_type, txfm_param->tx_size, txfm_param->eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_8x32_sse4_1(const tran_low_t *input, uint8_t *dest,
@@ -5298,19 +5305,8 @@
                                          const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  switch (tx_type) {
-    case DCT_DCT:
-      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
-                                                txfm_param->tx_size,
-                                                txfm_param->eob, bd);
-      break;
-    case IDTX:
-      av1_inv_txfm2d_add_8x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                txfm_param->tx_type, txfm_param->bd);
-      break;
-    default: assert(0);
-  }
+  av1_highbd_inv_txfm2d_add_universe_sse4_1(
+      input, dest, stride, tx_type, txfm_param->tx_size, txfm_param->eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_32x8_sse4_1(const tran_low_t *input, uint8_t *dest,
@@ -5318,19 +5314,8 @@
                                          const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  switch (tx_type) {
-    case DCT_DCT:
-      av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type,
-                                                txfm_param->tx_size,
-                                                txfm_param->eob, bd);
-      break;
-    case IDTX:
-      av1_inv_txfm2d_add_32x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                txfm_param->tx_type, txfm_param->bd);
-      break;
-    default: assert(0);
-  }
+  av1_highbd_inv_txfm2d_add_universe_sse4_1(
+      input, dest, stride, tx_type, txfm_param->tx_size, txfm_param->eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_4x4_sse4_1(const tran_low_t *input, uint8_t *dest,
@@ -5347,47 +5332,67 @@
     av1_highbd_iwht4x4_add(input, dest, stride, eob, bd);
     return;
   }
-  switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
-    case V_DCT:
-    case H_DCT:
-    case V_ADST:
-    case H_ADST:
-    case V_FLIPADST:
-    case H_FLIPADST:
-    case IDTX:
-      av1_inv_txfm2d_add_4x4_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
-                               bd);
-      break;
-    default:
-      av1_inv_txfm2d_add_4x4_sse4_1(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                    tx_type, bd);
-      break;
+  av1_inv_txfm2d_add_4x4_sse4_1(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+                                bd);
+}
+static void iidentity32_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols,
+                               int bd, int out_shift) {
+  (void)bit;
+  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+  const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
+  const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
+  __m128i v[32];
+  for (int i = 0; i < 32; i += 16) {
+    v[i] = _mm_slli_epi32(in[i], 2);
+    v[i + 1] = _mm_slli_epi32(in[i + 1], 2);
+    v[i + 2] = _mm_slli_epi32(in[i + 2], 2);
+    v[i + 3] = _mm_slli_epi32(in[i + 3], 2);
+    v[i + 4] = _mm_slli_epi32(in[i + 4], 2);
+    v[i + 5] = _mm_slli_epi32(in[i + 5], 2);
+    v[i + 6] = _mm_slli_epi32(in[i + 6], 2);
+    v[i + 7] = _mm_slli_epi32(in[i + 7], 2);
+    v[i + 8] = _mm_slli_epi32(in[i + 8], 2);
+    v[i + 9] = _mm_slli_epi32(in[i + 9], 2);
+    v[i + 10] = _mm_slli_epi32(in[i + 10], 2);
+    v[i + 11] = _mm_slli_epi32(in[i + 11], 2);
+    v[i + 12] = _mm_slli_epi32(in[i + 12], 2);
+    v[i + 13] = _mm_slli_epi32(in[i + 13], 2);
+    v[i + 14] = _mm_slli_epi32(in[i + 14], 2);
+    v[i + 15] = _mm_slli_epi32(in[i + 15], 2);
+  }
+
+  if (!do_cols) {
+    const int log_range_out = AOMMAX(16, bd + 6);
+    const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX(
+        -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+    const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN(
+        (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+    shift_sse4_1(v, out, &clamp_lo_out, &clamp_hi_out, out_shift, 32);
+  } else {
+    highbd_clamp_epi32_sse4_1(v, out, &clamp_lo, &clamp_hi, 32);
   }
 }
-
 static const transform_1d_sse4_1
     highbd_txfm_all_1d_zeros_w8_arr[TX_SIZES][ITX_TYPES_1D][4] = {
       {
           { idct4x4_sse4_1, NULL, NULL, NULL },
           { iadst4x4_sse4_1, NULL, NULL, NULL },
-          { NULL, NULL, NULL, NULL },
+          { iidentity4_sse4_1, iidentity4_sse4_1, iidentity4_sse4_1, NULL },
       },
       { { idct8x8_low1_sse4_1, idct8x8_new_sse4_1, NULL, NULL },
         { iadst8x8_low1_sse4_1, iadst8x8_new_sse4_1, NULL, NULL },
-        { iidentity8_sse4_1, NULL, NULL, NULL } },
+        { iidentity8_sse4_1, iidentity8_sse4_1, NULL, NULL } },
       {
           { idct16x16_low1_sse4_1, idct16x16_low8_sse4_1, idct16x16_sse4_1,
             NULL },
           { iadst16x16_low1_sse4_1, iadst16x16_low8_sse4_1, iadst16x16_sse4_1,
             NULL },
-          { iidentity16_sse4_1, NULL, NULL, NULL },
+          { iidentity16_sse4_1, NULL, iidentity16_sse4_1, NULL },
       },
       { { idct32x32_low1_sse4_1, idct32x32_low8_sse4_1, idct32x32_low16_sse4_1,
           idct32x32_sse4_1 },
         { NULL, NULL, NULL, NULL },
-        { NULL, NULL, NULL, NULL } },
+        { iidentity32_sse4_1, NULL, NULL, NULL } },
       { { idct64x64_low1_sse4_1, idct64x64_low8_sse4_1, idct64x64_low16_sse4_1,
           idct64x64_sse4_1 },
         { NULL, NULL, NULL, NULL },
@@ -5534,6 +5539,64 @@
     }
   }
 }
+static void highbd_inv_txfm2d_add_idtx_ssse41(const int32_t *input,
+                                              uint16_t *output, int stride,
+                                              TX_TYPE tx_type, TX_SIZE tx_size,
+                                              int eob, const int bd) {
+  (void)eob;
+  __m128i buf1[64 * 4];
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int input_stride = AOMMIN(32, txfm_size_col);
+  const int row_max = AOMMIN(32, txfm_size_row);
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+  const transform_1d_sse4_1 row_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txw_idx][hitx_1d_tab[tx_type]][0];
+  const transform_1d_sse4_1 col_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txh_idx][vitx_1d_tab[tx_type]][0];
+
+  for (int i = 0; i < (row_max >> 2); ++i) {
+    __m128i buf0[32];
+    const int32_t *input_row = input + i * input_stride * 4;
+    for (int j = 0; j < (input_stride >> 2); ++j) {
+      __m128i *buf0_cur = buf0 + j * 4;
+      load_buffer_32bit_input(input_row + j * 4, input_stride, buf0_cur, 4);
+    }
+    if (rect_type == 1 || rect_type == -1) {
+      av1_round_shift_rect_array_32_sse4_1(buf0, buf0, input_stride, 0,
+                                           NewInvSqrt2);
+    }
+    row_txfm(buf0, buf0, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]);
+
+    __m128i *_buf1 = buf1 + i * 4;
+    for (int j = 0; j < (input_stride >> 2); ++j) {
+      _buf1[j * txfm_size_row + 0] = buf0[j * 4 + 0];
+      _buf1[j * txfm_size_row + 1] = buf0[j * 4 + 1];
+      _buf1[j * txfm_size_row + 2] = buf0[j * 4 + 2];
+      _buf1[j * txfm_size_row + 3] = buf0[j * 4 + 3];
+    }
+  }
+  for (int i = 0; i < (input_stride >> 2); i++) {
+    col_txfm(buf1 + i * txfm_size_row, buf1 + i * txfm_size_row,
+             inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+
+    av1_round_shift_array_32_sse4_1(buf1 + i * txfm_size_row,
+                                    buf1 + i * txfm_size_row, txfm_size_row,
+                                    -shift[1]);
+  }
+
+  // write to buffer
+  {
+    for (int i = 0; i < (txfm_size_col >> 3); i++) {
+      highbd_write_buffer_8xn_sse4_1(buf1 + i * txfm_size_row * 2,
+                                     output + 8 * i, stride, 0, txfm_size_row,
+                                     bd);
+    }
+  }
+}
 static void highbd_inv_txfm2d_add_no_identity_sse41(const int32_t *input,
                                                     uint16_t *output,
                                                     int stride, TX_TYPE tx_type,
@@ -5879,6 +5942,10 @@
           input, CONVERT_TO_SHORTPTR(output), stride, tx_type, tx_size, eob,
           bd);
       break;
+    case IDTX:
+      highbd_inv_txfm2d_add_idtx_ssse41(input, CONVERT_TO_SHORTPTR(output),
+                                        stride, tx_type, tx_size, eob, bd);
+      break;
     default: assert(0); break;
   }
 }
@@ -5889,26 +5956,9 @@
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
   const TX_SIZE tx_size = txfm_param->tx_size;
-  const int32_t *src = cast_to_int32(input);
   int eob = txfm_param->eob;
-  switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
-    case V_DCT:
-    case H_DCT:
-    case V_ADST:
-    case H_ADST:
-    case V_FLIPADST:
-    case H_FLIPADST:
-    case IDTX:
-      av1_inv_txfm2d_add_4x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
-                               bd);
-      break;
-    default:
-      highbd_inv_txfm2d_add_4x8_sse41(input, CONVERT_TO_SHORTPTR(dest), stride,
-                                      tx_type, tx_size, eob, bd);
-      break;
-  }
+  highbd_inv_txfm2d_add_4x8_sse41(input, CONVERT_TO_SHORTPTR(dest), stride,
+                                  tx_type, tx_size, eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_8x4_sse4_1(const tran_low_t *input, uint8_t *dest,
@@ -5917,26 +5967,9 @@
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
   const TX_SIZE tx_size = txfm_param->tx_size;
-  const int32_t *src = cast_to_int32(input);
   int eob = txfm_param->eob;
-  switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
-    case V_DCT:
-    case H_DCT:
-    case V_ADST:
-    case H_ADST:
-    case V_FLIPADST:
-    case H_FLIPADST:
-    case IDTX:
-      av1_inv_txfm2d_add_8x4_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
-                               bd);
-      break;
-    default:
-      highbd_inv_txfm2d_add_8x4_sse41(input, CONVERT_TO_SHORTPTR(dest), stride,
-                                      tx_type, tx_size, eob, bd);
-      break;
-  }
+  highbd_inv_txfm2d_add_8x4_sse41(input, CONVERT_TO_SHORTPTR(dest), stride,
+                                  tx_type, tx_size, eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_4x16_sse4_1(const tran_low_t *input, uint8_t *dest,
@@ -5945,26 +5978,9 @@
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
   const TX_SIZE tx_size = txfm_param->tx_size;
-  const int32_t *src = cast_to_int32(input);
   int eob = txfm_param->eob;
-  switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
-    case V_DCT:
-    case H_DCT:
-    case V_ADST:
-    case H_ADST:
-    case V_FLIPADST:
-    case H_FLIPADST:
-    case IDTX:
-      av1_inv_txfm2d_add_4x16_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
-                                bd);
-      break;
-    default:
-      highbd_inv_txfm2d_add_4x16_sse4_1(input, CONVERT_TO_SHORTPTR(dest),
-                                        stride, tx_type, tx_size, eob, bd);
-      break;
-  }
+  highbd_inv_txfm2d_add_4x16_sse4_1(input, CONVERT_TO_SHORTPTR(dest), stride,
+                                    tx_type, tx_size, eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_16x4_sse4_1(const tran_low_t *input, uint8_t *dest,
@@ -5973,26 +5989,9 @@
   int bd = txfm_param->bd;
   const TX_TYPE tx_type = txfm_param->tx_type;
   const TX_SIZE tx_size = txfm_param->tx_size;
-  const int32_t *src = cast_to_int32(input);
   int eob = txfm_param->eob;
-  switch (tx_type) {
-      // Assembly version doesn't support some transform types, so use C version
-      // for those.
-    case V_DCT:
-    case H_DCT:
-    case V_ADST:
-    case H_ADST:
-    case V_FLIPADST:
-    case H_FLIPADST:
-    case IDTX:
-      av1_inv_txfm2d_add_16x4_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
-                                bd);
-      break;
-    default:
-      highbd_inv_txfm2d_add_16x4_sse4_1(input, CONVERT_TO_SHORTPTR(dest),
-                                        stride, tx_type, tx_size, eob, bd);
-      break;
-  }
+  highbd_inv_txfm2d_add_16x4_sse4_1(input, CONVERT_TO_SHORTPTR(dest), stride,
+                                    tx_type, tx_size, eob, bd);
 }
 
 void av1_highbd_inv_txfm_add_sse4_1(const tran_low_t *input, uint8_t *dest,