Add some comments for av1_pick_tx_size_type_yrd()

Also includes some code optimizations.

BUG=aomedia:2617

Change-Id: Ib6ff03c3247d938396e1532065b34b93d8d4f8ee
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 4a1e985..6967248 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -3189,41 +3189,50 @@
   return rd;
 }
 
-// Search for best transform size and type for luma inter blocks.
+// Return 1 to terminate transform search early. The decision is made based on
+// the comparison with the reference RD cost and the model-estimated RD cost.
+static AOM_INLINE int model_based_tx_search_prune(const AV1_COMP *cpi,
+                                                  MACROBLOCK *x,
+                                                  BLOCK_SIZE bsize,
+                                                  int64_t ref_best_rd) {
+  const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
+  assert(level >= 0 && level <= 2);
+  int model_rate;
+  int64_t model_dist;
+  int model_skip;
+  MACROBLOCKD *const xd = &x->e_mbd;
+  model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
+      cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
+      NULL, NULL, NULL);
+  if (model_skip) return 0;
+  const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
+  // TODO(debargha, urvang): Improve the model and make the check below
+  // tighter.
+  static const int prune_factor_by8[] = { 3, 5 };
+  const int factor = prune_factor_by8[level - 1];
+  return ((model_rd * factor) >> 3) > ref_best_rd;
+}
+
+// Search for best transform size and type for luma inter blocks. The best
+// transform size and type, if found, will be saved in the MB_MODE_INFO
+// structure, and the corresponding RD stats will be saved in rd_stats.
 void av1_pick_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
                                RD_STATS *rd_stats, BLOCK_SIZE bsize,
                                int64_t ref_best_rd) {
-  const AV1_COMMON *cm = &cpi->common;
   MACROBLOCKD *const xd = &x->e_mbd;
   assert(is_inter_block(xd->mi[0]));
 
   av1_invalid_rd_stats(rd_stats);
 
+  // If modeled RD cost is a lot worse than the best so far, terminate early.
   if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
       ref_best_rd != INT64_MAX) {
-    int model_rate;
-    int64_t model_dist;
-    int model_skip;
-    model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
-        cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
-        NULL, NULL, NULL);
-    const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
-    // If the modeled rd is a lot worse than the best so far, breakout.
-    // TODO(debargha, urvang): Improve the model and make the check below
-    // tighter.
-    assert(cpi->sf.tx_sf.model_based_prune_tx_search_level >= 0 &&
-           cpi->sf.tx_sf.model_based_prune_tx_search_level <= 2);
-    static const int prune_factor_by8[] = { 3, 5 };
-    if (!model_skip &&
-        ((model_rd *
-          prune_factor_by8[cpi->sf.tx_sf.model_based_prune_tx_search_level -
-                           1]) >>
-         3) > ref_best_rd)
-      return;
+    if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
   }
 
+  // Hashing based speed feature. If the hash of the prediction residue block is
+  // found in the hash table, use previous search results and terminate early.
   uint32_t hash = 0;
-  int32_t match_index = -1;
   MB_RD_RECORD *mb_rd_record = NULL;
   const int mi_row = x->e_mbd.mi_row;
   const int mi_col = x->e_mbd.mi_col;
@@ -3238,7 +3247,7 @@
   if (is_mb_rd_hash_enabled) {
     hash = get_block_residue_hash(x, bsize);
     mb_rd_record = &x->mb_rd_record;
-    match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
+    const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
     if (match_index != -1) {
       MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
       fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
@@ -3250,7 +3259,8 @@
   // context and terminate early.
   int64_t dist;
   if (x->predict_skip_level &&
-      predict_skip_flag(x, bsize, &dist, cm->features.reduced_tx_set_used)) {
+      predict_skip_flag(x, bsize, &dist,
+                        cpi->common.features.reduced_tx_set_used)) {
     set_skip_flag(x, rd_stats, bsize, dist);
     // Save the RD search results into tx_rd_record.
     if (is_mb_rd_hash_enabled)
@@ -3261,8 +3271,9 @@
   ++x->tx_search_count;
 #endif  // CONFIG_SPEED_STATS
 
-  // Precompute residual hashes and find existing or add new RD records to
-  // store and reuse rate and distortion values to speed up TX size search.
+  // Pre-compute residue hashes (transform block level) and find existing or
+  // add new RD records to store and reuse rate and distortion values to speed
+  // up TX size/type search.
   TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
   int found_rd_info = 0;
   if (ref_best_rd != INT64_MAX && within_border &&
@@ -3270,24 +3281,19 @@
     found_rd_info = find_tx_size_rd_records(x, bsize, matched_rd_info);
   }
 
-  int found = 0;
-  RD_STATS this_rd_stats;
-  av1_init_rd_stats(&this_rd_stats);
   const int64_t rd =
-      select_tx_size_and_type(cpi, x, &this_rd_stats, bsize, ref_best_rd,
+      select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd,
                               found_rd_info ? matched_rd_info : NULL);
 
-  if (rd < INT64_MAX) {
-    *rd_stats = this_rd_stats;
-    found = 1;
+  if (rd == INT64_MAX) {
+    // We should always find at least one candidate unless ref_best_rd is less
+    // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
+    // might have failed to find something better)
+    assert(ref_best_rd != INT64_MAX);
+    av1_invalid_rd_stats(rd_stats);
+    return;
   }
 
-  // We should always find at least one candidate unless ref_best_rd is less
-  // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
-  // might have failed to find something better)
-  assert(IMPLIES(!found, ref_best_rd != INT64_MAX));
-  if (!found) return;
-
   // Save the RD search results into tx_rd_record.
   if (is_mb_rd_hash_enabled) {
     assert(mb_rd_record != NULL);