ExtPart: Properly handle tpl stats collection
Use a flag to indicate if tpl stats are available.
Relax the psnr diff.
Change-Id: I2c2986175457fc1da08c12ca07757ee3563c3428
diff --git a/aom/aom_external_partition.h b/aom/aom_external_partition.h
index 45e7a4d..ad2214c 100644
--- a/aom/aom_external_partition.h
+++ b/aom/aom_external_partition.h
@@ -176,6 +176,7 @@
* 128x128 / (16x16) = 64. Change it if the tpl process changes.
*/
typedef struct aom_sb_tpl_features {
+ int available; /**< If tpl stats are available */
int tpl_unit_length; /**< The block length of tpl process */
int num_units; /**< The number of units inside the current superblock */
int64_t intra_cost[64]; /**< The intra cost of each unit */
diff --git a/av1/encoder/partition_search.c b/av1/encoder/partition_search.c
index 22b0ff2..d2d373d 100644
--- a/av1/encoder/partition_search.c
+++ b/av1/encoder/partition_search.c
@@ -102,6 +102,12 @@
TplParams *const tpl_data = &cpi->ppi->tpl_data;
TplDepFrame *tpl_frame = &tpl_data->tpl_frame[cpi->gf_frame_index];
TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
+ // If tpl stats is not established, early return
+ if (!tpl_data->ready || gf_group->max_layer_depth_allowed == 0) {
+ features->sb_features.tpl_features.available = 0;
+ return;
+ }
+
const int tpl_stride = tpl_frame->stride;
const int step = 1 << tpl_data->tpl_stats_block_mis_log2;
const int mi_width =
@@ -159,6 +165,7 @@
}
fclose(pfile);
} else {
+ features->sb_features.tpl_features.available = 1;
features->sb_features.tpl_features.tpl_unit_length = tpl_data->tpl_bsize_1d;
features->sb_features.tpl_features.num_units = num_blocks;
int count = 0;
@@ -4096,11 +4103,6 @@
static void prepare_sb_features_before_search(
AV1_COMP *const cpi, ThreadData *td, int mi_row, int mi_col,
const BLOCK_SIZE bsize, aom_partition_features_t *features) {
- // TODO(chengchen): properly handle feature collection for unit tests.
- // Also take care of cases where tpl stats are not available.
- // Now in unit test, this function causes failures, due to tpl stats
- // not ready.
- return;
av1_collect_motion_search_features_sb(cpi, td, mi_row, mi_col, bsize,
features);
collect_tpl_stats_sb(cpi, bsize, mi_row, mi_col, features);
diff --git a/test/av1_external_partition_test.cc b/test/av1_external_partition_test.cc
index bea0445..9cbf491 100644
--- a/test/av1_external_partition_test.cc
+++ b/test/av1_external_partition_test.cc
@@ -389,8 +389,6 @@
ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
const double psnr2 = GetAveragePsnr();
- printf("psnr %.5f\n", psnr);
-
EXPECT_DOUBLE_EQ(psnr, psnr2);
}
@@ -408,7 +406,7 @@
ASSERT_NO_FATAL_FAILURE(RunLoop(&video));
const double psnr2 = GetAveragePsnr();
- const double psnr_thresh = 0.001;
+ const double psnr_thresh = 0.02;
EXPECT_NEAR(psnr, psnr2, psnr_thresh);
}