Guided CNN Restoration: further improvements to unit sizes
Includes following changes:
* Similar to traditional loop restoration framework: If remaining rows/cols
near boundary are half-unit-height/width, extend the previous unit to cover those rows/cols.
* For narrow units near boundary, allow splitting either bigger dimension only (to HORZ/VERT type).
* Signal one of 4 unit sizes at frame level: largest allowed size is
min(2048, next_power_of_2 of largest dimension). And next 3 sizes are max_size/2, max_size/4 and max_size/8.
Also some code cleanups related to this.
diff --git a/av1/common/cnn_tflite.cc b/av1/common/cnn_tflite.cc
index b94792b..6ee12cc 100644
--- a/av1/common/cnn_tflite.cc
+++ b/av1/common/cnn_tflite.cc
@@ -618,30 +618,19 @@
return 1;
}
-typedef enum {
- GUIDED_QT_NONE,
- GUIDED_QT_SPLIT,
- GUIDED_QT_HORZ,
- GUIDED_QT_VERT,
- GUIDED_QT_TYPES,
- GUIDED_QT_INVALID = -1
-} GuidedQuadTreePartitionType;
-
// Get unit width and height based on max size and partition type.
-static void get_unit_size(int quadtree_max_size,
+static void get_unit_size(int max_unit_width, int max_unit_height,
GuidedQuadTreePartitionType partition_type,
int *unit_width, int *unit_height) {
assert(partition_type >= 0 && partition_type < GUIDED_QT_TYPES);
- const int full_size = quadtree_max_size;
- const int half_size = quadtree_max_size >> 1;
*unit_width =
(partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_HORZ)
- ? full_size
- : half_size;
+ ? max_unit_width
+ : max_unit_width >> 1;
*unit_height =
(partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_VERT)
- ? full_size
- : half_size;
+ ? max_unit_height
+ : max_unit_height >> 1;
}
// ------------------- Guided Quadtree: Encoder ------------------------------//
@@ -968,17 +957,24 @@
const std::vector<std::vector<std::vector<double>>> &interm,
GuidedQuadTreePartitionType partition_type, const uint16_t *src,
int src_stride, const uint16_t *dgd, int dgd_stride, int start_row,
- int end_row, int start_col, int end_col, int quadtree_max_size,
- const int *quadtset, int rdmult, const std::pair<int, int> &prev_A,
- const int *splitcosts, const int *norestorecosts, int bit_depth,
- bool is_partial_unit, double *this_rdcost,
- std::vector<std::vector<uint16_t>> &out,
+ int end_row, int start_col, int end_col, int max_unit_width,
+ int max_unit_height, const int *quadtset, int rdmult,
+ const std::pair<int, int> &prev_A, const int *quad_split_costs,
+ const int *binary_split_costs, const int *norestorecosts, int bit_depth,
+ bool is_horz_partitioning_allowed, int is_vert_partitioning_allowed,
+ double *this_rdcost, std::vector<std::vector<uint16_t>> &out,
std::vector<std::pair<int, int>> &A) {
- assert(IMPLIES(is_partial_unit, partition_type == GUIDED_QT_NONE));
+ assert(IMPLIES(
+ !is_horz_partitioning_allowed,
+ partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_VERT));
+ assert(IMPLIES(
+ !is_vert_partitioning_allowed,
+ partition_type == GUIDED_QT_NONE || partition_type == GUIDED_QT_HORZ));
// Get unit width and height based on partition type.
int unit_width;
int unit_height;
- get_unit_size(quadtree_max_size, partition_type, &unit_width, &unit_height);
+ get_unit_size(max_unit_width, max_unit_height, partition_type, &unit_width,
+ &unit_height);
// Compute restored unit, a0 and a1.
generate_linear_combination(interm, src, src_stride, dgd, dgd_stride,
@@ -995,11 +991,16 @@
compute_sse(out, src, src_stride, start_row, end_row, start_col, end_col);
// Compute Rate.
- const int num_bits_for_a = compute_rate(A, prev_A, quadtset, norestorecosts);
- // Partition is implied to be NONE in case of partial unit.
+ const int a_signaling_cost =
+ compute_rate(A, prev_A, quadtset, norestorecosts);
+ // Partition signaling cost depending on 1, 2 or 4 possible partition types.
const int partition_signaling_cost =
- is_partial_unit ? 0 : splitcosts[partition_type];
- const int bitrate = num_bits_for_a + partition_signaling_cost;
+ is_horz_partitioning_allowed && is_vert_partitioning_allowed
+ ? quad_split_costs[partition_type]
+ : (is_horz_partitioning_allowed || is_vert_partitioning_allowed)
+ ? binary_split_costs[partition_type]
+ : 0;
+ const int bitrate = a_signaling_cost + partition_signaling_cost;
// Compute RDCost.
*this_rdcost =
@@ -1013,39 +1014,54 @@
static void select_quadtree_partitioning(
const std::vector<std::vector<std::vector<double>>> &interm,
const uint16_t *src, int src_stride, int start_row, int start_col,
- int width, int height, int quadtree_max_size, const int *quadtset,
- int rdmult, const std::pair<int, int> &prev_A, const int *splitcosts,
- const int norestorecosts[2], int bit_depth, const uint16_t *dgd,
- int dgd_stride, std::vector<int> &split,
+ int width, int height, int quadtree_max_size, int max_unit_width,
+ int max_unit_height, const int *quadtset, int rdmult,
+ const std::pair<int, int> &prev_A, const int *quad_split_costs,
+ const int *binary_split_costs, const int norestorecosts[2], int bit_depth,
+ const uint16_t *dgd, int dgd_stride, std::vector<int> &split,
std::vector<std::pair<int, int>> &A, double *rdcost) {
- const int end_row = AOMMIN(start_row + quadtree_max_size, height);
- const int end_col = AOMMIN(start_col + quadtree_max_size, width);
- const bool is_partial_unit = (start_row + quadtree_max_size > height) ||
- (start_col + quadtree_max_size > width);
+ const int end_row = AOMMIN(start_row + max_unit_height, height);
+ const int end_col = AOMMIN(start_col + max_unit_width, width);
+ // Check for special cases near boundary.
+ const bool is_horz_partitioning_allowed =
+ (max_unit_height >= quadtree_max_size);
+ const bool is_vert_partitioning_allowed =
+ (max_unit_width >= quadtree_max_size);
+ const bool is_split_partitioning_allowed =
+ is_horz_partitioning_allowed && is_vert_partitioning_allowed;
auto best_rdcost = DBL_MAX;
std::vector<std::pair<int, int>> best_A;
std::vector<std::vector<uint16_t>> best_out(
- quadtree_max_size, std::vector<uint16_t>(quadtree_max_size));
+ max_unit_height, std::vector<uint16_t>(max_unit_width));
GuidedQuadTreePartitionType best_partition_type = GUIDED_QT_INVALID;
for (int type = 0; type < GUIDED_QT_TYPES; ++type) {
const auto this_partition_type = (GuidedQuadTreePartitionType)type;
- // Special case: if only partial unit is within boundary, we implicitly
- // use NONE partitioning and do not try the splitting options.
- if (is_partial_unit && (this_partition_type != GUIDED_QT_NONE)) {
+ // Check for special cases near boundary.
+ if (!is_horz_partitioning_allowed &&
+ (this_partition_type == GUIDED_QT_HORZ)) {
continue;
}
-
+ if (!is_vert_partitioning_allowed &&
+ (this_partition_type == GUIDED_QT_VERT)) {
+ continue;
+ }
+ if (!is_split_partitioning_allowed &&
+ (this_partition_type == GUIDED_QT_SPLIT)) {
+ continue;
+ }
+ // Try this partition type.
double this_rdcost;
std::vector<std::pair<int, int>> this_A;
std::vector<std::vector<uint16_t>> this_out(
- quadtree_max_size, std::vector<uint16_t>(quadtree_max_size));
- try_one_partition(interm, this_partition_type, src, src_stride, dgd,
- dgd_stride, start_row, end_row, start_col, end_col,
- quadtree_max_size, quadtset, rdmult, prev_A, splitcosts,
- norestorecosts, bit_depth, is_partial_unit, &this_rdcost,
- this_out, this_A);
+ max_unit_height, std::vector<uint16_t>(max_unit_width));
+ try_one_partition(
+ interm, this_partition_type, src, src_stride, dgd, dgd_stride,
+ start_row, end_row, start_col, end_col, max_unit_width, max_unit_height,
+ quadtset, rdmult, prev_A, quad_split_costs, binary_split_costs,
+ norestorecosts, bit_depth, is_horz_partitioning_allowed,
+ is_vert_partitioning_allowed, &this_rdcost, this_out, this_A);
if (this_rdcost < best_rdcost) {
best_rdcost = this_rdcost;
best_A = this_A;
@@ -1063,45 +1079,29 @@
}
// Save split decision.
- if (is_partial_unit) {
+ if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) {
// Nothing should be added to 'split' array.
assert(best_partition_type == GUIDED_QT_NONE);
return;
}
- switch (best_partition_type) {
- case GUIDED_QT_NONE:
- split.push_back(0);
- split.push_back(0);
- break;
- case GUIDED_QT_SPLIT:
- split.push_back(0);
- split.push_back(1);
- break;
- case GUIDED_QT_HORZ:
- split.push_back(1);
- split.push_back(1);
- break;
- case GUIDED_QT_VERT:
- split.push_back(1);
- split.push_back(0);
- break;
- default: assert(0 && "Wrong partition type"); break;
- }
+ assert(best_partition_type >= 0 && best_partition_type < GUIDED_QT_TYPES);
+ split.push_back(best_partition_type);
}
static void apply_quadtree_partitioning(
const std::vector<std::vector<std::vector<double>>> &interm, int start_row,
int start_col, int width, int height, int quadtree_max_size,
- const int *quadtset, int bit_depth, const std::vector<int> &split,
- size_t &split_index, const std::vector<std::pair<int, int>> &A,
- size_t &A_index, uint16_t *dgd, int dgd_stride);
+ int max_unit_width, int max_unit_height, const int *quadtset, int bit_depth,
+ const std::vector<int> &split, size_t &split_index,
+ const std::vector<std::pair<int, int>> &A, size_t &A_index, uint16_t *dgd,
+ int dgd_stride);
// Top-level function to apply guided restoration on encoder side.
static int restore_cnn_quadtree_encode_img_tflite_highbd(
YV12_BUFFER_CONFIG *source_frame, AV1_COMMON *cm, int superres_denom,
- int rdmult, const int *splitcosts, int (*norestorecosts)[2],
- int num_threads, int bit_depth, int is_intra_only, int is_luma,
- int cnn_index, QUADInfo *quad_info, double *rdcost) {
+ int rdmult, const int *quad_split_costs, const int *binary_split_costs,
+ int (*norestorecosts)[2], int num_threads, int bit_depth, int is_intra_only,
+ int is_luma, int cnn_index, QUADInfo *quad_info, double *rdcost) {
YV12_BUFFER_CONFIG *dgd_buf = &cm->cur_frame->buf;
uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd_buf->y_buffer);
const int dgd_stride = dgd_buf->y_stride;
@@ -1136,7 +1136,8 @@
std::vector<int> best_split; // selected partitioning options.
std::vector<std::pair<int, int>> best_A; // selected a0, a1 weight pairs.
double best_rdcost_total = DBL_MAX;
- for (int this_unit_index = 0; this_unit_index <= 1; ++this_unit_index) {
+ for (int this_unit_index = 0; this_unit_index < GUIDED_QT_UNIT_SIZES;
+ ++this_unit_index) {
const int quadtree_max_size =
quad_tree_get_unit_size(width, height, this_unit_index);
// For each quadtree unit, compute the best partitioning out of
@@ -1147,20 +1148,27 @@
// Previous a0, a1 pair is mid-point of the range by default.
std::pair<int, int> prev_A =
std::make_pair(GUIDED_A_MID + A0_min, GUIDED_A_MID + A1_min);
- // TODO(urvang): Include padded area in a unit if it's < unit size / 2?
- // If so, need to modify / replace quad_tree_get_unit_info_length().
- // Also double check: quad_tree_get_split_info_length().
- for (int row = 0; row < height; row += quadtree_max_size) {
- for (int col = 0; col < width; col += quadtree_max_size) {
+ const int ext_size = quadtree_max_size * 3 / 2;
+ for (int row = 0; row < height;) {
+ const int remaining_height = height - row;
+ const int this_unit_height =
+ (remaining_height < ext_size) ? remaining_height : quadtree_max_size;
+ for (int col = 0; col < width;) {
+ const int remaining_width = width - col;
+ const int this_unit_width =
+ (remaining_width < ext_size) ? remaining_width : quadtree_max_size;
double this_rdcost;
select_quadtree_partitioning(
interm, src, src_stride, row, col, width, height, quadtree_max_size,
- quadtset, rdmult, prev_A, splitcosts, this_norestorecosts,
+ this_unit_width, this_unit_height, quadtset, rdmult, prev_A,
+ quad_split_costs, binary_split_costs, this_norestorecosts,
bit_depth, dgd, dgd_stride, this_split, this_A, &this_rdcost);
// updates.
this_rdcost_total += this_rdcost;
prev_A = this_A.back();
+ col += this_unit_width;
}
+ row += this_unit_height;
}
// Update best options.
if (this_rdcost_total < best_rdcost_total) {
@@ -1188,22 +1196,33 @@
// Apply guided restoration to 'dgd' using best options above.
size_t split_index = 0;
size_t A_index = 0;
- for (int row = 0; row < height; row += quad_info->unit_size) {
- for (int col = 0; col < width; col += quad_info->unit_size) {
+ const int quadtree_max_size = quad_info->unit_size;
+ const int ext_size = quadtree_max_size * 3 / 2;
+ for (int row = 0; row < height;) {
+ const int remaining_height = height - row;
+ const int this_unit_height =
+ (remaining_height < ext_size) ? remaining_height : quadtree_max_size;
+ for (int col = 0; col < width;) {
+ const int remaining_width = width - col;
+ const int this_unit_width =
+ (remaining_width < ext_size) ? remaining_width : quadtree_max_size;
apply_quadtree_partitioning(
- interm, row, col, width, height, quad_info->unit_size, quadtset,
- bit_depth, best_split, split_index, best_A, A_index, dgd, dgd_stride);
+ interm, row, col, width, height, quadtree_max_size, this_unit_width,
+ this_unit_height, quadtset, bit_depth, best_split, split_index,
+ best_A, A_index, dgd, dgd_stride);
+ col += this_unit_width;
}
+ row += this_unit_height;
}
return 1;
}
extern "C" int av1_restore_cnn_quadtree_encode_tflite(
- AV1_COMMON *cm, YV12_BUFFER_CONFIG *source_frame, int RDMULT,
- int *splitcosts, int (*norestorecosts)[2], int num_threads,
- const int apply_cnn[MAX_MB_PLANE], const int cnn_indices[MAX_MB_PLANE],
- QUADInfo *quad_info, double *rdcost) {
+ struct AV1Common *cm, YV12_BUFFER_CONFIG *source_frame, int RDMULT,
+ int *quad_split_costs, int *binary_split_costs, int (*norestorecosts)[2],
+ int num_threads, const int apply_cnn[MAX_MB_PLANE],
+ const int cnn_indices[MAX_MB_PLANE], QUADInfo *quad_info, double *rdcost) {
YV12_BUFFER_CONFIG *buf = &cm->cur_frame->buf;
const int is_intra_only = frame_is_intra_only(cm);
for (int plane = 0; plane < av1_num_planes(cm); ++plane) {
@@ -1217,8 +1236,9 @@
case AOM_PLANE_Y:
ret = restore_cnn_quadtree_encode_img_tflite_highbd(
source_frame, cm, cm->superres_scale_denominator, RDMULT,
- splitcosts, norestorecosts, num_threads, cm->seq_params.bit_depth,
- is_intra_only, is_luma, cnn_index, quad_info, rdcost);
+ quad_split_costs, binary_split_costs, norestorecosts, num_threads,
+ cm->seq_params.bit_depth, is_intra_only, is_luma, cnn_index,
+ quad_info, rdcost);
if (ret == 0) return ret;
break;
case AOM_PLANE_U:
@@ -1290,42 +1310,30 @@
static void apply_quadtree_partitioning(
const std::vector<std::vector<std::vector<double>>> &interm, int start_row,
int start_col, int width, int height, int quadtree_max_size,
- const int *quadtset, int bit_depth, const std::vector<int> &split,
- size_t &split_index, const std::vector<std::pair<int, int>> &A,
- size_t &A_index, uint16_t *dgd, int dgd_stride) {
- const int end_row = AOMMIN(start_row + quadtree_max_size, height);
- const int end_col = AOMMIN(start_col + quadtree_max_size, width);
- const bool is_partial_unit = (start_row + quadtree_max_size > height) ||
- (start_col + quadtree_max_size > width);
+ int max_unit_width, int max_unit_height, const int *quadtset, int bit_depth,
+ const std::vector<int> &split, size_t &split_index,
+ const std::vector<std::pair<int, int>> &A, size_t &A_index, uint16_t *dgd,
+ int dgd_stride) {
+ const int end_row = AOMMIN(start_row + max_unit_height, height);
+ const int end_col = AOMMIN(start_col + max_unit_width, width);
+ // Check for special cases near boundary.
+ const bool is_horz_partitioning_allowed =
+ (max_unit_height >= quadtree_max_size);
+ const bool is_vert_partitioning_allowed =
+ (max_unit_width >= quadtree_max_size);
// Get partition type.
GuidedQuadTreePartitionType partition_type = GUIDED_QT_NONE;
- if (!is_partial_unit) {
- const int spl1 = split[split_index++];
- const int spl2 = split[split_index++];
- if (spl1 == 0) {
- if (spl2 == 0) {
- partition_type = GUIDED_QT_NONE; // (0, 0)
- } else {
- assert(spl2 == 1);
- partition_type = GUIDED_QT_SPLIT; // (0, 1)
- }
- } else {
- assert(spl1 == 1);
- if (spl2 == 1) {
- partition_type = GUIDED_QT_HORZ; // (1, 1)
- } else {
- assert(spl2 == 0);
- partition_type = GUIDED_QT_VERT; // (1, 0)
- }
- }
+ if (is_horz_partitioning_allowed || is_vert_partitioning_allowed) {
+ partition_type = (GuidedQuadTreePartitionType)split[split_index++];
}
assert(partition_type >= 0 && partition_type < GUIDED_QT_TYPES);
// Get unit width and height based on partition type.
int unit_width;
int unit_height;
- get_unit_size(quadtree_max_size, partition_type, &unit_width, &unit_height);
+ get_unit_size(max_unit_width, max_unit_height, partition_type, &unit_width,
+ &unit_height);
// Compute restored unit, a0 and a1 with given A parameters.
apply_linear_combination(interm, start_row, end_row, start_col, end_col,
@@ -1377,12 +1385,22 @@
// For each quadtree unit, apply given quadtree partitioning.
size_t split_index = 0;
size_t A_index = 0;
- for (int row = 0; row < height; row += quadtree_max_size) {
- for (int col = 0; col < width; col += quadtree_max_size) {
+ const int ext_size = quadtree_max_size * 3 / 2;
+ for (int row = 0; row < height;) {
+ const int remaining_height = height - row;
+ const int this_unit_height =
+ (remaining_height < ext_size) ? remaining_height : quadtree_max_size;
+ for (int col = 0; col < width;) {
+ const int remaining_width = width - col;
+ const int this_unit_width =
+ (remaining_width < ext_size) ? remaining_width : quadtree_max_size;
apply_quadtree_partitioning(interm, row, col, width, height,
- quadtree_max_size, quadtset, bit_depth, split,
+ quadtree_max_size, this_unit_width,
+ this_unit_height, quadtset, bit_depth, split,
split_index, A, A_index, dgd, dgd_stride);
+ col += this_unit_width;
}
+ row += this_unit_height;
}
assert(split_index == split.size());
assert(A_index == A.size());
diff --git a/av1/common/cnn_tflite.h b/av1/common/cnn_tflite.h
index f74165c..3ed3682 100644
--- a/av1/common/cnn_tflite.h
+++ b/av1/common/cnn_tflite.h
@@ -114,9 +114,9 @@
// Apply Guided CNN restoration on encoder side.
int av1_restore_cnn_quadtree_encode_tflite(
struct AV1Common *cm, YV12_BUFFER_CONFIG *source_frame, int RDMULT,
- int *splitcosts, int (*norestorecosts)[2], int num_threads,
- const int apply_cnn[MAX_MB_PLANE], const int cnn_indices[MAX_MB_PLANE],
- QUADInfo *quad_info, double *rdcost);
+ int *quad_split_costs, int *binary_split_costs, int (*norestorecosts)[2],
+ int num_threads, const int apply_cnn[MAX_MB_PLANE],
+ const int cnn_indices[MAX_MB_PLANE], QUADInfo *quad_info, double *rdcost);
// Apply Guided CNN restoration on decoder side.
int av1_restore_cnn_quadtree_decode_tflite(struct AV1Common *cm,
diff --git a/av1/common/entropy.c b/av1/common/entropy.c
index 89a47ff..1e63d77 100644
--- a/av1/common/entropy.c
+++ b/av1/common/entropy.c
@@ -241,6 +241,7 @@
#endif // CONFIG_RST_MERGECOEFFS
#if CONFIG_CNN_GUIDED_QUADTREE
RESET_CDF_COUNTER(fc->cnn_guided_quad_cdf, 4);
+ RESET_CDF_COUNTER(fc->cnn_guided_binary_cdf, 2);
RESET_CDF_COUNTER(fc->cnn_guided_norestore_cdf, 2);
#endif // CONFIG_CNN_GUIDED_QUADTREE
#if CONFIG_AIMC
diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c
index 0d15acb..def65a8 100644
--- a/av1/common/entropymode.c
+++ b/av1/common/entropymode.c
@@ -1560,6 +1560,9 @@
static const aom_cdf_prob default_cnn_guided_quad_cdf[CDF_SIZE(4)] = {
AOM_CDF4(23552, 24576, 28672),
};
+static const aom_cdf_prob default_cnn_guided_binary_cdf[CDF_SIZE(2)] = {
+ AOM_CDF2(23552),
+};
static const aom_cdf_prob
default_cnn_guided_norestore_cdf[GUIDED_NORESTORE_CONTEXTS][CDF_SIZE(2)] = {
{ AOM_CDF2(16384) }, { AOM_CDF2(24576) }
@@ -1967,6 +1970,7 @@
#endif // CONFIG_PC_WIENER
#if CONFIG_CNN_GUIDED_QUADTREE
av1_copy(fc->cnn_guided_quad_cdf, default_cnn_guided_quad_cdf);
+ av1_copy(fc->cnn_guided_binary_cdf, default_cnn_guided_binary_cdf);
av1_copy(fc->cnn_guided_norestore_cdf, default_cnn_guided_norestore_cdf);
#endif // CONFIG_CNN_GUIDED_QUADTREE
#if CONFIG_AIMC
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 3b4e443..e277f52 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -240,7 +240,12 @@
aom_cdf_prob merged_param_cdf[CDF_SIZE(2)];
#endif // CONFIG_RST_MERGECOEFFS
#if CONFIG_CNN_GUIDED_QUADTREE
+ // CDF to signal one of GUIDED_QT_TYPES partition types for each unit, when
+ // all types are allowed.
aom_cdf_prob cnn_guided_quad_cdf[CDF_SIZE(4)];
+ // CDF to signal one of 2 partition types for each unit: GUIDED_QT_NONE or
+ // GUIDED_QT_HORZ/VERT, when only 2 types are allowed.
+ aom_cdf_prob cnn_guided_binary_cdf[CDF_SIZE(2)];
aom_cdf_prob cnn_guided_norestore_cdf[GUIDED_NORESTORE_CONTEXTS][CDF_SIZE(2)];
#endif // CONFIG_CNN_GUIDED_QUADTREE
#if !CONFIG_AIMC
diff --git a/av1/common/guided_quadtree.c b/av1/common/guided_quadtree.c
index e296c0b..d9c4400 100644
--- a/av1/common/guided_quadtree.c
+++ b/av1/common/guided_quadtree.c
@@ -275,52 +275,63 @@
dst->signaled = src->signaled;
}
-// Returns (int)floor(x / y),
-#define DIVIDE_WITH_FLOOR(x, y) ((x) / (y))
-// Returns (int)ceil(x / y),
-#define DIVIDE_WITH_CEILING(x, y) (((x) + (y)-1) / (y))
-
-int quad_tree_get_unit_info_length(int width, int height, int unit_length,
- const QUADSplitInfo *split_info,
- int split_info_length) {
- // We can compute total units as follows:
- // (1) regular units: they may / may not be split. So, compute length of
- // regular unit info by going through the split_info array. (2) unregular
- // units (blocks near boundary that are NOT unit_length in size): they are
- // never split. So, length of unregular unit info is same as number of
- // unregular units.
- const int regular_units = DIVIDE_WITH_FLOOR(width, unit_length) *
- DIVIDE_WITH_FLOOR(height, unit_length);
- assert(regular_units * 2 == split_info_length);
- const int total_units = DIVIDE_WITH_CEILING(width, unit_length) *
- DIVIDE_WITH_CEILING(height, unit_length);
- const int unregular_unit_info_len = total_units - regular_units;
-
- int regular_unit_info_len = 0;
- for (int i = 0; i < split_info_length; i += 2) {
- if (split_info == NULL ||
- (split_info[i].split == 0 && split_info[i + 1].split == 1)) {
- regular_unit_info_len += 4; // Split
- } else if (split_info[i].split == 1 && split_info[i + 1].split == 1) {
- regular_unit_info_len += 2; // Horz
- } else if (split_info[i].split == 1 && split_info[i + 1].split == 0) {
- regular_unit_info_len += 2; // Vert
- } else {
- assert(split_info[i].split == 0 && split_info[i + 1].split == 0);
- regular_unit_info_len += 1; // No split
+int quad_tree_get_max_unit_info_length(int width, int height, int unit_length) {
+ int unit_info_length = 0;
+ const int ext_size = unit_length * 3 / 2;
+ for (int row = 0; row < height;) {
+ const int remaining_height = height - row;
+ const int this_unit_height =
+ (remaining_height < ext_size) ? remaining_height : unit_length;
+ for (int col = 0; col < width;) {
+ const int remaining_width = width - col;
+ const int this_unit_width =
+ (remaining_width < ext_size) ? remaining_width : unit_length;
+ // Check for special cases near boundary.
+ const bool is_horz_partitioning_allowed =
+ (this_unit_height >= unit_length);
+ const bool is_vert_partitioning_allowed =
+ (this_unit_width >= unit_length);
+ if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) {
+ // Implicitly no split, so single unit info will be signaled.
+ ++unit_info_length;
+ } else {
+ // Assume maximum possible sub-units.
+ const int max_sub_units =
+ is_horz_partitioning_allowed && is_vert_partitioning_allowed ? 4
+ : 2;
+ unit_info_length += max_sub_units;
+ }
+ col += this_unit_width;
}
+ row += this_unit_height;
}
-
- return regular_unit_info_len + unregular_unit_info_len;
+ return unit_info_length;
}
int quad_tree_get_split_info_length(int width, int height, int unit_length) {
- // Split info only signaled for units of full size. Blocks near boundaries are
- // never split, so no info is signaled for those.
- const int num_split_info_wide = DIVIDE_WITH_FLOOR(width, unit_length);
- const int num_split_info_high = DIVIDE_WITH_FLOOR(height, unit_length);
- // 2 bits signaled for each split info.
- return num_split_info_wide * num_split_info_high * 2;
+ int split_info_len = 0;
+ const int ext_size = unit_length * 3 / 2;
+ for (int row = 0; row < height;) {
+ const int remaining_height = height - row;
+ const int this_unit_height =
+ (remaining_height < ext_size) ? remaining_height : unit_length;
+ for (int col = 0; col < width;) {
+ const int remaining_width = width - col;
+ const int this_unit_width =
+ (remaining_width < ext_size) ? remaining_width : unit_length;
+ // Check for special cases near boundary.
+ const bool is_horz_partitioning_allowed =
+ (this_unit_height >= unit_length);
+ const bool is_vert_partitioning_allowed =
+ (this_unit_width >= unit_length);
+ if (is_horz_partitioning_allowed || is_vert_partitioning_allowed) {
+ ++split_info_len;
+ }
+ col += this_unit_width;
+ }
+ row += this_unit_height;
+ }
+ return split_info_len;
}
void av1_alloc_quadtree_struct(struct AV1Common *cm, QUADInfo *quad_info) {
diff --git a/av1/common/guided_quadtree.h b/av1/common/guided_quadtree.h
index 4fef1e6..f51d9fe 100644
--- a/av1/common/guided_quadtree.h
+++ b/av1/common/guided_quadtree.h
@@ -30,17 +30,27 @@
#define GUIDED_A_MID (GUIDED_A_NUM_VALUES >> 1)
#define GUIDED_A_RANGE (GUIDED_A_NUM_VALUES - 1)
#define GUIDED_A_PAIR_BITS (GUIDED_A_BITS * 2 - 1)
+#define GUIDED_QT_UNIT_SIZES_LOG2 2
+#define GUIDED_QT_UNIT_SIZES (1 << GUIDED_QT_UNIT_SIZES_LOG2)
+
+typedef enum {
+ GUIDED_QT_NONE,
+ GUIDED_QT_SPLIT,
+ GUIDED_QT_HORZ,
+ GUIDED_QT_VERT,
+ GUIDED_QT_TYPES,
+ GUIDED_QT_INVALID = -1
+} GuidedQuadTreePartitionType;
int *get_quadparm_from_qindex(int qindex, int superres_denom, int is_intra_only,
int is_luma, int cnn_index);
void quad_copy(const QUADInfo *src, QUADInfo *dst, struct AV1Common *cm);
-// Get the length of unit info array based on dimensions and split info.
-// If split_info == NULL, assumes each block uses split, thereby returning
-// longest possible unit info length.
-int quad_tree_get_unit_info_length(int width, int height, int unit_length,
- const QUADSplitInfo *split_info,
- int split_info_length);
+
+// Get the maximum possible length of unit info array based on dimensions,
+// assuming each block uses split.
+int quad_tree_get_max_unit_info_length(int width, int height, int unit_length);
+
// Get the length of split info array based on dimensions.
int quad_tree_get_split_info_length(int width, int height, int unit_length);
@@ -56,10 +66,12 @@
// Get quad tree unit size.
static INLINE int quad_tree_get_unit_size(int width, int height,
int unit_index) {
- const bool is_720p_or_smaller = (width * height <= 1280 * 720);
- const int min_unit_size = is_720p_or_smaller ? 256 : 512;
- assert(unit_index >= 0 && unit_index <= 1);
- return min_unit_size << unit_index;
+ const int max_dim = AOMMAX(width, height);
+ const int max_dim_pow_2_bits = 1 + get_msb(max_dim);
+ const int max_dim_pow_2 = 1 << max_dim_pow_2_bits;
+ const int max_unit_size = AOMMAX(AOMMIN(max_dim_pow_2, 2048), 256);
+ assert(unit_index >= 0 && unit_index < GUIDED_QT_UNIT_SIZES);
+ return max_unit_size >> unit_index;
}
// Allocates buffers in 'quad_info' assuming 'quad_info->unit_index',
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index a71b6a6..a736131 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -137,7 +137,8 @@
#if CONFIG_CNN_GUIDED_QUADTREE
static AOM_INLINE void read_filter_quadtree(FRAME_CONTEXT *ctx, int QP,
- int cnn_index, int superres_denom,
+ int cnn_index, int width,
+ int height, int superres_denom,
int is_intra_only, QUADInfo *qi,
aom_reader *rb);
#endif // CONFIG_CNN_GUIDED_QUADTREE
@@ -1822,19 +1823,81 @@
}
}
#if CONFIG_CNN_GUIDED_QUADTREE
+static void read_quadtree_split_info(FRAME_CONTEXT *ctx, int width, int height,
+ QUADInfo *qi, aom_reader *rb) {
+ int unit_info_length = 0;
+ int split_info_index = 0;
+ const int ext_size = qi->unit_size * 3 / 2;
+ for (int row = 0; row < height;) {
+ const int remaining_height = height - row;
+ const int this_unit_height =
+ (remaining_height < ext_size) ? remaining_height : qi->unit_size;
+ for (int col = 0; col < width;) {
+ const int remaining_width = width - col;
+ const int this_unit_width =
+ (remaining_width < ext_size) ? remaining_width : qi->unit_size;
+ // Check for special cases near boundary.
+ const bool is_horz_partitioning_allowed =
+ (this_unit_height >= qi->unit_size);
+ const bool is_vert_partitioning_allowed =
+ (this_unit_width >= qi->unit_size);
+ if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) {
+ // Implicitly no split, and single unit info will be signaled.
+ ++unit_info_length;
+ } else {
+ assert(split_info_index < qi->split_info_length);
+ // Read split info, depending on how many partition types are allowed.
+ const bool all_partition_types_allowed =
+ (is_horz_partitioning_allowed && is_vert_partitioning_allowed);
+ if (all_partition_types_allowed) {
+ qi->split_info[split_info_index].split =
+ aom_read_symbol(rb, ctx->cnn_guided_quad_cdf, 4, ACCT_STR);
+ } else {
+ const int do_split =
+ aom_read_symbol(rb, ctx->cnn_guided_binary_cdf, 2, ACCT_STR);
+ if (do_split) {
+ qi->split_info[split_info_index].split =
+ is_horz_partitioning_allowed ? GUIDED_QT_HORZ : GUIDED_QT_VERT;
+ } else {
+ qi->split_info[split_info_index].split = GUIDED_QT_NONE;
+ }
+ }
+ // Look at the split info to determine number of (sub)units.
+ const GuidedQuadTreePartitionType partition_type =
+ qi->split_info[split_info_index].split;
+ switch (partition_type) {
+ case GUIDED_QT_NONE: unit_info_length += 1; break;
+ case GUIDED_QT_HORZ:
+ case GUIDED_QT_VERT: unit_info_length += 2; break;
+ case GUIDED_QT_SPLIT: unit_info_length += 4; break;
+ default: assert(0 && "Wrong guided quadtree split type."); break;
+ }
+ ++split_info_index;
+ }
+ col += this_unit_width;
+ }
+ row += this_unit_height;
+ }
+ assert(qi->split_info_length == split_info_index);
+ qi->unit_info_length = unit_info_length;
+}
+
static AOM_INLINE void read_filter_quadtree(FRAME_CONTEXT *ctx, int QP,
- int cnn_index, int superres_denom,
+ int cnn_index, int width,
+ int height, int superres_denom,
int is_intra_only, QUADInfo *qi,
aom_reader *rb) {
- int A0_min, A1_min;
- int *quadtset;
- quadtset =
+ // Read partitioning info.
+ read_quadtree_split_info(ctx, width, height, qi, rb);
+
+ // Read weight parameters 'a'.
+ const int *quadtset =
get_quadparm_from_qindex(QP, superres_denom, is_intra_only, 1, cnn_index);
const int norestore_ctx =
get_guided_norestore_ctx(QP, superres_denom, is_intra_only);
- A0_min = quadtset[2];
- A1_min = quadtset[3];
+ const int A0_min = quadtset[2];
+ const int A1_min = quadtset[3];
int ref_0 = GUIDED_A_MID;
int ref_1 = GUIDED_A_MID;
@@ -1861,8 +1924,6 @@
ref_0 = qi->unit_info[i].xqd[0] - A0_min;
ref_1 = qi->unit_info[i].xqd[1] - A1_min;
}
- // printf("a0:%d a1:%d\n", qi->unit_info[i].xqd[0],
- // qi->unit_info[i].xqd[1]);
}
}
#endif // CONFIG_CNN_GUIDED_QUADTREE
@@ -1922,18 +1983,10 @@
#if CONFIG_CNN_GUIDED_QUADTREE
if (cm->use_cnn[0] && !cm->cnn_quad_info.signaled) {
QUADInfo *qi = (QUADInfo *)&cm->cnn_quad_info;
- for (int s = 0; s < qi->split_info_length; s += 2) {
- const int split_index = aom_read_symbol(
- reader, xd->tile_ctx->cnn_guided_quad_cdf, 4, ACCT_STR);
- qi->split_info[s].split = split_index >> 1;
- qi->split_info[s + 1].split = split_index & 1;
- }
- qi->unit_info_length = quad_tree_get_unit_info_length(
+ read_filter_quadtree(
+ xd->tile_ctx, cm->quant_params.base_qindex, cm->cnn_indices[0],
cm->superres_upscaled_width, cm->superres_upscaled_height,
- qi->unit_size, qi->split_info, qi->split_info_length);
- read_filter_quadtree(xd->tile_ctx, cm->quant_params.base_qindex,
- cm->cnn_indices[0], cm->superres_scale_denominator,
- frame_is_intra_only(cm), qi, reader);
+ cm->superres_scale_denominator, frame_is_intra_only(cm), qi, reader);
cm->cnn_quad_info.signaled = 1;
}
#endif // CONFIG_CNN_GUIDED_QUADTREE
@@ -2154,7 +2207,7 @@
#if CONFIG_CNN_GUIDED_QUADTREE
// Read quad tree unit index.
static INLINE int quad_tree_read_unit_index(struct aom_read_bit_buffer *rb) {
- const int unit_index = aom_rb_read_bit(rb);
+ const int unit_index = aom_rb_read_literal(rb, GUIDED_QT_UNIT_SIZES_LOG2);
return unit_index;
}
#endif // CONFIG_CNN_GUIDED_QUADTREE
@@ -2191,9 +2244,9 @@
qi->unit_size);
// We allocate unit info assuming maximum number of possible units for now.
// Actual length will be set later after actually reading split info.
- qi->unit_info_length = quad_tree_get_unit_info_length(
+ qi->unit_info_length = quad_tree_get_max_unit_info_length(
cm->superres_upscaled_width, cm->superres_upscaled_height,
- qi->unit_size, NULL, qi->split_info_length);
+ qi->unit_size);
av1_alloc_quadtree_struct(cm, qi);
}
#endif // CONFIG_CNN_GUIDED_QUADTREE
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 0c0c827..88d4b7e 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -85,8 +85,9 @@
#if CONFIG_CNN_GUIDED_QUADTREE
// Write crlc coeffs for one frame
static void write_filter_quadtree(FRAME_CONTEXT *ctx, int qp, int cnn_index,
- int superres_denom, int is_intra_only,
- const QUADInfo *ci, aom_writer *wb);
+ int width, int height, int superres_denom,
+ int is_intra_only, const QUADInfo *qi,
+ aom_writer *wb);
#endif // CONFIG_CNN_GUIDED_QUADTREE
#if CONFIG_IBC_SR_EXT
@@ -2334,22 +2335,10 @@
#if CONFIG_CNN_GUIDED_QUADTREE
if (cm->use_cnn[0] && !cm->cnn_quad_info.signaled) {
QUADInfo *qi = (QUADInfo *)&cm->cnn_quad_info;
- assert(qi->split_info_length ==
- quad_tree_get_split_info_length(cm->superres_upscaled_width,
- cm->superres_upscaled_height,
- qi->unit_size));
- for (int s = 0; s < qi->split_info_length; s += 2) {
- const int split_index =
- qi->split_info[s].split * 2 + qi->split_info[s + 1].split;
- aom_write_symbol(w, split_index, xd->tile_ctx->cnn_guided_quad_cdf, 4);
- }
- assert(qi->unit_info_length ==
- quad_tree_get_unit_info_length(
- cm->superres_upscaled_width, cm->superres_upscaled_height,
- qi->unit_size, qi->split_info, qi->split_info_length));
- write_filter_quadtree(xd->tile_ctx, cm->quant_params.base_qindex,
- cm->cnn_indices[0], cm->superres_scale_denominator,
- frame_is_intra_only(cm), qi, w);
+ write_filter_quadtree(
+ xd->tile_ctx, cm->quant_params.base_qindex, cm->cnn_indices[0],
+ cm->superres_upscaled_width, cm->superres_upscaled_height,
+ cm->superres_scale_denominator, frame_is_intra_only(cm), qi, w);
qi->signaled = 1;
}
#endif // CONFIG_CNN_GUIDED_QUADTREE
@@ -3062,9 +3051,75 @@
#endif // !CONFIG_NEW_DF
#if CONFIG_CNN_GUIDED_QUADTREE
+static void write_quadtree_split_info(FRAME_CONTEXT *ctx, int width, int height,
+ const QUADInfo *qi, aom_writer *wb) {
+#ifndef NDEBUG
+ int unit_info_length = 0;
+#endif // NDEBUG
+ int split_info_index = 0;
+ const int ext_size = qi->unit_size * 3 / 2;
+ for (int row = 0; row < height;) {
+ const int remaining_height = height - row;
+ const int this_unit_height =
+ (remaining_height < ext_size) ? remaining_height : qi->unit_size;
+ for (int col = 0; col < width;) {
+ const int remaining_width = width - col;
+ const int this_unit_width =
+ (remaining_width < ext_size) ? remaining_width : qi->unit_size;
+ // Check for special cases near boundary.
+ const bool is_horz_partitioning_allowed =
+ (this_unit_height >= qi->unit_size);
+ const bool is_vert_partitioning_allowed =
+ (this_unit_width >= qi->unit_size);
+ if (!is_horz_partitioning_allowed && !is_vert_partitioning_allowed) {
+ // Implicitly no split, and single unit info will be signaled.
+#ifndef NDEBUG
+ ++unit_info_length;
+#endif // NDEBUG
+ } else {
+ assert(split_info_index < qi->split_info_length);
+ // Write split info, depending on how many partition types are allowed.
+ const bool all_partition_types_allowed =
+ (is_horz_partitioning_allowed && is_vert_partitioning_allowed);
+ const GuidedQuadTreePartitionType partition_type =
+ qi->split_info[split_info_index].split;
+ if (all_partition_types_allowed) {
+ aom_write_symbol(wb, partition_type, ctx->cnn_guided_quad_cdf,
+ GUIDED_QT_TYPES);
+ } else {
+ aom_write_symbol(wb, partition_type != GUIDED_QT_NONE,
+ ctx->cnn_guided_binary_cdf, 2);
+ }
+#ifndef NDEBUG
+ // Look at the split info to determine number of (sub)units.
+ switch (partition_type) {
+ case GUIDED_QT_NONE: unit_info_length += 1; break;
+ case GUIDED_QT_HORZ:
+ case GUIDED_QT_VERT: unit_info_length += 2; break;
+ case GUIDED_QT_SPLIT: unit_info_length += 4; break;
+ default: assert(0 && "Wrong guided quadtree split type."); break;
+ }
+#endif // NDEBUG
+ ++split_info_index;
+ }
+ col += this_unit_width;
+ }
+ row += this_unit_height;
+ }
+ assert(qi->split_info_length == split_info_index);
+#ifndef NDEBUG
+ assert(qi->unit_info_length == unit_info_length);
+#endif // NDEBUG
+}
+
static void write_filter_quadtree(FRAME_CONTEXT *ctx, int QP, int cnn_index,
- int superres_denom, int is_intra_only,
- const QUADInfo *ci, aom_writer *wb) {
+ int width, int height, int superres_denom,
+ int is_intra_only, const QUADInfo *qi,
+ aom_writer *wb) {
+ // Write partitioning info.
+ write_quadtree_split_info(ctx, width, height, qi, wb);
+
+ // Write weight parameters 'a'.
const int *const quadtset =
get_quadparm_from_qindex(QP, superres_denom, is_intra_only, 1, cnn_index);
const int norestore_ctx =
@@ -3073,9 +3128,9 @@
const int A1_min = quadtset[3];
int ref_0 = GUIDED_A_MID;
int ref_1 = GUIDED_A_MID;
- for (int i = 0; i < ci->unit_info_length; i++) {
- const int a0 = ci->unit_info[i].xqd[0];
- const int a1 = ci->unit_info[i].xqd[1];
+ for (int i = 0; i < qi->unit_info_length; i++) {
+ const int a0 = qi->unit_info[i].xqd[0];
+ const int a1 = qi->unit_info[i].xqd[1];
int norestore;
if (norestore_ctx != -1) {
norestore = (a0 == 0 && a1 == 0);
@@ -3107,8 +3162,8 @@
// Write quad tree unit index.
static INLINE void quad_tree_write_unit_index(struct aom_write_bit_buffer *wb,
const QUADInfo *const qi) {
- assert(qi->unit_index >= 0 && qi->unit_index <= 1);
- aom_wb_write_bit(wb, qi->unit_index);
+ assert(qi->unit_index >= 0 && qi->unit_index < GUIDED_QT_UNIT_SIZES);
+ aom_wb_write_literal(wb, qi->unit_index, GUIDED_QT_UNIT_SIZES_LOG2);
}
#endif // CONFIG_CNN_GUIDED_QUADTREE
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 02cc6a3..d740110 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -884,6 +884,10 @@
*/
int cnn_guided_quad_cost[4];
/*!
+ * cnn_guided binary split cost
+ */
+ int cnn_guided_binary_cost[2];
+ /*!
* cnn_guided quad norestore cost
*/
int cnn_guided_norestore_cost[2][2];
diff --git a/av1/encoder/encodeframe_utils.c b/av1/encoder/encodeframe_utils.c
index 8f5780b..20e1191 100644
--- a/av1/encoder/encodeframe_utils.c
+++ b/av1/encoder/encodeframe_utils.c
@@ -1271,6 +1271,8 @@
#endif // CONFIG_RST_MERGECEOFFS
#if CONFIG_CNN_GUIDED_QUADTREE
AVERAGE_CDF(ctx_left->cnn_guided_quad_cdf, ctx_tr->cnn_guided_quad_cdf, 4);
+ AVERAGE_CDF(ctx_left->cnn_guided_binary_cdf, ctx_tr->cnn_guided_binary_cdf,
+ 2);
AVERAGE_CDF(ctx_left->cnn_guided_norestore_cdf,
ctx_tr->cnn_guided_norestore_cdf, 2);
#endif // CONFIG_CNN_GUIDED_QUADTREE
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 0b7e69a..e47994d 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -2339,6 +2339,7 @@
if (!av1_restore_cnn_quadtree_encode_tflite(
cm, cpi->source, cpi->rd.RDMULT,
x->mode_costs.cnn_guided_quad_cost,
+ x->mode_costs.cnn_guided_binary_cost,
x->mode_costs.cnn_guided_norestore_cost,
cpi->mt_info.num_workers, quadtree_cnn, curr_cnn_indices,
&quad_info, &curr_cnn_rdcosts)) {
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index ebb64df..d8c2073 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -455,6 +455,8 @@
#if CONFIG_CNN_GUIDED_QUADTREE
av1_cost_tokens_from_cdf(mode_costs->cnn_guided_quad_cost,
fc->cnn_guided_quad_cdf, NULL);
+ av1_cost_tokens_from_cdf(mode_costs->cnn_guided_binary_cost,
+ fc->cnn_guided_binary_cdf, NULL);
for (int i = 0; i < GUIDED_NORESTORE_CONTEXTS; ++i)
av1_cost_tokens_from_cdf(mode_costs->cnn_guided_norestore_cost[i],
fc->cnn_guided_norestore_cdf[i], NULL);
@@ -502,6 +504,8 @@
#if CONFIG_CNN_GUIDED_QUADTREE
av1_cost_tokens_from_cdf(mode_costs->cnn_guided_quad_cost,
fc->cnn_guided_quad_cdf, NULL);
+ av1_cost_tokens_from_cdf(mode_costs->cnn_guided_binary_cost,
+ fc->cnn_guided_binary_cdf, NULL);
for (int c = 0; c < GUIDED_NORESTORE_CONTEXTS; ++c)
av1_cost_tokens_from_cdf(mode_costs->cnn_guided_norestore_cost[c],
fc->cnn_guided_norestore_cdf[c], NULL);