Move the tx pruning flags into MACROBLOCK

So they can be generated at prediction block, and then easily
accessed by transform block.

Change-Id: I376042e8d57e00586d3cf90e237544e705b77e8b
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 0b334cf..4e6b8ee 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -365,6 +365,8 @@
   int comp_idx_cost[COMP_INDEX_CONTEXTS][2];
   int comp_group_idx_cost[COMP_GROUP_IDX_CONTEXTS][2];
 #endif  // CONFIG_JNT_COMP
+  // Bit flags for pruning tx type search, tx split, etc.
+  int tx_search_prune[EXT_TX_SET_TYPES];
 };
 
 static INLINE int is_rect_tx_allowed_bsize(BLOCK_SIZE bsize) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index c6b0c16..dc07334 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1255,20 +1255,15 @@
   return (score > av1_prune_tx_split_thresholds[bidx]);
 }
 
-static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type,
-                       int tx_type_pruning_aggressiveness,
-                       int use_tx_split_prune) {
-  if (bsize >= BLOCK_32X32) return 0;
+static void prune_tx_2D(BLOCK_SIZE bsize, MACROBLOCK *x,
+                        TX_TYPE_PRUNE_MODE prune_mode, int use_tx_split_prune) {
+  if (bsize >= BLOCK_32X32) return;
   aom_clear_system_state();
   const struct macroblock_plane *const p = &x->plane[0];
-  const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
-  const float score_thresh =
-      av1_prune_2D_adaptive_thresholds[bidx]
-                                      [tx_type_pruning_aggressiveness - 1];
   float hfeatures[16], vfeatures[16];
   float hscores[4], vscores[4];
   float scores_2D[16];
-  int tx_type_table_2D[16] = {
+  const int tx_type_table_2D[16] = {
     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
@@ -1284,7 +1279,7 @@
   get_horver_correlation_full(p->src_diff, bw, bw, bh,
                               &hfeatures[hfeatures_num - 1],
                               &vfeatures[vfeatures_num - 1]);
-
+  const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
   const float *fc1_hor = av1_prune_2D_learned_weights_hor[bidx];
   const float *b1_hor =
       fc1_hor + av1_prune_2D_num_hidden_units_hor[bidx] * hfeatures_num;
@@ -1314,22 +1309,43 @@
   score_2D_average /= 16;
   score_2D_transform_pow8(scores_2D, (20 - score_2D_average));
 
-  // Always keep the TX type with the highest score, prune all others with
-  // score below score_thresh.
-  int max_score_i = 0;
-  float max_score = 0.0f;
-  for (int i = 0; i < 16; i++) {
-    if (scores_2D[i] > max_score &&
-        av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
-      max_score = scores_2D[i];
-      max_score_i = i;
+  // TODO(huisu@google.com): support more tx set types.
+  const int tx_set_types[2] = { EXT_TX_SET_ALL16, EXT_TX_SET_DTT9_IDTX_1DDCT };
+  for (int tx_set_idx = 0; tx_set_idx < 2; ++tx_set_idx) {
+    const int tx_set_type = tx_set_types[tx_set_idx];
+    // Always keep the TX type with the highest score, prune all others with
+    // score below score_thresh.
+    int max_score_i = 0;
+    float max_score = 0.0f;
+    for (int i = 0; i < 16; i++) {
+      if (scores_2D[i] > max_score &&
+          av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
+        max_score = scores_2D[i];
+        max_score_i = i;
+      }
     }
-  }
 
-  int prune_bitmask = 0;
-  for (int i = 0; i < 16; i++) {
-    if (scores_2D[i] < score_thresh && i != max_score_i)
-      prune_bitmask |= (1 << tx_type_table_2D[i]);
+    int pruning_aggressiveness = 0;
+    if (prune_mode == PRUNE_2D_ACCURATE) {
+      if (tx_set_type == EXT_TX_SET_ALL16)
+        pruning_aggressiveness = 6;
+      else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+        pruning_aggressiveness = 4;
+    } else if (prune_mode == PRUNE_2D_FAST) {
+      if (tx_set_type == EXT_TX_SET_ALL16)
+        pruning_aggressiveness = 10;
+      else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+        pruning_aggressiveness = 7;
+    }
+    const float score_thresh =
+        av1_prune_2D_adaptive_thresholds[bidx][pruning_aggressiveness - 1];
+
+    int prune_bitmask = 0;
+    for (int i = 0; i < 16; i++) {
+      if (scores_2D[i] < score_thresh && i != max_score_i)
+        prune_bitmask |= (1 << tx_type_table_2D[i]);
+    }
+    x->tx_search_prune[tx_set_type] = prune_bitmask;
   }
 
   // Also apply TX size pruning if it's turned on. The value
@@ -1342,51 +1358,42 @@
         prune_tx_split(bsize, p->src_diff, hfeatures[hfeatures_num - 1],
                        vfeatures[vfeatures_num - 1]);
   }
-  prune_bitmask |= (prune_tx_split_flag << TX_TYPES);
-  return prune_bitmask;
+  x->tx_search_prune[0] |= (prune_tx_split_flag << TX_TYPES);
 }
 
-static int prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
-                    const MACROBLOCKD *const xd, int tx_set_type,
-                    int use_tx_split_prune) {
+static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
+                     const MACROBLOCKD *const xd, int tx_set_type,
+                     int use_tx_split_prune) {
   int tx_set = ext_tx_set_index[1][tx_set_type];
   assert(tx_set >= 0);
+  av1_zero(x->tx_search_prune);
   const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
-
   switch (cpi->sf.tx_type_search.prune_mode) {
-    case NO_PRUNE: return 0; break;
+    case NO_PRUNE: return;
     case PRUNE_ONE:
-      if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return 0;
-      return prune_one_for_sby(cpi, bsize, x, xd);
+      if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return;
+      x->tx_search_prune[tx_set_type] = prune_one_for_sby(cpi, bsize, x, xd);
       break;
     case PRUNE_TWO:
       if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
-        if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return 0;
-        return prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
+        if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return;
+        x->tx_search_prune[tx_set_type] =
+            prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
       }
-      if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
-        return prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
-      return prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
+      if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) {
+        x->tx_search_prune[tx_set_type] =
+            prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
+      }
+      x->tx_search_prune[tx_set_type] =
+          prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
       break;
     case PRUNE_2D_ACCURATE:
-      if (tx_set_type == EXT_TX_SET_ALL16)
-        return prune_tx_2D(bsize, x, tx_set_type, 6, use_tx_split_prune);
-      else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
-        return prune_tx_2D(bsize, x, tx_set_type, 4, use_tx_split_prune);
-      else
-        return 0;
-      break;
     case PRUNE_2D_FAST:
-      if (tx_set_type == EXT_TX_SET_ALL16)
-        return prune_tx_2D(bsize, x, tx_set_type, 10, use_tx_split_prune);
-      else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
-        return prune_tx_2D(bsize, x, tx_set_type, 7, use_tx_split_prune);
-      else
-        return 0;
+      prune_tx_2D(bsize, x, cpi->sf.tx_type_search.prune_mode,
+                  use_tx_split_prune);
       break;
+    default: assert(0);
   }
-  assert(0);
-  return 0;
 }
 
 static int do_tx_type_search(TX_TYPE tx_type, int prune,
@@ -2319,7 +2326,7 @@
 }
 
 static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
-                            TX_TYPE tx_type, TX_SIZE tx_size, int prune) {
+                            TX_TYPE tx_type, TX_SIZE tx_size) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   const MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   const int is_inter = is_inter_block(mbmi);
@@ -2337,7 +2344,8 @@
   if (!av1_ext_tx_used[tx_set_type][tx_type]) return 1;
   if (is_inter) {
     if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
-      if (!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
+      if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
+                             cpi->sf.tx_type_search.prune_mode))
         return 1;
     }
   }
@@ -2371,7 +2379,6 @@
   int s0 = x->skip_cost[skip_ctx][0];
   int s1 = x->skip_cost[skip_ctx][1];
   const int is_inter = is_inter_block(mbmi);
-  int prune = 0;
   av1_invalid_rd_stats(rd_stats);
 
   mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
@@ -2381,7 +2388,7 @@
 
   if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
       !x->use_default_inter_tx_type) {
-    prune = prune_tx(cpi, bs, x, xd, tx_set_type, 0);
+    prune_tx(cpi, bs, x, xd, tx_set_type, 0);
   }
 #if CONFIG_FILTER_INTRA
   if (skip_invalid_tx_size_for_filter_intra(mbmi, AOM_PLANE_Y, rd_stats)) {
@@ -2399,7 +2406,7 @@
             tx_type != get_default_tx_type(0, xd, mbmi->tx_size))
           continue;
         if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
-          if (!do_tx_type_search(tx_type, prune,
+          if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
                                  cpi->sf.tx_type_search.prune_mode))
             continue;
         }
@@ -2440,6 +2447,8 @@
                      mbmi->tx_size, cpi->sf.use_fast_coef_costing);
   }
   mbmi->tx_type = best_tx_type;
+  // Reset the pruning flags.
+  av1_zero(x->tx_search_prune);
 }
 
 static void choose_smallest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
@@ -2507,10 +2516,9 @@
     depth = MAX_TX_DEPTH;
   }
 
-  int prune = 0;
   if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
       !x->use_default_inter_tx_type) {
-    prune = prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0);
+    prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0);
   }
 
   last_rd = INT64_MAX;
@@ -2526,7 +2534,7 @@
     TX_TYPE tx_type;
     for (tx_type = tx_start; tx_type < tx_end; ++tx_type) {
       RD_STATS this_rd_stats;
-      if (skip_txfm_search(cpi, x, bs, tx_type, n, prune)) continue;
+      if (skip_txfm_search(cpi, x, bs, tx_type, n)) continue;
 
       if (mbmi->ref_mv_idx > 0) x->rd_model = LOW_TXFM_RD;
       rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, n);
@@ -2578,6 +2586,8 @@
   memcpy(x->blk_skip[0], best_blk_skip, sizeof(best_blk_skip[0]) * n4);
 
   mbmi->min_tx_size = mbmi->tx_size;
+  // Reset the pruning flags.
+  av1_zero(x->tx_search_prune);
 }
 
 static void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
@@ -3820,7 +3830,6 @@
                             TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
                             RD_STATS *rd_stats, int64_t ref_best_rd,
                             int *is_cost_valid, int fast,
-                            int tx_split_prune_flag,
                             TX_SIZE_RD_INFO_NODE *rd_info_node) {
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
@@ -3920,6 +3929,9 @@
 #endif
   }
 
+  int tx_split_prune_flag = 0;
+  if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE)
+    tx_split_prune_flag = ((x->tx_search_prune[0] >> TX_TYPES) & 1);
   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH && tx_split_prune_flag == 0) {
     const TX_SIZE sub_txs = sub_tx_size_map[1][tx_size];
     const int bsw = tx_size_wide_unit[sub_txs];
@@ -3947,7 +3959,7 @@
         select_tx_block(
             cpi, x, offsetr, offsetc, plane, block, sub_txs, depth + 1,
             plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
-            ref_best_rd - tmp_rd, &this_cost_valid, fast, 0,
+            ref_best_rd - tmp_rd, &this_cost_valid, fast,
             (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
 
 #if CONFIG_DIST_8X8
@@ -4099,7 +4111,6 @@
 static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
                                    RD_STATS *rd_stats, BLOCK_SIZE bsize,
                                    int64_t ref_best_rd, int fast,
-                                   int tx_split_prune_flag,
                                    TX_SIZE_RD_INFO_NODE *rd_info_tree) {
   MACROBLOCKD *const xd = &x->e_mbd;
   int is_cost_valid = 1;
@@ -4138,7 +4149,7 @@
         select_tx_block(cpi, x, idy, idx, 0, block, max_tx_size, init_depth,
                         plane_bsize, ctxa, ctxl, tx_above, tx_left,
                         &pn_rd_stats, ref_best_rd - this_rd, &is_cost_valid,
-                        fast, tx_split_prune_flag, rd_info_tree);
+                        fast, rd_info_tree);
         if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
           av1_invalid_rd_stats(rd_stats);
           return;
@@ -4172,7 +4183,6 @@
                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
                                        int mi_row, int mi_col,
                                        int64_t ref_best_rd, TX_TYPE tx_type,
-                                       int tx_split_prune_flag,
                                        TX_SIZE_RD_INFO_NODE *rd_info_tree) {
   const int fast = cpi->sf.tx_size_search_method > USE_FULL_RD;
   const AV1_COMMON *const cm = &cpi->common;
@@ -4199,7 +4209,7 @@
 
   mbmi->tx_type = tx_type;
   select_inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, fast,
-                         tx_split_prune_flag, rd_info_tree);
+                         rd_info_tree);
   if (rd_stats->rate == INT_MAX) return INT64_MAX;
 
   mbmi->min_tx_size = mbmi->inter_tx_size[0][0];
@@ -4800,7 +4810,6 @@
 #endif
   const int n4 = bsize_to_num_blk(bsize);
   int idx, idy;
-  int prune = 0;
   // Get the tx_size 1 level down
   const TX_SIZE min_tx_size =
       sub_tx_size_map[1][max_txsize_rect_lookup[1][bsize]];
@@ -4851,16 +4860,12 @@
 
   if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
       !x->use_default_inter_tx_type && !xd->lossless[mbmi->segment_id]) {
-    prune = prune_tx(cpi, bsize, x, xd, tx_set_type,
-                     cpi->sf.tx_type_search.use_tx_size_pruning);
+    prune_tx(cpi, bsize, x, xd, tx_set_type,
+             cpi->sf.tx_type_search.use_tx_size_pruning);
   }
 
   int found = 0;
 
-  int tx_split_prune_flag = 0;
-  if (is_inter && cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE)
-    tx_split_prune_flag = ((prune >> TX_TYPES) & 1);
-
   for (tx_type = txk_start; tx_type < txk_end; ++tx_type) {
     RD_STATS this_rd_stats;
     av1_init_rd_stats(&this_rd_stats);
@@ -4868,7 +4873,7 @@
 #if !CONFIG_TXK_SEL
     if (is_inter) {
       if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
-        if (!do_tx_type_search(tx_type, prune,
+        if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
                                cpi->sf.tx_type_search.prune_mode))
           continue;
       }
@@ -4882,7 +4887,7 @@
       if (tx_type != DCT_DCT) continue;
 
     rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, mi_row, mi_col,
-                                 ref_best_rd, tx_type, tx_split_prune_flag,
+                                 ref_best_rd, tx_type,
                                  found_rd_info ? matched_rd_info : NULL);
 #if !CONFIG_TXK_SEL
     // If the current tx_type is not included in the tx_set for the smallest
@@ -4928,6 +4933,9 @@
 #endif
   }
 
+  // Reset the pruning flags.
+  av1_zero(x->tx_search_prune);
+
   // We should always find at least one candidate unless ref_best_rd is less
   // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
   // might have failed to find something better)