Fix lowbd inv_txfm's mismatch between C and SIMD

BUG=aomedia:2350

* avx2 av1_{highbd,}_inv_txfm_add
When the intput value of round_shift_16bit_w16_avx2 is too large,
such as 32767, this avx2 function produce different result with
c version round_shift.

* lowbd_inv_txfm2d_add_{4x16,16x4}_ssse3
This function missmatch with C version when the values in
dqcoeff is too large, which maybe not prouduced by forward
transform in AV1 encoder.

* Update unitest AV1LbdInvTxfm2d.gt_int16 to replay this
issue.

Change-Id: I7c249de60657f26c205090b5d00b92a902c04bb5
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index aca5ec7..c4de2ee 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -111,9 +111,7 @@
 
 #inv txfm
 add_proto qw/void av1_inv_txfm_add/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-# TODO(http://crbug.com/aomedia/2350): avx2 is disabled due to test vector
-# mismatches.
-specialize qw/av1_inv_txfm_add ssse3 neon/;
+specialize qw/av1_inv_txfm_add ssse3 avx2 neon/;
 
 add_proto qw/void av1_highbd_inv_txfm_add/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 # TODO(http://crbug.com/aomedia/2350): avx2 is disabled due to test vector
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c
index cf1f947..3f5ad89 100644
--- a/av1/common/x86/av1_inv_txfm_avx2.c
+++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -1638,6 +1638,7 @@
   assert(row_txfm != NULL);
   int ud_flip, lr_flip;
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  const __m256i scale0 = _mm256_set1_epi16(1 << (15 + shift[0]));
   for (int i = 0; i < buf_size_nonzero_h_div16; i++) {
     __m256i buf0[64];
     const int32_t *input_row = input + (i << 4) * input_stride;
@@ -1652,7 +1653,9 @@
       round_shift_avx2(buf0, buf0, input_stride);  // rect special code
     }
     row_txfm(buf0, buf0, cos_bit_row);
-    round_shift_16bit_w16_avx2(buf0, txfm_size_col, shift[0]);
+    for (int j = 0; j < txfm_size_col; ++j) {
+      buf0[j] = _mm256_mulhrs_epi16(buf0[j], scale0);
+    }
 
     __m256i *buf1_cur = buf1 + (i << 4);
     if (lr_flip) {
@@ -1668,10 +1671,13 @@
       }
     }
   }
+  const __m256i scale1 = _mm256_set1_epi16(1 << (15 + shift[1]));
   for (int i = 0; i < buf_size_w_div16; i++) {
     __m256i *buf1_cur = buf1 + i * txfm_size_row;
     col_txfm(buf1_cur, buf1_cur, cos_bit_col);
-    round_shift_16bit_w16_avx2(buf1_cur, txfm_size_row, shift[1]);
+    for (int j = 0; j < txfm_size_row; ++j) {
+      buf1_cur[j] = _mm256_mulhrs_epi16(buf1_cur[j], scale1);
+    }
   }
   for (int i = 0; i < buf_size_w_div16; i++) {
     lowbd_write_buffer_16xn_avx2(buf1 + i * txfm_size_row, output + 16 * i,
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c
index de0a561..2208a91 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.c
+++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -2820,8 +2820,22 @@
     load_buffer_32bit_to_16bit_w4(input_cur, txfm_size_col, buf_cur,
                                   row_one_loop);
     transpose_16bit_4x8(buf_cur, buf_cur);
-    row_txfm(buf_cur, buf_cur, cos_bit_row);
-    round_shift_16bit_ssse3(buf_cur, row_one_loop, shift[0]);
+    if (row_txfm == iidentity4_new_ssse3) {
+      const __m128i scale = pair_set_epi16(NewSqrt2, 3 << (NewSqrt2Bits - 1));
+      const __m128i ones = _mm_set1_epi16(1);
+      for (int j = 0; j < 4; ++j) {
+        const __m128i buf_lo = _mm_unpacklo_epi16(buf_cur[j], ones);
+        const __m128i buf_hi = _mm_unpackhi_epi16(buf_cur[j], ones);
+        const __m128i buf_32_lo =
+            _mm_srai_epi32(_mm_madd_epi16(buf_lo, scale), (NewSqrt2Bits + 1));
+        const __m128i buf_32_hi =
+            _mm_srai_epi32(_mm_madd_epi16(buf_hi, scale), (NewSqrt2Bits + 1));
+        buf_cur[j] = _mm_packs_epi32(buf_32_lo, buf_32_hi);
+      }
+    } else {
+      row_txfm(buf_cur, buf_cur, cos_bit_row);
+      round_shift_16bit_ssse3(buf_cur, row_one_loop, shift[0]);
+    }
     if (lr_flip) {
       __m128i temp[8];
       flip_buf_sse2(buf_cur, temp, txfm_size_col);
@@ -2867,8 +2881,22 @@
                                txfm_size_row);
     transpose_16bit_8x4(buf_cur, buf_cur);
   }
-  row_txfm(buf, buf, cos_bit_row);
-  round_shift_16bit_ssse3(buf, txfm_size_col, shift[0]);
+  if (row_txfm == iidentity16_new_ssse3) {
+    const __m128i scale = pair_set_epi16(2 * NewSqrt2, 3 << (NewSqrt2Bits - 1));
+    const __m128i ones = _mm_set1_epi16(1);
+    for (int j = 0; j < 16; ++j) {
+      const __m128i buf_lo = _mm_unpacklo_epi16(buf[j], ones);
+      const __m128i buf_hi = _mm_unpackhi_epi16(buf[j], ones);
+      const __m128i buf_32_lo =
+          _mm_srai_epi32(_mm_madd_epi16(buf_lo, scale), (NewSqrt2Bits + 1));
+      const __m128i buf_32_hi =
+          _mm_srai_epi32(_mm_madd_epi16(buf_hi, scale), (NewSqrt2Bits + 1));
+      buf[j] = _mm_packs_epi32(buf_32_lo, buf_32_hi);
+    }
+  } else {
+    row_txfm(buf, buf, cos_bit_row);
+    round_shift_16bit_ssse3(buf, txfm_size_col, shift[0]);
+  }
   if (lr_flip) {
     __m128i temp[16];
     flip_buf_sse2(buf, temp, 16);
@@ -2916,22 +2944,14 @@
       break;
   }
 }
+
 void av1_inv_txfm_add_ssse3(const tran_low_t *dqcoeff, uint8_t *dst, int stride,
                             const TxfmParam *txfm_param) {
-  const TX_TYPE tx_type = txfm_param->tx_type;
   if (!txfm_param->lossless) {
-    switch (txfm_param->tx_size) {
-      case TX_4X16:
-      case TX_16X4:
-        // TODO(http://crbug.com/aomedia/2350): the ssse3 versions cause test
-        // vector mismatches.
-        av1_inv_txfm_add_c(dqcoeff, dst, stride, txfm_param);
-        break;
-      default:
-        av1_lowbd_inv_txfm2d_add_ssse3(dqcoeff, dst, stride, tx_type,
-                                       txfm_param->tx_size, txfm_param->eob);
-        break;
-    }
+    const TX_TYPE tx_type = txfm_param->tx_type;
+    av1_lowbd_inv_txfm2d_add_ssse3(dqcoeff, dst, stride, tx_type,
+                                   txfm_param->tx_size, txfm_param->eob);
+
   } else {
     av1_inv_txfm_add_c(dqcoeff, dst, stride, txfm_param);
   }
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index 5432130..46cf068 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -38,6 +38,25 @@
 
 namespace {
 
+static const char *tx_type_name[] = {
+  "DCT_DCT",
+  "ADST_DCT",
+  "DCT_ADST",
+  "ADST_ADST",
+  "FLIPADST_DCT",
+  "DCT_FLIPADST",
+  "FLIPADST_FLIPADST",
+  "ADST_FLIPADST",
+  "FLIPADST_ADST",
+  "IDTX",
+  "V_DCT",
+  "H_DCT",
+  "V_ADST",
+  "H_ADST",
+  "V_FLIPADST",
+  "H_FLIPADST",
+};
+
 // AV1InvTxfm2dParam argument list:
 // tx_type_, tx_size_, max_error_, max_avg_error_
 typedef ::testing::tuple<TX_TYPE, TX_SIZE, int, double> AV1InvTxfm2dParam;
@@ -243,14 +262,15 @@
 class AV1LbdInvTxfm2d : public ::testing::TestWithParam<AV1LbdInvTxfm2dParam> {
  public:
   virtual void SetUp() { target_func_ = GET_PARAM(0); }
-  void RunAV1InvTxfm2dTest(TX_TYPE tx_type, TX_SIZE tx_size, int run_times);
+  void RunAV1InvTxfm2dTest(TX_TYPE tx_type, TX_SIZE tx_size, int run_times,
+                           int gt_int16 = 0);
 
  private:
   LbdInvTxfm2dFunc target_func_;
 };
 
 void AV1LbdInvTxfm2d::RunAV1InvTxfm2dTest(TX_TYPE tx_type, TX_SIZE tx_size,
-                                          int run_times) {
+                                          int run_times, int gt_int16) {
   FwdTxfm2dFunc fwd_func_ = libaom_test::fwd_txfm_func_ls[tx_size];
   InvTxfm2dFunc ref_func_ = libaom_test::inv_txfm_func_ls[tx_size];
   if (fwd_func_ == NULL || ref_func_ == NULL || target_func_ == NULL) {
@@ -275,6 +295,7 @@
   const int16_t eobmax = rows_nonezero * cols_nonezero;
   ACMRandom rnd(ACMRandom::DeterministicSeed());
   int randTimes = run_times == 1 ? (eobmax + 500) : 1;
+
   for (int cnt = 0; cnt < randTimes; ++cnt) {
     const int16_t max_in = (1 << (bd)) - 1;
     for (int r = 0; r < BLK_WIDTH; ++r) {
@@ -291,7 +312,9 @@
     for (int i = eob; i < eobmax; i++) {
       inv_input[scan[i]] = 0;
     }
-
+    if (gt_int16) {
+      inv_input[scan[eob - 1]] = ((int32_t)INT16_MAX * 100 / 141);
+    }
     aom_usec_timer timer;
     aom_usec_timer_start(&timer);
     for (int i = 0; i < run_times; ++i) {
@@ -313,10 +336,13 @@
     for (int r = 0; r < rows; ++r) {
       for (int c = 0; c < cols; ++c) {
         uint8_t ref_value = static_cast<uint8_t>(ref_output[r * stride + c]);
+        if (ref_value != output[r * stride + c]) {
+          printf(" ");
+        }
         ASSERT_EQ(ref_value, output[r * stride + c])
             << "[" << r << "," << c << "] " << cnt
             << " tx_size: " << static_cast<int>(tx_size)
-            << " tx_type: " << tx_type << " eob " << eob;
+            << " tx_type: " << tx_type_name[tx_type] << " eob " << eob;
       }
     }
   }
@@ -334,8 +360,23 @@
   }
 }
 
-TEST_P(AV1LbdInvTxfm2d, DISABLED_Speed) {
+TEST_P(AV1LbdInvTxfm2d, gt_int16) {
+  static const TX_TYPE types[] = {
+    DCT_DCT, ADST_DCT, FLIPADST_DCT, IDTX, V_DCT, H_DCT, H_ADST, H_FLIPADST
+  };
   for (int j = 0; j < (int)(TX_SIZES_ALL); ++j) {
+    const TX_SIZE sz = static_cast<TX_SIZE>(j);
+    for (uint8_t i = 0; i < sizeof(types) / sizeof(TX_TYPE); ++i) {
+      const TX_TYPE tp = types[i];
+      if (libaom_test::IsTxSizeTypeValid(sz, tp)) {
+        RunAV1InvTxfm2dTest(tp, sz, 1, 1);
+      }
+    }
+  }
+}
+
+TEST_P(AV1LbdInvTxfm2d, DISABLED_Speed) {
+  for (int j = 1; j < (int)(TX_SIZES_ALL); ++j) {
     for (int i = 0; i < (int)TX_TYPES; ++i) {
       if (libaom_test::IsTxSizeTypeValid(static_cast<TX_SIZE>(j),
                                          static_cast<TX_TYPE>(i))) {