Refactor inv txfm ssse3 for tx_type with identity

Add dedicate functions for inv txfm include identity type.
lowbd_inv_txfm2d_add_{idtx,h_identity,v_identity}_ssse3
These functions are faster then the general one, for saving
a lot instructions by dropping transpose and dropping
some load or store operations.

The unittest shows this CL is 1.9x ~ 7.0x faster for those
inv txfm include identity type.
* 2.3x ~ 7.0x tx_size (64x64, 64x32, 32x64)
* 1.9x ~ 3.1x other cases

Change-Id: Ib8b0341c25088b5d424a52901853db198343279f
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c
index e785e61..3618370 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.c
+++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -2118,36 +2118,14 @@
   output[15] = _mm_subs_epi16(__zero, x8[1]);
 }
 
-static void iidentity4_new_sse2(const __m128i *input, __m128i *output,
-                                int8_t cos_bit) {
+static void iidentity4_new_ssse3(const __m128i *input, __m128i *output,
+                                 int8_t cos_bit) {
   (void)cos_bit;
-  const __m128i scale = _mm_set1_epi16(NewSqrt2);
-  const __m128i rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
-  const __m128i one = _mm_set1_epi16(1);
-  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
+  const int16_t scale_fractional = (NewSqrt2 - (1 << NewSqrt2Bits));
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
   for (int i = 0; i < 4; ++i) {
-    __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
-    __m128i a_hi = _mm_unpackhi_epi16(input[i], one);
-    __m128i b_lo = _mm_madd_epi16(a_lo, scale_rounding);
-    __m128i b_hi = _mm_madd_epi16(a_hi, scale_rounding);
-    __m128i c_lo = _mm_srai_epi32(b_lo, NewSqrt2Bits);
-    __m128i c_hi = _mm_srai_epi32(b_hi, NewSqrt2Bits);
-    output[i] = _mm_packs_epi32(c_lo, c_hi);
-  }
-}
-
-static void iidentity4_w4_new_sse2(const __m128i *input, __m128i *output,
-                                   int8_t cos_bit) {
-  (void)cos_bit;
-  const __m128i scale = _mm_set1_epi16(NewSqrt2);
-  const __m128i rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
-  const __m128i one = _mm_set1_epi16(1);
-  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
-  for (int i = 0; i < 4; ++i) {
-    __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
-    __m128i b_lo = _mm_madd_epi16(a_lo, scale_rounding);
-    __m128i c_lo = _mm_srai_epi32(b_lo, NewSqrt2Bits);
-    output[i] = _mm_packs_epi32(c_lo, c_lo);
+    __m128i x = _mm_mulhrs_epi16(input[i], scale);
+    output[i] = _mm_adds_epi16(x, input[i]);
   }
 }
 
@@ -2159,69 +2137,15 @@
   }
 }
 
-static void iidentity16_new_sse2(const __m128i *input, __m128i *output,
-                                 int8_t cos_bit) {
+static void iidentity16_new_ssse3(const __m128i *input, __m128i *output,
+                                  int8_t cos_bit) {
   (void)cos_bit;
-  const __m128i scale = _mm_set1_epi16(2 * NewSqrt2);
-  const __m128i rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
-  const __m128i one = _mm_set1_epi16(1);
-  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
+  const int16_t scale_fractional = 2 * (NewSqrt2 - (1 << NewSqrt2Bits));
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
   for (int i = 0; i < 16; ++i) {
-    __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
-    __m128i a_hi = _mm_unpackhi_epi16(input[i], one);
-    __m128i b_lo = _mm_madd_epi16(a_lo, scale_rounding);
-    __m128i b_hi = _mm_madd_epi16(a_hi, scale_rounding);
-    __m128i c_lo = _mm_srai_epi32(b_lo, NewSqrt2Bits);
-    __m128i c_hi = _mm_srai_epi32(b_hi, NewSqrt2Bits);
-    output[i] = _mm_packs_epi32(c_lo, c_hi);
-  }
-}
-
-static void iidentity16_w4_new_sse2(const __m128i *input, __m128i *output,
-                                    int8_t cos_bit) {
-  (void)cos_bit;
-  const __m128i scale = _mm_set1_epi16(2 * NewSqrt2);
-  const __m128i rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
-  const __m128i one = _mm_set1_epi16(1);
-  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
-  for (int i = 0; i < 16; ++i) {
-    __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
-    __m128i b_lo = _mm_madd_epi16(a_lo, scale_rounding);
-    __m128i c_lo = _mm_srai_epi32(b_lo, NewSqrt2Bits);
-    output[i] = _mm_packs_epi32(c_lo, c_lo);
-  }
-}
-
-static void iidentity32_new_sse2(const __m128i *input, __m128i *output,
-                                 int8_t cos_bit) {
-  (void)cos_bit;
-  for (int i = 0; i < 32; ++i) {
-    output[i] = _mm_adds_epi16(input[i], input[i]);
-    output[i] = _mm_adds_epi16(output[i], output[i]);
-  }
-}
-
-static void iidentity64_low32_new_sse2(const __m128i *input, __m128i *output,
-                                       int8_t cos_bit) {
-  (void)cos_bit;
-  const __m128i scale = _mm_set1_epi16(4 * NewSqrt2);
-  const __m128i rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
-  const __m128i one = _mm_set1_epi16(1);
-  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
-  for (int i = 0; i < 32; ++i) {
-    __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
-    __m128i a_hi = _mm_unpackhi_epi16(input[i], one);
-    __m128i b_lo = _mm_madd_epi16(a_lo, scale_rounding);
-    __m128i b_hi = _mm_madd_epi16(a_hi, scale_rounding);
-    __m128i c_lo = _mm_srai_epi32(b_lo, NewSqrt2Bits);
-    __m128i c_hi = _mm_srai_epi32(b_hi, NewSqrt2Bits);
-    output[i] = _mm_packs_epi32(c_lo, c_hi);
-  }
-  // TODO(binpengsmail@gmail.com):
-  // Potential optimization to drop this store to output
-  // by adding dedicate functions for inv txfm include identity type
-  for (int i = 32; i < 64; ++i) {
-    output[i] = _mm_setzero_si128();
+    __m128i x = _mm_mulhrs_epi16(input[i], scale);
+    __m128i srcx2 = _mm_adds_epi16(input[i], input[i]);
+    output[i] = _mm_adds_epi16(x, srcx2);
   }
 }
 
@@ -2284,24 +2208,381 @@
 // 1D functions process process 8 pixels at one time.
 static const transform_1d_ssse3
     lowbd_txfm_all_1d_w8_arr[TX_SIZES][ITX_TYPES_1D] = {
-      { idct4_new_sse2, iadst4_new_sse2, iidentity4_new_sse2 },
+      { idct4_new_sse2, iadst4_new_sse2, iidentity4_new_ssse3 },
       { idct8_new_sse2, iadst8_new_sse2, iidentity8_new_sse2 },
-      { idct16_new_sse2, iadst16_new_sse2, iidentity16_new_sse2 },
-      { idct32_new_sse2, NULL, iidentity32_new_sse2 },
-      { idct64_low32_new_ssse3, NULL, iidentity64_low32_new_sse2 },
+      { idct16_new_sse2, iadst16_new_sse2, iidentity16_new_ssse3 },
+      { idct32_new_sse2, NULL, NULL },
+      { idct64_low32_new_ssse3, NULL, NULL },
     };
 
 // 1D functions process process 4 pixels at one time.
 // used in 4x4, 4x8, 4x16, 8x4, 16x4
 static const transform_1d_ssse3
     lowbd_txfm_all_1d_w4_arr[TX_SIZES][ITX_TYPES_1D] = {
-      { idct4_w4_new_sse2, iadst4_w4_new_sse2, iidentity4_w4_new_sse2 },
+      { idct4_w4_new_sse2, iadst4_w4_new_sse2, iidentity4_new_ssse3 },
       { idct8_w4_new_sse2, iadst8_w4_new_sse2, iidentity8_new_sse2 },
-      { idct16_w4_new_sse2, iadst16_w4_new_sse2, iidentity16_w4_new_sse2 },
+      { idct16_w4_new_sse2, iadst16_w4_new_sse2, iidentity16_new_ssse3 },
       { NULL, NULL, NULL },
       { NULL, NULL, NULL },
     };
 
+static INLINE void iidentity4_row_8xn_ssse3(__m128i *out, const int32_t *input,
+                                            int stride, int shift, int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  for (int h = 0; h < height; ++h) {
+    __m128i src = load_32bit_to_16bit(input_row);
+    input_row += stride;
+    __m128i x = _mm_mulhrs_epi16(src, scale);
+    x = _mm_adds_epi16(x, src);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity4_row_rect_8xn_ssse3(__m128i *out,
+                                                 const int32_t *input,
+                                                 int stride, int shift,
+                                                 int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
+  for (int h = 0; h < height; ++h) {
+    __m128i src = load_32bit_to_16bit(input_row);
+    input_row += stride;
+    src = _mm_mulhrs_epi16(src, rect_scale);
+    __m128i x = _mm_mulhrs_epi16(src, scale);
+    x = _mm_adds_epi16(x, src);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity4_col_8xn_ssse3(uint8_t *output, int stride,
+                                            __m128i *buf, int shift,
+                                            int height) {
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  const __m128i zero = _mm_setzero_si128();
+  for (int h = 0; h < height; ++h) {
+    __m128i x = _mm_mulhrs_epi16(buf[h], scale);
+    x = _mm_adds_epi16(x, buf[h]);
+    x = _mm_mulhrs_epi16(x, mshift);
+    const __m128i pred =
+        _mm_loadl_epi64((__m128i const *)(output + h * stride));
+    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
+    __m128i u = _mm_packus_epi16(x, x);
+    _mm_storel_epi64((__m128i *)(output + h * stride), u);
+  }
+}
+
+static INLINE void iidentity8_row_8xn_ssse3(__m128i *out, const int32_t *input,
+                                            int stride, int shift, int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  for (int h = 0; h < height; ++h) {
+    __m128i src0 = _mm_load_si128((__m128i *)(input_row));
+    __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
+    input_row += stride;
+    __m128i x = _mm_packs_epi32(src0, src1);
+    x = _mm_adds_epi16(x, x);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity8_row_rect_8xn_ssse3(__m128i *out,
+                                                 const int32_t *input,
+                                                 int stride, int shift,
+                                                 int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 * 8);
+  for (int h = 0; h < height; ++h) {
+    __m128i src0 = _mm_load_si128((__m128i *)(input_row));
+    __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
+    input_row += stride;
+    __m128i x = _mm_packs_epi32(src0, src1);
+    x = _mm_mulhrs_epi16(x, rect_scale);
+    x = _mm_adds_epi16(x, x);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity8_col_8xn_ssse3(uint8_t *output, int stride,
+                                            __m128i *buf, int shift,
+                                            int height) {
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const __m128i zero = _mm_setzero_si128();
+  for (int h = 0; h < height; ++h) {
+    __m128i x = _mm_adds_epi16(buf[h], buf[h]);
+    x = _mm_mulhrs_epi16(x, mshift);
+    const __m128i pred =
+        _mm_loadl_epi64((__m128i const *)(output + h * stride));
+    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
+    __m128i u = _mm_packus_epi16(x, x);
+    _mm_storel_epi64((__m128i *)(output + h * stride), u);
+  }
+}
+
+static INLINE void iidentity16_row_8xn_ssse3(__m128i *out, const int32_t *input,
+                                             int stride, int shift,
+                                             int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  for (int h = 0; h < height; ++h) {
+    __m128i src = load_32bit_to_16bit(input_row);
+    input_row += stride;
+    __m128i x = _mm_mulhrs_epi16(src, scale);
+    __m128i srcx2 = _mm_adds_epi16(src, src);
+    x = _mm_adds_epi16(x, srcx2);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity16_row_rect_8xn_ssse3(__m128i *out,
+                                                  const int32_t *input,
+                                                  int stride, int shift,
+                                                  int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
+  for (int h = 0; h < height; ++h) {
+    __m128i src = load_32bit_to_16bit(input_row);
+    input_row += stride;
+    src = _mm_mulhrs_epi16(src, rect_scale);
+    __m128i x = _mm_mulhrs_epi16(src, scale);
+    __m128i srcx2 = _mm_adds_epi16(src, src);
+    x = _mm_adds_epi16(x, srcx2);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity16_col_8xn_ssse3(uint8_t *output, int stride,
+                                             __m128i *buf, int shift,
+                                             int height) {
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  const __m128i zero = _mm_setzero_si128();
+  for (int h = 0; h < height; ++h) {
+    __m128i x = _mm_mulhrs_epi16(buf[h], scale);
+    __m128i srcx2 = _mm_adds_epi16(buf[h], buf[h]);
+    x = _mm_adds_epi16(x, srcx2);
+    x = _mm_mulhrs_epi16(x, mshift);
+    const __m128i pred =
+        _mm_loadl_epi64((__m128i const *)(output + h * stride));
+    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
+    __m128i u = _mm_packus_epi16(x, x);
+    _mm_storel_epi64((__m128i *)(output + h * stride), u);
+  }
+}
+
+static INLINE void iidentity32_row_8xn_ssse3(__m128i *out, const int32_t *input,
+                                             int stride, int shift,
+                                             int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  for (int h = 0; h < height; ++h) {
+    __m128i src0 = _mm_load_si128((__m128i *)(input_row));
+    __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
+    input_row += stride;
+    __m128i x = _mm_packs_epi32(src0, src1);
+    x = _mm_adds_epi16(x, x);
+    x = _mm_adds_epi16(x, x);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity32_row_rect_8xn_ssse3(__m128i *out,
+                                                  const int32_t *input,
+                                                  int stride, int shift,
+                                                  int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 * 8);
+  for (int h = 0; h < height; ++h) {
+    __m128i src0 = _mm_load_si128((__m128i *)(input_row));
+    __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
+    input_row += stride;
+    __m128i x = _mm_packs_epi32(src0, src1);
+    x = _mm_mulhrs_epi16(x, rect_scale);
+    x = _mm_adds_epi16(x, x);
+    x = _mm_adds_epi16(x, x);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity32_col_8xn_ssse3(uint8_t *output, int stride,
+                                             __m128i *buf, int shift,
+                                             int height) {
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const __m128i zero = _mm_setzero_si128();
+  for (int h = 0; h < height; ++h) {
+    __m128i x = _mm_adds_epi16(buf[h], buf[h]);
+    x = _mm_adds_epi16(x, x);
+    x = _mm_mulhrs_epi16(x, mshift);
+    const __m128i pred =
+        _mm_loadl_epi64((__m128i const *)(output + h * stride));
+    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
+    __m128i u = _mm_packus_epi16(x, x);
+    _mm_storel_epi64((__m128i *)(output + h * stride), u);
+  }
+}
+
+static INLINE void iidentity64_row_8xn_ssse3(__m128i *out, const int32_t *input,
+                                             int stride, int shift,
+                                             int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  for (int h = 0; h < height; ++h) {
+    __m128i src = load_32bit_to_16bit(input_row);
+    input_row += stride;
+    __m128i x = _mm_mulhrs_epi16(src, scale);
+    __m128i srcx5 = _mm_adds_epi16(src, src);
+    srcx5 = _mm_adds_epi16(srcx5, srcx5);
+    srcx5 = _mm_adds_epi16(srcx5, src);
+    x = _mm_adds_epi16(x, srcx5);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity64_row_rect_8xn_ssse3(__m128i *out,
+                                                  const int32_t *input,
+                                                  int stride, int shift,
+                                                  int height) {
+  const int32_t *input_row = input;
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
+  for (int h = 0; h < height; ++h) {
+    __m128i src = load_32bit_to_16bit(input_row);
+    input_row += stride;
+    src = _mm_mulhrs_epi16(src, rect_scale);
+    __m128i x = _mm_mulhrs_epi16(src, scale);
+    __m128i srcx5 = _mm_adds_epi16(src, src);
+    srcx5 = _mm_adds_epi16(srcx5, srcx5);
+    srcx5 = _mm_adds_epi16(srcx5, src);
+    x = _mm_adds_epi16(x, srcx5);
+    out[h] = _mm_mulhrs_epi16(x, mshift);
+  }
+}
+
+static INLINE void iidentity64_col_8xn_ssse3(uint8_t *output, int stride,
+                                             __m128i *buf, int shift,
+                                             int height) {
+  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
+  const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
+  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+  const __m128i zero = _mm_setzero_si128();
+  for (int h = 0; h < height; ++h) {
+    __m128i x = _mm_mulhrs_epi16(buf[h], scale);
+    __m128i srcx5 = _mm_adds_epi16(buf[h], buf[h]);
+    srcx5 = _mm_adds_epi16(srcx5, srcx5);
+    srcx5 = _mm_adds_epi16(srcx5, buf[h]);
+    x = _mm_adds_epi16(x, srcx5);
+    x = _mm_mulhrs_epi16(x, mshift);
+    const __m128i pred =
+        _mm_loadl_epi64((__m128i const *)(output + h * stride));
+    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
+    __m128i u = _mm_packus_epi16(x, x);
+    _mm_storel_epi64((__m128i *)(output + h * stride), u);
+  }
+}
+
+static INLINE void identity_row_8xn_ssse3(__m128i *out, const int32_t *input,
+                                          int stride, int shift, int height,
+                                          int txw_idx, int rect_type) {
+  if (rect_type != 1 && rect_type != -1) {
+    switch (txw_idx) {
+      case 0:
+        iidentity4_row_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      case 1:
+        iidentity8_row_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      case 2:
+        iidentity16_row_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      case 3:
+        iidentity32_row_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      case 4:
+        iidentity64_row_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      default: break;
+    }
+  } else {
+    switch (txw_idx) {
+      case 0:
+        iidentity4_row_rect_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      case 1:
+        iidentity8_row_rect_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      case 2:
+        iidentity16_row_rect_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      case 3:
+        iidentity32_row_rect_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      case 4:
+        iidentity64_row_rect_8xn_ssse3(out, input, stride, shift, height);
+        break;
+      default: break;
+    }
+  }
+}
+
+static INLINE void identity_col_8xn_ssse3(uint8_t *output, int stride,
+                                          __m128i *buf, int shift, int height,
+                                          int txh_idx) {
+  switch (txh_idx) {
+    case 0: iidentity4_col_8xn_ssse3(output, stride, buf, shift, height); break;
+    case 1: iidentity8_col_8xn_ssse3(output, stride, buf, shift, height); break;
+    case 2:
+      iidentity16_col_8xn_ssse3(output, stride, buf, shift, height);
+      break;
+    case 3:
+      iidentity32_col_8xn_ssse3(output, stride, buf, shift, height);
+      break;
+    case 4:
+      iidentity64_col_8xn_ssse3(output, stride, buf, shift, height);
+      break;
+    default: break;
+  }
+}
+
+static INLINE void lowbd_inv_txfm2d_add_idtx_ssse3(const int32_t *input,
+                                                   uint8_t *output, int stride,
+                                                   TX_SIZE tx_size) {
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int input_stride = AOMMIN(32, txfm_size_col);
+  const int row_max = AOMMIN(32, txfm_size_row);
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+  __m128i buf[32];
+
+  for (int i = 0; i<input_stride>> 3; ++i) {
+    identity_row_8xn_ssse3(buf, input + 8 * i, input_stride, shift[0], row_max,
+                           txw_idx, rect_type);
+    identity_col_8xn_ssse3(output + 8 * i, stride, buf, shift[1], row_max,
+                           txh_idx);
+  }
+}
+
 void av1_lowbd_inv_txfm2d_add_4x4_ssse3(const int32_t *input, uint8_t *output,
                                         int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
@@ -2383,8 +2664,9 @@
   return _mm_packus_epi16(x0, x1);
 }
 
-static void lowbd_write_buffer_16xn_sse2(__m128i *in, uint8_t *output,
-                                         int stride, int flipud, int height) {
+static INLINE void lowbd_write_buffer_16xn_sse2(__m128i *in, uint8_t *output,
+                                                int stride, int flipud,
+                                                int height) {
   int j = flipud ? (height - 1) : 0;
   const int step = flipud ? -1 : 1;
   for (int i = 0; i < height; ++i, j += step) {
@@ -2394,28 +2676,19 @@
   }
 }
 
-static INLINE void round_shift_sse2(const __m128i *input, __m128i *output,
-                                    int size) {
-  const __m128i scale = _mm_set1_epi16(NewInvSqrt2);
-  const __m128i rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
-  const __m128i one = _mm_set1_epi16(1);
-  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
+static INLINE void round_shift_ssse3(const __m128i *input, __m128i *output,
+                                     int size) {
+  const __m128i scale = _mm_set1_epi16(NewInvSqrt2 * 8);
   for (int i = 0; i < size; ++i) {
-    __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
-    __m128i a_hi = _mm_unpackhi_epi16(input[i], one);
-    __m128i b_lo = _mm_madd_epi16(a_lo, scale_rounding);
-    __m128i b_hi = _mm_madd_epi16(a_hi, scale_rounding);
-    __m128i c_lo = _mm_srai_epi32(b_lo, NewSqrt2Bits);
-    __m128i c_hi = _mm_srai_epi32(b_hi, NewSqrt2Bits);
-    output[i] = _mm_packs_epi32(c_lo, c_hi);
+    output[i] = _mm_mulhrs_epi16(input[i], scale);
   }
 }
 
-static INLINE void lowbd_inv_txfm2d_add_internal_ssse3(const int32_t *input,
-                                                       uint8_t *output,
-                                                       int stride,
-                                                       TX_TYPE tx_type,
-                                                       TX_SIZE tx_size) {
+static INLINE void lowbd_inv_txfm2d_add_no_identity_ssse3(const int32_t *input,
+                                                          uint8_t *output,
+                                                          int stride,
+                                                          TX_TYPE tx_type,
+                                                          TX_SIZE tx_size) {
   __m128i buf1[64 * 8];
   const int8_t *shift = inv_txfm_shift_ls[tx_size];
   const int txw_idx = get_txw_idx(tx_size);
@@ -2447,7 +2720,7 @@
       transpose_16bit_8x8(buf0_cur, buf0_cur);
     }
     if (rect_type == 1 || rect_type == -1) {
-      round_shift_sse2(buf0, buf0, input_stride);  // rect special code
+      round_shift_ssse3(buf0, buf0, input_stride);  // rect special code
     }
     row_txfm(buf0, buf0, cos_bit_row);
     round_shift_16bit(buf0, txfm_size_col, shift[0]);
@@ -2481,25 +2754,152 @@
   }
 }
 
+static INLINE void lowbd_inv_txfm2d_add_h_identity_ssse3(const int32_t *input,
+                                                         uint8_t *output,
+                                                         int stride,
+                                                         TX_TYPE tx_type,
+                                                         TX_SIZE tx_size) {
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int txfm_size_col_notzero = AOMMIN(32, txfm_size_col);
+  const int txfm_size_row_notzero = AOMMIN(32, txfm_size_row);
+  const int buf_size_w_div8 = txfm_size_col >> 3;
+  const int input_stride = txfm_size_col_notzero;
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+
+  const transform_1d_ssse3 col_txfm =
+      lowbd_txfm_all_1d_w8_arr[txh_idx][vitx_1d_tab[tx_type]];
+
+  assert(col_txfm != NULL);
+
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  for (int i = 0; i < AOMMIN(4, buf_size_w_div8); i++) {
+    __m128i buf0[64];
+    identity_row_8xn_ssse3(buf0, input + 8 * i, input_stride, shift[0],
+                           txfm_size_row_notzero, txw_idx, rect_type);
+    col_txfm(buf0, buf0, cos_bit_col);
+    __m128i mshift = _mm_set1_epi16(1 << (15 + shift[1]));
+    int k = ud_flip ? (txfm_size_row - 1) : 0;
+    const int step = ud_flip ? -1 : 1;
+    for (int j = 0; j < txfm_size_row; ++j, k += step) {
+      const __m128i v =
+          _mm_loadl_epi64((__m128i const *)(output + 8 * i + j * stride));
+      __m128i res = _mm_mulhrs_epi16(buf0[k], mshift);
+      const __m128i u = lowbd_get_recon_8x8_sse2(v, res);
+      _mm_storel_epi64((__m128i *)(output + 8 * i + j * stride), u);
+    }
+  }
+}
+
+static INLINE void lowbd_inv_txfm2d_add_v_identity_ssse3(const int32_t *input,
+                                                         uint8_t *output,
+                                                         int stride,
+                                                         TX_TYPE tx_type,
+                                                         TX_SIZE tx_size) {
+  __m128i buf1[64];
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int buf_size_w_div8 = txfm_size_col >> 3;
+  const int buf_size_h_div8 = AOMMIN(32, txfm_size_row) >> 3;
+  const int input_stride = AOMMIN(32, txfm_size_col);
+  const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
+
+  const transform_1d_ssse3 row_txfm =
+      lowbd_txfm_all_1d_w8_arr[txw_idx][hitx_1d_tab[tx_type]];
+
+  assert(row_txfm != NULL);
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  for (int i = 0; i < buf_size_h_div8; i++) {
+    __m128i buf0[64];
+    const int32_t *input_row = input + i * input_stride * 8;
+    for (int j = 0; j < AOMMIN(4, buf_size_w_div8); ++j) {
+      __m128i *buf0_cur = buf0 + j * 8;
+      load_buffer_32bit_to_16bit(input_row + j * 8, input_stride, buf0_cur, 8);
+      transpose_16bit_8x8(buf0_cur, buf0_cur);
+    }
+    if (rect_type == 1 || rect_type == -1) {
+      round_shift_ssse3(buf0, buf0, input_stride);  // rect special code
+    }
+    row_txfm(buf0, buf0, cos_bit_row);
+    round_shift_16bit(buf0, txfm_size_col, shift[0]);
+    __m128i *_buf1 = buf1;
+    if (lr_flip) {
+      for (int j = 0; j < buf_size_w_div8; ++j) {
+        __m128i temp[8];
+        flip_buf_sse2(buf0 + 8 * j, temp, 8);
+        transpose_16bit_8x8(temp, _buf1 + 8 * (buf_size_w_div8 - 1 - j));
+      }
+    } else {
+      for (int j = 0; j < buf_size_w_div8; ++j) {
+        transpose_16bit_8x8(buf0 + 8 * j, _buf1 + 8 * j);
+      }
+    }
+
+    for (int j = 0; j < buf_size_w_div8; ++j) {
+      identity_col_8xn_ssse3(output + i * 8 * stride + j * 8, stride,
+                             buf1 + j * 8, shift[1], 8, txh_idx);
+    }
+  }
+}
+
+// for 32x32,32x64,64x32,64x64,32x8,8x32,16x32,32x16,64x16,16x64
+static INLINE void lowbd_inv_txfm2d_add_ssse3(const int32_t *input,
+                                              uint8_t *output, int stride,
+                                              TX_TYPE tx_type,
+                                              TX_SIZE tx_size) {
+  switch (tx_type) {
+    case DCT_DCT:
+      lowbd_inv_txfm2d_add_no_identity_ssse3(input, output, stride, tx_type,
+                                             tx_size);
+      break;
+    case IDTX:
+      lowbd_inv_txfm2d_add_idtx_ssse3(input, output, stride, tx_size);
+      break;
+    case V_DCT:
+    case V_ADST:
+    case V_FLIPADST:
+      lowbd_inv_txfm2d_add_h_identity_ssse3(input, output, stride, tx_type,
+                                            tx_size);
+      break;
+    case H_DCT:
+    case H_ADST:
+    case H_FLIPADST:
+      lowbd_inv_txfm2d_add_v_identity_ssse3(input, output, stride, tx_type,
+                                            tx_size);
+      break;
+    default:
+      lowbd_inv_txfm2d_add_no_identity_ssse3(input, output, stride, tx_type,
+                                             tx_size);
+      break;
+  }
+}
+
 void av1_lowbd_inv_txfm2d_add_16x16_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_16X16);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_16X16);
 }
 
 void av1_lowbd_inv_txfm2d_add_32x32_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_32X32);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_32X32);
 }
 
 void av1_lowbd_inv_txfm2d_add_64x64_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  // TODO(binpengsmail@gmail.com):
-  // To add dedicate functions for inv txfm include identity type
-  // Should be simpler and faster then the general one
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_64X64);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_64X64);
 }
 
 void av1_lowbd_inv_txfm2d_add_4x8_ssse3(const int32_t *input, uint8_t *output,
@@ -2524,7 +2924,7 @@
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   load_buffer_32bit_to_16bit_w4(input, txfm_size_col, buf, txfm_size_row);
   transpose_16bit_4x8(buf, buf);
-  round_shift_sse2(buf, buf, txfm_size_col);  // rect special code
+  round_shift_ssse3(buf, buf, txfm_size_col);  // rect special code
   row_txfm(buf, buf, cos_bit_row);
   // round_shift_16bit(buf, txfm_size_col, shift[0]);// shift[0] is 0
   if (lr_flip) {
@@ -2561,7 +2961,7 @@
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   load_buffer_32bit_to_16bit(input, txfm_size_col, buf, txfm_size_row);
   transpose_16bit_8x4(buf, buf);
-  round_shift_sse2(buf, buf, txfm_size_col);  // rect special code
+  round_shift_ssse3(buf, buf, txfm_size_col);  // rect special code
   row_txfm(buf, buf, cos_bit_row);
   // round_shift_16bit(buf, txfm_size_col, shift[0]); // shift[0] is 0
   if (lr_flip) {
@@ -2579,37 +2979,37 @@
 void av1_lowbd_inv_txfm2d_add_8x16_ssse3(const int32_t *input, uint8_t *output,
                                          int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_8X16);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_8X16);
 }
 
 void av1_lowbd_inv_txfm2d_add_16x8_ssse3(const int32_t *input, uint8_t *output,
                                          int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_16X8);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_16X8);
 }
 
 void av1_lowbd_inv_txfm2d_add_16x32_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_16X32);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_16X32);
 }
 
 void av1_lowbd_inv_txfm2d_add_32x16_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_32X16);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_32X16);
 }
 
 void av1_lowbd_inv_txfm2d_add_32x64_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_32X64);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_32X64);
 }
 
 void av1_lowbd_inv_txfm2d_add_64x32_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_64X32);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_64X32);
 }
 
 void av1_lowbd_inv_txfm2d_add_4x16_ssse3(const int32_t *input, uint8_t *output,
@@ -2706,25 +3106,25 @@
 void av1_lowbd_inv_txfm2d_add_8x32_ssse3(const int32_t *input, uint8_t *output,
                                          int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_8X32);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_8X32);
 }
 
 void av1_lowbd_inv_txfm2d_add_32x8_ssse3(const int32_t *input, uint8_t *output,
                                          int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_32X8);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_32X8);
 }
 
 void av1_lowbd_inv_txfm2d_add_16x64_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_16X64);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_16X64);
 }
 
 void av1_lowbd_inv_txfm2d_add_64x16_ssse3(const int32_t *input, uint8_t *output,
                                           int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
-  lowbd_inv_txfm2d_add_internal_ssse3(input, output, stride, tx_type, TX_64X16);
+  lowbd_inv_txfm2d_add_ssse3(input, output, stride, tx_type, TX_64X16);
 }
 
 typedef void (*inv_txfm_func)(const int32_t *input, uint8_t *output, int stride,