Prune palette rd search for speed=1,2

Speed feature prune_palette_search_level has been
introduced to perform coarse search to prune the
palette colors. For winner colors,neighbors are
also evaluated using a finer search. This speed
feature is applicable for speed=1,2.
For speed>=3, 2 way palette search would be used.

           Instruction Count
cpu-used       Reduction        Quality Loss
   2             1.86%             0.03%
   1             1.16%            -0.01%

STATS_CHANGED

Change-Id: I92258e751f5166ca4585124df3b0893081ad7167
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 2896db4..a3ce861 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -4700,6 +4700,67 @@
     if (beat_best_pallette_rd) *beat_best_pallette_rd = 1;
   }
 }
+
+static AOM_INLINE int perform_top_color_coarse_palette_search(
+    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
+    BLOCK_SIZE bsize, int dc_mode_cost, const int *data,
+    const int *const top_colors, int start_n, int end_n, int step_size,
+    uint16_t *color_cache, int n_cache, MB_MODE_INFO *best_mbmi,
+    uint8_t *best_palette_color_map, int64_t *best_rd, int64_t *best_model_rd,
+    int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
+    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip,
+    uint8_t *tx_type_map) {
+  int centroids[PALETTE_MAX_SIZE];
+  int n = start_n;
+  int top_color_winner = end_n + 1;
+  while (1) {
+    int beat_best_pallette_rd = 0;
+    for (int i = 0; i < n; ++i) centroids[i] = top_colors[i];
+    palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
+                 color_cache, n_cache, best_mbmi, best_palette_color_map,
+                 best_rd, best_model_rd, rate, rate_tokenonly, distortion,
+                 skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
+                 &beat_best_pallette_rd);
+    // Break if current palette colors is not winning
+    if (beat_best_pallette_rd) top_color_winner = n;
+    n += step_size;
+    if (n > end_n) break;
+  }
+  return top_color_winner;
+}
+
+static AOM_INLINE int perform_k_means_coarse_palette_search(
+    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
+    BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int lb, int ub,
+    int start_n, int end_n, int step_size, uint16_t *color_cache, int n_cache,
+    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
+    int64_t *best_model_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
+    int *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
+    uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
+    int data_points) {
+  int centroids[PALETTE_MAX_SIZE];
+  const int max_itr = 50;
+  int n = start_n;
+  int k_means_winner = end_n + 1;
+  while (1) {
+    int beat_best_pallette_rd = 0;
+    for (int i = 0; i < n; ++i) {
+      centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
+    }
+    av1_k_means(data, centroids, color_map, data_points, n, 1, max_itr);
+    palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
+                 color_cache, n_cache, best_mbmi, best_palette_color_map,
+                 best_rd, best_model_rd, rate, rate_tokenonly, distortion,
+                 skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
+                 &beat_best_pallette_rd);
+    // Break if current palette colors is not winning
+    if (beat_best_pallette_rd) k_means_winner = n;
+    n += step_size;
+    if (n > end_n) break;
+  }
+  return k_means_winner;
+}
+
 // Perform palette search for top colors from minimum palette colors (/maximum)
 // with a step-size of 1 (/-1)
 static AOM_INLINE int perform_top_color_palette_search(
@@ -4712,7 +4773,8 @@
     uint8_t *best_blk_skip, uint8_t *tx_type_map) {
   int centroids[PALETTE_MAX_SIZE];
   int n = start_n;
-  assert((step_size == -1) || (step_size == 1));
+  assert((step_size == -1) || (step_size == 1) || (step_size == 0) ||
+         (step_size == 2));
   assert(IMPLIES(step_size == -1, start_n > end_n));
   assert(IMPLIES(step_size == 1, start_n < end_n));
   while (1) {
@@ -4724,7 +4786,8 @@
                  skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
                  &beat_best_pallette_rd);
     // Break if current palette colors is not winning
-    if (cpi->sf.prune_palette_search_level && !beat_best_pallette_rd) return n;
+    if ((cpi->sf.prune_palette_search_level == 2) && !beat_best_pallette_rd)
+      return n;
     n += step_size;
     if (n == end_n) break;
   }
@@ -4744,7 +4807,8 @@
   int centroids[PALETTE_MAX_SIZE];
   const int max_itr = 50;
   int n = start_n;
-  assert((step_size == -1) || (step_size == 1));
+  assert((step_size == -1) || (step_size == 1) || (step_size == 0) ||
+         (step_size == 2));
   assert(IMPLIES(step_size == -1, start_n > end_n));
   assert(IMPLIES(step_size == 1, start_n < end_n));
   while (1) {
@@ -4759,13 +4823,52 @@
                  skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
                  &beat_best_pallette_rd);
     // Break if current palette colors is not winning
-    if (cpi->sf.prune_palette_search_level && !beat_best_pallette_rd) return n;
+    if ((cpi->sf.prune_palette_search_level == 2) && !beat_best_pallette_rd)
+      return n;
     n += step_size;
     if (n == end_n) break;
   }
   return n;
 }
 
+#define START_N_STAGE2(x)                         \
+  ((x == PALETTE_MIN_SIZE) ? PALETTE_MIN_SIZE + 1 \
+                           : AOMMAX(x - 1, PALETTE_MIN_SIZE));
+#define END_N_STAGE2(x, end_n) \
+  ((x == end_n) ? x - 1 : AOMMIN(x + 1, PALETTE_MAX_SIZE));
+
+static AOM_INLINE void update_start_end_stage_2(int *start_n_stage2,
+                                                int *end_n_stage2,
+                                                int *step_size_stage2,
+                                                int winner, int end_n) {
+  *start_n_stage2 = START_N_STAGE2(winner);
+  *end_n_stage2 = END_N_STAGE2(winner, end_n);
+  *step_size_stage2 = *end_n_stage2 - *start_n_stage2;
+}
+
+// Start index and step size below are chosen to evaluate unique
+// candidates in neighbor search, in case a winner candidate is found in
+// coarse search. Example,
+// 1) 8 colors (end_n = 8): 2,3,4,5,6,7,8. start_n is chosen as 2 and step
+// size is chosen as 3. Therefore, coarse search will evaluate 2, 5 and 8.
+// If winner is found at 5, then 4 and 6 are evaluated. Similarly, for 2
+// (3) and 8 (7).
+// 2) 7 colors (end_n = 7): 2,3,4,5,6,7. If start_n is chosen as 2 (same
+// as for 8 colors) then step size should also be 2, to cover all
+// candidates. Coarse search will evaluate 2, 4 and 6. If winner is either
+// 2 or 4, 3 will be evaluated. Instead, if start_n=3 and step_size=3,
+// coarse search will evaluate 3 and 6. For the winner, unique neighbors
+// (3: 2,4 or 6: 5,7) would be evaluated.
+
+// start index for coarse palette search for dominant colors and k-means
+static const uint8_t start_n_lookup_table[PALETTE_MAX_SIZE + 1] = { 0, 0, 0,
+                                                                    3, 3, 2,
+                                                                    3, 3, 2 };
+// step size for coarse palette search for dominant colors and k-means
+static const uint8_t step_size_lookup_table[PALETTE_MAX_SIZE + 1] = { 0, 0, 0,
+                                                                      3, 3, 3,
+                                                                      3, 3, 3 };
+
 static void rd_pick_palette_intra_sby(
     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
     int dc_mode_cost, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
@@ -4852,53 +4955,106 @@
     // Try the dominant colors directly.
     // TODO(huisu@google.com): Try to avoid duplicate computation in cases
     // where the dominant colors and the k-means results are similar.
-    const int start_n = AOMMIN(colors, PALETTE_MAX_SIZE),
-              end_n = PALETTE_MIN_SIZE;
-
-    // Perform top color palette search from start_n
-    const int top_color_winner = perform_top_color_palette_search(
-        cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, start_n, end_n - 1,
-        -1, color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
-        best_model_rd, rate, rate_tokenonly, distortion, skippable,
-        beat_best_rd, ctx, best_blk_skip, tx_type_map);
-
-    if (top_color_winner > end_n) {
-      // Perform top color palette search in reverse order for the remaining
-      // colors
-      perform_top_color_palette_search(
-          cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, end_n,
-          top_color_winner, 1, color_cache, n_cache, best_mbmi,
-          best_palette_color_map, best_rd, best_model_rd, rate, rate_tokenonly,
-          distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
-    }
-    // K-means clustering.
-    if (colors == PALETTE_MIN_SIZE) {
-      // Special case: These colors automatically become the centroids.
-      assert(colors == 2);
-      centroids[0] = lb;
-      centroids[1] = ub;
-      palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, colors,
-                   color_cache, n_cache, best_mbmi, best_palette_color_map,
-                   best_rd, best_model_rd, rate, rate_tokenonly, distortion,
-                   skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
-                   NULL);
-    } else {
-      // Perform k-means palette search from start_n
-      const int k_means_winner = perform_k_means_palette_search(
-          cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, start_n, end_n - 1,
-          -1, color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
-          best_model_rd, rate, rate_tokenonly, distortion, skippable,
+    if ((cpi->sf.prune_palette_search_level == 1) &&
+        (colors > PALETTE_MIN_SIZE)) {
+      const int end_n = AOMMIN(colors, PALETTE_MAX_SIZE);
+      assert(PALETTE_MAX_SIZE == 8);
+      assert(PALETTE_MIN_SIZE == 2);
+      // Choose the start index and step size for coarse search based on number
+      // of colors
+      const int start_n = start_n_lookup_table[end_n];
+      const int step_size = step_size_lookup_table[end_n];
+      // Perform top color coarse palette search to find the winner candidate
+      const int top_color_winner = perform_top_color_coarse_palette_search(
+          cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, start_n, end_n,
+          step_size, color_cache, n_cache, best_mbmi, best_palette_color_map,
+          best_rd, best_model_rd, rate, rate_tokenonly, distortion, skippable,
+          beat_best_rd, ctx, best_blk_skip, tx_type_map);
+      // Evaluate neighbors for the winner color (if winner is found) in the
+      // above coarse search for dominant colors
+      if (top_color_winner <= end_n) {
+        int start_n_stage2, end_n_stage2, step_size_stage2;
+        update_start_end_stage_2(&start_n_stage2, &end_n_stage2,
+                                 &step_size_stage2, top_color_winner, end_n);
+        // perform finer search for the winner candidate
+        perform_top_color_palette_search(
+            cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, start_n_stage2,
+            end_n_stage2 + step_size_stage2, step_size_stage2, color_cache,
+            n_cache, best_mbmi, best_palette_color_map, best_rd, best_model_rd,
+            rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
+            best_blk_skip, tx_type_map);
+      }
+      // K-means clustering.
+      // Perform k-means coarse palette search to find the winner candidate
+      const int k_means_winner = perform_k_means_coarse_palette_search(
+          cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, start_n, end_n,
+          step_size, color_cache, n_cache, best_mbmi, best_palette_color_map,
+          best_rd, best_model_rd, rate, rate_tokenonly, distortion, skippable,
           beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
           rows * cols);
-      if (k_means_winner > end_n) {
-        // Perform k-means palette search in reverse order for the remaining
-        // colors
+      // Evaluate neighbors for the winner color (if winner is found) in the
+      // above coarse search for k-means
+      if (k_means_winner <= end_n) {
+        int start_n_stage2, end_n_stage2, step_size_stage2;
+        update_start_end_stage_2(&start_n_stage2, &end_n_stage2,
+                                 &step_size_stage2, k_means_winner, end_n);
+        // perform finer search for the winner candidate
         perform_k_means_palette_search(
-            cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, end_n,
-            k_means_winner, 1, color_cache, n_cache, best_mbmi,
+            cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, start_n_stage2,
+            end_n_stage2 + step_size_stage2, step_size_stage2, color_cache,
+            n_cache, best_mbmi, best_palette_color_map, best_rd, best_model_rd,
+            rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
+            best_blk_skip, tx_type_map, color_map, rows * cols);
+      }
+    } else {
+      const int start_n = AOMMIN(colors, PALETTE_MAX_SIZE),
+                end_n = PALETTE_MIN_SIZE;
+      // Perform top color palette search from start_n
+      const int top_color_winner = perform_top_color_palette_search(
+          cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, start_n,
+          end_n - 1, -1, color_cache, n_cache, best_mbmi,
+          best_palette_color_map, best_rd, best_model_rd, rate, rate_tokenonly,
+          distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
+
+      if (top_color_winner > end_n) {
+        // Perform top color palette search in reverse order for the remaining
+        // colors
+        perform_top_color_palette_search(
+            cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, end_n,
+            top_color_winner, 1, color_cache, n_cache, best_mbmi,
             best_palette_color_map, best_rd, best_model_rd, rate,
             rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
-            best_blk_skip, tx_type_map, color_map, rows * cols);
+            best_blk_skip, tx_type_map);
+      }
+      // K-means clustering.
+      if (colors == PALETTE_MIN_SIZE) {
+        // Special case: These colors automatically become the centroids.
+        assert(colors == 2);
+        centroids[0] = lb;
+        centroids[1] = ub;
+        palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, colors,
+                     color_cache, n_cache, best_mbmi, best_palette_color_map,
+                     best_rd, best_model_rd, rate, rate_tokenonly, distortion,
+                     skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
+                     NULL);
+      } else {
+        // Perform k-means palette search from start_n
+        const int k_means_winner = perform_k_means_palette_search(
+            cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, start_n, end_n - 1,
+            -1, color_cache, n_cache, best_mbmi, best_palette_color_map,
+            best_rd, best_model_rd, rate, rate_tokenonly, distortion, skippable,
+            beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
+            rows * cols);
+        if (k_means_winner > end_n) {
+          // Perform k-means palette search in reverse order for the remaining
+          // colors
+          perform_k_means_palette_search(
+              cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, end_n,
+              k_means_winner, 1, color_cache, n_cache, best_mbmi,
+              best_palette_color_map, best_rd, best_model_rd, rate,
+              rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
+              best_blk_skip, tx_type_map, color_map, rows * cols);
+        }
       }
     }
   }
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index 71c2b0d..6569085 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -343,6 +343,7 @@
     sf->reduce_inter_modes = boosted ? 1 : 2;
     sf->tx_type_search.prune_mode = PRUNE_2D_FAST;
     sf->prune_comp_type_by_model_rd = boosted ? 0 : 1;
+    sf->prune_palette_search_level = 1;
   }
 
   if (speed >= 2) {
@@ -417,9 +418,7 @@
                                                                          : 2;
     sf->tx_type_search.use_skip_flag_prediction =
         cm->allow_screen_content_tools ? 1 : 2;
-    // TODO(any): Experiment with binary search and extend for all frame types
-    // and speed = 1 and 2
-    sf->prune_palette_search_level = 1;
+    sf->prune_palette_search_level = 2;
   }
 
   if (speed >= 4) {
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index 88e1d03..eac4697 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -874,7 +874,9 @@
 
   // prune palette search
   // 0: No pruning
-  // 1: Perform 2 way palette search from max colors to min colors (and min
+  // 1: Perform coarse search to prune the palette colors. For winner colors,
+  // neighbors are also evaluated using a finer search.
+  // 2: Perform 2 way palette search from max colors to min colors (and min
   // colors to remaining colors) and terminate the search if current number of
   // palette colors is not the winner.
   int prune_palette_search_level;