Add some comments for av1_pick_tx_size_type_yrd()
Also includes some code optimizations.
BUG=aomedia:2617
Change-Id: Ib6ff03c3247d938396e1532065b34b93d8d4f8ee
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
index 4a1e985..6967248 100644
--- a/av1/encoder/tx_search.c
+++ b/av1/encoder/tx_search.c
@@ -3189,41 +3189,50 @@
return rd;
}
-// Search for best transform size and type for luma inter blocks.
+// Return 1 to terminate transform search early. The decision is made based on
+// the comparison with the reference RD cost and the model-estimated RD cost.
+static AOM_INLINE int model_based_tx_search_prune(const AV1_COMP *cpi,
+ MACROBLOCK *x,
+ BLOCK_SIZE bsize,
+ int64_t ref_best_rd) {
+ const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
+ assert(level >= 0 && level <= 2);
+ int model_rate;
+ int64_t model_dist;
+ int model_skip;
+ MACROBLOCKD *const xd = &x->e_mbd;
+ model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
+ cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
+ NULL, NULL, NULL);
+ if (model_skip) return 0;
+ const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
+ // TODO(debargha, urvang): Improve the model and make the check below
+ // tighter.
+ static const int prune_factor_by8[] = { 3, 5 };
+ const int factor = prune_factor_by8[level - 1];
+ return ((model_rd * factor) >> 3) > ref_best_rd;
+}
+
+// Search for best transform size and type for luma inter blocks. The best
+// transform size and type, if found, will be saved in the MB_MODE_INFO
+// structure, and the corresponding RD stats will be saved in rd_stats.
void av1_pick_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
RD_STATS *rd_stats, BLOCK_SIZE bsize,
int64_t ref_best_rd) {
- const AV1_COMMON *cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
assert(is_inter_block(xd->mi[0]));
av1_invalid_rd_stats(rd_stats);
+ // If modeled RD cost is a lot worse than the best so far, terminate early.
if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
ref_best_rd != INT64_MAX) {
- int model_rate;
- int64_t model_dist;
- int model_skip;
- model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
- cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
- NULL, NULL, NULL);
- const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
- // If the modeled rd is a lot worse than the best so far, breakout.
- // TODO(debargha, urvang): Improve the model and make the check below
- // tighter.
- assert(cpi->sf.tx_sf.model_based_prune_tx_search_level >= 0 &&
- cpi->sf.tx_sf.model_based_prune_tx_search_level <= 2);
- static const int prune_factor_by8[] = { 3, 5 };
- if (!model_skip &&
- ((model_rd *
- prune_factor_by8[cpi->sf.tx_sf.model_based_prune_tx_search_level -
- 1]) >>
- 3) > ref_best_rd)
- return;
+ if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
}
+ // Hashing based speed feature. If the hash of the prediction residue block is
+ // found in the hash table, use previous search results and terminate early.
uint32_t hash = 0;
- int32_t match_index = -1;
MB_RD_RECORD *mb_rd_record = NULL;
const int mi_row = x->e_mbd.mi_row;
const int mi_col = x->e_mbd.mi_col;
@@ -3238,7 +3247,7 @@
if (is_mb_rd_hash_enabled) {
hash = get_block_residue_hash(x, bsize);
mb_rd_record = &x->mb_rd_record;
- match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
+ const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
if (match_index != -1) {
MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
@@ -3250,7 +3259,8 @@
// context and terminate early.
int64_t dist;
if (x->predict_skip_level &&
- predict_skip_flag(x, bsize, &dist, cm->features.reduced_tx_set_used)) {
+ predict_skip_flag(x, bsize, &dist,
+ cpi->common.features.reduced_tx_set_used)) {
set_skip_flag(x, rd_stats, bsize, dist);
// Save the RD search results into tx_rd_record.
if (is_mb_rd_hash_enabled)
@@ -3261,8 +3271,9 @@
++x->tx_search_count;
#endif // CONFIG_SPEED_STATS
- // Precompute residual hashes and find existing or add new RD records to
- // store and reuse rate and distortion values to speed up TX size search.
+ // Pre-compute residue hashes (transform block level) and find existing or
+ // add new RD records to store and reuse rate and distortion values to speed
+ // up TX size/type search.
TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
int found_rd_info = 0;
if (ref_best_rd != INT64_MAX && within_border &&
@@ -3270,24 +3281,19 @@
found_rd_info = find_tx_size_rd_records(x, bsize, matched_rd_info);
}
- int found = 0;
- RD_STATS this_rd_stats;
- av1_init_rd_stats(&this_rd_stats);
const int64_t rd =
- select_tx_size_and_type(cpi, x, &this_rd_stats, bsize, ref_best_rd,
+ select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd,
found_rd_info ? matched_rd_info : NULL);
- if (rd < INT64_MAX) {
- *rd_stats = this_rd_stats;
- found = 1;
+ if (rd == INT64_MAX) {
+ // We should always find at least one candidate unless ref_best_rd is less
+ // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
+ // might have failed to find something better)
+ assert(ref_best_rd != INT64_MAX);
+ av1_invalid_rd_stats(rd_stats);
+ return;
}
- // We should always find at least one candidate unless ref_best_rd is less
- // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
- // might have failed to find something better)
- assert(IMPLIES(!found, ref_best_rd != INT64_MAX));
- if (!found) return;
-
// Save the RD search results into tx_rd_record.
if (is_mb_rd_hash_enabled) {
assert(mb_rd_record != NULL);