Introduce early_term_luma_palette_size_search speed feature

The sf early_term_luma_palette_size_search is introduced to
terminate rd evaluation of luma palette_size based on the
available header/mode cost information for coarse palette_size
search when prune_palette_search_level = 1. This sf is enabled
for speeds 1 and 2 in allintra encoding mode.

For allintra video encode (on screen content set),

          Instruction Count        BD-Rate Loss(%)
cpu-used     Reduction(%)   avg.psnr  ovr.psnr    ssim
   1           1.069        0.0037    0.0036      0.0089
   2           2.765        0.0096    0.0091      0.0016

For AVIF still image encode,

          Instruction Count    BD-Rate Loss(%)
cpu-used     Reduction(%)      psnr       ssim
   1           1.134           0.0014     0.0002
   2           1.291           0.0012     0.0002

BUG=aomedia:2959

STATS_CHANGED

Change-Id: I2d696b5e58ee38867e50acf48a235cb9a54d972f
diff --git a/av1/encoder/palette.c b/av1/encoder/palette.c
index 9078b5c..04d2390 100644
--- a/av1/encoder/palette.c
+++ b/av1/encoder/palette.c
@@ -217,11 +217,13 @@
 static AOM_INLINE void palette_rd_y(
     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
     BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *centroids, int n,
-    uint16_t *color_cache, int n_cache, MB_MODE_INFO *best_mbmi,
-    uint8_t *best_palette_color_map, int64_t *best_rd, int *rate,
-    int *rate_tokenonly, int64_t *distortion, int *skippable, int *beat_best_rd,
-    PICK_MODE_CONTEXT *ctx, uint8_t *blk_skip, uint8_t *tx_type_map,
-    int *beat_best_palette_rd) {
+    uint16_t *color_cache, int n_cache, bool do_header_rd_based_gating,
+    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
+    int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
+    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *blk_skip,
+    uint8_t *tx_type_map, int *beat_best_palette_rd,
+    bool *do_header_rd_based_breakout) {
+  if (do_header_rd_based_breakout != NULL) *do_header_rd_based_breakout = false;
   optimize_palette_colors(color_cache, n_cache, n, 1, centroids,
                           cpi->common.seq_params->bit_depth);
   const int num_unique_colors = av1_remove_duplicates(centroids, n);
@@ -252,12 +254,31 @@
   extend_palette_color_map(color_map, cols, rows, block_width, block_height);
 
   RD_STATS tokenonly_rd_stats;
-  av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
-                                    *best_rd);
-  if (tokenonly_rd_stats.rate == INT_MAX) return;
-  const int palette_mode_cost =
-      intra_mode_info_cost_y(cpi, x, mbmi, bsize, dc_mode_cost);
-  int this_rate = tokenonly_rd_stats.rate + palette_mode_cost;
+  int this_rate;
+
+  if (do_header_rd_based_gating) {
+    assert(do_header_rd_based_breakout != NULL);
+    const int palette_mode_rate =
+        intra_mode_info_cost_y(cpi, x, mbmi, bsize, dc_mode_cost);
+    const int64_t header_rd = RDCOST(x->rdmult, palette_mode_rate, 0);
+    // Terminate further palette_size search, if the header cost corresponding
+    // to lower palette_size is more than best_rd.
+    if (header_rd > *best_rd) {
+      *do_header_rd_based_breakout = true;
+      return;
+    }
+    av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
+                                      *best_rd);
+    if (tokenonly_rd_stats.rate == INT_MAX) return;
+    this_rate = tokenonly_rd_stats.rate + palette_mode_rate;
+  } else {
+    av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
+                                      *best_rd);
+    if (tokenonly_rd_stats.rate == INT_MAX) return;
+    this_rate = tokenonly_rd_stats.rate +
+                intra_mode_info_cost_y(cpi, x, mbmi, bsize, dc_mode_cost);
+  }
+
   int64_t this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->bsize)) {
     tokenonly_rd_stats.rate -= tx_size_cost(x, bsize, mbmi->tx_size);
@@ -298,11 +319,12 @@
 static AOM_INLINE int perform_top_color_palette_search(
     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
     BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *top_colors,
-    int start_n, int end_n, int step_size, int *last_n_searched,
-    uint16_t *color_cache, int n_cache, MB_MODE_INFO *best_mbmi,
-    uint8_t *best_palette_color_map, int64_t *best_rd, int *rate,
-    int *rate_tokenonly, int64_t *distortion, int *skippable, int *beat_best_rd,
-    PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip, uint8_t *tx_type_map) {
+    int start_n, int end_n, int step_size, bool do_header_rd_based_gating,
+    int *last_n_searched, uint16_t *color_cache, int n_cache,
+    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
+    int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
+    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip,
+    uint8_t *tx_type_map) {
   int centroids[PALETTE_MAX_SIZE];
   int n = start_n;
   int top_color_winner = end_n;
@@ -312,13 +334,16 @@
   assert(IMPLIES(step_size > 0, start_n < end_n));
   while (!is_iter_over(n, end_n, step_size)) {
     int beat_best_palette_rd = 0;
+    bool do_header_rd_based_breakout = false;
     memcpy(centroids, top_colors, n * sizeof(top_colors[0]));
     palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
-                 color_cache, n_cache, best_mbmi, best_palette_color_map,
-                 best_rd, rate, rate_tokenonly, distortion, skippable,
-                 beat_best_rd, ctx, best_blk_skip, tx_type_map,
-                 &beat_best_palette_rd);
+                 color_cache, n_cache, do_header_rd_based_gating, best_mbmi,
+                 best_palette_color_map, best_rd, rate, rate_tokenonly,
+                 distortion, skippable, beat_best_rd, ctx, best_blk_skip,
+                 tx_type_map, &beat_best_palette_rd,
+                 &do_header_rd_based_breakout);
     *last_n_searched = n;
+    if (do_header_rd_based_breakout) break;
     if (beat_best_palette_rd) {
       top_color_winner = n;
     } else if (cpi->sf.intra_sf.prune_palette_search_level == 2) {
@@ -338,11 +363,12 @@
     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
     BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int lower_bound,
     int upper_bound, int start_n, int end_n, int step_size,
-    int *last_n_searched, uint16_t *color_cache, int n_cache,
-    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
-    int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
-    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip,
-    uint8_t *tx_type_map, uint8_t *color_map, int data_points) {
+    bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
+    int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
+    int64_t *best_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
+    int *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
+    uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
+    int data_points) {
   int centroids[PALETTE_MAX_SIZE];
   const int max_itr = 50;
   int n = start_n;
@@ -353,17 +379,20 @@
   assert(IMPLIES(step_size > 0, start_n < end_n));
   while (!is_iter_over(n, end_n, step_size)) {
     int beat_best_palette_rd = 0;
+    bool do_header_rd_based_breakout = false;
     for (int i = 0; i < n; ++i) {
       centroids[i] =
           lower_bound + (2 * i + 1) * (upper_bound - lower_bound) / n / 2;
     }
     av1_k_means(data, centroids, color_map, data_points, n, 1, max_itr);
     palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
-                 color_cache, n_cache, best_mbmi, best_palette_color_map,
-                 best_rd, rate, rate_tokenonly, distortion, skippable,
-                 beat_best_rd, ctx, best_blk_skip, tx_type_map,
-                 &beat_best_palette_rd);
+                 color_cache, n_cache, do_header_rd_based_gating, best_mbmi,
+                 best_palette_color_map, best_rd, rate, rate_tokenonly,
+                 distortion, skippable, beat_best_rd, ctx, best_blk_skip,
+                 tx_type_map, &beat_best_palette_rd,
+                 &do_header_rd_based_breakout);
     *last_n_searched = n;
+    if (do_header_rd_based_breakout) break;
     if (beat_best_palette_rd) {
       top_color_winner = n;
     } else if (cpi->sf.intra_sf.prune_palette_search_level == 2) {
@@ -521,13 +550,18 @@
       const int min_n = start_n_lookup_table[max_n];
       const int step_size = step_size_lookup_table[max_n];
       assert(min_n >= PALETTE_MIN_SIZE);
+      // Header rdcost based early gating is currently enabled only for coarse
+      // palette size search. For all other cases, the do_header_rd_based_gating
+      // is explicitly passed as 'false'.
+      const bool do_header_rd_based_gating =
+          cpi->sf.intra_sf.early_term_luma_palette_size_search != 0;
 
       // Perform top color coarse palette search to find the winner candidate
       const int top_color_winner = perform_top_color_palette_search(
           cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, min_n, max_n + 1,
-          step_size, &unused, color_cache, n_cache, best_mbmi,
-          best_palette_color_map, best_rd, rate, rate_tokenonly, distortion,
-          skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
+          step_size, do_header_rd_based_gating, &unused, color_cache, n_cache,
+          best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
+          distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
       // Evaluate neighbors for the winner color (if winner is found) in the
       // above coarse search for dominant colors
       if (top_color_winner <= max_n) {
@@ -537,7 +571,8 @@
         // perform finer search for the winner candidate
         perform_top_color_palette_search(
             cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, stage2_min_n,
-            stage2_max_n + 1, stage2_step_size, &unused, color_cache, n_cache,
+            stage2_max_n + 1, stage2_step_size,
+            /*do_header_rd_based_gating=*/false, &unused, color_cache, n_cache,
             best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
             distortion, skippable, beat_best_rd, ctx, best_blk_skip,
             tx_type_map);
@@ -546,10 +581,10 @@
       // Perform k-means coarse palette search to find the winner candidate
       const int k_means_winner = perform_k_means_palette_search(
           cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
-          min_n, max_n + 1, step_size, &unused, color_cache, n_cache, best_mbmi,
-          best_palette_color_map, best_rd, rate, rate_tokenonly, distortion,
-          skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
-          rows * cols);
+          min_n, max_n + 1, step_size, do_header_rd_based_gating, &unused,
+          color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
+          rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
+          best_blk_skip, tx_type_map, color_map, rows * cols);
       // Evaluate neighbors for the winner color (if winner is found) in the
       // above coarse search for k-means
       if (k_means_winner <= max_n) {
@@ -559,10 +594,11 @@
         // perform finer search for the winner candidate
         perform_k_means_palette_search(
             cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
-            start_n_stage2, end_n_stage2 + 1, step_size_stage2, &unused,
-            color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
-            rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
-            best_blk_skip, tx_type_map, color_map, rows * cols);
+            start_n_stage2, end_n_stage2 + 1, step_size_stage2,
+            /*do_header_rd_based_gating=*/false, &unused, color_cache, n_cache,
+            best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
+            distortion, skippable, beat_best_rd, ctx, best_blk_skip,
+            tx_type_map, color_map, rows * cols);
       }
     } else {
       const int max_n = AOMMIN(colors, PALETTE_MAX_SIZE),
@@ -571,17 +607,19 @@
       int last_n_searched = max_n;
       perform_top_color_palette_search(
           cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, max_n, min_n - 1,
-          -1, &last_n_searched, color_cache, n_cache, best_mbmi,
-          best_palette_color_map, best_rd, rate, rate_tokenonly, distortion,
-          skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
+          -1, /*do_header_rd_based_gating=*/false, &last_n_searched,
+          color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
+          rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
+          best_blk_skip, tx_type_map);
 
       if (last_n_searched > min_n) {
         // Search in ascending order until we get to the previous best
         perform_top_color_palette_search(
             cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, min_n,
-            last_n_searched, 1, &unused, color_cache, n_cache, best_mbmi,
-            best_palette_color_map, best_rd, rate, rate_tokenonly, distortion,
-            skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
+            last_n_searched, 1, /*do_header_rd_based_gating=*/false, &unused,
+            color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
+            rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
+            best_blk_skip, tx_type_map);
       }
       // K-means clustering.
       if (colors == PALETTE_MIN_SIZE) {
@@ -590,26 +628,29 @@
         centroids[0] = lower_bound;
         centroids[1] = upper_bound;
         palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, colors,
-                     color_cache, n_cache, best_mbmi, best_palette_color_map,
-                     best_rd, rate, rate_tokenonly, distortion, skippable,
-                     beat_best_rd, ctx, best_blk_skip, tx_type_map, NULL);
+                     color_cache, n_cache, /*do_header_rd_based_gating=*/false,
+                     best_mbmi, best_palette_color_map, best_rd, rate,
+                     rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
+                     best_blk_skip, tx_type_map, NULL, NULL);
       } else {
         // Perform k-means palette search in descending order
         last_n_searched = max_n;
         perform_k_means_palette_search(
             cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
-            max_n, min_n - 1, -1, &last_n_searched, color_cache, n_cache,
-            best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
-            distortion, skippable, beat_best_rd, ctx, best_blk_skip,
-            tx_type_map, color_map, rows * cols);
+            max_n, min_n - 1, -1, /*do_header_rd_based_gating=*/false,
+            &last_n_searched, color_cache, n_cache, best_mbmi,
+            best_palette_color_map, best_rd, rate, rate_tokenonly, distortion,
+            skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
+            rows * cols);
         if (last_n_searched > min_n) {
           // Search in ascending order until we get to the previous best
           perform_k_means_palette_search(
               cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
-              min_n, last_n_searched, 1, &unused, color_cache, n_cache,
-              best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
-              distortion, skippable, beat_best_rd, ctx, best_blk_skip,
-              tx_type_map, color_map, rows * cols);
+              min_n, last_n_searched, 1, /*do_header_rd_based_gating=*/false,
+              &unused, color_cache, n_cache, best_mbmi, best_palette_color_map,
+              best_rd, rate, rate_tokenonly, distortion, skippable,
+              beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
+              rows * cols);
         }
       }
     }
diff --git a/av1/encoder/speed_features.c b/av1/encoder/speed_features.c
index b39bf33..3dc91a5 100644
--- a/av1/encoder/speed_features.c
+++ b/av1/encoder/speed_features.c
@@ -345,6 +345,7 @@
     sf->mv_sf.exhaustive_searches_thresh <<= 1;
 
     sf->intra_sf.prune_palette_search_level = 1;
+    sf->intra_sf.early_term_luma_palette_size_search = 1;
     sf->intra_sf.top_intra_model_count_allowed = 3;
 
     sf->tx_sf.adaptive_txb_search_level = 2;
@@ -1692,6 +1693,7 @@
   intra_sf->intra_pruning_with_hog = 0;
   intra_sf->chroma_intra_pruning_with_hog = 0;
   intra_sf->prune_palette_search_level = 0;
+  intra_sf->early_term_luma_palette_size_search = 0;
 
   for (int i = 0; i < TX_SIZES; i++) {
     intra_sf->intra_y_mode_mask[i] = INTRA_ALL;
diff --git a/av1/encoder/speed_features.h b/av1/encoder/speed_features.h
index ef9238b..a2eab3d 100644
--- a/av1/encoder/speed_features.h
+++ b/av1/encoder/speed_features.h
@@ -936,6 +936,17 @@
   // palette colors is not the winner.
   int prune_palette_search_level;
 
+  // Terminate early in luma palette_size search.
+  // 0: No early termination
+  // 1: Terminate early for higher luma palette_size, if header rd cost of lower
+  // palette_size is more than best_rd.
+  // For allintra encode, this sf reduces instruction count by 1.07% and 2.76%
+  // for speed 1 and 2 on screen content set with coding performance change less
+  // than 0.01%. For AVIF image encode, this sf reduces instruction count
+  // by 1.13% and 1.29% for speed 1 and 2 on a typical image dataset with coding
+  // performance change less than 0.01%.
+  int early_term_luma_palette_size_search;
+
   // Prune chroma intra modes based on luma intra mode winner.
   // 0: No pruning
   // 1: Prune chroma intra modes other than UV_DC_PRED, UV_SMOOTH_PRED,