AV1FwdTxfm2d.RunFwdAccuracyCheck: Add rect txfms.
- Added all 1:2 and 1:4 transforms through a dynamically generated list.
- Reworked the code to support testing these rectangular transforms.
BUG=aomedia:1114
Change-Id: I7e83b48f02a530716d5e30103780c5c4f450cbbd
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index 5ef3917..cf89df6 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -12,6 +12,7 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
+#include <vector>
#include "test/acm_random.h"
#include "test/util.h"
@@ -26,6 +27,8 @@
using libaom_test::Fwd_Txfm2d_Func;
using libaom_test::TYPE_TXFM;
+using std::vector;
+
namespace {
#if CONFIG_HIGHBITDEPTH
// tx_type_, tx_size_, max_error_, max_avg_error_
@@ -41,19 +44,30 @@
count_ = 500;
TXFM_2D_FLIP_CFG fwd_txfm_flip_cfg;
av1_get_fwd_txfm_cfg(tx_type_, tx_size_, &fwd_txfm_flip_cfg);
- // TODO(sarahparker) this test will need to be updated when these
- // functions are extended to support rectangular transforms
- int amplify_bit = fwd_txfm_flip_cfg.row_cfg->shift[0] +
- fwd_txfm_flip_cfg.row_cfg->shift[1] +
- fwd_txfm_flip_cfg.row_cfg->shift[2];
+ tx_width_ = fwd_txfm_flip_cfg.row_cfg->txfm_size;
+ tx_height_ = fwd_txfm_flip_cfg.col_cfg->txfm_size;
+ const int8_t *shift = (tx_width_ > tx_height_)
+ ? fwd_txfm_flip_cfg.row_cfg->shift
+ : fwd_txfm_flip_cfg.col_cfg->shift;
+ const int amplify_bit = shift[0] + shift[1] + shift[2];
ud_flip_ = fwd_txfm_flip_cfg.ud_flip;
lr_flip_ = fwd_txfm_flip_cfg.lr_flip;
amplify_factor_ =
amplify_bit >= 0 ? (1 << amplify_bit) : (1.0 / (1 << -amplify_bit));
+ // For rectangular transforms, we need to multiply by an extra factor.
+ const int rect_type = get_rect_tx_log_ratio(tx_width_, tx_height_);
+ if (abs(rect_type) == 1) {
+ amplify_factor_ *= pow(2, 0.5);
+ } else if (abs(rect_type) == 2) {
+ const int tx_max_dim = AOMMAX(tx_width_, tx_height_);
+ const int rect_type2_shift =
+ tx_max_dim == 64 ? 3 : tx_max_dim == 32 ? 2 : 1;
+ amplify_factor_ *= pow(2, rect_type2_shift);
+ }
+
fwd_txfm_ = libaom_test::fwd_txfm_func_ls[tx_size_];
- txfm1d_size_ = libaom_test::get_txfm1d_size(tx_size_);
- txfm2d_size_ = txfm1d_size_ * txfm1d_size_;
+ txfm2d_size_ = tx_width_ * tx_height_;
get_txfm1d_type(tx_type_, &type0_, &type1_);
input_ = reinterpret_cast<int16_t *>(
aom_memalign(16, sizeof(input_[0]) * txfm2d_size_));
@@ -76,33 +90,40 @@
ref_output_[ni] = 0;
}
- fwd_txfm_(input_, output_, txfm1d_size_, tx_type_, bd);
+ fwd_txfm_(input_, output_, tx_width_, tx_type_, bd);
- if (lr_flip_ && ud_flip_)
- libaom_test::fliplrud(ref_input_, txfm1d_size_, txfm1d_size_);
- else if (lr_flip_)
- libaom_test::fliplr(ref_input_, txfm1d_size_, txfm1d_size_);
- else if (ud_flip_)
- libaom_test::flipud(ref_input_, txfm1d_size_, txfm1d_size_);
+ if (lr_flip_ && ud_flip_) {
+ libaom_test::fliplrud(ref_input_, tx_width_, tx_height_, tx_width_);
+ } else if (lr_flip_) {
+ libaom_test::fliplr(ref_input_, tx_width_, tx_height_, tx_width_);
+ } else if (ud_flip_) {
+ libaom_test::flipud(ref_input_, tx_width_, tx_height_, tx_width_);
+ }
- reference_hybrid_2d(ref_input_, ref_output_, txfm1d_size_, type0_,
- type1_);
+ reference_hybrid_2d(ref_input_, ref_output_, tx_width_, tx_height_,
+ type0_, type1_);
+ double actual_max_error = 0;
for (int ni = 0; ni < txfm2d_size_; ++ni) {
ref_output_[ni] = round(ref_output_[ni] * amplify_factor_);
- EXPECT_GE(max_error_,
- fabs(output_[ni] - ref_output_[ni]) / amplify_factor_);
+ const double this_error =
+ fabs(output_[ni] - ref_output_[ni]) / amplify_factor_;
+ actual_max_error = AOMMAX(actual_max_error, this_error);
}
+ EXPECT_GE(max_error_, actual_max_error)
+ << "tx_size = " << tx_size_ << ", tx_type = " << tx_type_;
+ if (actual_max_error > max_error_) { // exit early.
+ break;
+ }
+
avg_abs_error += compute_avg_abs_error<int32_t, double>(
output_, ref_output_, txfm2d_size_);
}
avg_abs_error /= amplify_factor_;
avg_abs_error /= count_;
- // max_abs_avg_error comes from upper bound of avg_abs_error
- // printf("type0: %d type1: %d txfm_size: %d accuracy_avg_abs_error:
- // %f\n", type0_, type1_, txfm1d_size_, avg_abs_error);
- EXPECT_GE(max_avg_error_, avg_abs_error);
+ EXPECT_GE(max_avg_error_, avg_abs_error)
+ << "tx_size = " << tx_size_ << ", tx_type = " << tx_type_;
}
virtual void TearDown() {
@@ -119,7 +140,8 @@
double amplify_factor_;
TX_TYPE tx_type_;
TX_SIZE tx_size_;
- int txfm1d_size_;
+ int tx_width_;
+ int tx_height_;
int txfm2d_size_;
Fwd_Txfm2d_Func fwd_txfm_;
TYPE_TXFM type0_;
@@ -132,51 +154,44 @@
int lr_flip_; // flip left to right
};
-TEST_P(AV1FwdTxfm2d, RunFwdAccuracyCheck) { RunFwdAccuracyCheck(); }
-const AV1FwdTxfm2dParam av1_fwd_txfm2d_param_c[] = {
- AV1FwdTxfm2dParam(FLIPADST_DCT, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(DCT_FLIPADST, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(FLIPADST_FLIPADST, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(ADST_FLIPADST, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(FLIPADST_ADST, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(FLIPADST_DCT, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(DCT_FLIPADST, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(FLIPADST_FLIPADST, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(ADST_FLIPADST, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(FLIPADST_ADST, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(FLIPADST_DCT, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(DCT_FLIPADST, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(FLIPADST_FLIPADST, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(ADST_FLIPADST, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(FLIPADST_ADST, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(FLIPADST_DCT, TX_32X32, 70, 7),
- AV1FwdTxfm2dParam(DCT_FLIPADST, TX_32X32, 70, 7),
- AV1FwdTxfm2dParam(FLIPADST_FLIPADST, TX_32X32, 70, 7),
- AV1FwdTxfm2dParam(ADST_FLIPADST, TX_32X32, 70, 7),
- AV1FwdTxfm2dParam(FLIPADST_ADST, TX_32X32, 70, 7),
- AV1FwdTxfm2dParam(DCT_DCT, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(ADST_DCT, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(DCT_ADST, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(ADST_ADST, TX_4X4, 2, 0.2),
- AV1FwdTxfm2dParam(DCT_DCT, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(ADST_DCT, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(DCT_ADST, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(ADST_ADST, TX_8X8, 5, 0.6),
- AV1FwdTxfm2dParam(DCT_DCT, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(ADST_DCT, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(DCT_ADST, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(ADST_ADST, TX_16X16, 11, 1.5),
- AV1FwdTxfm2dParam(DCT_DCT, TX_32X32, 70, 7),
- AV1FwdTxfm2dParam(ADST_DCT, TX_32X32, 70, 7),
- AV1FwdTxfm2dParam(DCT_ADST, TX_32X32, 70, 7),
- AV1FwdTxfm2dParam(ADST_ADST, TX_32X32, 70, 7),
+vector<AV1FwdTxfm2dParam> GetTxfm2dParamList() {
+ vector<AV1FwdTxfm2dParam> param_list;
+ for (int t = 0; t <= FLIPADST_ADST; ++t) {
+ const TX_TYPE tx_type = static_cast<TX_TYPE>(t);
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_4X4, 2, 0.2));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_8X8, 5, 0.6));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X16, 11, 1.5));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_32X32, 70, 7));
+
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_4X8, 2.5, 0.4));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_8X4, 2.5, 0.4));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_8X16, 6, 1));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X8, 6, 1));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X32, 30, 7));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_32X16, 30, 7));
+
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_4X16, 5, 0.6));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X4, 5, 0.6));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_8X32, 11, 1.6));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_32X8, 11, 1.6));
+
#if CONFIG_TX64X64
- AV1FwdTxfm2dParam(DCT_DCT, TX_64X64, 70, 7),
+ if (tx_type == DCT_DCT) { // Other types not supported by these tx sizes.
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_64X64, 70, 7));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_32X64, 136, 7));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_64X32, 136, 7));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_16X64, 16, 1.6));
+ param_list.push_back(AV1FwdTxfm2dParam(tx_type, TX_64X16, 16, 1.6));
+ }
#endif // CONFIG_TX64X64
-};
+ }
+ return param_list;
+}
INSTANTIATE_TEST_CASE_P(C, AV1FwdTxfm2d,
- ::testing::ValuesIn(av1_fwd_txfm2d_param_c));
+ ::testing::ValuesIn(GetTxfm2dParamList()));
+
+TEST_P(AV1FwdTxfm2d, RunFwdAccuracyCheck) { RunFwdAccuracyCheck(); }
TEST(AV1FwdTxfm2d, CfgTest) {
for (int bd_idx = 0; bd_idx < BD_NUM; ++bd_idx) {
diff --git a/test/av1_txfm_test.cc b/test/av1_txfm_test.cc
index b4c6bfa..6872443 100644
--- a/test/av1_txfm_test.cc
+++ b/test/av1_txfm_test.cc
@@ -106,35 +106,40 @@
reference_adst_1d(in, out, size);
}
-void reference_hybrid_2d(double *in, double *out, int size, int type0,
- int type1) {
- double *tempOut = new double[size * size];
+void reference_hybrid_2d(double *in, double *out, int tx_width, int tx_height,
+ int type0, int type1) {
+ double *const temp_in = new double[AOMMAX(tx_width, tx_height)];
+ double *const temp_out = new double[AOMMAX(tx_width, tx_height)];
+ double *const out_interm = new double[tx_width * tx_height];
+ const int stride = tx_width;
- for (int r = 0; r < size; r++) {
- // out ->tempOut
- for (int c = 0; c < size; c++) {
- tempOut[r * size + c] = in[c * size + r];
+ // Transform columns.
+ for (int c = 0; c < tx_width; ++c) {
+ for (int r = 0; r < tx_height; ++r) {
+ temp_in[r] = in[r * stride + c];
+ }
+ reference_hybrid_1d(temp_in, temp_out, tx_height, type0);
+ for (int r = 0; r < tx_height; ++r) {
+ out_interm[r * stride + c] = temp_out[r];
}
}
- // dct each row: in -> out
- for (int r = 0; r < size; r++) {
- reference_hybrid_1d(tempOut + r * size, out + r * size, size, type0);
+ // Transform rows.
+ for (int r = 0; r < tx_height; ++r) {
+ reference_hybrid_1d(out_interm + r * stride, out + r * stride, tx_width,
+ type1);
}
- for (int r = 0; r < size; r++) {
- // out ->tempOut
- for (int c = 0; c < size; c++) {
- tempOut[r * size + c] = out[c * size + r];
- }
- }
-
- for (int r = 0; r < size; r++) {
- reference_hybrid_1d(tempOut + r * size, out + r * size, size, type1);
- }
+ delete[] temp_in;
+ delete[] temp_out;
+ delete[] out_interm;
#if CONFIG_TX64X64
- if (size == 64) { // tx_size == TX64X64
+ // These transforms use an approximate 2D DCT transform, by only keeping the
+ // top-left quarter of the coefficients, and repacking them in the first
+ // quarter indices.
+ // TODO(urvang): Refactor this code.
+ if (tx_width == 64 && tx_height == 64) { // tx_size == TX_64X64
// Zero out top-right 32x32 area.
for (int row = 0; row < 32; ++row) {
memset(out + row * 64 + 32, 0, 32 * sizeof(*out));
@@ -145,50 +150,72 @@
for (int row = 1; row < 32; ++row) {
memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out));
}
+ } else if (tx_width == 32 && tx_height == 64) { // tx_size == TX_32X64
+ // Zero out the bottom 32x32 area.
+ memset(out + 32 * 32, 0, 32 * 32 * sizeof(*out));
+ // Note: no repacking needed here.
+ } else if (tx_width == 64 && tx_height == 32) { // tx_size == TX_64X32
+ // Zero out right 32x32 area.
+ for (int row = 0; row < 32; ++row) {
+ memset(out + row * 64 + 32, 0, 32 * sizeof(*out));
+ }
+ // Re-pack non-zero coeffs in the first 32x32 indices.
+ for (int row = 1; row < 32; ++row) {
+ memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out));
+ }
+ } else if (tx_width == 16 && tx_height == 64) { // tx_size == TX_16X64
+ // Zero out the bottom 16x32 area.
+ memset(out + 16 * 32, 0, 16 * 32 * sizeof(*out));
+ // Note: no repacking needed here.
+ } else if (tx_width == 64 && tx_height == 16) { // tx_size == TX_64X16
+ // Zero out right 32x16 area.
+ for (int row = 0; row < 16; ++row) {
+ memset(out + row * 64 + 32, 0, 32 * sizeof(*out));
+ }
+ // Re-pack non-zero coeffs in the first 32x16 indices.
+ for (int row = 1; row < 16; ++row) {
+ memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out));
+ }
}
#endif // CONFIG_TX_64X64
- delete[] tempOut;
}
template <typename Type>
-void fliplr(Type *dest, int stride, int length) {
- int i, j;
- for (i = 0; i < length; ++i) {
- for (j = 0; j < length / 2; ++j) {
- const Type tmp = dest[i * stride + j];
- dest[i * stride + j] = dest[i * stride + length - 1 - j];
- dest[i * stride + length - 1 - j] = tmp;
+void fliplr(Type *dest, int width, int height, int stride) {
+ for (int r = 0; r < height; ++r) {
+ for (int c = 0; c < width / 2; ++c) {
+ const Type tmp = dest[r * stride + c];
+ dest[r * stride + c] = dest[r * stride + width - 1 - c];
+ dest[r * stride + width - 1 - c] = tmp;
}
}
}
template <typename Type>
-void flipud(Type *dest, int stride, int length) {
- int i, j;
- for (j = 0; j < length; ++j) {
- for (i = 0; i < length / 2; ++i) {
- const Type tmp = dest[i * stride + j];
- dest[i * stride + j] = dest[(length - 1 - i) * stride + j];
- dest[(length - 1 - i) * stride + j] = tmp;
+void flipud(Type *dest, int width, int height, int stride) {
+ for (int c = 0; c < width; ++c) {
+ for (int r = 0; r < height / 2; ++r) {
+ const Type tmp = dest[r * stride + c];
+ dest[r * stride + c] = dest[(height - 1 - r) * stride + c];
+ dest[(height - 1 - r) * stride + c] = tmp;
}
}
}
template <typename Type>
-void fliplrud(Type *dest, int stride, int length) {
- int i, j;
- for (i = 0; i < length / 2; ++i) {
- for (j = 0; j < length; ++j) {
- const Type tmp = dest[i * stride + j];
- dest[i * stride + j] = dest[(length - 1 - i) * stride + length - 1 - j];
- dest[(length - 1 - i) * stride + length - 1 - j] = tmp;
+void fliplrud(Type *dest, int width, int height, int stride) {
+ for (int r = 0; r < height / 2; ++r) {
+ for (int c = 0; c < width; ++c) {
+ const Type tmp = dest[r * stride + c];
+ dest[r * stride + c] = dest[(height - 1 - r) * stride + width - 1 - c];
+ dest[(height - 1 - r) * stride + width - 1 - c] = tmp;
}
}
}
-template void fliplr<double>(double *dest, int stride, int length);
-template void flipud<double>(double *dest, int stride, int length);
-template void fliplrud<double>(double *dest, int stride, int length);
+template void fliplr<double>(double *dest, int width, int height, int stride);
+template void flipud<double>(double *dest, int width, int height, int stride);
+template void fliplrud<double>(double *dest, int width, int height, int stride);
int bd_arr[BD_NUM] = { 8, 10, 12 };
diff --git a/test/av1_txfm_test.h b/test/av1_txfm_test.h
index eb5bdf9..d2d40e7 100644
--- a/test/av1_txfm_test.h
+++ b/test/av1_txfm_test.h
@@ -46,8 +46,8 @@
void reference_hybrid_1d(double *in, double *out, int size, int type);
-void reference_hybrid_2d(double *in, double *out, int size, int type0,
- int type1);
+void reference_hybrid_2d(double *in, double *out, int tx_width, int tx_height,
+ int type0, int type1);
template <typename Type1, typename Type2>
static double compute_avg_abs_error(const Type1 *a, const Type2 *b,
const int size) {
@@ -60,13 +60,13 @@
}
template <typename Type>
-void fliplr(Type *dest, int stride, int length);
+void fliplr(Type *dest, int width, int height, int stride);
template <typename Type>
-void flipud(Type *dest, int stride, int length);
+void flipud(Type *dest, int width, int height, int stride);
template <typename Type>
-void fliplrud(Type *dest, int stride, int length);
+void fliplrud(Type *dest, int width, int height, int stride);
typedef void (*TxfmFunc)(const int32_t *in, int32_t *out, const int8_t *cos_bit,
const int8_t *range_bit);
@@ -81,27 +81,21 @@
#if CONFIG_AV1_ENCODER
static const Fwd_Txfm2d_Func fwd_txfm_func_ls[TX_SIZES_ALL] = {
- av1_fwd_txfm2d_4x4_c,
- av1_fwd_txfm2d_8x8_c,
- av1_fwd_txfm2d_16x16_c,
+ av1_fwd_txfm2d_4x4_c, av1_fwd_txfm2d_8x8_c, av1_fwd_txfm2d_16x16_c,
av1_fwd_txfm2d_32x32_c,
#if CONFIG_TX64X64
av1_fwd_txfm2d_64x64_c,
#endif // CONFIG_TX64X64
- av1_fwd_txfm2d_4x8_c,
- av1_fwd_txfm2d_8x4_c,
- av1_fwd_txfm2d_8x16_c,
- av1_fwd_txfm2d_16x8_c,
- av1_fwd_txfm2d_16x32_c,
- av1_fwd_txfm2d_32x16_c,
+ av1_fwd_txfm2d_4x8_c, av1_fwd_txfm2d_8x4_c, av1_fwd_txfm2d_8x16_c,
+ av1_fwd_txfm2d_16x8_c, av1_fwd_txfm2d_16x32_c, av1_fwd_txfm2d_32x16_c,
#if CONFIG_TX64X64
- av1_fwd_txfm2d_32x64_c,
- av1_fwd_txfm2d_64x32_c,
+ av1_fwd_txfm2d_32x64_c, av1_fwd_txfm2d_64x32_c,
#endif // CONFIG_TX64X64
- NULL,
- NULL,
- NULL,
- NULL,
+ av1_fwd_txfm2d_4x16_c, av1_fwd_txfm2d_16x4_c, av1_fwd_txfm2d_8x32_c,
+ av1_fwd_txfm2d_32x8_c,
+#if CONFIG_TX64X64
+ av1_fwd_txfm2d_16x64_c, av1_fwd_txfm2d_64x16_c,
+#endif // CONFIG_TX64X64
};
#endif