Move intra related elements in InterModeSearchState to substruct

All intra related elements inside InterModeSearchState moved
to a different sub-structure IntraModeSearchState.

Change-Id: I128bbeeac985f7b231ae75fab3f0cb8962fcc779
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index ea1cec4..4c13788 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -502,24 +502,11 @@
   int valid;
 } SingleInterModeState;
 
-typedef struct InterModeSearchState {
-  int64_t best_rd;
-  MB_MODE_INFO best_mbmode;
-  int best_rate_y;
-  int best_rate_uv;
-  int best_mode_skippable;
-  int best_skip2;
-  THR_MODES best_mode_index;
+typedef struct IntraModeSearchState {
   int skip_intra_modes;
-  int num_available_refs;
-  int64_t dist_refs[REF_FRAMES];
-  int dist_order_refs[REF_FRAMES];
-  int64_t mode_threshold[MAX_MODES];
   PREDICTION_MODE best_intra_mode;
-  int64_t best_intra_rd;
   int angle_stats_ready;
   uint8_t directional_mode_skip_mask[INTRA_MODES];
-  unsigned int best_pred_sse;
   int rate_uv_intra;
   int rate_uv_tokenonly;
   int64_t dist_uvs;
@@ -528,6 +515,22 @@
   PALETTE_MODE_INFO pmi_uv;
   int8_t uv_angle_delta;
   int64_t best_pred_rd[REFERENCE_MODES];
+} IntraModeSearchState;
+
+typedef struct InterModeSearchState {
+  int64_t best_rd;
+  MB_MODE_INFO best_mbmode;
+  int best_rate_y;
+  int best_rate_uv;
+  int best_mode_skippable;
+  int best_skip2;
+  THR_MODES best_mode_index;
+  int num_available_refs;
+  int64_t dist_refs[REF_FRAMES];
+  int dist_order_refs[REF_FRAMES];
+  int64_t mode_threshold[MAX_MODES];
+  int64_t best_intra_rd;
+  unsigned int best_pred_sse;
   int64_t best_pred_diff[REFERENCE_MODES];
   // Save a set of single_newmv for each checked ref_mv.
   int_mv single_newmv[MAX_REF_MV_SEARCH][REF_FRAMES];
@@ -544,6 +547,7 @@
                                             [FWD_REFS];
   int single_state_modelled_cnt[2][SINGLE_INTER_MODE_NUM];
   MV_REFERENCE_FRAME single_rd_order[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
+  IntraModeSearchState intra_search_state;
 } InterModeSearchState;
 
 void av1_inter_mode_data_init(TileDataEnc *tile_data) {
@@ -5760,16 +5764,16 @@
   x->comp_rd_stats_idx = 0;
 }
 
-static AOM_INLINE void search_palette_mode(
-    const AV1_COMP *cpi, MACROBLOCK *x, RD_STATS *rd_cost,
+static AOM_INLINE int search_palette_mode(
+    const AV1_COMP *cpi, MACROBLOCK *x, RD_STATS *this_rd_cost,
     PICK_MODE_CONTEXT *ctx, BLOCK_SIZE bsize, MB_MODE_INFO *const mbmi,
     PALETTE_MODE_INFO *const pmi, unsigned int *ref_costs_single,
-    InterModeSearchState *search_state) {
+    IntraModeSearchState *intra_search_state, int64_t best_rd) {
   const AV1_COMMON *const cm = &cpi->common;
   const int num_planes = av1_num_planes(cm);
   MACROBLOCKD *const xd = &x->e_mbd;
   int rate2 = 0;
-  int64_t distortion2 = 0, best_rd_palette = search_state->best_rd, this_rd,
+  int64_t distortion2 = 0, best_rd_palette = best_rd, this_rd,
           best_model_rd_palette = INT64_MAX;
   int skippable = 0;
   TX_SIZE uv_tx = TX_4X4;
@@ -5794,7 +5798,10 @@
       best_palette_color_map, &best_rd_palette, &best_model_rd_palette,
       &rd_stats_y.rate, NULL, &rd_stats_y.dist, &rd_stats_y.skip, NULL, ctx,
       best_blk_skip, best_tx_type_map);
-  if (rd_stats_y.rate == INT_MAX || pmi->palette_size[0] == 0) return;
+  if (rd_stats_y.rate == INT_MAX || pmi->palette_size[0] == 0) {
+    this_rd_cost->rdcost = INT64_MAX;
+    return skippable;
+  }
 
   memcpy(x->blk_skip, best_blk_skip,
          sizeof(best_blk_skip[0]) * bsize_to_num_blk(bsize));
@@ -5807,54 +5814,58 @@
   rate2 = rd_stats_y.rate + ref_costs_single[INTRA_FRAME];
   if (num_planes > 1) {
     uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
-    if (search_state->rate_uv_intra == INT_MAX) {
-      choose_intra_uv_mode(cpi, x, bsize, uv_tx, &search_state->rate_uv_intra,
-                           &search_state->rate_uv_tokenonly,
-                           &search_state->dist_uvs, &search_state->skip_uvs,
-                           &search_state->mode_uv);
-      search_state->pmi_uv = *pmi;
-      search_state->uv_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
+    if (intra_search_state->rate_uv_intra == INT_MAX) {
+      choose_intra_uv_mode(
+          cpi, x, bsize, uv_tx, &intra_search_state->rate_uv_intra,
+          &intra_search_state->rate_uv_tokenonly, &intra_search_state->dist_uvs,
+          &intra_search_state->skip_uvs, &intra_search_state->mode_uv);
+      intra_search_state->pmi_uv = *pmi;
+      intra_search_state->uv_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
     }
-    mbmi->uv_mode = search_state->mode_uv;
-    pmi->palette_size[1] = search_state->pmi_uv.palette_size[1];
+    mbmi->uv_mode = intra_search_state->mode_uv;
+    pmi->palette_size[1] = intra_search_state->pmi_uv.palette_size[1];
     if (pmi->palette_size[1] > 0) {
       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
-             search_state->pmi_uv.palette_colors + PALETTE_MAX_SIZE,
+             intra_search_state->pmi_uv.palette_colors + PALETTE_MAX_SIZE,
              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
     }
-    mbmi->angle_delta[PLANE_TYPE_UV] = search_state->uv_angle_delta;
-    skippable = skippable && search_state->skip_uvs;
-    distortion2 += search_state->dist_uvs;
-    rate2 += search_state->rate_uv_intra;
+    mbmi->angle_delta[PLANE_TYPE_UV] = intra_search_state->uv_angle_delta;
+    skippable = skippable && intra_search_state->skip_uvs;
+    distortion2 += intra_search_state->dist_uvs;
+    rate2 += intra_search_state->rate_uv_intra;
   }
 
   if (skippable) {
     rate2 -= rd_stats_y.rate;
-    if (num_planes > 1) rate2 -= search_state->rate_uv_tokenonly;
+    if (num_planes > 1) rate2 -= intra_search_state->rate_uv_tokenonly;
     rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
   } else {
     rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
   }
   this_rd = RDCOST(x->rdmult, rate2, distortion2);
-  if (this_rd < search_state->best_rd) {
-    search_state->best_mode_index = THR_DC;
-    mbmi->mv[0].as_int = 0;
-    rd_cost->rate = rate2;
-    rd_cost->dist = distortion2;
-    rd_cost->rdcost = this_rd;
-    search_state->best_rd = this_rd;
-    search_state->best_mbmode = *mbmi;
-    search_state->best_skip2 = 0;
-    search_state->best_mode_skippable = skippable;
-    memcpy(ctx->blk_skip, x->blk_skip,
-           sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
-    av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
-  }
+  this_rd_cost->rate = rate2;
+  this_rd_cost->dist = distortion2;
+  this_rd_cost->rdcost = this_rd;
+  return skippable;
+}
+
+static AOM_INLINE void init_intra_mode_search_state(
+    IntraModeSearchState *intra_search_state) {
+  intra_search_state->skip_intra_modes = 0;
+  intra_search_state->best_intra_mode = DC_PRED;
+  intra_search_state->angle_stats_ready = 0;
+  av1_zero(intra_search_state->directional_mode_skip_mask);
+  intra_search_state->rate_uv_intra = INT_MAX;
+  av1_zero(intra_search_state->pmi_uv);
+  for (int i = 0; i < REFERENCE_MODES; ++i)
+    intra_search_state->best_pred_rd[i] = INT64_MAX;
 }
 
 static AOM_INLINE void init_inter_mode_search_state(
     InterModeSearchState *search_state, const AV1_COMP *cpi,
     const MACROBLOCK *x, BLOCK_SIZE bsize, int64_t best_rd_so_far) {
+  init_intra_mode_search_state(&search_state->intra_search_state);
+
   search_state->best_rd = best_rd_so_far;
 
   av1_zero(search_state->best_mbmode);
@@ -5873,8 +5884,6 @@
   const MB_MODE_INFO *const mbmi = xd->mi[0];
   const unsigned char segment_id = mbmi->segment_id;
 
-  search_state->skip_intra_modes = 0;
-
   search_state->num_available_refs = 0;
   memset(search_state->dist_refs, -1, sizeof(search_state->dist_refs));
   memset(search_state->dist_order_refs, -1,
@@ -5888,19 +5897,9 @@
         ((int64_t)rd_threshes[i] * x->thresh_freq_fact[bsize][i]) >>
         RD_THRESH_FAC_FRAC_BITS;
 
-  search_state->best_intra_mode = DC_PRED;
   search_state->best_intra_rd = INT64_MAX;
 
-  search_state->angle_stats_ready = 0;
-  av1_zero(search_state->directional_mode_skip_mask);
-
   search_state->best_pred_sse = UINT_MAX;
-  search_state->rate_uv_intra = INT_MAX;
-
-  av1_zero(search_state->pmi_uv);
-
-  for (int i = 0; i < REFERENCE_MODES; ++i)
-    search_state->best_pred_rd[i] = INT64_MAX;
 
   av1_zero(search_state->single_newmv);
   av1_zero(search_state->single_newmv_rate);
@@ -6103,12 +6102,14 @@
   set_default_interp_filters(mbmi, cm->interp_filter);
 }
 
-static int64_t handle_intra_mode(InterModeSearchState *search_state,
+static int64_t handle_intra_mode(IntraModeSearchState *intra_search_state,
                                  const AV1_COMP *cpi, MACROBLOCK *x,
                                  BLOCK_SIZE bsize, int ref_frame_cost,
                                  const PICK_MODE_CONTEXT *ctx, int disable_skip,
                                  RD_STATS *rd_stats, RD_STATS *rd_stats_y,
-                                 RD_STATS *rd_stats_uv) {
+                                 RD_STATS *rd_stats_uv, int64_t best_rd,
+                                 int64_t *best_intra_rd,
+                                 int8_t best_mbmode_skip) {
   const AV1_COMMON *cm = &cpi->common;
   const SPEED_FEATURES *const sf = &cpi->sf;
   MACROBLOCKD *const xd = &x->e_mbd;
@@ -6126,8 +6127,8 @@
   if (mode != DC_PRED && mode != PAETH_PRED) known_rate += intra_cost_penalty;
   known_rate += AOMMIN(x->skip_cost[skip_ctx][0], x->skip_cost[skip_ctx][1]);
   const int64_t known_rd = RDCOST(x->rdmult, known_rate, 0);
-  if (known_rd > search_state->best_rd) {
-    search_state->skip_intra_modes = 1;
+  if (known_rd > best_rd) {
+    intra_search_state->skip_intra_modes = 1;
     return INT64_MAX;
   }
 
@@ -6135,24 +6136,24 @@
   if (is_directional_mode && av1_use_angle_delta(bsize) &&
       cpi->oxcf.enable_angle_delta) {
     if (sf->intra_sf.intra_pruning_with_hog &&
-        !search_state->angle_stats_ready) {
+        !intra_search_state->angle_stats_ready) {
       prune_intra_mode_with_hog(x, bsize,
                                 cpi->sf.intra_sf.intra_pruning_with_hog_thresh,
-                                search_state->directional_mode_skip_mask);
-      search_state->angle_stats_ready = 1;
+                                intra_search_state->directional_mode_skip_mask);
+      intra_search_state->angle_stats_ready = 1;
     }
-    if (search_state->directional_mode_skip_mask[mode]) return INT64_MAX;
+    if (intra_search_state->directional_mode_skip_mask[mode]) return INT64_MAX;
     av1_init_rd_stats(rd_stats_y);
     rd_stats_y->rate = INT_MAX;
     int64_t model_rd = INT64_MAX;
     int rate_dummy;
     rd_pick_intra_angle_sby(cpi, x, &rate_dummy, rd_stats_y, bsize, mode_cost,
-                            search_state->best_rd, &model_rd, 0);
+                            best_rd, &model_rd, 0);
 
   } else {
     av1_init_rd_stats(rd_stats_y);
     mbmi->angle_delta[PLANE_TYPE_Y] = 0;
-    super_block_yrd(cpi, x, rd_stats_y, bsize, search_state->best_rd);
+    super_block_yrd(cpi, x, rd_stats_y, bsize, best_rd);
   }
 
   // Pick filter intra modes.
@@ -6163,9 +6164,9 @@
       const int tmp_rate =
           rd_stats_y->rate + x->filter_intra_cost[bsize][0] + mode_cost;
       best_rd_so_far = RDCOST(x->rdmult, tmp_rate, rd_stats_y->dist);
-      try_filter_intra = (best_rd_so_far / 2) <= search_state->best_rd;
+      try_filter_intra = (best_rd_so_far / 2) <= best_rd;
     } else {
-      try_filter_intra = !search_state->best_mbmode.skip;
+      try_filter_intra = !best_mbmode_skip;
     }
 
     if (try_filter_intra) {
@@ -6182,7 +6183,7 @@
       for (FILTER_INTRA_MODE fi_mode = FILTER_DC_PRED;
            fi_mode < FILTER_INTRA_MODES; ++fi_mode) {
         mbmi->filter_intra_mode_info.filter_intra_mode = fi_mode;
-        super_block_yrd(cpi, x, &rd_stats_y_fi, bsize, search_state->best_rd);
+        super_block_yrd(cpi, x, &rd_stats_y_fi, bsize, best_rd);
         if (rd_stats_y_fi.rate == INT_MAX) continue;
         const int this_rate_tmp =
             rd_stats_y_fi.rate +
@@ -6190,8 +6191,7 @@
         const int64_t this_rd_tmp =
             RDCOST(x->rdmult, this_rate_tmp, rd_stats_y_fi.dist);
 
-        if (this_rd_tmp != INT64_MAX &&
-            this_rd_tmp / 2 > search_state->best_rd) {
+        if (this_rd_tmp != INT64_MAX && this_rd_tmp / 2 > best_rd) {
           break;
         }
         if (this_rd_tmp < best_rd_so_far) {
@@ -6233,44 +6233,43 @@
         cpi->oxcf.enable_palette &&
         av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type);
     const TX_SIZE uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
-    if (search_state->rate_uv_intra == INT_MAX) {
+    if (intra_search_state->rate_uv_intra == INT_MAX) {
       const int rate_y =
           rd_stats_y->skip ? x->skip_cost[skip_ctx][1] : rd_stats_y->rate;
       const int64_t rdy =
           RDCOST(x->rdmult, rate_y + mode_cost_y, rd_stats_y->dist);
-      if (search_state->best_rd < (INT64_MAX / 2) &&
-          rdy > (search_state->best_rd + (search_state->best_rd >> 2))) {
-        search_state->skip_intra_modes = 1;
+      if (best_rd < (INT64_MAX / 2) && rdy > (best_rd + (best_rd >> 2))) {
+        intra_search_state->skip_intra_modes = 1;
         return INT64_MAX;
       }
-      choose_intra_uv_mode(cpi, x, bsize, uv_tx, &search_state->rate_uv_intra,
-                           &search_state->rate_uv_tokenonly,
-                           &search_state->dist_uvs, &search_state->skip_uvs,
-                           &search_state->mode_uv);
-      if (try_palette) search_state->pmi_uv = *pmi;
-      search_state->uv_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
+      choose_intra_uv_mode(
+          cpi, x, bsize, uv_tx, &intra_search_state->rate_uv_intra,
+          &intra_search_state->rate_uv_tokenonly, &intra_search_state->dist_uvs,
+          &intra_search_state->skip_uvs, &intra_search_state->mode_uv);
+      if (try_palette) intra_search_state->pmi_uv = *pmi;
+      intra_search_state->uv_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
 
-      const int uv_rate = search_state->rate_uv_tokenonly;
-      const int64_t uv_dist = search_state->dist_uvs;
+      const int uv_rate = intra_search_state->rate_uv_tokenonly;
+      const int64_t uv_dist = intra_search_state->dist_uvs;
       const int64_t uv_rd = RDCOST(x->rdmult, uv_rate, uv_dist);
-      if (uv_rd > search_state->best_rd) {
-        search_state->skip_intra_modes = 1;
+      if (uv_rd > best_rd) {
+        intra_search_state->skip_intra_modes = 1;
         return INT64_MAX;
       }
     }
 
-    rd_stats_uv->rate = search_state->rate_uv_tokenonly;
-    rd_stats_uv->dist = search_state->dist_uvs;
-    rd_stats_uv->skip = search_state->skip_uvs;
+    rd_stats_uv->rate = intra_search_state->rate_uv_tokenonly;
+    rd_stats_uv->dist = intra_search_state->dist_uvs;
+    rd_stats_uv->skip = intra_search_state->skip_uvs;
     rd_stats->skip = rd_stats_y->skip && rd_stats_uv->skip;
-    mbmi->uv_mode = search_state->mode_uv;
+    mbmi->uv_mode = intra_search_state->mode_uv;
     if (try_palette) {
-      pmi->palette_size[1] = search_state->pmi_uv.palette_size[1];
+      pmi->palette_size[1] = intra_search_state->pmi_uv.palette_size[1];
       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
-             search_state->pmi_uv.palette_colors + PALETTE_MAX_SIZE,
+             intra_search_state->pmi_uv.palette_colors + PALETTE_MAX_SIZE,
              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
     }
-    mbmi->angle_delta[PLANE_TYPE_UV] = search_state->uv_angle_delta;
+    mbmi->angle_delta[PLANE_TYPE_UV] = intra_search_state->uv_angle_delta;
   }
 
   rd_stats->rate = rd_stats_y->rate + mode_cost_y;
@@ -6300,21 +6299,20 @@
   // Calculate the final RD estimate for this mode.
   const int64_t this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
   // Keep record of best intra rd
-  if (this_rd < search_state->best_intra_rd) {
-    search_state->best_intra_rd = this_rd;
-    search_state->best_intra_mode = mode;
+  if (this_rd < *best_intra_rd) {
+    *best_intra_rd = this_rd;
+    intra_search_state->best_intra_mode = mode;
   }
 
   if (sf->intra_sf.skip_intra_in_interframe) {
-    if (search_state->best_rd < (INT64_MAX / 2) &&
-        this_rd > (search_state->best_rd + (search_state->best_rd >> 1)))
-      search_state->skip_intra_modes = 1;
+    if (best_rd < (INT64_MAX / 2) && this_rd > (best_rd + (best_rd >> 1)))
+      intra_search_state->skip_intra_modes = 1;
   }
 
   if (!disable_skip) {
     for (int i = 0; i < REFERENCE_MODES; ++i) {
-      search_state->best_pred_rd[i] =
-          AOMMIN(search_state->best_pred_rd[i], this_rd);
+      intra_search_state->best_pred_rd[i] =
+          AOMMIN(intra_search_state->best_pred_rd[i], this_rd);
     }
   }
   return this_rd;
@@ -7045,14 +7043,20 @@
       hybrid_rd = RDCOST(x->rdmult, hybrid_rate, rd_stats.dist);
 
       if (!comp_pred) {
-        if (single_rd < search_state.best_pred_rd[SINGLE_REFERENCE])
-          search_state.best_pred_rd[SINGLE_REFERENCE] = single_rd;
+        if (single_rd <
+            search_state.intra_search_state.best_pred_rd[SINGLE_REFERENCE])
+          search_state.intra_search_state.best_pred_rd[SINGLE_REFERENCE] =
+              single_rd;
       } else {
-        if (single_rd < search_state.best_pred_rd[COMPOUND_REFERENCE])
-          search_state.best_pred_rd[COMPOUND_REFERENCE] = single_rd;
+        if (single_rd <
+            search_state.intra_search_state.best_pred_rd[COMPOUND_REFERENCE])
+          search_state.intra_search_state.best_pred_rd[COMPOUND_REFERENCE] =
+              single_rd;
       }
-      if (hybrid_rd < search_state.best_pred_rd[REFERENCE_MODE_SELECT])
-        search_state.best_pred_rd[REFERENCE_MODE_SELECT] = hybrid_rd;
+      if (hybrid_rd <
+          search_state.intra_search_state.best_pred_rd[REFERENCE_MODE_SELECT])
+        search_state.intra_search_state.best_pred_rd[REFERENCE_MODE_SELECT] =
+            hybrid_rd;
     }
 
     // TODO(anyone): evaluate the quality and speed trade-off of the early
@@ -7180,16 +7184,17 @@
       aom_clear_system_state();
       av1_nn_softmax(scores, probs, 2);
 
-      if (probs[1] > 0.8) search_state.skip_intra_modes = 1;
+      if (probs[1] > 0.8) search_state.intra_search_state.skip_intra_modes = 1;
     } else if ((search_state.best_mbmode.skip) &&
                (sf->intra_sf.skip_intra_in_interframe >= 2)) {
-      search_state.skip_intra_modes = 1;
+      search_state.intra_search_state.skip_intra_modes = 1;
     }
   }
 
   const int intra_ref_frame_cost = ref_costs_single[INTRA_FRAME];
   for (int j = 0; j < intra_mode_num; ++j) {
-    if (sf->intra_sf.skip_intra_in_interframe && search_state.skip_intra_modes)
+    if (sf->intra_sf.skip_intra_in_interframe &&
+        search_state.intra_search_state.skip_intra_modes)
       break;
     const THR_MODES mode_enum = intra_mode_idx_ls[j];
     const MODE_DEFINITION *mode_def = &av1_mode_defs[mode_enum];
@@ -7210,15 +7215,18 @@
           continue;
       }
       if (sf->rt_sf.mode_search_skip_flags & FLAG_SKIP_INTRA_DIRMISMATCH) {
-        if (conditional_skipintra(this_mode, search_state.best_intra_mode))
+        if (conditional_skipintra(
+                this_mode, search_state.intra_search_state.best_intra_mode))
           continue;
       }
     }
 
     RD_STATS intra_rd_stats, intra_rd_stats_y, intra_rd_stats_uv;
     intra_rd_stats.rdcost = handle_intra_mode(
-        &search_state, cpi, x, bsize, intra_ref_frame_cost, ctx, 0,
-        &intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv);
+        &search_state.intra_search_state, cpi, x, bsize, intra_ref_frame_cost,
+        ctx, 0, &intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv,
+        search_state.best_rd, &search_state.best_intra_rd,
+        search_state.best_mbmode.skip);
     // Collect mode stats for multiwinner mode processing
     const int txfm_search_done = 1;
     store_winner_mode_stats(
@@ -7254,9 +7262,26 @@
       av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type) &&
       !is_inter_mode(search_state.best_mbmode.mode);
   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
+  RD_STATS this_rd_cost;
+  int this_skippable = 0;
   if (try_palette) {
-    search_palette_mode(cpi, x, rd_cost, ctx, bsize, mbmi, pmi,
-                        ref_costs_single, &search_state);
+    this_skippable = search_palette_mode(
+        cpi, x, &this_rd_cost, ctx, bsize, mbmi, pmi, ref_costs_single,
+        &search_state.intra_search_state, search_state.best_rd);
+    if (this_rd_cost.rdcost < search_state.best_rd) {
+      search_state.best_mode_index = THR_DC;
+      mbmi->mv[0].as_int = 0;
+      rd_cost->rate = this_rd_cost.rate;
+      rd_cost->dist = this_rd_cost.dist;
+      rd_cost->rdcost = this_rd_cost.rdcost;
+      search_state.best_rd = rd_cost->rdcost;
+      search_state.best_mbmode = *mbmi;
+      search_state.best_skip2 = 0;
+      search_state.best_mode_skippable = this_skippable;
+      memcpy(ctx->blk_skip, x->blk_skip,
+             sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
+      av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
+    }
   }
 
   search_state.best_mbmode.skip_mode = 0;
@@ -7317,11 +7342,12 @@
   }
 
   for (i = 0; i < REFERENCE_MODES; ++i) {
-    if (search_state.best_pred_rd[i] == INT64_MAX) {
+    if (search_state.intra_search_state.best_pred_rd[i] == INT64_MAX) {
       search_state.best_pred_diff[i] = INT_MIN;
     } else {
       search_state.best_pred_diff[i] =
-          search_state.best_rd - search_state.best_pred_rd[i];
+          search_state.best_rd -
+          search_state.intra_search_state.best_pred_rd[i];
     }
   }