Prune palette rd search

Speed feature prune_palette_search_level has been
introduced to perform 2-way palette search.

STATS_CHANGED

Change-Id: I89c4e43e45355c011be1eb8fb3370b7e933939ad
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 196a34b..562169f 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4581,7 +4581,7 @@
     uint8_t *best_palette_color_map, int64_t *best_rd, int64_t *best_model_rd,
     int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
     int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *blk_skip,
-    uint8_t *tx_type_map) {
+    uint8_t *tx_type_map, int *beat_best_pallette_rd) {
   optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
   int k = av1_remove_duplicates(centroids, n);
   if (k < PALETTE_MIN_SIZE) {
@@ -4638,8 +4638,74 @@
     if (rate_tokenonly) *rate_tokenonly = tokenonly_rd_stats.rate;
     if (distortion) *distortion = tokenonly_rd_stats.dist;
     if (skippable) *skippable = tokenonly_rd_stats.skip;
+    if (beat_best_pallette_rd) *beat_best_pallette_rd = 1;
   }
 }
+// Perform palette search for top colors from minimum palette colors (/maximum)
+// with a step-size of 1 (/-1)
+static AOM_INLINE int perform_top_color_palette_search(
+    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
+    BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *top_colors,
+    int start_n, int end_n, int step_size, uint16_t *color_cache, int n_cache,
+    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
+    int64_t *best_model_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
+    int *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
+    uint8_t *best_blk_skip, uint8_t *tx_type_map) {
+  int centroids[PALETTE_MAX_SIZE];
+  int n = start_n;
+  assert((step_size == -1) || (step_size == 1));
+  assert(IMPLIES(step_size == -1, start_n > end_n));
+  assert(IMPLIES(step_size == 1, start_n < end_n));
+  while (1) {
+    int beat_best_pallette_rd = 0;
+    for (int i = 0; i < n; ++i) centroids[i] = top_colors[i];
+    palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
+                 color_cache, n_cache, best_mbmi, best_palette_color_map,
+                 best_rd, best_model_rd, rate, rate_tokenonly, distortion,
+                 skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
+                 &beat_best_pallette_rd);
+    // Break if current palette colors is not winning
+    if (cpi->sf.prune_palette_search_level && !beat_best_pallette_rd) return n;
+    n += step_size;
+    if (n == end_n) break;
+  }
+  return n;
+}
+// Perform k-means based palette search from minimum palette colors (/maximum)
+// with a step-size of 1 (/-1)
+static AOM_INLINE int perform_k_means_palette_search(
+    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
+    BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int lb, int ub,
+    int start_n, int end_n, int step_size, uint16_t *color_cache, int n_cache,
+    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
+    int64_t *best_model_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
+    int *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
+    uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
+    int data_points) {
+  int centroids[PALETTE_MAX_SIZE];
+  const int max_itr = 50;
+  int n = start_n;
+  assert((step_size == -1) || (step_size == 1));
+  assert(IMPLIES(step_size == -1, start_n > end_n));
+  assert(IMPLIES(step_size == 1, start_n < end_n));
+  while (1) {
+    int beat_best_pallette_rd = 0;
+    for (int i = 0; i < n; ++i) {
+      centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
+    }
+    av1_k_means(data, centroids, color_map, data_points, n, 1, max_itr);
+    palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
+                 color_cache, n_cache, best_mbmi, best_palette_color_map,
+                 best_rd, best_model_rd, rate, rate_tokenonly, distortion,
+                 skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
+                 &beat_best_pallette_rd);
+    // Break if current palette colors is not winning
+    if (cpi->sf.prune_palette_search_level && !beat_best_pallette_rd) return n;
+    n += step_size;
+    if (n == end_n) break;
+  }
+  return n;
+}
 
 static void rd_pick_palette_intra_sby(
     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
@@ -4724,38 +4790,57 @@
       count_buf[top_colors[i]] = 0;
     }
 
-    int n;
-
     // Try the dominant colors directly.
     // TODO(huisu@google.com): Try to avoid duplicate computation in cases
     // where the dominant colors and the k-means results are similar.
-    for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
-      for (int i = 0; i < n; ++i) centroids[i] = top_colors[i];
-      palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
-                   color_cache, n_cache, best_mbmi, best_palette_color_map,
-                   best_rd, best_model_rd, rate, rate_tokenonly, distortion,
-                   skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
-    }
+    const int start_n = AOMMIN(colors, PALETTE_MAX_SIZE),
+              end_n = PALETTE_MIN_SIZE;
 
+    // Perform top color palette search from start_n
+    const int top_color_winner = perform_top_color_palette_search(
+        cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, start_n, end_n - 1,
+        -1, color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
+        best_model_rd, rate, rate_tokenonly, distortion, skippable,
+        beat_best_rd, ctx, best_blk_skip, tx_type_map);
+
+    if (top_color_winner > end_n) {
+      // Perform top color palette search in reverse order for the remaining
+      // colors
+      perform_top_color_palette_search(
+          cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, end_n,
+          top_color_winner, 1, color_cache, n_cache, best_mbmi,
+          best_palette_color_map, best_rd, best_model_rd, rate, rate_tokenonly,
+          distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
+    }
     // K-means clustering.
-    const int max_itr = 50;
-    for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
-      if (colors == PALETTE_MIN_SIZE) {
-        // Special case: These colors automatically become the centroids.
-        assert(colors == n);
-        assert(colors == 2);
-        centroids[0] = lb;
-        centroids[1] = ub;
-      } else {
-        for (int i = 0; i < n; ++i) {
-          centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
-        }
-        av1_k_means(data, centroids, color_map, rows * cols, n, 1, max_itr);
-      }
-      palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
+    if (colors == PALETTE_MIN_SIZE) {
+      // Special case: These colors automatically become the centroids.
+      assert(colors == 2);
+      centroids[0] = lb;
+      centroids[1] = ub;
+      palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, colors,
                    color_cache, n_cache, best_mbmi, best_palette_color_map,
                    best_rd, best_model_rd, rate, rate_tokenonly, distortion,
-                   skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
+                   skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
+                   NULL);
+    } else {
+      // Perform k-means palette search from start_n
+      const int k_means_winner = perform_k_means_palette_search(
+          cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, start_n, end_n - 1,
+          -1, color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
+          best_model_rd, rate, rate_tokenonly, distortion, skippable,
+          beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
+          rows * cols);
+      if (k_means_winner > end_n) {
+        // Perform k-means palette search in reverse order for the remaining
+        // colors
+        perform_k_means_palette_search(
+            cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, end_n,
+            k_means_winner, 1, color_cache, n_cache, best_mbmi,
+            best_palette_color_map, best_rd, best_model_rd, rate,
+            rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
+            best_blk_skip, tx_type_map, color_map, rows * cols);
+      }
     }
   }
 
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 98b3e94..dff9905 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -415,6 +415,9 @@
                                                                          : 2;
     sf->tx_type_search.use_skip_flag_prediction =
         cm->allow_screen_content_tools ? 1 : 2;
+    // TODO(any): Experiment with binary search and extend for all frame types
+    // and speed = 1 and 2
+    sf->prune_palette_search_level = frame_is_intra_only(&cpi->common) ? 0 : 1;
   }
 
   if (speed >= 4) {
@@ -473,6 +476,7 @@
     sf->simple_motion_search_prune_agg = 2;
     sf->use_interp_filter = 2;
     sf->prune_ref_mv_idx_search = 1;
+    sf->prune_palette_search_level = 1;
   }
 
   if (speed >= 5) {
@@ -865,6 +869,7 @@
   sf->force_tx_search_off = 0;
   sf->motion_mode_for_winner_cand = 0;
   sf->num_inter_modes_for_tx_search = INT_MAX;
+  sf->prune_palette_search_level = 0;
 
   for (i = 0; i < TX_SIZES; i++) {
     sf->intra_y_mode_mask[i] = INTRA_ALL;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 2d6aff7..bfbbef8 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -864,6 +864,13 @@
 
   // If set forces interpolation filter to EIGHTTAP_REGULAR
   int skip_interp_filter_search;
+
+  // prune palette search
+  // 0: No pruning
+  // 1: Perform 2 way palette search from max colors to min colors (and min
+  // colors to remaining colors) and terminate the search if current number of
+  // palette colors is not the winner.
+  int prune_palette_search_level;
 } SPEED_FEATURES;
 
 struct AV1_COMP;