Add unit test to av1_record_tpl_txfm_block()
Change-Id: I6997c787030cbc41a84b90581df5458196bad5ea
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index de43044..baa2e3f 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -50,13 +50,12 @@
}
}
-static AOM_INLINE void tpl_stats_record_txfm_block(TplTxfmStats *tpl_txfm_stats,
- const tran_low_t *coeff,
- int coeff_num) {
+void av1_record_tpl_txfm_block(TplTxfmStats *tpl_txfm_stats,
+ const tran_low_t *coeff) {
// For transform larger than 16x16, the scale of coeff need to be adjusted.
// It's not LOSSLESS_Q_STEP.
- assert(coeff_num <= 256);
- for (int i = 0; i < coeff_num; ++i) {
+ assert(tpl_txfm_stats->coeff_num <= 256);
+ for (int i = 0; i < tpl_txfm_stats->coeff_num; ++i) {
tpl_txfm_stats->abs_coeff_sum[i] += abs(coeff[i]) / (double)LOSSLESS_Q_STEP;
}
++tpl_txfm_stats->txfm_block_count;
@@ -799,7 +798,7 @@
rec_stride_pool, tx_size, best_mode, mi_row, mi_col,
use_y_only_rate_distortion);
- tpl_stats_record_txfm_block(tpl_txfm_stats, coeff, tpl_frame->coeff_num);
+ av1_record_tpl_txfm_block(tpl_txfm_stats, coeff);
tpl_stats->recrf_dist = recon_error << (TPL_DEP_COST_SCALE_LOG2);
tpl_stats->recrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
diff --git a/av1/encoder/tpl_model.h b/av1/encoder/tpl_model.h
index b2b5393..e6489be 100644
--- a/av1/encoder/tpl_model.h
+++ b/av1/encoder/tpl_model.h
@@ -339,6 +339,17 @@
void av1_accumulate_tpl_txfm_stats(const TplTxfmStats *sub_stats,
TplTxfmStats *accumulated_stats);
+/*
+ *!\brief Record a transform block into TplTxfmStats
+ *
+ * \param[in] tpl_txfm_stats A structure for storing transform stats
+ * \param[out] coeff An array of transform coefficients. Its size
+ * should equal to tpl_txfm_stats.coeff_num.
+ *
+ */
+void av1_record_tpl_txfm_block(TplTxfmStats *tpl_txfm_stats,
+ const tran_low_t *coeff);
+
/*!\brief Init data structure storing transform stats
*
*\ingroup tpl_modelling
diff --git a/test/tpl_model_test.cc b/test/tpl_model_test.cc
index 111a7a2..83845ee 100644
--- a/test/tpl_model_test.cc
+++ b/test/tpl_model_test.cc
@@ -204,4 +204,29 @@
}
}
+TEST(TPLModelTest, TxfmStatsRecordTest) {
+ TplTxfmStats stats1;
+ TplTxfmStats stats2;
+ av1_init_tpl_txfm_stats(&stats1);
+ av1_init_tpl_txfm_stats(&stats2);
+
+ tran_low_t coeff[256];
+ for (int i = 0; i < 256; ++i) {
+ coeff[i] = i;
+ }
+ av1_record_tpl_txfm_block(&stats1, coeff);
+ EXPECT_EQ(stats1.txfm_block_count, 1);
+
+ // we record the same transform block twice for testing purpose
+ av1_record_tpl_txfm_block(&stats2, coeff);
+ av1_record_tpl_txfm_block(&stats2, coeff);
+ EXPECT_EQ(stats2.txfm_block_count, 2);
+
+ EXPECT_EQ(stats1.coeff_num, 256);
+ EXPECT_EQ(stats2.coeff_num, 256);
+ for (int i = 0; i < 256; ++i) {
+ EXPECT_DOUBLE_EQ(stats2.abs_coeff_sum[i], 2 * stats1.abs_coeff_sum[i]);
+ }
+}
+
} // namespace