Sync inv identity txfm ssse3/avx2 with C version

Due to the change in CL50009.
https://aomedia-review.googlesource.com/c/aom/+/50009
The behaviour of identity txfm ssse3 and avx2 version
aren't match with C version.

Change-Id: I68886f3f37f586cf587b3c3cd31de04eab6b5e4a
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c
index b36819d..899d068 100644
--- a/av1/common/x86/av1_inv_txfm_avx2.c
+++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -863,184 +863,70 @@
   }
 }
 
-static INLINE void iidentity16_row_16xn_avx2(__m256i *out, const int32_t *input,
-                                             int stride, int shift, int height,
-                                             int rect) {
+static INLINE void iidentity_row_16xn_avx2(__m256i *out, const int32_t *input,
+                                           int stride, int shift, int height,
+                                           int txw_idx, int rect_type) {
   const int32_t *input_row = input;
-  const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
-  const __m256i scale =
-      _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  if (!rect) {
-    for (int h = 0; h < height; ++h) {
+  const __m256i scale = _mm256_set1_epi16(NewSqrt2list[txw_idx]);
+  const __m256i rounding = _mm256_set1_epi16((1 << (NewSqrt2Bits - 1)) +
+                                             (1 << (NewSqrt2Bits - shift - 1)));
+  const __m256i one = _mm256_set1_epi16(1);
+  const __m256i scale_rounding = _mm256_unpacklo_epi16(scale, rounding);
+  if (rect_type != 1 && rect_type != -1) {
+    for (int i = 0; i < height; ++i) {
       __m256i src = load_32bit_to_16bit_w16_avx2(input_row);
       input_row += stride;
-      __m256i x = _mm256_mulhrs_epi16(src, scale);
-      __m256i srcx2 = _mm256_adds_epi16(src, src);
-      x = _mm256_adds_epi16(x, srcx2);
-      out[h] = _mm256_mulhrs_epi16(x, mshift);
+      __m256i lo = _mm256_unpacklo_epi16(src, one);
+      __m256i hi = _mm256_unpackhi_epi16(src, one);
+      lo = _mm256_madd_epi16(lo, scale_rounding);
+      hi = _mm256_madd_epi16(hi, scale_rounding);
+      lo = _mm256_srai_epi32(lo, NewSqrt2Bits - shift);
+      hi = _mm256_srai_epi32(hi, NewSqrt2Bits - shift);
+      out[i] = _mm256_packs_epi32(lo, hi);
     }
   } else {
     const __m256i rect_scale =
         _mm256_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
-    for (int h = 0; h < height; ++h) {
+    for (int i = 0; i < height; ++i) {
       __m256i src = load_32bit_to_16bit_w16_avx2(input_row);
-      input_row += stride;
       src = _mm256_mulhrs_epi16(src, rect_scale);
-      __m256i x = _mm256_mulhrs_epi16(src, scale);
-      __m256i srcx2 = _mm256_adds_epi16(src, src);
-      x = _mm256_adds_epi16(x, srcx2);
-      out[h] = _mm256_mulhrs_epi16(x, mshift);
+      input_row += stride;
+      __m256i lo = _mm256_unpacklo_epi16(src, one);
+      __m256i hi = _mm256_unpackhi_epi16(src, one);
+      lo = _mm256_madd_epi16(lo, scale_rounding);
+      hi = _mm256_madd_epi16(hi, scale_rounding);
+      lo = _mm256_srai_epi32(lo, NewSqrt2Bits - shift);
+      hi = _mm256_srai_epi32(hi, NewSqrt2Bits - shift);
+      out[i] = _mm256_packs_epi32(lo, hi);
     }
   }
 }
 
-static INLINE void iidentity16_col_16xn_avx2(uint8_t *output, int stride,
-                                             __m256i *buf, int shift,
-                                             int height) {
-  const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
-  const __m256i scale =
-      _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+static INLINE void iidentity_col_16xn_avx2(uint8_t *output, int stride,
+                                           __m256i *buf, int shift, int height,
+                                           int txh_idx) {
+  const __m256i scale = _mm256_set1_epi16(NewSqrt2list[txh_idx]);
+  const __m256i scale_rounding = _mm256_set1_epi16(1 << (NewSqrt2Bits - 1));
+  const __m256i shift_rounding = _mm256_set1_epi32(1 << (-shift - 1));
+  const __m256i one = _mm256_set1_epi16(1);
+  const __m256i scale_coeff = _mm256_unpacklo_epi16(scale, scale_rounding);
   for (int h = 0; h < height; ++h) {
-    __m256i x = _mm256_mulhrs_epi16(buf[h], scale);
-    __m256i srcx2 = _mm256_adds_epi16(buf[h], buf[h]);
-    x = _mm256_adds_epi16(x, srcx2);
-    x = _mm256_mulhrs_epi16(x, mshift);
+    __m256i lo = _mm256_unpacklo_epi16(buf[h], one);
+    __m256i hi = _mm256_unpackhi_epi16(buf[h], one);
+    lo = _mm256_madd_epi16(lo, scale_coeff);
+    hi = _mm256_madd_epi16(hi, scale_coeff);
+    lo = _mm256_srai_epi32(lo, NewSqrt2Bits);
+    hi = _mm256_srai_epi32(hi, NewSqrt2Bits);
+    lo = _mm256_add_epi32(lo, shift_rounding);
+    hi = _mm256_add_epi32(hi, shift_rounding);
+    lo = _mm256_srai_epi32(lo, -shift);
+    hi = _mm256_srai_epi32(hi, -shift);
+    __m256i x = _mm256_packs_epi32(lo, hi);
     write_recon_w16_avx2(x, output);
     output += stride;
   }
 }
 
-static INLINE void iidentity32_row_16xn_avx2(__m256i *out, const int32_t *input,
-                                             int stride, int shift, int height,
-                                             int rect) {
-  const int32_t *input_row = input;
-  const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
-  if (!rect) {
-    for (int h = 0; h < height; ++h) {
-      __m256i x = load_32bit_to_16bit_w16_avx2(input_row);
-      input_row += stride;
-      x = _mm256_adds_epi16(x, x);
-      x = _mm256_adds_epi16(x, x);
-      out[h] = _mm256_mulhrs_epi16(x, mshift);
-    }
-  } else {
-    const __m256i rect_scale = _mm256_set1_epi16(NewInvSqrt2 * 8);
-    for (int h = 0; h < height; ++h) {
-      __m256i x = load_32bit_to_16bit_w16_avx2(input_row);
-      input_row += stride;
-      x = _mm256_mulhrs_epi16(x, rect_scale);
-      x = _mm256_adds_epi16(x, x);
-      x = _mm256_adds_epi16(x, x);
-      out[h] = _mm256_mulhrs_epi16(x, mshift);
-    }
-  }
-}
-
-static INLINE void iidentity32_col_16xn_avx2(uint8_t *output, int stride,
-                                             __m256i *buf, int shift,
-                                             int height) {
-  const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
-  for (int h = 0; h < height; ++h) {
-    __m256i x = _mm256_adds_epi16(buf[h], buf[h]);
-    x = _mm256_adds_epi16(x, x);
-    x = _mm256_mulhrs_epi16(x, mshift);
-    write_recon_w16_avx2(x, output);
-    output += stride;
-  }
-}
-
-static INLINE void iidentity64_row_16xn_avx2(__m256i *out, const int32_t *input,
-                                             int stride, int shift, int height,
-                                             int rect) {
-  const int32_t *input_row = input;
-  const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
-  const __m256i scale =
-      _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  if (!rect) {
-    for (int h = 0; h < height; ++h) {
-      __m256i src = load_32bit_to_16bit_w16_avx2(input_row);
-      input_row += stride;
-      __m256i x = _mm256_mulhrs_epi16(src, scale);
-      __m256i srcx5 = _mm256_adds_epi16(src, src);
-      srcx5 = _mm256_adds_epi16(srcx5, srcx5);
-      srcx5 = _mm256_adds_epi16(srcx5, src);
-      x = _mm256_adds_epi16(x, srcx5);
-      out[h] = _mm256_mulhrs_epi16(x, mshift);
-    }
-  } else {
-    const __m256i rect_scale =
-        _mm256_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
-    for (int h = 0; h < height; ++h) {
-      __m256i src = load_32bit_to_16bit_w16_avx2(input_row);
-      input_row += stride;
-      src = _mm256_mulhrs_epi16(src, rect_scale);
-      __m256i x = _mm256_mulhrs_epi16(src, scale);
-      __m256i srcx5 = _mm256_adds_epi16(src, src);
-      srcx5 = _mm256_adds_epi16(srcx5, srcx5);
-      srcx5 = _mm256_adds_epi16(srcx5, src);
-      x = _mm256_adds_epi16(x, srcx5);
-      out[h] = _mm256_mulhrs_epi16(x, mshift);
-    }
-  }
-}
-
-static INLINE void iidentity64_col_16xn_avx2(uint8_t *output, int stride,
-                                             __m256i *buf, int shift,
-                                             int height) {
-  const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
-  const __m256i scale =
-      _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  for (int h = 0; h < height; ++h) {
-    __m256i x = _mm256_mulhrs_epi16(buf[h], scale);
-    __m256i srcx5 = _mm256_adds_epi16(buf[h], buf[h]);
-    srcx5 = _mm256_adds_epi16(srcx5, srcx5);
-    srcx5 = _mm256_adds_epi16(srcx5, buf[h]);
-    x = _mm256_adds_epi16(x, srcx5);
-    x = _mm256_mulhrs_epi16(x, mshift);
-    write_recon_w16_avx2(x, output);
-    output += stride;
-  }
-}
-
-static INLINE void identity_row_16xn_avx2(__m256i *out, const int32_t *input,
-                                          int stride, int shift, int height,
-                                          int txw_idx, int rect_type) {
-  int rect = (rect_type != 1 && rect_type != -1) ? 0 : 1;
-  switch (txw_idx) {
-    case 2:
-      iidentity16_row_16xn_avx2(out, input, stride, shift, height, rect);
-      break;
-    case 3:
-      iidentity32_row_16xn_avx2(out, input, stride, shift, height, rect);
-      break;
-    case 4:
-      iidentity64_row_16xn_avx2(out, input, stride, shift, height, rect);
-      break;
-    default: break;
-  }
-}
-
-static INLINE void identity_col_16xn_avx2(uint8_t *output, int stride,
-                                          __m256i *buf, int shift, int height,
-                                          int txh_idx) {
-  switch (txh_idx) {
-    case 2:
-      iidentity16_col_16xn_avx2(output, stride, buf, shift, height);
-      break;
-    case 3:
-      iidentity32_col_16xn_avx2(output, stride, buf, shift, height);
-      break;
-    case 4:
-      iidentity64_col_16xn_avx2(output, stride, buf, shift, height);
-      break;
-    default: break;
-  }
-}
-
 static INLINE void lowbd_inv_txfm2d_add_idtx_avx2(const int32_t *input,
                                                   uint8_t *output, int stride,
                                                   TX_SIZE tx_size) {
@@ -1054,9 +940,10 @@
   const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
   __m256i buf[32];
   for (int i = 0; i < input_stride; i += 16) {
-    identity_row_16xn_avx2(buf, input + i, input_stride, shift[0], row_max,
-                           txw_idx, rect_type);
-    identity_col_16xn_avx2(output + i, stride, buf, shift[1], row_max, txh_idx);
+    iidentity_row_16xn_avx2(buf, input + i, input_stride, shift[0], row_max,
+                            txw_idx, rect_type);
+    iidentity_col_16xn_avx2(output + i, stride, buf, shift[1], row_max,
+                            txh_idx);
   }
 }
 
@@ -1085,8 +972,8 @@
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   for (int i = 0; i < txfm_size_col_notzero; i += 16) {
     __m256i buf0[64];
-    identity_row_16xn_avx2(buf0, input + i, input_stride, shift[0],
-                           txfm_size_row_notzero, txw_idx, rect_type);
+    iidentity_row_16xn_avx2(buf0, input + i, input_stride, shift[0],
+                            txfm_size_row_notzero, txw_idx, rect_type);
     col_txfm(buf0, buf0, cos_bit_col);
     __m256i mshift = _mm256_set1_epi16(1 << (15 + shift[1]));
     int k = ud_flip ? (txfm_size_row - 1) : 0;
@@ -1149,8 +1036,8 @@
       }
     }
     for (int j = 0; j < buf_size_w_div16; ++j) {
-      identity_col_16xn_avx2(output + i * 16 * stride + j * 16, stride,
-                             buf1 + j * 16, shift[1], 16, txh_idx);
+      iidentity_col_16xn_avx2(output + i * 16 * stride + j * 16, stride,
+                              buf1 + j * 16, shift[1], 16, txh_idx);
     }
   }
 }
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c
index 25cac55..b9706e9 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.c
+++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -1563,339 +1563,72 @@
       { NULL, NULL, NULL },
     };
 
-static INLINE void iidentity4_row_8xn_ssse3(__m128i *out, const int32_t *input,
-                                            int stride, int shift, int height) {
+static INLINE void iidentity_row_8xn_ssse3(__m128i *out, const int32_t *input,
+                                           int stride, int shift, int height,
+                                           int txw_idx, int rect_type) {
   const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  for (int h = 0; h < height; ++h) {
-    __m128i src = load_32bit_to_16bit(input_row);
-    input_row += stride;
-    __m128i x = _mm_mulhrs_epi16(src, scale);
-    x = _mm_adds_epi16(x, src);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity4_row_rect_8xn_ssse3(__m128i *out,
-                                                 const int32_t *input,
-                                                 int stride, int shift,
-                                                 int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
-  for (int h = 0; h < height; ++h) {
-    __m128i src = load_32bit_to_16bit(input_row);
-    input_row += stride;
-    src = _mm_mulhrs_epi16(src, rect_scale);
-    __m128i x = _mm_mulhrs_epi16(src, scale);
-    x = _mm_adds_epi16(x, src);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity4_col_8xn_ssse3(uint8_t *output, int stride,
-                                            __m128i *buf, int shift,
-                                            int height) {
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  const __m128i zero = _mm_setzero_si128();
-  for (int h = 0; h < height; ++h) {
-    __m128i x = _mm_mulhrs_epi16(buf[h], scale);
-    x = _mm_adds_epi16(x, buf[h]);
-    x = _mm_mulhrs_epi16(x, mshift);
-    const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
-    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
-    __m128i u = _mm_packus_epi16(x, x);
-    _mm_storel_epi64((__m128i *)(output), u);
-    output += stride;
-  }
-}
-
-static INLINE void iidentity8_row_8xn_ssse3(__m128i *out, const int32_t *input,
-                                            int stride, int shift, int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  for (int h = 0; h < height; ++h) {
-    __m128i src0 = _mm_load_si128((__m128i *)(input_row));
-    __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
-    input_row += stride;
-    __m128i x = _mm_packs_epi32(src0, src1);
-    x = _mm_adds_epi16(x, x);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity8_row_rect_8xn_ssse3(__m128i *out,
-                                                 const int32_t *input,
-                                                 int stride, int shift,
-                                                 int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 * 8);
-  for (int h = 0; h < height; ++h) {
-    __m128i src0 = _mm_load_si128((__m128i *)(input_row));
-    __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
-    input_row += stride;
-    __m128i x = _mm_packs_epi32(src0, src1);
-    x = _mm_mulhrs_epi16(x, rect_scale);
-    x = _mm_adds_epi16(x, x);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity8_col_8xn_ssse3(uint8_t *output, int stride,
-                                            __m128i *buf, int shift,
-                                            int height) {
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const __m128i zero = _mm_setzero_si128();
-  for (int h = 0; h < height; ++h) {
-    __m128i x = _mm_adds_epi16(buf[h], buf[h]);
-    x = _mm_mulhrs_epi16(x, mshift);
-    const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
-    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
-    __m128i u = _mm_packus_epi16(x, x);
-    _mm_storel_epi64((__m128i *)(output), u);
-    output += stride;
-  }
-}
-
-static INLINE void iidentity16_row_8xn_ssse3(__m128i *out, const int32_t *input,
-                                             int stride, int shift,
-                                             int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  for (int h = 0; h < height; ++h) {
-    __m128i src = load_32bit_to_16bit(input_row);
-    input_row += stride;
-    __m128i x = _mm_mulhrs_epi16(src, scale);
-    __m128i srcx2 = _mm_adds_epi16(src, src);
-    x = _mm_adds_epi16(x, srcx2);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity16_row_rect_8xn_ssse3(__m128i *out,
-                                                  const int32_t *input,
-                                                  int stride, int shift,
-                                                  int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
-  for (int h = 0; h < height; ++h) {
-    __m128i src = load_32bit_to_16bit(input_row);
-    input_row += stride;
-    src = _mm_mulhrs_epi16(src, rect_scale);
-    __m128i x = _mm_mulhrs_epi16(src, scale);
-    __m128i srcx2 = _mm_adds_epi16(src, src);
-    x = _mm_adds_epi16(x, srcx2);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity16_col_8xn_ssse3(uint8_t *output, int stride,
-                                             __m128i *buf, int shift,
-                                             int height) {
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  const __m128i zero = _mm_setzero_si128();
-  for (int h = 0; h < height; ++h) {
-    __m128i x = _mm_mulhrs_epi16(buf[h], scale);
-    __m128i srcx2 = _mm_adds_epi16(buf[h], buf[h]);
-    x = _mm_adds_epi16(x, srcx2);
-    x = _mm_mulhrs_epi16(x, mshift);
-    const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
-    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
-    __m128i u = _mm_packus_epi16(x, x);
-    _mm_storel_epi64((__m128i *)(output), u);
-    output += stride;
-  }
-}
-
-static INLINE void iidentity32_row_8xn_ssse3(__m128i *out, const int32_t *input,
-                                             int stride, int shift,
-                                             int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  for (int h = 0; h < height; ++h) {
-    __m128i src0 = _mm_load_si128((__m128i *)(input_row));
-    __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
-    input_row += stride;
-    __m128i x = _mm_packs_epi32(src0, src1);
-    x = _mm_adds_epi16(x, x);
-    x = _mm_adds_epi16(x, x);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity32_row_rect_8xn_ssse3(__m128i *out,
-                                                  const int32_t *input,
-                                                  int stride, int shift,
-                                                  int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 * 8);
-  for (int h = 0; h < height; ++h) {
-    __m128i src0 = _mm_load_si128((__m128i *)(input_row));
-    __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
-    input_row += stride;
-    __m128i x = _mm_packs_epi32(src0, src1);
-    x = _mm_mulhrs_epi16(x, rect_scale);
-    x = _mm_adds_epi16(x, x);
-    x = _mm_adds_epi16(x, x);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity32_col_8xn_ssse3(uint8_t *output, int stride,
-                                             __m128i *buf, int shift,
-                                             int height) {
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const __m128i zero = _mm_setzero_si128();
-  for (int h = 0; h < height; ++h) {
-    __m128i x = _mm_adds_epi16(buf[h], buf[h]);
-    x = _mm_adds_epi16(x, x);
-    x = _mm_mulhrs_epi16(x, mshift);
-    const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
-    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
-    __m128i u = _mm_packus_epi16(x, x);
-    _mm_storel_epi64((__m128i *)(output), u);
-    output += stride;
-  }
-}
-
-static INLINE void iidentity64_row_8xn_ssse3(__m128i *out, const int32_t *input,
-                                             int stride, int shift,
-                                             int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  for (int h = 0; h < height; ++h) {
-    __m128i src = load_32bit_to_16bit(input_row);
-    input_row += stride;
-    __m128i x = _mm_mulhrs_epi16(src, scale);
-    __m128i srcx5 = _mm_adds_epi16(src, src);
-    srcx5 = _mm_adds_epi16(srcx5, srcx5);
-    srcx5 = _mm_adds_epi16(srcx5, src);
-    x = _mm_adds_epi16(x, srcx5);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity64_row_rect_8xn_ssse3(__m128i *out,
-                                                  const int32_t *input,
-                                                  int stride, int shift,
-                                                  int height) {
-  const int32_t *input_row = input;
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
-  for (int h = 0; h < height; ++h) {
-    __m128i src = load_32bit_to_16bit(input_row);
-    input_row += stride;
-    src = _mm_mulhrs_epi16(src, rect_scale);
-    __m128i x = _mm_mulhrs_epi16(src, scale);
-    __m128i srcx5 = _mm_adds_epi16(src, src);
-    srcx5 = _mm_adds_epi16(srcx5, srcx5);
-    srcx5 = _mm_adds_epi16(srcx5, src);
-    x = _mm_adds_epi16(x, srcx5);
-    out[h] = _mm_mulhrs_epi16(x, mshift);
-  }
-}
-
-static INLINE void iidentity64_col_8xn_ssse3(uint8_t *output, int stride,
-                                             __m128i *buf, int shift,
-                                             int height) {
-  const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
-  const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
-  const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
-  const __m128i zero = _mm_setzero_si128();
-  for (int h = 0; h < height; ++h) {
-    __m128i x = _mm_mulhrs_epi16(buf[h], scale);
-    __m128i srcx5 = _mm_adds_epi16(buf[h], buf[h]);
-    srcx5 = _mm_adds_epi16(srcx5, srcx5);
-    srcx5 = _mm_adds_epi16(srcx5, buf[h]);
-    x = _mm_adds_epi16(x, srcx5);
-    x = _mm_mulhrs_epi16(x, mshift);
-    const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
-    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
-    __m128i u = _mm_packus_epi16(x, x);
-    _mm_storel_epi64((__m128i *)(output), u);
-    output += stride;
-  }
-}
-
-static INLINE void identity_row_8xn_ssse3(__m128i *out, const int32_t *input,
-                                          int stride, int shift, int height,
-                                          int txw_idx, int rect_type) {
+  const __m128i scale = _mm_set1_epi16(NewSqrt2list[txw_idx]);
+  const __m128i rounding = _mm_set1_epi16((1 << (NewSqrt2Bits - 1)) +
+                                          (1 << (NewSqrt2Bits - shift - 1)));
+  const __m128i one = _mm_set1_epi16(1);
+  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
   if (rect_type != 1 && rect_type != -1) {
-    switch (txw_idx) {
-      case 0:
-        iidentity4_row_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      case 1:
-        iidentity8_row_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      case 2:
-        iidentity16_row_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      case 3:
-        iidentity32_row_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      case 4:
-        iidentity64_row_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      default: break;
+    for (int i = 0; i < height; ++i) {
+      __m128i src = load_32bit_to_16bit(input_row);
+      input_row += stride;
+      __m128i lo = _mm_unpacklo_epi16(src, one);
+      __m128i hi = _mm_unpackhi_epi16(src, one);
+      lo = _mm_madd_epi16(lo, scale_rounding);
+      hi = _mm_madd_epi16(hi, scale_rounding);
+      lo = _mm_srai_epi32(lo, NewSqrt2Bits - shift);
+      hi = _mm_srai_epi32(hi, NewSqrt2Bits - shift);
+      out[i] = _mm_packs_epi32(lo, hi);
     }
   } else {
-    switch (txw_idx) {
-      case 0:
-        iidentity4_row_rect_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      case 1:
-        iidentity8_row_rect_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      case 2:
-        iidentity16_row_rect_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      case 3:
-        iidentity32_row_rect_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      case 4:
-        iidentity64_row_rect_8xn_ssse3(out, input, stride, shift, height);
-        break;
-      default: break;
+    const __m128i rect_scale =
+        _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
+    for (int i = 0; i < height; ++i) {
+      __m128i src = load_32bit_to_16bit(input_row);
+      src = _mm_mulhrs_epi16(src, rect_scale);
+      input_row += stride;
+      __m128i lo = _mm_unpacklo_epi16(src, one);
+      __m128i hi = _mm_unpackhi_epi16(src, one);
+      lo = _mm_madd_epi16(lo, scale_rounding);
+      hi = _mm_madd_epi16(hi, scale_rounding);
+      lo = _mm_srai_epi32(lo, NewSqrt2Bits - shift);
+      hi = _mm_srai_epi32(hi, NewSqrt2Bits - shift);
+      out[i] = _mm_packs_epi32(lo, hi);
     }
   }
 }
 
-static INLINE void identity_col_8xn_ssse3(uint8_t *output, int stride,
-                                          __m128i *buf, int shift, int height,
-                                          int txh_idx) {
-  switch (txh_idx) {
-    case 0: iidentity4_col_8xn_ssse3(output, stride, buf, shift, height); break;
-    case 1: iidentity8_col_8xn_ssse3(output, stride, buf, shift, height); break;
-    case 2:
-      iidentity16_col_8xn_ssse3(output, stride, buf, shift, height);
-      break;
-    case 3:
-      iidentity32_col_8xn_ssse3(output, stride, buf, shift, height);
-      break;
-    case 4:
-      iidentity64_col_8xn_ssse3(output, stride, buf, shift, height);
-      break;
-    default: break;
+static INLINE void iidentity_col_8xn_ssse3(uint8_t *output, int stride,
+                                           __m128i *buf, int shift, int height,
+                                           int txh_idx) {
+  const __m128i scale = _mm_set1_epi16(NewSqrt2list[txh_idx]);
+  const __m128i scale_rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
+  const __m128i shift_rounding = _mm_set1_epi32(1 << (-shift - 1));
+  const __m128i one = _mm_set1_epi16(1);
+  const __m128i scale_coeff = _mm_unpacklo_epi16(scale, scale_rounding);
+  const __m128i zero = _mm_setzero_si128();
+  for (int h = 0; h < height; ++h) {
+    __m128i lo = _mm_unpacklo_epi16(buf[h], one);
+    __m128i hi = _mm_unpackhi_epi16(buf[h], one);
+    lo = _mm_madd_epi16(lo, scale_coeff);
+    hi = _mm_madd_epi16(hi, scale_coeff);
+    lo = _mm_srai_epi32(lo, NewSqrt2Bits);
+    hi = _mm_srai_epi32(hi, NewSqrt2Bits);
+    lo = _mm_add_epi32(lo, shift_rounding);
+    hi = _mm_add_epi32(hi, shift_rounding);
+    lo = _mm_srai_epi32(lo, -shift);
+    hi = _mm_srai_epi32(hi, -shift);
+    __m128i x = _mm_packs_epi32(lo, hi);
+
+    const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
+    x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
+    __m128i u = _mm_packus_epi16(x, x);
+    _mm_storel_epi64((__m128i *)(output), u);
+    output += stride;
   }
 }
 
@@ -1913,10 +1646,10 @@
   __m128i buf[32];
 
   for (int i = 0; i<input_stride>> 3; ++i) {
-    identity_row_8xn_ssse3(buf, input + 8 * i, input_stride, shift[0], row_max,
-                           txw_idx, rect_type);
-    identity_col_8xn_ssse3(output + 8 * i, stride, buf, shift[1], row_max,
-                           txh_idx);
+    iidentity_row_8xn_ssse3(buf, input + 8 * i, input_stride, shift[0], row_max,
+                            txw_idx, rect_type);
+    iidentity_col_8xn_ssse3(output + 8 * i, stride, buf, shift[1], row_max,
+                            txh_idx);
   }
 }
 
@@ -2083,8 +1816,8 @@
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   for (int i = 0; i < AOMMIN(4, buf_size_w_div8); i++) {
     __m128i buf0[64];
-    identity_row_8xn_ssse3(buf0, input + 8 * i, input_stride, shift[0],
-                           txfm_size_row_notzero, txw_idx, rect_type);
+    iidentity_row_8xn_ssse3(buf0, input + 8 * i, input_stride, shift[0],
+                            txfm_size_row_notzero, txw_idx, rect_type);
     col_txfm(buf0, buf0, cos_bit_col);
     __m128i mshift = _mm_set1_epi16(1 << (15 + shift[1]));
     int k = ud_flip ? (txfm_size_row - 1) : 0;
@@ -2149,8 +1882,8 @@
     }
 
     for (int j = 0; j < buf_size_w_div8; ++j) {
-      identity_col_8xn_ssse3(output + i * 8 * stride + j * 8, stride,
-                             buf1 + j * 8, shift[1], 8, txh_idx);
+      iidentity_col_8xn_ssse3(output + i * 8 * stride + j * 8, stride,
+                              buf1 + j * 8, shift[1], 8, txh_idx);
     }
   }
 }
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.h b/av1/common/x86/av1_inv_txfm_ssse3.h
index 96dc0d6..ccdb006 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.h
+++ b/av1/common/x86/av1_inv_txfm_ssse3.h
@@ -53,6 +53,10 @@
     out1 = _mm_subs_epi16(_in0, _in1);                  \
   } while (0)
 
+#ifdef __cplusplus
+extern "C" {
+#endif
+
 static INLINE void round_shift_16bit_ssse3(__m128i *in, int size, int bit) {
   if (bit < 0) {
     const __m128i scale = _mm_set1_epi16(1 << (15 + bit));
@@ -66,10 +70,6 @@
   }
 }
 
-#ifdef __cplusplus
-extern "C" {
-#endif
-
 // 1D itx types
 typedef enum ATTRIBUTE_PACKED {
   IDCT_1D,
@@ -93,6 +93,10 @@
   IIDENTITY_1D, IADST_1D,     IIDENTITY_1D, IFLIPADST_1D,
 };
 
+// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5
+static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096,
+                                          4 * 5793 };
+
 typedef void (*transform_1d_ssse3)(const __m128i *input, __m128i *output,
                                    int8_t cos_bit);
 
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index 73feb16..c07ce09 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -325,11 +325,13 @@
 #endif  // HAVE_SSSE3
 
 #if HAVE_AVX2
-#if defined(_MSC_VER) || defined(__AVX2__)
-#include "av1/common/x86/av1_inv_txfm_avx2.h"
+extern "C" void av1_lowbd_inv_txfm2d_add_avx2(const int32_t *input,
+                                              uint8_t *output, int stride,
+                                              TX_TYPE tx_type, TX_SIZE tx_size,
+                                              int eob);
+
 INSTANTIATE_TEST_CASE_P(AVX2, AV1LbdInvTxfm2d,
                         ::testing::Values(av1_lowbd_inv_txfm2d_add_avx2));
-#endif  // (_MSC_VER) || (__AVX2__)
 #endif  // HAVE_AVX2
 
 }  // namespace