Add av1_accumulate_tpl_txfm_stats()
Add unit test for its basic mechanics.
Change-Id: Ib30909628484fb92f9d84a2cebaffcd6202c6240
diff --git a/av1/encoder/ethread.c b/av1/encoder/ethread.c
index f7a361e..0911ead 100644
--- a/av1/encoder/ethread.c
+++ b/av1/encoder/ethread.c
@@ -1415,21 +1415,17 @@
}
// Accumulate transform stats after tpl.
-static void tpl_accumulate_txfm_stats(AV1_COMP *cpi, int num_workers) {
- double *total_abs_coeff_sum = cpi->td.tpl_txfm_stats.abs_coeff_sum;
- int *txfm_block_count = &cpi->td.tpl_txfm_stats.txfm_block_count;
- TplParams *tpl_data = &cpi->ppi->tpl_data;
- int coeff_num = tpl_data->tpl_frame[tpl_data->frame_idx].coeff_num;
+static void tpl_accumulate_txfm_stats(ThreadData *main_td,
+ const MultiThreadInfo *mt_info,
+ int num_workers) {
+ TplTxfmStats *accumulated_stats = &main_td->tpl_txfm_stats;
for (int i = num_workers - 1; i >= 0; i--) {
- AVxWorker *const worker = &cpi->mt_info.workers[i];
+ AVxWorker *const worker = &mt_info->workers[i];
EncWorkerData *const thread_data = (EncWorkerData *)worker->data1;
ThreadData *td = thread_data->td;
- if (td != &cpi->td) {
- TplTxfmStats *tpl_txfm_stats = &td->tpl_txfm_stats;
- *txfm_block_count += tpl_txfm_stats->txfm_block_count;
- for (int j = 0; j < coeff_num; j++) {
- total_abs_coeff_sum[j] += tpl_txfm_stats->abs_coeff_sum[j];
- }
+ if (td != main_td) {
+ const TplTxfmStats *tpl_txfm_stats = &td->tpl_txfm_stats;
+ av1_accumulate_tpl_txfm_stats(tpl_txfm_stats, accumulated_stats);
}
}
}
@@ -1458,7 +1454,7 @@
prepare_tpl_workers(cpi, tpl_worker_hook, num_workers);
launch_workers(&cpi->mt_info, num_workers);
sync_enc_workers(&cpi->mt_info, cm, num_workers);
- tpl_accumulate_txfm_stats(cpi, num_workers);
+ tpl_accumulate_txfm_stats(&cpi->td, &cpi->mt_info, num_workers);
}
// Deallocate memory for temporal filter multi-thread synchronization.
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index 64318cd..de43044 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -42,6 +42,14 @@
sizeof(tpl_txfm_stats->abs_coeff_sum[0]) * tpl_txfm_stats->coeff_num);
}
+void av1_accumulate_tpl_txfm_stats(const TplTxfmStats *sub_stats,
+ TplTxfmStats *accumulated_stats) {
+ accumulated_stats->txfm_block_count += sub_stats->txfm_block_count;
+ for (int i = 0; i < accumulated_stats->coeff_num; ++i) {
+ accumulated_stats->abs_coeff_sum[i] += sub_stats->abs_coeff_sum[i];
+ }
+}
+
static AOM_INLINE void tpl_stats_record_txfm_block(TplTxfmStats *tpl_txfm_stats,
const tran_low_t *coeff,
int coeff_num) {
diff --git a/av1/encoder/tpl_model.h b/av1/encoder/tpl_model.h
index 5062ec1..b2b5393 100644
--- a/av1/encoder/tpl_model.h
+++ b/av1/encoder/tpl_model.h
@@ -328,6 +328,17 @@
*/
void av1_init_tpl_txfm_stats(TplTxfmStats *tpl_txfm_stats);
+/*
+ *!\brief Accumulate TplTxfmStats
+ *
+ * \param[in] sub_stats a structure for storing sub transform stats
+ * \param[out] accumulated_stats a structure for storing accumulated transform
+ *stats
+ *
+ */
+void av1_accumulate_tpl_txfm_stats(const TplTxfmStats *sub_stats,
+ TplTxfmStats *accumulated_stats);
+
/*!\brief Init data structure storing transform stats
*
*\ingroup tpl_modelling
diff --git a/test/tpl_model_test.cc b/test/tpl_model_test.cc
index d1e6050..111a7a2 100644
--- a/test/tpl_model_test.cc
+++ b/test/tpl_model_test.cc
@@ -182,4 +182,26 @@
}
}
+TEST(TPLModelTest, TxfmStatsAccumulateTest) {
+ TplTxfmStats sub_stats;
+ av1_init_tpl_txfm_stats(&sub_stats);
+ sub_stats.txfm_block_count = 17;
+ for (int i = 0; i < sub_stats.coeff_num; ++i) {
+ sub_stats.abs_coeff_sum[i] = i;
+ }
+
+ TplTxfmStats accumulated_stats;
+ av1_init_tpl_txfm_stats(&accumulated_stats);
+ accumulated_stats.txfm_block_count = 13;
+ for (int i = 0; i < accumulated_stats.coeff_num; ++i) {
+ accumulated_stats.abs_coeff_sum[i] = 5 * i;
+ }
+
+ av1_accumulate_tpl_txfm_stats(&sub_stats, &accumulated_stats);
+ EXPECT_DOUBLE_EQ(accumulated_stats.txfm_block_count, 30);
+ for (int i = 0; i < accumulated_stats.coeff_num; ++i) {
+ EXPECT_DOUBLE_EQ(accumulated_stats.abs_coeff_sum[i], 6 * i);
+ }
+}
+
} // namespace