Refactor handle_intra_mode()

Change-Id: I8a72cf817c1bcdbfc981530e37e72c4824fe1771
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index e853c3f..da12066 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -1064,6 +1064,7 @@
 
 static INLINE int av1_allow_palette(int allow_screen_content_tools,
                                     BLOCK_SIZE sb_type) {
+  assert(sb_type < BLOCK_SIZES_ALL);
   return allow_screen_content_tools && block_size_wide[sb_type] <= 64 &&
          block_size_high[sb_type] <= 64 && sb_type >= BLOCK_8X8;
 }
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 48e0300..447d0b4 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -11921,22 +11921,15 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = xd->mi[0];
   assert(mbmi->ref_frame[0] == INTRA_FRAME);
-  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
-  const int try_palette =
-      cpi->oxcf.enable_palette &&
-      av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type);
-  const int *const intra_mode_cost = x->mbmode_cost[size_group_lookup[bsize]];
+  const PREDICTION_MODE mode = mbmi->mode;
+  const int mode_cost = x->mbmode_cost[size_group_lookup[bsize]][mode];
   const int intra_cost_penalty = av1_get_intra_cost_penalty(
       cm->base_qindex, cm->y_dc_delta_q, cm->seq_params.bit_depth);
-  const int rows = block_size_high[bsize];
-  const int cols = block_size_wide[bsize];
-  const int num_planes = av1_num_planes(cm);
   const int skip_ctx = av1_get_skip_context(xd);
 
-  int known_rate = intra_mode_cost[mbmi->mode];
+  int known_rate = mode_cost;
   known_rate += ref_frame_cost;
-  if (mbmi->mode != DC_PRED && mbmi->mode != PAETH_PRED)
-    known_rate += intra_cost_penalty;
+  if (mode != DC_PRED && mode != PAETH_PRED) known_rate += intra_cost_penalty;
   known_rate += AOMMIN(x->skip_cost[skip_ctx][0], x->skip_cost[skip_ctx][1]);
   const int64_t known_rd = RDCOST(x->rdmult, known_rate, 0);
   if (known_rd > search_state->best_rd) {
@@ -11944,107 +11937,116 @@
     return INT64_MAX;
   }
 
-  TX_SIZE uv_tx;
-  int is_directional_mode = av1_is_directional_mode(mbmi->mode);
+  const int is_directional_mode = av1_is_directional_mode(mode);
   if (is_directional_mode && av1_use_angle_delta(bsize) &&
       cpi->oxcf.enable_angle_delta) {
-    int rate_dummy;
-    int64_t model_rd = INT64_MAX;
     if (sf->intra_angle_estimation && !search_state->angle_stats_ready) {
       const int src_stride = x->plane[0].src.stride;
       const uint8_t *src = x->plane[0].src.buf;
+      const int rows = block_size_high[bsize];
+      const int cols = block_size_wide[bsize];
       angle_estimation(src, src_stride, rows, cols, bsize, is_cur_buf_hbd(xd),
                        search_state->directional_mode_skip_mask);
       search_state->angle_stats_ready = 1;
     }
-    if (search_state->directional_mode_skip_mask[mbmi->mode]) return INT64_MAX;
+    if (search_state->directional_mode_skip_mask[mode]) return INT64_MAX;
     av1_init_rd_stats(rd_stats_y);
     rd_stats_y->rate = INT_MAX;
+    int64_t model_rd = INT64_MAX;
+    int rate_dummy;
     rd_pick_intra_angle_sby(cpi, x, mi_row, mi_col, &rate_dummy, rd_stats_y,
-                            bsize, intra_mode_cost[mbmi->mode],
-                            search_state->best_rd, &model_rd, 0);
+                            bsize, mode_cost, search_state->best_rd, &model_rd,
+                            0);
+
   } else {
     av1_init_rd_stats(rd_stats_y);
     mbmi->angle_delta[PLANE_TYPE_Y] = 0;
     super_block_yrd(cpi, x, rd_stats_y, bsize, search_state->best_rd);
   }
-  uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
-  memcpy(best_blk_skip, x->blk_skip,
-         sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
-  int try_filter_intra = 0;
-  int64_t best_rd_tmp = INT64_MAX;
-  if (mbmi->mode == DC_PRED && av1_filter_intra_allowed_bsize(cm, bsize)) {
+
+  if (mode == DC_PRED && av1_filter_intra_allowed_bsize(cm, bsize)) {
+    int try_filter_intra = 0;
+    int64_t best_rd_so_far = INT64_MAX;
     if (rd_stats_y->rate != INT_MAX) {
-      const int tmp_rate = rd_stats_y->rate + x->filter_intra_cost[bsize][0] +
-                           intra_mode_cost[mbmi->mode];
-      best_rd_tmp = RDCOST(x->rdmult, tmp_rate, rd_stats_y->dist);
-      try_filter_intra = !((best_rd_tmp / 2) > search_state->best_rd);
+      const int tmp_rate =
+          rd_stats_y->rate + x->filter_intra_cost[bsize][0] + mode_cost;
+      best_rd_so_far = RDCOST(x->rdmult, tmp_rate, rd_stats_y->dist);
+      try_filter_intra = (best_rd_so_far / 2) <= search_state->best_rd;
     } else {
-      try_filter_intra = !(search_state->best_mbmode.skip);
-    }
-  }
-  if (try_filter_intra) {
-    RD_STATS rd_stats_y_fi;
-    int filter_intra_selected_flag = 0;
-    TX_SIZE best_tx_size = mbmi->tx_size;
-    TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
-    memcpy(best_txk_type, mbmi->txk_type,
-           sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
-    FILTER_INTRA_MODE best_fi_mode = FILTER_DC_PRED;
-
-    mbmi->filter_intra_mode_info.use_filter_intra = 1;
-    for (FILTER_INTRA_MODE fi_mode = FILTER_DC_PRED;
-         fi_mode < FILTER_INTRA_MODES; ++fi_mode) {
-      int64_t this_rd_tmp;
-      mbmi->filter_intra_mode_info.filter_intra_mode = fi_mode;
-      super_block_yrd(cpi, x, &rd_stats_y_fi, bsize, search_state->best_rd);
-      if (rd_stats_y_fi.rate == INT_MAX) {
-        continue;
-      }
-      const int this_rate_tmp =
-          rd_stats_y_fi.rate +
-          intra_mode_info_cost_y(cpi, x, mbmi, bsize,
-                                 intra_mode_cost[mbmi->mode]);
-      this_rd_tmp = RDCOST(x->rdmult, this_rate_tmp, rd_stats_y_fi.dist);
-
-      if (this_rd_tmp != INT64_MAX && this_rd_tmp / 2 > search_state->best_rd) {
-        break;
-      }
-      if (this_rd_tmp < best_rd_tmp) {
-        best_tx_size = mbmi->tx_size;
-        memcpy(best_txk_type, mbmi->txk_type,
-               sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
-        memcpy(best_blk_skip, x->blk_skip,
-               sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
-        best_fi_mode = fi_mode;
-        *rd_stats_y = rd_stats_y_fi;
-        filter_intra_selected_flag = 1;
-        best_rd_tmp = this_rd_tmp;
-      }
+      try_filter_intra = !search_state->best_mbmode.skip;
     }
 
-    mbmi->tx_size = best_tx_size;
-    memcpy(mbmi->txk_type, best_txk_type,
-           sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
-    memcpy(x->blk_skip, best_blk_skip,
-           sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
+    if (try_filter_intra) {
+      RD_STATS rd_stats_y_fi;
+      int filter_intra_selected_flag = 0;
+      TX_SIZE best_tx_size = mbmi->tx_size;
+      TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
+      memcpy(best_txk_type, mbmi->txk_type,
+             sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
+      FILTER_INTRA_MODE best_fi_mode = FILTER_DC_PRED;
+      uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
+      memcpy(best_blk_skip, x->blk_skip,
+             sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
 
-    if (filter_intra_selected_flag) {
       mbmi->filter_intra_mode_info.use_filter_intra = 1;
-      mbmi->filter_intra_mode_info.filter_intra_mode = best_fi_mode;
-    } else {
-      mbmi->filter_intra_mode_info.use_filter_intra = 0;
+      for (FILTER_INTRA_MODE fi_mode = FILTER_DC_PRED;
+           fi_mode < FILTER_INTRA_MODES; ++fi_mode) {
+        mbmi->filter_intra_mode_info.filter_intra_mode = fi_mode;
+        super_block_yrd(cpi, x, &rd_stats_y_fi, bsize, search_state->best_rd);
+        if (rd_stats_y_fi.rate == INT_MAX) continue;
+        const int this_rate_tmp =
+            rd_stats_y_fi.rate +
+            intra_mode_info_cost_y(cpi, x, mbmi, bsize, mode_cost);
+        const int64_t this_rd_tmp =
+            RDCOST(x->rdmult, this_rate_tmp, rd_stats_y_fi.dist);
+
+        if (this_rd_tmp != INT64_MAX &&
+            this_rd_tmp / 2 > search_state->best_rd) {
+          break;
+        }
+        if (this_rd_tmp < best_rd_so_far) {
+          best_tx_size = mbmi->tx_size;
+          memcpy(best_txk_type, mbmi->txk_type,
+                 sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
+          memcpy(best_blk_skip, x->blk_skip,
+                 sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
+          best_fi_mode = fi_mode;
+          *rd_stats_y = rd_stats_y_fi;
+          filter_intra_selected_flag = 1;
+          best_rd_so_far = this_rd_tmp;
+        }
+      }
+
+      mbmi->tx_size = best_tx_size;
+      memcpy(mbmi->txk_type, best_txk_type,
+             sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
+      memcpy(x->blk_skip, best_blk_skip,
+             sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
+
+      if (filter_intra_selected_flag) {
+        mbmi->filter_intra_mode_info.use_filter_intra = 1;
+        mbmi->filter_intra_mode_info.filter_intra_mode = best_fi_mode;
+      } else {
+        mbmi->filter_intra_mode_info.use_filter_intra = 0;
+      }
     }
   }
+
   if (rd_stats_y->rate == INT_MAX) return INT64_MAX;
+
   const int mode_cost_y =
-      intra_mode_info_cost_y(cpi, x, mbmi, bsize, intra_mode_cost[mbmi->mode]);
+      intra_mode_info_cost_y(cpi, x, mbmi, bsize, mode_cost);
   av1_init_rd_stats(rd_stats);
   av1_init_rd_stats(rd_stats_uv);
+  const int num_planes = av1_num_planes(cm);
   if (num_planes > 1) {
-    uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
+    PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
+    const int try_palette =
+        cpi->oxcf.enable_palette &&
+        av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type);
+    const TX_SIZE uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
     if (search_state->rate_uv_intra[uv_tx] == INT_MAX) {
-      int rate_y =
+      const int rate_y =
           rd_stats_y->skip ? x->skip_cost[skip_ctx][1] : rd_stats_y->rate;
       const int64_t rdy =
           RDCOST(x->rdmult, rate_y + mode_cost_y, rd_stats_y->dist);
@@ -12083,6 +12085,7 @@
     }
     mbmi->angle_delta[PLANE_TYPE_UV] = search_state->uv_angle_delta[uv_tx];
   }
+
   rd_stats->rate = rd_stats_y->rate + mode_cost_y;
   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(bsize)) {
     // super_block_yrd above includes the cost of the tx_size in the
@@ -12093,13 +12096,14 @@
   }
   if (num_planes > 1 && !x->skip_chroma_rd) {
     const int uv_mode_cost =
-        x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mbmi->uv_mode];
+        x->intra_uv_mode_cost[is_cfl_allowed(xd)][mode][mbmi->uv_mode];
     rd_stats->rate +=
         rd_stats_uv->rate +
         intra_mode_info_cost_uv(cpi, x, mbmi, bsize, uv_mode_cost);
   }
-  if (mbmi->mode != DC_PRED && mbmi->mode != PAETH_PRED)
+  if (mode != DC_PRED && mode != PAETH_PRED) {
     rd_stats->rate += intra_cost_penalty;
+  }
   rd_stats->dist = rd_stats_y->dist + rd_stats_uv->dist;
 
   // Estimate the reference frame signaling cost and add it
@@ -12121,7 +12125,7 @@
   // Keep record of best intra rd
   if (this_rd < search_state->best_intra_rd) {
     search_state->best_intra_rd = this_rd;
-    search_state->best_intra_mode = mbmi->mode;
+    search_state->best_intra_mode = mode;
   }
 
   if (sf->skip_intra_in_interframe) {
@@ -12131,9 +12135,10 @@
   }
 
   if (!disable_skip) {
-    for (int i = 0; i < REFERENCE_MODES; ++i)
+    for (int i = 0; i < REFERENCE_MODES; ++i) {
       search_state->best_pred_rd[i] =
           AOMMIN(search_state->best_pred_rd[i], this_rd);
+    }
   }
   return this_rd;
 }
@@ -12595,6 +12600,7 @@
   int64_t ref_frame_rd[REF_FRAMES] = { INT64_MAX, INT64_MAX, INT64_MAX,
                                        INT64_MAX, INT64_MAX, INT64_MAX,
                                        INT64_MAX, INT64_MAX };
+  const int skip_ctx = av1_get_skip_context(xd);
   for (int midx = 0; midx < MAX_MODES; ++midx) {
     if (inter_mode_compatible_skip(cpi, x, bsize, midx)) continue;
 
@@ -12793,8 +12799,7 @@
           // Therfore, we should avoid updating best_rate_y and best_rate_uv
           // here. These two values will be updated when txfm_search is called
           search_state.best_rate_y =
-              rate_y +
-              x->skip_cost[av1_get_skip_context(xd)][this_skip2 || skippable];
+              rate_y + x->skip_cost[skip_ctx][this_skip2 || skippable];
           search_state.best_rate_uv = rate_uv;
         }
         memcpy(ctx->blk_skip, x->blk_skip,
@@ -12888,7 +12893,6 @@
                          search_state.best_rd)) {
           continue;
         } else if (cpi->sf.inter_mode_rd_model_estimation == 1) {
-          const int skip_ctx = av1_get_skip_context(xd);
           inter_mode_data_push(tile_data, mbmi->sb_type, rd_stats.sse,
                                rd_stats.dist,
                                rd_stats_y.rate + rd_stats_uv.rate +
@@ -12910,7 +12914,7 @@
         search_state.best_mode_skippable = rd_stats.skip;
         search_state.best_rate_y =
             rd_stats_y.rate +
-            x->skip_cost[av1_get_skip_context(xd)][rd_stats.skip || mbmi->skip];
+            x->skip_cost[skip_ctx][rd_stats.skip || mbmi->skip];
         search_state.best_rate_uv = rd_stats_uv.rate;
         memcpy(ctx->blk_skip, x->blk_skip,
                sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
@@ -12924,28 +12928,19 @@
 #if CONFIG_COLLECT_COMPONENT_TIMING
   start_timing(cpi, handle_intra_mode_time);
 #endif
+  const int intra_ref_frame_cost = ref_costs_single[INTRA_FRAME];
   for (int j = 0; j < intra_mode_num; ++j) {
-    const int mode_index = intra_mode_idx_ls[j];
-    const MV_REFERENCE_FRAME ref_frame =
-        av1_mode_order[mode_index].ref_frame[0];
-    assert(av1_mode_order[mode_index].ref_frame[1] == NONE_FRAME);
-    assert(ref_frame == INTRA_FRAME);
     if (sf->skip_intra_in_interframe && search_state.skip_intra_modes) break;
+    const int mode_index = intra_mode_idx_ls[j];
+    assert(av1_mode_order[mode_index].ref_frame[0] == INTRA_FRAME);
+    assert(av1_mode_order[mode_index].ref_frame[1] == NONE_FRAME);
     init_mbmi(mbmi, mode_index, cm);
     x->skip = 0;
-    set_ref_ptrs(cm, xd, INTRA_FRAME, NONE_FRAME);
-
-    // Select prediction reference frames.
-    for (i = 0; i < num_planes; i++) {
-      xd->plane[i].pre[0] = yv12_mb[ref_frame][i];
-    }
 
     RD_STATS intra_rd_stats, intra_rd_stats_y, intra_rd_stats_uv;
-
-    const int ref_frame_cost = ref_costs_single[ref_frame];
     intra_rd_stats.rdcost = handle_intra_mode(
-        &search_state, cpi, x, bsize, mi_row, mi_col, ref_frame_cost, ctx, 0,
-        &intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv);
+        &search_state, cpi, x, bsize, mi_row, mi_col, intra_ref_frame_cost, ctx,
+        0, &intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv);
     if (intra_rd_stats.rdcost < search_state.best_rd) {
       search_state.best_rd = intra_rd_stats.rdcost;
       // Note index of best mode so far
@@ -12956,8 +12951,7 @@
       search_state.best_skip2 = 0;
       search_state.best_mode_skippable = intra_rd_stats.skip;
       search_state.best_rate_y =
-          intra_rd_stats_y.rate +
-          x->skip_cost[av1_get_skip_context(xd)][intra_rd_stats.skip];
+          intra_rd_stats_y.rate + x->skip_cost[skip_ctx][intra_rd_stats.skip];
       search_state.best_rate_uv = intra_rd_stats_uv.rate;
       memcpy(ctx->blk_skip, x->blk_skip,
              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);