Use prediction error instead of RD cost if available
When TplFramStats::rate_dist_present is set, it indicates that
intra_pred_err/inter_pred_err fields are valid. In this case, use
those fields instead of intra_cost/inter_cost for TPL propagation.
Bug: b/261621000
Change-Id: I670f19afbe3cf463cb1538eeaba999742f88960e
diff --git a/av1/qmode_rc/ratectrl_qmode.cc b/av1/qmode_rc/ratectrl_qmode.cc
index d98669f..f10000d 100644
--- a/av1/qmode_rc/ratectrl_qmode.cc
+++ b/av1/qmode_rc/ratectrl_qmode.cc
@@ -961,10 +961,18 @@
}
TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
- int unit_count) {
+ int unit_count,
+ bool rate_dist_present) {
TplUnitDepStats dep_stats = {};
- dep_stats.intra_cost = block_stats.intra_cost * 1.0 / unit_count;
- dep_stats.inter_cost = block_stats.inter_cost * 1.0 / unit_count;
+ if (rate_dist_present) {
+ dep_stats.intra_cost = block_stats.intra_pred_err;
+ dep_stats.inter_cost = block_stats.inter_pred_err;
+ } else {
+ dep_stats.intra_cost = block_stats.intra_cost;
+ dep_stats.inter_cost = block_stats.inter_cost;
+ }
+ dep_stats.intra_cost *= 1.0 / unit_count;
+ dep_stats.inter_cost *= 1.0 / unit_count;
// In rare case, inter_cost may be greater than intra_cost.
// If so, we need to modify inter_cost such that inter_cost <= intra_cost
// because it is required by GetPropagationFraction()
@@ -1062,8 +1070,8 @@
const int block_unit_cols = std::min(block_stats.width / min_block_size,
unit_cols - block_unit_col);
const int unit_count = block_unit_rows * block_unit_cols;
- TplUnitDepStats this_unit_stats =
- TplBlockStatsToDepStats(block_stats, unit_count);
+ TplUnitDepStats this_unit_stats = TplBlockStatsToDepStats(
+ block_stats, unit_count, frame_stats.rate_dist_present);
for (int r = 0; r < block_unit_rows; r++) {
for (int c = 0; c < block_unit_cols; c++) {
unit_stats[block_unit_row + r][block_unit_col + c] = this_unit_stats;
diff --git a/av1/qmode_rc/ratectrl_qmode.h b/av1/qmode_rc/ratectrl_qmode.h
index 29623d6..89fac61 100644
--- a/av1/qmode_rc/ratectrl_qmode.h
+++ b/av1/qmode_rc/ratectrl_qmode.h
@@ -73,7 +73,7 @@
int min_block_size, bool has_alt_stats);
TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
- int unit_count);
+ int unit_count, bool rate_dist_present);
Status FillTplUnitDepStats(TplFrameDepStats &frame_dep_stats,
const TplFrameStats &frame_stats,
diff --git a/test/ratectrl_qmode_test.cc b/test/ratectrl_qmode_test.cc
index f1da003..6c4cd66 100644
--- a/test/ratectrl_qmode_test.cc
+++ b/test/ratectrl_qmode_test.cc
@@ -437,7 +437,8 @@
const int unit_count = 2;
TplBlockStats block_stats =
CreateToyTplBlockStats(8, 4, 0, 0, intra_cost, inter_cost);
- TplUnitDepStats unit_stats = TplBlockStatsToDepStats(block_stats, unit_count);
+ TplUnitDepStats unit_stats = TplBlockStatsToDepStats(
+ block_stats, unit_count, /*rate_dist_present=*/false);
double expected_intra_cost = intra_cost * 1.0 / unit_count;
EXPECT_NEAR(unit_stats.intra_cost, expected_intra_cost, kErrorEpsilon);
// When inter_cost >= intra_cost in block_stats, in unit_stats,
@@ -445,6 +446,23 @@
EXPECT_LE(unit_stats.inter_cost, unit_stats.intra_cost);
}
+TEST_F(RateControlQModeTest, TplBlockStatsToDepStatsUsingPredErr) {
+ const int intra_cost = 100;
+ const int inter_cost = 120;
+ const int unit_count = 2;
+ TplBlockStats block_stats =
+ CreateToyTplBlockStats(8, 4, 0, 0, intra_cost, inter_cost);
+ block_stats.intra_pred_err = 40;
+ block_stats.inter_pred_err = 50;
+ TplUnitDepStats unit_stats = TplBlockStatsToDepStats(
+ block_stats, unit_count, /*rate_dist_present=*/true);
+ double expected_intra_cost = block_stats.intra_pred_err * 1.0 / unit_count;
+ EXPECT_NEAR(unit_stats.intra_cost, expected_intra_cost, kErrorEpsilon);
+ // When inter_cost >= intra_cost in block_stats, in unit_stats,
+ // the inter_cost will be modified so that it's upper-bounded by intra_cost.
+ EXPECT_LE(unit_stats.inter_cost, unit_stats.intra_cost);
+}
+
TEST_F(RateControlQModeTest, TplFrameDepStatsPropagateSingleZeroMotion) {
// cur frame with coding_idx 1 use ref frame with coding_idx 0
const std::array<int, kBlockRefCount> ref_frame_index = { 0, -1 };