Move the tx pruning flags into MACROBLOCK
So they can be generated at prediction block, and then easily
accessed by transform block.
Change-Id: I376042e8d57e00586d3cf90e237544e705b77e8b
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 0b334cf..4e6b8ee 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -365,6 +365,8 @@
int comp_idx_cost[COMP_INDEX_CONTEXTS][2];
int comp_group_idx_cost[COMP_GROUP_IDX_CONTEXTS][2];
#endif // CONFIG_JNT_COMP
+ // Bit flags for pruning tx type search, tx split, etc.
+ int tx_search_prune[EXT_TX_SET_TYPES];
};
static INLINE int is_rect_tx_allowed_bsize(BLOCK_SIZE bsize) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index c6b0c16..dc07334 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1255,20 +1255,15 @@
return (score > av1_prune_tx_split_thresholds[bidx]);
}
-static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type,
- int tx_type_pruning_aggressiveness,
- int use_tx_split_prune) {
- if (bsize >= BLOCK_32X32) return 0;
+static void prune_tx_2D(BLOCK_SIZE bsize, MACROBLOCK *x,
+ TX_TYPE_PRUNE_MODE prune_mode, int use_tx_split_prune) {
+ if (bsize >= BLOCK_32X32) return;
aom_clear_system_state();
const struct macroblock_plane *const p = &x->plane[0];
- const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
- const float score_thresh =
- av1_prune_2D_adaptive_thresholds[bidx]
- [tx_type_pruning_aggressiveness - 1];
float hfeatures[16], vfeatures[16];
float hscores[4], vscores[4];
float scores_2D[16];
- int tx_type_table_2D[16] = {
+ const int tx_type_table_2D[16] = {
DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
@@ -1284,7 +1279,7 @@
get_horver_correlation_full(p->src_diff, bw, bw, bh,
&hfeatures[hfeatures_num - 1],
&vfeatures[vfeatures_num - 1]);
-
+ const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
const float *fc1_hor = av1_prune_2D_learned_weights_hor[bidx];
const float *b1_hor =
fc1_hor + av1_prune_2D_num_hidden_units_hor[bidx] * hfeatures_num;
@@ -1314,22 +1309,43 @@
score_2D_average /= 16;
score_2D_transform_pow8(scores_2D, (20 - score_2D_average));
- // Always keep the TX type with the highest score, prune all others with
- // score below score_thresh.
- int max_score_i = 0;
- float max_score = 0.0f;
- for (int i = 0; i < 16; i++) {
- if (scores_2D[i] > max_score &&
- av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
- max_score = scores_2D[i];
- max_score_i = i;
+ // TODO(huisu@google.com): support more tx set types.
+ const int tx_set_types[2] = { EXT_TX_SET_ALL16, EXT_TX_SET_DTT9_IDTX_1DDCT };
+ for (int tx_set_idx = 0; tx_set_idx < 2; ++tx_set_idx) {
+ const int tx_set_type = tx_set_types[tx_set_idx];
+ // Always keep the TX type with the highest score, prune all others with
+ // score below score_thresh.
+ int max_score_i = 0;
+ float max_score = 0.0f;
+ for (int i = 0; i < 16; i++) {
+ if (scores_2D[i] > max_score &&
+ av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
+ max_score = scores_2D[i];
+ max_score_i = i;
+ }
}
- }
- int prune_bitmask = 0;
- for (int i = 0; i < 16; i++) {
- if (scores_2D[i] < score_thresh && i != max_score_i)
- prune_bitmask |= (1 << tx_type_table_2D[i]);
+ int pruning_aggressiveness = 0;
+ if (prune_mode == PRUNE_2D_ACCURATE) {
+ if (tx_set_type == EXT_TX_SET_ALL16)
+ pruning_aggressiveness = 6;
+ else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+ pruning_aggressiveness = 4;
+ } else if (prune_mode == PRUNE_2D_FAST) {
+ if (tx_set_type == EXT_TX_SET_ALL16)
+ pruning_aggressiveness = 10;
+ else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+ pruning_aggressiveness = 7;
+ }
+ const float score_thresh =
+ av1_prune_2D_adaptive_thresholds[bidx][pruning_aggressiveness - 1];
+
+ int prune_bitmask = 0;
+ for (int i = 0; i < 16; i++) {
+ if (scores_2D[i] < score_thresh && i != max_score_i)
+ prune_bitmask |= (1 << tx_type_table_2D[i]);
+ }
+ x->tx_search_prune[tx_set_type] = prune_bitmask;
}
// Also apply TX size pruning if it's turned on. The value
@@ -1342,51 +1358,42 @@
prune_tx_split(bsize, p->src_diff, hfeatures[hfeatures_num - 1],
vfeatures[vfeatures_num - 1]);
}
- prune_bitmask |= (prune_tx_split_flag << TX_TYPES);
- return prune_bitmask;
+ x->tx_search_prune[0] |= (prune_tx_split_flag << TX_TYPES);
}
-static int prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
- const MACROBLOCKD *const xd, int tx_set_type,
- int use_tx_split_prune) {
+static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
+ const MACROBLOCKD *const xd, int tx_set_type,
+ int use_tx_split_prune) {
int tx_set = ext_tx_set_index[1][tx_set_type];
assert(tx_set >= 0);
+ av1_zero(x->tx_search_prune);
const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
-
switch (cpi->sf.tx_type_search.prune_mode) {
- case NO_PRUNE: return 0; break;
+ case NO_PRUNE: return;
case PRUNE_ONE:
- if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return 0;
- return prune_one_for_sby(cpi, bsize, x, xd);
+ if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return;
+ x->tx_search_prune[tx_set_type] = prune_one_for_sby(cpi, bsize, x, xd);
break;
case PRUNE_TWO:
if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
- if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return 0;
- return prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
+ if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return;
+ x->tx_search_prune[tx_set_type] =
+ prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
}
- if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
- return prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
- return prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
+ if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) {
+ x->tx_search_prune[tx_set_type] =
+ prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
+ }
+ x->tx_search_prune[tx_set_type] =
+ prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
break;
case PRUNE_2D_ACCURATE:
- if (tx_set_type == EXT_TX_SET_ALL16)
- return prune_tx_2D(bsize, x, tx_set_type, 6, use_tx_split_prune);
- else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
- return prune_tx_2D(bsize, x, tx_set_type, 4, use_tx_split_prune);
- else
- return 0;
- break;
case PRUNE_2D_FAST:
- if (tx_set_type == EXT_TX_SET_ALL16)
- return prune_tx_2D(bsize, x, tx_set_type, 10, use_tx_split_prune);
- else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
- return prune_tx_2D(bsize, x, tx_set_type, 7, use_tx_split_prune);
- else
- return 0;
+ prune_tx_2D(bsize, x, cpi->sf.tx_type_search.prune_mode,
+ use_tx_split_prune);
break;
+ default: assert(0);
}
- assert(0);
- return 0;
}
static int do_tx_type_search(TX_TYPE tx_type, int prune,
@@ -2319,7 +2326,7 @@
}
static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
- TX_TYPE tx_type, TX_SIZE tx_size, int prune) {
+ TX_TYPE tx_type, TX_SIZE tx_size) {
const MACROBLOCKD *const xd = &x->e_mbd;
const MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
const int is_inter = is_inter_block(mbmi);
@@ -2337,7 +2344,8 @@
if (!av1_ext_tx_used[tx_set_type][tx_type]) return 1;
if (is_inter) {
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
- if (!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
+ if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
+ cpi->sf.tx_type_search.prune_mode))
return 1;
}
}
@@ -2371,7 +2379,6 @@
int s0 = x->skip_cost[skip_ctx][0];
int s1 = x->skip_cost[skip_ctx][1];
const int is_inter = is_inter_block(mbmi);
- int prune = 0;
av1_invalid_rd_stats(rd_stats);
mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
@@ -2381,7 +2388,7 @@
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type) {
- prune = prune_tx(cpi, bs, x, xd, tx_set_type, 0);
+ prune_tx(cpi, bs, x, xd, tx_set_type, 0);
}
#if CONFIG_FILTER_INTRA
if (skip_invalid_tx_size_for_filter_intra(mbmi, AOM_PLANE_Y, rd_stats)) {
@@ -2399,7 +2406,7 @@
tx_type != get_default_tx_type(0, xd, mbmi->tx_size))
continue;
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
- if (!do_tx_type_search(tx_type, prune,
+ if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
cpi->sf.tx_type_search.prune_mode))
continue;
}
@@ -2440,6 +2447,8 @@
mbmi->tx_size, cpi->sf.use_fast_coef_costing);
}
mbmi->tx_type = best_tx_type;
+ // Reset the pruning flags.
+ av1_zero(x->tx_search_prune);
}
static void choose_smallest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
@@ -2507,10 +2516,9 @@
depth = MAX_TX_DEPTH;
}
- int prune = 0;
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type) {
- prune = prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0);
+ prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0);
}
last_rd = INT64_MAX;
@@ -2526,7 +2534,7 @@
TX_TYPE tx_type;
for (tx_type = tx_start; tx_type < tx_end; ++tx_type) {
RD_STATS this_rd_stats;
- if (skip_txfm_search(cpi, x, bs, tx_type, n, prune)) continue;
+ if (skip_txfm_search(cpi, x, bs, tx_type, n)) continue;
if (mbmi->ref_mv_idx > 0) x->rd_model = LOW_TXFM_RD;
rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, n);
@@ -2578,6 +2586,8 @@
memcpy(x->blk_skip[0], best_blk_skip, sizeof(best_blk_skip[0]) * n4);
mbmi->min_tx_size = mbmi->tx_size;
+ // Reset the pruning flags.
+ av1_zero(x->tx_search_prune);
}
static void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
@@ -3820,7 +3830,6 @@
TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
RD_STATS *rd_stats, int64_t ref_best_rd,
int *is_cost_valid, int fast,
- int tx_split_prune_flag,
TX_SIZE_RD_INFO_NODE *rd_info_node) {
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
@@ -3920,6 +3929,9 @@
#endif
}
+ int tx_split_prune_flag = 0;
+ if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE)
+ tx_split_prune_flag = ((x->tx_search_prune[0] >> TX_TYPES) & 1);
if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH && tx_split_prune_flag == 0) {
const TX_SIZE sub_txs = sub_tx_size_map[1][tx_size];
const int bsw = tx_size_wide_unit[sub_txs];
@@ -3947,7 +3959,7 @@
select_tx_block(
cpi, x, offsetr, offsetc, plane, block, sub_txs, depth + 1,
plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
- ref_best_rd - tmp_rd, &this_cost_valid, fast, 0,
+ ref_best_rd - tmp_rd, &this_cost_valid, fast,
(rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
#if CONFIG_DIST_8X8
@@ -4099,7 +4111,6 @@
static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
RD_STATS *rd_stats, BLOCK_SIZE bsize,
int64_t ref_best_rd, int fast,
- int tx_split_prune_flag,
TX_SIZE_RD_INFO_NODE *rd_info_tree) {
MACROBLOCKD *const xd = &x->e_mbd;
int is_cost_valid = 1;
@@ -4138,7 +4149,7 @@
select_tx_block(cpi, x, idy, idx, 0, block, max_tx_size, init_depth,
plane_bsize, ctxa, ctxl, tx_above, tx_left,
&pn_rd_stats, ref_best_rd - this_rd, &is_cost_valid,
- fast, tx_split_prune_flag, rd_info_tree);
+ fast, rd_info_tree);
if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
av1_invalid_rd_stats(rd_stats);
return;
@@ -4172,7 +4183,6 @@
RD_STATS *rd_stats, BLOCK_SIZE bsize,
int mi_row, int mi_col,
int64_t ref_best_rd, TX_TYPE tx_type,
- int tx_split_prune_flag,
TX_SIZE_RD_INFO_NODE *rd_info_tree) {
const int fast = cpi->sf.tx_size_search_method > USE_FULL_RD;
const AV1_COMMON *const cm = &cpi->common;
@@ -4199,7 +4209,7 @@
mbmi->tx_type = tx_type;
select_inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, fast,
- tx_split_prune_flag, rd_info_tree);
+ rd_info_tree);
if (rd_stats->rate == INT_MAX) return INT64_MAX;
mbmi->min_tx_size = mbmi->inter_tx_size[0][0];
@@ -4800,7 +4810,6 @@
#endif
const int n4 = bsize_to_num_blk(bsize);
int idx, idy;
- int prune = 0;
// Get the tx_size 1 level down
const TX_SIZE min_tx_size =
sub_tx_size_map[1][max_txsize_rect_lookup[1][bsize]];
@@ -4851,16 +4860,12 @@
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type && !xd->lossless[mbmi->segment_id]) {
- prune = prune_tx(cpi, bsize, x, xd, tx_set_type,
- cpi->sf.tx_type_search.use_tx_size_pruning);
+ prune_tx(cpi, bsize, x, xd, tx_set_type,
+ cpi->sf.tx_type_search.use_tx_size_pruning);
}
int found = 0;
- int tx_split_prune_flag = 0;
- if (is_inter && cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE)
- tx_split_prune_flag = ((prune >> TX_TYPES) & 1);
-
for (tx_type = txk_start; tx_type < txk_end; ++tx_type) {
RD_STATS this_rd_stats;
av1_init_rd_stats(&this_rd_stats);
@@ -4868,7 +4873,7 @@
#if !CONFIG_TXK_SEL
if (is_inter) {
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
- if (!do_tx_type_search(tx_type, prune,
+ if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
cpi->sf.tx_type_search.prune_mode))
continue;
}
@@ -4882,7 +4887,7 @@
if (tx_type != DCT_DCT) continue;
rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, mi_row, mi_col,
- ref_best_rd, tx_type, tx_split_prune_flag,
+ ref_best_rd, tx_type,
found_rd_info ? matched_rd_info : NULL);
#if !CONFIG_TXK_SEL
// If the current tx_type is not included in the tx_set for the smallest
@@ -4928,6 +4933,9 @@
#endif
}
+ // Reset the pruning flags.
+ av1_zero(x->tx_search_prune);
+
// We should always find at least one candidate unless ref_best_rd is less
// than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
// might have failed to find something better)