Fix unit tests for TX64X64.
All tests are now passing.
Change-Id: Ifc1a0f3ff69f5730722a27eed092395595127e8e
diff --git a/av1/common/av1_inv_txfm2d.c b/av1/common/av1_inv_txfm2d.c
index 94f1ad0..6450e26 100644
--- a/av1/common/av1_inv_txfm2d.c
+++ b/av1/common/av1_inv_txfm2d.c
@@ -129,7 +129,7 @@
return cfg;
}
-TXFM_2D_FLIP_CFG av1_get_inv_txfm_32x64_cfg(int tx_type) {
+TXFM_2D_FLIP_CFG av1_get_inv_txfm_32x64_cfg(TX_TYPE tx_type) {
TXFM_2D_FLIP_CFG cfg = { 0, 0, NULL, NULL };
switch (tx_type) {
case DCT_DCT:
@@ -142,7 +142,7 @@
return cfg;
}
-TXFM_2D_FLIP_CFG av1_get_inv_txfm_64x32_cfg(int tx_type) {
+TXFM_2D_FLIP_CFG av1_get_inv_txfm_64x32_cfg(TX_TYPE tx_type) {
TXFM_2D_FLIP_CFG cfg = { 0, 0, NULL, NULL };
switch (tx_type) {
case DCT_DCT:
diff --git a/av1/common/av1_txfm.h b/av1/common/av1_txfm.h
index ba5aa0c..44b5bd0 100644
--- a/av1/common/av1_txfm.h
+++ b/av1/common/av1_txfm.h
@@ -367,12 +367,15 @@
int bd);
TXFM_2D_FLIP_CFG av1_get_fwd_txfm_cfg(TX_TYPE tx_type, TX_SIZE tx_size);
+TXFM_2D_FLIP_CFG av1_get_inv_txfm_cfg(TX_TYPE tx_type, TX_SIZE tx_size);
#if CONFIG_TX64X64
TXFM_2D_FLIP_CFG av1_get_fwd_txfm_64x64_cfg(TX_TYPE tx_type);
TXFM_2D_FLIP_CFG av1_get_fwd_txfm_64x32_cfg(TX_TYPE tx_type);
TXFM_2D_FLIP_CFG av1_get_fwd_txfm_32x64_cfg(TX_TYPE tx_type);
+TXFM_2D_FLIP_CFG av1_get_inv_txfm_64x64_cfg(TX_TYPE tx_type);
+TXFM_2D_FLIP_CFG av1_get_inv_txfm_64x32_cfg(TX_TYPE tx_type);
+TXFM_2D_FLIP_CFG av1_get_inv_txfm_32x64_cfg(TX_TYPE tx_type);
#endif // CONFIG_TX64X64
-TXFM_2D_FLIP_CFG av1_get_inv_txfm_cfg(TX_TYPE tx_type, TX_SIZE tx_size);
#ifdef __cplusplus
}
#endif // __cplusplus
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index 33dc3a8..7d6bc72 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -183,8 +183,17 @@
// TODO(angiebird): include rect txfm in this test
for (int tx_size = 0; tx_size < TX_SIZES; ++tx_size) {
for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
- TXFM_2D_FLIP_CFG cfg = av1_get_fwd_txfm_cfg(
- static_cast<TX_TYPE>(tx_type), static_cast<TX_SIZE>(tx_size));
+#if CONFIG_TX64X64
+ if (tx_size == TX_64X64 && tx_type != DCT_DCT) continue;
+#endif // CONFIG_TX64X64
+ const TXFM_2D_FLIP_CFG cfg =
+#if CONFIG_TX64X64
+ (tx_size == TX_64X64)
+ ? av1_get_fwd_txfm_64x64_cfg(static_cast<TX_TYPE>(tx_type))
+ :
+#endif // CONFIG_TX64X64
+ av1_get_fwd_txfm_cfg(static_cast<TX_TYPE>(tx_type),
+ static_cast<TX_SIZE>(tx_size));
int8_t stage_range_col[MAX_TXFM_STAGE_NUM];
int8_t stage_range_row[MAX_TXFM_STAGE_NUM];
av1_gen_fwd_stage_range(stage_range_col, stage_range_row, &cfg, bd);
diff --git a/test/av1_inv_txfm1d_test.cc b/test/av1_inv_txfm1d_test.cc
index b44c041..54d19f9 100644
--- a/test/av1_inv_txfm1d_test.cc
+++ b/test/av1_inv_txfm1d_test.cc
@@ -21,7 +21,12 @@
namespace {
const int txfm_type_num = 2;
-const int txfm_size_ls[5] = { 4, 8, 16, 32, 64 };
+const int txfm_size_ls[] = {
+ 4, 8, 16, 32,
+#if CONFIG_TX64X64
+ 64,
+#endif // CONFIG_TX64X64
+};
const TxfmFunc fwd_txfm_func_ls[][2] = {
{ av1_fdct4_new, av1_fadst4_new },
@@ -54,8 +59,11 @@
double output[64];
libaom_test::reference_idct_1d(input, output, size);
- for (int i = 0; i < size; ++i)
+ for (int i = 0; i < size; ++i) {
+ ASSERT_GE(output[i], INT32_MIN);
+ ASSERT_LE(output[i], INT32_MAX);
out[i] = static_cast<int32_t>(round(output[i]));
+ }
}
void random_matrix(int32_t *dst, int len, ACMRandom *rnd) {
@@ -73,24 +81,42 @@
TEST(av1_inv_txfm1d, InvAccuracyCheck) {
ACMRandom rnd(ACMRandom::DeterministicSeed());
const int count_test_block = 20000;
- const int max_error[] = { 6, 10, 19, 28 };
+ const int max_error[] = {
+ 6,
+ 10,
+ 19,
+ 31,
+#if CONFIG_TX64X64
+ 40,
+#endif // CONFIG_TX64X64
+ };
+ ASSERT_EQ(NELEMENTS(max_error), TX_SIZES);
+ ASSERT_EQ(NELEMENTS(inv_txfm_func_ls), TX_SIZES);
for (int k = 0; k < count_test_block; ++k) {
// choose a random transform to test
- const int txfm_type = rnd.Rand8() % NELEMENTS(inv_txfm_func_ls);
- const int txfm_size = txfm_size_ls[txfm_type];
- const TxfmFunc txfm_func = inv_txfm_func_ls[txfm_type][0];
+ const TX_SIZE tx_size = static_cast<TX_SIZE>(rnd.Rand8() % TX_SIZES);
+ const int tx_size_pix = txfm_size_ls[tx_size];
+ const TxfmFunc inv_txfm_func = inv_txfm_func_ls[tx_size][0];
int32_t input[64];
- random_matrix(input, txfm_size, &rnd);
+ random_matrix(input, tx_size_pix, &rnd);
+
+#if CONFIG_TX64X64
+ // 64x64 transform assumes last 32 values are zero.
+ memset(input + 32, 0, 32 * sizeof(input[0]));
+#endif // CONFIG_TX64X64
int32_t ref_output[64];
- reference_idct_1d_int(input, ref_output, txfm_size);
+ reference_idct_1d_int(input, ref_output, tx_size_pix);
int32_t output[64];
- txfm_func(input, output, cos_bit, range_bit);
+ inv_txfm_func(input, output, cos_bit, range_bit);
- for (int i = 0; i < txfm_size; ++i) {
- EXPECT_LE(abs(output[i] - ref_output[i]), max_error[txfm_type]);
+ for (int i = 0; i < tx_size_pix; ++i) {
+ EXPECT_LE(abs(output[i] - ref_output[i]), max_error[tx_size])
+ << "tx_size = " << tx_size << ", i = " << i
+ << ", output[i] = " << output[i]
+ << ", ref_output[i] = " << ref_output[i];
}
}
}
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index 4aa943f..8baaf9f 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -190,8 +190,17 @@
// TODO(angiebird): include rect txfm in this test
for (int tx_size = 0; tx_size < TX_SIZES; ++tx_size) {
for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
- TXFM_2D_FLIP_CFG cfg = av1_get_inv_txfm_cfg(
- static_cast<TX_TYPE>(tx_type), static_cast<TX_SIZE>(tx_size));
+#if CONFIG_TX64X64
+ if (tx_size == TX_64X64 && tx_type != DCT_DCT) continue;
+#endif // CONFIG_TX64X64
+ const TXFM_2D_FLIP_CFG cfg =
+#if CONFIG_TX64X64
+ (tx_size == TX_64X64)
+ ? av1_get_inv_txfm_64x64_cfg(static_cast<TX_TYPE>(tx_type))
+ :
+#endif // CONFIG_TX64X64
+ av1_get_inv_txfm_cfg(static_cast<TX_TYPE>(tx_type),
+ static_cast<TX_SIZE>(tx_size));
int8_t stage_range_col[MAX_TXFM_STAGE_NUM];
int8_t stage_range_row[MAX_TXFM_STAGE_NUM];
av1_gen_inv_stage_range(stage_range_col, stage_range_row, &cfg,
diff --git a/test/av1_txfm_test.cc b/test/av1_txfm_test.cc
index 235872c..800f216 100644
--- a/test/av1_txfm_test.cc
+++ b/test/av1_txfm_test.cc
@@ -176,8 +176,15 @@
template void fliplrud<double>(double *dest, int stride, int length);
int bd_arr[BD_NUM] = { 8, 10, 12 };
+
+#if CONFIG_TX64X64
+int8_t low_range_arr[BD_NUM] = { 18, 32, 32 };
+// TODO(urvang): Try reducing cos bit by 1 for TX64X64 to get it back to 32.
+int8_t high_range_arr[BD_NUM] = { 33, 33, 33 };
+#else
int8_t low_range_arr[BD_NUM] = { 16, 32, 32 };
int8_t high_range_arr[BD_NUM] = { 32, 32, 32 };
+#endif // CONFIG_TX64X64
void txfm_stage_range_check(const int8_t *stage_range, int stage_num,
const int8_t *cos_bit, int low_range,