Validate position and dimensions of TplBlockStats
Also added convenience functions ok(), operator*(), and operator->() to
Status/StatusOr.
Now that CreateTplFrameDepStatsWithoutPropagation crops blocks to
their overlap with the frame, it seems a good idea to check that the
dimensions are valid before cropping.
Change-Id: I7d0c968cfc2d3bd5925938dfe33f74a432b3e977
diff --git a/av1/ratectrl_qmode.cc b/av1/ratectrl_qmode.cc
index 83e09cb..e12f52e 100644
--- a/av1/ratectrl_qmode.cc
+++ b/av1/ratectrl_qmode.cc
@@ -903,10 +903,40 @@
return dep_stats;
}
-TplFrameDepStats CreateTplFrameDepStatsWithoutPropagation(
+namespace {
+Status ValidateBlockStats(const TplFrameStats &frame_stats,
+ const TplBlockStats &block_stats,
+ int min_block_size) {
+ if (block_stats.col >= frame_stats.frame_width ||
+ block_stats.row >= frame_stats.frame_height) {
+ std::ostringstream error_message;
+ error_message << "Block position (" << block_stats.col << ", "
+ << block_stats.row
+ << ") is out of range; frame dimensions are "
+ << frame_stats.frame_width << " x "
+ << frame_stats.frame_height;
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ if (block_stats.col % min_block_size != 0 ||
+ block_stats.row % min_block_size != 0 ||
+ block_stats.width % min_block_size != 0 ||
+ block_stats.height % min_block_size != 0) {
+ std::ostringstream error_message;
+ error_message
+ << "Invalid block position or dimension, must be a multiple of "
+ << min_block_size << "; col = " << block_stats.col
+ << ", row = " << block_stats.row << ", width = " << block_stats.width
+ << ", height = " << block_stats.height;
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ return { AOM_CODEC_OK, "" };
+}
+} // namespace
+
+StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
const TplFrameStats &frame_stats) {
if (frame_stats.block_stats_list.empty()) {
- return {};
+ return TplFrameDepStats();
}
const int min_block_size = frame_stats.min_block_size;
const int unit_rows =
@@ -916,6 +946,11 @@
TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
frame_stats.frame_height, frame_stats.frame_width, min_block_size);
for (const TplBlockStats &block_stats : frame_stats.block_stats_list) {
+ Status status =
+ ValidateBlockStats(frame_stats, block_stats, min_block_size);
+ if (!status.ok()) {
+ return status;
+ }
const int block_unit_row = block_stats.row / min_block_size;
const int block_unit_col = block_stats.col / min_block_size;
// The block must start within the frame boundaries, but it may extend past
@@ -1101,7 +1136,7 @@
return ref_frame_table_list;
}
-TplGopDepStats ComputeTplGopDepStats(
+StatusOr<TplGopDepStats> ComputeTplGopDepStats(
const TplGopStats &tpl_gop_stats,
const std::vector<RefFrameTable> &ref_frame_table_list) {
const int frame_count =
@@ -1109,10 +1144,16 @@
// Create the struct to store TPL dependency stats
TplGopDepStats tpl_gop_dep_stats;
+ tpl_gop_dep_stats.frame_dep_stats_list.reserve(frame_count);
for (int coding_idx = 0; coding_idx < frame_count; coding_idx++) {
- tpl_gop_dep_stats.frame_dep_stats_list.push_back(
+ const StatusOr<TplFrameDepStats> tpl_frame_dep_stats =
CreateTplFrameDepStatsWithoutPropagation(
- tpl_gop_stats.frame_stats_list[coding_idx]));
+ tpl_gop_stats.frame_stats_list[coding_idx]);
+ if (!tpl_frame_dep_stats.ok()) {
+ return tpl_frame_dep_stats.status();
+ }
+ tpl_gop_dep_stats.frame_dep_stats_list.push_back(
+ std::move(*tpl_frame_dep_stats));
}
// Back propagation
@@ -1149,8 +1190,11 @@
GopEncodeInfo gop_encode_info;
gop_encode_info.final_snapshot = ref_frame_table_list.back();
- TplGopDepStats gop_dep_stats =
+ StatusOr<TplGopDepStats> gop_dep_stats =
ComputeTplGopDepStats(tpl_gop_stats, ref_frame_table_list);
+ if (!gop_dep_stats.ok()) {
+ return gop_dep_stats.status();
+ }
const int frame_count =
static_cast<int>(tpl_gop_stats.frame_stats_list.size());
for (int i = 0; i < frame_count; i++) {
@@ -1162,7 +1206,7 @@
param.q_index = rc_param_.base_q_index;
} else {
const TplFrameDepStats &frame_dep_stats =
- gop_dep_stats.frame_dep_stats_list[i];
+ gop_dep_stats->frame_dep_stats_list[i];
const double cost_without_propagation =
TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
const double cost_with_propagation =
diff --git a/av1/ratectrl_qmode.h b/av1/ratectrl_qmode.h
index bca86c2..7a59687 100644
--- a/av1/ratectrl_qmode.h
+++ b/av1/ratectrl_qmode.h
@@ -69,7 +69,7 @@
TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
int unit_count);
-TplFrameDepStats CreateTplFrameDepStatsWithoutPropagation(
+StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
const TplFrameStats &frame_stats);
std::vector<int> GetKeyFrameList(const FirstpassInfo &first_pass_info);
@@ -85,7 +85,7 @@
int GetBlockOverlapArea(int r0, int c0, int r1, int c1, int size);
-TplGopDepStats ComputeTplGopDepStats(
+StatusOr<TplGopDepStats> ComputeTplGopDepStats(
const TplGopStats &tpl_gop_stats,
const std::vector<RefFrameTable> &ref_frame_table_list);
diff --git a/av1/ratectrl_qmode_interface.h b/av1/ratectrl_qmode_interface.h
index 5d5d830..316a699 100644
--- a/av1/ratectrl_qmode_interface.h
+++ b/av1/ratectrl_qmode_interface.h
@@ -85,6 +85,7 @@
struct Status {
aom_codec_err_t code;
std::string message; // Empty if code == AOM_CODEC_OK.
+ bool ok() const { return code == AOM_CODEC_OK; }
};
// A very simple imitation of absl::StatusOr, this is conceptually a union of a
@@ -99,11 +100,38 @@
StatusOr(Status status) : status_(std::move(status)) {
assert(status_.code != AOM_CODEC_OK);
}
- const T &value() const & { return value_; }
- T &value() & { return value_; }
- const T &&value() const && { return value_; }
- T &&value() && { return std::move(value_); }
+
const Status &status() const { return status_; }
+ bool ok() const { return status().ok(); }
+
+ // operator* returns the value; it should only be called after checking that
+ // ok() returns true.
+ const T &operator*() const & { return value_; }
+ T &operator*() & { return value_; }
+ const T &&operator*() const && { return value_; }
+ T &&operator*() && { return std::move(value_); }
+
+ // sor->field is equivalent to (*sor).field.
+ const T *operator->() const & { return &value_; }
+ T *operator->() & { return &value_; }
+
+ // value() is equivalent to operator*, but asserts that ok() is true.
+ const T &value() const & {
+ assert(ok());
+ return value_;
+ }
+ T &value() & {
+ assert(ok());
+ return value_;
+ }
+ const T &&value() const && {
+ assert(ok());
+ return value_;
+ }
+ T &&value() && {
+ assert(ok());
+ return std::move(value_);
+ }
private:
T value_; // This could be std::optional<T> if it were available.
diff --git a/test/ratectrl_qmode_test.cc b/test/ratectrl_qmode_test.cc
index 13d9d50..ae516e9 100644
--- a/test/ratectrl_qmode_test.cc
+++ b/test/ratectrl_qmode_test.cc
@@ -29,8 +29,16 @@
namespace {
+using ::testing::HasSubstr;
+
constexpr int kRefFrameTableSize = 7;
+MATCHER(IsOkStatus, "") {
+ *result_listener << "with code " << arg.code
+ << " and message: " << arg.message;
+ return arg.ok();
+}
+
// Reads a whitespace-delimited string from stream, and parses it as a double.
// Returns an empty string if the entire string was successfully parsed as a
// double, or an error messaage if not.
@@ -431,21 +439,38 @@
TEST(RateControlQModeTest, CreateTplFrameDepStats) {
TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
- TplFrameDepStats frame_dep_stats =
+ StatusOr<TplFrameDepStats> frame_dep_stats =
CreateTplFrameDepStatsWithoutPropagation(frame_stats);
- EXPECT_EQ(frame_stats.min_block_size, frame_dep_stats.unit_size);
- const int unit_rows = static_cast<int>(frame_dep_stats.unit_stats.size());
- const int unit_cols = static_cast<int>(frame_dep_stats.unit_stats[0].size());
- EXPECT_EQ(frame_stats.frame_height, unit_rows * frame_dep_stats.unit_size);
- EXPECT_EQ(frame_stats.frame_width, unit_cols * frame_dep_stats.unit_size);
+ ASSERT_THAT(frame_dep_stats.status(), IsOkStatus());
+ EXPECT_EQ(frame_stats.min_block_size, frame_dep_stats->unit_size);
+ const int unit_rows = static_cast<int>(frame_dep_stats->unit_stats.size());
+ const int unit_cols = static_cast<int>(frame_dep_stats->unit_stats[0].size());
+ EXPECT_EQ(frame_stats.frame_height, unit_rows * frame_dep_stats->unit_size);
+ EXPECT_EQ(frame_stats.frame_width, unit_cols * frame_dep_stats->unit_size);
const double intra_cost_sum =
- TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
+ TplFrameDepStatsAccumulateIntraCost(*frame_dep_stats);
const double expected_intra_cost_sum =
TplFrameStatsAccumulateIntraCost(frame_stats);
EXPECT_NEAR(intra_cost_sum, expected_intra_cost_sum, kErrorEpsilon);
}
+TEST(RateControlQModeTest, BlockRowNotAMultipleOfMinBlockSizeError) {
+ TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
+ frame_stats.block_stats_list.back().row = 1;
+ auto result = CreateTplFrameDepStatsWithoutPropagation(frame_stats);
+ EXPECT_FALSE(result.ok());
+ EXPECT_THAT(result.status().message, HasSubstr("must be a multiple of 8"));
+}
+
+TEST(RateControlQModeTest, BlockPositionOutOfRangeError) {
+ TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
+ frame_stats.block_stats_list.back().row += 8;
+ auto result = CreateTplFrameDepStatsWithoutPropagation(frame_stats);
+ EXPECT_FALSE(result.ok());
+ EXPECT_THAT(result.status().message, HasSubstr("out of range"));
+}
+
TEST(RateControlQModeTest, GetBlockOverlapArea) {
const int size = 8;
const int r0 = 8;
@@ -490,9 +515,10 @@
gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats0);
// cur frame with coding_idx 1
- const TplFrameDepStats frame_dep_stats1 =
+ const StatusOr<TplFrameDepStats> frame_dep_stats1 =
CreateTplFrameDepStatsWithoutPropagation(frame_stats);
- gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats1);
+ ASSERT_THAT(frame_dep_stats1.status(), IsOkStatus());
+ gop_dep_stats.frame_dep_stats_list.push_back(std::move(*frame_dep_stats1));
const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
TplFrameDepStatsPropagate(/*coding_idx=*/1, ref_frame_table, &gop_dep_stats);
@@ -532,9 +558,10 @@
gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats1);
// cur frame with coding_idx 2
- const TplFrameDepStats frame_dep_stats2 =
+ const StatusOr<TplFrameDepStats> frame_dep_stats2 =
CreateTplFrameDepStatsWithoutPropagation(frame_stats);
- gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats2);
+ ASSERT_THAT(frame_dep_stats2.status(), IsOkStatus());
+ gop_dep_stats.frame_dep_stats_list.push_back(std::move(*frame_dep_stats2));
const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
TplFrameDepStatsPropagate(/*coding_idx=*/2, ref_frame_table, &gop_dep_stats);
@@ -578,8 +605,10 @@
frame_stats.min_block_size));
// cur frame with coding_idx 1
- gop_dep_stats.frame_dep_stats_list.push_back(
- CreateTplFrameDepStatsWithoutPropagation(frame_stats));
+ const StatusOr<TplFrameDepStats> frame_dep_stats =
+ CreateTplFrameDepStatsWithoutPropagation(frame_stats);
+ ASSERT_THAT(frame_dep_stats.status(), IsOkStatus());
+ gop_dep_stats.frame_dep_stats_list.push_back(std::move(*frame_dep_stats));
const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
TplFrameDepStatsPropagate(/*coding_idx=*/1, ref_frame_table, &gop_dep_stats);
@@ -627,8 +656,9 @@
ref_frame_table_list.push_back(CreateToyRefFrameTable(i));
}
- const TplGopDepStats &gop_dep_stats =
+ const StatusOr<TplGopDepStats> gop_dep_stats =
ComputeTplGopDepStats(tpl_gop_stats, ref_frame_table_list);
+ ASSERT_THAT(gop_dep_stats.status(), IsOkStatus());
double expected_sum = 0;
for (int i = 2; i >= 0; i--) {
@@ -637,7 +667,7 @@
expected_sum +=
TplFrameStatsAccumulateIntraCost(tpl_gop_stats.frame_stats_list[i]);
const double sum =
- TplFrameDepStatsAccumulate(gop_dep_stats.frame_dep_stats_list[i]);
+ TplFrameDepStatsAccumulate(gop_dep_stats->frame_dep_stats_list[i]);
EXPECT_NEAR(sum, expected_sum, kErrorEpsilon);
break;
}
@@ -909,12 +939,6 @@
return frame;
}
-MATCHER(IsOkStatus, "") {
- *result_listener << "with code " << arg.code
- << " and message: " << arg.message;
- return arg.code == AOM_CODEC_OK;
-}
-
TEST(RateControlQModeTest, TestSetRcParamErrorChecking) {
// Default constructed RateControlParam should not be valid.
RateControlParam rc_param = {};
@@ -1005,7 +1029,7 @@
const auto gop_info = rc.DetermineGopInfo(firstpass_info);
ASSERT_THAT(gop_info.status(), IsOkStatus());
std::vector<int> gop_interval_list;
- std::transform(gop_info.value().begin(), gop_info.value().end(),
+ std::transform(gop_info->begin(), gop_info->end(),
std::back_inserter(gop_interval_list),
[](GopStruct const &x) { return x.show_frame_count; });
EXPECT_THAT(gop_interval_list,
@@ -1028,7 +1052,7 @@
ASSERT_THAT(rc.SetRcParam(rc_param), IsOkStatus());
const auto gop_info = rc.DetermineGopInfo(firstpass_info);
ASSERT_THAT(gop_info.status(), IsOkStatus());
- const GopStructList &gop_list = gop_info.value();
+ const GopStructList &gop_list = *gop_info;
// Read TPL stats
std::vector<TplGopStats> tpl_gop_list;
ASSERT_NO_FATAL_FAILURE(ReadTplInfo("tpl_stats", gop_list, &tpl_gop_list));
@@ -1044,10 +1068,10 @@
const auto gop_encode_info = rc.GetGopEncodeInfo(
gop_list[gop_idx], tpl_gop_list[tpl_gop_idx], ref_frame_table);
ASSERT_THAT(gop_encode_info.status(), IsOkStatus());
- for (auto &frame_param : gop_encode_info.value().param_list) {
+ for (auto &frame_param : gop_encode_info->param_list) {
std::cout << frame_param.q_index << std::endl;
}
- ref_frame_table = gop_encode_info.value().final_snapshot;
+ ref_frame_table = gop_encode_info->final_snapshot;
}
}