new_multisymbol: use cdf-based cost of intra/inter flag
Change-Id: I3df4789de2a8c34f725a770128e2062e01efb3b0
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index a944998..bdf6527 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -86,11 +86,6 @@
int av1_get_intra_inter_context(const MACROBLOCKD *xd);
-static INLINE aom_prob av1_get_intra_inter_prob(const AV1_COMMON *cm,
- const MACROBLOCKD *xd) {
- return cm->fc->intra_inter_prob[av1_get_intra_inter_context(xd)];
-}
-
int av1_get_reference_mode_context(const AV1_COMMON *cm, const MACROBLOCKD *xd);
static INLINE aom_prob av1_get_reference_mode_prob(const AV1_COMMON *cm,
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index f21bf51..18a500e 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -402,12 +402,12 @@
static void write_is_inter(const AV1_COMMON *cm, const MACROBLOCKD *xd,
int segment_id, aom_writer *w, const int is_inter) {
if (!segfeature_active(&cm->seg, segment_id, SEG_LVL_REF_FRAME)) {
+ const int ctx = av1_get_intra_inter_context(xd);
#if CONFIG_NEW_MULTISYMBOL
FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
- const int ctx = av1_get_intra_inter_context(xd);
aom_write_symbol(w, is_inter, ec_ctx->intra_inter_cdf[ctx], 2);
#else
- aom_write(w, is_inter, av1_get_intra_inter_prob(cm, xd));
+ aom_write(w, is_inter, cm->fc->intra_inter_prob[ctx]);
#endif
}
}
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index a4fe7b5..b6c91e0 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -230,6 +230,8 @@
av1_coeff_cost token_tail_costs[TX_SIZES];
// mode costs
+ int intra_inter_cost[INTRA_INTER_CONTEXTS][2];
+
int mbmode_cost[BLOCK_SIZE_GROUPS][INTRA_MODES];
int newmv_mode_cost[NEWMV_MODE_CONTEXTS][2];
int zeromv_mode_cost[ZEROMV_MODE_CONTEXTS][2];
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 6d8533c..3f5e623 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -298,6 +298,16 @@
#endif // CONFIG_INTRABC
if (!frame_is_intra_only(cm)) {
+ for (i = 0; i < INTRA_INTER_CONTEXTS; ++i) {
+#if CONFIG_NEW_MULTISYMBOL
+ av1_cost_tokens_from_cdf(x->intra_inter_cost[i], fc->intra_inter_cdf[i],
+ NULL);
+#else
+ x->intra_inter_cost[i][0] = av1_cost_bit(fc->intra_inter_prob[i], 0);
+ x->intra_inter_cost[i][1] = av1_cost_bit(fc->intra_inter_prob[i], 1);
+#endif
+ }
+
for (i = 0; i < NEWMV_MODE_CONTEXTS; ++i) {
#if CONFIG_NEW_MULTISYMBOL
av1_cost_tokens_from_cdf(x->newmv_mode_cost[i], fc->newmv_cdf[i], NULL);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 5d84497..2e3d90c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7096,8 +7096,8 @@
}
static void estimate_ref_frame_costs(
- const AV1_COMMON *cm, const MACROBLOCKD *xd, int segment_id,
- unsigned int *ref_costs_single,
+ const AV1_COMMON *cm, const MACROBLOCKD *xd, const MACROBLOCK *x,
+ int segment_id, unsigned int *ref_costs_single,
#if CONFIG_EXT_COMP_REFS
unsigned int (*ref_costs_comp)[TOTAL_REFS_PER_FRAME],
#else
@@ -7120,7 +7120,7 @@
*comp_mode_p = 128;
} else {
- aom_prob intra_inter_p = av1_get_intra_inter_prob(cm, xd);
+ int intra_inter_ctx = av1_get_intra_inter_context(xd);
aom_prob comp_inter_p = 128;
if (cm->reference_mode == REFERENCE_MODE_SELECT) {
@@ -7130,7 +7130,7 @@
*comp_mode_p = 128;
}
- ref_costs_single[INTRA_FRAME] = av1_cost_bit(intra_inter_p, 0);
+ ref_costs_single[INTRA_FRAME] = x->intra_inter_cost[intra_inter_ctx][0];
if (cm->reference_mode != COMPOUND_REFERENCE) {
aom_prob ref_single_p1 = av1_get_pred_prob_single_ref_p1(cm, xd);
@@ -7140,7 +7140,7 @@
aom_prob ref_single_p5 = av1_get_pred_prob_single_ref_p5(cm, xd);
aom_prob ref_single_p6 = av1_get_pred_prob_single_ref_p6(cm, xd);
- unsigned int base_cost = av1_cost_bit(intra_inter_p, 1);
+ unsigned int base_cost = x->intra_inter_cost[intra_inter_ctx][1];
ref_costs_single[LAST_FRAME] = ref_costs_single[LAST2_FRAME] =
ref_costs_single[LAST3_FRAME] = ref_costs_single[BWDREF_FRAME] =
@@ -7189,7 +7189,7 @@
aom_prob bwdref_comp_p = av1_get_pred_prob_comp_bwdref_p(cm, xd);
aom_prob bwdref_comp_p1 = av1_get_pred_prob_comp_bwdref_p1(cm, xd);
- unsigned int base_cost = av1_cost_bit(intra_inter_p, 1);
+ unsigned int base_cost = x->intra_inter_cost[intra_inter_ctx][1];
#if CONFIG_EXT_COMP_REFS
aom_prob comp_ref_type_p = av1_get_comp_reference_type_prob(cm, xd);
@@ -10583,8 +10583,8 @@
palette_ctx += (left_mi->mbmi.palette_mode_info.palette_size[0] > 0);
}
- estimate_ref_frame_costs(cm, xd, segment_id, ref_costs_single, ref_costs_comp,
- &comp_mode_p);
+ estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
+ ref_costs_comp, &comp_mode_p);
for (i = 0; i < REFERENCE_MODES; ++i) best_pred_rd[i] = INT64_MAX;
for (i = 0; i < TX_SIZES_ALL; i++) rate_uv_intra[i] = INT_MAX;
@@ -12163,8 +12163,8 @@
(void)mi_row;
(void)mi_col;
- estimate_ref_frame_costs(cm, xd, segment_id, ref_costs_single, ref_costs_comp,
- &comp_mode_p);
+ estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
+ ref_costs_comp, &comp_mode_p);
for (i = 0; i < TOTAL_REFS_PER_FRAME; ++i) x->pred_sse[i] = INT_MAX;
for (i = LAST_FRAME; i < TOTAL_REFS_PER_FRAME; ++i)