palette: add one more method for base color selection

In addition to k-means, consider using the dominant colors directly.

Improve keyframe by about 1% on the screen_content testset.

Change-Id: I08a932c322cfe36fb8def778d14f96d71c1017db
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 33093a5..e458896 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -3790,13 +3790,15 @@
   const int limit = 4;
   for (int r = 0; r + blk_h <= height; r += blk_h) {
     for (int c = 0; c + blk_w <= width; c += blk_w) {
+      int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
       const int n_colors =
 #if CONFIG_HIGHBITDEPTH
           use_hbd ? av1_count_colors_highbd(src + r * stride + c, stride, blk_w,
-                                            blk_h, bd)
+                                            blk_h, bd, count_buf)
                   :
 #endif  // CONFIG_HIGHBITDEPTH
-                  av1_count_colors(src + r * stride + c, stride, blk_w, blk_h);
+                  av1_count_colors(src + r * stride + c, stride, blk_w, blk_h,
+                                   count_buf);
       if (n_colors > 1 && n_colors <= limit) counts++;
     }
   }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 75e2898..6352c83 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1829,16 +1829,19 @@
                                   visible_rows);
 }
 
-int av1_count_colors(const uint8_t *src, int stride, int rows, int cols) {
-  int val_count[256];
-  memset(val_count, 0, sizeof(val_count));
+int av1_count_colors(const uint8_t *src, int stride, int rows, int cols,
+                     int *val_count) {
+  const int max_pix_val = 1 << 8;
+  memset(val_count, 0, max_pix_val * sizeof(val_count[0]));
   for (int r = 0; r < rows; ++r) {
     for (int c = 0; c < cols; ++c) {
-      ++val_count[src[r * stride + c]];
+      const int this_val = src[r * stride + c];
+      assert(this_val < max_pix_val);
+      ++val_count[this_val];
     }
   }
   int n = 0;
-  for (int i = 0; i < 256; ++i) {
+  for (int i = 0; i < max_pix_val; ++i) {
     if (val_count[i]) ++n;
   }
   return n;
@@ -1846,18 +1849,20 @@
 
 #if CONFIG_HIGHBITDEPTH
 int av1_count_colors_highbd(const uint8_t *src8, int stride, int rows, int cols,
-                            int bit_depth) {
+                            int bit_depth, int *val_count) {
   assert(bit_depth <= 12);
+  const int max_pix_val = 1 << bit_depth;
   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
-  int val_count[1 << 12];
-  memset(val_count, 0, (1 << 12) * sizeof(val_count[0]));
+  memset(val_count, 0, max_pix_val * sizeof(val_count[0]));
   for (int r = 0; r < rows; ++r) {
     for (int c = 0; c < cols; ++c) {
-      ++val_count[src[r * stride + c]];
+      const int this_val = src[r * stride + c];
+      assert(this_val < max_pix_val);
+      ++val_count[this_val];
     }
   }
   int n = 0;
-  for (int i = 0; i < (1 << bit_depth); ++i) {
+  for (int i = 0; i < max_pix_val; ++i) {
     if (val_count[i]) ++n;
   }
   return n;
@@ -2850,6 +2855,86 @@
 }
 #endif  // CONFIG_PALETTE_DELTA_ENCODING
 
+// Given the base colors as specified in centroids[], calculate the RD cost
+// of palette mode.
+static void palette_rd_y(const AV1_COMP *const cpi, MACROBLOCK *x,
+                         MB_MODE_INFO *mbmi, BLOCK_SIZE bsize, int palette_ctx,
+                         int dc_mode_cost, const float *data, float *centroids,
+                         int n,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                         uint16_t *color_cache, int n_cache,
+#endif
+                         MB_MODE_INFO *best_mbmi,
+                         uint8_t *best_palette_color_map, int64_t *best_rd,
+                         int64_t *best_model_rd, int *rate, int *rate_tokenonly,
+                         int *rate_overhead, int64_t *distortion,
+                         int *skippable) {
+#if CONFIG_PALETTE_DELTA_ENCODING
+  optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+  int k = av1_remove_duplicates(centroids, n);
+  if (k < PALETTE_MIN_SIZE) {
+    // Too few unique colors to create a palette. And DC_PRED will work
+    // well for that case anyway. So skip.
+    return;
+  }
+  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
+#if CONFIG_HIGHBITDEPTH
+  if (cpi->common.use_highbitdepth)
+    for (int i = 0; i < k; ++i)
+      pmi->palette_colors[i] =
+          clip_pixel_highbd((int)centroids[i], cpi->common.bit_depth);
+  else
+#endif  // CONFIG_HIGHBITDEPTH
+    for (int i = 0; i < k; ++i)
+      pmi->palette_colors[i] = clip_pixel((int)centroids[i]);
+  pmi->palette_size[0] = k;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  uint8_t *const color_map = xd->plane[0].color_index_map;
+  int block_width, block_height, rows, cols;
+  av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
+                           &cols);
+  av1_calc_indices(data, centroids, color_map, rows * cols, k, 1);
+  extend_palette_color_map(color_map, cols, rows, block_width, block_height);
+  int palette_mode_cost =
+      dc_mode_cost +
+      x->palette_y_size_cost[bsize - BLOCK_8X8][k - PALETTE_MIN_SIZE] +
+      write_uniform_cost(k, color_map[0]) +
+      x->palette_y_mode_cost[bsize - BLOCK_8X8][palette_ctx][1];
+  palette_mode_cost += av1_palette_color_cost_y(pmi,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                                                color_cache, n_cache,
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+                                                cpi->common.bit_depth);
+  palette_mode_cost +=
+      av1_cost_color_map(x, 0, 0, bsize, mbmi->tx_size, PALETTE_MAP);
+  int64_t this_model_rd = intra_model_yrd(cpi, x, bsize, palette_mode_cost);
+  if (*best_model_rd != INT64_MAX &&
+      this_model_rd > *best_model_rd + (*best_model_rd >> 1))
+    return;
+  if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
+  RD_STATS tokenonly_rd_stats;
+  super_block_yrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
+  if (tokenonly_rd_stats.rate == INT_MAX) return;
+  int this_rate = tokenonly_rd_stats.rate + palette_mode_cost;
+  int64_t this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
+  if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->sb_type)) {
+    tokenonly_rd_stats.rate -=
+        tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
+  }
+  if (this_rd < *best_rd) {
+    *best_rd = this_rd;
+    memcpy(best_palette_color_map, color_map,
+           block_width * block_height * sizeof(color_map[0]));
+    *best_mbmi = *mbmi;
+    *rate_overhead = this_rate - tokenonly_rd_stats.rate;
+    if (rate) *rate = this_rate;
+    if (rate_tokenonly) *rate_tokenonly = tokenonly_rd_stats.rate;
+    if (distortion) *distortion = tokenonly_rd_stats.dist;
+    if (skippable) *skippable = tokenonly_rd_stats.skip;
+  }
+}
+
 static int rd_pick_palette_intra_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
                                      BLOCK_SIZE bsize, int palette_ctx,
                                      int dc_mode_cost, MB_MODE_INFO *best_mbmi,
@@ -2863,7 +2948,7 @@
   MB_MODE_INFO *const mbmi = &mic->mbmi;
   assert(!is_inter_block(mbmi));
   assert(bsize >= BLOCK_8X8);
-  int this_rate, colors, n;
+  int colors, n;
   const int src_stride = x->plane[0].src.stride;
   const uint8_t *const src = x->plane[0].src.buf;
   uint8_t *const color_map = xd->plane[0].color_index_map;
@@ -2873,27 +2958,25 @@
 
   assert(cpi->common.allow_screen_content_tools);
 
+  int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
 #if CONFIG_HIGHBITDEPTH
   if (cpi->common.use_highbitdepth)
     colors = av1_count_colors_highbd(src, src_stride, rows, cols,
-                                     cpi->common.bit_depth);
+                                     cpi->common.bit_depth, count_buf);
   else
 #endif  // CONFIG_HIGHBITDEPTH
-    colors = av1_count_colors(src, src_stride, rows, cols);
+    colors = av1_count_colors(src, src_stride, rows, cols, count_buf);
 #if CONFIG_FILTER_INTRA
   mbmi->filter_intra_mode_info.use_filter_intra_mode[0] = 0;
 #endif  // CONFIG_FILTER_INTRA
 
   if (colors > 1 && colors <= 64) {
     aom_clear_system_state();
-    int r, c, i, k, palette_mode_cost;
+    int r, c, i;
     const int max_itr = 50;
     float *const data = x->palette_buffer->kmeans_data_buf;
     float centroids[PALETTE_MAX_SIZE];
     float lb, ub, val;
-    RD_STATS tokenonly_rd_stats;
-    int64_t this_rd, this_model_rd;
-    PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
 #if CONFIG_HIGHBITDEPTH
     uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
     if (cpi->common.use_highbitdepth)
@@ -2942,82 +3025,56 @@
     const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
 #endif  // CONFIG_PALETTE_DELTA_ENCODING
 
-    for (n = colors > PALETTE_MAX_SIZE ? PALETTE_MAX_SIZE : colors; n >= 2;
-         --n) {
+    // Find the dominant colors, stored in top_colors[].
+    int top_colors[PALETTE_MAX_SIZE] = { 0 };
+    for (i = 0; i < AOMMIN(colors, PALETTE_MAX_SIZE); ++i) {
+      int max_count = 0;
+      for (int j = 0; j < (1 << cpi->common.bit_depth); ++j) {
+        if (count_buf[j] > max_count) {
+          max_count = count_buf[j];
+          top_colors[i] = j;
+        }
+      }
+      assert(max_count > 0);
+      count_buf[top_colors[i]] = 0;
+    }
+
+    // Try the dominant colors directly.
+    // TODO(huisu@google.com): Try to avoid duplicate computation in cases
+    // where the dominant colors and the k-means results are similar.
+    for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
+      for (i = 0; i < n; ++i) centroids[i] = top_colors[i];
+      palette_rd_y(cpi, x, mbmi, bsize, palette_ctx, dc_mode_cost, data,
+                   centroids, n,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                   color_cache, n_cache,
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+                   best_mbmi, best_palette_color_map, best_rd, best_model_rd,
+                   rate, rate_tokenonly, &rate_overhead, distortion, skippable);
+    }
+
+    // K-means clustering.
+    for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
       if (colors == PALETTE_MIN_SIZE) {
         // Special case: These colors automatically become the centroids.
         assert(colors == n);
         assert(colors == 2);
         centroids[0] = lb;
         centroids[1] = ub;
-        k = 2;
       } else {
         for (i = 0; i < n; ++i) {
           centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
         }
         av1_k_means(data, centroids, color_map, rows * cols, n, 1, max_itr);
-#if CONFIG_PALETTE_DELTA_ENCODING
-        optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
-#endif  // CONFIG_PALETTE_DELTA_ENCODING
-        k = av1_remove_duplicates(centroids, n);
-        if (k < PALETTE_MIN_SIZE) {
-          // Too few unique colors to create a palette. And DC_PRED will work
-          // well for that case anyway. So skip.
-          continue;
-        }
       }
 
-#if CONFIG_HIGHBITDEPTH
-      if (cpi->common.use_highbitdepth)
-        for (i = 0; i < k; ++i)
-          pmi->palette_colors[i] =
-              clip_pixel_highbd((int)centroids[i], cpi->common.bit_depth);
-      else
-#endif  // CONFIG_HIGHBITDEPTH
-        for (i = 0; i < k; ++i)
-          pmi->palette_colors[i] = clip_pixel((int)centroids[i]);
-      pmi->palette_size[0] = k;
-
-      av1_calc_indices(data, centroids, color_map, rows * cols, k, 1);
-      extend_palette_color_map(color_map, cols, rows, block_width,
-                               block_height);
-      palette_mode_cost =
-          dc_mode_cost +
-          x->palette_y_size_cost[bsize - BLOCK_8X8][k - PALETTE_MIN_SIZE] +
-          write_uniform_cost(k, color_map[0]) +
-          x->palette_y_mode_cost[bsize - BLOCK_8X8][palette_ctx][1];
-      palette_mode_cost += av1_palette_color_cost_y(pmi,
+      palette_rd_y(cpi, x, mbmi, bsize, palette_ctx, dc_mode_cost, data,
+                   centroids, n,
 #if CONFIG_PALETTE_DELTA_ENCODING
-                                                    color_cache, n_cache,
+                   color_cache, n_cache,
 #endif  // CONFIG_PALETTE_DELTA_ENCODING
-                                                    cpi->common.bit_depth);
-      palette_mode_cost +=
-          av1_cost_color_map(x, 0, 0, bsize, mbmi->tx_size, PALETTE_MAP);
-      this_model_rd = intra_model_yrd(cpi, x, bsize, palette_mode_cost);
-      if (*best_model_rd != INT64_MAX &&
-          this_model_rd > *best_model_rd + (*best_model_rd >> 1))
-        continue;
-      if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
-      super_block_yrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
-      if (tokenonly_rd_stats.rate == INT_MAX) continue;
-      this_rate = tokenonly_rd_stats.rate + palette_mode_cost;
-      this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
-      if (!xd->lossless[mbmi->segment_id] &&
-          block_signals_txsize(mbmi->sb_type)) {
-        tokenonly_rd_stats.rate -=
-            tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
-      }
-      if (this_rd < *best_rd) {
-        *best_rd = this_rd;
-        memcpy(best_palette_color_map, color_map,
-               block_width * block_height * sizeof(color_map[0]));
-        *best_mbmi = *mbmi;
-        rate_overhead = this_rate - tokenonly_rd_stats.rate;
-        if (rate) *rate = this_rate;
-        if (rate_tokenonly) *rate_tokenonly = tokenonly_rd_stats.rate;
-        if (distortion) *distortion = tokenonly_rd_stats.dist;
-        if (skippable) *skippable = tokenonly_rd_stats.skip;
-      }
+                   best_mbmi, best_palette_color_map, best_rd, best_model_rd,
+                   rate, rate_tokenonly, &rate_overhead, distortion, skippable);
     }
   }
 
@@ -5125,16 +5182,17 @@
   mbmi->filter_intra_mode_info.use_filter_intra_mode[1] = 0;
 #endif  // CONFIG_FILTER_INTRA
 
+  int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
 #if CONFIG_HIGHBITDEPTH
   if (cpi->common.use_highbitdepth) {
     colors_u = av1_count_colors_highbd(src_u, src_stride, rows, cols,
-                                       cpi->common.bit_depth);
+                                       cpi->common.bit_depth, count_buf);
     colors_v = av1_count_colors_highbd(src_v, src_stride, rows, cols,
-                                       cpi->common.bit_depth);
+                                       cpi->common.bit_depth, count_buf);
   } else {
 #endif  // CONFIG_HIGHBITDEPTH
-    colors_u = av1_count_colors(src_u, src_stride, rows, cols);
-    colors_v = av1_count_colors(src_v, src_stride, rows, cols);
+    colors_u = av1_count_colors(src_u, src_stride, rows, cols, count_buf);
+    colors_v = av1_count_colors(src_v, src_stride, rows, cols, count_buf);
 #if CONFIG_HIGHBITDEPTH
   }
 #endif  // CONFIG_HIGHBITDEPTH
diff --git a/av1/encoder/rdopt.h b/av1/encoder/rdopt.h
index 8e6de9e..970621c 100644
--- a/av1/encoder/rdopt.h
+++ b/av1/encoder/rdopt.h
@@ -63,11 +63,12 @@
 } OUTPUT_STATUS;
 
 // Returns the number of colors in 'src'.
-int av1_count_colors(const uint8_t *src, int stride, int rows, int cols);
+int av1_count_colors(const uint8_t *src, int stride, int rows, int cols,
+                     int *val_count);
 #if CONFIG_HIGHBITDEPTH
 // Same as av1_count_colors(), but for high-bitdepth mode.
 int av1_count_colors_highbd(const uint8_t *src8, int stride, int rows, int cols,
-                            int bit_depth);
+                            int bit_depth, int *val_count);
 #endif  // CONFIG_HIGHBITDEPTH
 
 void av1_dist_block(const AV1_COMP *cpi, MACROBLOCK *x, int plane,