Code cleanup and optimization in av1_txfm_search()

BUG=aomedia:2617

Change-Id: I3bc4255b5e195bb837184064b19bab2e0670e0f0
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index e315b0c..4a1e985 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -3467,23 +3467,17 @@
   }
 }
 
+// This function combines y and uv planes' transform search processes together
+// for inter-predicted blocks (including IntraBC), when the prediction is
+// already generated. It first does subtraction to obtain the prediction error.
+// Then it calls av1_pick_tx_size_type_yrd/av1_super_block_yrd and
+// av1_super_block_uvrd sequentially and handles the early terminations
+// happening in those functions. At the end, it computes the
+// rd_stats/_y/_uv accordingly.
 int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
                     RD_STATS *rd_stats, RD_STATS *rd_stats_y,
                     RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
-  /*
-   * This function combines y and uv planes' transform search processes
-   * together, when the prediction is generated. It first does subtraction to
-   * obtain the prediction error. Then it calls
-   * av1_pick_tx_size_type_yrd/av1_super_block_yrd and av1_super_block_uvrd
-   * sequentially and handles the early terminations happening in those
-   * functions. At the end, it computes the rd_stats/_y/_uv accordingly.
-   */
-  const AV1_COMMON *cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
-  MB_MODE_INFO *const mbmi = xd->mi[0];
-  const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
-  const int64_t rd_thresh =
-      ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
   const int skip_ctx = av1_get_skip_context(xd);
   const int skip_flag_cost[2] = { x->skip_cost[skip_ctx][0],
                                   x->skip_cost[skip_ctx][1] };
@@ -3492,12 +3486,16 @@
   // Account for minimum skip and non_skip rd.
   // Eventually either one of them will be added to mode_rate
   const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
-
   if (min_header_rd_possible > ref_best_rd) {
     av1_invalid_rd_stats(rd_stats_y);
     return 0;
   }
 
+  const AV1_COMMON *cm = &cpi->common;
+  MB_MODE_INFO *const mbmi = xd->mi[0];
+  const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
+  const int64_t rd_thresh =
+      ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
   av1_init_rd_stats(rd_stats);
   av1_init_rd_stats(rd_stats_y);
   rd_stats->rate = mode_rate;
@@ -3555,32 +3553,27 @@
     av1_merge_rd_stats(rd_stats, rd_stats_uv);
   }
 
-  if (rd_stats->skip) {
-    rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
+  int choose_skip = rd_stats->skip;
+  if (!choose_skip && !xd->lossless[mbmi->segment_id]) {
+    const int64_t rdcost_no_skip = RDCOST(
+        x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_flag_cost[0],
+        rd_stats->dist);
+    const int64_t rdcost_skip =
+        RDCOST(x->rdmult, skip_flag_cost[1], rd_stats->sse);
+    if (rdcost_no_skip >= rdcost_skip) choose_skip = 1;
+  }
+  if (choose_skip) {
     rd_stats_y->rate = 0;
     rd_stats_uv->rate = 0;
+    rd_stats->rate = mode_rate + skip_flag_cost[1];
     rd_stats->dist = rd_stats->sse;
     rd_stats_y->dist = rd_stats_y->sse;
     rd_stats_uv->dist = rd_stats_uv->sse;
-    rd_stats->rate += skip_flag_cost[1];
     mbmi->skip = 1;
-    // here mbmi->skip temporarily plays a role as what this_skip2 does
-
-    const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
-    if (tmprd > ref_best_rd) return 0;
-  } else if (!xd->lossless[mbmi->segment_id] &&
-             (RDCOST(x->rdmult,
-                     rd_stats_y->rate + rd_stats_uv->rate + skip_flag_cost[0],
-                     rd_stats->dist) >=
-              RDCOST(x->rdmult, skip_flag_cost[1], rd_stats->sse))) {
-    rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
-    rd_stats->rate += skip_flag_cost[1];
-    rd_stats->dist = rd_stats->sse;
-    rd_stats_y->dist = rd_stats_y->sse;
-    rd_stats_uv->dist = rd_stats_uv->sse;
-    rd_stats_y->rate = 0;
-    rd_stats_uv->rate = 0;
-    mbmi->skip = 1;
+    if (rd_stats->skip) {
+      const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
+      if (tmprd > ref_best_rd) return 0;
+    }
   } else {
     rd_stats->rate += skip_flag_cost[0];
     mbmi->skip = 0;