Fix lowbd inv_txfm's mismatch between C and SIMD
BUG=aomedia:2350
* avx2 av1_{highbd,}_inv_txfm_add
When the intput value of round_shift_16bit_w16_avx2 is too large,
such as 32767, this avx2 function produce different result with
c version round_shift.
* lowbd_inv_txfm2d_add_{4x16,16x4}_ssse3
This function missmatch with C version when the values in
dqcoeff is too large, which maybe not prouduced by forward
transform in AV1 encoder.
* Update unitest AV1LbdInvTxfm2d.gt_int16 to replay this
issue.
Change-Id: I7c249de60657f26c205090b5d00b92a902c04bb5
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index aca5ec7..c4de2ee 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -111,9 +111,7 @@
#inv txfm
add_proto qw/void av1_inv_txfm_add/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-# TODO(http://crbug.com/aomedia/2350): avx2 is disabled due to test vector
-# mismatches.
-specialize qw/av1_inv_txfm_add ssse3 neon/;
+specialize qw/av1_inv_txfm_add ssse3 avx2 neon/;
add_proto qw/void av1_highbd_inv_txfm_add/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
# TODO(http://crbug.com/aomedia/2350): avx2 is disabled due to test vector
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c
index cf1f947..3f5ad89 100644
--- a/av1/common/x86/av1_inv_txfm_avx2.c
+++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -1638,6 +1638,7 @@
assert(row_txfm != NULL);
int ud_flip, lr_flip;
get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+ const __m256i scale0 = _mm256_set1_epi16(1 << (15 + shift[0]));
for (int i = 0; i < buf_size_nonzero_h_div16; i++) {
__m256i buf0[64];
const int32_t *input_row = input + (i << 4) * input_stride;
@@ -1652,7 +1653,9 @@
round_shift_avx2(buf0, buf0, input_stride); // rect special code
}
row_txfm(buf0, buf0, cos_bit_row);
- round_shift_16bit_w16_avx2(buf0, txfm_size_col, shift[0]);
+ for (int j = 0; j < txfm_size_col; ++j) {
+ buf0[j] = _mm256_mulhrs_epi16(buf0[j], scale0);
+ }
__m256i *buf1_cur = buf1 + (i << 4);
if (lr_flip) {
@@ -1668,10 +1671,13 @@
}
}
}
+ const __m256i scale1 = _mm256_set1_epi16(1 << (15 + shift[1]));
for (int i = 0; i < buf_size_w_div16; i++) {
__m256i *buf1_cur = buf1 + i * txfm_size_row;
col_txfm(buf1_cur, buf1_cur, cos_bit_col);
- round_shift_16bit_w16_avx2(buf1_cur, txfm_size_row, shift[1]);
+ for (int j = 0; j < txfm_size_row; ++j) {
+ buf1_cur[j] = _mm256_mulhrs_epi16(buf1_cur[j], scale1);
+ }
}
for (int i = 0; i < buf_size_w_div16; i++) {
lowbd_write_buffer_16xn_avx2(buf1 + i * txfm_size_row, output + 16 * i,
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c
index de0a561..2208a91 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.c
+++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -2820,8 +2820,22 @@
load_buffer_32bit_to_16bit_w4(input_cur, txfm_size_col, buf_cur,
row_one_loop);
transpose_16bit_4x8(buf_cur, buf_cur);
- row_txfm(buf_cur, buf_cur, cos_bit_row);
- round_shift_16bit_ssse3(buf_cur, row_one_loop, shift[0]);
+ if (row_txfm == iidentity4_new_ssse3) {
+ const __m128i scale = pair_set_epi16(NewSqrt2, 3 << (NewSqrt2Bits - 1));
+ const __m128i ones = _mm_set1_epi16(1);
+ for (int j = 0; j < 4; ++j) {
+ const __m128i buf_lo = _mm_unpacklo_epi16(buf_cur[j], ones);
+ const __m128i buf_hi = _mm_unpackhi_epi16(buf_cur[j], ones);
+ const __m128i buf_32_lo =
+ _mm_srai_epi32(_mm_madd_epi16(buf_lo, scale), (NewSqrt2Bits + 1));
+ const __m128i buf_32_hi =
+ _mm_srai_epi32(_mm_madd_epi16(buf_hi, scale), (NewSqrt2Bits + 1));
+ buf_cur[j] = _mm_packs_epi32(buf_32_lo, buf_32_hi);
+ }
+ } else {
+ row_txfm(buf_cur, buf_cur, cos_bit_row);
+ round_shift_16bit_ssse3(buf_cur, row_one_loop, shift[0]);
+ }
if (lr_flip) {
__m128i temp[8];
flip_buf_sse2(buf_cur, temp, txfm_size_col);
@@ -2867,8 +2881,22 @@
txfm_size_row);
transpose_16bit_8x4(buf_cur, buf_cur);
}
- row_txfm(buf, buf, cos_bit_row);
- round_shift_16bit_ssse3(buf, txfm_size_col, shift[0]);
+ if (row_txfm == iidentity16_new_ssse3) {
+ const __m128i scale = pair_set_epi16(2 * NewSqrt2, 3 << (NewSqrt2Bits - 1));
+ const __m128i ones = _mm_set1_epi16(1);
+ for (int j = 0; j < 16; ++j) {
+ const __m128i buf_lo = _mm_unpacklo_epi16(buf[j], ones);
+ const __m128i buf_hi = _mm_unpackhi_epi16(buf[j], ones);
+ const __m128i buf_32_lo =
+ _mm_srai_epi32(_mm_madd_epi16(buf_lo, scale), (NewSqrt2Bits + 1));
+ const __m128i buf_32_hi =
+ _mm_srai_epi32(_mm_madd_epi16(buf_hi, scale), (NewSqrt2Bits + 1));
+ buf[j] = _mm_packs_epi32(buf_32_lo, buf_32_hi);
+ }
+ } else {
+ row_txfm(buf, buf, cos_bit_row);
+ round_shift_16bit_ssse3(buf, txfm_size_col, shift[0]);
+ }
if (lr_flip) {
__m128i temp[16];
flip_buf_sse2(buf, temp, 16);
@@ -2916,22 +2944,14 @@
break;
}
}
+
void av1_inv_txfm_add_ssse3(const tran_low_t *dqcoeff, uint8_t *dst, int stride,
const TxfmParam *txfm_param) {
- const TX_TYPE tx_type = txfm_param->tx_type;
if (!txfm_param->lossless) {
- switch (txfm_param->tx_size) {
- case TX_4X16:
- case TX_16X4:
- // TODO(http://crbug.com/aomedia/2350): the ssse3 versions cause test
- // vector mismatches.
- av1_inv_txfm_add_c(dqcoeff, dst, stride, txfm_param);
- break;
- default:
- av1_lowbd_inv_txfm2d_add_ssse3(dqcoeff, dst, stride, tx_type,
- txfm_param->tx_size, txfm_param->eob);
- break;
- }
+ const TX_TYPE tx_type = txfm_param->tx_type;
+ av1_lowbd_inv_txfm2d_add_ssse3(dqcoeff, dst, stride, tx_type,
+ txfm_param->tx_size, txfm_param->eob);
+
} else {
av1_inv_txfm_add_c(dqcoeff, dst, stride, txfm_param);
}
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index 5432130..46cf068 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -38,6 +38,25 @@
namespace {
+static const char *tx_type_name[] = {
+ "DCT_DCT",
+ "ADST_DCT",
+ "DCT_ADST",
+ "ADST_ADST",
+ "FLIPADST_DCT",
+ "DCT_FLIPADST",
+ "FLIPADST_FLIPADST",
+ "ADST_FLIPADST",
+ "FLIPADST_ADST",
+ "IDTX",
+ "V_DCT",
+ "H_DCT",
+ "V_ADST",
+ "H_ADST",
+ "V_FLIPADST",
+ "H_FLIPADST",
+};
+
// AV1InvTxfm2dParam argument list:
// tx_type_, tx_size_, max_error_, max_avg_error_
typedef ::testing::tuple<TX_TYPE, TX_SIZE, int, double> AV1InvTxfm2dParam;
@@ -243,14 +262,15 @@
class AV1LbdInvTxfm2d : public ::testing::TestWithParam<AV1LbdInvTxfm2dParam> {
public:
virtual void SetUp() { target_func_ = GET_PARAM(0); }
- void RunAV1InvTxfm2dTest(TX_TYPE tx_type, TX_SIZE tx_size, int run_times);
+ void RunAV1InvTxfm2dTest(TX_TYPE tx_type, TX_SIZE tx_size, int run_times,
+ int gt_int16 = 0);
private:
LbdInvTxfm2dFunc target_func_;
};
void AV1LbdInvTxfm2d::RunAV1InvTxfm2dTest(TX_TYPE tx_type, TX_SIZE tx_size,
- int run_times) {
+ int run_times, int gt_int16) {
FwdTxfm2dFunc fwd_func_ = libaom_test::fwd_txfm_func_ls[tx_size];
InvTxfm2dFunc ref_func_ = libaom_test::inv_txfm_func_ls[tx_size];
if (fwd_func_ == NULL || ref_func_ == NULL || target_func_ == NULL) {
@@ -275,6 +295,7 @@
const int16_t eobmax = rows_nonezero * cols_nonezero;
ACMRandom rnd(ACMRandom::DeterministicSeed());
int randTimes = run_times == 1 ? (eobmax + 500) : 1;
+
for (int cnt = 0; cnt < randTimes; ++cnt) {
const int16_t max_in = (1 << (bd)) - 1;
for (int r = 0; r < BLK_WIDTH; ++r) {
@@ -291,7 +312,9 @@
for (int i = eob; i < eobmax; i++) {
inv_input[scan[i]] = 0;
}
-
+ if (gt_int16) {
+ inv_input[scan[eob - 1]] = ((int32_t)INT16_MAX * 100 / 141);
+ }
aom_usec_timer timer;
aom_usec_timer_start(&timer);
for (int i = 0; i < run_times; ++i) {
@@ -313,10 +336,13 @@
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
uint8_t ref_value = static_cast<uint8_t>(ref_output[r * stride + c]);
+ if (ref_value != output[r * stride + c]) {
+ printf(" ");
+ }
ASSERT_EQ(ref_value, output[r * stride + c])
<< "[" << r << "," << c << "] " << cnt
<< " tx_size: " << static_cast<int>(tx_size)
- << " tx_type: " << tx_type << " eob " << eob;
+ << " tx_type: " << tx_type_name[tx_type] << " eob " << eob;
}
}
}
@@ -334,8 +360,23 @@
}
}
-TEST_P(AV1LbdInvTxfm2d, DISABLED_Speed) {
+TEST_P(AV1LbdInvTxfm2d, gt_int16) {
+ static const TX_TYPE types[] = {
+ DCT_DCT, ADST_DCT, FLIPADST_DCT, IDTX, V_DCT, H_DCT, H_ADST, H_FLIPADST
+ };
for (int j = 0; j < (int)(TX_SIZES_ALL); ++j) {
+ const TX_SIZE sz = static_cast<TX_SIZE>(j);
+ for (uint8_t i = 0; i < sizeof(types) / sizeof(TX_TYPE); ++i) {
+ const TX_TYPE tp = types[i];
+ if (libaom_test::IsTxSizeTypeValid(sz, tp)) {
+ RunAV1InvTxfm2dTest(tp, sz, 1, 1);
+ }
+ }
+ }
+}
+
+TEST_P(AV1LbdInvTxfm2d, DISABLED_Speed) {
+ for (int j = 1; j < (int)(TX_SIZES_ALL); ++j) {
for (int i = 0; i < (int)TX_TYPES; ++i) {
if (libaom_test::IsTxSizeTypeValid(static_cast<TX_SIZE>(j),
static_cast<TX_TYPE>(i))) {