Use two tpl passes to estimate the propagation factor
Change-Id: I808033427cae547987c2c3205aefd92035710b75
diff --git a/av1/qmode_rc/ratectrl_qmode.cc b/av1/qmode_rc/ratectrl_qmode.cc
index b11a609..8a02d2e 100644
--- a/av1/qmode_rc/ratectrl_qmode.cc
+++ b/av1/qmode_rc/ratectrl_qmode.cc
@@ -899,7 +899,8 @@
}
TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
- int min_block_size) {
+ int min_block_size,
+ bool has_alt_stats) {
const int unit_rows = (frame_height + min_block_size - 1) / min_block_size;
const int unit_cols = (frame_width + min_block_size - 1) / min_block_size;
TplFrameDepStats frame_dep_stats;
@@ -908,6 +909,13 @@
for (auto &row : frame_dep_stats.unit_stats) {
row.resize(unit_cols);
}
+
+ if (has_alt_stats) {
+ frame_dep_stats.alt_unit_stats.resize(unit_rows);
+ for (auto &row : frame_dep_stats.alt_unit_stats) {
+ row.resize(unit_cols);
+ }
+ }
return frame_dep_stats;
}
@@ -919,7 +927,10 @@
// In rare case, inter_cost may be greater than intra_cost.
// If so, we need to modify inter_cost such that inter_cost <= intra_cost
// because it is required by GetPropagationFraction()
- dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
+ if (block_stats.ref_frame_index[0] >= 0)
+ dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
+ else
+ dep_stats.inter_cost = dep_stats.intra_cost;
dep_stats.mv = block_stats.mv;
dep_stats.ref_frame_index = block_stats.ref_frame_index;
return dep_stats;
@@ -985,19 +996,15 @@
}
} // namespace
-StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
- const TplFrameStats &frame_stats) {
- if (frame_stats.block_stats_list.empty()) {
- return TplFrameDepStats();
- }
+Status FillTplUnitDepStats(TplFrameDepStats &frame_dep_stats,
+ const TplFrameStats &frame_stats,
+ const std::vector<TplBlockStats> &block_stats_list) {
const int min_block_size = frame_stats.min_block_size;
const int unit_rows =
(frame_stats.frame_height + min_block_size - 1) / min_block_size;
const int unit_cols =
(frame_stats.frame_width + min_block_size - 1) / min_block_size;
- TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
- frame_stats.frame_height, frame_stats.frame_width, min_block_size);
- for (const TplBlockStats &block_stats : frame_stats.block_stats_list) {
+ for (const TplBlockStats &block_stats : block_stats_list) {
Status status =
ValidateBlockStats(frame_stats, block_stats, min_block_size);
if (!status.ok()) {
@@ -1022,8 +1029,32 @@
}
}
}
+ return { AOM_CODEC_OK, "" };
+}
- frame_dep_stats.rdcost = TplFrameDepStatsAccumulateInterCost(frame_dep_stats);
+StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
+ const TplFrameStats &frame_stats) {
+ if (frame_stats.block_stats_list.empty()) {
+ return TplFrameDepStats();
+ }
+ const int min_block_size = frame_stats.min_block_size;
+ TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
+ frame_stats.frame_height, frame_stats.frame_width, min_block_size,
+ !frame_stats.alternate_block_stats_list.empty());
+
+ Status status = FillTplUnitDepStats(frame_dep_stats, frame_stats,
+ frame_stats.block_stats_list);
+ if (!status.ok()) return status;
+
+ if (!frame_stats.alternate_block_stats_list.empty()) {
+ status = FillTplUnitDepStats(frame_dep_stats, frame_stats,
+ frame_stats.alternate_block_stats_list);
+ if (!status.ok()) return status;
+ frame_dep_stats.rdcost =
+ TplFrameDepStatsAccumulateInterCost(frame_dep_stats.unit_stats);
+ frame_dep_stats.alt_rdcost =
+ TplFrameDepStatsAccumulateInterCost(frame_dep_stats.alt_unit_stats);
+ }
return frame_dep_stats;
}
@@ -1070,12 +1101,12 @@
}
double TplFrameDepStatsAccumulateInterCost(
- const TplFrameDepStats &frame_dep_stats) {
+ const std::vector<std::vector<TplUnitDepStats>> &unit_stats) {
auto getInterCost = [](double sum, const TplUnitDepStats &unit) {
return sum + unit.inter_cost;
};
double sum = 0;
- for (const auto &row : frame_dep_stats.unit_stats) {
+ for (const auto &row : unit_stats) {
sum = std::accumulate(row.begin(), row.end(), sum, getInterCost);
}
return std::max(sum, 1.0);
@@ -1427,10 +1458,7 @@
if (gop_frame.update_type == GopFrameType::kRegularGolden ||
gop_frame.update_type == GopFrameType::kRegularKey ||
gop_frame.update_type == GopFrameType::kRegularArf) {
- double qstep_ratio = 1 / 3.0;
- param.q_index = av1_get_q_index_from_qstep_ratio(
- rc_param_.base_q_index, qstep_ratio, AOM_BITS_8);
- if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
+ param.q_index = 5;
}
}
gop_encode_info.param_list.push_back(param);
@@ -1464,6 +1492,36 @@
static_cast<int>(tpl_gop_stats.frame_stats_list.size());
const int active_worst_quality = rc_param_.base_q_index;
int active_best_quality = rc_param_.base_q_index;
+
+ double base_rdcost = 1.0; // baseline total rdcost
+ double hqr_rdcost = 0; // high quality reference total rdcost
+ double arf_rdcost_high = 1.0;
+ double arf_rdcost_low = 0;
+
+ bool kf_arf_seen = false;
+
+ for (int i = 0; i < frame_count; ++i) {
+ FrameEncodeParameters param;
+ const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
+ const TplFrameDepStats &frame_dep_stats =
+ gop_dep_stats->frame_dep_stats_list[i];
+ if (gop_frame.update_type == GopFrameType::kRegularGolden ||
+ gop_frame.update_type == GopFrameType::kRegularKey ||
+ gop_frame.update_type == GopFrameType::kRegularArf) {
+ if (!kf_arf_seen) {
+ arf_rdcost_high += frame_dep_stats.rdcost;
+ arf_rdcost_low += frame_dep_stats.alt_rdcost;
+ }
+ kf_arf_seen = 1;
+ } else {
+ base_rdcost += frame_dep_stats.rdcost;
+ hqr_rdcost += frame_dep_stats.alt_rdcost;
+ }
+ }
+
+ double tp_frame_importance =
+ 1.0 + fabs((base_rdcost - hqr_rdcost) / arf_rdcost_high);
+
for (int i = 0; i < frame_count; i++) {
FrameEncodeParameters param;
const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
@@ -1481,8 +1539,15 @@
TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
const double cost_with_propagation =
TplFrameDepStatsAccumulate(frame_dep_stats);
- const double frame_importance =
+ double frame_importance =
cost_with_propagation / cost_without_propagation;
+
+ // TODO(jingning): Temporarily make the switch between single and
+ // two TPL passes depending on the availability. This part of code
+ // needs further modifications to support SB level calculation.
+ if (rc_param_.tpl_pass_count == TplPassCount::kTwoTplPasses)
+ frame_importance = tp_frame_importance;
+
// Imitate the behavior of av1_tpl_get_qstep_ratio()
const double qstep_ratio = sqrt(1 / frame_importance);
param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
diff --git a/av1/qmode_rc/ratectrl_qmode.h b/av1/qmode_rc/ratectrl_qmode.h
index f60000e..70e112a 100644
--- a/av1/qmode_rc/ratectrl_qmode.h
+++ b/av1/qmode_rc/ratectrl_qmode.h
@@ -35,9 +35,11 @@
};
struct TplFrameDepStats {
- int unit_size; // equivalent to min_block_size
- double rdcost; // overall rate-distortion cost
+ int unit_size; // equivalent to min_block_size
+ double rdcost; // overall rate-distortion cost
+ double alt_rdcost; // rate-distortion cost in the second tpl pass
std::vector<std::vector<TplUnitDepStats>> unit_stats;
+ std::vector<std::vector<TplUnitDepStats>> alt_unit_stats;
};
struct TplGopDepStats {
@@ -66,11 +68,15 @@
// and blocks along the bottom or right edge of the frame may extend beyond the
// edges of the frame.
TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
- int min_block_size);
+ int min_block_size, bool has_alt_stats);
TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
int unit_count);
+Status FillTplUnitDepStats(TplFrameDepStats &frame_dep_stats,
+ const TplFrameStats &frame_stats,
+ const std::vector<TplBlockStats> &block_stats_list);
+
StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
const TplFrameStats &frame_stats);
@@ -80,7 +86,7 @@
const TplFrameDepStats &frame_dep_stats);
double TplFrameDepStatsAccumulateInterCost(
- const TplFrameDepStats &frame_dep_stats);
+ const std::vector<std::vector<TplUnitDepStats>> &unit_stats);
double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats);
diff --git a/test/ratectrl_qmode_test.cc b/test/ratectrl_qmode_test.cc
index 3c947b0..dcf72d2 100644
--- a/test/ratectrl_qmode_test.cc
+++ b/test/ratectrl_qmode_test.cc
@@ -314,7 +314,7 @@
static TplFrameStats CreateToyTplFrameStatsWithDiffSizes(int min_block_size,
int max_block_size) {
- TplFrameStats frame_stats;
+ TplFrameStats frame_stats = {};
const int max_h = max_block_size;
const int max_w = max_h;
const int count = max_block_size / min_block_size;
@@ -455,7 +455,7 @@
// ref frame with coding_idx 0
TplFrameDepStats frame_dep_stats0 =
CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
- frame_stats.min_block_size);
+ frame_stats.min_block_size, false);
gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats0);
// cur frame with coding_idx 1
@@ -492,13 +492,13 @@
// ref frame with coding_idx 0
const TplFrameDepStats frame_dep_stats0 =
CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
- frame_stats.min_block_size);
+ frame_stats.min_block_size, false);
gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats0);
// ref frame with coding_idx 1
const TplFrameDepStats frame_dep_stats1 =
CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
- frame_stats.min_block_size);
+ frame_stats.min_block_size, false);
gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats1);
// cur frame with coding_idx 2
@@ -546,7 +546,7 @@
// ref frame with coding_idx 0
gop_dep_stats.frame_dep_stats_list.push_back(
CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
- frame_stats.min_block_size));
+ frame_stats.min_block_size, false));
// cur frame with coding_idx 1
const StatusOr<TplFrameDepStats> frame_dep_stats =