Refactor rd_pick_intra_sby_mode()

Simplify code.

Change-Id: Ifa65ea66e55c52ab79f32de1fc27121ddf088fc3
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 00bfc61..d70fdbe 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1851,14 +1851,15 @@
 }
 
 #if CONFIG_PALETTE
-static int rd_pick_palette_intra_sby(
-    const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int palette_ctx,
-    int dc_mode_cost, PALETTE_MODE_INFO *palette_mode_info,
-    uint8_t *best_palette_color_map, TX_SIZE *best_tx, TX_TYPE *best_tx_type,
-    PREDICTION_MODE *mode_selected, int64_t *best_rd) {
+static int rd_pick_palette_intra_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
+                                     BLOCK_SIZE bsize, int palette_ctx,
+                                     int dc_mode_cost, MB_MODE_INFO *best_mbmi,
+                                     uint8_t *best_palette_color_map,
+                                     int64_t *best_rd) {
   int rate_overhead = 0;
   MACROBLOCKD *const xd = &x->e_mbd;
   MODE_INFO *const mic = xd->mi[0];
+  MB_MODE_INFO *const mbmi = &mic->mbmi;
   const int rows = block_size_high[bsize];
   const int cols = block_size_wide[bsize];
   int this_rate, colors, n;
@@ -1866,6 +1867,7 @@
   int64_t this_rd;
   const int src_stride = x->plane[0].src.stride;
   const uint8_t *const src = x->plane[0].src.buf;
+  uint8_t *const color_map = xd->plane[0].color_index_map;
 
   assert(cpi->common.allow_screen_content_tools);
 
@@ -1876,9 +1878,8 @@
   else
 #endif  // CONFIG_AOM_HIGHBITDEPTH
     colors = av1_count_colors(src, src_stride, rows, cols);
-  palette_mode_info->palette_size[0] = 0;
 #if CONFIG_FILTER_INTRA
-  mic->mbmi.filter_intra_mode_info.use_filter_intra_mode[0] = 0;
+  mbmi->filter_intra_mode_info.use_filter_intra_mode[0] = 0;
 #endif  // CONFIG_FILTER_INTRA
 
   if (colors > 1 && colors <= 64) {
@@ -1887,9 +1888,7 @@
     uint8_t color_order[PALETTE_MAX_SIZE];
     float *const data = x->palette_buffer->kmeans_data_buf;
     float centroids[PALETTE_MAX_SIZE];
-    uint8_t *const color_map = xd->plane[0].color_index_map;
     float lb, ub, val;
-    MB_MODE_INFO *const mbmi = &mic->mbmi;
     PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
 #if CONFIG_AOM_HIGHBITDEPTH
     uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
@@ -1978,16 +1977,19 @@
 
       if (this_rd < *best_rd) {
         *best_rd = this_rd;
-        *palette_mode_info = *pmi;
         memcpy(best_palette_color_map, color_map,
                rows * cols * sizeof(color_map[0]));
-        *mode_selected = DC_PRED;
-        *best_tx = mbmi->tx_size;
-        *best_tx_type = mbmi->tx_type;
+        *best_mbmi = *mbmi;
         rate_overhead = this_rate - tokenonly_rd_stats.rate;
       }
     }
   }
+
+  if (best_mbmi->palette_mode_info.palette_size[0] > 0) {
+    memcpy(color_map, best_palette_color_map,
+           rows * cols * sizeof(best_palette_color_map[0]));
+  }
+  *mbmi = *best_mbmi;
   return rate_overhead;
 }
 #endif  // CONFIG_PALETTE
@@ -2836,36 +2838,30 @@
                                       int64_t *distortion, int *skippable,
                                       BLOCK_SIZE bsize, int64_t best_rd) {
   uint8_t mode_idx;
-  PREDICTION_MODE mode_selected = DC_PRED;
   MACROBLOCKD *const xd = &x->e_mbd;
   MODE_INFO *const mic = xd->mi[0];
+  MB_MODE_INFO *const mbmi = &mic->mbmi;
+  MB_MODE_INFO best_mbmi = *mbmi;
   int this_rate, this_rate_tokenonly, s;
   int64_t this_distortion, this_rd;
-  TX_SIZE best_tx = TX_4X4;
-#if CONFIG_EXT_INTRA || CONFIG_PALETTE
+#if CONFIG_EXT_INTRA
   const int rows = block_size_high[bsize];
   const int cols = block_size_wide[bsize];
-#endif  // CONFIG_EXT_INTRA || CONFIG_PALETTE
-#if CONFIG_EXT_INTRA
 #if CONFIG_INTRA_INTERP
   const int intra_filter_ctx = av1_get_pred_context_intra_interp(xd);
-  INTRA_FILTER best_filter = INTRA_FILTER_LINEAR;
 #endif  // CONFIG_INTRA_INTERP
-  int is_directional_mode, best_angle_delta = 0;
+  int is_directional_mode;
   uint8_t directional_mode_skip_mask[INTRA_MODES];
   const int src_stride = x->plane[0].src.stride;
   const uint8_t *src = x->plane[0].src.buf;
 #endif  // CONFIG_EXT_INTRA
 #if CONFIG_FILTER_INTRA
   int beat_best_rd = 0;
-  FILTER_INTRA_MODE_INFO filter_intra_mode_info;
   uint16_t filter_intra_mode_skip_mask = (1 << FILTER_INTRA_MODES) - 1;
 #endif  // CONFIG_FILTER_INTRA
-  TX_TYPE best_tx_type = DCT_DCT;
   const int *bmode_costs;
 #if CONFIG_PALETTE
-  PALETTE_MODE_INFO palette_mode_info;
-  PALETTE_MODE_INFO *const pmi = &mic->mbmi.palette_mode_info;
+  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
   uint8_t *best_palette_color_map =
       cpi->common.allow_screen_content_tools
           ? x->palette_buffer->best_palette_color_map
@@ -2887,7 +2883,7 @@
   bmode_costs = cpi->y_mode_costs[A][L];
 
 #if CONFIG_EXT_INTRA
-  mic->mbmi.angle_delta[0] = 0;
+  mbmi->angle_delta[0] = 0;
   memset(directional_mode_skip_mask, 0,
          sizeof(directional_mode_skip_mask[0]) * INTRA_MODES);
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -2899,11 +2895,9 @@
     angle_estimation(src, src_stride, rows, cols, directional_mode_skip_mask);
 #endif  // CONFIG_EXT_INTRA
 #if CONFIG_FILTER_INTRA
-  filter_intra_mode_info.use_filter_intra_mode[0] = 0;
-  mic->mbmi.filter_intra_mode_info.use_filter_intra_mode[0] = 0;
+  mbmi->filter_intra_mode_info.use_filter_intra_mode[0] = 0;
 #endif  // CONFIG_FILTER_INTRA
 #if CONFIG_PALETTE
-  palette_mode_info.palette_size[0] = 0;
   pmi->palette_size[0] = 0;
   if (above_mi)
     palette_ctx += (above_mi->mbmi.palette_mode_info.palette_size[0] > 0);
@@ -2921,25 +2915,24 @@
     RD_STATS this_rd_stats;
     if (mode_idx == FINAL_MODE_SEARCH) {
       if (x->use_default_intra_tx_type == 0) break;
-      mic->mbmi.mode = mode_selected;
+      mbmi->mode = best_mbmi.mode;
       x->use_default_intra_tx_type = 0;
     } else {
-      mic->mbmi.mode = mode_idx;
+      mbmi->mode = mode_idx;
     }
 #if CONFIG_PVQ
     od_encode_rollback(&x->daala_enc, &pre_buf);
 #endif
 #if CONFIG_EXT_INTRA
-    is_directional_mode = av1_is_directional_mode(mic->mbmi.mode, bsize);
-    if (is_directional_mode && directional_mode_skip_mask[mic->mbmi.mode])
-      continue;
+    is_directional_mode = av1_is_directional_mode(mbmi->mode, bsize);
+    if (is_directional_mode && directional_mode_skip_mask[mbmi->mode]) continue;
     if (is_directional_mode) {
       this_rd_stats.rate = INT_MAX;
       this_rd =
           rd_pick_intra_angle_sby(cpi, x, &this_rate, &this_rd_stats, bsize,
-                                  bmode_costs[mic->mbmi.mode], best_rd);
+                                  bmode_costs[mbmi->mode], best_rd);
     } else {
-      mic->mbmi.angle_delta[0] = 0;
+      mbmi->angle_delta[0] = 0;
       super_block_yrd(cpi, x, &this_rd_stats, bsize, best_rd);
     }
 #else
@@ -2951,63 +2944,53 @@
 
     if (this_rate_tokenonly == INT_MAX) continue;
 
-    this_rate = this_rate_tokenonly + bmode_costs[mic->mbmi.mode];
+    this_rate = this_rate_tokenonly + bmode_costs[mbmi->mode];
 
-    if (!xd->lossless[xd->mi[0]->mbmi.segment_id] &&
-        mic->mbmi.sb_type >= BLOCK_8X8) {
+    if (!xd->lossless[mbmi->segment_id] && mbmi->sb_type >= BLOCK_8X8) {
       // super_block_yrd above includes the cost of the tx_size in the
       // tokenonly rate, but for intra blocks, tx_size is always coded
       // (prediction granularity), so we account for it in the full rate,
       // not the tokenonly rate.
       this_rate_tokenonly -=
           cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)]
-                           [tx_size_to_depth(mic->mbmi.tx_size)];
+                           [tx_size_to_depth(mbmi->tx_size)];
     }
 #if CONFIG_PALETTE
-    if (cpi->common.allow_screen_content_tools && mic->mbmi.mode == DC_PRED)
+    if (cpi->common.allow_screen_content_tools && mbmi->mode == DC_PRED)
       this_rate += av1_cost_bit(
           av1_default_palette_y_mode_prob[bsize - BLOCK_8X8][palette_ctx], 0);
 #endif  // CONFIG_PALETTE
 #if CONFIG_FILTER_INTRA
-    if (mic->mbmi.mode == DC_PRED)
+    if (mbmi->mode == DC_PRED)
       this_rate += av1_cost_bit(cpi->common.fc->filter_intra_probs[0], 0);
 #endif  // CONFIG_FILTER_INTRA
 #if CONFIG_EXT_INTRA
     if (is_directional_mode) {
       const int max_angle_delta = av1_get_max_angle_delta(bsize, 0);
 #if CONFIG_INTRA_INTERP
-      const int p_angle =
-          mode_to_angle_map[mic->mbmi.mode] +
-          mic->mbmi.angle_delta[0] * av1_get_angle_step(bsize, 0);
+      const int p_angle = mode_to_angle_map[mbmi->mode] +
+                          mbmi->angle_delta[0] * av1_get_angle_step(bsize, 0);
       if (av1_is_intra_filter_switchable(p_angle))
         this_rate +=
-            cpi->intra_filter_cost[intra_filter_ctx][mic->mbmi.intra_filter];
+            cpi->intra_filter_cost[intra_filter_ctx][mbmi->intra_filter];
 #endif  // CONFIG_INTRA_INTERP
-      this_rate += write_uniform_cost(
-          2 * max_angle_delta + 1, max_angle_delta + mic->mbmi.angle_delta[0]);
+      this_rate += write_uniform_cost(2 * max_angle_delta + 1,
+                                      max_angle_delta + mbmi->angle_delta[0]);
     }
 #endif  // CONFIG_EXT_INTRA
     this_rd = RDCOST(x->rdmult, x->rddiv, this_rate, this_distortion);
 #if CONFIG_FILTER_INTRA
     if (best_rd == INT64_MAX || this_rd - best_rd < (best_rd >> 4)) {
-      filter_intra_mode_skip_mask ^= (1 << mic->mbmi.mode);
+      filter_intra_mode_skip_mask ^= (1 << mbmi->mode);
     }
 #endif  // CONFIG_FILTER_INTRA
 
     if (this_rd < best_rd) {
-      mode_selected = mic->mbmi.mode;
+      best_mbmi = *mbmi;
       best_rd = this_rd;
-      best_tx = mic->mbmi.tx_size;
-#if CONFIG_EXT_INTRA
-      best_angle_delta = mic->mbmi.angle_delta[0];
-#if CONFIG_INTRA_INTERP
-      best_filter = mic->mbmi.intra_filter;
-#endif  // CONFIG_INTRA_INTERP
-#endif  // CONFIG_EXT_INTRA
 #if CONFIG_FILTER_INTRA
       beat_best_rd = 1;
 #endif  // CONFIG_FILTER_INTRA
-      best_tx_type = mic->mbmi.tx_type;
       *rate = this_rate;
       *rate_tokenonly = this_rate_tokenonly;
       *distortion = this_distortion;
@@ -3023,11 +3006,10 @@
 #endif
 
 #if CONFIG_PALETTE
-  if (cpi->common.allow_screen_content_tools)
+  if (cpi->common.allow_screen_content_tools) {
     rd_pick_palette_intra_sby(cpi, x, bsize, palette_ctx, bmode_costs[DC_PRED],
-                              &palette_mode_info, best_palette_color_map,
-                              &best_tx, &best_tx_type, &mode_selected,
-                              &best_rd);
+                              &best_mbmi, best_palette_color_map, &best_rd);
+  }
 #endif  // CONFIG_PALETTE
 
 #if CONFIG_FILTER_INTRA
@@ -3035,43 +3017,12 @@
     if (rd_pick_filter_intra_sby(cpi, x, rate, rate_tokenonly, distortion,
                                  skippable, bsize, bmode_costs[DC_PRED],
                                  &best_rd, filter_intra_mode_skip_mask)) {
-      mode_selected = mic->mbmi.mode;
-      best_tx = mic->mbmi.tx_size;
-      filter_intra_mode_info = mic->mbmi.filter_intra_mode_info;
-      best_tx_type = mic->mbmi.tx_type;
+      best_mbmi = *mbmi;
     }
   }
-
-  mic->mbmi.filter_intra_mode_info.use_filter_intra_mode[0] =
-      filter_intra_mode_info.use_filter_intra_mode[0];
-  if (filter_intra_mode_info.use_filter_intra_mode[0]) {
-    mic->mbmi.filter_intra_mode_info.filter_intra_mode[0] =
-        filter_intra_mode_info.filter_intra_mode[0];
-#if CONFIG_PALETTE
-    palette_mode_info.palette_size[0] = 0;
-#endif  // CONFIG_PALETTE
-  }
 #endif  // CONFIG_FILTER_INTRA
 
-  mic->mbmi.mode = mode_selected;
-  mic->mbmi.tx_size = best_tx;
-#if CONFIG_EXT_INTRA
-  mic->mbmi.angle_delta[0] = best_angle_delta;
-#if CONFIG_INTRA_INTERP
-  mic->mbmi.intra_filter = best_filter;
-#endif  // CONFIG_INTRA_INTERP
-#endif  // CONFIG_EXT_INTRA
-  mic->mbmi.tx_type = best_tx_type;
-#if CONFIG_PALETTE
-  pmi->palette_size[0] = palette_mode_info.palette_size[0];
-  if (palette_mode_info.palette_size[0] > 0) {
-    memcpy(pmi->palette_colors, palette_mode_info.palette_colors,
-           PALETTE_MAX_SIZE * sizeof(palette_mode_info.palette_colors[0]));
-    memcpy(xd->plane[0].color_index_map, best_palette_color_map,
-           rows * cols * sizeof(best_palette_color_map[0]));
-  }
-#endif  // CONFIG_PALETTE
-
+  *mbmi = best_mbmi;
   return best_rd;
 }
 
@@ -3777,6 +3728,7 @@
     int *rate_tokenonly, int64_t *distortion, int *skippable) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
+  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
   const BLOCK_SIZE bsize = mbmi->sb_type;
   const int rows = block_size_high[bsize] >> (xd->plane[1].subsampling_y);
   const int cols = block_size_wide[bsize] >> (xd->plane[1].subsampling_x);
@@ -3786,6 +3738,7 @@
   const int src_stride = x->plane[1].src.stride;
   const uint8_t *const src_u = x->plane[1].src.buf;
   const uint8_t *const src_v = x->plane[2].src.buf;
+  uint8_t *const color_map = xd->plane[1].color_index_map;
   RD_STATS tokenonly_rd_stats;
 
   if (rows * cols > PALETTE_MAX_BLOCK_SIZE) return;
@@ -3817,8 +3770,6 @@
     float lb_v, ub_v, val_v;
     float *const data = x->palette_buffer->kmeans_data_buf;
     float centroids[2 * PALETTE_MAX_SIZE];
-    uint8_t *const color_map = xd->plane[1].color_index_map;
-    PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
 
 #if CONFIG_AOM_HIGHBITDEPTH
     uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u);
@@ -3925,6 +3876,13 @@
       }
     }
   }
+  if (palette_mode_info->palette_size[1] > 0) {
+    memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
+           palette_mode_info->palette_colors + PALETTE_MAX_SIZE,
+           2 * PALETTE_MAX_SIZE * sizeof(palette_mode_info->palette_colors[0]));
+    memcpy(color_map, best_palette_color_map,
+           rows * cols * sizeof(best_palette_color_map[0]));
+  }
 }
 #endif  // CONFIG_PALETTE
 
@@ -4085,10 +4043,8 @@
   od_encode_checkpoint(&x->daala_enc, &buf);
 #endif
 #if CONFIG_PALETTE
-  const int rows = block_size_high[bsize] >> (xd->plane[1].subsampling_y);
-  const int cols = block_size_wide[bsize] >> (xd->plane[1].subsampling_x);
   PALETTE_MODE_INFO palette_mode_info;
-  PALETTE_MODE_INFO *const pmi = &xd->mi[0]->mbmi.palette_mode_info;
+  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
   uint8_t *best_palette_color_map = NULL;
 #endif  // CONFIG_PALETTE
 #if CONFIG_EXT_INTRA
@@ -4219,13 +4175,6 @@
   mbmi->uv_mode = mode_selected;
 #if CONFIG_PALETTE
   pmi->palette_size[1] = palette_mode_info.palette_size[1];
-  if (palette_mode_info.palette_size[1] > 0) {
-    memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
-           palette_mode_info.palette_colors + PALETTE_MAX_SIZE,
-           2 * PALETTE_MAX_SIZE * sizeof(palette_mode_info.palette_colors[0]));
-    memcpy(xd->plane[1].color_index_map, best_palette_color_map,
-           rows * cols * sizeof(best_palette_color_map[0]));
-  }
 #endif  // CONFIG_PALETTE
 
   return best_rd;
@@ -9942,7 +9891,6 @@
 #if CONFIG_PALETTE
   // Only try palette mode when the best mode so far is an intra mode.
   if (cm->allow_screen_content_tools && !is_inter_mode(best_mbmode.mode)) {
-    PREDICTION_MODE mode_selected;
     int rate2 = 0;
 #if CONFIG_SUPERTX
     int best_rate_nocoef;
@@ -9950,31 +9898,22 @@
     int64_t distortion2 = 0, dummy_rd = best_rd, this_rd;
     int skippable = 0, rate_overhead_palette = 0;
     RD_STATS rd_stats_y;
-    TX_SIZE best_tx_size, uv_tx;
-    TX_TYPE best_tx_type;
-    PALETTE_MODE_INFO palette_mode_info;
+    TX_SIZE uv_tx;
     uint8_t *const best_palette_color_map =
         x->palette_buffer->best_palette_color_map;
     uint8_t *const color_map = xd->plane[0].color_index_map;
+    MB_MODE_INFO mbmi_dummy;
 
     mbmi->mode = DC_PRED;
     mbmi->uv_mode = DC_PRED;
     mbmi->ref_frame[0] = INTRA_FRAME;
     mbmi->ref_frame[1] = NONE;
-    palette_mode_info.palette_size[0] = 0;
     rate_overhead_palette = rd_pick_palette_intra_sby(
-        cpi, x, bsize, palette_ctx, intra_mode_cost[DC_PRED],
-        &palette_mode_info, best_palette_color_map, &best_tx_size,
-        &best_tx_type, &mode_selected, &dummy_rd);
-    if (palette_mode_info.palette_size[0] == 0) goto PALETTE_EXIT;
-
-    pmi->palette_size[0] = palette_mode_info.palette_size[0];
-    if (palette_mode_info.palette_size[0] > 0) {
-      memcpy(pmi->palette_colors, palette_mode_info.palette_colors,
-             PALETTE_MAX_SIZE * sizeof(palette_mode_info.palette_colors[0]));
-      memcpy(color_map, best_palette_color_map,
-             rows * cols * sizeof(best_palette_color_map[0]));
-    }
+        cpi, x, bsize, palette_ctx, intra_mode_cost[DC_PRED], &mbmi_dummy,
+        best_palette_color_map, &dummy_rd);
+    if (pmi->palette_size[0] == 0) goto PALETTE_EXIT;
+    memcpy(color_map, best_palette_color_map,
+           rows * cols * sizeof(best_palette_color_map[0]));
     super_block_yrd(cpi, x, &rd_stats_y, bsize, best_rd);
     if (rd_stats_y.rate == INT_MAX) goto PALETTE_EXIT;
     uv_tx = uv_txsize_lookup[bsize][mbmi->tx_size][xd->plane[1].subsampling_x]
@@ -9993,10 +9932,11 @@
     }
     mbmi->uv_mode = mode_uv[uv_tx];
     pmi->palette_size[1] = pmi_uv[uv_tx].palette_size[1];
-    if (pmi->palette_size[1] > 0)
+    if (pmi->palette_size[1] > 0) {
       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
              pmi_uv[uv_tx].palette_colors + PALETTE_MAX_SIZE,
              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
+    }
 #if CONFIG_EXT_INTRA
     mbmi->angle_delta[1] = uv_angle_delta[uv_tx];
 #endif  // CONFIG_EXT_INTRA
@@ -10049,7 +9989,7 @@
   // avoid a unit test failure
   if (!xd->lossless[mbmi->segment_id] &&
 #if CONFIG_PALETTE
-      mbmi->palette_mode_info.palette_size[0] == 0 &&
+      pmi->palette_size[0] == 0 &&
 #endif  // CONFIG_PALETTE
       !dc_skipped && best_mode_index >= 0 &&
       best_intra_rd < (best_rd + (best_rd >> 3))) {