[CFL] Alpha choice by rate-distortion cost
Measure SSE for all possible alphas.
Estimate rates for alpha signalling.
Change-Id: Idf1e3c632925cd306090fc38cf5b95eff7ee5c1c
Signed-off-by: David Michael Barr <b@rr-dav.id.au>
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index a9fe189..347926f 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -1436,8 +1436,8 @@
uint8_t *dst =
&pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
#if CONFIG_CFL
- av1_predict_intra_block_encoder_facade(xd, plane, block, blk_col, blk_row,
- tx_size);
+ av1_predict_intra_block_encoder_facade(x, plane, block, blk_col, blk_row,
+ tx_size, plane_bsize);
#else
av1_predict_intra_block_facade(xd, plane, block, blk_col, blk_row, tx_size);
#endif
@@ -1520,14 +1520,151 @@
}
#if CONFIG_CFL
-void av1_predict_intra_block_encoder_facade(MACROBLOCKD *xd, int plane,
- int block_idx, int blk_col,
- int blk_row, TX_SIZE tx_size) {
- MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
- mbmi->cfl_alpha_ind = 0;
- mbmi->cfl_alpha_signs[CFL_PRED_U] = CFL_SIGN_POS;
- mbmi->cfl_alpha_signs[CFL_PRED_V] = CFL_SIGN_POS;
+static int cfl_alpha_dist(const uint8_t *y_pix, int y_stride, double y_avg,
+ const uint8_t *src, int src_stride, int blk_width,
+ int blk_height, double dc_pred, double alpha,
+ int *dist_neg_out) {
+ const double dc_pred_bias = dc_pred + 0.5;
+ int dist = 0;
+ int diff;
+ if (alpha == 0.0) {
+ const int dc_pred_i = (int)dc_pred_bias;
+ for (int j = 0; j < blk_height; j++) {
+ for (int i = 0; i < blk_width; i++) {
+ diff = src[i] - dc_pred_i;
+ dist += diff * diff;
+ }
+ src += src_stride;
+ }
+
+ if (dist_neg_out) *dist_neg_out = dist;
+
+ return dist;
+ }
+
+ int dist_neg = 0;
+ for (int j = 0; j < blk_height; j++) {
+ for (int i = 0; i < blk_width; i++) {
+ const double scaled_luma = alpha * (y_pix[i] - y_avg);
+ const int uv = src[i];
+ diff = uv - (int)(scaled_luma + dc_pred_bias);
+ dist += diff * diff;
+ diff = uv + (int)(scaled_luma - dc_pred_bias);
+ dist_neg += diff * diff;
+ }
+ y_pix += y_stride;
+ src += src_stride;
+ }
+
+ if (dist_neg_out) *dist_neg_out = dist_neg;
+
+ return dist;
+}
+
+static int cfl_compute_alpha_ind(MACROBLOCK *const x, const CFL_CTX *const cfl,
+ BLOCK_SIZE bsize, int *const cfl_cost,
+ CFL_SIGN_TYPE *signs) {
+ const struct macroblock_plane *const p_u = &x->plane[AOM_PLANE_U];
+ const struct macroblock_plane *const p_v = &x->plane[AOM_PLANE_V];
+ const uint8_t *const src_u = p_u->src.buf;
+ const uint8_t *const src_v = p_v->src.buf;
+ const int src_stride_u = p_u->src.stride;
+ const int src_stride_v = p_v->src.stride;
+ const int block_width = block_size_wide[bsize];
+ const int block_height = block_size_high[bsize];
+ const double dc_pred_u = cfl->dc_pred[CFL_PRED_U];
+ const double dc_pred_v = cfl->dc_pred[CFL_PRED_V];
+
+ // Temporary pixel buffer used to store the CfL prediction when we compute the
+ // alpha index.
+ uint8_t tmp_pix[MAX_SB_SQUARE];
+ // Load CfL Prediction over the entire block
+ const double y_avg =
+ cfl_load(cfl, tmp_pix, MAX_SB_SIZE, 0, 0, max_txsize_lookup[bsize]);
+
+ int dist_u, dist_v;
+ int dist_u_neg, dist_v_neg;
+ int dist;
+ int64_t cost;
+ int64_t best_cost;
+
+ // Compute least squares parameter of the entire block
+ // IMPORTANT: We assume that the first code is 0,0
+ int ind = 0;
+ signs[CFL_PRED_U] = CFL_SIGN_POS;
+ signs[CFL_PRED_V] = CFL_SIGN_POS;
+
+ dist = cfl_alpha_dist(tmp_pix, MAX_SB_SIZE, y_avg, src_u, src_stride_u,
+ block_width, block_height, dc_pred_u, 0, NULL) +
+ cfl_alpha_dist(tmp_pix, MAX_SB_SIZE, y_avg, src_v, src_stride_v,
+ block_width, block_height, dc_pred_v, 0, NULL);
+ dist *= 16;
+ best_cost = RDCOST(x->rdmult, x->rddiv, *cfl_cost, dist);
+
+ for (int c = 1; c < CFL_ALPHABET_SIZE; c++) {
+ dist_u = cfl_alpha_dist(tmp_pix, MAX_SB_SIZE, y_avg, src_u, src_stride_u,
+ block_width, block_height, dc_pred_u,
+ cfl_alpha_codes[c][CFL_PRED_U], &dist_u_neg);
+ dist_v = cfl_alpha_dist(tmp_pix, MAX_SB_SIZE, y_avg, src_v, src_stride_v,
+ block_width, block_height, dc_pred_v,
+ cfl_alpha_codes[c][CFL_PRED_V], &dist_v_neg);
+ for (int sign_u = cfl_alpha_codes[c][CFL_PRED_U] == 0.0; sign_u < CFL_SIGNS;
+ sign_u++) {
+ for (int sign_v = cfl_alpha_codes[c][CFL_PRED_V] == 0.0;
+ sign_v < CFL_SIGNS; sign_v++) {
+ dist = (sign_u == CFL_SIGN_POS ? dist_u : dist_u_neg) +
+ (sign_v == CFL_SIGN_POS ? dist_v : dist_v_neg);
+ dist *= 16;
+ cost = RDCOST(x->rdmult, x->rddiv, cfl_cost[c], dist);
+ if (cost < best_cost) {
+ best_cost = cost;
+ ind = c;
+ signs[CFL_PRED_U] = sign_u;
+ signs[CFL_PRED_V] = sign_v;
+ }
+ }
+ }
+ }
+
+ return ind;
+}
+
+void av1_predict_intra_block_encoder_facade(MACROBLOCK *x, int plane,
+ int block_idx, int blk_col,
+ int blk_row, TX_SIZE tx_size,
+ BLOCK_SIZE plane_bsize) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
+ if (blk_col == 0 && blk_row == 0 && plane == AOM_PLANE_Y) {
+ mbmi->cfl_alpha_ind = 0;
+ mbmi->cfl_alpha_signs[CFL_PRED_U] = CFL_SIGN_POS;
+ mbmi->cfl_alpha_signs[CFL_PRED_V] = CFL_SIGN_POS;
+ }
+ if (plane != AOM_PLANE_Y && mbmi->uv_mode == DC_PRED) {
+ if (blk_col == 0 && blk_row == 0 && plane == AOM_PLANE_U) {
+#if !CONFIG_EC_ADAPT
+#error "CfL rate estimation requires ec_adapt."
+#endif
+ FRAME_CONTEXT *const ec_ctx = xd->tile_ctx;
+ assert(ec_ctx->cfl_alpha_cdf[CFL_ALPHABET_SIZE - 1] == AOM_ICDF(32768U));
+ const int prob_den = 32768U;
+
+ CFL_CTX *const cfl = xd->cfl;
+ int cfl_costs[CFL_ALPHABET_SIZE];
+ for (int c = 0; c < CFL_ALPHABET_SIZE; c++) {
+ int sign_bit_cost = (cfl_alpha_codes[c][CFL_PRED_U] != 0.0) +
+ (cfl_alpha_codes[c][CFL_PRED_V] != 0.0);
+ int prob_num = AOM_ICDF(ec_ctx->cfl_alpha_cdf[c]);
+ if (c > 0) prob_num -= AOM_ICDF(ec_ctx->cfl_alpha_cdf[c - 1]);
+ cfl_costs[c] = av1_cost_zero(get_prob(prob_num, prob_den)) +
+ av1_cost_literal(sign_bit_cost);
+ }
+ cfl_dc_pred(xd, plane_bsize, tx_size);
+ mbmi->cfl_alpha_ind = cfl_compute_alpha_ind(
+ x, cfl, plane_bsize, cfl_costs, mbmi->cfl_alpha_signs);
+ }
+ }
av1_predict_intra_block_facade(xd, plane, block_idx, blk_col, blk_row,
tx_size);
}