Guided CNN restoration: signal unit size.
An index `unit_index` is signaled at frame level to select one of two
unit sizes:
- unit_index = 0 means unit_size = 256 for low-res and 512 for high-res.
- unit_index = 1 means unit_size = 512 for low-res and 1024 for high-res.
For RDO, a loop is added in encoder side function
`restore_cnn_quadtree_encode_img_tflite_highbd` to pick the best unit
size.
diff --git a/av1/common/cnn_tflite.cc b/av1/common/cnn_tflite.cc
index 61d813f..e99b8ed 100644
--- a/av1/common/cnn_tflite.cc
+++ b/av1/common/cnn_tflite.cc
@@ -993,18 +993,18 @@
RDCOST_DBL_WITH_NATIVE_BD_DIST(rdmult, bitrate >> 4, sse, bit_depth);
}
-// Given intermediate restoration 'interm' and 'src', computes the best
-// partitioning out of NONE, SPLIT, HORZ and VERT based on RD cost, and uses it
-// to restore the unit starting at 'row' and 'col' inside 'dgd'.
-// Also, stores the split decisions in 'split' and a0,a1 pairs in 'A'.
+// Given intermediate restoration 'interm', source 'src' and degradade frame
+// 'dgd', computes the best partitioning out of NONE, SPLIT, HORZ and VERT based
+// on RD cost for the widthxheight unit starting at 'row' and 'col'.
+// The split decisions are stored in 'split' and a0,a1 pairs are stored in 'A'.
static void select_quadtree_partitioning(
const std::vector<std::vector<std::vector<double>>> &interm,
const uint16_t *src, int src_stride, int start_row, int start_col,
int width, int height, int quadtree_max_size, const int *quadtset,
int rdmult, const std::pair<int, int> &prev_A, const int *splitcosts,
- const int norestorecosts[2], int bit_depth, uint16_t *dgd, int dgd_stride,
- std::vector<int> &split, std::vector<std::pair<int, int>> &A,
- double *rdcost) {
+ const int norestorecosts[2], int bit_depth, const uint16_t *dgd,
+ int dgd_stride, std::vector<int> &split,
+ std::vector<std::pair<int, int>> &A, double *rdcost) {
const int end_row = AOMMIN(start_row + quadtree_max_size, height);
const int end_col = AOMMIN(start_col + quadtree_max_size, width);
const bool is_partial_unit = (start_row + quadtree_max_size > height) ||
@@ -1044,13 +1044,6 @@
// Save RDCost.
*rdcost = best_rdcost;
- // Restore this unit in 'dgd'.
- for (int row = start_row; row < end_row; ++row) {
- for (int col = start_col; col < end_col; ++col) {
- dgd[row * dgd_stride + col] = best_out[row - start_row][col - start_col];
- }
- }
-
// Save a0, a1 pairs.
for (auto &a0a1 : best_A) {
A.push_back(a0a1);
@@ -1083,6 +1076,13 @@
}
}
+static void apply_quadtree_partitioning(
+ const std::vector<std::vector<std::vector<double>>> &interm, int start_row,
+ int start_col, int width, int height, int quadtree_max_size,
+ const int *quadtset, int bit_depth, const std::vector<int> &split,
+ size_t &split_index, const std::vector<std::pair<int, int>> &A,
+ size_t &A_index, uint16_t *dgd, int dgd_stride);
+
// Top-level function to apply guided restoration on encoder side.
static int restore_cnn_quadtree_encode_img_tflite_highbd(
YV12_BUFFER_CONFIG *source_frame, AV1_COMMON *cm, int superres_denom,
@@ -1108,9 +1108,6 @@
// Initialization.
const uint16_t *src = CONVERT_TO_SHORTPTR(source_frame->y_buffer);
const int src_stride = source_frame->y_stride;
- const int unit_index = quad_tree_get_unit_index(width, height);
- const int quadtree_max_size =
- quad_tree_get_unit_size(width, height, unit_index);
const int *quadtset = get_quadparm_from_qindex(
qindex, superres_denom, is_intra_only, is_luma, cnn_index);
const int A0_min = quadtset[2];
@@ -1121,41 +1118,70 @@
const int *this_norestorecosts =
norestore_ctx == -1 ? null_norestorecosts : norestorecosts[norestore_ctx];
- // For each quadtree unit, compute the best partitioning out of
- // NONE, SPLIT, HORZ and VERT based on RD cost.
- std::vector<int> split; // selected partitioning options.
- std::vector<std::pair<int, int>> A; // selected a0, a1 weight pairs.
- // Previous a0, a1 pair is mid-point of the range by default.
- std::pair<int, int> prev_A = std::make_pair(8 + A0_min, 8 + A1_min);
- *rdcost = 0;
- // TODO(urvang): Include padded area in a unit if it's < unit size / 2?
- // If so, need to modify / replace quad_tree_get_unit_info_length().
- // Also double check: quad_tree_get_split_info_length().
- for (int row = 0; row < height; row += quadtree_max_size) {
- for (int col = 0; col < width; col += quadtree_max_size) {
- double this_rdcost;
- select_quadtree_partitioning(
- interm, src, src_stride, row, col, width, height, quadtree_max_size,
- quadtset, rdmult, prev_A, splitcosts, this_norestorecosts, bit_depth,
- dgd, dgd_stride, split, A, &this_rdcost);
- // updates.
- *rdcost += this_rdcost;
- prev_A = A.back();
+ // Try all possible quadtree unit sizes.
+ int best_unit_index = -1;
+ std::vector<int> best_split; // selected partitioning options.
+ std::vector<std::pair<int, int>> best_A; // selected a0, a1 weight pairs.
+ double best_rdcost_total = DBL_MAX;
+ for (int this_unit_index = 0; this_unit_index <= 1; ++this_unit_index) {
+ const int quadtree_max_size =
+ quad_tree_get_unit_size(width, height, this_unit_index);
+ // For each quadtree unit, compute the best partitioning out of
+ // NONE, SPLIT, HORZ and VERT based on RD cost.
+ std::vector<int> this_split; // selected partitioning options.
+ std::vector<std::pair<int, int>> this_A; // selected a0, a1 weight pairs.
+ double this_rdcost_total = 0.0;
+ // Previous a0, a1 pair is mid-point of the range by default.
+ std::pair<int, int> prev_A = std::make_pair(8 + A0_min, 8 + A1_min);
+ // TODO(urvang): Include padded area in a unit if it's < unit size / 2?
+ // If so, need to modify / replace quad_tree_get_unit_info_length().
+ // Also double check: quad_tree_get_split_info_length().
+ for (int row = 0; row < height; row += quadtree_max_size) {
+ for (int col = 0; col < width; col += quadtree_max_size) {
+ double this_rdcost;
+ select_quadtree_partitioning(
+ interm, src, src_stride, row, col, width, height, quadtree_max_size,
+ quadtset, rdmult, prev_A, splitcosts, this_norestorecosts,
+ bit_depth, dgd, dgd_stride, this_split, this_A, &this_rdcost);
+ // updates.
+ this_rdcost_total += this_rdcost;
+ prev_A = this_A.back();
+ }
+ }
+ // Update best options.
+ if (this_rdcost_total < best_rdcost_total) {
+ best_unit_index = this_unit_index;
+ best_split = this_split;
+ best_A = this_A;
+ best_rdcost_total = this_rdcost_total;
}
}
- // Fill in the decisions.
- quad_info->unit_index = unit_index;
- quad_info->split_info_length = (int)split.size();
- quad_info->unit_info_length = (int)A.size();
+ // Fill in the best options.
+ quad_info->unit_index = best_unit_index;
+ quad_info->split_info_length = (int)best_split.size();
+ quad_info->unit_info_length = (int)best_A.size();
av1_alloc_quadtree_struct(cm, quad_info);
- for (unsigned int i = 0; i < split.size(); ++i) {
- quad_info->split_info[i].split = split[i];
+ for (unsigned int i = 0; i < best_split.size(); ++i) {
+ quad_info->split_info[i].split = best_split[i];
}
- for (unsigned int i = 0; i < A.size(); ++i) {
- quad_info->unit_info[i].xqd[0] = A[i].first;
- quad_info->unit_info[i].xqd[1] = A[i].second;
+ for (unsigned int i = 0; i < best_A.size(); ++i) {
+ quad_info->unit_info[i].xqd[0] = best_A[i].first;
+ quad_info->unit_info[i].xqd[1] = best_A[i].second;
}
+ *rdcost = best_rdcost_total;
+
+ // Apply guided restoration to 'dgd' using best options above.
+ size_t split_index = 0;
+ size_t A_index = 0;
+ for (int row = 0; row < height; row += quad_info->unit_size) {
+ for (int col = 0; col < width; col += quad_info->unit_size) {
+ apply_quadtree_partitioning(
+ interm, row, col, width, height, quad_info->unit_size, quadtset,
+ bit_depth, best_split, split_index, best_A, A_index, dgd, dgd_stride);
+ }
+ }
+
return 1;
}
diff --git a/av1/common/guided_quadtree.h b/av1/common/guided_quadtree.h
index 2d953de..31550eb 100644
--- a/av1/common/guided_quadtree.h
+++ b/av1/common/guided_quadtree.h
@@ -46,17 +46,13 @@
return 0;
}
-// Get quad tree unit index based on dimensions.
-static INLINE int quad_tree_get_unit_index(int width, int height) {
- return (width * height <= 1280 * 720);
-}
-
// Get quad tree unit size.
static INLINE int quad_tree_get_unit_size(int width, int height,
- int quad_level) {
- (void)width;
- (void)height;
- return 512 >> quad_level;
+ int unit_index) {
+ const bool is_720p_or_smaller = (width * height <= 1280 * 720);
+ const int min_unit_size = is_720p_or_smaller ? 256 : 512;
+ assert(unit_index >= 0 && unit_index <= 1);
+ return min_unit_size << unit_index;
}
// Allocates buffers in 'quad_info' assuming 'quad_info->unit_index',
diff --git a/av1/common/restoration.h b/av1/common/restoration.h
index 4604ad6..0304328 100644
--- a/av1/common/restoration.h
+++ b/av1/common/restoration.h
@@ -450,8 +450,8 @@
/*!\cond */
#if CONFIG_CNN_GUIDED_QUADTREE
typedef struct {
- int unit_index; // unit size index inferred from frame dimensions.
- int unit_size; // inferred from unit_index.
+ int unit_index; // index signaled in bitstream.
+ int unit_size; // inferred from unit_index and frame dimensions.
int split_info_length;
int unit_info_length;
QUADUnitInfo *unit_info;
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index f7628d7..a1088f0 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2146,6 +2146,15 @@
}
#if CONFIG_CNN_RESTORATION
+
+#if CONFIG_CNN_GUIDED_QUADTREE
+// Read quad tree unit index.
+static INLINE int quad_tree_read_unit_index(struct aom_read_bit_buffer *rb) {
+ const int unit_index = aom_rb_read_bit(rb);
+ return unit_index;
+}
+#endif // CONFIG_CNN_GUIDED_QUADTREE
+
static void decode_cnn(AV1_COMMON *cm, struct aom_read_bit_buffer *rb) {
for (int plane = 0; plane < av1_num_planes(cm); ++plane) {
if (av1_allow_cnn_for_plane(cm, plane)) {
@@ -2169,8 +2178,7 @@
#if CONFIG_CNN_GUIDED_QUADTREE
if (cm->use_cnn[0]) {
QUADInfo *qi = &cm->cnn_quad_info;
- qi->unit_index = quad_tree_get_unit_index(cm->superres_upscaled_width,
- cm->superres_upscaled_height);
+ qi->unit_index = quad_tree_read_unit_index(rb);
qi->unit_size =
quad_tree_get_unit_size(cm->superres_upscaled_width,
cm->superres_upscaled_height, qi->unit_index);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index c9a6216..88817af 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -3120,6 +3120,16 @@
#endif // CONFIG_CNN_GUIDED_QUADTREE
#if CONFIG_CNN_RESTORATION
+
+#if CONFIG_CNN_GUIDED_QUADTREE
+// Write quad tree unit index.
+static INLINE void quad_tree_write_unit_index(struct aom_write_bit_buffer *wb,
+ const QUADInfo *const qi) {
+ assert(qi->unit_index >= 0 && qi->unit_index <= 1);
+ aom_wb_write_bit(wb, qi->unit_index);
+}
+#endif // CONFIG_CNN_GUIDED_QUADTREE
+
static void encode_cnn(AV1_COMMON *cm, struct aom_write_bit_buffer *wb) {
for (int plane = 0; plane < av1_num_planes(cm); ++plane) {
if (av1_allow_cnn_for_plane(cm, plane)) {
@@ -3141,9 +3151,11 @@
}
#if CONFIG_CNN_GUIDED_QUADTREE
if (cm->use_cnn[0]) {
- assert(cm->cnn_quad_info.unit_index ==
- quad_tree_get_unit_index(cm->superres_upscaled_width,
- cm->superres_upscaled_height));
+ quad_tree_write_unit_index(wb, &cm->cnn_quad_info);
+ assert(cm->cnn_quad_info.unit_size ==
+ quad_tree_get_unit_size(cm->superres_upscaled_width,
+ cm->superres_upscaled_height,
+ cm->cnn_quad_info.unit_index));
assert(cm->cnn_quad_info.split_info_length ==
quad_tree_get_split_info_length(cm->superres_upscaled_width,
cm->superres_upscaled_height,