Use CDFs to calcualte cost for the skip bit
Change-Id: I262d9b538988ddcbcac13a217c786fa5df17f8a4
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 3264c82..a4fe7b5 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -219,6 +219,8 @@
int skip_chroma_rd;
#endif
+ int skip_cost[SKIP_CONTEXTS][2];
+
#if CONFIG_LV_MAP
LV_MAP_COEFF_COST coeff_costs[TX_SIZES][PLANE_TYPES];
uint16_t cb_offset;
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 3558c19..2e7d05e 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1029,6 +1029,17 @@
int super_block_upper_left =
((mi_row & MAX_MIB_MASK) == 0) && ((mi_col & MAX_MIB_MASK) == 0);
+ const int seg_ref_active =
+ segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_REF_FRAME);
+
+ if (!seg_ref_active) {
+ const int skip_ctx = av1_get_skip_context(xd);
+ td->counts->skip[skip_ctx][mbmi->skip]++;
+#if CONFIG_NEW_MULTISYMBOL
+ update_cdf(fc->skip_cdfs[skip_ctx], mbmi->skip, 2);
+#endif // CONFIG_NEW_MULTISYMBOL
+ }
+
if (cm->delta_q_present_flag && (bsize != cm->sb_size || !mbmi->skip) &&
super_block_upper_left) {
const int dq = (mbmi->current_q_index - xd->prev_qindex) / cm->delta_q_res;
@@ -1087,8 +1098,6 @@
FRAME_COUNTS *const counts = td->counts;
RD_COUNTS *rdc = &td->rd_counts;
const int inter_block = is_inter_block(mbmi);
- const int seg_ref_active =
- segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_REF_FRAME);
if (!seg_ref_active) {
counts->intra_inter[av1_get_intra_inter_context(xd)][inter_block]++;
#if CONFIG_NEW_MULTISYMBOL
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 7adf8ae..6d8533c 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -120,6 +120,15 @@
#endif // CONFIG_UNPOISON_PARTITION_CTX
}
+ for (i = 0; i < SKIP_CONTEXTS; ++i) {
+#if CONFIG_NEW_MULTISYMBOL
+ av1_cost_tokens_from_cdf(x->skip_cost[i], fc->skip_cdfs[i], NULL);
+#else
+ x->skip_cost[i][0] = av1_cost_bit(fc->skip_probs[i], 0);
+ x->skip_cost[i][1] = av1_cost_bit(fc->skip_probs[i], 1);
+#endif // CONFIG_NEW_MULTISYMBOL
+ }
+
#if CONFIG_KF_CTX
for (i = 0; i < KF_MODE_CONTEXTS; ++i)
for (j = 0; j < KF_MODE_CONTEXTS; ++j)
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 0c071d3..5d84497 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2500,7 +2500,7 @@
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
int64_t rd = INT64_MAX;
- aom_prob skip_prob = av1_get_skip_prob(cm, xd);
+ const int skip_ctx = av1_get_skip_context(xd);
int s0, s1;
const int is_inter = is_inter_block(mbmi);
const int tx_select =
@@ -2511,13 +2511,12 @@
#if CONFIG_PVQ
assert(tx_size >= TX_4X4);
#endif // CONFIG_PVQ
- assert(skip_prob > 0);
#if CONFIG_EXT_TX && CONFIG_RECT_TX
assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
#endif // CONFIG_EXT_TX && CONFIG_RECT_TX
- s0 = av1_cost_bit(skip_prob, 0);
- s1 = av1_cost_bit(skip_prob, 1);
+ s0 = x->skip_cost[skip_ctx][0];
+ s1 = x->skip_cost[skip_ctx][1];
mbmi->tx_type = tx_type;
mbmi->tx_size = tx_size;
@@ -2632,9 +2631,9 @@
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
TX_TYPE tx_type, best_tx_type = DCT_DCT;
int64_t this_rd, best_rd = INT64_MAX;
- aom_prob skip_prob = av1_get_skip_prob(cm, xd);
- int s0 = av1_cost_bit(skip_prob, 0);
- int s1 = av1_cost_bit(skip_prob, 1);
+ const int skip_ctx = av1_get_skip_context(xd);
+ int s0 = x->skip_cost[skip_ctx][0];
+ int s1 = x->skip_cost[skip_ctx][1];
const int is_inter = is_inter_block(mbmi);
int prune = 0;
const int plane = 0;
@@ -5324,14 +5323,16 @@
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
const int is_inter = is_inter_block(mbmi);
- aom_prob skip_prob = av1_get_skip_prob(cm, xd);
- int s0 = av1_cost_bit(skip_prob, 0);
- int s1 = av1_cost_bit(skip_prob, 1);
+ const int skip_ctx = av1_get_skip_context(xd);
+ int s0 = x->skip_cost[skip_ctx][0];
+ int s1 = x->skip_cost[skip_ctx][1];
int64_t rd;
int row, col;
const int max_blocks_high = max_block_high(xd, bsize, 0);
const int max_blocks_wide = max_block_wide(xd, bsize, 0);
+ (void)cm;
+
mbmi->tx_type = tx_type;
inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd);
mbmi->min_tx_size = get_min_tx_size(mbmi->inter_tx_size[0][0]);
@@ -9113,28 +9114,29 @@
mbmi->rd_stats = *rd_stats;
#endif // CONFIG_RD_DEBUG
#if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
+ const int skip_ctx = av1_get_skip_context(xd);
if (rd_stats->skip) {
rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
rd_stats_y->rate = 0;
rd_stats_uv->rate = 0;
- rd_stats->rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ rd_stats->rate += x->skip_cost[skip_ctx][1];
mbmi->skip = 0;
// here mbmi->skip temporarily plays a role as what this_skip2 does
} else if (!xd->lossless[mbmi->segment_id] &&
(RDCOST(x->rdmult,
rd_stats_y->rate + rd_stats_uv->rate +
- av1_cost_bit(av1_get_skip_prob(cm, xd), 0),
- rd_stats->dist) >=
- RDCOST(x->rdmult, av1_cost_bit(av1_get_skip_prob(cm, xd), 1),
- rd_stats->sse))) {
+ x->skip_cost[skip_ctx][0],
+ rd_stats->dist) >= RDCOST(x->rdmult,
+ x->skip_cost[skip_ctx][1],
+ rd_stats->sse))) {
rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
- rd_stats->rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ rd_stats->rate += x->skip_cost[skip_ctx][1];
rd_stats->dist = rd_stats->sse;
rd_stats_y->rate = 0;
rd_stats_uv->rate = 0;
mbmi->skip = 1;
} else {
- rd_stats->rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+ rd_stats->rate += x->skip_cost[skip_ctx][0];
mbmi->skip = 0;
}
*disable_skip = 0;
@@ -9148,7 +9150,7 @@
#if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
mbmi->skip = 0;
#endif // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
- rd_stats->rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ rd_stats->rate += x->skip_cost[av1_get_skip_context(xd)][1];
rd_stats->dist = *skip_sse_sb;
rd_stats->sse = *skip_sse_sb;
@@ -10027,12 +10029,12 @@
mbmi->rd_stats = rd_stats;
#endif
- const aom_prob skip_prob = av1_get_skip_prob(cm, xd);
+ const int skip_ctx = av1_get_skip_context(xd);
RD_STATS rdc_noskip;
av1_init_rd_stats(&rdc_noskip);
rdc_noskip.rate =
- rate_mode + rate_mv + rd_stats.rate + av1_cost_bit(skip_prob, 0);
+ rate_mode + rate_mv + rd_stats.rate + x->skip_cost[skip_ctx][0];
rdc_noskip.dist = rd_stats.dist;
rdc_noskip.rdcost = RDCOST(x->rdmult, rdc_noskip.rate, rdc_noskip.dist);
if (rdc_noskip.rdcost < best_rd) {
@@ -10046,7 +10048,7 @@
mbmi->skip = 1;
RD_STATS rdc_skip;
av1_init_rd_stats(&rdc_skip);
- rdc_skip.rate = rate_mode + rate_mv + av1_cost_bit(skip_prob, 1);
+ rdc_skip.rate = rate_mode + rate_mv + x->skip_cost[skip_ctx][1];
rdc_skip.dist = rd_stats.sse;
rdc_skip.rdcost = RDCOST(x->rdmult, rdc_skip.rate, rdc_skip.dist);
if (rdc_skip.rdcost < best_rd) {
@@ -10076,6 +10078,8 @@
TX_SIZE max_uv_tx_size;
const int unify_bsize = CONFIG_CB4X4;
+ (void)cm;
+
ctx->skip = 0;
mbmi->ref_frame[0] = INTRA_FRAME;
mbmi->ref_frame[1] = NONE_FRAME;
@@ -10127,11 +10131,11 @@
if (y_skip && (uv_skip || x->skip_chroma_rd)) {
rd_cost->rate = rate_y + rate_uv - rate_y_tokenonly - rate_uv_tokenonly +
- av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ x->skip_cost[av1_get_skip_context(xd)][1];
rd_cost->dist = dist_y + dist_uv;
} else {
rd_cost->rate =
- rate_y + rate_uv + av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+ rate_y + rate_uv + x->skip_cost[av1_get_skip_context(xd)][0];
rd_cost->dist = dist_y + dist_uv;
}
rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
@@ -10396,9 +10400,9 @@
rate2 -= (rate_y + rate_uv);
rate_y = 0;
rate_uv = 0;
- rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
} else {
- rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+ rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
}
this_rd = RDCOST(x->rdmult, rate2, distortion2);
@@ -10470,8 +10474,9 @@
int64_t best_pred_diff[REFERENCE_MODES];
int64_t best_pred_rd[REFERENCE_MODES];
MB_MODE_INFO best_mbmode;
- int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
- int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ const int skip_ctx = av1_get_skip_context(xd);
+ int rate_skip0 = x->skip_cost[skip_ctx][0];
+ int rate_skip1 = x->skip_cost[skip_ctx][1];
int best_mode_skippable = 0;
int midx, best_mode_index = -1;
unsigned int ref_costs_single[TOTAL_REFS_PER_FRAME];
@@ -11333,12 +11338,12 @@
if (RDCOST(x->rdmult, rate_y + rate_uv, distortion2) <
RDCOST(x->rdmult, 0, total_sse))
tmp_ref_rd = RDCOST(
- x->rdmult, rate2 + av1_cost_bit(av1_get_skip_prob(cm, xd), 0),
+ x->rdmult, rate2 + x->skip_cost[av1_get_skip_context(xd)][0],
distortion2);
else
tmp_ref_rd =
RDCOST(x->rdmult,
- rate2 + av1_cost_bit(av1_get_skip_prob(cm, xd), 1) -
+ rate2 + x->skip_cost[av1_get_skip_context(xd)][1] -
rate_y - rate_uv,
total_sse);
}
@@ -11494,16 +11499,15 @@
if (RDCOST(x->rdmult, tmp_rd_stats_y.rate + tmp_rd_stats_uv.rate,
tmp_rd_stats.dist) <
RDCOST(x->rdmult, 0, tmp_rd_stats.sse))
- tmp_alt_rd =
- RDCOST(x->rdmult,
- tmp_rd_stats.rate +
- av1_cost_bit(av1_get_skip_prob(cm, xd), 0),
- tmp_rd_stats.dist);
+ tmp_alt_rd = RDCOST(
+ x->rdmult,
+ tmp_rd_stats.rate + x->skip_cost[av1_get_skip_context(xd)][0],
+ tmp_rd_stats.dist);
else
tmp_alt_rd =
RDCOST(x->rdmult,
tmp_rd_stats.rate +
- av1_cost_bit(av1_get_skip_prob(cm, xd), 1) -
+ x->skip_cost[av1_get_skip_context(xd)][1] -
tmp_rd_stats_y.rate - tmp_rd_stats_uv.rate,
tmp_rd_stats.sse);
#endif // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
@@ -11586,15 +11590,15 @@
rate_y = 0;
rate_uv = 0;
// Cost the skip mb case
- rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
} else if (ref_frame != INTRA_FRAME && !xd->lossless[mbmi->segment_id]) {
if (RDCOST(x->rdmult, rate_y + rate_uv + rate_skip0, distortion2) <
RDCOST(x->rdmult, rate_skip1, total_sse)) {
// Add in the cost of the no skip flag.
- rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+ rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
} else {
// FIXME(rbultje) make this work for splitmv also
- rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
distortion2 = total_sse;
assert(total_sse >= 0);
rate2 -= (rate_y + rate_uv);
@@ -11604,7 +11608,7 @@
}
} else {
// Add in the cost of the no skip flag.
- rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+ rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
}
// Calculate the final RD estimate for this mode.
@@ -11660,8 +11664,9 @@
best_mbmode = *mbmi;
best_skip2 = this_skip2;
best_mode_skippable = skippable;
- best_rate_y = rate_y + av1_cost_bit(av1_get_skip_prob(cm, xd),
- this_skip2 || skippable);
+ best_rate_y =
+ rate_y +
+ x->skip_cost[av1_get_skip_context(xd)][this_skip2 || skippable];
best_rate_uv = rate_uv;
#if CONFIG_VAR_TX
for (i = 0; i < MAX_MB_PLANE; ++i)
@@ -11769,13 +11774,13 @@
(rd_stats_y.dist + rd_stats_uv.dist)) >
RDCOST(x->rdmult, 0, (rd_stats_y.sse + rd_stats_uv.sse))) {
skip_blk = 1;
- rd_stats_y.rate = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ rd_stats_y.rate = x->skip_cost[av1_get_skip_context(xd)][1];
rd_stats_uv.rate = 0;
rd_stats_y.dist = rd_stats_y.sse;
rd_stats_uv.dist = rd_stats_uv.sse;
} else {
skip_blk = 0;
- rd_stats_y.rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+ rd_stats_y.rate += x->skip_cost[av1_get_skip_context(xd)][0];
}
if (RDCOST(x->rdmult, best_rate_y + best_rate_uv, rd_cost->dist) >
@@ -11873,9 +11878,9 @@
if (skippable) {
rate2 -= (rd_stats_y.rate + rate_uv_tokenonly[uv_tx]);
- rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
} else {
- rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+ rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
}
this_rd = RDCOST(x->rdmult, rate2, distortion2);
if (this_rd < best_rd) {
@@ -12526,8 +12531,9 @@
int ref, skip_blk, backup_skip = x->skip;
int64_t rd_causal;
RD_STATS rd_stats_y, rd_stats_uv;
- int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
- int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ const int skip_ctx = av1_get_skip_context(xd);
+ int rate_skip0 = x->skip_cost[skip_ctx][0];
+ int rate_skip1 = x->skip_cost[skip_ctx][1];
// Recompute the best causal predictor and rd
mbmi->motion_mode = SIMPLE_TRANSLATION;
@@ -12664,8 +12670,9 @@
xd->mi[0]);
#endif // CONFIG_NCOBMC_ADAPT_WEIGHT && CONFIG_WARPED_MOTION
RD_STATS rd_stats_y, rd_stats_uv;
- int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
- int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+ const int skip_ctx = av1_get_skip_context(xd);
+ int rate_skip0 = x->skip_cost[skip_ctx][0];
+ int rate_skip1 = x->skip_cost[skip_ctx][1];
int64_t this_rd;
int ref;