Do not propagate arf frames in 2-pass tpl.
Gives slight gains (~0.1%).
Change-Id: Iea5f818501c17239221b518860a0e1c9a735a56d
diff --git a/av1/qmode_rc/ratectrl_qmode.cc b/av1/qmode_rc/ratectrl_qmode.cc
index afb1516..2c7f101 100644
--- a/av1/qmode_rc/ratectrl_qmode.cc
+++ b/av1/qmode_rc/ratectrl_qmode.cc
@@ -1187,7 +1187,7 @@
ModifyDivisor(unit_dep_stats.intra_cost);
}
-void TplFrameDepStatsBackTrace(int coding_idx,
+void TplFrameDepStatsBackTrace(int coding_idx, GopFrameType update_type,
const RefFrameTable &ref_frame_table,
TplGopDepStats *tpl_gop_dep_stats) {
assert(!tpl_gop_dep_stats->frame_dep_stats_list.empty());
@@ -1197,6 +1197,10 @@
if (frame_dep_stats->unit_stats.empty()) return;
if (frame_dep_stats->alt_unit_stats.empty()) return;
+ const bool ignore_inter = update_type == GopFrameType::kRegularKey ||
+ update_type == GopFrameType::kRegularArf ||
+ update_type == GopFrameType::kRegularGolden;
+
const int unit_size = frame_dep_stats->unit_size;
const int frame_unit_rows =
static_cast<int>(frame_dep_stats->unit_stats.size());
@@ -1250,14 +1254,14 @@
ref_frame_dep_stats
.alt_unit_stats[ref_unit_row][ref_unit_col];
alt_ref_unit_stats.propagation_cost +=
- (alt_unit_dep_stats.inter_cost +
+ ((ignore_inter ? 0.0 : alt_unit_dep_stats.inter_cost) +
alt_unit_dep_stats.propagation_cost) *
propagation_ratio;
TplUnitDepStats &ref_unit_stats =
ref_frame_dep_stats.unit_stats[ref_unit_row][ref_unit_col];
ref_unit_stats.propagation_cost +=
- (unit_dep_stats.inter_cost +
+ ((ignore_inter ? 0.0 : unit_dep_stats.inter_cost) +
unit_dep_stats.propagation_cost) *
propagation_ratio;
}
@@ -1390,7 +1394,7 @@
}
StatusOr<TplGopDepStats> ComputeTplGopDepStats(
- const TplGopStats &tpl_gop_stats,
+ const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
const std::vector<LookaheadStats> &lookahead_stats,
const std::vector<RefFrameTable> &ref_frame_table_list) {
std::vector<const TplFrameStats *> tpl_frame_stats_list_with_lookahead;
@@ -1435,7 +1439,19 @@
&tpl_gop_dep_stats);
} else {
// Two pass TPL runs
- TplFrameDepStatsBackTrace(coding_idx, ref_frame_table,
+
+ int first_gop_size = static_cast<int>(gop_struct.gop_frame_list.size());
+ GopFrameType update_type;
+ if (coding_idx >= first_gop_size) {
+ update_type = lookahead_stats[0]
+ .gop_struct[0]
+ .gop_frame_list[coding_idx - first_gop_size]
+ .update_type;
+ } else {
+ update_type = gop_struct.gop_frame_list[coding_idx].update_type;
+ }
+
+ TplFrameDepStatsBackTrace(coding_idx, update_type, ref_frame_table,
&tpl_gop_dep_stats);
}
}
@@ -1969,7 +1985,7 @@
gop_encode_info.final_snapshot = ref_frame_table_list[frame_count];
StatusOr<TplGopDepStats> gop_dep_stats = ComputeTplGopDepStats(
- tpl_gop_stats, lookahead_stats, ref_frame_table_list);
+ gop_struct, tpl_gop_stats, lookahead_stats, ref_frame_table_list);
if (!gop_dep_stats.ok()) {
return gop_dep_stats.status();
}
diff --git a/av1/qmode_rc/ratectrl_qmode.h b/av1/qmode_rc/ratectrl_qmode.h
index 89fac61..d8d3ec8 100644
--- a/av1/qmode_rc/ratectrl_qmode.h
+++ b/av1/qmode_rc/ratectrl_qmode.h
@@ -107,7 +107,7 @@
}
StatusOr<TplGopDepStats> ComputeTplGopDepStats(
- const TplGopStats &tpl_gop_stats,
+ const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
const std::vector<LookaheadStats> &lookahead_stats,
const std::vector<RefFrameTable> &ref_frame_table_list);
diff --git a/test/ratectrl_qmode_test.cc b/test/ratectrl_qmode_test.cc
index 6c4cd66..760ae70 100644
--- a/test/ratectrl_qmode_test.cc
+++ b/test/ratectrl_qmode_test.cc
@@ -620,8 +620,8 @@
ref_frame_table_list.push_back(CreateToyRefFrameTable(i));
}
- const StatusOr<TplGopDepStats> gop_dep_stats =
- ComputeTplGopDepStats(tpl_gop_stats, {}, ref_frame_table_list);
+ const StatusOr<TplGopDepStats> gop_dep_stats = ComputeTplGopDepStats(
+ gop_struct, tpl_gop_stats, {}, ref_frame_table_list);
ASSERT_THAT(gop_dep_stats.status(), IsOkStatus());
double expected_sum = 0;