new_multisymbol: use cdf-based cost of intra/inter flag

Change-Id: I3df4789de2a8c34f725a770128e2062e01efb3b0
diff --git a/av1/common/pred_common.h b/av1/common/pred_common.h
index a944998..bdf6527 100644
--- a/av1/common/pred_common.h
+++ b/av1/common/pred_common.h
@@ -86,11 +86,6 @@
 
 int av1_get_intra_inter_context(const MACROBLOCKD *xd);
 
-static INLINE aom_prob av1_get_intra_inter_prob(const AV1_COMMON *cm,
-                                                const MACROBLOCKD *xd) {
-  return cm->fc->intra_inter_prob[av1_get_intra_inter_context(xd)];
-}
-
 int av1_get_reference_mode_context(const AV1_COMMON *cm, const MACROBLOCKD *xd);
 
 static INLINE aom_prob av1_get_reference_mode_prob(const AV1_COMMON *cm,
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index f21bf51..18a500e 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -402,12 +402,12 @@
 static void write_is_inter(const AV1_COMMON *cm, const MACROBLOCKD *xd,
                            int segment_id, aom_writer *w, const int is_inter) {
   if (!segfeature_active(&cm->seg, segment_id, SEG_LVL_REF_FRAME)) {
+    const int ctx = av1_get_intra_inter_context(xd);
 #if CONFIG_NEW_MULTISYMBOL
     FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
-    const int ctx = av1_get_intra_inter_context(xd);
     aom_write_symbol(w, is_inter, ec_ctx->intra_inter_cdf[ctx], 2);
 #else
-    aom_write(w, is_inter, av1_get_intra_inter_prob(cm, xd));
+    aom_write(w, is_inter, cm->fc->intra_inter_prob[ctx]);
 #endif
   }
 }
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index a4fe7b5..b6c91e0 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -230,6 +230,8 @@
   av1_coeff_cost token_tail_costs[TX_SIZES];
 
   // mode costs
+  int intra_inter_cost[INTRA_INTER_CONTEXTS][2];
+
   int mbmode_cost[BLOCK_SIZE_GROUPS][INTRA_MODES];
   int newmv_mode_cost[NEWMV_MODE_CONTEXTS][2];
   int zeromv_mode_cost[ZEROMV_MODE_CONTEXTS][2];
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 6d8533c..3f5e623 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -298,6 +298,16 @@
 #endif  // CONFIG_INTRABC
 
   if (!frame_is_intra_only(cm)) {
+    for (i = 0; i < INTRA_INTER_CONTEXTS; ++i) {
+#if CONFIG_NEW_MULTISYMBOL
+      av1_cost_tokens_from_cdf(x->intra_inter_cost[i], fc->intra_inter_cdf[i],
+                               NULL);
+#else
+      x->intra_inter_cost[i][0] = av1_cost_bit(fc->intra_inter_prob[i], 0);
+      x->intra_inter_cost[i][1] = av1_cost_bit(fc->intra_inter_prob[i], 1);
+#endif
+    }
+
     for (i = 0; i < NEWMV_MODE_CONTEXTS; ++i) {
 #if CONFIG_NEW_MULTISYMBOL
       av1_cost_tokens_from_cdf(x->newmv_mode_cost[i], fc->newmv_cdf[i], NULL);
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 5d84497..2e3d90c 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -7096,8 +7096,8 @@
 }
 
 static void estimate_ref_frame_costs(
-    const AV1_COMMON *cm, const MACROBLOCKD *xd, int segment_id,
-    unsigned int *ref_costs_single,
+    const AV1_COMMON *cm, const MACROBLOCKD *xd, const MACROBLOCK *x,
+    int segment_id, unsigned int *ref_costs_single,
 #if CONFIG_EXT_COMP_REFS
     unsigned int (*ref_costs_comp)[TOTAL_REFS_PER_FRAME],
 #else
@@ -7120,7 +7120,7 @@
 
     *comp_mode_p = 128;
   } else {
-    aom_prob intra_inter_p = av1_get_intra_inter_prob(cm, xd);
+    int intra_inter_ctx = av1_get_intra_inter_context(xd);
     aom_prob comp_inter_p = 128;
 
     if (cm->reference_mode == REFERENCE_MODE_SELECT) {
@@ -7130,7 +7130,7 @@
       *comp_mode_p = 128;
     }
 
-    ref_costs_single[INTRA_FRAME] = av1_cost_bit(intra_inter_p, 0);
+    ref_costs_single[INTRA_FRAME] = x->intra_inter_cost[intra_inter_ctx][0];
 
     if (cm->reference_mode != COMPOUND_REFERENCE) {
       aom_prob ref_single_p1 = av1_get_pred_prob_single_ref_p1(cm, xd);
@@ -7140,7 +7140,7 @@
       aom_prob ref_single_p5 = av1_get_pred_prob_single_ref_p5(cm, xd);
       aom_prob ref_single_p6 = av1_get_pred_prob_single_ref_p6(cm, xd);
 
-      unsigned int base_cost = av1_cost_bit(intra_inter_p, 1);
+      unsigned int base_cost = x->intra_inter_cost[intra_inter_ctx][1];
 
       ref_costs_single[LAST_FRAME] = ref_costs_single[LAST2_FRAME] =
           ref_costs_single[LAST3_FRAME] = ref_costs_single[BWDREF_FRAME] =
@@ -7189,7 +7189,7 @@
       aom_prob bwdref_comp_p = av1_get_pred_prob_comp_bwdref_p(cm, xd);
       aom_prob bwdref_comp_p1 = av1_get_pred_prob_comp_bwdref_p1(cm, xd);
 
-      unsigned int base_cost = av1_cost_bit(intra_inter_p, 1);
+      unsigned int base_cost = x->intra_inter_cost[intra_inter_ctx][1];
 
 #if CONFIG_EXT_COMP_REFS
       aom_prob comp_ref_type_p = av1_get_comp_reference_type_prob(cm, xd);
@@ -10583,8 +10583,8 @@
       palette_ctx += (left_mi->mbmi.palette_mode_info.palette_size[0] > 0);
   }
 
-  estimate_ref_frame_costs(cm, xd, segment_id, ref_costs_single, ref_costs_comp,
-                           &comp_mode_p);
+  estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
+                           ref_costs_comp, &comp_mode_p);
 
   for (i = 0; i < REFERENCE_MODES; ++i) best_pred_rd[i] = INT64_MAX;
   for (i = 0; i < TX_SIZES_ALL; i++) rate_uv_intra[i] = INT_MAX;
@@ -12163,8 +12163,8 @@
   (void)mi_row;
   (void)mi_col;
 
-  estimate_ref_frame_costs(cm, xd, segment_id, ref_costs_single, ref_costs_comp,
-                           &comp_mode_p);
+  estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
+                           ref_costs_comp, &comp_mode_p);
 
   for (i = 0; i < TOTAL_REFS_PER_FRAME; ++i) x->pred_sse[i] = INT_MAX;
   for (i = LAST_FRAME; i < TOTAL_REFS_PER_FRAME; ++i)