Add av1_get_palette_mode_ctx()

Change-Id: I6cd366d929d689217f292db07cbeaf1fd35c2055
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index caeef4d..2697a51 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2613,16 +2613,14 @@
 
 // Given the base colors as specified in centroids[], calculate the RD cost
 // of palette mode.
-static void palette_rd_y(const AV1_COMP *const cpi, MACROBLOCK *x,
-                         MB_MODE_INFO *mbmi, BLOCK_SIZE bsize, int palette_ctx,
-                         int dc_mode_cost, const int *data, int *centroids,
-                         int n, uint16_t *color_cache, int n_cache,
-                         MB_MODE_INFO *best_mbmi,
-                         uint8_t *best_palette_color_map, int64_t *best_rd,
-                         int64_t *best_model_rd, int *rate, int *rate_tokenonly,
-                         int *rate_overhead, int64_t *distortion,
-                         int *skippable, PICK_MODE_CONTEXT *ctx,
-                         uint8_t *blk_skip) {
+static void palette_rd_y(
+    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
+    BLOCK_SIZE bsize, int palette_mode_ctx, int dc_mode_cost, const int *data,
+    int *centroids, int n, uint16_t *color_cache, int n_cache,
+    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
+    int64_t *best_model_rd, int *rate, int *rate_tokenonly, int *rate_overhead,
+    int64_t *distortion, int *skippable, PICK_MODE_CONTEXT *ctx,
+    uint8_t *blk_skip) {
   optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
   int k = av1_remove_duplicates(centroids, n);
   if (k < PALETTE_MIN_SIZE) {
@@ -2650,7 +2648,7 @@
   int palette_mode_cost =
       dc_mode_cost + x->palette_y_size_cost[bsize_ctx][k - PALETTE_MIN_SIZE] +
       write_uniform_cost(k, color_map[0]) +
-      x->palette_y_mode_cost[bsize_ctx][palette_ctx][1];
+      x->palette_y_mode_cost[bsize_ctx][palette_mode_ctx][1];
   palette_mode_cost += av1_palette_color_cost_y(pmi, color_cache, n_cache,
                                                 cpi->common.bit_depth);
   palette_mode_cost +=
@@ -2684,11 +2682,11 @@
 }
 
 static int rd_pick_palette_intra_sby(
-    const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int palette_ctx,
-    int dc_mode_cost, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
-    int64_t *best_rd, int64_t *best_model_rd, int *rate, int *rate_tokenonly,
-    int64_t *distortion, int *skippable, PICK_MODE_CONTEXT *ctx,
-    uint8_t *best_blk_skip) {
+    const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
+    int palette_mode_ctx, int dc_mode_cost, MB_MODE_INFO *best_mbmi,
+    uint8_t *best_palette_color_map, int64_t *best_rd, int64_t *best_model_rd,
+    int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
+    PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip) {
   int rate_overhead = 0;
   MACROBLOCKD *const xd = &x->e_mbd;
   MODE_INFO *const mic = xd->mi[0];
@@ -2776,7 +2774,7 @@
     // where the dominant colors and the k-means results are similar.
     for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
       for (i = 0; i < n; ++i) centroids[i] = top_colors[i];
-      palette_rd_y(cpi, x, mbmi, bsize, palette_ctx, dc_mode_cost, data,
+      palette_rd_y(cpi, x, mbmi, bsize, palette_mode_ctx, dc_mode_cost, data,
                    centroids, n, color_cache, n_cache, best_mbmi,
                    best_palette_color_map, best_rd, best_model_rd, rate,
                    rate_tokenonly, &rate_overhead, distortion, skippable, ctx,
@@ -2797,7 +2795,7 @@
         }
         av1_k_means(data, centroids, color_map, rows * cols, n, 1, max_itr);
       }
-      palette_rd_y(cpi, x, mbmi, bsize, palette_ctx, dc_mode_cost, data,
+      palette_rd_y(cpi, x, mbmi, bsize, palette_mode_ctx, dc_mode_cost, data,
                    centroids, n, color_cache, n_cache, best_mbmi,
                    best_palette_color_map, best_rd, best_model_rd, rate,
                    rate_tokenonly, &rate_overhead, distortion, skippable, ctx,
@@ -3287,9 +3285,11 @@
 #endif  // CONFIG_FILTER_INTRA
   const int *bmode_costs;
   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
-  int palette_y_mode_ctx = 0;
   const int try_palette =
       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
+  const int palette_mode_ctx = try_palette ? av1_get_palette_mode_ctx(xd) : 0;
+  const int palette_bsize_ctx =
+      try_palette ? av1_get_palette_bsize_ctx(bsize) : 0;
   uint8_t *best_palette_color_map =
       try_palette ? x->palette_buffer->best_palette_color_map : NULL;
   const MODE_INFO *above_mi = xd->above_mi;
@@ -3318,16 +3318,6 @@
   mbmi->filter_intra_mode_info.use_filter_intra = 0;
 #endif  // CONFIG_FILTER_INTRA
   pmi->palette_size[0] = 0;
-  if (try_palette) {
-    if (above_mi) {
-      palette_y_mode_ctx +=
-          (above_mi->mbmi.palette_mode_info.palette_size[0] > 0);
-    }
-    if (left_mi) {
-      palette_y_mode_ctx +=
-          (left_mi->mbmi.palette_mode_info.palette_size[0] > 0);
-    }
-  }
 
   if (cpi->sf.tx_type_search.fast_intra_tx_type_search)
     x->use_default_intra_tx_type = 1;
@@ -3380,8 +3370,8 @@
           tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
     }
     if (try_palette && mbmi->mode == DC_PRED) {
-      const int bsize_ctx = av1_get_palette_bsize_ctx(bsize);
-      this_rate += x->palette_y_mode_cost[bsize_ctx][palette_y_mode_ctx][0];
+      this_rate +=
+          x->palette_y_mode_cost[palette_bsize_ctx][palette_mode_ctx][0];
     }
 #if CONFIG_FILTER_INTRA
     if (mbmi->mode == DC_PRED && av1_filter_intra_allowed_txsize(mbmi->tx_size))
@@ -3424,7 +3414,7 @@
 
   if (try_palette) {
     rd_pick_palette_intra_sby(
-        cpi, x, bsize, palette_y_mode_ctx, bmode_costs[DC_PRED], &best_mbmi,
+        cpi, x, bsize, palette_mode_ctx, bmode_costs[DC_PRED], &best_mbmi,
         best_palette_color_map, &best_rd, &best_model_rd, rate, rate_tokenonly,
         distortion, skippable, ctx, ctx->blk_skip[0]);
   }
@@ -3444,7 +3434,7 @@
   if (x->use_default_intra_tx_type) {
     *mbmi = best_mbmi;
     x->use_default_intra_tx_type = 0;
-    intra_block_yrd(cpi, x, bsize, bmode_costs, palette_y_mode_ctx, &best_rd,
+    intra_block_yrd(cpi, x, bsize, bmode_costs, palette_mode_ctx, &best_rd,
                     rate, rate_tokenonly, distortion, skippable, &best_mbmi,
                     ctx);
   }
@@ -9064,6 +9054,9 @@
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const int try_palette =
       av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type);
+  const int palette_mode_ctx = try_palette ? av1_get_palette_mode_ctx(xd) : 0;
+  const int palette_bsize_ctx =
+      try_palette ? av1_get_palette_bsize_ctx(bsize) : 0;
   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
   const struct segmentation *const seg = &cm->seg;
@@ -9140,9 +9133,6 @@
 
   const int rows = block_size_high[bsize];
   const int cols = block_size_wide[bsize];
-  int palette_ctx = 0;
-  const MODE_INFO *above_mi = xd->above_mi;
-  const MODE_INFO *left_mi = xd->left_mi;
   int dst_width1[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
   int dst_width2[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
   int dst_height1[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
@@ -9176,14 +9166,7 @@
   }
 
   av1_zero(best_mbmode);
-
   av1_zero(pmi_uv);
-  if (try_palette) {
-    if (above_mi)
-      palette_ctx += (above_mi->mbmi.palette_mode_info.palette_size[0] > 0);
-    if (left_mi)
-      palette_ctx += (left_mi->mbmi.palette_mode_info.palette_size[0] > 0);
-  }
 
   estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
                            ref_costs_comp);
@@ -9781,10 +9764,8 @@
                  x->intra_uv_mode_cost[mbmi->mode][mbmi->uv_mode];
 #endif
 
-      if (try_palette && mbmi->mode == DC_PRED) {
-        const int bsize_ctx = av1_get_palette_bsize_ctx(bsize);
-        rate2 += x->palette_y_mode_cost[bsize_ctx][palette_ctx][0];
-      }
+      if (try_palette && mbmi->mode == DC_PRED)
+        rate2 += x->palette_y_mode_cost[palette_bsize_ctx][palette_mode_ctx][0];
 
       if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(bsize)) {
         // super_block_yrd above includes the cost of the tx_size in the
@@ -10359,7 +10340,7 @@
     mbmi->ref_frame[0] = INTRA_FRAME;
     mbmi->ref_frame[1] = NONE_FRAME;
     rate_overhead_palette = rd_pick_palette_intra_sby(
-        cpi, x, bsize, palette_ctx, intra_mode_cost[DC_PRED],
+        cpi, x, bsize, palette_mode_ctx, intra_mode_cost[DC_PRED],
         &best_mbmi_palette, best_palette_color_map, &best_rd_palette,
         &best_model_rd_palette, NULL, NULL, NULL, NULL, ctx, best_blk_skip);