Allow Gop look ahead in ratectrl_qmode
This is a rework of 161321 to allow arbitrary number of lookahead gops.
BUG=b/243687356
Change-Id: Iaa6036743db945d1b6fa2fda7f3cf7037dc9ff99
diff --git a/av1/ratectrl_qmode.cc b/av1/ratectrl_qmode.cc
index 8a54e15..9e2aae2 100644
--- a/av1/ratectrl_qmode.cc
+++ b/av1/ratectrl_qmode.cc
@@ -1153,7 +1153,9 @@
}
std::vector<RefFrameTable> AV1RateControlQMode::GetRefFrameTableList(
- const GopStruct &gop_struct, RefFrameTable ref_frame_table) {
+ const GopStruct &gop_struct,
+ const std::vector<LookaheadStats> &lookahead_stats,
+ RefFrameTable ref_frame_table) {
if (gop_struct.global_coding_idx_offset == 0) {
// For the first GOP, ref_frame_table need not be initialized. This is fine,
// because the first frame (a key frame) will fully initialize it.
@@ -1180,14 +1182,46 @@
}
ref_frame_table_list.push_back(ref_frame_table);
}
+
+ int gop_size_offset = static_cast<int>(gop_struct.gop_frame_list.size());
+
+ for (const auto &lookahead_stat : lookahead_stats) {
+ for (GopFrame gop_frame : lookahead_stat.gop_struct->gop_frame_list) {
+ if (gop_frame.is_key_frame) {
+ ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
+ } else if (gop_frame.update_ref_idx != -1) {
+ assert(gop_frame.update_ref_idx <
+ static_cast<int>(ref_frame_table.size()));
+ gop_frame.coding_idx += gop_size_offset;
+ ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
+ }
+ ref_frame_table_list.push_back(ref_frame_table);
+ }
+ gop_size_offset +=
+ static_cast<int>(lookahead_stat.gop_struct->gop_frame_list.size());
+ }
+
return ref_frame_table_list;
}
StatusOr<TplGopDepStats> ComputeTplGopDepStats(
const TplGopStats &tpl_gop_stats,
+ const std::vector<LookaheadStats> &lookahead_stats,
const std::vector<RefFrameTable> &ref_frame_table_list) {
+ std::vector<const TplFrameStats *> tpl_frame_stats_list_with_lookahead;
+ for (const auto &tpl_frame_stats : tpl_gop_stats.frame_stats_list) {
+ tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
+ }
+ for (auto &lookahead_stat : lookahead_stats) {
+ for (const auto &tpl_frame_stats :
+ lookahead_stat.tpl_gop_stats->frame_stats_list) {
+ tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
+ }
+ }
+
const int frame_count =
- static_cast<int>(tpl_gop_stats.frame_stats_list.size());
+ static_cast<int>(tpl_frame_stats_list_with_lookahead.size());
+
// Create the struct to store TPL dependency stats
TplGopDepStats tpl_gop_dep_stats;
@@ -1195,7 +1229,7 @@
for (int coding_idx = 0; coding_idx < frame_count; coding_idx++) {
const StatusOr<TplFrameDepStats> tpl_frame_dep_stats =
CreateTplFrameDepStatsWithoutPropagation(
- tpl_gop_stats.frame_stats_list[coding_idx]);
+ *tpl_frame_stats_list_with_lookahead[coding_idx]);
if (!tpl_frame_dep_stats.ok()) {
return tpl_frame_dep_stats.status();
}
@@ -1233,20 +1267,26 @@
const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
const std::vector<LookaheadStats> &lookahead_stats,
const RefFrameTable &ref_frame_table_snapshot_init) {
- // TODO(b/242892473): Use lookahead_stats.
- (void)lookahead_stats;
Status status = ValidateTplStats(gop_struct, tpl_gop_stats);
if (!status.ok()) {
return status;
}
- const std::vector<RefFrameTable> ref_frame_table_list =
- GetRefFrameTableList(gop_struct, ref_frame_table_snapshot_init);
+ for (auto &lookahead_stat : lookahead_stats) {
+ Status status = ValidateTplStats(*lookahead_stat.gop_struct,
+ *lookahead_stat.tpl_gop_stats);
+ if (!status.ok()) {
+ return status;
+ }
+ }
+
+ const std::vector<RefFrameTable> ref_frame_table_list = GetRefFrameTableList(
+ gop_struct, lookahead_stats, ref_frame_table_snapshot_init);
GopEncodeInfo gop_encode_info;
gop_encode_info.final_snapshot = ref_frame_table_list.back();
- StatusOr<TplGopDepStats> gop_dep_stats =
- ComputeTplGopDepStats(tpl_gop_stats, ref_frame_table_list);
+ StatusOr<TplGopDepStats> gop_dep_stats = ComputeTplGopDepStats(
+ tpl_gop_stats, lookahead_stats, ref_frame_table_list);
if (!gop_dep_stats.ok()) {
return gop_dep_stats.status();
}
diff --git a/av1/ratectrl_qmode.h b/av1/ratectrl_qmode.h
index 93bd834..5f7dd86 100644
--- a/av1/ratectrl_qmode.h
+++ b/av1/ratectrl_qmode.h
@@ -87,6 +87,7 @@
StatusOr<TplGopDepStats> ComputeTplGopDepStats(
const TplGopStats &tpl_gop_stats,
+ const std::vector<LookaheadStats> &lookahead_stats,
const std::vector<RefFrameTable> &ref_frame_table_list);
class AV1RateControlQMode : public AV1RateControlQModeInterface {
@@ -105,7 +106,9 @@
// If this is first GOP, ref_frame_table is ignored and all refs are assumed
// invalid; otherwise ref_frame_table is used as the initial state.
std::vector<RefFrameTable> GetRefFrameTableList(
- const GopStruct &gop_struct, RefFrameTable ref_frame_table);
+ const GopStruct &gop_struct,
+ const std::vector<LookaheadStats> &lookahead_stats,
+ RefFrameTable ref_frame_table);
private:
RateControlParam rc_param_;
diff --git a/av1/ratectrl_qmode_interface.h b/av1/ratectrl_qmode_interface.h
index 446e2ac..ab2e4e0 100644
--- a/av1/ratectrl_qmode_interface.h
+++ b/av1/ratectrl_qmode_interface.h
@@ -254,13 +254,6 @@
const TplGopStats *tpl_gop_stats; // Not owned, may not be nullptr.
};
-// Now that there are no more references to the old three-argument
-// GetGopEncodeInfo, it can be brought back to life as an alias for
-// GetGopEncodeInfoWithLookahead.
-// TODO(b/242892473): Remove this #define after replacing all references to
-// GetGopEncodeInfoWithLookahead with GetGopEncodeInfo.
-#define GetGopEncodeInfo GetGopEncodeInfoWithLookahead
-
class AV1RateControlQModeInterface {
public:
AV1RateControlQModeInterface();
diff --git a/test/ratectrl_qmode_test.cc b/test/ratectrl_qmode_test.cc
index 29bc52d..dd29ae9 100644
--- a/test/ratectrl_qmode_test.cc
+++ b/test/ratectrl_qmode_test.cc
@@ -585,6 +585,7 @@
}
}
+// TODO(jianj): Add tests for non empty lookahead stats.
TEST_F(RateControlQModeTest, ComputeTplGopDepStats) {
TplGopStats tpl_gop_stats;
std::vector<RefFrameTable> ref_frame_table_list;
@@ -602,7 +603,7 @@
ref_frame_table_list.push_back(CreateToyRefFrameTable(i));
}
const StatusOr<TplGopDepStats> gop_dep_stats =
- ComputeTplGopDepStats(tpl_gop_stats, ref_frame_table_list);
+ ComputeTplGopDepStats(tpl_gop_stats, {}, ref_frame_table_list);
ASSERT_THAT(gop_dep_stats.status(), IsOkStatus());
double expected_sum = 0;
@@ -970,7 +971,7 @@
// For the first GOP only, GetRefFrameTableList can be passed a
// default-constructed RefFrameTable (because it's all going to be
// replaced by the key frame anyway).
- rc.GetRefFrameTableList(gop_struct, RefFrameTable()),
+ rc.GetRefFrameTableList(gop_struct, {}, RefFrameTable()),
ElementsAre(
ElementsAre(matches_invalid, matches_invalid, matches_invalid),
ElementsAre(matches_frame0, matches_frame0, matches_frame0),
@@ -1000,7 +1001,7 @@
gop_struct.global_coding_idx_offset = 5; // This is not the first GOP.
gop_struct.gop_frame_list = { frame0, frame1, frame2 };
ASSERT_THAT(
- rc.GetRefFrameTableList(gop_struct, RefFrameTable(3, previous)),
+ rc.GetRefFrameTableList(gop_struct, {}, RefFrameTable(3, previous)),
ElementsAre(
ElementsAre(matches_previous, matches_previous, matches_previous),
ElementsAre(matches_previous, matches_previous, matches_frame0),