Move prune_tx_2D() to transform block level
With txk-sel merged, transform type pruning can be moved to transform
block level. This way the logic is more clear, and we can potentially
improve the accuracy of the model, and extend it to intra blocks.
Compression and encoding speed is roughly the same.
Change-Id: Id6d6a2639a392ebd7e72a67a70a5482ba67da1bc
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 26a827b..71cdcf8 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -330,6 +330,7 @@
// Bit flags for pruning tx type search, tx split, etc.
int tx_search_prune[EXT_TX_SET_TYPES];
int must_find_valid_partition;
+ int tx_split_prune_flag; // Flag to skip tx split RD search.
};
static INLINE int is_rect_tx_allowed_bsize(BLOCK_SIZE bsize) {
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index dd0eb05..df60836 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -1212,16 +1212,18 @@
return dst_score;
}
-static int prune_tx_split(BLOCK_SIZE bsize, const int16_t *diff, float hcorr,
- float vcorr) {
- if (bsize <= BLOCK_4X4 || bsize > BLOCK_16X16) return 0;
-
- float features[17];
+static void prune_tx_split(BLOCK_SIZE bsize, MACROBLOCK *x) {
+ if (bsize <= BLOCK_4X4 || bsize > BLOCK_16X16) return;
+ aom_clear_system_state();
+ const struct macroblock_plane *const p = &x->plane[0];
const int bw = block_size_wide[bsize], bh = block_size_high[bsize];
+ float features[17];
const int feature_num = (bw / 4) * (bh / 4) + 1;
assert(feature_num <= 17);
- get_2D_energy_distribution(diff, bw, bw, bh, features);
+ float hcorr, vcorr;
+ get_horver_correlation_full(p->src_diff, bw, bw, bh, &hcorr, &vcorr);
+ get_2D_energy_distribution(p->src_diff, bw, bw, bh, features);
features[feature_num - 2] = hcorr;
features[feature_num - 1] = vcorr;
@@ -1230,54 +1232,67 @@
const float *b1 =
fc1 + av1_prune_tx_split_num_hidden_units[bidx] * feature_num;
const float *fc2 = b1 + av1_prune_tx_split_num_hidden_units[bidx];
- float b2 = *(fc2 + av1_prune_tx_split_num_hidden_units[bidx]);
- float score =
+ const float b2 = *(fc2 + av1_prune_tx_split_num_hidden_units[bidx]);
+ const float score =
compute_tx_split_prune_score(features, feature_num, fc1, b1, fc2, b2,
av1_prune_tx_split_num_hidden_units[bidx]);
-
- return (score > av1_prune_tx_split_thresholds[bidx]);
+ x->tx_split_prune_flag = score > av1_prune_tx_split_thresholds[bidx];
}
-static void prune_tx_2D(BLOCK_SIZE bsize, MACROBLOCK *x,
- TX_TYPE_PRUNE_MODE prune_mode, int use_tx_split_prune) {
- if (bsize >= BLOCK_32X32) return;
- aom_clear_system_state();
- const struct macroblock_plane *const p = &x->plane[0];
- float hfeatures[16], vfeatures[16];
- float hscores[4], vscores[4];
- float scores_2D[16];
- const int tx_type_table_2D[16] = {
+static int prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
+ int blk_row, int blk_col, TxSetType tx_set_type,
+ TX_TYPE_PRUNE_MODE prune_mode) {
+ static const int tx_type_table_2D[16] = {
DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
H_DCT, H_ADST, H_FLIPADST, IDTX
};
- const int bw = block_size_wide[bsize], bh = block_size_high[bsize];
+ static const int model_idx_map[TX_SIZES_ALL] = {
+ 0, 3, 6, -1, -1, 1, 2, 4, 5, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1,
+ };
+ if (tx_set_type != EXT_TX_SET_ALL16 &&
+ tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
+ return 0;
+ const int model_idx = model_idx_map[tx_size];
+ if (model_idx < 0) return 0; // Model not established yet.
+
+ aom_clear_system_state();
+ float hfeatures[16], vfeatures[16];
+ float hscores[4], vscores[4];
+ float scores_2D[16];
+ const int bw = tx_size_wide[tx_size];
+ const int bh = tx_size_high[tx_size];
const int hfeatures_num = bw <= 8 ? bw : bw / 2;
const int vfeatures_num = bh <= 8 ? bh : bh / 2;
assert(hfeatures_num <= 16);
assert(vfeatures_num <= 16);
- get_energy_distribution_finer(p->src_diff, bw, bw, bh, hfeatures, vfeatures);
- get_horver_correlation_full(p->src_diff, bw, bw, bh,
+ const struct macroblock_plane *const p = &x->plane[0];
+ const int diff_stride = block_size_wide[bsize];
+ const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
+ get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
+ vfeatures);
+ get_horver_correlation_full(diff, diff_stride, bw, bh,
&hfeatures[hfeatures_num - 1],
&vfeatures[vfeatures_num - 1]);
- const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
- const float *fc1_hor = av1_prune_2D_learned_weights_hor[bidx];
+ const float *fc1_hor = av1_prune_2D_learned_weights_hor[model_idx];
const float *b1_hor =
- fc1_hor + av1_prune_2D_num_hidden_units_hor[bidx] * hfeatures_num;
- const float *fc2_hor = b1_hor + av1_prune_2D_num_hidden_units_hor[bidx];
- const float *b2_hor = fc2_hor + av1_prune_2D_num_hidden_units_hor[bidx] * 4;
+ fc1_hor + av1_prune_2D_num_hidden_units_hor[model_idx] * hfeatures_num;
+ const float *fc2_hor = b1_hor + av1_prune_2D_num_hidden_units_hor[model_idx];
+ const float *b2_hor =
+ fc2_hor + av1_prune_2D_num_hidden_units_hor[model_idx] * 4;
compute_1D_scores(hfeatures, hfeatures_num, fc1_hor, b1_hor, fc2_hor, b2_hor,
- av1_prune_2D_num_hidden_units_hor[bidx], hscores);
+ av1_prune_2D_num_hidden_units_hor[model_idx], hscores);
- const float *fc1_ver = av1_prune_2D_learned_weights_ver[bidx];
+ const float *fc1_ver = av1_prune_2D_learned_weights_ver[model_idx];
const float *b1_ver =
- fc1_ver + av1_prune_2D_num_hidden_units_ver[bidx] * vfeatures_num;
- const float *fc2_ver = b1_ver + av1_prune_2D_num_hidden_units_ver[bidx];
- const float *b2_ver = fc2_ver + av1_prune_2D_num_hidden_units_ver[bidx] * 4;
+ fc1_ver + av1_prune_2D_num_hidden_units_ver[model_idx] * vfeatures_num;
+ const float *fc2_ver = b1_ver + av1_prune_2D_num_hidden_units_ver[model_idx];
+ const float *b2_ver =
+ fc2_ver + av1_prune_2D_num_hidden_units_ver[model_idx] * 4;
compute_1D_scores(vfeatures, vfeatures_num, fc1_ver, b1_ver, fc2_ver, b2_ver,
- av1_prune_2D_num_hidden_units_ver[bidx], vscores);
+ av1_prune_2D_num_hidden_units_ver[model_idx], vscores);
float score_2D_average = 0.0f;
for (int i = 0; i < 4; i++) {
@@ -1292,62 +1307,46 @@
score_2D_average /= 16;
score_2D_transform_pow8(scores_2D, (20 - score_2D_average));
- // TODO(huisu@google.com): support more tx set types.
- const int tx_set_types[2] = { EXT_TX_SET_ALL16, EXT_TX_SET_DTT9_IDTX_1DDCT };
- for (int tx_set_idx = 0; tx_set_idx < 2; ++tx_set_idx) {
- const int tx_set_type = tx_set_types[tx_set_idx];
- // Always keep the TX type with the highest score, prune all others with
- // score below score_thresh.
- int max_score_i = 0;
- float max_score = 0.0f;
- for (int i = 0; i < 16; i++) {
- if (scores_2D[i] > max_score &&
- av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
- max_score = scores_2D[i];
- max_score_i = i;
- }
+ // Always keep the TX type with the highest score, prune all others with
+ // score below score_thresh.
+ int max_score_i = 0;
+ float max_score = 0.0f;
+ for (int i = 0; i < 16; i++) {
+ if (scores_2D[i] > max_score &&
+ av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
+ max_score = scores_2D[i];
+ max_score_i = i;
}
-
- int pruning_aggressiveness = 0;
- if (prune_mode == PRUNE_2D_ACCURATE) {
- if (tx_set_type == EXT_TX_SET_ALL16)
- pruning_aggressiveness = 6;
- else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
- pruning_aggressiveness = 4;
- } else if (prune_mode == PRUNE_2D_FAST) {
- if (tx_set_type == EXT_TX_SET_ALL16)
- pruning_aggressiveness = 10;
- else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
- pruning_aggressiveness = 7;
- }
- const float score_thresh =
- av1_prune_2D_adaptive_thresholds[bidx][pruning_aggressiveness - 1];
-
- int prune_bitmask = 0;
- for (int i = 0; i < 16; i++) {
- if (scores_2D[i] < score_thresh && i != max_score_i)
- prune_bitmask |= (1 << tx_type_table_2D[i]);
- }
- x->tx_search_prune[tx_set_type] = prune_bitmask;
}
- // Also apply TX size pruning if it's turned on. The value
- // of prune_tx_split_flag indicates whether we should do
- // full TX size search (flag=0) or use the largest available
- // TX size without performing any further search (flag=1).
- int prune_tx_split_flag = 0;
- if (use_tx_split_prune) {
- prune_tx_split_flag =
- prune_tx_split(bsize, p->src_diff, hfeatures[hfeatures_num - 1],
- vfeatures[vfeatures_num - 1]);
+ int pruning_aggressiveness = 0;
+ if (prune_mode == PRUNE_2D_ACCURATE) {
+ if (tx_set_type == EXT_TX_SET_ALL16)
+ pruning_aggressiveness = 6;
+ else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+ pruning_aggressiveness = 4;
+ } else if (prune_mode == PRUNE_2D_FAST) {
+ if (tx_set_type == EXT_TX_SET_ALL16)
+ pruning_aggressiveness = 10;
+ else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+ pruning_aggressiveness = 7;
}
- x->tx_search_prune[0] |= (prune_tx_split_flag << TX_TYPES);
+ const float score_thresh =
+ av1_prune_2D_adaptive_thresholds[model_idx][pruning_aggressiveness - 1];
+
+ int prune_bitmask = 0;
+ for (int i = 0; i < 16; i++) {
+ if (scores_2D[i] < score_thresh && i != max_score_i)
+ prune_bitmask |= (1 << tx_type_table_2D[i]);
+ }
+ return prune_bitmask;
}
static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
const MACROBLOCKD *const xd, int tx_set_type,
int use_tx_split_prune) {
av1_zero(x->tx_search_prune);
+ x->tx_split_prune_flag = 0;
const MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
if (!is_inter_block(mbmi) || cpi->sf.tx_type_search.prune_mode == NO_PRUNE ||
x->use_default_inter_tx_type || xd->lossless[mbmi->segment_id] ||
@@ -1377,8 +1376,7 @@
break;
case PRUNE_2D_ACCURATE:
case PRUNE_2D_FAST:
- prune_tx_2D(bsize, x, cpi->sf.tx_type_search.prune_mode,
- use_tx_split_prune);
+ if (use_tx_split_prune) prune_tx_split(bsize, x);
break;
default: assert(0);
}
@@ -1905,6 +1903,8 @@
}
int rate_cost = 0;
+ const int txk_type_idx =
+ av1_get_txk_type_index(plane_bsize, blk_row, blk_col);
TX_TYPE txk_start = DCT_DCT;
TX_TYPE txk_end = TX_TYPES - 1;
if (!(!is_inter && x->use_default_intra_tx_type) &&
@@ -1917,31 +1917,40 @@
const TxSetType tx_set_type = get_ext_tx_set_type(
tx_size, plane_bsize, is_inter, cm->reduced_tx_set_used);
int prune = 0;
- if (is_inter && plane == 0 && !fast_tx_search &&
- cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
- prune = x->tx_search_prune[tx_set_type];
- for (TX_TYPE tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
- if (txk_end != DCT_DCT) {
- if (is_inter && plane == 0 &&
- cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
- if (!do_tx_type_search(tx_type, prune,
- cpi->sf.tx_type_search.prune_mode))
- continue;
+ const int do_prune = plane == 0 && !fast_tx_search && txk_end != DCT_DCT &&
+ !(!is_inter && x->use_default_intra_tx_type) &&
+ !(is_inter && x->use_default_inter_tx_type) &&
+ cpi->sf.tx_type_search.prune_mode > NO_PRUNE;
+ if (do_prune) {
+ if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE) {
+ if (is_inter) {
+ prune = prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col,
+ tx_set_type, cpi->sf.tx_type_search.prune_mode);
}
+ } else {
+ if (is_inter) prune = x->tx_search_prune[tx_set_type];
+ }
+ }
+
+ int allowed_tx_mask[TX_TYPES] = { 0 }; // 1: allow; 0: skip.
+ int allowed_tx_num = 0;
+ for (TX_TYPE tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
+ allowed_tx_mask[tx_type] = 1;
+ if (do_prune) {
+ if (!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
+ allowed_tx_mask[tx_type] = 0;
}
if (fast_tx_search && tx_type != DCT_DCT && tx_type != H_DCT &&
tx_type != V_DCT)
- continue;
+ allowed_tx_mask[tx_type] = 0;
if (plane == 0) {
if (!is_inter && x->use_default_intra_tx_type &&
tx_type != get_default_tx_type(0, xd, tx_size))
- continue;
+ allowed_tx_mask[tx_type] = 0;
if (is_inter && x->use_default_inter_tx_type &&
tx_type != get_default_tx_type(0, xd, tx_size))
- continue;
- const int txk_type_idx =
- av1_get_txk_type_index(plane_bsize, blk_row, blk_col);
- mbmi->txk_type[txk_type_idx] = tx_type;
+ allowed_tx_mask[tx_type] = 0;
+ if (plane == 0) mbmi->txk_type[txk_type_idx] = tx_type;
}
const TX_TYPE ref_tx_type =
av1_get_tx_type(get_plane_type(plane), xd, blk_row, blk_col, tx_size,
@@ -1949,9 +1958,16 @@
if (tx_type != ref_tx_type) {
// use av1_get_tx_type() to check if the tx_type is valid for the current
// mode if it's not, we skip it here.
- continue;
+ allowed_tx_mask[tx_type] = 0;
}
+ allowed_tx_num += allowed_tx_mask[tx_type];
+ }
+ // Need to have at least one transform type allowed.
+ if (allowed_tx_num == 0) allowed_tx_mask[DCT_DCT] = 1;
+ for (TX_TYPE tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
+ if (!allowed_tx_mask[tx_type]) continue;
+ if (plane == 0) mbmi->txk_type[txk_type_idx] = tx_type;
const SCAN_ORDER *scan_order = get_scan(tx_size, tx_type);
RD_STATS this_rd_stats;
av1_invalid_rd_stats(&this_rd_stats);
@@ -1993,13 +2009,13 @@
if (cpi->sf.tx_type_search.skip_tx_search && !best_eob) break;
}
- if (best_eob == 0) best_tx_type = DCT_DCT;
+ assert(best_rd != INT64_MAX);
+ if (best_eob == 0) best_tx_type = DCT_DCT;
if (plane == 0) {
update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
best_tx_type);
}
-
x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
x->plane[plane].eobs[block] = best_eob;
@@ -2393,6 +2409,7 @@
mbmi->tx_size, cpi->sf.use_fast_coef_costing);
// Reset the pruning flags.
av1_zero(x->tx_search_prune);
+ x->tx_split_prune_flag = 0;
}
static void choose_smallest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
@@ -2498,6 +2515,7 @@
// Reset the pruning flags.
av1_zero(x->tx_search_prune);
+ x->tx_split_prune_flag = 0;
}
static void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
@@ -3579,7 +3597,7 @@
int tx_split_prune_flag = 0;
if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE)
- tx_split_prune_flag = ((x->tx_search_prune[0] >> TX_TYPES) & 1);
+ tx_split_prune_flag = x->tx_split_prune_flag;
if (cpi->sf.txb_split_cap)
if (p->eobs[block] == 0) tx_split_prune_flag = 1;
@@ -4424,6 +4442,7 @@
// Reset the pruning flags.
av1_zero(x->tx_search_prune);
+ x->tx_split_prune_flag = 0;
// We should always find at least one candidate unless ref_best_rd is less
// than INT64_MAX (in which case, all the calls to select_tx_size_fix_type