Guided CNN restoration: use macros for constants
diff --git a/av1/common/cnn_tflite.cc b/av1/common/cnn_tflite.cc
index e99b8ed..b94792b 100644
--- a/av1/common/cnn_tflite.cc
+++ b/av1/common/cnn_tflite.cc
@@ -776,8 +776,8 @@
// finer search
double flrA0 = (floor(A0));
double flrA1 = (floor(A1));
- flrA0 = AOMMIN(AOMMAX(flrA0, A0_min), A0_min + 15);
- flrA1 = AOMMIN(AOMMAX(flrA1, A1_min), A1_min + 15);
+ flrA0 = AOMMIN(AOMMAX(flrA0, A0_min), A0_min + GUIDED_A_RANGE);
+ flrA1 = AOMMIN(AOMMAX(flrA1, A1_min), A1_min + GUIDED_A_RANGE);
{
A0 = flrA0;
A1 = flrA1;
@@ -792,17 +792,20 @@
err += diff * diff;
}
}
- // approx RD cost assuming 7 bits per a0, a1 pair
+ // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair
cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
- rdmult, (norestorecost[0] + (7 << AV1_PROB_COST_SHIFT)) >> 4, err,
- bit_depth);
+ rdmult,
+ (norestorecost[0] +
+ (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >>
+ 4,
+ err, bit_depth);
if (cost < bestcost) {
bestA0 = A0;
bestA1 = A1;
bestcost = cost;
}
}
- if (flrA0 < A0_min + 15) {
+ if (flrA0 < A0_min + GUIDED_A_RANGE) {
A0 = flrA0 + 1;
A1 = flrA1;
err = 0;
@@ -816,17 +819,20 @@
err += diff * diff;
}
}
- // approx RD cost assuming 7 bits per a0, a1 pair
+ // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair
cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
- rdmult, (norestorecost[0] + (7 << AV1_PROB_COST_SHIFT)) >> 4, err,
- bit_depth);
+ rdmult,
+ (norestorecost[0] +
+ (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >>
+ 4,
+ err, bit_depth);
if (cost < bestcost) {
bestA0 = A0;
bestA1 = A1;
bestcost = cost;
}
}
- if (flrA1 < A1_min + 15) {
+ if (flrA1 < A1_min + GUIDED_A_RANGE) {
A0 = flrA0;
A1 = flrA1 + 1;
err = 0;
@@ -840,17 +846,21 @@
err += diff * diff;
}
}
- // approx RD cost assuming 7 bits per a0, a1 pair
+ // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair
cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
- rdmult, (norestorecost[0] + (7 << AV1_PROB_COST_SHIFT)) >> 4, err,
- bit_depth);
+ rdmult,
+ (norestorecost[0] +
+ (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >>
+ 4,
+ err, bit_depth);
if (cost < bestcost) {
bestA0 = A0;
bestA1 = A1;
bestcost = cost;
}
}
- if (flrA0 < A0_min + 15 && flrA1 < A1_min + 15) {
+ if (flrA0 < A0_min + GUIDED_A_RANGE &&
+ flrA1 < A1_min + GUIDED_A_RANGE) {
A0 = flrA0 + 1;
A1 = flrA1 + 1;
err = 0;
@@ -864,10 +874,13 @@
err += diff * diff;
}
}
- // approx RD cost assuming 7 bits per a0, a1 pair
+ // approx RD cost assuming GUIDED_A_PAIR_BITS bits per a0, a1 pair
cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
- rdmult, (norestorecost[0] + (7 << AV1_PROB_COST_SHIFT)) >> 4, err,
- bit_depth);
+ rdmult,
+ (norestorecost[0] +
+ (GUIDED_A_PAIR_BITS << AV1_PROB_COST_SHIFT)) >>
+ 4,
+ err, bit_depth);
if (cost < bestcost) {
bestA0 = A0;
bestA1 = A1;
@@ -879,12 +892,12 @@
} else {
A0 = (round(A0));
A1 = (round(A1));
- A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + 15);
- A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + 15);
+ A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + GUIDED_A_RANGE);
+ A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + GUIDED_A_RANGE);
}
- A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + 15);
- A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + 15);
+ A0 = AOMMIN(AOMMAX(A0, A0_min), A0_min + GUIDED_A_RANGE);
+ A1 = AOMMIN(AOMMAX(A1, A1_min), A1_min + GUIDED_A_RANGE);
A.emplace_back((int)A0, (int)A1);
for (int i = this_start_row; i < this_end_row; i++) {
for (int j = this_start_col; j < this_end_col; j++) {
@@ -929,21 +942,21 @@
const int A0_min = quadtset[2];
const int A1_min = quadtset[3];
int num_bits = 0;
- int ref0 = AOMMIN(AOMMAX(prev_A.first - A0_min, 0), 15);
- int ref1 = AOMMIN(AOMMAX(prev_A.second - A1_min, 0), 15);
+ int ref0 = AOMMIN(AOMMAX(prev_A.first - A0_min, 0), GUIDED_A_RANGE);
+ int ref1 = AOMMIN(AOMMAX(prev_A.second - A1_min, 0), GUIDED_A_RANGE);
for (auto &this_A : A) {
if (this_A.first == 0 && this_A.second == 0) {
num_bits += norestorecosts[1];
} else {
num_bits += norestorecosts[0];
- num_bits += (aom_count_primitive_refsubexpfin(16, 1, ref0,
- this_A.first - A0_min) +
- aom_count_primitive_refsubexpfin(16, 1, ref1,
- this_A.second - A1_min))
+ num_bits += (aom_count_primitive_refsubexpfin(
+ GUIDED_A_NUM_VALUES, 1, ref0, this_A.first - A0_min) +
+ aom_count_primitive_refsubexpfin(
+ GUIDED_A_NUM_VALUES, 1, ref1, this_A.second - A1_min))
<< AV1_PROB_COST_SHIFT;
}
- ref0 = AOMMIN(AOMMAX(this_A.first - A0_min, 0), 15);
- ref1 = AOMMIN(AOMMAX(this_A.second - A1_min, 0), 15);
+ ref0 = AOMMIN(AOMMAX(this_A.first - A0_min, 0), GUIDED_A_RANGE);
+ ref1 = AOMMIN(AOMMAX(this_A.second - A1_min, 0), GUIDED_A_RANGE);
}
return num_bits;
}
@@ -1132,7 +1145,8 @@
std::vector<std::pair<int, int>> this_A; // selected a0, a1 weight pairs.
double this_rdcost_total = 0.0;
// Previous a0, a1 pair is mid-point of the range by default.
- std::pair<int, int> prev_A = std::make_pair(8 + A0_min, 8 + A1_min);
+ std::pair<int, int> prev_A =
+ std::make_pair(GUIDED_A_MID + A0_min, GUIDED_A_MID + A1_min);
// TODO(urvang): Include padded area in a unit if it's < unit size / 2?
// If so, need to modify / replace quad_tree_get_unit_info_length().
// Also double check: quad_tree_get_split_info_length().
diff --git a/av1/common/guided_quadtree.h b/av1/common/guided_quadtree.h
index 31550eb..4fef1e6 100644
--- a/av1/common/guided_quadtree.h
+++ b/av1/common/guided_quadtree.h
@@ -24,6 +24,13 @@
#endif
#if CONFIG_CNN_GUIDED_QUADTREE
+
+#define GUIDED_A_BITS 4
+#define GUIDED_A_NUM_VALUES (1 << GUIDED_A_BITS)
+#define GUIDED_A_MID (GUIDED_A_NUM_VALUES >> 1)
+#define GUIDED_A_RANGE (GUIDED_A_NUM_VALUES - 1)
+#define GUIDED_A_PAIR_BITS (GUIDED_A_BITS * 2 - 1)
+
int *get_quadparm_from_qindex(int qindex, int superres_denom, int is_intra_only,
int is_luma, int cnn_index);
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index a1088f0..a71b6a6 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -1836,8 +1836,8 @@
A0_min = quadtset[2];
A1_min = quadtset[3];
- int ref_0 = 8;
- int ref_1 = 8;
+ int ref_0 = GUIDED_A_MID;
+ int ref_1 = GUIDED_A_MID;
for (int i = 0; i < qi->unit_info_length; i++) {
const int norestore =
norestore_ctx == -1
@@ -1847,13 +1847,17 @@
if (norestore) {
qi->unit_info[i].xqd[0] = 0;
qi->unit_info[i].xqd[1] = 0;
- ref_0 = AOMMAX(A0_min, AOMMIN(A0_min + 15, 0)) - A0_min;
- ref_1 = AOMMAX(A1_min, AOMMIN(A1_min + 15, 0)) - A1_min;
+ ref_0 = AOMMAX(A0_min, AOMMIN(A0_min + GUIDED_A_RANGE, 0)) - A0_min;
+ ref_1 = AOMMAX(A1_min, AOMMIN(A1_min + GUIDED_A_RANGE, 0)) - A1_min;
} else {
qi->unit_info[i].xqd[0] =
- aom_read_primitive_refsubexpfin(rb, 16, 1, ref_0, ACCT_STR) + A0_min;
+ aom_read_primitive_refsubexpfin(rb, GUIDED_A_NUM_VALUES, 1, ref_0,
+ ACCT_STR) +
+ A0_min;
qi->unit_info[i].xqd[1] =
- aom_read_primitive_refsubexpfin(rb, 16, 1, ref_1, ACCT_STR) + A1_min;
+ aom_read_primitive_refsubexpfin(rb, GUIDED_A_NUM_VALUES, 1, ref_1,
+ ACCT_STR) +
+ A1_min;
ref_0 = qi->unit_info[i].xqd[0] - A0_min;
ref_1 = qi->unit_info[i].xqd[1] - A1_min;
}
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 88817af..0c0c827 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -3065,26 +3065,18 @@
static void write_filter_quadtree(FRAME_CONTEXT *ctx, int QP, int cnn_index,
int superres_denom, int is_intra_only,
const QUADInfo *ci, aom_writer *wb) {
- int A0_min, A1_min;
- int *quadtset;
- quadtset =
+ const int *const quadtset =
get_quadparm_from_qindex(QP, superres_denom, is_intra_only, 1, cnn_index);
const int norestore_ctx =
get_guided_norestore_ctx(QP, superres_denom, is_intra_only);
- A0_min = quadtset[2];
- A1_min = quadtset[3];
- int a0;
- int a1;
- int b_a0;
- int b_a1;
- int ref_0 = 8;
- int ref_1 = 8;
+ const int A0_min = quadtset[2];
+ const int A1_min = quadtset[3];
+ int ref_0 = GUIDED_A_MID;
+ int ref_1 = GUIDED_A_MID;
for (int i = 0; i < ci->unit_info_length; i++) {
- a0 = ci->unit_info[i].xqd[0];
- a1 = ci->unit_info[i].xqd[1];
+ const int a0 = ci->unit_info[i].xqd[0];
+ const int a1 = ci->unit_info[i].xqd[1];
int norestore;
-
- // printf("a0:%d a1:%d\n", a0, a1);
if (norestore_ctx != -1) {
norestore = (a0 == 0 && a1 == 0);
aom_write_symbol(wb, norestore,
@@ -3093,27 +3085,17 @@
norestore = 0;
}
if (norestore) {
- ref_0 = AOMMAX(A0_min, AOMMIN(A0_min + 15, 0)) - A0_min;
- ref_1 = AOMMAX(A1_min, AOMMIN(A1_min + 15, 0)) - A1_min;
+ ref_0 = AOMMAX(A0_min, AOMMIN(A0_min + GUIDED_A_RANGE, 0)) - A0_min;
+ ref_1 = AOMMAX(A1_min, AOMMIN(A1_min + GUIDED_A_RANGE, 0)) - A1_min;
} else {
- b_a0 = a0 - A0_min;
- if (b_a0 < 0) {
- b_a0 = 0;
- }
- if (b_a0 > 15) {
- b_a0 = 15;
- }
- b_a1 = a1 - A1_min;
- if (b_a1 < 0) {
- b_a1 = 0;
- }
- if (b_a1 > 15) {
- b_a1 = 15;
- }
- aom_write_primitive_refsubexpfin(wb, 16, 1, ref_0, b_a0);
- aom_write_primitive_refsubexpfin(wb, 16, 1, ref_1, b_a1);
- ref_0 = b_a0;
- ref_1 = b_a1;
+ const int a0_offset = AOMMIN(AOMMAX(a0 - A0_min, 0), GUIDED_A_RANGE);
+ const int a1_offset = AOMMIN(AOMMAX(a1 - A1_min, 0), GUIDED_A_RANGE);
+ aom_write_primitive_refsubexpfin(wb, GUIDED_A_NUM_VALUES, 1, ref_0,
+ a0_offset);
+ aom_write_primitive_refsubexpfin(wb, GUIDED_A_NUM_VALUES, 1, ref_1,
+ a1_offset);
+ ref_0 = a0_offset;
+ ref_1 = a1_offset;
}
}
}