[CFL] Uniform Q3 alpha grid with extent [-2, 2]
Expand the range of alpha to [-2, 2] in Q3.
Jointly signal the signs, including zeros.
Use the signs to give context for each quadrant
and half-axis. The (0, 0) point is excluded.
Symmetry in alpha_u == alpha_v yields 6 contexts.
Results on Subset1 (Compared to 9136ab7d with CFL enabled)
PSNR | PSNR Cb | PSNR Cr | PSNR HVS | SSIM | MS SSIM | CIEDE 2000
-0.0792 | -0.7535 | -0.7574 | -0.0639 | -0.0843 | -0.0665 | -0.3324
Change-Id: I250369692e92a91d9c8d174a203d441217d15063
Signed-off-by: David Michael Barr <b@rr-dav.id.au>
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 43b2521..bd3591c 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1642,23 +1642,18 @@
}
#if CONFIG_CFL
-static void write_cfl_alphas(FRAME_CONTEXT *const frame_ctx, int ind,
- const CFL_SIGN_TYPE signs[CFL_SIGNS],
- aom_writer *w) {
- // Check for uninitialized signs
- if (cfl_alpha_codes[ind][CFL_PRED_U] == 0)
- assert(signs[CFL_PRED_U] == CFL_SIGN_POS);
- if (cfl_alpha_codes[ind][CFL_PRED_V] == 0)
- assert(signs[CFL_PRED_V] == CFL_SIGN_POS);
-
- // Write a symbol representing a combination of alpha Cb and alpha Cr.
- aom_write_symbol(w, ind, frame_ctx->cfl_alpha_cdf, CFL_ALPHABET_SIZE);
-
- // Signs are only signaled for nonzero codes.
- if (cfl_alpha_codes[ind][CFL_PRED_U] != 0)
- aom_write_bit(w, signs[CFL_PRED_U]);
- if (cfl_alpha_codes[ind][CFL_PRED_V] != 0)
- aom_write_bit(w, signs[CFL_PRED_V]);
+static void write_cfl_alphas(FRAME_CONTEXT *const ec_ctx, int idx,
+ int joint_sign, aom_writer *w) {
+ aom_write_symbol(w, joint_sign, ec_ctx->cfl_sign_cdf, CFL_JOINT_SIGNS);
+ // Magnitudes are only signaled for nonzero codes.
+ if (CFL_SIGN_U(joint_sign) != CFL_SIGN_ZERO) {
+ aom_cdf_prob *cdf_u = ec_ctx->cfl_alpha_cdf[CFL_CONTEXT_U(joint_sign)];
+ aom_write_symbol(w, CFL_IDX_U(idx), cdf_u, CFL_ALPHABET_SIZE);
+ }
+ if (CFL_SIGN_V(joint_sign) != CFL_SIGN_ZERO) {
+ aom_cdf_prob *cdf_v = ec_ctx->cfl_alpha_cdf[CFL_CONTEXT_V(joint_sign)];
+ aom_write_symbol(w, CFL_IDX_V(idx), cdf_v, CFL_ALPHABET_SIZE);
+ }
}
#endif
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 7f468ea..c8e5811 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -5685,19 +5685,25 @@
}
static inline void cfl_update_costs(CFL_CTX *cfl, FRAME_CONTEXT *ec_ctx) {
- assert(ec_ctx->cfl_alpha_cdf[CFL_ALPHABET_SIZE - 1] ==
- AOM_ICDF(CDF_PROB_TOP));
-
- aom_cdf_prob prev_cdf = 0;
-
- for (int c = 0; c < CFL_ALPHABET_SIZE; c++) {
- const int sign_bit_cost = (cfl_alpha_codes[c][CFL_PRED_U] != 0) +
- (cfl_alpha_codes[c][CFL_PRED_V] != 0);
-
- aom_cdf_prob prob = AOM_ICDF(ec_ctx->cfl_alpha_cdf[c]) - prev_cdf;
- prev_cdf = AOM_ICDF(ec_ctx->cfl_alpha_cdf[c]);
-
- cfl->costs[c] = av1_cost_symbol(prob) + av1_cost_literal(sign_bit_cost);
+ int sign_cost[CFL_JOINT_SIGNS];
+ av1_cost_tokens_from_cdf(sign_cost, ec_ctx->cfl_sign_cdf, NULL);
+ for (int joint_sign = 0; joint_sign < CFL_JOINT_SIGNS; joint_sign++) {
+ const aom_cdf_prob *cdf_u =
+ ec_ctx->cfl_alpha_cdf[CFL_CONTEXT_U(joint_sign)];
+ const aom_cdf_prob *cdf_v =
+ ec_ctx->cfl_alpha_cdf[CFL_CONTEXT_V(joint_sign)];
+ int *cost_u = cfl->costs[joint_sign][CFL_PRED_U];
+ int *cost_v = cfl->costs[joint_sign][CFL_PRED_V];
+ if (CFL_SIGN_U(joint_sign) == CFL_SIGN_ZERO)
+ memset(cost_u, 0, CFL_ALPHABET_SIZE * sizeof(*cost_u));
+ else
+ av1_cost_tokens_from_cdf(cost_u, cdf_u, NULL);
+ if (CFL_SIGN_V(joint_sign) == CFL_SIGN_ZERO)
+ memset(cost_v, 0, CFL_ALPHABET_SIZE * sizeof(*cost_v));
+ else
+ av1_cost_tokens_from_cdf(cost_v, cdf_v, NULL);
+ for (int u = 0; u < CFL_ALPHABET_SIZE; u++)
+ cost_u[u] += sign_cost[joint_sign];
}
}
@@ -5722,8 +5728,6 @@
const int *y_averages_q3 = cfl->y_averages_q3;
const uint8_t *y_pix = cfl->y_down_pix;
- CFL_SIGN_TYPE *signs = mbmi->cfl_alpha_signs;
-
cfl_update_costs(cfl, ec_ctx);
int64_t sse[CFL_PRED_PLANES][CFL_MAGS_SIZE];
@@ -5734,47 +5738,54 @@
cfl_alpha_dist(y_pix, MAX_SB_SIZE, y_averages_q3, src_v, src_stride_v,
width, height, tx_size, dc_pred_v, 0, NULL);
- for (int m = 1; m < CFL_MAGS_SIZE; m += 2) {
- assert(cfl_alpha_mags_q3[m + 1] == -cfl_alpha_mags_q3[m]);
+ for (int c = 0; c < CFL_ALPHABET_SIZE; c++) {
+ const int m = c * 2 + 1;
+ const int abs_alpha_q3 = c + 1;
sse[CFL_PRED_U][m] = cfl_alpha_dist(
y_pix, MAX_SB_SIZE, y_averages_q3, src_u, src_stride_u, width, height,
- tx_size, dc_pred_u, cfl_alpha_mags_q3[m], &sse[CFL_PRED_U][m + 1]);
+ tx_size, dc_pred_u, abs_alpha_q3, &sse[CFL_PRED_U][m + 1]);
sse[CFL_PRED_V][m] = cfl_alpha_dist(
y_pix, MAX_SB_SIZE, y_averages_q3, src_v, src_stride_v, width, height,
- tx_size, dc_pred_v, cfl_alpha_mags_q3[m], &sse[CFL_PRED_V][m + 1]);
+ tx_size, dc_pred_v, abs_alpha_q3, &sse[CFL_PRED_V][m + 1]);
}
int64_t dist;
int64_t cost;
- int64_t best_cost;
+ int64_t best_cost = INT64_MAX;
+ int best_rate = 0;
// Compute least squares parameter of the entire block
int ind = 0;
- signs[CFL_PRED_U] = CFL_SIGN_POS;
- signs[CFL_PRED_V] = CFL_SIGN_POS;
- best_cost = INT64_MAX;
+ int signs = 0;
- for (int c = 0; c < CFL_ALPHABET_SIZE; c++) {
- const int idx_u = cfl_alpha_codes[c][CFL_PRED_U];
- const int idx_v = cfl_alpha_codes[c][CFL_PRED_V];
- for (CFL_SIGN_TYPE sign_u = idx_u == 0; sign_u < CFL_SIGNS; sign_u++) {
- for (CFL_SIGN_TYPE sign_v = idx_v == 0; sign_v < CFL_SIGNS; sign_v++) {
+ for (int joint_sign = 0; joint_sign < CFL_JOINT_SIGNS; joint_sign++) {
+ const int sign_u = CFL_SIGN_U(joint_sign);
+ const int sign_v = CFL_SIGN_V(joint_sign);
+ const int size_u = (sign_u == CFL_SIGN_ZERO) ? 1 : CFL_ALPHABET_SIZE;
+ const int size_v = (sign_v == CFL_SIGN_ZERO) ? 1 : CFL_ALPHABET_SIZE;
+ for (int u = 0; u < size_u; u++) {
+ const int idx_u = (sign_u == CFL_SIGN_ZERO) ? 0 : u * 2 + 1;
+ for (int v = 0; v < size_v; v++) {
+ const int idx_v = (sign_v == CFL_SIGN_ZERO) ? 0 : v * 2 + 1;
dist = sse[CFL_PRED_U][idx_u + (sign_u == CFL_SIGN_NEG)] +
sse[CFL_PRED_V][idx_v + (sign_v == CFL_SIGN_NEG)];
dist *= 16;
- cost = RDCOST(x->rdmult, cfl->costs[c], dist);
+ const int rate = cfl->costs[joint_sign][CFL_PRED_U][u] +
+ cfl->costs[joint_sign][CFL_PRED_V][v];
+ cost = RDCOST(x->rdmult, rate, dist);
if (cost < best_cost) {
best_cost = cost;
- ind = c;
- signs[CFL_PRED_U] = sign_u;
- signs[CFL_PRED_V] = sign_v;
+ best_rate = rate;
+ ind = (u << CFL_ALPHABET_SIZE_LOG2) + v;
+ signs = joint_sign;
}
}
}
}
mbmi->cfl_alpha_idx = ind;
- return cfl->costs[ind];
+ mbmi->cfl_alpha_signs = signs;
+ return best_rate;
}
#endif // CONFIG_CFL