Use CDFs to calcualte cost for the skip bit

Change-Id: I262d9b538988ddcbcac13a217c786fa5df17f8a4
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 3264c82..a4fe7b5 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -219,6 +219,8 @@
   int skip_chroma_rd;
 #endif
 
+  int skip_cost[SKIP_CONTEXTS][2];
+
 #if CONFIG_LV_MAP
   LV_MAP_COEFF_COST coeff_costs[TX_SIZES][PLANE_TYPES];
   uint16_t cb_offset;
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 3558c19..2e7d05e 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -1029,6 +1029,17 @@
   int super_block_upper_left =
       ((mi_row & MAX_MIB_MASK) == 0) && ((mi_col & MAX_MIB_MASK) == 0);
 
+  const int seg_ref_active =
+      segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_REF_FRAME);
+
+  if (!seg_ref_active) {
+    const int skip_ctx = av1_get_skip_context(xd);
+    td->counts->skip[skip_ctx][mbmi->skip]++;
+#if CONFIG_NEW_MULTISYMBOL
+    update_cdf(fc->skip_cdfs[skip_ctx], mbmi->skip, 2);
+#endif  // CONFIG_NEW_MULTISYMBOL
+  }
+
   if (cm->delta_q_present_flag && (bsize != cm->sb_size || !mbmi->skip) &&
       super_block_upper_left) {
     const int dq = (mbmi->current_q_index - xd->prev_qindex) / cm->delta_q_res;
@@ -1087,8 +1098,6 @@
     FRAME_COUNTS *const counts = td->counts;
     RD_COUNTS *rdc = &td->rd_counts;
     const int inter_block = is_inter_block(mbmi);
-    const int seg_ref_active =
-        segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_REF_FRAME);
     if (!seg_ref_active) {
       counts->intra_inter[av1_get_intra_inter_context(xd)][inter_block]++;
 #if CONFIG_NEW_MULTISYMBOL
diff --git a/av1/encoder/rd.c b/av1/encoder/rd.c
index 7adf8ae..6d8533c 100644
--- a/av1/encoder/rd.c
+++ b/av1/encoder/rd.c
@@ -120,6 +120,15 @@
 #endif  // CONFIG_UNPOISON_PARTITION_CTX
   }
 
+  for (i = 0; i < SKIP_CONTEXTS; ++i) {
+#if CONFIG_NEW_MULTISYMBOL
+    av1_cost_tokens_from_cdf(x->skip_cost[i], fc->skip_cdfs[i], NULL);
+#else
+    x->skip_cost[i][0] = av1_cost_bit(fc->skip_probs[i], 0);
+    x->skip_cost[i][1] = av1_cost_bit(fc->skip_probs[i], 1);
+#endif  // CONFIG_NEW_MULTISYMBOL
+  }
+
 #if CONFIG_KF_CTX
   for (i = 0; i < KF_MODE_CONTEXTS; ++i)
     for (j = 0; j < KF_MODE_CONTEXTS; ++j)
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 0c071d3..5d84497 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2500,7 +2500,7 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   int64_t rd = INT64_MAX;
-  aom_prob skip_prob = av1_get_skip_prob(cm, xd);
+  const int skip_ctx = av1_get_skip_context(xd);
   int s0, s1;
   const int is_inter = is_inter_block(mbmi);
   const int tx_select =
@@ -2511,13 +2511,12 @@
 #if CONFIG_PVQ
   assert(tx_size >= TX_4X4);
 #endif  // CONFIG_PVQ
-  assert(skip_prob > 0);
 #if CONFIG_EXT_TX && CONFIG_RECT_TX
   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
 #endif  // CONFIG_EXT_TX && CONFIG_RECT_TX
 
-  s0 = av1_cost_bit(skip_prob, 0);
-  s1 = av1_cost_bit(skip_prob, 1);
+  s0 = x->skip_cost[skip_ctx][0];
+  s1 = x->skip_cost[skip_ctx][1];
 
   mbmi->tx_type = tx_type;
   mbmi->tx_size = tx_size;
@@ -2632,9 +2631,9 @@
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   TX_TYPE tx_type, best_tx_type = DCT_DCT;
   int64_t this_rd, best_rd = INT64_MAX;
-  aom_prob skip_prob = av1_get_skip_prob(cm, xd);
-  int s0 = av1_cost_bit(skip_prob, 0);
-  int s1 = av1_cost_bit(skip_prob, 1);
+  const int skip_ctx = av1_get_skip_context(xd);
+  int s0 = x->skip_cost[skip_ctx][0];
+  int s1 = x->skip_cost[skip_ctx][1];
   const int is_inter = is_inter_block(mbmi);
   int prune = 0;
   const int plane = 0;
@@ -5324,14 +5323,16 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const int is_inter = is_inter_block(mbmi);
-  aom_prob skip_prob = av1_get_skip_prob(cm, xd);
-  int s0 = av1_cost_bit(skip_prob, 0);
-  int s1 = av1_cost_bit(skip_prob, 1);
+  const int skip_ctx = av1_get_skip_context(xd);
+  int s0 = x->skip_cost[skip_ctx][0];
+  int s1 = x->skip_cost[skip_ctx][1];
   int64_t rd;
   int row, col;
   const int max_blocks_high = max_block_high(xd, bsize, 0);
   const int max_blocks_wide = max_block_wide(xd, bsize, 0);
 
+  (void)cm;
+
   mbmi->tx_type = tx_type;
   inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd);
   mbmi->min_tx_size = get_min_tx_size(mbmi->inter_tx_size[0][0]);
@@ -9113,28 +9114,29 @@
       mbmi->rd_stats = *rd_stats;
 #endif  // CONFIG_RD_DEBUG
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
+      const int skip_ctx = av1_get_skip_context(xd);
       if (rd_stats->skip) {
         rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
         rd_stats_y->rate = 0;
         rd_stats_uv->rate = 0;
-        rd_stats->rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+        rd_stats->rate += x->skip_cost[skip_ctx][1];
         mbmi->skip = 0;
         // here mbmi->skip temporarily plays a role as what this_skip2 does
       } else if (!xd->lossless[mbmi->segment_id] &&
                  (RDCOST(x->rdmult,
                          rd_stats_y->rate + rd_stats_uv->rate +
-                             av1_cost_bit(av1_get_skip_prob(cm, xd), 0),
-                         rd_stats->dist) >=
-                  RDCOST(x->rdmult, av1_cost_bit(av1_get_skip_prob(cm, xd), 1),
-                         rd_stats->sse))) {
+                             x->skip_cost[skip_ctx][0],
+                         rd_stats->dist) >= RDCOST(x->rdmult,
+                                                   x->skip_cost[skip_ctx][1],
+                                                   rd_stats->sse))) {
         rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
-        rd_stats->rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+        rd_stats->rate += x->skip_cost[skip_ctx][1];
         rd_stats->dist = rd_stats->sse;
         rd_stats_y->rate = 0;
         rd_stats_uv->rate = 0;
         mbmi->skip = 1;
       } else {
-        rd_stats->rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+        rd_stats->rate += x->skip_cost[skip_ctx][0];
         mbmi->skip = 0;
       }
       *disable_skip = 0;
@@ -9148,7 +9150,7 @@
 #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
       mbmi->skip = 0;
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
-      rd_stats->rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+      rd_stats->rate += x->skip_cost[av1_get_skip_context(xd)][1];
 
       rd_stats->dist = *skip_sse_sb;
       rd_stats->sse = *skip_sse_sb;
@@ -10027,12 +10029,12 @@
     mbmi->rd_stats = rd_stats;
 #endif
 
-    const aom_prob skip_prob = av1_get_skip_prob(cm, xd);
+    const int skip_ctx = av1_get_skip_context(xd);
 
     RD_STATS rdc_noskip;
     av1_init_rd_stats(&rdc_noskip);
     rdc_noskip.rate =
-        rate_mode + rate_mv + rd_stats.rate + av1_cost_bit(skip_prob, 0);
+        rate_mode + rate_mv + rd_stats.rate + x->skip_cost[skip_ctx][0];
     rdc_noskip.dist = rd_stats.dist;
     rdc_noskip.rdcost = RDCOST(x->rdmult, rdc_noskip.rate, rdc_noskip.dist);
     if (rdc_noskip.rdcost < best_rd) {
@@ -10046,7 +10048,7 @@
     mbmi->skip = 1;
     RD_STATS rdc_skip;
     av1_init_rd_stats(&rdc_skip);
-    rdc_skip.rate = rate_mode + rate_mv + av1_cost_bit(skip_prob, 1);
+    rdc_skip.rate = rate_mode + rate_mv + x->skip_cost[skip_ctx][1];
     rdc_skip.dist = rd_stats.sse;
     rdc_skip.rdcost = RDCOST(x->rdmult, rdc_skip.rate, rdc_skip.dist);
     if (rdc_skip.rdcost < best_rd) {
@@ -10076,6 +10078,8 @@
   TX_SIZE max_uv_tx_size;
   const int unify_bsize = CONFIG_CB4X4;
 
+  (void)cm;
+
   ctx->skip = 0;
   mbmi->ref_frame[0] = INTRA_FRAME;
   mbmi->ref_frame[1] = NONE_FRAME;
@@ -10127,11 +10131,11 @@
 
     if (y_skip && (uv_skip || x->skip_chroma_rd)) {
       rd_cost->rate = rate_y + rate_uv - rate_y_tokenonly - rate_uv_tokenonly +
-                      av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+                      x->skip_cost[av1_get_skip_context(xd)][1];
       rd_cost->dist = dist_y + dist_uv;
     } else {
       rd_cost->rate =
-          rate_y + rate_uv + av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+          rate_y + rate_uv + x->skip_cost[av1_get_skip_context(xd)][0];
       rd_cost->dist = dist_y + dist_uv;
     }
     rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
@@ -10396,9 +10400,9 @@
     rate2 -= (rate_y + rate_uv);
     rate_y = 0;
     rate_uv = 0;
-    rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+    rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
   } else {
-    rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+    rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
   }
   this_rd = RDCOST(x->rdmult, rate2, distortion2);
 
@@ -10470,8 +10474,9 @@
   int64_t best_pred_diff[REFERENCE_MODES];
   int64_t best_pred_rd[REFERENCE_MODES];
   MB_MODE_INFO best_mbmode;
-  int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
-  int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+  const int skip_ctx = av1_get_skip_context(xd);
+  int rate_skip0 = x->skip_cost[skip_ctx][0];
+  int rate_skip1 = x->skip_cost[skip_ctx][1];
   int best_mode_skippable = 0;
   int midx, best_mode_index = -1;
   unsigned int ref_costs_single[TOTAL_REFS_PER_FRAME];
@@ -11333,12 +11338,12 @@
           if (RDCOST(x->rdmult, rate_y + rate_uv, distortion2) <
               RDCOST(x->rdmult, 0, total_sse))
             tmp_ref_rd = RDCOST(
-                x->rdmult, rate2 + av1_cost_bit(av1_get_skip_prob(cm, xd), 0),
+                x->rdmult, rate2 + x->skip_cost[av1_get_skip_context(xd)][0],
                 distortion2);
           else
             tmp_ref_rd =
                 RDCOST(x->rdmult,
-                       rate2 + av1_cost_bit(av1_get_skip_prob(cm, xd), 1) -
+                       rate2 + x->skip_cost[av1_get_skip_context(xd)][1] -
                            rate_y - rate_uv,
                        total_sse);
         }
@@ -11494,16 +11499,15 @@
             if (RDCOST(x->rdmult, tmp_rd_stats_y.rate + tmp_rd_stats_uv.rate,
                        tmp_rd_stats.dist) <
                 RDCOST(x->rdmult, 0, tmp_rd_stats.sse))
-              tmp_alt_rd =
-                  RDCOST(x->rdmult,
-                         tmp_rd_stats.rate +
-                             av1_cost_bit(av1_get_skip_prob(cm, xd), 0),
-                         tmp_rd_stats.dist);
+              tmp_alt_rd = RDCOST(
+                  x->rdmult,
+                  tmp_rd_stats.rate + x->skip_cost[av1_get_skip_context(xd)][0],
+                  tmp_rd_stats.dist);
             else
               tmp_alt_rd =
                   RDCOST(x->rdmult,
                          tmp_rd_stats.rate +
-                             av1_cost_bit(av1_get_skip_prob(cm, xd), 1) -
+                             x->skip_cost[av1_get_skip_context(xd)][1] -
                              tmp_rd_stats_y.rate - tmp_rd_stats_uv.rate,
                          tmp_rd_stats.sse);
 #endif  // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION
@@ -11586,15 +11590,15 @@
         rate_y = 0;
         rate_uv = 0;
         // Cost the skip mb case
-        rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+        rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
       } else if (ref_frame != INTRA_FRAME && !xd->lossless[mbmi->segment_id]) {
         if (RDCOST(x->rdmult, rate_y + rate_uv + rate_skip0, distortion2) <
             RDCOST(x->rdmult, rate_skip1, total_sse)) {
           // Add in the cost of the no skip flag.
-          rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+          rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
         } else {
           // FIXME(rbultje) make this work for splitmv also
-          rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+          rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
           distortion2 = total_sse;
           assert(total_sse >= 0);
           rate2 -= (rate_y + rate_uv);
@@ -11604,7 +11608,7 @@
         }
       } else {
         // Add in the cost of the no skip flag.
-        rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+        rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
       }
 
       // Calculate the final RD estimate for this mode.
@@ -11660,8 +11664,9 @@
         best_mbmode = *mbmi;
         best_skip2 = this_skip2;
         best_mode_skippable = skippable;
-        best_rate_y = rate_y + av1_cost_bit(av1_get_skip_prob(cm, xd),
-                                            this_skip2 || skippable);
+        best_rate_y =
+            rate_y +
+            x->skip_cost[av1_get_skip_context(xd)][this_skip2 || skippable];
         best_rate_uv = rate_uv;
 #if CONFIG_VAR_TX
         for (i = 0; i < MAX_MB_PLANE; ++i)
@@ -11769,13 +11774,13 @@
                (rd_stats_y.dist + rd_stats_uv.dist)) >
         RDCOST(x->rdmult, 0, (rd_stats_y.sse + rd_stats_uv.sse))) {
       skip_blk = 1;
-      rd_stats_y.rate = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+      rd_stats_y.rate = x->skip_cost[av1_get_skip_context(xd)][1];
       rd_stats_uv.rate = 0;
       rd_stats_y.dist = rd_stats_y.sse;
       rd_stats_uv.dist = rd_stats_uv.sse;
     } else {
       skip_blk = 0;
-      rd_stats_y.rate += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+      rd_stats_y.rate += x->skip_cost[av1_get_skip_context(xd)][0];
     }
 
     if (RDCOST(x->rdmult, best_rate_y + best_rate_uv, rd_cost->dist) >
@@ -11873,9 +11878,9 @@
 
     if (skippable) {
       rate2 -= (rd_stats_y.rate + rate_uv_tokenonly[uv_tx]);
-      rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+      rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
     } else {
-      rate2 += av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
+      rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
     }
     this_rd = RDCOST(x->rdmult, rate2, distortion2);
     if (this_rd < best_rd) {
@@ -12526,8 +12531,9 @@
   int ref, skip_blk, backup_skip = x->skip;
   int64_t rd_causal;
   RD_STATS rd_stats_y, rd_stats_uv;
-  int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
-  int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+  const int skip_ctx = av1_get_skip_context(xd);
+  int rate_skip0 = x->skip_cost[skip_ctx][0];
+  int rate_skip1 = x->skip_cost[skip_ctx][1];
 
   // Recompute the best causal predictor and rd
   mbmi->motion_mode = SIMPLE_TRANSLATION;
@@ -12664,8 +12670,9 @@
       xd->mi[0]);
 #endif  // CONFIG_NCOBMC_ADAPT_WEIGHT && CONFIG_WARPED_MOTION
   RD_STATS rd_stats_y, rd_stats_uv;
-  int rate_skip0 = av1_cost_bit(av1_get_skip_prob(cm, xd), 0);
-  int rate_skip1 = av1_cost_bit(av1_get_skip_prob(cm, xd), 1);
+  const int skip_ctx = av1_get_skip_context(xd);
+  int rate_skip0 = x->skip_cost[skip_ctx][0];
+  int rate_skip1 = x->skip_cost[skip_ctx][1];
   int64_t this_rd;
   int ref;