Separate transform_search functions from rdopt.c
Created tx_search.c and tx_search.h to improve
modularity of rdopt.c
tx_search.c : To keep transform related functions
tx_search.h : To keep transform related data
structures, defs and enum.
Change-Id: I9f7d8b0257c75fc245cf64ba242daf51c6c87d93
diff --git a/av1/av1.cmake b/av1/av1.cmake
index 12a3f1d..f599574 100644
--- a/av1/av1.cmake
+++ b/av1/av1.cmake
@@ -218,6 +218,8 @@
"${AOM_ROOT}/av1/encoder/tokenize.h"
"${AOM_ROOT}/av1/encoder/tpl_model.c"
"${AOM_ROOT}/av1/encoder/tpl_model.h"
+ "${AOM_ROOT}/av1/encoder/tx_search.c"
+ "${AOM_ROOT}/av1/encoder/tx_search.h"
"${AOM_ROOT}/av1/encoder/wedge_utils.c"
"${AOM_ROOT}/av1/encoder/var_based_part.c"
"${AOM_ROOT}/av1/encoder/var_based_part.h"
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 3b8f558..5e40d1b 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -65,42 +65,13 @@
#include "av1/encoder/reconinter_enc.h"
#include "av1/encoder/tokenize.h"
#include "av1/encoder/tpl_model.h"
-#include "av1/encoder/tx_prune_model_weights.h"
-
-// Set this macro as 1 to collect data about tx size selection.
-#define COLLECT_TX_SIZE_DATA 0
-
-#if COLLECT_TX_SIZE_DATA
-static const char av1_tx_size_data_output_file[] = "tx_size_data.txt";
-#endif
+#include "av1/encoder/tx_search.h"
typedef struct {
PREDICTION_MODE mode;
MV_REFERENCE_FRAME ref_frame[2];
} MODE_DEFINITION;
-enum {
- FTXS_NONE = 0,
- FTXS_DCT_AND_1D_DCT_ONLY = 1 << 0,
- FTXS_DISABLE_TRELLIS_OPT = 1 << 1,
- FTXS_USE_TRANSFORM_DOMAIN = 1 << 2
-} UENUM1BYTE(FAST_TX_SEARCH_MODE);
-
-struct rdcost_block_args {
- const AV1_COMP *cpi;
- MACROBLOCK *x;
- ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
- ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
- RD_STATS rd_stats;
- int64_t this_rd;
- int64_t best_rd;
- int exit_early;
- int incomplete_exit;
- int use_fast_coef_costing;
- FAST_TX_SEARCH_MODE ftxs_mode;
- int skip_trellis;
-};
-
// Structure to store the compound type related stats for best compound type
typedef struct {
INTERINTER_COMPOUND_DATA best_compound_data;
@@ -786,37 +757,6 @@
return is_cfl_allowed(xd);
}
-static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
- RD_STATS *rd_stats, BLOCK_SIZE bsize,
- int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode);
-
-static unsigned pixel_dist_visible_only(
- const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
- const int src_stride, const uint8_t *dst, const int dst_stride,
- const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
- int visible_cols) {
- unsigned sse;
-
- if (txb_rows == visible_rows && txb_cols == visible_cols) {
- cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
- return sse;
- }
-
-#if CONFIG_AV1_HIGHBITDEPTH
- const MACROBLOCKD *xd = &x->e_mbd;
- if (is_cur_buf_hbd(xd)) {
- uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
- visible_cols, visible_rows);
- return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
- }
-#else
- (void)x;
-#endif
- sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
- visible_rows);
- return sse;
-}
-
#if CONFIG_DIST_8X8
static uint64_t cdef_dist_8x8_16bit(uint16_t *dst, int dstride, uint16_t *src,
int sstride, int coeff_shift) {
@@ -1214,70 +1154,6 @@
}
#endif // CONFIG_DIST_8X8
-static AOM_INLINE void get_energy_distribution_finer(const int16_t *diff,
- int stride, int bw, int bh,
- float *hordist,
- float *verdist) {
- // First compute downscaled block energy values (esq); downscale factors
- // are defined by w_shift and h_shift.
- unsigned int esq[256];
- const int w_shift = bw <= 8 ? 0 : 1;
- const int h_shift = bh <= 8 ? 0 : 1;
- const int esq_w = bw >> w_shift;
- const int esq_h = bh >> h_shift;
- const int esq_sz = esq_w * esq_h;
- int i, j;
- memset(esq, 0, esq_sz * sizeof(esq[0]));
- if (w_shift) {
- for (i = 0; i < bh; i++) {
- unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
- const int16_t *cur_diff_row = diff + i * stride;
- for (j = 0; j < bw; j += 2) {
- cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
- cur_diff_row[j + 1] * cur_diff_row[j + 1]);
- }
- }
- } else {
- for (i = 0; i < bh; i++) {
- unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
- const int16_t *cur_diff_row = diff + i * stride;
- for (j = 0; j < bw; j++) {
- cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
- }
- }
- }
-
- uint64_t total = 0;
- for (i = 0; i < esq_sz; i++) total += esq[i];
-
- // Output hordist and verdist arrays are normalized 1D projections of esq
- if (total == 0) {
- float hor_val = 1.0f / esq_w;
- for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
- float ver_val = 1.0f / esq_h;
- for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
- return;
- }
-
- const float e_recip = 1.0f / (float)total;
- memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
- memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
- const unsigned int *cur_esq_row;
- for (i = 0; i < esq_h - 1; i++) {
- cur_esq_row = esq + i * esq_w;
- for (j = 0; j < esq_w - 1; j++) {
- hordist[j] += (float)cur_esq_row[j];
- verdist[i] += (float)cur_esq_row[j];
- }
- verdist[i] += (float)cur_esq_row[j];
- }
- cur_esq_row = esq + i * esq_w;
- for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
-
- for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
- for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
-}
-
// Similar to get_horver_correlation, but also takes into account first
// row/column, when computing horizontal/vertical correlation.
void av1_get_horver_correlation_full_c(const int16_t *diff, int stride,
@@ -1377,234 +1253,6 @@
}
}
-// These thresholds were calibrated to provide a certain number of TX types
-// pruned by the model on average, i.e. selecting a threshold with index i
-// will lead to pruning i+1 TX types on average
-static const float *prune_2D_adaptive_thresholds[] = {
- // TX_4X4
- (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
- 0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
- 0.09778f, 0.11780f },
- // TX_8X8
- (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
- 0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
- 0.10803f, 0.14124f },
- // TX_16X16
- (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
- 0.06897f, 0.07629f, 0.08875f, 0.11169f },
- // TX_32X32
- NULL,
- // TX_64X64
- NULL,
- // TX_4X8
- (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
- 0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
- 0.10168f, 0.12585f },
- // TX_8X4
- (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
- 0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
- 0.10583f, 0.13123f },
- // TX_8X16
- (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
- 0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
- 0.10730f, 0.14221f },
- // TX_16X8
- (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
- 0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
- 0.10339f, 0.13464f },
- // TX_16X32
- NULL,
- // TX_32X16
- NULL,
- // TX_32X64
- NULL,
- // TX_64X32
- NULL,
- // TX_4X16
- (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
- 0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
- 0.10242f, 0.12878f },
- // TX_16X4
- (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
- 0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
- 0.10217f, 0.12610f },
- // TX_8X32
- NULL,
- // TX_32X8
- NULL,
- // TX_16X64
- NULL,
- // TX_64X16
- NULL,
-};
-
-// Probablities are sorted in descending order.
-static INLINE void sort_probability(float prob[], int txk[], int len) {
- int i, j, k;
-
- for (i = 1; i <= len - 1; ++i) {
- for (j = 0; j < i; ++j) {
- if (prob[j] < prob[i]) {
- float temp;
- int tempi;
-
- temp = prob[i];
- tempi = txk[i];
-
- for (k = i; k > j; k--) {
- prob[k] = prob[k - 1];
- txk[k] = txk[k - 1];
- }
-
- prob[j] = temp;
- txk[j] = tempi;
- break;
- }
- }
- }
-}
-
-static INLINE float get_adaptive_thresholds(TX_SIZE tx_size,
- TxSetType tx_set_type,
- TX_TYPE_PRUNE_MODE prune_mode) {
- const int prune_aggr_table[4][2] = { { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 } };
- int pruning_aggressiveness = 0;
- if (tx_set_type == EXT_TX_SET_ALL16)
- pruning_aggressiveness =
- prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][0];
- else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
- pruning_aggressiveness =
- prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][1];
-
- return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
-}
-
-static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
- int blk_row, int blk_col, TxSetType tx_set_type,
- TX_TYPE_PRUNE_MODE prune_mode, int *txk_map,
- uint16_t *allowed_tx_mask) {
- int tx_type_table_2D[16] = {
- DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
- ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
- FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
- H_DCT, H_ADST, H_FLIPADST, IDTX
- };
- if (tx_set_type != EXT_TX_SET_ALL16 &&
- tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
- return;
-#if CONFIG_NN_V2
- NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
- NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
-#else
- const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
- const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
-#endif
- if (!nn_config_hor || !nn_config_ver) return; // Model not established yet.
-
- aom_clear_system_state();
- float hfeatures[16], vfeatures[16];
- float hscores[4], vscores[4];
- float scores_2D_raw[16];
- float scores_2D[16];
- const int bw = tx_size_wide[tx_size];
- const int bh = tx_size_high[tx_size];
- const int hfeatures_num = bw <= 8 ? bw : bw / 2;
- const int vfeatures_num = bh <= 8 ? bh : bh / 2;
- assert(hfeatures_num <= 16);
- assert(vfeatures_num <= 16);
-
- const struct macroblock_plane *const p = &x->plane[0];
- const int diff_stride = block_size_wide[bsize];
- const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
- get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
- vfeatures);
- av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
- &hfeatures[hfeatures_num - 1],
- &vfeatures[vfeatures_num - 1]);
- aom_clear_system_state();
-#if CONFIG_NN_V2
- av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
- av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
-#else
- av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
- av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
-#endif
- aom_clear_system_state();
-
- for (int i = 0; i < 4; i++) {
- float *cur_scores_2D = scores_2D_raw + i * 4;
- cur_scores_2D[0] = vscores[i] * hscores[0];
- cur_scores_2D[1] = vscores[i] * hscores[1];
- cur_scores_2D[2] = vscores[i] * hscores[2];
- cur_scores_2D[3] = vscores[i] * hscores[3];
- }
-
- av1_nn_softmax(scores_2D_raw, scores_2D, 16);
-
- const float score_thresh =
- get_adaptive_thresholds(tx_size, tx_set_type, prune_mode);
-
- // Always keep the TX type with the highest score, prune all others with
- // score below score_thresh.
- int max_score_i = 0;
- float max_score = 0.0f;
- uint16_t allow_bitmask = 0;
- float sum_score = 0.0;
- // Calculate sum of allowed tx type score and Populate allow bit mask based
- // on score_thresh and allowed_tx_mask
- for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
- int allow_tx_type = *allowed_tx_mask & (1 << tx_type_table_2D[tx_idx]);
- if (scores_2D[tx_idx] > max_score && allow_tx_type) {
- max_score = scores_2D[tx_idx];
- max_score_i = tx_idx;
- }
- if (scores_2D[tx_idx] >= score_thresh && allow_tx_type) {
- // Set allow mask based on score_thresh
- allow_bitmask |= (1 << tx_type_table_2D[tx_idx]);
-
- // Accumulate score of allowed tx type
- sum_score += scores_2D[tx_idx];
- }
- }
- if (!((allow_bitmask >> max_score_i) & 0x01)) {
- // Set allow mask based on tx type with max score
- allow_bitmask |= (1 << tx_type_table_2D[max_score_i]);
- sum_score += scores_2D[max_score_i];
- }
- // Sort tx type probability of all types
- sort_probability(scores_2D, tx_type_table_2D, TX_TYPES);
-
- // Enable more pruning based on tx type probability and number of allowed tx
- // types
- if (prune_mode == PRUNE_2D_AGGRESSIVE) {
- float temp_score = 0.0;
- float score_ratio = 0.0;
- int tx_idx, tx_count = 0;
- const float inv_sum_score = 100 / sum_score;
- // Get allowed tx types based on sorted probability score and tx count
- for (tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
- // Skip the tx type which has more than 30% of cumulative
- // probability and allowed tx type count is more than 2
- if (score_ratio > 30.0 && tx_count >= 2) break;
-
- // Calculate cumulative probability of allowed tx types
- if (allow_bitmask & (1 << tx_type_table_2D[tx_idx])) {
- // Calculate cumulative probability
- temp_score += scores_2D[tx_idx];
-
- // Calculate percentage of cumulative probability of allowed tx type
- score_ratio = temp_score * inv_sum_score;
- tx_count++;
- }
- }
- // Set remaining tx types as pruned
- for (; tx_idx < TX_TYPES; tx_idx++)
- allow_bitmask &= ~(1 << tx_type_table_2D[tx_idx]);
- }
- memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
- *allowed_tx_mask = allow_bitmask;
-}
-
static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
const AV1_COMMON *cm = &cpi->common;
const int num_planes = av1_num_planes(cm);
@@ -1666,76 +1314,6 @@
}
#endif
-// Compute the pixel domain distortion from src and dst on all visible 4x4s in
-// the
-// transform block.
-static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
- int plane, const uint8_t *src, const int src_stride,
- const uint8_t *dst, const int dst_stride,
- int blk_row, int blk_col,
- const BLOCK_SIZE plane_bsize,
- const BLOCK_SIZE tx_bsize) {
- int txb_rows, txb_cols, visible_rows, visible_cols;
- const MACROBLOCKD *xd = &x->e_mbd;
-
- get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
- &txb_cols, &txb_rows, &visible_cols, &visible_rows);
- assert(visible_rows > 0);
- assert(visible_cols > 0);
-
-#if CONFIG_DIST_8X8
- if (x->using_dist_8x8 && plane == 0)
- return (unsigned)av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride,
- tx_bsize, txb_cols, txb_rows, visible_cols,
- visible_rows, x->qindex);
-#endif // CONFIG_DIST_8X8
-
- unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
- dst_stride, tx_bsize, txb_rows,
- txb_cols, visible_rows, visible_cols);
-
- return sse;
-}
-
-// Compute the pixel domain distortion from diff on all visible 4x4s in the
-// transform block.
-static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane,
- int blk_row, int blk_col,
- const BLOCK_SIZE plane_bsize,
- const BLOCK_SIZE tx_bsize,
- unsigned int *block_mse_q8) {
- int visible_rows, visible_cols;
- const MACROBLOCKD *xd = &x->e_mbd;
- get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
- NULL, &visible_cols, &visible_rows);
- const int diff_stride = block_size_wide[plane_bsize];
- const int16_t *diff = x->plane[plane].src_diff;
-#if CONFIG_DIST_8X8
- int txb_height = block_size_high[tx_bsize];
- int txb_width = block_size_wide[tx_bsize];
- if (x->using_dist_8x8 && plane == 0) {
- const int src_stride = x->plane[plane].src.stride;
- const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
- const int diff_idx = (blk_row * diff_stride + blk_col) << MI_SIZE_LOG2;
- const uint8_t *src = &x->plane[plane].src.buf[src_idx];
- return dist_8x8_diff(x, src, src_stride, diff + diff_idx, diff_stride,
- txb_width, txb_height, visible_cols, visible_rows,
- x->qindex);
- }
-#endif
- diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
- uint64_t sse =
- aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
- if (block_mse_q8 != NULL) {
- if (visible_cols > 0 && visible_rows > 0)
- *block_mse_q8 =
- (unsigned int)((256 * sse) / (visible_cols * visible_rows));
- else
- *block_mse_q8 = UINT_MAX;
- }
- return sse;
-}
-
int av1_count_colors(const uint8_t *src, int stride, int rows, int cols,
int *val_count) {
const int max_pix_val = 1 << 8;
@@ -1775,1497 +1353,6 @@
return n;
}
-static AOM_INLINE void inverse_transform_block_facade(MACROBLOCKD *xd,
- int plane, int block,
- int blk_row, int blk_col,
- int eob,
- int reduced_tx_set) {
- if (!eob) return;
-
- struct macroblockd_plane *const pd = &xd->plane[plane];
- tran_low_t *dqcoeff = pd->dqcoeff + BLOCK_OFFSET(block);
- const PLANE_TYPE plane_type = get_plane_type(plane);
- const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
- const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
- tx_size, reduced_tx_set);
- const int dst_stride = pd->dst.stride;
- uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
- av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
- dst_stride, eob, reduced_tx_set);
-}
-
-static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record, const uint32_t hash);
-
-static uint32_t get_intra_txb_hash(MACROBLOCK *x, int plane, int blk_row,
- int blk_col, BLOCK_SIZE plane_bsize,
- TX_SIZE tx_size) {
- int16_t tmp_data[64 * 64];
- const int diff_stride = block_size_wide[plane_bsize];
- const int16_t *diff = x->plane[plane].src_diff;
- const int16_t *cur_diff_row = diff + 4 * blk_row * diff_stride + 4 * blk_col;
- const int txb_w = tx_size_wide[tx_size];
- const int txb_h = tx_size_high[tx_size];
- uint8_t *hash_data = (uint8_t *)cur_diff_row;
- if (txb_w != diff_stride) {
- int16_t *cur_hash_row = tmp_data;
- for (int i = 0; i < txb_h; i++) {
- memcpy(cur_hash_row, cur_diff_row, sizeof(*diff) * txb_w);
- cur_hash_row += txb_w;
- cur_diff_row += diff_stride;
- }
- hash_data = (uint8_t *)tmp_data;
- }
- CRC32C *crc = &x->mb_rd_record.crc_calculator;
- const uint32_t hash = av1_get_crc32c_value(crc, hash_data, 2 * txb_w * txb_h);
- return (hash << 5) + tx_size;
-}
-
-static INLINE void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
- TX_SIZE tx_size, int64_t *out_dist,
- int64_t *out_sse) {
- MACROBLOCKD *const xd = &x->e_mbd;
- const struct macroblock_plane *const p = &x->plane[plane];
- const struct macroblockd_plane *const pd = &xd->plane[plane];
- // Transform domain distortion computation is more efficient as it does
- // not involve an inverse transform, but it is less accurate.
- const int buffer_length = av1_get_max_eob(tx_size);
- int64_t this_sse;
- // TX-domain results need to shift down to Q2/D10 to match pixel
- // domain distortion values which are in Q2^2
- int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
- const int block_offset = BLOCK_OFFSET(block);
- tran_low_t *const coeff = p->coeff + block_offset;
- tran_low_t *const dqcoeff = pd->dqcoeff + block_offset;
-#if CONFIG_AV1_HIGHBITDEPTH
- if (is_cur_buf_hbd(xd))
- *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
- xd->bd);
- else
- *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
-#else
- *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
-#endif
- *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
- *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
-}
-
-static INLINE int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
- int plane, BLOCK_SIZE plane_bsize,
- int block, int blk_row, int blk_col,
- TX_SIZE tx_size) {
- MACROBLOCKD *const xd = &x->e_mbd;
- const struct macroblock_plane *const p = &x->plane[plane];
- const struct macroblockd_plane *const pd = &xd->plane[plane];
- const uint16_t eob = p->eobs[block];
- const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
- const int bsw = block_size_wide[tx_bsize];
- const int bsh = block_size_high[tx_bsize];
- const int src_stride = x->plane[plane].src.stride;
- const int dst_stride = xd->plane[plane].dst.stride;
- // Scale the transform block index to pixel unit.
- const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
- const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
- const uint8_t *src = &x->plane[plane].src.buf[src_idx];
- const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
- const tran_low_t *dqcoeff = pd->dqcoeff + BLOCK_OFFSET(block);
-
- assert(cpi != NULL);
- assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
-
- uint8_t *recon;
- DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
-
-#if CONFIG_AV1_HIGHBITDEPTH
- if (is_cur_buf_hbd(xd)) {
- recon = CONVERT_TO_BYTEPTR(recon16);
- av1_highbd_convolve_2d_copy_sr(CONVERT_TO_SHORTPTR(dst), dst_stride,
- CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw,
- bsh, NULL, NULL, 0, 0, NULL, xd->bd);
- } else {
- recon = (uint8_t *)recon16;
- av1_convolve_2d_copy_sr(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh, NULL,
- NULL, 0, 0, NULL);
- }
-#else
- recon = (uint8_t *)recon16;
- av1_convolve_2d_copy_sr(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh, NULL,
- NULL, 0, 0, NULL);
-#endif
-
- const PLANE_TYPE plane_type = get_plane_type(plane);
- TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
- cpi->common.reduced_tx_set_used);
- av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
- MAX_TX_SIZE, eob,
- cpi->common.reduced_tx_set_used);
-
- return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
- blk_row, blk_col, plane_bsize, tx_bsize);
-}
-
-// NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
-// 0: Do not collect any RD stats
-// 1: Collect RD stats for transform units
-// 2: Collect RD stats for partition units
-#if CONFIG_COLLECT_RD_STATS
-
-static AOM_INLINE void get_energy_distribution_fine(
- const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
- const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
- double *verdist) {
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
- unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
-
- if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
- // Special cases: calculate 'esq' values manually, as we don't have 'vf'
- // functions for the 16 (very small) sub-blocks of this block.
- const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
- const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
- assert(bw <= 32);
- assert(bh <= 32);
- assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
- if (cpi->common.seq_params.use_highbitdepth) {
- const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
- const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
- for (int i = 0; i < bh; ++i)
- for (int j = 0; j < bw; ++j) {
- const int index = (j >> w_shift) + ((i >> h_shift) << 2);
- esq[index] +=
- (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
- (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
- }
- } else {
- for (int i = 0; i < bh; ++i)
- for (int j = 0; j < bw; ++j) {
- const int index = (j >> w_shift) + ((i >> h_shift) << 2);
- esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
- (src[j + i * src_stride] - dst[j + i * dst_stride]);
- }
- }
- } else { // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
- const int f_index =
- (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
- assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
- const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
- assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
- assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
- cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
- cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
- &esq[1]);
- cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
- &esq[2]);
- cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
- dst_stride, &esq[3]);
- src += bh / 4 * src_stride;
- dst += bh / 4 * dst_stride;
-
- cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
- cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
- &esq[5]);
- cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
- &esq[6]);
- cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
- dst_stride, &esq[7]);
- src += bh / 4 * src_stride;
- dst += bh / 4 * dst_stride;
-
- cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
- cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
- &esq[9]);
- cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
- &esq[10]);
- cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
- dst_stride, &esq[11]);
- src += bh / 4 * src_stride;
- dst += bh / 4 * dst_stride;
-
- cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
- cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
- &esq[13]);
- cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
- &esq[14]);
- cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
- dst_stride, &esq[15]);
- }
-
- double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
- esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
- esq[12] + esq[13] + esq[14] + esq[15];
- if (total > 0) {
- const double e_recip = 1.0 / total;
- hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
- hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
- hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
- if (need_4th) {
- hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
- }
- verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
- verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
- verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
- if (need_4th) {
- verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
- }
- } else {
- hordist[0] = verdist[0] = 0.25;
- hordist[1] = verdist[1] = 0.25;
- hordist[2] = verdist[2] = 0.25;
- if (need_4th) {
- hordist[3] = verdist[3] = 0.25;
- }
- }
-}
-
-static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
- double sum = 0.0;
- for (int j = 0; j < h; ++j) {
- for (int i = 0; i < w; ++i) {
- const int err = diff[j * stride + i];
- sum += err * err;
- }
- }
- assert(w > 0 && h > 0);
- return sum / (w * h);
-}
-
-static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
- double sum = 0.0;
- for (int j = 0; j < h; ++j) {
- for (int i = 0; i < w; ++i) {
- sum += abs(diff[j * stride + i]);
- }
- }
- assert(w > 0 && h > 0);
- return sum / (w * h);
-}
-
-static AOM_INLINE void get_2x2_normalized_sses_and_sads(
- const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
- int src_stride, const uint8_t *const dst, int dst_stride,
- const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
- double *const sad_norm_arr) {
- const BLOCK_SIZE tx_bsize_half =
- get_partition_subsize(tx_bsize, PARTITION_SPLIT);
- if (tx_bsize_half == BLOCK_INVALID) { // manually calculate stats
- const int half_width = block_size_wide[tx_bsize] / 2;
- const int half_height = block_size_high[tx_bsize] / 2;
- for (int row = 0; row < 2; ++row) {
- for (int col = 0; col < 2; ++col) {
- const int16_t *const this_src_diff =
- src_diff + row * half_height * diff_stride + col * half_width;
- if (sse_norm_arr) {
- sse_norm_arr[row * 2 + col] =
- get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
- }
- if (sad_norm_arr) {
- sad_norm_arr[row * 2 + col] =
- get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
- }
- }
- }
- } else { // use function pointers to calculate stats
- const int half_width = block_size_wide[tx_bsize_half];
- const int half_height = block_size_high[tx_bsize_half];
- const int num_samples_half = half_width * half_height;
- for (int row = 0; row < 2; ++row) {
- for (int col = 0; col < 2; ++col) {
- const uint8_t *const this_src =
- src + row * half_height * src_stride + col * half_width;
- const uint8_t *const this_dst =
- dst + row * half_height * dst_stride + col * half_width;
-
- if (sse_norm_arr) {
- unsigned int this_sse;
- cpi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
- dst_stride, &this_sse);
- sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
- }
-
- if (sad_norm_arr) {
- const unsigned int this_sad = cpi->fn_ptr[tx_bsize_half].sdf(
- this_src, src_stride, this_dst, dst_stride);
- sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
- }
- }
- }
- }
-}
-
-#if CONFIG_COLLECT_RD_STATS == 1
-static double get_mean(const int16_t *diff, int stride, int w, int h) {
- double sum = 0.0;
- for (int j = 0; j < h; ++j) {
- for (int i = 0; i < w; ++i) {
- sum += diff[j * stride + i];
- }
- }
- assert(w > 0 && h > 0);
- return sum / (w * h);
-}
-static AOM_INLINE void PrintTransformUnitStats(
- const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
- int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
- TX_TYPE tx_type, int64_t rd) {
- if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
-
- // Generate small sample to restrict output size.
- static unsigned int seed = 21743;
- if (lcg_rand16(&seed) % 256 > 0) return;
-
- const char output_file[] = "tu_stats.txt";
- FILE *fout = fopen(output_file, "a");
- if (!fout) return;
-
- const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
- const MACROBLOCKD *const xd = &x->e_mbd;
- const int plane = 0;
- struct macroblock_plane *const p = &x->plane[plane];
- const struct macroblockd_plane *const pd = &xd->plane[plane];
- const int txw = tx_size_wide[tx_size];
- const int txh = tx_size_high[tx_size];
- const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
- const int q_step = p->dequant_QTX[1] >> dequant_shift;
- const int num_samples = txw * txh;
-
- const double rate_norm = (double)rd_stats->rate / num_samples;
- const double dist_norm = (double)rd_stats->dist / num_samples;
-
- fprintf(fout, "%g %g", rate_norm, dist_norm);
-
- const int src_stride = p->src.stride;
- const uint8_t *const src =
- &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
- const int dst_stride = pd->dst.stride;
- const uint8_t *const dst =
- &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
- unsigned int sse;
- cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
- const double sse_norm = (double)sse / num_samples;
-
- const unsigned int sad =
- cpi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
- const double sad_norm = (double)sad / num_samples;
-
- fprintf(fout, " %g %g", sse_norm, sad_norm);
-
- const int diff_stride = block_size_wide[plane_bsize];
- const int16_t *const src_diff =
- &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
-
- double sse_norm_arr[4], sad_norm_arr[4];
- get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
- dst_stride, src_diff, diff_stride,
- sse_norm_arr, sad_norm_arr);
- for (int i = 0; i < 4; ++i) {
- fprintf(fout, " %g", sse_norm_arr[i]);
- }
- for (int i = 0; i < 4; ++i) {
- fprintf(fout, " %g", sad_norm_arr[i]);
- }
-
- const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
- const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
-
- fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
- tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
-
- int model_rate;
- int64_t model_dist;
- model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
- &model_rate, &model_dist);
- const double model_rate_norm = (double)model_rate / num_samples;
- const double model_dist_norm = (double)model_dist / num_samples;
- fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
-
- const double mean = get_mean(src_diff, diff_stride, txw, txh);
- float hor_corr, vert_corr;
- av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
- &vert_corr);
- fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
-
- double hdist[4] = { 0 }, vdist[4] = { 0 };
- get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
- 1, hdist, vdist);
- fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
- hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
-
- fprintf(fout, " %d %" PRId64, x->rdmult, rd);
-
- fprintf(fout, "\n");
- fclose(fout);
-}
-#endif // CONFIG_COLLECT_RD_STATS == 1
-
-#if CONFIG_COLLECT_RD_STATS >= 2
-static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
- const uint8_t *dst8, int dst_stride, int w,
- int h) {
- const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
- const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
- double sum = 0.0;
- for (int j = 0; j < h; ++j) {
- for (int i = 0; i < w; ++i) {
- const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
- sum += diff;
- }
- }
- assert(w > 0 && h > 0);
- return sum / (w * h);
-}
-
-static double get_diff_mean(const uint8_t *src, int src_stride,
- const uint8_t *dst, int dst_stride, int w, int h) {
- double sum = 0.0;
- for (int j = 0; j < h; ++j) {
- for (int i = 0; i < w; ++i) {
- const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
- sum += diff;
- }
- }
- assert(w > 0 && h > 0);
- return sum / (w * h);
-}
-
-static AOM_INLINE void PrintPredictionUnitStats(const AV1_COMP *const cpi,
- const TileDataEnc *tile_data,
- MACROBLOCK *x,
- const RD_STATS *const rd_stats,
- BLOCK_SIZE plane_bsize) {
- if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
-
- if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
- (tile_data == NULL ||
- !tile_data->inter_mode_rd_models[plane_bsize].ready))
- return;
- (void)tile_data;
- // Generate small sample to restrict output size.
- static unsigned int seed = 95014;
-
- if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
- 1)
- return;
-
- const char output_file[] = "pu_stats.txt";
- FILE *fout = fopen(output_file, "a");
- if (!fout) return;
-
- MACROBLOCKD *const xd = &x->e_mbd;
- const int plane = 0;
- struct macroblock_plane *const p = &x->plane[plane];
- struct macroblockd_plane *pd = &xd->plane[plane];
- const int diff_stride = block_size_wide[plane_bsize];
- int bw, bh;
- get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
- &bh);
- const int num_samples = bw * bh;
- const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
- const int q_step = p->dequant_QTX[1] >> dequant_shift;
- const int shift = (xd->bd - 8);
-
- const double rate_norm = (double)rd_stats->rate / num_samples;
- const double dist_norm = (double)rd_stats->dist / num_samples;
- const double rdcost_norm =
- (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
-
- fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
-
- const int src_stride = p->src.stride;
- const uint8_t *const src = p->src.buf;
- const int dst_stride = pd->dst.stride;
- const uint8_t *const dst = pd->dst.buf;
- const int16_t *const src_diff = p->src_diff;
-
- int64_t sse = calculate_sse(xd, p, pd, bw, bh);
- const double sse_norm = (double)sse / num_samples;
-
- const unsigned int sad =
- cpi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
- const double sad_norm =
- (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
-
- fprintf(fout, " %g %g", sse_norm, sad_norm);
-
- double sse_norm_arr[4], sad_norm_arr[4];
- get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
- dst_stride, src_diff, diff_stride,
- sse_norm_arr, sad_norm_arr);
- if (shift) {
- for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
- for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
- }
- for (int i = 0; i < 4; ++i) {
- fprintf(fout, " %g", sse_norm_arr[i]);
- }
- for (int i = 0; i < 4; ++i) {
- fprintf(fout, " %g", sad_norm_arr[i]);
- }
-
- fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
-
- int model_rate;
- int64_t model_dist;
- model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
- &model_rate, &model_dist);
- const double model_rdcost_norm =
- (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
- const double model_rate_norm = (double)model_rate / num_samples;
- const double model_dist_norm = (double)model_dist / num_samples;
- fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
- model_rdcost_norm);
-
- double mean;
- if (is_cur_buf_hbd(xd)) {
- mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
- pd->dst.stride, bw, bh);
- } else {
- mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
- bw, bh);
- }
- mean /= (1 << shift);
- float hor_corr, vert_corr;
- av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
- &vert_corr);
- fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
-
- double hdist[4] = { 0 }, vdist[4] = { 0 };
- get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
- dst_stride, 1, hdist, vdist);
- fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
- hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
-
- if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
- assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
- const int64_t overall_sse = get_sse(cpi, x);
- int est_residue_cost = 0;
- int64_t est_dist = 0;
- get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
- &est_dist);
- const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
- const double est_dist_norm = (double)est_dist / num_samples;
- const double est_rdcost_norm =
- (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
- fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
- est_rdcost_norm);
- }
-
- fprintf(fout, "\n");
- fclose(fout);
-}
-#endif // CONFIG_COLLECT_RD_STATS >= 2
-#endif // CONFIG_COLLECT_RD_STATS
-
-// pruning thresholds for prune_txk_type and prune_txk_type_separ
-static const int prune_factors[5] = { 200, 200, 120, 80, 40 }; // scale 1000
-static const int mul_factors[5] = { 80, 80, 70, 50, 30 }; // scale 100
-// R-D costs are sorted in ascending order.
-static INLINE void sort_rd(int64_t rds[], int txk[], int len) {
- int i, j, k;
-
- for (i = 1; i <= len - 1; ++i) {
- for (j = 0; j < i; ++j) {
- if (rds[j] > rds[i]) {
- int64_t temprd;
- int tempi;
-
- temprd = rds[i];
- tempi = txk[i];
-
- for (k = i; k > j; k--) {
- rds[k] = rds[k - 1];
- txk[k] = txk[k - 1];
- }
-
- rds[j] = temprd;
- txk[j] = tempi;
- break;
- }
- }
- }
-}
-
-uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
- int block, TX_SIZE tx_size, int blk_row, int blk_col,
- BLOCK_SIZE plane_bsize, int *txk_map,
- uint16_t allowed_tx_mask, int prune_factor,
- const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
- const AV1_COMMON *cm = &cpi->common;
- int tx_type;
-
- int64_t rds[TX_TYPES];
-
- int num_cand = 0;
- int last = TX_TYPES - 1;
-
- TxfmParam txfm_param;
- QUANT_PARAM quant_param;
- av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
- av1_setup_quant(cm, tx_size, 1, AV1_XFORM_QUANT_B, &quant_param);
-
- for (int idx = 0; idx < TX_TYPES; idx++) {
- tx_type = idx;
- int rate_cost = 0;
- int64_t dist = 0, sse = 0;
- if (!(allowed_tx_mask & (1 << tx_type))) {
- txk_map[last] = tx_type;
- last--;
- continue;
- }
- txfm_param.tx_type = tx_type;
-
- // do txfm and quantization
- av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
- &quant_param);
- // estimate rate cost
- rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
- txb_ctx, reduced_tx_set_used, 0);
- // tx domain dist
- dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
-
- txk_map[num_cand] = tx_type;
- rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
- if (rds[num_cand] == 0) rds[num_cand] = 1;
- num_cand++;
- }
-
- if (num_cand == 0) return (uint16_t)0xFFFF;
-
- sort_rd(rds, txk_map, num_cand);
- uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
-
- // 0 < prune_factor <= 1000 controls aggressiveness
- int64_t factor = 0;
- for (int idx = 1; idx < num_cand; idx++) {
- factor = 1000 * (rds[idx] - rds[0]) / rds[0];
- if (factor < (int64_t)prune_factor)
- prune &= ~(1 << txk_map[idx]);
- else
- break;
- }
- return prune;
-}
-
-uint16_t prune_txk_type_separ(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
- int block, TX_SIZE tx_size, int blk_row,
- int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
- int16_t allowed_tx_mask, int prune_factor,
- const TXB_CTX *const txb_ctx,
- int reduced_tx_set_used, int64_t ref_best_rd,
- int num_sel) {
- const AV1_COMMON *cm = &cpi->common;
-
- int idx;
-
- int64_t rds_v[4];
- int64_t rds_h[4];
- int idx_v[4] = { 0, 1, 2, 3 };
- int idx_h[4] = { 0, 1, 2, 3 };
- int skip_v[4] = { 0 };
- int skip_h[4] = { 0 };
- const int idx_map[16] = {
- DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
- ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
- FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
- H_DCT, H_ADST, H_FLIPADST, IDTX
- };
-
- const int sel_pattern_v[16] = {
- 0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
- };
- const int sel_pattern_h[16] = {
- 0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
- };
-
- QUANT_PARAM quant_param;
- TxfmParam txfm_param;
- av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
- av1_setup_quant(cm, tx_size, 1, AV1_XFORM_QUANT_B, &quant_param);
- int tx_type;
- // to ensure we can try ones even outside of ext_tx_set of current block
- // this function should only be called for size < 16
- assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
- txfm_param.tx_set_type = EXT_TX_SET_ALL16;
-
- int rate_cost = 0;
- int64_t dist = 0, sse = 0;
- // evaluate horizontal with vertical DCT
- for (idx = 0; idx < 4; ++idx) {
- tx_type = idx_map[idx];
- txfm_param.tx_type = tx_type;
-
- av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
- &quant_param);
-
- dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
-
- rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
- txb_ctx, reduced_tx_set_used, 0);
-
- rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
-
- if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
- skip_h[idx] = 1;
- }
- }
- sort_rd(rds_h, idx_h, 4);
- for (idx = 1; idx < 4; idx++) {
- if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
- }
-
- if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
-
- // evaluate vertical with the best horizontal chosen
- rds_v[0] = rds_h[0];
- int start_v = 1, end_v = 4;
- const int *idx_map_v = idx_map + idx_h[0];
-
- for (idx = start_v; idx < end_v; ++idx) {
- tx_type = idx_map_v[idx_v[idx] * 4];
- txfm_param.tx_type = tx_type;
-
- av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
- &quant_param);
-
- dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
-
- rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
- txb_ctx, reduced_tx_set_used, 0);
-
- rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
-
- if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
- skip_v[idx] = 1;
- }
- }
- sort_rd(rds_v, idx_v, 4);
- for (idx = 1; idx < 4; idx++) {
- if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
- }
-
- // combine rd_h and rd_v to prune tx candidates
- int i_v, i_h;
- int64_t rds[16];
- int num_cand = 0, last = TX_TYPES - 1;
-
- for (int i = 0; i < 16; i++) {
- i_v = sel_pattern_v[i];
- i_h = sel_pattern_h[i];
- tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
- if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
- skip_v[idx_v[i_v]]) {
- txk_map[last] = tx_type;
- last--;
- } else {
- txk_map[num_cand] = tx_type;
- rds[num_cand] = rds_v[i_v] + rds_h[i_h];
- if (rds[num_cand] == 0) rds[num_cand] = 1;
- num_cand++;
- }
- }
- sort_rd(rds, txk_map, num_cand);
-
- uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
- num_sel = AOMMIN(num_sel, num_cand);
-
- for (int i = 1; i < num_sel; i++) {
- int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
- if (factor < (int64_t)prune_factor)
- prune &= ~(1 << txk_map[i]);
- else
- break;
- }
- return prune;
-}
-
-static INLINE int is_intra_hash_match(
- const AV1_COMP *cpi, MACROBLOCK *x, int plane, int blk_row, int blk_col,
- BLOCK_SIZE plane_bsize, TX_SIZE tx_size, const TXB_CTX *const txb_ctx,
- TXB_RD_INFO **intra_txb_rd_info, int within_border,
- const int tx_type_map_idx, uint16_t *cur_joint_ctx) {
- const AV1_COMMON *cm = &cpi->common;
- MACROBLOCKD *xd = &x->e_mbd;
- MB_MODE_INFO *mbmi = xd->mi[0];
- const int is_inter = is_inter_block(mbmi);
- if (within_border && cpi->sf.tx_sf.use_intra_txb_hash &&
- frame_is_intra_only(cm) && !is_inter && plane == 0 &&
- tx_size_wide[tx_size] == tx_size_high[tx_size]) {
- const uint32_t intra_hash =
- get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
- const int intra_hash_idx =
- find_tx_size_rd_info(&x->txb_rd_record_intra, intra_hash);
- *intra_txb_rd_info = &x->txb_rd_record_intra.tx_rd_info[intra_hash_idx];
- *cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
- if ((*intra_txb_rd_info)->entropy_context == *cur_joint_ctx &&
- x->txb_rd_record_intra.tx_rd_info[intra_hash_idx].valid) {
- xd->tx_type_map[tx_type_map_idx] = (*intra_txb_rd_info)->tx_type;
- const TX_TYPE ref_tx_type =
- av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
- cpi->common.reduced_tx_set_used);
- return (ref_tx_type == (*intra_txb_rd_info)->tx_type);
- }
- }
- return 0;
-}
-
-static INLINE void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
- int block, int blk_row, int blk_col,
- BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
- const TXB_CTX *const txb_ctx, int skip_trellis,
- TX_TYPE best_tx_type, TX_TYPE last_tx_type,
- int *rate_cost, uint16_t best_eob) {
- const AV1_COMMON *cm = &cpi->common;
- MACROBLOCKD *xd = &x->e_mbd;
- MB_MODE_INFO *mbmi = xd->mi[0];
- const int is_inter = is_inter_block(mbmi);
- if (!is_inter && best_eob &&
- (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
- blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
- // intra mode needs decoded result such that the next transform block
- // can use it for prediction.
- // if the last search tx_type is the best tx_type, we don't need to
- // do this again
- if (best_tx_type != last_tx_type) {
- TxfmParam txfm_param_intra;
- QUANT_PARAM quant_param_intra;
- av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
- av1_setup_quant(cm, tx_size, !skip_trellis,
- skip_trellis
- ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
- : AV1_XFORM_QUANT_FP)
- : AV1_XFORM_QUANT_FP,
- &quant_param_intra);
- av1_setup_qmatrix(cm, x, plane, tx_size, best_tx_type,
- &quant_param_intra);
- av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
- &txfm_param_intra, &quant_param_intra);
- if (quant_param_intra.use_optimize_b) {
- av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
- cpi->sf.rd_sf.trellis_eob_fast, rate_cost);
- }
- }
-
- inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,
- x->plane[plane].eobs[block],
- cm->reduced_tx_set_used);
-
- // This may happen because of hash collision. The eob stored in the hash
- // table is non-zero, but the real eob is zero. We need to make sure tx_type
- // is DCT_DCT in this case.
- if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
- best_tx_type != DCT_DCT) {
- update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
- }
- }
-}
-
-static void search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
- int block, int blk_row, int blk_col,
- BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
- const TXB_CTX *const txb_ctx,
- FAST_TX_SEARCH_MODE ftxs_mode,
- int use_fast_coef_costing, int skip_trellis,
- int64_t ref_best_rd, RD_STATS *best_rd_stats) {
- const AV1_COMMON *cm = &cpi->common;
- MACROBLOCKD *xd = &x->e_mbd;
- struct macroblockd_plane *const pd = &xd->plane[plane];
- MB_MODE_INFO *mbmi = xd->mi[0];
- const int is_inter = is_inter_block(mbmi);
- int64_t best_rd = INT64_MAX;
- uint16_t best_eob = 0;
- TX_TYPE best_tx_type = DCT_DCT;
- TX_TYPE last_tx_type = TX_TYPES;
- int rate_cost = 0;
- const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
- // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
- // of the best tx_type
- DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff[MAX_SB_SQUARE]);
- tran_low_t *orig_dqcoeff = pd->dqcoeff;
- tran_low_t *best_dqcoeff = this_dqcoeff;
- const int tx_type_map_idx =
- plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
- int perform_block_coeff_opt = 0;
- av1_invalid_rd_stats(best_rd_stats);
-
- TXB_RD_INFO *intra_txb_rd_info = NULL;
- uint16_t cur_joint_ctx = 0;
- const int mi_row = xd->mi_row;
- const int mi_col = xd->mi_col;
- const int within_border =
- mi_row >= xd->tile.mi_row_start &&
- (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
- mi_col >= xd->tile.mi_col_start &&
- (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
- skip_trellis |=
- cpi->optimize_seg_arr[mbmi->segment_id] == NO_TRELLIS_OPT ||
- cpi->optimize_seg_arr[mbmi->segment_id] == FINAL_PASS_TRELLIS_OPT;
- if (is_intra_hash_match(cpi, x, plane, blk_row, blk_col, plane_bsize, tx_size,
- txb_ctx, &intra_txb_rd_info, within_border,
- tx_type_map_idx, &cur_joint_ctx)) {
- best_rd_stats->rate = intra_txb_rd_info->rate;
- best_rd_stats->dist = intra_txb_rd_info->dist;
- best_rd_stats->sse = intra_txb_rd_info->sse;
- best_rd_stats->skip = intra_txb_rd_info->eob == 0;
- x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
- x->plane[plane].txb_entropy_ctx[block] = intra_txb_rd_info->txb_entropy_ctx;
- best_eob = intra_txb_rd_info->eob;
- best_tx_type = intra_txb_rd_info->tx_type;
- perform_block_coeff_opt = intra_txb_rd_info->perform_block_coeff_opt;
- skip_trellis |= !perform_block_coeff_opt;
- update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
- recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
- txb_ctx, skip_trellis, best_tx_type, last_tx_type, &rate_cost,
- best_eob);
- pd->dqcoeff = orig_dqcoeff;
- return;
- }
-
- // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
- // TX_TYPES, only that specific tx type is allowed.
- TX_TYPE txk_allowed = TX_TYPES;
- int txk_map[TX_TYPES] = {
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
- };
-
- if ((!is_inter && x->use_default_intra_tx_type) ||
- (is_inter && x->use_default_inter_tx_type)) {
- txk_allowed =
- get_default_tx_type(0, xd, tx_size, cpi->is_screen_content_type);
- } else if (x->rd_model == LOW_TXFM_RD) {
- if (plane == 0) txk_allowed = DCT_DCT;
- }
-
- uint8_t best_txb_ctx = 0;
- const TxSetType tx_set_type =
- av1_get_ext_tx_set_type(tx_size, is_inter, cm->reduced_tx_set_used);
-
- TX_TYPE uv_tx_type = DCT_DCT;
- if (plane) {
- // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
- uv_tx_type = txk_allowed =
- av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
- cm->reduced_tx_set_used);
- }
- PREDICTION_MODE intra_dir =
- mbmi->filter_intra_mode_info.use_filter_intra
- ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
- : mbmi->mode;
- uint16_t ext_tx_used_flag =
- cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset &&
- tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
- ? av1_reduced_intra_tx_used_flag[intra_dir]
- : av1_ext_tx_used_flag[tx_set_type];
- if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
- ext_tx_used_flag == 0x0001 ||
- (is_inter && cpi->oxcf.use_inter_dct_only) ||
- (!is_inter && cpi->oxcf.use_intra_dct_only)) {
- txk_allowed = DCT_DCT;
- }
-
- const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
- int64_t block_sse = 0;
- unsigned int block_mse_q8 = UINT_MAX;
- block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize, tx_bsize,
- &block_mse_q8);
- assert(block_mse_q8 != UINT_MAX);
- if (is_cur_buf_hbd(xd)) {
- block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
- block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
- }
- block_sse *= 16;
-
- // Used mse based threshold logic to take decision of R-D of optimization of
- // coeffs. For smaller residuals, coeff optimization would be helpful. For
- // larger residuals, R-D optimization may not be effective.
- // TODO(any): Experiment with variance and mean based thresholds
- perform_block_coeff_opt = (block_mse_q8 <= x->coeff_opt_dist_threshold);
- skip_trellis |= !perform_block_coeff_opt;
-
- if (cpi->oxcf.enable_flip_idtx == 0) {
- for (TX_TYPE tx_type = FLIPADST_DCT; tx_type <= H_FLIPADST; ++tx_type) {
- ext_tx_used_flag &= ~(1 << tx_type);
- }
- }
-
- uint16_t allowed_tx_mask = 0; // 1: allow; 0: skip.
- if (txk_allowed < TX_TYPES) {
- allowed_tx_mask = 1 << txk_allowed;
- allowed_tx_mask &= ext_tx_used_flag;
- } else if (fast_tx_search) {
- allowed_tx_mask = 0x0c01; // V_DCT, H_DCT, DCT_DCT
- allowed_tx_mask &= ext_tx_used_flag;
- } else {
- assert(plane == 0);
- allowed_tx_mask = ext_tx_used_flag;
- int num_allowed = 0;
- const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
- const int *tx_type_probs = cpi->tx_type_probs[update_type][tx_size];
- int i;
-
- if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
- const int thresh = cpi->tx_type_probs_thresh[update_type];
- uint16_t prune = 0;
- int max_prob = -1;
- int max_idx = 0;
- for (i = 0; i < TX_TYPES; i++) {
- if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
- max_prob = tx_type_probs[i];
- max_idx = i;
- }
- }
-
- for (i = 0; i < TX_TYPES; i++) {
- if (tx_type_probs[i] < thresh && i != max_idx) prune |= (1 << i);
- }
- allowed_tx_mask &= (~prune);
- }
- for (i = 0; i < TX_TYPES; i++) {
- if (allowed_tx_mask & (1 << i)) num_allowed++;
- }
- assert(num_allowed > 0);
-
- if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
- int pf = prune_factors[x->prune_mode];
- int mf = mul_factors[x->prune_mode];
- if (num_allowed <= 7) {
- const uint16_t prune = prune_txk_type(
- cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
- txk_map, allowed_tx_mask, pf, txb_ctx, cm->reduced_tx_set_used);
- allowed_tx_mask &= (~prune);
- } else {
- const int num_sel = (num_allowed * mf + 50) / 100;
- const uint16_t prune = prune_txk_type_separ(
- cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
- txk_map, allowed_tx_mask, pf, txb_ctx, cm->reduced_tx_set_used,
- ref_best_rd, num_sel);
-
- allowed_tx_mask &= (~prune);
- }
- } else {
- assert(num_allowed > 0);
- int allowed_tx_count = (x->prune_mode == PRUNE_2D_AGGRESSIVE) ? 1 : 5;
- // !fast_tx_search && txk_end != txk_start && plane == 0
- if (x->prune_mode >= PRUNE_2D_ACCURATE && is_inter &&
- num_allowed > allowed_tx_count) {
- prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
- x->prune_mode, txk_map, &allowed_tx_mask);
- }
- }
- }
-
- // Need to have at least one transform type allowed.
- if (allowed_tx_mask == 0) {
- txk_allowed = (plane ? uv_tx_type : DCT_DCT);
- allowed_tx_mask = (1 << txk_allowed);
- }
-
- assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
-
- // Tranform domain distortion is accurate for higher residuals.
- // TODO(any): Experiment with variance and mean based thresholds
- int use_transform_domain_distortion =
- (x->use_transform_domain_distortion > 0) &&
- (block_mse_q8 >= x->tx_domain_dist_threshold) &&
- // Any 64-pt transforms only preserves half the coefficients.
- // Therefore transform domain distortion is not valid for these
- // transform sizes.
- txsize_sqr_up_map[tx_size] != TX_64X64;
-#if CONFIG_DIST_8X8
- if (x->using_dist_8x8) use_transform_domain_distortion = 0;
-#endif
- int calc_pixel_domain_distortion_final =
- x->use_transform_domain_distortion == 1 &&
- use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
- if (calc_pixel_domain_distortion_final &&
- (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
- calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
-
- const uint16_t *eobs_ptr = x->plane[plane].eobs;
-
- TxfmParam txfm_param;
- QUANT_PARAM quant_param;
- av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
- av1_setup_quant(cm, tx_size, !skip_trellis,
- skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
- : AV1_XFORM_QUANT_FP)
- : AV1_XFORM_QUANT_FP,
- &quant_param);
- int use_qm = !(xd->lossless[mbmi->segment_id] || cm->using_qmatrix == 0);
-
- for (int idx = 0; idx < TX_TYPES; ++idx) {
- const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
- if (!(allowed_tx_mask & (1 << tx_type))) continue;
- txfm_param.tx_type = tx_type;
- if (use_qm) {
- av1_setup_qmatrix(cm, x, plane, tx_size, tx_type, &quant_param);
- }
- if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
- RD_STATS this_rd_stats;
- av1_invalid_rd_stats(&this_rd_stats);
-
- av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
- &quant_param);
-
- if (quant_param.use_optimize_b) {
- if (cpi->sf.rd_sf.optimize_b_precheck && best_rd < INT64_MAX &&
- eobs_ptr[block] >= 4) {
- // Calculate distortion quickly in transform domain.
- dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
- &this_rd_stats.sse);
-
- const int64_t best_rd_ = AOMMIN(best_rd, ref_best_rd);
- const int64_t dist_cost_estimate =
- RDCOST(x->rdmult, 0, AOMMIN(this_rd_stats.dist, this_rd_stats.sse));
- if (dist_cost_estimate - (dist_cost_estimate >> 3) > best_rd_) continue;
- }
- av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
- cpi->sf.rd_sf.trellis_eob_fast, &rate_cost);
- } else {
- rate_cost =
- av1_cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
- use_fast_coef_costing, cm->reduced_tx_set_used);
- }
-
- // If rd cost based on coeff rate is more than best_rd, skip the calculation
- // of distortion
- int64_t tmp_rd = RDCOST(x->rdmult, rate_cost, 0);
- if (tmp_rd > best_rd) continue;
- if (eobs_ptr[block] == 0) {
- // When eob is 0, pixel domain distortion is more efficient and accurate.
- this_rd_stats.dist = this_rd_stats.sse = block_sse;
- } else if (use_transform_domain_distortion) {
- dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
- &this_rd_stats.sse);
- } else {
- int64_t sse_diff = INT64_MAX;
- // high_energy threshold assumes that every pixel within a txfm block
- // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
- // for 8 bit, then the threshold is scaled based on input bit depth.
- const int64_t high_energy_thresh =
- ((int64_t)128 * 128 * tx_size_2d[tx_size]) << ((xd->bd - 8) * 2);
- const int is_high_energy = (block_sse >= high_energy_thresh);
- if (tx_size == TX_64X64 || is_high_energy) {
- // Because 3 out 4 quadrants of transform coefficients are forced to
- // zero, the inverse transform has a tendency to overflow. sse_diff
- // is effectively the energy of those 3 quadrants, here we use it
- // to decide if we should do pixel domain distortion. If the energy
- // is mostly in first quadrant, then it is unlikely that we have
- // overflow issue in inverse transform.
- dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
- &this_rd_stats.sse);
- sse_diff = block_sse - this_rd_stats.sse;
- }
- if (tx_size != TX_64X64 || !is_high_energy ||
- (sse_diff * 2) < this_rd_stats.sse) {
- const int64_t tx_domain_dist = this_rd_stats.dist;
- this_rd_stats.dist = dist_block_px_domain(
- cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
- // For high energy blocks, occasionally, the pixel domain distortion
- // can be artificially low due to clamping at reconstruction stage
- // even when inverse transform output is hugely different from the
- // actual residue.
- if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
- this_rd_stats.dist = tx_domain_dist;
- } else {
- this_rd_stats.dist += sse_diff;
- }
- this_rd_stats.sse = block_sse;
- }
-
- this_rd_stats.rate = rate_cost;
-
- const int64_t rd =
- RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
-
- if (rd < best_rd) {
- best_rd = rd;
- *best_rd_stats = this_rd_stats;
- best_tx_type = tx_type;
- best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
- best_eob = x->plane[plane].eobs[block];
- last_tx_type = best_tx_type;
-
- // Swap qcoeff and dqcoeff buffers
- tran_low_t *const tmp_dqcoeff = best_dqcoeff;
- best_dqcoeff = pd->dqcoeff;
- pd->dqcoeff = tmp_dqcoeff;
- }
-
-#if CONFIG_COLLECT_RD_STATS == 1
- if (plane == 0) {
- PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
- plane_bsize, tx_size, tx_type, rd);
- }
-#endif // CONFIG_COLLECT_RD_STATS == 1
-
-#if COLLECT_TX_SIZE_DATA
- // Generate small sample to restrict output size.
- static unsigned int seed = 21743;
- if (lcg_rand16(&seed) % 200 == 0) {
- FILE *fp = NULL;
-
- if (within_border) {
- fp = fopen(av1_tx_size_data_output_file, "a");
- }
-
- if (fp) {
- // Transform info and RD
- const int txb_w = tx_size_wide[tx_size];
- const int txb_h = tx_size_high[tx_size];
-
- // Residue signal.
- const int diff_stride = block_size_wide[plane_bsize];
- struct macroblock_plane *const p = &x->plane[plane];
- const int16_t *src_diff =
- &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
-
- for (int r = 0; r < txb_h; ++r) {
- for (int c = 0; c < txb_w; ++c) {
- fprintf(fp, "%d,", src_diff[c]);
- }
- src_diff += diff_stride;
- }
-
- fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
- fprintf(fp, "\n");
- fclose(fp);
- }
- }
-#endif // COLLECT_TX_SIZE_DATA
-
- if (cpi->sf.tx_sf.adaptive_txb_search_level) {
- if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
- ref_best_rd) {
- break;
- }
- }
-
- // Skip transform type search when we found the block has been quantized to
- // all zero and at the same time, it has better rdcost than doing transform.
- if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
- }
-
- assert(best_rd != INT64_MAX);
-
- best_rd_stats->skip = best_eob == 0;
- if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
- x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
- x->plane[plane].eobs[block] = best_eob;
-
- pd->dqcoeff = best_dqcoeff;
-
- if (calc_pixel_domain_distortion_final && best_eob) {
- best_rd_stats->dist = dist_block_px_domain(
- cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
- best_rd_stats->sse = block_sse;
- }
-
- if (intra_txb_rd_info != NULL) {
- intra_txb_rd_info->valid = 1;
- intra_txb_rd_info->entropy_context = cur_joint_ctx;
- intra_txb_rd_info->rate = best_rd_stats->rate;
- intra_txb_rd_info->dist = best_rd_stats->dist;
- intra_txb_rd_info->sse = best_rd_stats->sse;
- intra_txb_rd_info->eob = best_eob;
- intra_txb_rd_info->txb_entropy_ctx = best_txb_ctx;
- intra_txb_rd_info->perform_block_coeff_opt = perform_block_coeff_opt;
- if (plane == 0) intra_txb_rd_info->tx_type = best_tx_type;
- }
-
- recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
- txb_ctx, skip_trellis, best_tx_type, last_tx_type, &rate_cost,
- best_eob);
- pd->dqcoeff = orig_dqcoeff;
-}
-
-static AOM_INLINE void block_rd_txfm(int plane, int block, int blk_row,
- int blk_col, BLOCK_SIZE plane_bsize,
- TX_SIZE tx_size, void *arg) {
- struct rdcost_block_args *args = arg;
- MACROBLOCK *const x = args->x;
- MACROBLOCKD *const xd = &x->e_mbd;
- const int is_inter = is_inter_block(xd->mi[0]);
- const AV1_COMP *cpi = args->cpi;
- ENTROPY_CONTEXT *a = args->t_above + blk_col;
- ENTROPY_CONTEXT *l = args->t_left + blk_row;
- const AV1_COMMON *cm = &cpi->common;
- RD_STATS this_rd_stats;
-
- av1_init_rd_stats(&this_rd_stats);
-
- if (args->exit_early) {
- args->incomplete_exit = 1;
- return;
- }
-
- if (!is_inter) {
- av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
- av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
- }
- TXB_CTX txb_ctx;
- get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
- search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
- &txb_ctx, args->ftxs_mode, args->use_fast_coef_costing,
- args->skip_trellis, args->best_rd - args->this_rd,
- &this_rd_stats);
-
- if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
- assert(!is_inter || plane_bsize < BLOCK_8X8);
- cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
- }
-
-#if CONFIG_RD_DEBUG
- av1_update_txb_coeff_cost(&this_rd_stats, plane, tx_size, blk_row, blk_col,
- this_rd_stats.rate);
-#endif // CONFIG_RD_DEBUG
- av1_set_txb_context(x, plane, block, tx_size, a, l);
-
- const int blk_idx =
- blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
-
- if (plane == 0)
- set_blk_skip(x, plane, blk_idx, x->plane[plane].eobs[block] == 0);
- else
- set_blk_skip(x, plane, blk_idx, 0);
-
- int64_t rd;
- if (is_inter) {
- const int64_t rd1 =
- RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
- const int64_t rd2 = RDCOST(x->rdmult, 0, this_rd_stats.sse);
-
- // TODO(jingning): temporarily enabled only for luma component
- rd = AOMMIN(rd1, rd2);
- this_rd_stats.skip &= !x->plane[plane].eobs[block];
- } else {
- // Signal non-skip for Intra blocks
- rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
- this_rd_stats.skip = 0;
- }
-
- av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
-
- args->this_rd += rd;
-
- if (args->this_rd > args->best_rd) args->exit_early = 1;
-}
-
-static AOM_INLINE void txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
- RD_STATS *rd_stats, int64_t ref_best_rd,
- int64_t this_rd, int plane,
- BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
- int use_fast_coef_casting,
- FAST_TX_SEARCH_MODE ftxs_mode,
- int skip_trellis) {
- if (!cpi->oxcf.enable_tx64 && txsize_sqr_up_map[tx_size] == TX_64X64) {
- av1_invalid_rd_stats(rd_stats);
- return;
- }
-
- MACROBLOCKD *const xd = &x->e_mbd;
- const struct macroblockd_plane *const pd = &xd->plane[plane];
- struct rdcost_block_args args;
- av1_zero(args);
- args.x = x;
- args.cpi = cpi;
- args.best_rd = ref_best_rd;
- args.use_fast_coef_costing = use_fast_coef_casting;
- args.ftxs_mode = ftxs_mode;
- args.this_rd = this_rd;
- args.skip_trellis = skip_trellis;
- av1_init_rd_stats(&args.rd_stats);
-
- if (plane == 0) xd->mi[0]->tx_size = tx_size;
-
- av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
-
- if (args.this_rd > args.best_rd) {
- args.exit_early = 1;
- }
-
- av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
- &args);
-
- MB_MODE_INFO *const mbmi = xd->mi[0];
- const int is_inter = is_inter_block(mbmi);
- const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
-
- if (invalid_rd) {
- av1_invalid_rd_stats(rd_stats);
- } else {
- *rd_stats = args.rd_stats;
- }
-}
-
-static int tx_size_cost(const MACROBLOCK *const x, BLOCK_SIZE bsize,
- TX_SIZE tx_size) {
- assert(bsize == x->e_mbd.mi[0]->sb_type);
- if (x->tx_mode_search_type != TX_MODE_SELECT || !block_signals_txsize(bsize))
- return 0;
-
- const int32_t tx_size_cat = bsize_to_tx_size_cat(bsize);
- const int depth = tx_size_to_depth(tx_size, bsize);
- const MACROBLOCKD *const xd = &x->e_mbd;
- const int tx_size_ctx = get_tx_size_context(xd);
- return x->tx_size_cost[tx_size_cat][tx_size_ctx][depth];
-}
-
-static int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
- RD_STATS *rd_stats, int64_t ref_best_rd, BLOCK_SIZE bs,
- TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
- int skip_trellis) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- int64_t rd = INT64_MAX;
- const int skip_ctx = av1_get_skip_context(xd);
- int s0, s1;
- const int is_inter = is_inter_block(mbmi);
- const int tx_select = x->tx_mode_search_type == TX_MODE_SELECT &&
- block_signals_txsize(mbmi->sb_type);
- int ctx = txfm_partition_context(
- xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
- const int r_tx_size =
- is_inter ? x->txfm_partition_cost[ctx][0] : tx_size_cost(x, bs, tx_size);
-
- assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
-
- s0 = x->skip_cost[skip_ctx][0];
- s1 = x->skip_cost[skip_ctx][1];
-
- int64_t skip_rd = INT64_MAX;
- int64_t this_rd = RDCOST(x->rdmult, s0 + r_tx_size * tx_select, 0);
-
- if (is_inter) skip_rd = RDCOST(x->rdmult, s1, 0);
-
- mbmi->tx_size = tx_size;
- txfm_rd_in_plane(
- x, cpi, rd_stats, ref_best_rd, AOMMIN(this_rd, skip_rd), AOM_PLANE_Y, bs,
- tx_size, cpi->sf.rd_sf.use_fast_coef_costing, ftxs_mode, skip_trellis);
- if (rd_stats->rate == INT_MAX) return INT64_MAX;
-
- // rdstats->rate should include all the rate except skip/non-skip cost as the
- // same is accounted in the caller functions after rd evaluation of all
- // planes. However the decisions should be done after considering the
- // skip/non-skip header cost
- if (rd_stats->skip && is_inter) {
- rd = RDCOST(x->rdmult, s1, rd_stats->sse);
- } else {
- // Intra blocks are always signalled as non-skip
- rd = RDCOST(x->rdmult, rd_stats->rate + s0 + r_tx_size * tx_select,
- rd_stats->dist);
- rd_stats->rate += r_tx_size * tx_select;
- }
- if (is_inter && !xd->lossless[xd->mi[0]->segment_id]) {
- int64_t temp_skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
- if (temp_skip_rd <= rd) {
- rd = temp_skip_rd;
- rd_stats->rate = 0;
- rd_stats->dist = rd_stats->sse;
- rd_stats->skip = 1;
- }
- }
-
- return rd;
-}
-
static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
MACROBLOCK *x, int64_t ref_best_rd,
RD_STATS *rd_stats) {
@@ -3292,423 +1379,6 @@
return rd;
}
-static AOM_INLINE void choose_largest_tx_size(const AV1_COMP *const cpi,
- MACROBLOCK *x, RD_STATS *rd_stats,
- int64_t ref_best_rd,
- BLOCK_SIZE bs) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- mbmi->tx_size = tx_size_from_tx_mode(bs, x->tx_mode_search_type);
-
- // If tx64 is not enabled, we need to go down to the next available size
- if (!cpi->oxcf.enable_tx64) {
- static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
- TX_4X4, // 4x4 transform
- TX_8X8, // 8x8 transform
- TX_16X16, // 16x16 transform
- TX_32X32, // 32x32 transform
- TX_32X32, // 64x64 transform
- TX_4X8, // 4x8 transform
- TX_8X4, // 8x4 transform
- TX_8X16, // 8x16 transform
- TX_16X8, // 16x8 transform
- TX_16X32, // 16x32 transform
- TX_32X16, // 32x16 transform
- TX_32X32, // 32x64 transform
- TX_32X32, // 64x32 transform
- TX_4X16, // 4x16 transform
- TX_16X4, // 16x4 transform
- TX_8X32, // 8x32 transform
- TX_32X8, // 32x8 transform
- TX_16X32, // 16x64 transform
- TX_32X16, // 64x16 transform
- };
-
- mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
- }
-
- const int skip_ctx = av1_get_skip_context(xd);
- int s0, s1;
-
- s0 = x->skip_cost[skip_ctx][0];
- s1 = x->skip_cost[skip_ctx][1];
-
- int64_t skip_rd = INT64_MAX;
- int64_t this_rd = RDCOST(x->rdmult, s0, 0);
-
- // Skip RDcost is used only for Inter blocks
- if (is_inter_block(xd->mi[0])) skip_rd = RDCOST(x->rdmult, s1, 0);
-
- txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, AOMMIN(this_rd, skip_rd),
- AOM_PLANE_Y, bs, mbmi->tx_size,
- cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE, 0);
-}
-
-static AOM_INLINE void choose_smallest_tx_size(const AV1_COMP *const cpi,
- MACROBLOCK *x,
- RD_STATS *rd_stats,
- int64_t ref_best_rd,
- BLOCK_SIZE bs) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
-
- mbmi->tx_size = TX_4X4;
- // TODO(any) : Pass this_rd based on skip/non-skip cost
- txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
- cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE, 0);
-}
-
-static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
- const SPEED_FEATURES *sf,
- int tx_size_search_method) {
- if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
-
- if (sf->tx_sf.tx_size_search_lgr_block) {
- if (mi_width > mi_size_wide[BLOCK_64X64] ||
- mi_height > mi_size_high[BLOCK_64X64])
- return MAX_VARTX_DEPTH;
- }
-
- if (is_inter) {
- return (mi_height != mi_width)
- ? sf->tx_sf.inter_tx_size_search_init_depth_rect
- : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
- } else {
- return (mi_height != mi_width)
- ? sf->tx_sf.intra_tx_size_search_init_depth_rect
- : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
- }
-}
-
-static AOM_INLINE void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
- MACROBLOCK *x,
- RD_STATS *rd_stats,
- int64_t ref_best_rd,
- BLOCK_SIZE bs) {
- av1_invalid_rd_stats(rd_stats);
-
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
- const int tx_select = x->tx_mode_search_type == TX_MODE_SELECT;
- int start_tx;
- int depth, init_depth;
-
- if (tx_select) {
- start_tx = max_rect_tx_size;
- init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
- is_inter_block(mbmi), &cpi->sf,
- x->tx_size_search_method);
- } else {
- const TX_SIZE chosen_tx_size =
- tx_size_from_tx_mode(bs, x->tx_mode_search_type);
- start_tx = chosen_tx_size;
- init_depth = MAX_TX_DEPTH;
- }
-
- uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
- uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
- TX_SIZE best_tx_size = max_rect_tx_size;
- int64_t best_rd = INT64_MAX;
- const int n4 = bsize_to_num_blk(bs);
- x->rd_model = FULL_TXFM_RD;
- depth = init_depth;
- int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
- for (int n = start_tx; depth <= MAX_TX_DEPTH;
- depth++, n = sub_tx_size_map[n]) {
-#if CONFIG_DIST_8X8
- if (x->using_dist_8x8) {
- if (tx_size_wide[n] < 8 || tx_size_high[n] < 8) continue;
- }
-#endif
- if (!cpi->oxcf.enable_tx64 && txsize_sqr_up_map[n] == TX_64X64) continue;
-
- RD_STATS this_rd_stats;
- rd[depth] =
- txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, n, FTXS_NONE, 0);
-
- if (rd[depth] < best_rd) {
- av1_copy_array(best_blk_skip, x->blk_skip, n4);
- av1_copy_array(best_txk_type_map, xd->tx_type_map, n4);
- best_tx_size = n;
- best_rd = rd[depth];
- *rd_stats = this_rd_stats;
- }
- if (n == TX_4X4) break;
- // If we are searching three depths, prune the smallest size depending
- // on rd results for the first two depths for low contrast blocks.
- if (depth > init_depth && depth != MAX_TX_DEPTH &&
- x->source_variance < 256) {
- if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
- }
- }
-
- if (rd_stats->rate != INT_MAX) {
- mbmi->tx_size = best_tx_size;
- av1_copy_array(xd->tx_type_map, best_txk_type_map, n4);
- av1_copy_array(x->blk_skip, best_blk_skip, n4);
- }
-}
-
-// origin_threshold * 128 / 100
-static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
- {
- 64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
- 68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
- },
- {
- 88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
- 68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
- },
- {
- 90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
- 74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
- },
-};
-
-// lookup table for predict_skip_flag
-// int max_tx_size = max_txsize_rect_lookup[bsize];
-// if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
-// max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
-static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
- TX_4X4, TX_4X8, TX_8X4, TX_8X8, TX_8X16, TX_16X8,
- TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
- TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16, TX_16X4,
- TX_8X8, TX_8X8, TX_16X16, TX_16X16,
-};
-
-// Uses simple features on top of DCT coefficients to quickly predict
-// whether optimal RD decision is to skip encoding the residual.
-// The sse value is stored in dist.
-static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
- int reduced_tx_set) {
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
- const MACROBLOCKD *xd = &x->e_mbd;
- const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
-
- *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
-
- const int64_t mse = *dist / bw / bh;
- // Normalized quantizer takes the transform upscaling factor (8 for tx size
- // smaller than 32) into account.
- const int16_t normalized_dc_q = dc_q >> 3;
- const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
- // For faster early skip decision, use dist to compare against threshold so
- // that quality risk is less for the skip=1 decision. Otherwise, use mse
- // since the fwd_txfm coeff checks will take care of quality
- // TODO(any): Use dist to return 0 when predict_skip_level is 1
- int64_t pred_err = (x->predict_skip_level >= 2) ? *dist : mse;
- // Predict not to skip when error is larger than threshold.
- if (pred_err > mse_thresh) return 0;
- // Return as skip otherwise for aggressive early skip
- else if (x->predict_skip_level >= 2)
- return 1;
-
- const int max_tx_size = max_predict_sf_tx_size[bsize];
- const int tx_h = tx_size_high[max_tx_size];
- const int tx_w = tx_size_wide[max_tx_size];
- DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
- TxfmParam param;
- param.tx_type = DCT_DCT;
- param.tx_size = max_tx_size;
- param.bd = xd->bd;
- param.is_hbd = is_cur_buf_hbd(xd);
- param.lossless = 0;
- param.tx_set_type = av1_get_ext_tx_set_type(
- param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
- const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
- const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
- const int16_t *src_diff = x->plane[0].src_diff;
- const int n_coeff = tx_w * tx_h;
- const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
- const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
- const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
- for (int row = 0; row < bh; row += tx_h) {
- for (int col = 0; col < bw; col += tx_w) {
- av1_fwd_txfm(src_diff + col, coefs, bw, ¶m);
- // Operating on TX domain, not pixels; we want the QTX quantizers
- const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
- if (dc_coef >= dc_thresh) return 0;
- for (int i = 1; i < n_coeff; ++i) {
- const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
- if (ac_coef >= ac_thresh) return 0;
- }
- }
- src_diff += tx_h * bw;
- }
- return 1;
-}
-
-// Used to set proper context for early termination with skip = 1.
-static AOM_INLINE void set_skip_flag(MACROBLOCK *x, RD_STATS *rd_stats,
- int bsize, int64_t dist) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- const int n4 = bsize_to_num_blk(bsize);
- const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
- memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
- memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
- mbmi->tx_size = tx_size;
- for (int i = 0; i < n4; ++i) set_blk_skip(x, 0, i, 1);
- rd_stats->skip = 1;
- if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
- rd_stats->dist = rd_stats->sse = (dist << 4);
- // Though decision is to make the block as skip based on luma stats,
- // it is possible that block becomes non skip after chroma rd. In addition
- // intermediate non skip costs calculated by caller function will be
- // incorrect, if rate is set as zero (i.e., if zero_blk_rate is not
- // accounted). Hence intermediate rate is populated to code the luma tx blks
- // as skip, the caller function based on final rd decision (i.e., skip vs
- // non-skip) sets the final rate accordingly. Here the rate populated
- // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
- // size possible) in the current block. Eg: For 128*128 block, rate would be
- // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
- // block as 'all zeros'
- ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
- ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
- av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
- ENTROPY_CONTEXT *ta = ctxa;
- ENTROPY_CONTEXT *tl = ctxl;
- const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
- TXB_CTX txb_ctx;
- get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
- const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
- .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
- rd_stats->rate = zero_blk_rate *
- (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
- (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
-}
-
-static INLINE uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
- const int rows = block_size_high[bsize];
- const int cols = block_size_wide[bsize];
- const int16_t *diff = x->plane[0].src_diff;
- const uint32_t hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
- (uint8_t *)diff, 2 * rows * cols);
- return (hash << 5) + bsize;
-}
-
-static AOM_INLINE void save_tx_rd_info(int n4, uint32_t hash,
- const MACROBLOCK *const x,
- const RD_STATS *const rd_stats,
- MB_RD_RECORD *tx_rd_record) {
- int index;
- if (tx_rd_record->num < RD_RECORD_BUFFER_LEN) {
- index =
- (tx_rd_record->index_start + tx_rd_record->num) % RD_RECORD_BUFFER_LEN;
- ++tx_rd_record->num;
- } else {
- index = tx_rd_record->index_start;
- tx_rd_record->index_start =
- (tx_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
- }
- MB_RD_INFO *const tx_rd_info = &tx_rd_record->tx_rd_info[index];
- const MACROBLOCKD *const xd = &x->e_mbd;
- const MB_MODE_INFO *const mbmi = xd->mi[0];
- tx_rd_info->hash_value = hash;
- tx_rd_info->tx_size = mbmi->tx_size;
- memcpy(tx_rd_info->blk_skip, x->blk_skip,
- sizeof(tx_rd_info->blk_skip[0]) * n4);
- av1_copy(tx_rd_info->inter_tx_size, mbmi->inter_tx_size);
- av1_copy_array(tx_rd_info->tx_type_map, xd->tx_type_map, n4);
- tx_rd_info->rd_stats = *rd_stats;
-}
-
-static AOM_INLINE void fetch_tx_rd_info(int n4,
- const MB_RD_INFO *const tx_rd_info,
- RD_STATS *const rd_stats,
- MACROBLOCK *const x) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- mbmi->tx_size = tx_rd_info->tx_size;
- memcpy(x->blk_skip, tx_rd_info->blk_skip,
- sizeof(tx_rd_info->blk_skip[0]) * n4);
- av1_copy(mbmi->inter_tx_size, tx_rd_info->inter_tx_size);
- av1_copy_array(xd->tx_type_map, tx_rd_info->tx_type_map, n4);
- *rd_stats = tx_rd_info->rd_stats;
-}
-
-static INLINE int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
- const int64_t ref_best_rd,
- const uint32_t hash) {
- int32_t match_index = -1;
- if (ref_best_rd != INT64_MAX) {
- for (int i = 0; i < mb_rd_record->num; ++i) {
- const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
- // If there is a match in the tx_rd_record, fetch the RD decision and
- // terminate early.
- if (mb_rd_record->tx_rd_info[index].hash_value == hash) {
- match_index = index;
- break;
- }
- }
- }
- return match_index;
-}
-
-static AOM_INLINE void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
- RD_STATS *rd_stats, BLOCK_SIZE bs,
- int64_t ref_best_rd) {
- MACROBLOCKD *xd = &x->e_mbd;
- av1_init_rd_stats(rd_stats);
- int is_inter = is_inter_block(xd->mi[0]);
- assert(bs == xd->mi[0]->sb_type);
-
- const int mi_row = -xd->mb_to_top_edge >> (3 + MI_SIZE_LOG2);
- const int mi_col = -xd->mb_to_left_edge >> (3 + MI_SIZE_LOG2);
-
- uint32_t hash = 0;
- int32_t match_index = -1;
- MB_RD_RECORD *mb_rd_record = NULL;
- const int within_border = mi_row >= xd->tile.mi_row_start &&
- (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
- mi_col >= xd->tile.mi_col_start &&
- (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
- const int is_mb_rd_hash_enabled =
- (within_border && cpi->sf.rd_sf.use_mb_rd_hash && is_inter);
- const int n4 = bsize_to_num_blk(bs);
- if (is_mb_rd_hash_enabled) {
- hash = get_block_residue_hash(x, bs);
- mb_rd_record = &x->mb_rd_record;
- match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
- if (match_index != -1) {
- MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
- fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
- return;
- }
- }
-
- // If we predict that skip is the optimal RD decision - set the respective
- // context and terminate early.
- int64_t dist;
-
- if (x->predict_skip_level && is_inter &&
- (!xd->lossless[xd->mi[0]->segment_id]) &&
- predict_skip_flag(x, bs, &dist, cpi->common.reduced_tx_set_used)) {
- // Populate rdstats as per skip decision
- set_skip_flag(x, rd_stats, bs, dist);
- // Save the RD search results into tx_rd_record.
- if (is_mb_rd_hash_enabled)
- save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
- return;
- }
-
- if (xd->lossless[xd->mi[0]->segment_id]) {
- choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
- } else if (x->tx_size_search_method == USE_LARGESTALL) {
- choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
- } else {
- choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
- }
-
- // Save the RD search results into tx_rd_record.
- if (is_mb_rd_hash_enabled) {
- assert(mb_rd_record != NULL);
- save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
- }
-}
-
// Return the rate cost for luma prediction mode info. of intra blocks.
static int intra_mode_info_cost_y(const AV1_COMP *cpi, const MACROBLOCK *x,
const MB_MODE_INFO *mbmi, BLOCK_SIZE bsize,
@@ -5016,1054 +2686,6 @@
return best_rd;
}
-// Return value 0: early termination triggered, no valid rd cost available;
-// 1: rd cost values are valid.
-static int super_block_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x,
- RD_STATS *rd_stats, BLOCK_SIZE bsize,
- int64_t ref_best_rd) {
- av1_init_rd_stats(rd_stats);
- int is_cost_valid = 1;
- if (ref_best_rd < 0) is_cost_valid = 0;
- if (x->skip_chroma_rd || !is_cost_valid) return is_cost_valid;
-
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
- int plane;
- const int is_inter = is_inter_block(mbmi);
- int64_t this_rd = 0, skip_rd = 0;
- const BLOCK_SIZE plane_bsize =
- get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
-
- if (is_inter && is_cost_valid) {
- for (plane = 1; plane < MAX_MB_PLANE; ++plane)
- av1_subtract_plane(x, plane_bsize, plane);
- }
-
- if (is_cost_valid) {
- const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
- for (plane = 1; plane < MAX_MB_PLANE; ++plane) {
- RD_STATS pn_rd_stats;
- int64_t chroma_ref_best_rd = ref_best_rd;
- // For inter blocks, refined ref_best_rd is used for early exit
- // For intra blocks, even though current rd crosses ref_best_rd, early
- // exit is not recommended as current rd is used for gating subsequent
- // modes as well (say, for angular modes)
- // TODO(any): Extend the early exit mechanism for intra modes as well
- if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
- is_inter && chroma_ref_best_rd != INT64_MAX)
- chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_rd);
- txfm_rd_in_plane(x, cpi, &pn_rd_stats, chroma_ref_best_rd, 0, plane,
- plane_bsize, uv_tx_size,
- cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE, 0);
- if (pn_rd_stats.rate == INT_MAX) {
- is_cost_valid = 0;
- break;
- }
- av1_merge_rd_stats(rd_stats, &pn_rd_stats);
- this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
- skip_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
- if (AOMMIN(this_rd, skip_rd) > ref_best_rd) {
- is_cost_valid = 0;
- break;
- }
- }
- }
-
- if (!is_cost_valid) {
- // reset cost value
- av1_invalid_rd_stats(rd_stats);
- }
-
- return is_cost_valid;
-}
-
-// Pick transform type for a transform block of tx_size.
-static AOM_INLINE void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
- TX_SIZE tx_size, int blk_row, int blk_col,
- int plane, int block, int plane_bsize,
- TXB_CTX *txb_ctx, RD_STATS *rd_stats,
- FAST_TX_SEARCH_MODE ftxs_mode,
- int64_t ref_rdcost,
- TXB_RD_INFO *rd_info_array) {
- const struct macroblock_plane *const p = &x->plane[plane];
- const uint16_t cur_joint_ctx =
- (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
- MACROBLOCKD *xd = &x->e_mbd;
- const int tx_type_map_idx =
- plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
- // Look up RD and terminate early in case when we've already processed exactly
- // the same residual with exactly the same entropy context.
- if (rd_info_array != NULL && rd_info_array->valid &&
- rd_info_array->entropy_context == cur_joint_ctx) {
- if (plane == 0) xd->tx_type_map[tx_type_map_idx] = rd_info_array->tx_type;
- const TX_TYPE ref_tx_type =
- av1_get_tx_type(&x->e_mbd, get_plane_type(plane), blk_row, blk_col,
- tx_size, cpi->common.reduced_tx_set_used);
- if (ref_tx_type == rd_info_array->tx_type) {
- rd_stats->rate += rd_info_array->rate;
- rd_stats->dist += rd_info_array->dist;
- rd_stats->sse += rd_info_array->sse;
- rd_stats->skip &= rd_info_array->eob == 0;
- p->eobs[block] = rd_info_array->eob;
- p->txb_entropy_ctx[block] = rd_info_array->txb_entropy_ctx;
- return;
- }
- }
-
- RD_STATS this_rd_stats;
- search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
- txb_ctx, ftxs_mode, 0, 0, ref_rdcost, &this_rd_stats);
-
- av1_merge_rd_stats(rd_stats, &this_rd_stats);
-
- // Save RD results for possible reuse in future.
- if (rd_info_array != NULL) {
- rd_info_array->valid = 1;
- rd_info_array->entropy_context = cur_joint_ctx;
- rd_info_array->rate = this_rd_stats.rate;
- rd_info_array->dist = this_rd_stats.dist;
- rd_info_array->sse = this_rd_stats.sse;
- rd_info_array->eob = p->eobs[block];
- rd_info_array->txb_entropy_ctx = p->txb_entropy_ctx[block];
- if (plane == 0) rd_info_array->tx_type = xd->tx_type_map[tx_type_map_idx];
- }
-}
-
-static float get_dev(float mean, double x2_sum, int num) {
- const float e_x2 = (float)(x2_sum / num);
- const float diff = e_x2 - mean * mean;
- const float dev = (diff > 0) ? sqrtf(diff) : 0;
- return dev;
-}
-
-// Feature used by the model to predict tx split: the mean and standard
-// deviation values of the block and sub-blocks.
-static AOM_INLINE void get_mean_dev_features(const int16_t *data, int stride,
- int bw, int bh, float *feature) {
- const int16_t *const data_ptr = &data[0];
- const int subh = (bh >= bw) ? (bh >> 1) : bh;
- const int subw = (bw >= bh) ? (bw >> 1) : bw;
- const int num = bw * bh;
- const int sub_num = subw * subh;
- int feature_idx = 2;
- int total_x_sum = 0;
- int64_t total_x2_sum = 0;
- int blk_idx = 0;
- double mean2_sum = 0.0f;
- float dev_sum = 0.0f;
-
- for (int row = 0; row < bh; row += subh) {
- for (int col = 0; col < bw; col += subw) {
- int x_sum;
- int64_t x2_sum;
- // TODO(any): Write a SIMD version. Clear registers.
- aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
- &x_sum, &x2_sum);
- total_x_sum += x_sum;
- total_x2_sum += x2_sum;
-
- aom_clear_system_state();
- const float mean = (float)x_sum / sub_num;
- const float dev = get_dev(mean, (double)x2_sum, sub_num);
- feature[feature_idx++] = mean;
- feature[feature_idx++] = dev;
- mean2_sum += (double)(mean * mean);
- dev_sum += dev;
- blk_idx++;
- }
- }
-
- const float lvl0_mean = (float)total_x_sum / num;
- feature[0] = lvl0_mean;
- feature[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
-
- if (blk_idx > 1) {
- // Deviation of means.
- feature[feature_idx++] = get_dev(lvl0_mean, mean2_sum, blk_idx);
- // Mean of deviations.
- feature[feature_idx++] = dev_sum / blk_idx;
- }
-}
-
-static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
- int blk_col, TX_SIZE tx_size) {
- const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
- if (!nn_config) return -1;
-
- const int diff_stride = block_size_wide[bsize];
- const int16_t *diff =
- x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
- const int bw = tx_size_wide[tx_size];
- const int bh = tx_size_high[tx_size];
- aom_clear_system_state();
-
- float features[64] = { 0.0f };
- get_mean_dev_features(diff, diff_stride, bw, bh, features);
-
- float score = 0.0f;
- av1_nn_predict(features, nn_config, 1, &score);
- aom_clear_system_state();
-
- int int_score = (int)(score * 10000);
- return clamp(int_score, -80000, 80000);
-}
-
-typedef struct {
- int64_t rd;
- int txb_entropy_ctx;
- TX_TYPE tx_type;
-} TxCandidateInfo;
-
-static AOM_INLINE void try_tx_block_no_split(
- const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
- TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
- const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
- int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
- FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
- TxCandidateInfo *no_split) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- struct macroblock_plane *const p = &x->plane[0];
- const int bw = mi_size_wide[plane_bsize];
-
- no_split->rd = INT64_MAX;
- no_split->txb_entropy_ctx = 0;
- no_split->tx_type = TX_TYPES;
-
- const ENTROPY_CONTEXT *const pta = ta + blk_col;
- const ENTROPY_CONTEXT *const ptl = tl + blk_row;
-
- const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
- TXB_CTX txb_ctx;
- get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
- const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
- .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
- rd_stats->zero_rate = zero_blk_rate;
- const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
- mbmi->inter_tx_size[index] = tx_size;
- tx_type_rd(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize, &txb_ctx,
- rd_stats, ftxs_mode, ref_best_rd,
- rd_info_node != NULL ? rd_info_node->rd_info_array : NULL);
- assert(rd_stats->rate < INT_MAX);
-
- if ((RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
- RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
- rd_stats->skip == 1) &&
- !xd->lossless[mbmi->segment_id]) {
-#if CONFIG_RD_DEBUG
- av1_update_txb_coeff_cost(rd_stats, 0, tx_size, blk_row, blk_col,
- zero_blk_rate - rd_stats->rate);
-#endif // CONFIG_RD_DEBUG
- rd_stats->rate = zero_blk_rate;
- rd_stats->dist = rd_stats->sse;
- rd_stats->skip = 1;
- set_blk_skip(x, 0, blk_row * bw + blk_col, 1);
- p->eobs[block] = 0;
- update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
- } else {
- set_blk_skip(x, 0, blk_row * bw + blk_col, 0);
- rd_stats->skip = 0;
- }
-
- if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
- rd_stats->rate += x->txfm_partition_cost[txfm_partition_ctx][0];
-
- no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
- no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
- no_split->tx_type =
- xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
-}
-
-static AOM_INLINE void select_tx_block(
- const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
- TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
- ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
- RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
- int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
- TXB_RD_INFO_NODE *rd_info_node);
-
-static AOM_INLINE void try_tx_block_split(
- const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
- TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
- ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
- int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
- FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
- RD_STATS *split_rd_stats, int64_t *split_rd) {
- assert(tx_size < TX_SIZES_ALL);
- MACROBLOCKD *const xd = &x->e_mbd;
- const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
- const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
- const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
- const int bsw = tx_size_wide_unit[sub_txs];
- const int bsh = tx_size_high_unit[sub_txs];
- const int sub_step = bsw * bsh;
- const int nblks =
- (tx_size_high_unit[tx_size] / bsh) * (tx_size_wide_unit[tx_size] / bsw);
- assert(nblks > 0);
- int blk_idx = 0;
- int64_t tmp_rd = 0;
- *split_rd = INT64_MAX;
- split_rd_stats->rate = x->txfm_partition_cost[txfm_partition_ctx][1];
-
- for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
- for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw, ++blk_idx) {
- assert(blk_idx < 4);
- const int offsetr = blk_row + r;
- const int offsetc = blk_col + c;
- if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
-
- RD_STATS this_rd_stats;
- int this_cost_valid = 1;
- select_tx_block(
- cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta,
- tl, tx_above, tx_left, &this_rd_stats, no_split_rd / nblks,
- ref_best_rd - tmp_rd, &this_cost_valid, ftxs_mode,
- (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
- if (!this_cost_valid) return;
- av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
- tmp_rd = RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
- if (no_split_rd < tmp_rd) return;
- block += sub_step;
- }
- }
-
- *split_rd = tmp_rd;
-}
-
-// Search for the best tx partition/type for a given luma block.
-static AOM_INLINE void select_tx_block(
- const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
- TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
- ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
- RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
- int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
- TXB_RD_INFO_NODE *rd_info_node) {
- assert(tx_size < TX_SIZES_ALL);
- av1_init_rd_stats(rd_stats);
- if (ref_best_rd < 0) {
- *is_cost_valid = 0;
- return;
- }
-
- MACROBLOCKD *const xd = &x->e_mbd;
- const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
- const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
- if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
-
- const int bw = mi_size_wide[plane_bsize];
- MB_MODE_INFO *const mbmi = xd->mi[0];
- const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
- mbmi->sb_type, tx_size);
- struct macroblock_plane *const p = &x->plane[0];
-
- const int try_no_split =
- cpi->oxcf.enable_tx64 || txsize_sqr_up_map[tx_size] != TX_64X64;
- int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
-#if CONFIG_DIST_8X8
- if (x->using_dist_8x8)
- try_split &= tx_size_wide[tx_size] >= 16 && tx_size_high[tx_size] >= 16;
-#endif
- TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
-
- // TX no split
- if (try_no_split) {
- try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
- plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
- ftxs_mode, rd_info_node, &no_split);
-
- if (cpi->sf.tx_sf.adaptive_txb_search_level &&
- (no_split.rd -
- (no_split.rd >> (1 + cpi->sf.tx_sf.adaptive_txb_search_level))) >
- ref_best_rd) {
- *is_cost_valid = 0;
- return;
- }
-
- if (cpi->sf.tx_sf.txb_split_cap) {
- if (p->eobs[block] == 0) try_split = 0;
- }
-
- if (cpi->sf.tx_sf.adaptive_txb_search_level &&
- (no_split.rd -
- (no_split.rd >> (2 + cpi->sf.tx_sf.adaptive_txb_search_level))) >
- prev_level_rd) {
- try_split = 0;
- }
- }
-
- if (x->e_mbd.bd == 8 && try_split &&
- !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
- const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
- if (threshold >= 0) {
- const int split_score =
- ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
- if (split_score < -threshold) try_split = 0;
- }
- }
-
- // TX split
- int64_t split_rd = INT64_MAX;
- RD_STATS split_rd_stats;
- av1_init_rd_stats(&split_rd_stats);
- if (try_split) {
- try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
- plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
- AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
- rd_info_node, &split_rd_stats, &split_rd);
- }
-
- if (no_split.rd < split_rd) {
- ENTROPY_CONTEXT *pta = ta + blk_col;
- ENTROPY_CONTEXT *ptl = tl + blk_row;
- const TX_SIZE tx_size_selected = tx_size;
- p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
- av1_set_txb_context(x, 0, block, tx_size_selected, pta, ptl);
- txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
- tx_size);
- for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
- for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
- const int index =
- av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
- mbmi->inter_tx_size[index] = tx_size_selected;
- }
- }
- mbmi->tx_size = tx_size_selected;
- update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
- set_blk_skip(x, 0, blk_row * bw + blk_col, rd_stats->skip);
- } else {
- *rd_stats = split_rd_stats;
- if (split_rd == INT64_MAX) *is_cost_valid = 0;
- }
-}
-
-static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
- RD_STATS *rd_stats, BLOCK_SIZE bsize,
- int64_t ref_best_rd,
- TXB_RD_INFO_NODE *rd_info_tree) {
- MACROBLOCKD *const xd = &x->e_mbd;
- assert(is_inter_block(xd->mi[0]));
- assert(bsize < BLOCK_SIZES_ALL);
-
- // TODO(debargha): enable this as a speed feature where the
- // select_inter_block_yrd() function above will use a simplified search
- // such as not using full optimize, but the inter_block_yrd() function
- // will use more complex search given that the transform partitions have
- // already been decided.
-
- const int fast_tx_search = x->tx_size_search_method > USE_FULL_RD;
- int64_t rd_thresh = ref_best_rd;
- if (fast_tx_search && rd_thresh < INT64_MAX) {
- if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
- }
- assert(rd_thresh > 0);
-
- const FAST_TX_SEARCH_MODE ftxs_mode =
- fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
- const struct macroblockd_plane *const pd = &xd->plane[0];
- const BLOCK_SIZE plane_bsize =
- get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
- assert(plane_bsize < BLOCK_SIZES_ALL);
- const int mi_width = mi_size_wide[plane_bsize];
- const int mi_height = mi_size_high[plane_bsize];
- ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
- ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
- TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
- TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
- av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
- memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
- memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
-
- const int skip_ctx = av1_get_skip_context(xd);
- const int s0 = x->skip_cost[skip_ctx][0];
- const int s1 = x->skip_cost[skip_ctx][1];
- const int init_depth = get_search_init_depth(mi_width, mi_height, 1, &cpi->sf,
- x->tx_size_search_method);
- const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
- const int bh = tx_size_high_unit[max_tx_size];
- const int bw = tx_size_wide_unit[max_tx_size];
- const int step = bw * bh;
- int64_t skip_rd = RDCOST(x->rdmult, s1, 0);
- int64_t this_rd = RDCOST(x->rdmult, s0, 0);
- int block = 0;
-
- av1_init_rd_stats(rd_stats);
- for (int idy = 0; idy < mi_height; idy += bh) {
- for (int idx = 0; idx < mi_width; idx += bw) {
- const int64_t best_rd_sofar =
- (rd_thresh == INT64_MAX) ? INT64_MAX
- : (rd_thresh - (AOMMIN(skip_rd, this_rd)));
- int is_cost_valid = 1;
- RD_STATS pn_rd_stats;
- select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth,
- plane_bsize, ctxa, ctxl, tx_above, tx_left, &pn_rd_stats,
- INT64_MAX, best_rd_sofar, &is_cost_valid, ftxs_mode,
- rd_info_tree);
- if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
- av1_invalid_rd_stats(rd_stats);
- return INT64_MAX;
- }
- av1_merge_rd_stats(rd_stats, &pn_rd_stats);
- skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
- this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
- block += step;
- if (rd_info_tree != NULL) rd_info_tree += 1;
- }
- }
-
- if (skip_rd <= this_rd) {
- rd_stats->skip = 1;
- } else {
- rd_stats->skip = 0;
- }
-
- if (rd_stats->rate == INT_MAX) return INT64_MAX;
-
- // If fast_tx_search is true, only DCT and 1D DCT were tested in
- // select_inter_block_yrd() above. Do a better search for tx type with
- // tx sizes already decided.
- if (fast_tx_search) {
- if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
- return INT64_MAX;
- }
-
- int64_t rd;
- if (rd_stats->skip) {
- rd = RDCOST(x->rdmult, s1, rd_stats->sse);
- } else {
- rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
- if (!xd->lossless[xd->mi[0]->segment_id])
- rd = AOMMIN(rd, RDCOST(x->rdmult, s1, rd_stats->sse));
- }
-
- return rd;
-}
-
-// Finds rd cost for a y block, given the transform size partitions
-static AOM_INLINE void tx_block_yrd(
- const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
- TX_SIZE tx_size, BLOCK_SIZE plane_bsize, int depth,
- ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
- TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, int64_t ref_best_rd,
- RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode) {
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
- const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
-
- assert(tx_size < TX_SIZES_ALL);
-
- if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
-
- const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
- plane_bsize, blk_row, blk_col)];
-
- int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
- mbmi->sb_type, tx_size);
-
- av1_init_rd_stats(rd_stats);
- if (tx_size == plane_tx_size) {
- ENTROPY_CONTEXT *ta = above_ctx + blk_col;
- ENTROPY_CONTEXT *tl = left_ctx + blk_row;
- const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
- TXB_CTX txb_ctx;
- get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
-
- const int zero_blk_rate = x->coeff_costs[txs_ctx][get_plane_type(0)]
- .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
- rd_stats->zero_rate = zero_blk_rate;
- tx_type_rd(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize,
- &txb_ctx, rd_stats, ftxs_mode, ref_best_rd, NULL);
- const int mi_width = mi_size_wide[plane_bsize];
- if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
- RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
- rd_stats->skip == 1) {
- rd_stats->rate = zero_blk_rate;
- rd_stats->dist = rd_stats->sse;
- rd_stats->skip = 1;
- set_blk_skip(x, 0, blk_row * mi_width + blk_col, 1);
- x->plane[0].eobs[block] = 0;
- x->plane[0].txb_entropy_ctx[block] = 0;
- update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
- } else {
- rd_stats->skip = 0;
- set_blk_skip(x, 0, blk_row * mi_width + blk_col, 0);
- }
- if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
- rd_stats->rate += x->txfm_partition_cost[ctx][0];
- av1_set_txb_context(x, 0, block, tx_size, ta, tl);
- txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
- tx_size);
- } else {
- const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
- const int bsw = tx_size_wide_unit[sub_txs];
- const int bsh = tx_size_high_unit[sub_txs];
- const int step = bsh * bsw;
- RD_STATS pn_rd_stats;
- int64_t this_rd = 0;
- assert(bsw > 0 && bsh > 0);
-
- for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
- for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
- const int offsetr = blk_row + row;
- const int offsetc = blk_col + col;
-
- if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
-
- av1_init_rd_stats(&pn_rd_stats);
- tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
- depth + 1, above_ctx, left_ctx, tx_above, tx_left,
- ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
- if (pn_rd_stats.rate == INT_MAX) {
- av1_invalid_rd_stats(rd_stats);
- return;
- }
- av1_merge_rd_stats(rd_stats, &pn_rd_stats);
- this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
- block += step;
- }
- }
-
- if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
- rd_stats->rate += x->txfm_partition_cost[ctx][1];
- }
-}
-
-// Return value 0: early termination triggered, no valid rd cost available;
-// 1: rd cost values are valid.
-static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
- RD_STATS *rd_stats, BLOCK_SIZE bsize,
- int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
- MACROBLOCKD *const xd = &x->e_mbd;
- int is_cost_valid = 1;
- int64_t this_rd = 0;
-
- if (ref_best_rd < 0) is_cost_valid = 0;
-
- av1_init_rd_stats(rd_stats);
-
- if (is_cost_valid) {
- const struct macroblockd_plane *const pd = &xd->plane[0];
- const BLOCK_SIZE plane_bsize =
- get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
- const int mi_width = mi_size_wide[plane_bsize];
- const int mi_height = mi_size_high[plane_bsize];
- const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, plane_bsize, 0);
- const int bh = tx_size_high_unit[max_tx_size];
- const int bw = tx_size_wide_unit[max_tx_size];
- const int init_depth = get_search_init_depth(
- mi_width, mi_height, 1, &cpi->sf, x->tx_size_search_method);
- int idx, idy;
- int block = 0;
- int step = tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
- ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
- ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
- TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
- TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
- RD_STATS pn_rd_stats;
-
- av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
- memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
- memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
-
- for (idy = 0; idy < mi_height; idy += bh) {
- for (idx = 0; idx < mi_width; idx += bw) {
- av1_init_rd_stats(&pn_rd_stats);
- tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, plane_bsize,
- init_depth, ctxa, ctxl, tx_above, tx_left,
- ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
- if (pn_rd_stats.rate == INT_MAX) {
- av1_invalid_rd_stats(rd_stats);
- return 0;
- }
- av1_merge_rd_stats(rd_stats, &pn_rd_stats);
- this_rd +=
- AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
- RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
- block += step;
- }
- }
- }
-
- const int skip_ctx = av1_get_skip_context(xd);
- const int s0 = x->skip_cost[skip_ctx][0];
- const int s1 = x->skip_cost[skip_ctx][1];
- int64_t skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
- this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
- if (skip_rd < this_rd) {
- this_rd = skip_rd;
- rd_stats->rate = 0;
- rd_stats->dist = rd_stats->sse;
- rd_stats->skip = 1;
- }
- if (this_rd > ref_best_rd) is_cost_valid = 0;
-
- if (!is_cost_valid) {
- // reset cost value
- av1_invalid_rd_stats(rd_stats);
- }
- return is_cost_valid;
-}
-
-static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record,
- const uint32_t hash) {
- // Linear search through the circular buffer to find matching hash.
- for (int i = cur_record->index_start - 1; i >= 0; i--) {
- if (cur_record->hash_vals[i] == hash) return i;
- }
- for (int i = cur_record->num - 1; i >= cur_record->index_start; i--) {
- if (cur_record->hash_vals[i] == hash) return i;
- }
- int index;
- // If not found - add new RD info into the buffer and return its index
- if (cur_record->num < TX_SIZE_RD_RECORD_BUFFER_LEN) {
- index = (cur_record->index_start + cur_record->num) %
- TX_SIZE_RD_RECORD_BUFFER_LEN;
- cur_record->num++;
- } else {
- index = cur_record->index_start;
- cur_record->index_start =
- (cur_record->index_start + 1) % TX_SIZE_RD_RECORD_BUFFER_LEN;
- }
-
- cur_record->hash_vals[index] = hash;
- av1_zero(cur_record->tx_rd_info[index]);
- return index;
-}
-
-typedef struct {
- int leaf;
- int8_t children[4];
-} RD_RECORD_IDX_NODE;
-
-static const RD_RECORD_IDX_NODE rd_record_tree_8x8[] = {
- { 1, { 0 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_8x16[] = {
- { 0, { 1, 2, -1, -1 } },
- { 1, { 0, 0, 0, 0 } },
- { 1, { 0, 0, 0, 0 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_16x8[] = {
- { 0, { 1, 2, -1, -1 } },
- { 1, { 0 } },
- { 1, { 0 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_16x16[] = {
- { 0, { 1, 2, 3, 4 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_1_2[] = {
- { 0, { 1, 2, -1, -1 } },
- { 0, { 3, 4, 5, 6 } },
- { 0, { 7, 8, 9, 10 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_2_1[] = {
- { 0, { 1, 2, -1, -1 } },
- { 0, { 3, 4, 7, 8 } },
- { 0, { 5, 6, 9, 10 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_sqr[] = {
- { 0, { 1, 2, 3, 4 } }, { 0, { 5, 6, 9, 10 } }, { 0, { 7, 8, 11, 12 } },
- { 0, { 13, 14, 17, 18 } }, { 0, { 15, 16, 19, 20 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_64x128[] = {
- { 0, { 2, 3, 4, 5 } }, { 0, { 6, 7, 8, 9 } },
- { 0, { 10, 11, 14, 15 } }, { 0, { 12, 13, 16, 17 } },
- { 0, { 18, 19, 22, 23 } }, { 0, { 20, 21, 24, 25 } },
- { 0, { 26, 27, 30, 31 } }, { 0, { 28, 29, 32, 33 } },
- { 0, { 34, 35, 38, 39 } }, { 0, { 36, 37, 40, 41 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_128x64[] = {
- { 0, { 2, 3, 6, 7 } }, { 0, { 4, 5, 8, 9 } },
- { 0, { 10, 11, 18, 19 } }, { 0, { 12, 13, 20, 21 } },
- { 0, { 14, 15, 22, 23 } }, { 0, { 16, 17, 24, 25 } },
- { 0, { 26, 27, 34, 35 } }, { 0, { 28, 29, 36, 37 } },
- { 0, { 30, 31, 38, 39 } }, { 0, { 32, 33, 40, 41 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_128x128[] = {
- { 0, { 4, 5, 8, 9 } }, { 0, { 6, 7, 10, 11 } },
- { 0, { 12, 13, 16, 17 } }, { 0, { 14, 15, 18, 19 } },
- { 0, { 20, 21, 28, 29 } }, { 0, { 22, 23, 30, 31 } },
- { 0, { 24, 25, 32, 33 } }, { 0, { 26, 27, 34, 35 } },
- { 0, { 36, 37, 44, 45 } }, { 0, { 38, 39, 46, 47 } },
- { 0, { 40, 41, 48, 49 } }, { 0, { 42, 43, 50, 51 } },
- { 0, { 52, 53, 60, 61 } }, { 0, { 54, 55, 62, 63 } },
- { 0, { 56, 57, 64, 65 } }, { 0, { 58, 59, 66, 67 } },
- { 0, { 68, 69, 76, 77 } }, { 0, { 70, 71, 78, 79 } },
- { 0, { 72, 73, 80, 81 } }, { 0, { 74, 75, 82, 83 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_1_4[] = {
- { 0, { 1, -1, 2, -1 } },
- { 0, { 3, 4, -1, -1 } },
- { 0, { 5, 6, -1, -1 } },
-};
-
-static const RD_RECORD_IDX_NODE rd_record_tree_4_1[] = {
- { 0, { 1, 2, -1, -1 } },
- { 0, { 3, 4, -1, -1 } },
- { 0, { 5, 6, -1, -1 } },
-};
-
-static const RD_RECORD_IDX_NODE *rd_record_tree[BLOCK_SIZES_ALL] = {
- NULL, // BLOCK_4X4
- NULL, // BLOCK_4X8
- NULL, // BLOCK_8X4
- rd_record_tree_8x8, // BLOCK_8X8
- rd_record_tree_8x16, // BLOCK_8X16
- rd_record_tree_16x8, // BLOCK_16X8
- rd_record_tree_16x16, // BLOCK_16X16
- rd_record_tree_1_2, // BLOCK_16X32
- rd_record_tree_2_1, // BLOCK_32X16
- rd_record_tree_sqr, // BLOCK_32X32
- rd_record_tree_1_2, // BLOCK_32X64
- rd_record_tree_2_1, // BLOCK_64X32
- rd_record_tree_sqr, // BLOCK_64X64
- rd_record_tree_64x128, // BLOCK_64X128
- rd_record_tree_128x64, // BLOCK_128X64
- rd_record_tree_128x128, // BLOCK_128X128
- NULL, // BLOCK_4X16
- NULL, // BLOCK_16X4
- rd_record_tree_1_4, // BLOCK_8X32
- rd_record_tree_4_1, // BLOCK_32X8
- rd_record_tree_1_4, // BLOCK_16X64
- rd_record_tree_4_1, // BLOCK_64X16
-};
-
-static const int rd_record_tree_size[BLOCK_SIZES_ALL] = {
- 0, // BLOCK_4X4
- 0, // BLOCK_4X8
- 0, // BLOCK_8X4
- sizeof(rd_record_tree_8x8) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_8X8
- sizeof(rd_record_tree_8x16) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_8X16
- sizeof(rd_record_tree_16x8) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_16X8
- sizeof(rd_record_tree_16x16) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_16X16
- sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_16X32
- sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_32X16
- sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_32X32
- sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_32X64
- sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_64X32
- sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_64X64
- sizeof(rd_record_tree_64x128) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_64X128
- sizeof(rd_record_tree_128x64) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_128X64
- sizeof(rd_record_tree_128x128) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_128X128
- 0, // BLOCK_4X16
- 0, // BLOCK_16X4
- sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_8X32
- sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_32X8
- sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_16X64
- sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_64X16
-};
-
-static INLINE void init_rd_record_tree(TXB_RD_INFO_NODE *tree,
- BLOCK_SIZE bsize) {
- const RD_RECORD_IDX_NODE *rd_record = rd_record_tree[bsize];
- const int size = rd_record_tree_size[bsize];
- for (int i = 0; i < size; ++i) {
- if (rd_record[i].leaf) {
- av1_zero(tree[i].children);
- } else {
- for (int j = 0; j < 4; ++j) {
- const int8_t idx = rd_record[i].children[j];
- tree[i].children[j] = idx > 0 ? &tree[idx] : NULL;
- }
- }
- }
-}
-
-// Go through all TX blocks that could be used in TX size search, compute
-// residual hash values for them and find matching RD info that stores previous
-// RD search results for these TX blocks. The idea is to prevent repeated
-// rate/distortion computations that happen because of the combination of
-// partition and TX size search. The resulting RD info records are returned in
-// the form of a quadtree for easier access in actual TX size search.
-static int find_tx_size_rd_records(MACROBLOCK *x, BLOCK_SIZE bsize,
- TXB_RD_INFO_NODE *dst_rd_info) {
- TXB_RD_RECORD *rd_records_table[4] = { x->txb_rd_record_8X8,
- x->txb_rd_record_16X16,
- x->txb_rd_record_32X32,
- x->txb_rd_record_64X64 };
- const TX_SIZE max_square_tx_size = max_txsize_lookup[bsize];
- const int bw = block_size_wide[bsize];
- const int bh = block_size_high[bsize];
-
- // Hashing is performed only for square TX sizes larger than TX_4X4
- if (max_square_tx_size < TX_8X8) return 0;
- const int diff_stride = bw;
- const struct macroblock_plane *const p = &x->plane[0];
- const int16_t *diff = &p->src_diff[0];
- init_rd_record_tree(dst_rd_info, bsize);
- // Coordinates of the top-left corner of current block within the superblock
- // measured in pixels:
- const int mi_row = x->e_mbd.mi_row;
- const int mi_col = x->e_mbd.mi_col;
- const int mi_row_in_sb = (mi_row % MAX_MIB_SIZE) << MI_SIZE_LOG2;
- const int mi_col_in_sb = (mi_col % MAX_MIB_SIZE) << MI_SIZE_LOG2;
- int cur_rd_info_idx = 0;
- int cur_tx_depth = 0;
- TX_SIZE cur_tx_size = max_txsize_rect_lookup[bsize];
- while (cur_tx_depth <= MAX_VARTX_DEPTH) {
- const int cur_tx_bw = tx_size_wide[cur_tx_size];
- const int cur_tx_bh = tx_size_high[cur_tx_size];
- if (cur_tx_bw < 8 || cur_tx_bh < 8) break;
- const TX_SIZE next_tx_size = sub_tx_size_map[cur_tx_size];
- const int tx_size_idx = cur_tx_size - TX_8X8;
- for (int row = 0; row < bh; row += cur_tx_bh) {
- for (int col = 0; col < bw; col += cur_tx_bw) {
- if (cur_tx_bw != cur_tx_bh) {
- // Use dummy nodes for all rectangular transforms within the
- // TX size search tree.
- dst_rd_info[cur_rd_info_idx].rd_info_array = NULL;
- } else {
- // Get spatial location of this TX block within the superblock
- // (measured in cur_tx_bsize units).
- const int row_in_sb = (mi_row_in_sb + row) / cur_tx_bh;
- const int col_in_sb = (mi_col_in_sb + col) / cur_tx_bw;
-
- int16_t hash_data[MAX_SB_SQUARE];
- int16_t *cur_hash_row = hash_data;
- const int16_t *cur_diff_row = diff + row * diff_stride + col;
- for (int i = 0; i < cur_tx_bh; i++) {
- memcpy(cur_hash_row, cur_diff_row, sizeof(*hash_data) * cur_tx_bw);
- cur_hash_row += cur_tx_bw;
- cur_diff_row += diff_stride;
- }
- const int hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
- (uint8_t *)hash_data,
- 2 * cur_tx_bw * cur_tx_bh);
- // Find corresponding RD info based on the hash value.
- const int record_idx =
- row_in_sb * (MAX_MIB_SIZE >> (tx_size_idx + 1)) + col_in_sb;
- TXB_RD_RECORD *records = &rd_records_table[tx_size_idx][record_idx];
- int idx = find_tx_size_rd_info(records, hash);
- dst_rd_info[cur_rd_info_idx].rd_info_array =
- &records->tx_rd_info[idx];
- }
- ++cur_rd_info_idx;
- }
- }
- cur_tx_size = next_tx_size;
- ++cur_tx_depth;
- }
- return 1;
-}
-
-// Search for best transform size and type for luma inter blocks.
-static AOM_INLINE void pick_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
- RD_STATS *rd_stats,
- BLOCK_SIZE bsize,
- int64_t ref_best_rd) {
- const AV1_COMMON *cm = &cpi->common;
- MACROBLOCKD *const xd = &x->e_mbd;
- assert(is_inter_block(xd->mi[0]));
-
- av1_invalid_rd_stats(rd_stats);
-
- if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
- ref_best_rd != INT64_MAX) {
- int model_rate;
- int64_t model_dist;
- int model_skip;
- model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
- cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
- NULL, NULL, NULL);
- const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
- // If the modeled rd is a lot worse than the best so far, breakout.
- // TODO(debargha, urvang): Improve the model and make the check below
- // tighter.
- assert(cpi->sf.tx_sf.model_based_prune_tx_search_level >= 0 &&
- cpi->sf.tx_sf.model_based_prune_tx_search_level <= 2);
- static const int prune_factor_by8[] = { 3, 5 };
- if (!model_skip &&
- ((model_rd *
- prune_factor_by8[cpi->sf.tx_sf.model_based_prune_tx_search_level -
- 1]) >>
- 3) > ref_best_rd)
- return;
- }
-
- uint32_t hash = 0;
- int32_t match_index = -1;
- MB_RD_RECORD *mb_rd_record = NULL;
- const int mi_row = x->e_mbd.mi_row;
- const int mi_col = x->e_mbd.mi_col;
- const int within_border =
- mi_row >= xd->tile.mi_row_start &&
- (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
- mi_col >= xd->tile.mi_col_start &&
- (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
- const int is_mb_rd_hash_enabled =
- (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
- const int n4 = bsize_to_num_blk(bsize);
- if (is_mb_rd_hash_enabled) {
- hash = get_block_residue_hash(x, bsize);
- mb_rd_record = &x->mb_rd_record;
- match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
- if (match_index != -1) {
- MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
- fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
- return;
- }
- }
-
- // If we predict that skip is the optimal RD decision - set the respective
- // context and terminate early.
- int64_t dist;
- if (x->predict_skip_level &&
- predict_skip_flag(x, bsize, &dist, cm->reduced_tx_set_used)) {
- set_skip_flag(x, rd_stats, bsize, dist);
- // Save the RD search results into tx_rd_record.
- if (is_mb_rd_hash_enabled)
- save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
- return;
- }
-#if CONFIG_SPEED_STATS
- ++x->tx_search_count;
-#endif // CONFIG_SPEED_STATS
-
- // Precompute residual hashes and find existing or add new RD records to
- // store and reuse rate and distortion values to speed up TX size search.
- TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
- int found_rd_info = 0;
- if (ref_best_rd != INT64_MAX && within_border &&
- cpi->sf.tx_sf.use_inter_txb_hash) {
- found_rd_info = find_tx_size_rd_records(x, bsize, matched_rd_info);
- }
-
- int found = 0;
- RD_STATS this_rd_stats;
- av1_init_rd_stats(&this_rd_stats);
- const int64_t rd =
- select_tx_size_and_type(cpi, x, &this_rd_stats, bsize, ref_best_rd,
- found_rd_info ? matched_rd_info : NULL);
-
- if (rd < INT64_MAX) {
- *rd_stats = this_rd_stats;
- found = 1;
- }
-
- // We should always find at least one candidate unless ref_best_rd is less
- // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
- // might have failed to find something better)
- assert(IMPLIES(!found, ref_best_rd != INT64_MAX));
- if (!found) return;
-
- // Save the RD search results into tx_rd_record.
- if (is_mb_rd_hash_enabled) {
- assert(mb_rd_record != NULL);
- save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
- }
-}
-
static AOM_INLINE void rd_pick_palette_intra_sbuv(
const AV1_COMP *const cpi, MACROBLOCK *x, int dc_mode_cost,
uint8_t *best_palette_color_map, MB_MODE_INFO *const best_mbmi,
@@ -7637,142 +4259,6 @@
}
}
-static int txfm_search(const AV1_COMP *cpi, const TileDataEnc *tile_data,
- MACROBLOCK *x, BLOCK_SIZE bsize, RD_STATS *rd_stats,
- RD_STATS *rd_stats_y, RD_STATS *rd_stats_uv,
- int mode_rate, int64_t ref_best_rd) {
- /*
- * This function combines y and uv planes' transform search processes
- * together, when the prediction is generated. It first does subtraction to
- * obtain the prediction error. Then it calls
- * pick_tx_size_type_yrd/super_block_yrd and super_block_uvrd sequentially and
- * handles the early terminations happening in those functions. At the end, it
- * computes the rd_stats/_y/_uv accordingly.
- */
- const AV1_COMMON *cm = &cpi->common;
- MACROBLOCKD *const xd = &x->e_mbd;
- MB_MODE_INFO *const mbmi = xd->mi[0];
- const int ref_frame_1 = mbmi->ref_frame[1];
- const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
- const int64_t rd_thresh =
- ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
- const int skip_ctx = av1_get_skip_context(xd);
- const int skip_flag_cost[2] = { x->skip_cost[skip_ctx][0],
- x->skip_cost[skip_ctx][1] };
- const int64_t min_header_rate =
- mode_rate + AOMMIN(skip_flag_cost[0], skip_flag_cost[1]);
- // Account for minimum skip and non_skip rd.
- // Eventually either one of them will be added to mode_rate
- const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
- (void)tile_data;
-
- if (min_header_rd_possible > ref_best_rd) {
- av1_invalid_rd_stats(rd_stats_y);
- return 0;
- }
-
- av1_init_rd_stats(rd_stats);
- av1_init_rd_stats(rd_stats_y);
- rd_stats->rate = mode_rate;
-
- // cost and distortion
- av1_subtract_plane(x, bsize, 0);
- if (x->tx_mode_search_type == TX_MODE_SELECT &&
- !xd->lossless[mbmi->segment_id]) {
- pick_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
-#if CONFIG_COLLECT_RD_STATS == 2
- PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
-#endif // CONFIG_COLLECT_RD_STATS == 2
- } else {
- super_block_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
- memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
- for (int i = 0; i < xd->n4_h * xd->n4_w; ++i)
- set_blk_skip(x, 0, i, rd_stats_y->skip);
- }
-
- if (rd_stats_y->rate == INT_MAX) {
- // TODO(angiebird): check if we need this
- // restore_dst_buf(xd, *orig_dst, num_planes);
- mbmi->ref_frame[1] = ref_frame_1;
- return 0;
- }
-
- av1_merge_rd_stats(rd_stats, rd_stats_y);
-
- const int64_t non_skip_rdcosty =
- RDCOST(x->rdmult, rd_stats->rate + skip_flag_cost[0], rd_stats->dist);
- const int64_t skip_rdcosty =
- RDCOST(x->rdmult, mode_rate + skip_flag_cost[1], rd_stats->sse);
- const int64_t min_rdcosty = AOMMIN(non_skip_rdcosty, skip_rdcosty);
- if (min_rdcosty > ref_best_rd) {
- const int64_t tokenonly_rdy =
- AOMMIN(RDCOST(x->rdmult, rd_stats_y->rate, rd_stats_y->dist),
- RDCOST(x->rdmult, 0, rd_stats_y->sse));
- // Invalidate rd_stats_y to skip the rest of the motion modes search
- if (tokenonly_rdy -
- (tokenonly_rdy >> cpi->sf.inter_sf.prune_motion_mode_level) >
- rd_thresh)
- av1_invalid_rd_stats(rd_stats_y);
- mbmi->ref_frame[1] = ref_frame_1;
- return 0;
- }
-
- av1_init_rd_stats(rd_stats_uv);
- const int num_planes = av1_num_planes(cm);
- if (num_planes > 1) {
- int64_t ref_best_chroma_rd = ref_best_rd;
- // Calculate best rd cost possible for chroma
- if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
- (ref_best_chroma_rd != INT64_MAX)) {
- ref_best_chroma_rd =
- (ref_best_chroma_rd - AOMMIN(non_skip_rdcosty, skip_rdcosty));
- }
- const int is_cost_valid_uv =
- super_block_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
- if (!is_cost_valid_uv) {
- mbmi->ref_frame[1] = ref_frame_1;
- return 0;
- }
- av1_merge_rd_stats(rd_stats, rd_stats_uv);
- }
-
- if (rd_stats->skip) {
- rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
- rd_stats_y->rate = 0;
- rd_stats_uv->rate = 0;
- rd_stats->dist = rd_stats->sse;
- rd_stats_y->dist = rd_stats_y->sse;
- rd_stats_uv->dist = rd_stats_uv->sse;
- rd_stats->rate += skip_flag_cost[1];
- mbmi->skip = 1;
- // here mbmi->skip temporarily plays a role as what this_skip2 does
-
- const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
- if (tmprd > ref_best_rd) {
- mbmi->ref_frame[1] = ref_frame_1;
- return 0;
- }
- } else if (!xd->lossless[mbmi->segment_id] &&
- (RDCOST(x->rdmult,
- rd_stats_y->rate + rd_stats_uv->rate + skip_flag_cost[0],
- rd_stats->dist) >=
- RDCOST(x->rdmult, skip_flag_cost[1], rd_stats->sse))) {
- rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
- rd_stats->rate += skip_flag_cost[1];
- rd_stats->dist = rd_stats->sse;
- rd_stats_y->dist = rd_stats_y->sse;
- rd_stats_uv->dist = rd_stats_uv->sse;
- rd_stats_y->rate = 0;
- rd_stats_uv->rate = 0;
- mbmi->skip = 1;
- } else {
- rd_stats->rate += skip_flag_cost[0];
- mbmi->skip = 0;
- }
-
- return 1;
-}
-
static INLINE bool enable_wedge_search(MACROBLOCK *const x,
const AV1_COMP *const cpi) {
// Enable wedge search if source variance and edge strength are above
diff --git a/av1/encoder/tx_search.c b/av1/encoder/tx_search.c
new file mode 100644
index 0000000..01f969f
--- /dev/null
+++ b/av1/encoder/tx_search.c
@@ -0,0 +1,3574 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "av1/common/cfl.h"
+#include "av1/common/reconintra.h"
+#include "av1/encoder/encodetxb.h"
+#include "av1/encoder/hybrid_fwd_txfm.h"
+#include "av1/common/idct.h"
+#include "av1/encoder/model_rd.h"
+#include "av1/encoder/random.h"
+#include "av1/encoder/rdopt.h"
+#include "av1/encoder/rdopt_utils.h"
+#include "av1/encoder/tx_prune_model_weights.h"
+#include "av1/encoder/tx_search.h"
+
+struct rdcost_block_args {
+ const AV1_COMP *cpi;
+ MACROBLOCK *x;
+ ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
+ ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
+ RD_STATS rd_stats;
+ int64_t this_rd;
+ int64_t best_rd;
+ int exit_early;
+ int incomplete_exit;
+ int use_fast_coef_costing;
+ FAST_TX_SEARCH_MODE ftxs_mode;
+ int skip_trellis;
+};
+
+typedef struct {
+ int64_t rd;
+ int txb_entropy_ctx;
+ TX_TYPE tx_type;
+} TxCandidateInfo;
+
+typedef struct {
+ int leaf;
+ int8_t children[4];
+} RD_RECORD_IDX_NODE;
+
+// origin_threshold * 128 / 100
+static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
+ {
+ 64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
+ 68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
+ },
+ {
+ 88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
+ 68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
+ },
+ {
+ 90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
+ 74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
+ },
+};
+
+// lookup table for predict_skip_flag
+// int max_tx_size = max_txsize_rect_lookup[bsize];
+// if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
+// max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
+static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
+ TX_4X4, TX_4X8, TX_8X4, TX_8X8, TX_8X16, TX_16X8,
+ TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
+ TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16, TX_16X4,
+ TX_8X8, TX_8X8, TX_16X16, TX_16X16,
+};
+
+static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record,
+ const uint32_t hash) {
+ // Linear search through the circular buffer to find matching hash.
+ for (int i = cur_record->index_start - 1; i >= 0; i--) {
+ if (cur_record->hash_vals[i] == hash) return i;
+ }
+ for (int i = cur_record->num - 1; i >= cur_record->index_start; i--) {
+ if (cur_record->hash_vals[i] == hash) return i;
+ }
+ int index;
+ // If not found - add new RD info into the buffer and return its index
+ if (cur_record->num < TX_SIZE_RD_RECORD_BUFFER_LEN) {
+ index = (cur_record->index_start + cur_record->num) %
+ TX_SIZE_RD_RECORD_BUFFER_LEN;
+ cur_record->num++;
+ } else {
+ index = cur_record->index_start;
+ cur_record->index_start =
+ (cur_record->index_start + 1) % TX_SIZE_RD_RECORD_BUFFER_LEN;
+ }
+
+ cur_record->hash_vals[index] = hash;
+ av1_zero(cur_record->tx_rd_info[index]);
+ return index;
+}
+
+static const RD_RECORD_IDX_NODE rd_record_tree_8x8[] = {
+ { 1, { 0 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_8x16[] = {
+ { 0, { 1, 2, -1, -1 } },
+ { 1, { 0, 0, 0, 0 } },
+ { 1, { 0, 0, 0, 0 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_16x8[] = {
+ { 0, { 1, 2, -1, -1 } },
+ { 1, { 0 } },
+ { 1, { 0 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_16x16[] = {
+ { 0, { 1, 2, 3, 4 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_1_2[] = {
+ { 0, { 1, 2, -1, -1 } },
+ { 0, { 3, 4, 5, 6 } },
+ { 0, { 7, 8, 9, 10 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_2_1[] = {
+ { 0, { 1, 2, -1, -1 } },
+ { 0, { 3, 4, 7, 8 } },
+ { 0, { 5, 6, 9, 10 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_sqr[] = {
+ { 0, { 1, 2, 3, 4 } }, { 0, { 5, 6, 9, 10 } }, { 0, { 7, 8, 11, 12 } },
+ { 0, { 13, 14, 17, 18 } }, { 0, { 15, 16, 19, 20 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_64x128[] = {
+ { 0, { 2, 3, 4, 5 } }, { 0, { 6, 7, 8, 9 } },
+ { 0, { 10, 11, 14, 15 } }, { 0, { 12, 13, 16, 17 } },
+ { 0, { 18, 19, 22, 23 } }, { 0, { 20, 21, 24, 25 } },
+ { 0, { 26, 27, 30, 31 } }, { 0, { 28, 29, 32, 33 } },
+ { 0, { 34, 35, 38, 39 } }, { 0, { 36, 37, 40, 41 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_128x64[] = {
+ { 0, { 2, 3, 6, 7 } }, { 0, { 4, 5, 8, 9 } },
+ { 0, { 10, 11, 18, 19 } }, { 0, { 12, 13, 20, 21 } },
+ { 0, { 14, 15, 22, 23 } }, { 0, { 16, 17, 24, 25 } },
+ { 0, { 26, 27, 34, 35 } }, { 0, { 28, 29, 36, 37 } },
+ { 0, { 30, 31, 38, 39 } }, { 0, { 32, 33, 40, 41 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_128x128[] = {
+ { 0, { 4, 5, 8, 9 } }, { 0, { 6, 7, 10, 11 } },
+ { 0, { 12, 13, 16, 17 } }, { 0, { 14, 15, 18, 19 } },
+ { 0, { 20, 21, 28, 29 } }, { 0, { 22, 23, 30, 31 } },
+ { 0, { 24, 25, 32, 33 } }, { 0, { 26, 27, 34, 35 } },
+ { 0, { 36, 37, 44, 45 } }, { 0, { 38, 39, 46, 47 } },
+ { 0, { 40, 41, 48, 49 } }, { 0, { 42, 43, 50, 51 } },
+ { 0, { 52, 53, 60, 61 } }, { 0, { 54, 55, 62, 63 } },
+ { 0, { 56, 57, 64, 65 } }, { 0, { 58, 59, 66, 67 } },
+ { 0, { 68, 69, 76, 77 } }, { 0, { 70, 71, 78, 79 } },
+ { 0, { 72, 73, 80, 81 } }, { 0, { 74, 75, 82, 83 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_1_4[] = {
+ { 0, { 1, -1, 2, -1 } },
+ { 0, { 3, 4, -1, -1 } },
+ { 0, { 5, 6, -1, -1 } },
+};
+
+static const RD_RECORD_IDX_NODE rd_record_tree_4_1[] = {
+ { 0, { 1, 2, -1, -1 } },
+ { 0, { 3, 4, -1, -1 } },
+ { 0, { 5, 6, -1, -1 } },
+};
+
+static const RD_RECORD_IDX_NODE *rd_record_tree[BLOCK_SIZES_ALL] = {
+ NULL, // BLOCK_4X4
+ NULL, // BLOCK_4X8
+ NULL, // BLOCK_8X4
+ rd_record_tree_8x8, // BLOCK_8X8
+ rd_record_tree_8x16, // BLOCK_8X16
+ rd_record_tree_16x8, // BLOCK_16X8
+ rd_record_tree_16x16, // BLOCK_16X16
+ rd_record_tree_1_2, // BLOCK_16X32
+ rd_record_tree_2_1, // BLOCK_32X16
+ rd_record_tree_sqr, // BLOCK_32X32
+ rd_record_tree_1_2, // BLOCK_32X64
+ rd_record_tree_2_1, // BLOCK_64X32
+ rd_record_tree_sqr, // BLOCK_64X64
+ rd_record_tree_64x128, // BLOCK_64X128
+ rd_record_tree_128x64, // BLOCK_128X64
+ rd_record_tree_128x128, // BLOCK_128X128
+ NULL, // BLOCK_4X16
+ NULL, // BLOCK_16X4
+ rd_record_tree_1_4, // BLOCK_8X32
+ rd_record_tree_4_1, // BLOCK_32X8
+ rd_record_tree_1_4, // BLOCK_16X64
+ rd_record_tree_4_1, // BLOCK_64X16
+};
+
+static const int rd_record_tree_size[BLOCK_SIZES_ALL] = {
+ 0, // BLOCK_4X4
+ 0, // BLOCK_4X8
+ 0, // BLOCK_8X4
+ sizeof(rd_record_tree_8x8) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_8X8
+ sizeof(rd_record_tree_8x16) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_8X16
+ sizeof(rd_record_tree_16x8) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_16X8
+ sizeof(rd_record_tree_16x16) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_16X16
+ sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_16X32
+ sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_32X16
+ sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_32X32
+ sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_32X64
+ sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_64X32
+ sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_64X64
+ sizeof(rd_record_tree_64x128) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_64X128
+ sizeof(rd_record_tree_128x64) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_128X64
+ sizeof(rd_record_tree_128x128) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_128X128
+ 0, // BLOCK_4X16
+ 0, // BLOCK_16X4
+ sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_8X32
+ sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_32X8
+ sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_16X64
+ sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE), // BLOCK_64X16
+};
+
+static INLINE void init_rd_record_tree(TXB_RD_INFO_NODE *tree,
+ BLOCK_SIZE bsize) {
+ const RD_RECORD_IDX_NODE *rd_record = rd_record_tree[bsize];
+ const int size = rd_record_tree_size[bsize];
+ for (int i = 0; i < size; ++i) {
+ if (rd_record[i].leaf) {
+ av1_zero(tree[i].children);
+ } else {
+ for (int j = 0; j < 4; ++j) {
+ const int8_t idx = rd_record[i].children[j];
+ tree[i].children[j] = idx > 0 ? &tree[idx] : NULL;
+ }
+ }
+ }
+}
+
+// Go through all TX blocks that could be used in TX size search, compute
+// residual hash values for them and find matching RD info that stores previous
+// RD search results for these TX blocks. The idea is to prevent repeated
+// rate/distortion computations that happen because of the combination of
+// partition and TX size search. The resulting RD info records are returned in
+// the form of a quadtree for easier access in actual TX size search.
+static int find_tx_size_rd_records(MACROBLOCK *x, BLOCK_SIZE bsize,
+ TXB_RD_INFO_NODE *dst_rd_info) {
+ TXB_RD_RECORD *rd_records_table[4] = { x->txb_rd_record_8X8,
+ x->txb_rd_record_16X16,
+ x->txb_rd_record_32X32,
+ x->txb_rd_record_64X64 };
+ const TX_SIZE max_square_tx_size = max_txsize_lookup[bsize];
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+
+ // Hashing is performed only for square TX sizes larger than TX_4X4
+ if (max_square_tx_size < TX_8X8) return 0;
+ const int diff_stride = bw;
+ const struct macroblock_plane *const p = &x->plane[0];
+ const int16_t *diff = &p->src_diff[0];
+ init_rd_record_tree(dst_rd_info, bsize);
+ // Coordinates of the top-left corner of current block within the superblock
+ // measured in pixels:
+ const int mi_row = x->e_mbd.mi_row;
+ const int mi_col = x->e_mbd.mi_col;
+ const int mi_row_in_sb = (mi_row % MAX_MIB_SIZE) << MI_SIZE_LOG2;
+ const int mi_col_in_sb = (mi_col % MAX_MIB_SIZE) << MI_SIZE_LOG2;
+ int cur_rd_info_idx = 0;
+ int cur_tx_depth = 0;
+ TX_SIZE cur_tx_size = max_txsize_rect_lookup[bsize];
+ while (cur_tx_depth <= MAX_VARTX_DEPTH) {
+ const int cur_tx_bw = tx_size_wide[cur_tx_size];
+ const int cur_tx_bh = tx_size_high[cur_tx_size];
+ if (cur_tx_bw < 8 || cur_tx_bh < 8) break;
+ const TX_SIZE next_tx_size = sub_tx_size_map[cur_tx_size];
+ const int tx_size_idx = cur_tx_size - TX_8X8;
+ for (int row = 0; row < bh; row += cur_tx_bh) {
+ for (int col = 0; col < bw; col += cur_tx_bw) {
+ if (cur_tx_bw != cur_tx_bh) {
+ // Use dummy nodes for all rectangular transforms within the
+ // TX size search tree.
+ dst_rd_info[cur_rd_info_idx].rd_info_array = NULL;
+ } else {
+ // Get spatial location of this TX block within the superblock
+ // (measured in cur_tx_bsize units).
+ const int row_in_sb = (mi_row_in_sb + row) / cur_tx_bh;
+ const int col_in_sb = (mi_col_in_sb + col) / cur_tx_bw;
+
+ int16_t hash_data[MAX_SB_SQUARE];
+ int16_t *cur_hash_row = hash_data;
+ const int16_t *cur_diff_row = diff + row * diff_stride + col;
+ for (int i = 0; i < cur_tx_bh; i++) {
+ memcpy(cur_hash_row, cur_diff_row, sizeof(*hash_data) * cur_tx_bw);
+ cur_hash_row += cur_tx_bw;
+ cur_diff_row += diff_stride;
+ }
+ const int hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
+ (uint8_t *)hash_data,
+ 2 * cur_tx_bw * cur_tx_bh);
+ // Find corresponding RD info based on the hash value.
+ const int record_idx =
+ row_in_sb * (MAX_MIB_SIZE >> (tx_size_idx + 1)) + col_in_sb;
+ TXB_RD_RECORD *records = &rd_records_table[tx_size_idx][record_idx];
+ int idx = find_tx_size_rd_info(records, hash);
+ dst_rd_info[cur_rd_info_idx].rd_info_array =
+ &records->tx_rd_info[idx];
+ }
+ ++cur_rd_info_idx;
+ }
+ }
+ cur_tx_size = next_tx_size;
+ ++cur_tx_depth;
+ }
+ return 1;
+}
+
+static INLINE uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
+ const int rows = block_size_high[bsize];
+ const int cols = block_size_wide[bsize];
+ const int16_t *diff = x->plane[0].src_diff;
+ const uint32_t hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
+ (uint8_t *)diff, 2 * rows * cols);
+ return (hash << 5) + bsize;
+}
+
+static INLINE int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
+ const int64_t ref_best_rd,
+ const uint32_t hash) {
+ int32_t match_index = -1;
+ if (ref_best_rd != INT64_MAX) {
+ for (int i = 0; i < mb_rd_record->num; ++i) {
+ const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
+ // If there is a match in the tx_rd_record, fetch the RD decision and
+ // terminate early.
+ if (mb_rd_record->tx_rd_info[index].hash_value == hash) {
+ match_index = index;
+ break;
+ }
+ }
+ }
+ return match_index;
+}
+
+static AOM_INLINE void fetch_tx_rd_info(int n4,
+ const MB_RD_INFO *const tx_rd_info,
+ RD_STATS *const rd_stats,
+ MACROBLOCK *const x) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ mbmi->tx_size = tx_rd_info->tx_size;
+ memcpy(x->blk_skip, tx_rd_info->blk_skip,
+ sizeof(tx_rd_info->blk_skip[0]) * n4);
+ av1_copy(mbmi->inter_tx_size, tx_rd_info->inter_tx_size);
+ av1_copy_array(xd->tx_type_map, tx_rd_info->tx_type_map, n4);
+ *rd_stats = tx_rd_info->rd_stats;
+}
+
+// Compute the pixel domain distortion from diff on all visible 4x4s in the
+// transform block.
+static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane,
+ int blk_row, int blk_col,
+ const BLOCK_SIZE plane_bsize,
+ const BLOCK_SIZE tx_bsize,
+ unsigned int *block_mse_q8) {
+ int visible_rows, visible_cols;
+ const MACROBLOCKD *xd = &x->e_mbd;
+ get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
+ NULL, &visible_cols, &visible_rows);
+ const int diff_stride = block_size_wide[plane_bsize];
+ const int16_t *diff = x->plane[plane].src_diff;
+#if CONFIG_DIST_8X8
+ int txb_height = block_size_high[tx_bsize];
+ int txb_width = block_size_wide[tx_bsize];
+ if (x->using_dist_8x8 && plane == 0) {
+ const int src_stride = x->plane[plane].src.stride;
+ const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
+ const int diff_idx = (blk_row * diff_stride + blk_col) << MI_SIZE_LOG2;
+ const uint8_t *src = &x->plane[plane].src.buf[src_idx];
+ return dist_8x8_diff(x, src, src_stride, diff + diff_idx, diff_stride,
+ txb_width, txb_height, visible_cols, visible_rows,
+ x->qindex);
+ }
+#endif
+ diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
+ uint64_t sse =
+ aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
+ if (block_mse_q8 != NULL) {
+ if (visible_cols > 0 && visible_rows > 0)
+ *block_mse_q8 =
+ (unsigned int)((256 * sse) / (visible_cols * visible_rows));
+ else
+ *block_mse_q8 = UINT_MAX;
+ }
+ return sse;
+}
+
+// Uses simple features on top of DCT coefficients to quickly predict
+// whether optimal RD decision is to skip encoding the residual.
+// The sse value is stored in dist.
+static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
+ int reduced_tx_set) {
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ const MACROBLOCKD *xd = &x->e_mbd;
+ const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
+
+ *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
+
+ const int64_t mse = *dist / bw / bh;
+ // Normalized quantizer takes the transform upscaling factor (8 for tx size
+ // smaller than 32) into account.
+ const int16_t normalized_dc_q = dc_q >> 3;
+ const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
+ // For faster early skip decision, use dist to compare against threshold so
+ // that quality risk is less for the skip=1 decision. Otherwise, use mse
+ // since the fwd_txfm coeff checks will take care of quality
+ // TODO(any): Use dist to return 0 when predict_skip_level is 1
+ int64_t pred_err = (x->predict_skip_level >= 2) ? *dist : mse;
+ // Predict not to skip when error is larger than threshold.
+ if (pred_err > mse_thresh) return 0;
+ // Return as skip otherwise for aggressive early skip
+ else if (x->predict_skip_level >= 2)
+ return 1;
+
+ const int max_tx_size = max_predict_sf_tx_size[bsize];
+ const int tx_h = tx_size_high[max_tx_size];
+ const int tx_w = tx_size_wide[max_tx_size];
+ DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
+ TxfmParam param;
+ param.tx_type = DCT_DCT;
+ param.tx_size = max_tx_size;
+ param.bd = xd->bd;
+ param.is_hbd = is_cur_buf_hbd(xd);
+ param.lossless = 0;
+ param.tx_set_type = av1_get_ext_tx_set_type(
+ param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
+ const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
+ const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
+ const int16_t *src_diff = x->plane[0].src_diff;
+ const int n_coeff = tx_w * tx_h;
+ const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
+ const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
+ const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
+ for (int row = 0; row < bh; row += tx_h) {
+ for (int col = 0; col < bw; col += tx_w) {
+ av1_fwd_txfm(src_diff + col, coefs, bw, ¶m);
+ // Operating on TX domain, not pixels; we want the QTX quantizers
+ const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
+ if (dc_coef >= dc_thresh) return 0;
+ for (int i = 1; i < n_coeff; ++i) {
+ const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
+ if (ac_coef >= ac_thresh) return 0;
+ }
+ }
+ src_diff += tx_h * bw;
+ }
+ return 1;
+}
+
+// Used to set proper context for early termination with skip = 1.
+static AOM_INLINE void set_skip_flag(MACROBLOCK *x, RD_STATS *rd_stats,
+ int bsize, int64_t dist) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ const int n4 = bsize_to_num_blk(bsize);
+ const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
+ memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
+ memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
+ mbmi->tx_size = tx_size;
+ for (int i = 0; i < n4; ++i) set_blk_skip(x, 0, i, 1);
+ rd_stats->skip = 1;
+ if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
+ rd_stats->dist = rd_stats->sse = (dist << 4);
+ // Though decision is to make the block as skip based on luma stats,
+ // it is possible that block becomes non skip after chroma rd. In addition
+ // intermediate non skip costs calculated by caller function will be
+ // incorrect, if rate is set as zero (i.e., if zero_blk_rate is not
+ // accounted). Hence intermediate rate is populated to code the luma tx blks
+ // as skip, the caller function based on final rd decision (i.e., skip vs
+ // non-skip) sets the final rate accordingly. Here the rate populated
+ // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
+ // size possible) in the current block. Eg: For 128*128 block, rate would be
+ // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
+ // block as 'all zeros'
+ ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
+ ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
+ av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
+ ENTROPY_CONTEXT *ta = ctxa;
+ ENTROPY_CONTEXT *tl = ctxl;
+ const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
+ TXB_CTX txb_ctx;
+ get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
+ const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
+ .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
+ rd_stats->rate = zero_blk_rate *
+ (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
+ (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
+}
+
+static AOM_INLINE void save_tx_rd_info(int n4, uint32_t hash,
+ const MACROBLOCK *const x,
+ const RD_STATS *const rd_stats,
+ MB_RD_RECORD *tx_rd_record) {
+ int index;
+ if (tx_rd_record->num < RD_RECORD_BUFFER_LEN) {
+ index =
+ (tx_rd_record->index_start + tx_rd_record->num) % RD_RECORD_BUFFER_LEN;
+ ++tx_rd_record->num;
+ } else {
+ index = tx_rd_record->index_start;
+ tx_rd_record->index_start =
+ (tx_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
+ }
+ MB_RD_INFO *const tx_rd_info = &tx_rd_record->tx_rd_info[index];
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ const MB_MODE_INFO *const mbmi = xd->mi[0];
+ tx_rd_info->hash_value = hash;
+ tx_rd_info->tx_size = mbmi->tx_size;
+ memcpy(tx_rd_info->blk_skip, x->blk_skip,
+ sizeof(tx_rd_info->blk_skip[0]) * n4);
+ av1_copy(tx_rd_info->inter_tx_size, mbmi->inter_tx_size);
+ av1_copy_array(tx_rd_info->tx_type_map, xd->tx_type_map, n4);
+ tx_rd_info->rd_stats = *rd_stats;
+}
+
+static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
+ const SPEED_FEATURES *sf,
+ int tx_size_search_method) {
+ if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
+
+ if (sf->tx_sf.tx_size_search_lgr_block) {
+ if (mi_width > mi_size_wide[BLOCK_64X64] ||
+ mi_height > mi_size_high[BLOCK_64X64])
+ return MAX_VARTX_DEPTH;
+ }
+
+ if (is_inter) {
+ return (mi_height != mi_width)
+ ? sf->tx_sf.inter_tx_size_search_init_depth_rect
+ : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
+ } else {
+ return (mi_height != mi_width)
+ ? sf->tx_sf.intra_tx_size_search_init_depth_rect
+ : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
+ }
+}
+
+static AOM_INLINE void select_tx_block(
+ const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
+ TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
+ ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
+ RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
+ int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
+ TXB_RD_INFO_NODE *rd_info_node);
+
+// NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
+// 0: Do not collect any RD stats
+// 1: Collect RD stats for transform units
+// 2: Collect RD stats for partition units
+#if CONFIG_COLLECT_RD_STATS
+
+static AOM_INLINE void get_energy_distribution_fine(
+ const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
+ const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
+ double *verdist) {
+ const int bw = block_size_wide[bsize];
+ const int bh = block_size_high[bsize];
+ unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
+
+ if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
+ // Special cases: calculate 'esq' values manually, as we don't have 'vf'
+ // functions for the 16 (very small) sub-blocks of this block.
+ const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
+ const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
+ assert(bw <= 32);
+ assert(bh <= 32);
+ assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
+ if (cpi->common.seq_params.use_highbitdepth) {
+ const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
+ const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
+ for (int i = 0; i < bh; ++i)
+ for (int j = 0; j < bw; ++j) {
+ const int index = (j >> w_shift) + ((i >> h_shift) << 2);
+ esq[index] +=
+ (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
+ (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
+ }
+ } else {
+ for (int i = 0; i < bh; ++i)
+ for (int j = 0; j < bw; ++j) {
+ const int index = (j >> w_shift) + ((i >> h_shift) << 2);
+ esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
+ (src[j + i * src_stride] - dst[j + i * dst_stride]);
+ }
+ }
+ } else { // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
+ const int f_index =
+ (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
+ assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
+ const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
+ assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
+ assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
+ cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
+ cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
+ &esq[1]);
+ cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
+ &esq[2]);
+ cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
+ dst_stride, &esq[3]);
+ src += bh / 4 * src_stride;
+ dst += bh / 4 * dst_stride;
+
+ cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
+ cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
+ &esq[5]);
+ cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
+ &esq[6]);
+ cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
+ dst_stride, &esq[7]);
+ src += bh / 4 * src_stride;
+ dst += bh / 4 * dst_stride;
+
+ cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
+ cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
+ &esq[9]);
+ cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
+ &esq[10]);
+ cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
+ dst_stride, &esq[11]);
+ src += bh / 4 * src_stride;
+ dst += bh / 4 * dst_stride;
+
+ cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
+ cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
+ &esq[13]);
+ cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
+ &esq[14]);
+ cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
+ dst_stride, &esq[15]);
+ }
+
+ double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
+ esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
+ esq[12] + esq[13] + esq[14] + esq[15];
+ if (total > 0) {
+ const double e_recip = 1.0 / total;
+ hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
+ hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
+ hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
+ if (need_4th) {
+ hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
+ }
+ verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
+ verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
+ verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
+ if (need_4th) {
+ verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
+ }
+ } else {
+ hordist[0] = verdist[0] = 0.25;
+ hordist[1] = verdist[1] = 0.25;
+ hordist[2] = verdist[2] = 0.25;
+ if (need_4th) {
+ hordist[3] = verdist[3] = 0.25;
+ }
+ }
+}
+
+static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
+ double sum = 0.0;
+ for (int j = 0; j < h; ++j) {
+ for (int i = 0; i < w; ++i) {
+ const int err = diff[j * stride + i];
+ sum += err * err;
+ }
+ }
+ assert(w > 0 && h > 0);
+ return sum / (w * h);
+}
+
+static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
+ double sum = 0.0;
+ for (int j = 0; j < h; ++j) {
+ for (int i = 0; i < w; ++i) {
+ sum += abs(diff[j * stride + i]);
+ }
+ }
+ assert(w > 0 && h > 0);
+ return sum / (w * h);
+}
+
+static AOM_INLINE void get_2x2_normalized_sses_and_sads(
+ const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
+ int src_stride, const uint8_t *const dst, int dst_stride,
+ const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
+ double *const sad_norm_arr) {
+ const BLOCK_SIZE tx_bsize_half =
+ get_partition_subsize(tx_bsize, PARTITION_SPLIT);
+ if (tx_bsize_half == BLOCK_INVALID) { // manually calculate stats
+ const int half_width = block_size_wide[tx_bsize] / 2;
+ const int half_height = block_size_high[tx_bsize] / 2;
+ for (int row = 0; row < 2; ++row) {
+ for (int col = 0; col < 2; ++col) {
+ const int16_t *const this_src_diff =
+ src_diff + row * half_height * diff_stride + col * half_width;
+ if (sse_norm_arr) {
+ sse_norm_arr[row * 2 + col] =
+ get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
+ }
+ if (sad_norm_arr) {
+ sad_norm_arr[row * 2 + col] =
+ get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
+ }
+ }
+ }
+ } else { // use function pointers to calculate stats
+ const int half_width = block_size_wide[tx_bsize_half];
+ const int half_height = block_size_high[tx_bsize_half];
+ const int num_samples_half = half_width * half_height;
+ for (int row = 0; row < 2; ++row) {
+ for (int col = 0; col < 2; ++col) {
+ const uint8_t *const this_src =
+ src + row * half_height * src_stride + col * half_width;
+ const uint8_t *const this_dst =
+ dst + row * half_height * dst_stride + col * half_width;
+
+ if (sse_norm_arr) {
+ unsigned int this_sse;
+ cpi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
+ dst_stride, &this_sse);
+ sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
+ }
+
+ if (sad_norm_arr) {
+ const unsigned int this_sad = cpi->fn_ptr[tx_bsize_half].sdf(
+ this_src, src_stride, this_dst, dst_stride);
+ sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
+ }
+ }
+ }
+ }
+}
+
+#if CONFIG_COLLECT_RD_STATS == 1
+static double get_mean(const int16_t *diff, int stride, int w, int h) {
+ double sum = 0.0;
+ for (int j = 0; j < h; ++j) {
+ for (int i = 0; i < w; ++i) {
+ sum += diff[j * stride + i];
+ }
+ }
+ assert(w > 0 && h > 0);
+ return sum / (w * h);
+}
+static AOM_INLINE void PrintTransformUnitStats(
+ const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
+ int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+ TX_TYPE tx_type, int64_t rd) {
+ if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
+
+ // Generate small sample to restrict output size.
+ static unsigned int seed = 21743;
+ if (lcg_rand16(&seed) % 256 > 0) return;
+
+ const char output_file[] = "tu_stats.txt";
+ FILE *fout = fopen(output_file, "a");
+ if (!fout) return;
+
+ const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ const int plane = 0;
+ struct macroblock_plane *const p = &x->plane[plane];
+ const struct macroblockd_plane *const pd = &xd->plane[plane];
+ const int txw = tx_size_wide[tx_size];
+ const int txh = tx_size_high[tx_size];
+ const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
+ const int q_step = p->dequant_QTX[1] >> dequant_shift;
+ const int num_samples = txw * txh;
+
+ const double rate_norm = (double)rd_stats->rate / num_samples;
+ const double dist_norm = (double)rd_stats->dist / num_samples;
+
+ fprintf(fout, "%g %g", rate_norm, dist_norm);
+
+ const int src_stride = p->src.stride;
+ const uint8_t *const src =
+ &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
+ const int dst_stride = pd->dst.stride;
+ const uint8_t *const dst =
+ &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
+ unsigned int sse;
+ cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
+ const double sse_norm = (double)sse / num_samples;
+
+ const unsigned int sad =
+ cpi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
+ const double sad_norm = (double)sad / num_samples;
+
+ fprintf(fout, " %g %g", sse_norm, sad_norm);
+
+ const int diff_stride = block_size_wide[plane_bsize];
+ const int16_t *const src_diff =
+ &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
+
+ double sse_norm_arr[4], sad_norm_arr[4];
+ get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
+ dst_stride, src_diff, diff_stride,
+ sse_norm_arr, sad_norm_arr);
+ for (int i = 0; i < 4; ++i) {
+ fprintf(fout, " %g", sse_norm_arr[i]);
+ }
+ for (int i = 0; i < 4; ++i) {
+ fprintf(fout, " %g", sad_norm_arr[i]);
+ }
+
+ const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
+ const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
+
+ fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
+ tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
+
+ int model_rate;
+ int64_t model_dist;
+ model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
+ &model_rate, &model_dist);
+ const double model_rate_norm = (double)model_rate / num_samples;
+ const double model_dist_norm = (double)model_dist / num_samples;
+ fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
+
+ const double mean = get_mean(src_diff, diff_stride, txw, txh);
+ float hor_corr, vert_corr;
+ av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
+ &vert_corr);
+ fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
+
+ double hdist[4] = { 0 }, vdist[4] = { 0 };
+ get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
+ 1, hdist, vdist);
+ fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
+ hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
+
+ fprintf(fout, " %d %" PRId64, x->rdmult, rd);
+
+ fprintf(fout, "\n");
+ fclose(fout);
+}
+#endif // CONFIG_COLLECT_RD_STATS == 1
+
+#if CONFIG_COLLECT_RD_STATS >= 2
+static int64_t get_sse(const uint8_t *a, int a_stride, const uint8_t *b,
+ int b_stride, int width, int height) {
+ const int dw = width % 16;
+ const int dh = height % 16;
+ int64_t total_sse = 0;
+ unsigned int sse = 0;
+ int sum = 0;
+ int x, y;
+
+ if (dw > 0) {
+ encoder_variance(&a[width - dw], a_stride, &b[width - dw], b_stride, dw,
+ height, &sse, &sum);
+ total_sse += sse;
+ }
+
+ if (dh > 0) {
+ encoder_variance(&a[(height - dh) * a_stride], a_stride,
+ &b[(height - dh) * b_stride], b_stride, width - dw, dh,
+ &sse, &sum);
+ total_sse += sse;
+ }
+
+ for (y = 0; y < height / 16; ++y) {
+ const uint8_t *pa = a;
+ const uint8_t *pb = b;
+ for (x = 0; x < width / 16; ++x) {
+ aom_mse16x16(pa, a_stride, pb, b_stride, &sse);
+ total_sse += sse;
+
+ pa += 16;
+ pb += 16;
+ }
+
+ a += 16 * a_stride;
+ b += 16 * b_stride;
+ }
+
+ return total_sse;
+}
+
+static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
+ int64_t sse, int *est_residue_cost,
+ int64_t *est_dist) {
+ aom_clear_system_state();
+ const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
+ if (md->ready) {
+ if (sse < md->dist_mean) {
+ *est_residue_cost = 0;
+ *est_dist = sse;
+ } else {
+ *est_dist = (int64_t)round(md->dist_mean);
+ const double est_ld = md->a * sse + md->b;
+ // Clamp estimated rate cost by INT_MAX / 2.
+ // TODO(angiebird@google.com): find better solution than clamping.
+ if (fabs(est_ld) < 1e-2) {
+ *est_residue_cost = INT_MAX / 2;
+ } else {
+ double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
+ if (est_residue_cost_dbl < 0) {
+ *est_residue_cost = 0;
+ } else {
+ *est_residue_cost =
+ (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
+ }
+ }
+ if (*est_residue_cost <= 0) {
+ *est_residue_cost = 0;
+ *est_dist = sse;
+ }
+ }
+ return 1;
+ }
+ return 0;
+}
+
+static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
+ const uint8_t *dst8, int dst_stride, int w,
+ int h) {
+ const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+ const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
+ double sum = 0.0;
+ for (int j = 0; j < h; ++j) {
+ for (int i = 0; i < w; ++i) {
+ const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
+ sum += diff;
+ }
+ }
+ assert(w > 0 && h > 0);
+ return sum / (w * h);
+}
+
+static double get_diff_mean(const uint8_t *src, int src_stride,
+ const uint8_t *dst, int dst_stride, int w, int h) {
+ double sum = 0.0;
+ for (int j = 0; j < h; ++j) {
+ for (int i = 0; i < w; ++i) {
+ const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
+ sum += diff;
+ }
+ }
+ assert(w > 0 && h > 0);
+ return sum / (w * h);
+}
+
+static AOM_INLINE void PrintPredictionUnitStats(const AV1_COMP *const cpi,
+ const TileDataEnc *tile_data,
+ MACROBLOCK *x,
+ const RD_STATS *const rd_stats,
+ BLOCK_SIZE plane_bsize) {
+ if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
+
+ if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
+ (tile_data == NULL ||
+ !tile_data->inter_mode_rd_models[plane_bsize].ready))
+ return;
+ (void)tile_data;
+ // Generate small sample to restrict output size.
+ static unsigned int seed = 95014;
+
+ if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
+ 1)
+ return;
+
+ const char output_file[] = "pu_stats.txt";
+ FILE *fout = fopen(output_file, "a");
+ if (!fout) return;
+
+ MACROBLOCKD *const xd = &x->e_mbd;
+ const int plane = 0;
+ struct macroblock_plane *const p = &x->plane[plane];
+ struct macroblockd_plane *pd = &xd->plane[plane];
+ const int diff_stride = block_size_wide[plane_bsize];
+ int bw, bh;
+ get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
+ &bh);
+ const int num_samples = bw * bh;
+ const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
+ const int q_step = p->dequant_QTX[1] >> dequant_shift;
+ const int shift = (xd->bd - 8);
+
+ const double rate_norm = (double)rd_stats->rate / num_samples;
+ const double dist_norm = (double)rd_stats->dist / num_samples;
+ const double rdcost_norm =
+ (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
+
+ fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
+
+ const int src_stride = p->src.stride;
+ const uint8_t *const src = p->src.buf;
+ const int dst_stride = pd->dst.stride;
+ const uint8_t *const dst = pd->dst.buf;
+ const int16_t *const src_diff = p->src_diff;
+
+ int64_t sse = calculate_sse(xd, p, pd, bw, bh);
+ const double sse_norm = (double)sse / num_samples;
+
+ const unsigned int sad =
+ cpi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
+ const double sad_norm =
+ (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
+
+ fprintf(fout, " %g %g", sse_norm, sad_norm);
+
+ double sse_norm_arr[4], sad_norm_arr[4];
+ get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
+ dst_stride, src_diff, diff_stride,
+ sse_norm_arr, sad_norm_arr);
+ if (shift) {
+ for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
+ for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
+ }
+ for (int i = 0; i < 4; ++i) {
+ fprintf(fout, " %g", sse_norm_arr[i]);
+ }
+ for (int i = 0; i < 4; ++i) {
+ fprintf(fout, " %g", sad_norm_arr[i]);
+ }
+
+ fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
+
+ int model_rate;
+ int64_t model_dist;
+ model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
+ &model_rate, &model_dist);
+ const double model_rdcost_norm =
+ (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
+ const double model_rate_norm = (double)model_rate / num_samples;
+ const double model_dist_norm = (double)model_dist / num_samples;
+ fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
+ model_rdcost_norm);
+
+ double mean;
+ if (is_cur_buf_hbd(xd)) {
+ mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
+ pd->dst.stride, bw, bh);
+ } else {
+ mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
+ bw, bh);
+ }
+ mean /= (1 << shift);
+ float hor_corr, vert_corr;
+ av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
+ &vert_corr);
+ fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
+
+ double hdist[4] = { 0 }, vdist[4] = { 0 };
+ get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
+ dst_stride, 1, hdist, vdist);
+ fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
+ hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
+
+ if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
+ assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
+ const int64_t overall_sse = get_sse(cpi, x);
+ int est_residue_cost = 0;
+ int64_t est_dist = 0;
+ get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
+ &est_dist);
+ const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
+ const double est_dist_norm = (double)est_dist / num_samples;
+ const double est_rdcost_norm =
+ (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
+ fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
+ est_rdcost_norm);
+ }
+
+ fprintf(fout, "\n");
+ fclose(fout);
+}
+#endif // CONFIG_COLLECT_RD_STATS >= 2
+#endif // CONFIG_COLLECT_RD_STATS
+
+static AOM_INLINE void inverse_transform_block_facade(MACROBLOCKD *xd,
+ int plane, int block,
+ int blk_row, int blk_col,
+ int eob,
+ int reduced_tx_set) {
+ if (!eob) return;
+
+ struct macroblockd_plane *const pd = &xd->plane[plane];
+ tran_low_t *dqcoeff = pd->dqcoeff + BLOCK_OFFSET(block);
+ const PLANE_TYPE plane_type = get_plane_type(plane);
+ const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
+ const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
+ tx_size, reduced_tx_set);
+ const int dst_stride = pd->dst.stride;
+ uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
+ av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
+ dst_stride, eob, reduced_tx_set);
+}
+
+static INLINE void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+ int block, int blk_row, int blk_col,
+ BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+ const TXB_CTX *const txb_ctx, int skip_trellis,
+ TX_TYPE best_tx_type, TX_TYPE last_tx_type,
+ int *rate_cost, uint16_t best_eob) {
+ const AV1_COMMON *cm = &cpi->common;
+ MACROBLOCKD *xd = &x->e_mbd;
+ MB_MODE_INFO *mbmi = xd->mi[0];
+ const int is_inter = is_inter_block(mbmi);
+ if (!is_inter && best_eob &&
+ (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
+ blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
+ // intra mode needs decoded result such that the next transform block
+ // can use it for prediction.
+ // if the last search tx_type is the best tx_type, we don't need to
+ // do this again
+ if (best_tx_type != last_tx_type) {
+ TxfmParam txfm_param_intra;
+ QUANT_PARAM quant_param_intra;
+ av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
+ av1_setup_quant(cm, tx_size, !skip_trellis,
+ skip_trellis
+ ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
+ : AV1_XFORM_QUANT_FP)
+ : AV1_XFORM_QUANT_FP,
+ &quant_param_intra);
+ av1_setup_qmatrix(cm, x, plane, tx_size, best_tx_type,
+ &quant_param_intra);
+ av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
+ &txfm_param_intra, &quant_param_intra);
+ if (quant_param_intra.use_optimize_b) {
+ av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
+ cpi->sf.rd_sf.trellis_eob_fast, rate_cost);
+ }
+ }
+
+ inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,
+ x->plane[plane].eobs[block],
+ cm->reduced_tx_set_used);
+
+ // This may happen because of hash collision. The eob stored in the hash
+ // table is non-zero, but the real eob is zero. We need to make sure tx_type
+ // is DCT_DCT in this case.
+ if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
+ best_tx_type != DCT_DCT) {
+ update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
+ }
+ }
+}
+
+static unsigned pixel_dist_visible_only(
+ const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
+ const int src_stride, const uint8_t *dst, const int dst_stride,
+ const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
+ int visible_cols) {
+ unsigned sse;
+
+ if (txb_rows == visible_rows && txb_cols == visible_cols) {
+ cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
+ return sse;
+ }
+
+#if CONFIG_AV1_HIGHBITDEPTH
+ const MACROBLOCKD *xd = &x->e_mbd;
+ if (is_cur_buf_hbd(xd)) {
+ uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
+ visible_cols, visible_rows);
+ return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
+ }
+#else
+ (void)x;
+#endif
+ sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
+ visible_rows);
+ return sse;
+}
+
+// Compute the pixel domain distortion from src and dst on all visible 4x4s in
+// the
+// transform block.
+static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
+ int plane, const uint8_t *src, const int src_stride,
+ const uint8_t *dst, const int dst_stride,
+ int blk_row, int blk_col,
+ const BLOCK_SIZE plane_bsize,
+ const BLOCK_SIZE tx_bsize) {
+ int txb_rows, txb_cols, visible_rows, visible_cols;
+ const MACROBLOCKD *xd = &x->e_mbd;
+
+ get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
+ &txb_cols, &txb_rows, &visible_cols, &visible_rows);
+ assert(visible_rows > 0);
+ assert(visible_cols > 0);
+
+#if CONFIG_DIST_8X8
+ if (x->using_dist_8x8 && plane == 0)
+ return (unsigned)av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride,
+ tx_bsize, txb_cols, txb_rows, visible_cols,
+ visible_rows, x->qindex);
+#endif // CONFIG_DIST_8X8
+
+ unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
+ dst_stride, tx_bsize, txb_rows,
+ txb_cols, visible_rows, visible_cols);
+
+ return sse;
+}
+
+static INLINE int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
+ int plane, BLOCK_SIZE plane_bsize,
+ int block, int blk_row, int blk_col,
+ TX_SIZE tx_size) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ const struct macroblock_plane *const p = &x->plane[plane];
+ const struct macroblockd_plane *const pd = &xd->plane[plane];
+ const uint16_t eob = p->eobs[block];
+ const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
+ const int bsw = block_size_wide[tx_bsize];
+ const int bsh = block_size_high[tx_bsize];
+ const int src_stride = x->plane[plane].src.stride;
+ const int dst_stride = xd->plane[plane].dst.stride;
+ // Scale the transform block index to pixel unit.
+ const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
+ const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
+ const uint8_t *src = &x->plane[plane].src.buf[src_idx];
+ const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
+ const tran_low_t *dqcoeff = pd->dqcoeff + BLOCK_OFFSET(block);
+
+ assert(cpi != NULL);
+ assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
+
+ uint8_t *recon;
+ DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
+
+#if CONFIG_AV1_HIGHBITDEPTH
+ if (is_cur_buf_hbd(xd)) {
+ recon = CONVERT_TO_BYTEPTR(recon16);
+ av1_highbd_convolve_2d_copy_sr(CONVERT_TO_SHORTPTR(dst), dst_stride,
+ CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw,
+ bsh, NULL, NULL, 0, 0, NULL, xd->bd);
+ } else {
+ recon = (uint8_t *)recon16;
+ av1_convolve_2d_copy_sr(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh, NULL,
+ NULL, 0, 0, NULL);
+ }
+#else
+ recon = (uint8_t *)recon16;
+ av1_convolve_2d_copy_sr(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh, NULL,
+ NULL, 0, 0, NULL);
+#endif
+
+ const PLANE_TYPE plane_type = get_plane_type(plane);
+ TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
+ cpi->common.reduced_tx_set_used);
+ av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
+ MAX_TX_SIZE, eob,
+ cpi->common.reduced_tx_set_used);
+
+ return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
+ blk_row, blk_col, plane_bsize, tx_bsize);
+}
+
+static uint32_t get_intra_txb_hash(MACROBLOCK *x, int plane, int blk_row,
+ int blk_col, BLOCK_SIZE plane_bsize,
+ TX_SIZE tx_size) {
+ int16_t tmp_data[64 * 64];
+ const int diff_stride = block_size_wide[plane_bsize];
+ const int16_t *diff = x->plane[plane].src_diff;
+ const int16_t *cur_diff_row = diff + 4 * blk_row * diff_stride + 4 * blk_col;
+ const int txb_w = tx_size_wide[tx_size];
+ const int txb_h = tx_size_high[tx_size];
+ uint8_t *hash_data = (uint8_t *)cur_diff_row;
+ if (txb_w != diff_stride) {
+ int16_t *cur_hash_row = tmp_data;
+ for (int i = 0; i < txb_h; i++) {
+ memcpy(cur_hash_row, cur_diff_row, sizeof(*diff) * txb_w);
+ cur_hash_row += txb_w;
+ cur_diff_row += diff_stride;
+ }
+ hash_data = (uint8_t *)tmp_data;
+ }
+ CRC32C *crc = &x->mb_rd_record.crc_calculator;
+ const uint32_t hash = av1_get_crc32c_value(crc, hash_data, 2 * txb_w * txb_h);
+ return (hash << 5) + tx_size;
+}
+
+// pruning thresholds for prune_txk_type and prune_txk_type_separ
+static const int prune_factors[5] = { 200, 200, 120, 80, 40 }; // scale 1000
+static const int mul_factors[5] = { 80, 80, 70, 50, 30 }; // scale 100
+
+static INLINE int is_intra_hash_match(
+ const AV1_COMP *cpi, MACROBLOCK *x, int plane, int blk_row, int blk_col,
+ BLOCK_SIZE plane_bsize, TX_SIZE tx_size, const TXB_CTX *const txb_ctx,
+ TXB_RD_INFO **intra_txb_rd_info, int within_border,
+ const int tx_type_map_idx, uint16_t *cur_joint_ctx) {
+ const AV1_COMMON *cm = &cpi->common;
+ MACROBLOCKD *xd = &x->e_mbd;
+ MB_MODE_INFO *mbmi = xd->mi[0];
+ const int is_inter = is_inter_block(mbmi);
+ if (within_border && cpi->sf.tx_sf.use_intra_txb_hash &&
+ frame_is_intra_only(cm) && !is_inter && plane == 0 &&
+ tx_size_wide[tx_size] == tx_size_high[tx_size]) {
+ const uint32_t intra_hash =
+ get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
+ const int intra_hash_idx =
+ find_tx_size_rd_info(&x->txb_rd_record_intra, intra_hash);
+ *intra_txb_rd_info = &x->txb_rd_record_intra.tx_rd_info[intra_hash_idx];
+ *cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
+ if ((*intra_txb_rd_info)->entropy_context == *cur_joint_ctx &&
+ x->txb_rd_record_intra.tx_rd_info[intra_hash_idx].valid) {
+ xd->tx_type_map[tx_type_map_idx] = (*intra_txb_rd_info)->tx_type;
+ const TX_TYPE ref_tx_type =
+ av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
+ cpi->common.reduced_tx_set_used);
+ return (ref_tx_type == (*intra_txb_rd_info)->tx_type);
+ }
+ }
+ return 0;
+}
+
+// R-D costs are sorted in ascending order.
+static INLINE void sort_rd(int64_t rds[], int txk[], int len) {
+ int i, j, k;
+
+ for (i = 1; i <= len - 1; ++i) {
+ for (j = 0; j < i; ++j) {
+ if (rds[j] > rds[i]) {
+ int64_t temprd;
+ int tempi;
+
+ temprd = rds[i];
+ tempi = txk[i];
+
+ for (k = i; k > j; k--) {
+ rds[k] = rds[k - 1];
+ txk[k] = txk[k - 1];
+ }
+
+ rds[j] = temprd;
+ txk[j] = tempi;
+ break;
+ }
+ }
+ }
+}
+
+static INLINE void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
+ TX_SIZE tx_size, int64_t *out_dist,
+ int64_t *out_sse) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ const struct macroblock_plane *const p = &x->plane[plane];
+ const struct macroblockd_plane *const pd = &xd->plane[plane];
+ // Transform domain distortion computation is more efficient as it does
+ // not involve an inverse transform, but it is less accurate.
+ const int buffer_length = av1_get_max_eob(tx_size);
+ int64_t this_sse;
+ // TX-domain results need to shift down to Q2/D10 to match pixel
+ // domain distortion values which are in Q2^2
+ int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
+ const int block_offset = BLOCK_OFFSET(block);
+ tran_low_t *const coeff = p->coeff + block_offset;
+ tran_low_t *const dqcoeff = pd->dqcoeff + block_offset;
+#if CONFIG_AV1_HIGHBITDEPTH
+ if (is_cur_buf_hbd(xd))
+ *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
+ xd->bd);
+ else
+ *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
+#else
+ *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
+#endif
+ *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
+ *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
+}
+
+uint16_t prune_txk_type_separ(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+ int block, TX_SIZE tx_size, int blk_row,
+ int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
+ int16_t allowed_tx_mask, int prune_factor,
+ const TXB_CTX *const txb_ctx,
+ int reduced_tx_set_used, int64_t ref_best_rd,
+ int num_sel) {
+ const AV1_COMMON *cm = &cpi->common;
+
+ int idx;
+
+ int64_t rds_v[4];
+ int64_t rds_h[4];
+ int idx_v[4] = { 0, 1, 2, 3 };
+ int idx_h[4] = { 0, 1, 2, 3 };
+ int skip_v[4] = { 0 };
+ int skip_h[4] = { 0 };
+ const int idx_map[16] = {
+ DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
+ ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
+ FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
+ H_DCT, H_ADST, H_FLIPADST, IDTX
+ };
+
+ const int sel_pattern_v[16] = {
+ 0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
+ };
+ const int sel_pattern_h[16] = {
+ 0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
+ };
+
+ QUANT_PARAM quant_param;
+ TxfmParam txfm_param;
+ av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
+ av1_setup_quant(cm, tx_size, 1, AV1_XFORM_QUANT_B, &quant_param);
+ int tx_type;
+ // to ensure we can try ones even outside of ext_tx_set of current block
+ // this function should only be called for size < 16
+ assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
+ txfm_param.tx_set_type = EXT_TX_SET_ALL16;
+
+ int rate_cost = 0;
+ int64_t dist = 0, sse = 0;
+ // evaluate horizontal with vertical DCT
+ for (idx = 0; idx < 4; ++idx) {
+ tx_type = idx_map[idx];
+ txfm_param.tx_type = tx_type;
+
+ av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+ &quant_param);
+
+ dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
+
+ rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
+ txb_ctx, reduced_tx_set_used, 0);
+
+ rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
+
+ if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
+ skip_h[idx] = 1;
+ }
+ }
+ sort_rd(rds_h, idx_h, 4);
+ for (idx = 1; idx < 4; idx++) {
+ if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
+ }
+
+ if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
+
+ // evaluate vertical with the best horizontal chosen
+ rds_v[0] = rds_h[0];
+ int start_v = 1, end_v = 4;
+ const int *idx_map_v = idx_map + idx_h[0];
+
+ for (idx = start_v; idx < end_v; ++idx) {
+ tx_type = idx_map_v[idx_v[idx] * 4];
+ txfm_param.tx_type = tx_type;
+
+ av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+ &quant_param);
+
+ dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
+
+ rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
+ txb_ctx, reduced_tx_set_used, 0);
+
+ rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
+
+ if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
+ skip_v[idx] = 1;
+ }
+ }
+ sort_rd(rds_v, idx_v, 4);
+ for (idx = 1; idx < 4; idx++) {
+ if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
+ }
+
+ // combine rd_h and rd_v to prune tx candidates
+ int i_v, i_h;
+ int64_t rds[16];
+ int num_cand = 0, last = TX_TYPES - 1;
+
+ for (int i = 0; i < 16; i++) {
+ i_v = sel_pattern_v[i];
+ i_h = sel_pattern_h[i];
+ tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
+ if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
+ skip_v[idx_v[i_v]]) {
+ txk_map[last] = tx_type;
+ last--;
+ } else {
+ txk_map[num_cand] = tx_type;
+ rds[num_cand] = rds_v[i_v] + rds_h[i_h];
+ if (rds[num_cand] == 0) rds[num_cand] = 1;
+ num_cand++;
+ }
+ }
+ sort_rd(rds, txk_map, num_cand);
+
+ uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
+ num_sel = AOMMIN(num_sel, num_cand);
+
+ for (int i = 1; i < num_sel; i++) {
+ int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
+ if (factor < (int64_t)prune_factor)
+ prune &= ~(1 << txk_map[i]);
+ else
+ break;
+ }
+ return prune;
+}
+
+uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+ int block, TX_SIZE tx_size, int blk_row, int blk_col,
+ BLOCK_SIZE plane_bsize, int *txk_map,
+ uint16_t allowed_tx_mask, int prune_factor,
+ const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
+ const AV1_COMMON *cm = &cpi->common;
+ int tx_type;
+
+ int64_t rds[TX_TYPES];
+
+ int num_cand = 0;
+ int last = TX_TYPES - 1;
+
+ TxfmParam txfm_param;
+ QUANT_PARAM quant_param;
+ av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
+ av1_setup_quant(cm, tx_size, 1, AV1_XFORM_QUANT_B, &quant_param);
+
+ for (int idx = 0; idx < TX_TYPES; idx++) {
+ tx_type = idx;
+ int rate_cost = 0;
+ int64_t dist = 0, sse = 0;
+ if (!(allowed_tx_mask & (1 << tx_type))) {
+ txk_map[last] = tx_type;
+ last--;
+ continue;
+ }
+ txfm_param.tx_type = tx_type;
+
+ // do txfm and quantization
+ av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+ &quant_param);
+ // estimate rate cost
+ rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
+ txb_ctx, reduced_tx_set_used, 0);
+ // tx domain dist
+ dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
+
+ txk_map[num_cand] = tx_type;
+ rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
+ if (rds[num_cand] == 0) rds[num_cand] = 1;
+ num_cand++;
+ }
+
+ if (num_cand == 0) return (uint16_t)0xFFFF;
+
+ sort_rd(rds, txk_map, num_cand);
+ uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
+
+ // 0 < prune_factor <= 1000 controls aggressiveness
+ int64_t factor = 0;
+ for (int idx = 1; idx < num_cand; idx++) {
+ factor = 1000 * (rds[idx] - rds[0]) / rds[0];
+ if (factor < (int64_t)prune_factor)
+ prune &= ~(1 << txk_map[idx]);
+ else
+ break;
+ }
+ return prune;
+}
+
+// These thresholds were calibrated to provide a certain number of TX types
+// pruned by the model on average, i.e. selecting a threshold with index i
+// will lead to pruning i+1 TX types on average
+static const float *prune_2D_adaptive_thresholds[] = {
+ // TX_4X4
+ (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
+ 0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
+ 0.09778f, 0.11780f },
+ // TX_8X8
+ (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
+ 0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
+ 0.10803f, 0.14124f },
+ // TX_16X16
+ (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
+ 0.06897f, 0.07629f, 0.08875f, 0.11169f },
+ // TX_32X32
+ NULL,
+ // TX_64X64
+ NULL,
+ // TX_4X8
+ (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
+ 0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
+ 0.10168f, 0.12585f },
+ // TX_8X4
+ (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
+ 0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
+ 0.10583f, 0.13123f },
+ // TX_8X16
+ (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
+ 0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
+ 0.10730f, 0.14221f },
+ // TX_16X8
+ (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
+ 0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
+ 0.10339f, 0.13464f },
+ // TX_16X32
+ NULL,
+ // TX_32X16
+ NULL,
+ // TX_32X64
+ NULL,
+ // TX_64X32
+ NULL,
+ // TX_4X16
+ (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
+ 0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
+ 0.10242f, 0.12878f },
+ // TX_16X4
+ (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
+ 0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
+ 0.10217f, 0.12610f },
+ // TX_8X32
+ NULL,
+ // TX_32X8
+ NULL,
+ // TX_16X64
+ NULL,
+ // TX_64X16
+ NULL,
+};
+
+// Probablities are sorted in descending order.
+static INLINE void sort_probability(float prob[], int txk[], int len) {
+ int i, j, k;
+
+ for (i = 1; i <= len - 1; ++i) {
+ for (j = 0; j < i; ++j) {
+ if (prob[j] < prob[i]) {
+ float temp;
+ int tempi;
+
+ temp = prob[i];
+ tempi = txk[i];
+
+ for (k = i; k > j; k--) {
+ prob[k] = prob[k - 1];
+ txk[k] = txk[k - 1];
+ }
+
+ prob[j] = temp;
+ txk[j] = tempi;
+ break;
+ }
+ }
+ }
+}
+
+static INLINE float get_adaptive_thresholds(TX_SIZE tx_size,
+ TxSetType tx_set_type,
+ TX_TYPE_PRUNE_MODE prune_mode) {
+ const int prune_aggr_table[4][2] = { { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 } };
+ int pruning_aggressiveness = 0;
+ if (tx_set_type == EXT_TX_SET_ALL16)
+ pruning_aggressiveness =
+ prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][0];
+ else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
+ pruning_aggressiveness =
+ prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][1];
+
+ return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
+}
+
+static AOM_INLINE void get_energy_distribution_finer(const int16_t *diff,
+ int stride, int bw, int bh,
+ float *hordist,
+ float *verdist) {
+ // First compute downscaled block energy values (esq); downscale factors
+ // are defined by w_shift and h_shift.
+ unsigned int esq[256];
+ const int w_shift = bw <= 8 ? 0 : 1;
+ const int h_shift = bh <= 8 ? 0 : 1;
+ const int esq_w = bw >> w_shift;
+ const int esq_h = bh >> h_shift;
+ const int esq_sz = esq_w * esq_h;
+ int i, j;
+ memset(esq, 0, esq_sz * sizeof(esq[0]));
+ if (w_shift) {
+ for (i = 0; i < bh; i++) {
+ unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
+ const int16_t *cur_diff_row = diff + i * stride;
+ for (j = 0; j < bw; j += 2) {
+ cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
+ cur_diff_row[j + 1] * cur_diff_row[j + 1]);
+ }
+ }
+ } else {
+ for (i = 0; i < bh; i++) {
+ unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
+ const int16_t *cur_diff_row = diff + i * stride;
+ for (j = 0; j < bw; j++) {
+ cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
+ }
+ }
+ }
+
+ uint64_t total = 0;
+ for (i = 0; i < esq_sz; i++) total += esq[i];
+
+ // Output hordist and verdist arrays are normalized 1D projections of esq
+ if (total == 0) {
+ float hor_val = 1.0f / esq_w;
+ for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
+ float ver_val = 1.0f / esq_h;
+ for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
+ return;
+ }
+
+ const float e_recip = 1.0f / (float)total;
+ memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
+ memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
+ const unsigned int *cur_esq_row;
+ for (i = 0; i < esq_h - 1; i++) {
+ cur_esq_row = esq + i * esq_w;
+ for (j = 0; j < esq_w - 1; j++) {
+ hordist[j] += (float)cur_esq_row[j];
+ verdist[i] += (float)cur_esq_row[j];
+ }
+ verdist[i] += (float)cur_esq_row[j];
+ }
+ cur_esq_row = esq + i * esq_w;
+ for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
+
+ for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
+ for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
+}
+
+static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
+ int blk_row, int blk_col, TxSetType tx_set_type,
+ TX_TYPE_PRUNE_MODE prune_mode, int *txk_map,
+ uint16_t *allowed_tx_mask) {
+ int tx_type_table_2D[16] = {
+ DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
+ ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
+ FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
+ H_DCT, H_ADST, H_FLIPADST, IDTX
+ };
+ if (tx_set_type != EXT_TX_SET_ALL16 &&
+ tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
+ return;
+#if CONFIG_NN_V2
+ NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
+ NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
+#else
+ const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
+ const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
+#endif
+ if (!nn_config_hor || !nn_config_ver) return; // Model not established yet.
+
+ aom_clear_system_state();
+ float hfeatures[16], vfeatures[16];
+ float hscores[4], vscores[4];
+ float scores_2D_raw[16];
+ float scores_2D[16];
+ const int bw = tx_size_wide[tx_size];
+ const int bh = tx_size_high[tx_size];
+ const int hfeatures_num = bw <= 8 ? bw : bw / 2;
+ const int vfeatures_num = bh <= 8 ? bh : bh / 2;
+ assert(hfeatures_num <= 16);
+ assert(vfeatures_num <= 16);
+
+ const struct macroblock_plane *const p = &x->plane[0];
+ const int diff_stride = block_size_wide[bsize];
+ const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
+ get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
+ vfeatures);
+ av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
+ &hfeatures[hfeatures_num - 1],
+ &vfeatures[vfeatures_num - 1]);
+ aom_clear_system_state();
+#if CONFIG_NN_V2
+ av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
+ av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
+#else
+ av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
+ av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
+#endif
+ aom_clear_system_state();
+
+ for (int i = 0; i < 4; i++) {
+ float *cur_scores_2D = scores_2D_raw + i * 4;
+ cur_scores_2D[0] = vscores[i] * hscores[0];
+ cur_scores_2D[1] = vscores[i] * hscores[1];
+ cur_scores_2D[2] = vscores[i] * hscores[2];
+ cur_scores_2D[3] = vscores[i] * hscores[3];
+ }
+
+ av1_nn_softmax(scores_2D_raw, scores_2D, 16);
+
+ const float score_thresh =
+ get_adaptive_thresholds(tx_size, tx_set_type, prune_mode);
+
+ // Always keep the TX type with the highest score, prune all others with
+ // score below score_thresh.
+ int max_score_i = 0;
+ float max_score = 0.0f;
+ uint16_t allow_bitmask = 0;
+ float sum_score = 0.0;
+ // Calculate sum of allowed tx type score and Populate allow bit mask based
+ // on score_thresh and allowed_tx_mask
+ for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
+ int allow_tx_type = *allowed_tx_mask & (1 << tx_type_table_2D[tx_idx]);
+ if (scores_2D[tx_idx] > max_score && allow_tx_type) {
+ max_score = scores_2D[tx_idx];
+ max_score_i = tx_idx;
+ }
+ if (scores_2D[tx_idx] >= score_thresh && allow_tx_type) {
+ // Set allow mask based on score_thresh
+ allow_bitmask |= (1 << tx_type_table_2D[tx_idx]);
+
+ // Accumulate score of allowed tx type
+ sum_score += scores_2D[tx_idx];
+ }
+ }
+ if (!((allow_bitmask >> max_score_i) & 0x01)) {
+ // Set allow mask based on tx type with max score
+ allow_bitmask |= (1 << tx_type_table_2D[max_score_i]);
+ sum_score += scores_2D[max_score_i];
+ }
+ // Sort tx type probability of all types
+ sort_probability(scores_2D, tx_type_table_2D, TX_TYPES);
+
+ // Enable more pruning based on tx type probability and number of allowed tx
+ // types
+ if (prune_mode == PRUNE_2D_AGGRESSIVE) {
+ float temp_score = 0.0;
+ float score_ratio = 0.0;
+ int tx_idx, tx_count = 0;
+ const float inv_sum_score = 100 / sum_score;
+ // Get allowed tx types based on sorted probability score and tx count
+ for (tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
+ // Skip the tx type which has more than 30% of cumulative
+ // probability and allowed tx type count is more than 2
+ if (score_ratio > 30.0 && tx_count >= 2) break;
+
+ // Calculate cumulative probability of allowed tx types
+ if (allow_bitmask & (1 << tx_type_table_2D[tx_idx])) {
+ // Calculate cumulative probability
+ temp_score += scores_2D[tx_idx];
+
+ // Calculate percentage of cumulative probability of allowed tx type
+ score_ratio = temp_score * inv_sum_score;
+ tx_count++;
+ }
+ }
+ // Set remaining tx types as pruned
+ for (; tx_idx < TX_TYPES; tx_idx++)
+ allow_bitmask &= ~(1 << tx_type_table_2D[tx_idx]);
+ }
+ memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
+ *allowed_tx_mask = allow_bitmask;
+}
+
+static float get_dev(float mean, double x2_sum, int num) {
+ const float e_x2 = (float)(x2_sum / num);
+ const float diff = e_x2 - mean * mean;
+ const float dev = (diff > 0) ? sqrtf(diff) : 0;
+ return dev;
+}
+
+// Feature used by the model to predict tx split: the mean and standard
+// deviation values of the block and sub-blocks.
+static AOM_INLINE void get_mean_dev_features(const int16_t *data, int stride,
+ int bw, int bh, float *feature) {
+ const int16_t *const data_ptr = &data[0];
+ const int subh = (bh >= bw) ? (bh >> 1) : bh;
+ const int subw = (bw >= bh) ? (bw >> 1) : bw;
+ const int num = bw * bh;
+ const int sub_num = subw * subh;
+ int feature_idx = 2;
+ int total_x_sum = 0;
+ int64_t total_x2_sum = 0;
+ int blk_idx = 0;
+ double mean2_sum = 0.0f;
+ float dev_sum = 0.0f;
+
+ for (int row = 0; row < bh; row += subh) {
+ for (int col = 0; col < bw; col += subw) {
+ int x_sum;
+ int64_t x2_sum;
+ // TODO(any): Write a SIMD version. Clear registers.
+ aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
+ &x_sum, &x2_sum);
+ total_x_sum += x_sum;
+ total_x2_sum += x2_sum;
+
+ aom_clear_system_state();
+ const float mean = (float)x_sum / sub_num;
+ const float dev = get_dev(mean, (double)x2_sum, sub_num);
+ feature[feature_idx++] = mean;
+ feature[feature_idx++] = dev;
+ mean2_sum += (double)(mean * mean);
+ dev_sum += dev;
+ blk_idx++;
+ }
+ }
+
+ const float lvl0_mean = (float)total_x_sum / num;
+ feature[0] = lvl0_mean;
+ feature[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
+
+ if (blk_idx > 1) {
+ // Deviation of means.
+ feature[feature_idx++] = get_dev(lvl0_mean, mean2_sum, blk_idx);
+ // Mean of deviations.
+ feature[feature_idx++] = dev_sum / blk_idx;
+ }
+}
+
+static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
+ int blk_col, TX_SIZE tx_size) {
+ const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
+ if (!nn_config) return -1;
+
+ const int diff_stride = block_size_wide[bsize];
+ const int16_t *diff =
+ x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
+ const int bw = tx_size_wide[tx_size];
+ const int bh = tx_size_high[tx_size];
+ aom_clear_system_state();
+
+ float features[64] = { 0.0f };
+ get_mean_dev_features(diff, diff_stride, bw, bh, features);
+
+ float score = 0.0f;
+ av1_nn_predict(features, nn_config, 1, &score);
+ aom_clear_system_state();
+
+ int int_score = (int)(score * 10000);
+ return clamp(int_score, -80000, 80000);
+}
+
+static void search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
+ int block, int blk_row, int blk_col,
+ BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+ const TXB_CTX *const txb_ctx,
+ FAST_TX_SEARCH_MODE ftxs_mode,
+ int use_fast_coef_costing, int skip_trellis,
+ int64_t ref_best_rd, RD_STATS *best_rd_stats) {
+ const AV1_COMMON *cm = &cpi->common;
+ MACROBLOCKD *xd = &x->e_mbd;
+ struct macroblockd_plane *const pd = &xd->plane[plane];
+ MB_MODE_INFO *mbmi = xd->mi[0];
+ const int is_inter = is_inter_block(mbmi);
+ int64_t best_rd = INT64_MAX;
+ uint16_t best_eob = 0;
+ TX_TYPE best_tx_type = DCT_DCT;
+ TX_TYPE last_tx_type = TX_TYPES;
+ int rate_cost = 0;
+ const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
+ // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
+ // of the best tx_type
+ DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff[MAX_SB_SQUARE]);
+ tran_low_t *orig_dqcoeff = pd->dqcoeff;
+ tran_low_t *best_dqcoeff = this_dqcoeff;
+ const int tx_type_map_idx =
+ plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
+ int perform_block_coeff_opt = 0;
+ av1_invalid_rd_stats(best_rd_stats);
+
+ TXB_RD_INFO *intra_txb_rd_info = NULL;
+ uint16_t cur_joint_ctx = 0;
+ const int mi_row = xd->mi_row;
+ const int mi_col = xd->mi_col;
+ const int within_border =
+ mi_row >= xd->tile.mi_row_start &&
+ (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
+ mi_col >= xd->tile.mi_col_start &&
+ (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
+ skip_trellis |=
+ cpi->optimize_seg_arr[mbmi->segment_id] == NO_TRELLIS_OPT ||
+ cpi->optimize_seg_arr[mbmi->segment_id] == FINAL_PASS_TRELLIS_OPT;
+ if (is_intra_hash_match(cpi, x, plane, blk_row, blk_col, plane_bsize, tx_size,
+ txb_ctx, &intra_txb_rd_info, within_border,
+ tx_type_map_idx, &cur_joint_ctx)) {
+ best_rd_stats->rate = intra_txb_rd_info->rate;
+ best_rd_stats->dist = intra_txb_rd_info->dist;
+ best_rd_stats->sse = intra_txb_rd_info->sse;
+ best_rd_stats->skip = intra_txb_rd_info->eob == 0;
+ x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
+ x->plane[plane].txb_entropy_ctx[block] = intra_txb_rd_info->txb_entropy_ctx;
+ best_eob = intra_txb_rd_info->eob;
+ best_tx_type = intra_txb_rd_info->tx_type;
+ perform_block_coeff_opt = intra_txb_rd_info->perform_block_coeff_opt;
+ skip_trellis |= !perform_block_coeff_opt;
+ update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
+ recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+ txb_ctx, skip_trellis, best_tx_type, last_tx_type, &rate_cost,
+ best_eob);
+ pd->dqcoeff = orig_dqcoeff;
+ return;
+ }
+
+ // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
+ // TX_TYPES, only that specific tx type is allowed.
+ TX_TYPE txk_allowed = TX_TYPES;
+ int txk_map[TX_TYPES] = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+ };
+
+ if ((!is_inter && x->use_default_intra_tx_type) ||
+ (is_inter && x->use_default_inter_tx_type)) {
+ txk_allowed =
+ get_default_tx_type(0, xd, tx_size, cpi->is_screen_content_type);
+ } else if (x->rd_model == LOW_TXFM_RD) {
+ if (plane == 0) txk_allowed = DCT_DCT;
+ }
+
+ uint8_t best_txb_ctx = 0;
+ const TxSetType tx_set_type =
+ av1_get_ext_tx_set_type(tx_size, is_inter, cm->reduced_tx_set_used);
+
+ TX_TYPE uv_tx_type = DCT_DCT;
+ if (plane) {
+ // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
+ uv_tx_type = txk_allowed =
+ av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
+ cm->reduced_tx_set_used);
+ }
+ PREDICTION_MODE intra_dir =
+ mbmi->filter_intra_mode_info.use_filter_intra
+ ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
+ : mbmi->mode;
+ uint16_t ext_tx_used_flag =
+ cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset &&
+ tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
+ ? av1_reduced_intra_tx_used_flag[intra_dir]
+ : av1_ext_tx_used_flag[tx_set_type];
+ if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
+ ext_tx_used_flag == 0x0001 ||
+ (is_inter && cpi->oxcf.use_inter_dct_only) ||
+ (!is_inter && cpi->oxcf.use_intra_dct_only)) {
+ txk_allowed = DCT_DCT;
+ }
+
+ const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
+ int64_t block_sse = 0;
+ unsigned int block_mse_q8 = UINT_MAX;
+ block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize, tx_bsize,
+ &block_mse_q8);
+ assert(block_mse_q8 != UINT_MAX);
+ if (is_cur_buf_hbd(xd)) {
+ block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
+ block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
+ }
+ block_sse *= 16;
+
+ // Used mse based threshold logic to take decision of R-D of optimization of
+ // coeffs. For smaller residuals, coeff optimization would be helpful. For
+ // larger residuals, R-D optimization may not be effective.
+ // TODO(any): Experiment with variance and mean based thresholds
+ perform_block_coeff_opt = (block_mse_q8 <= x->coeff_opt_dist_threshold);
+ skip_trellis |= !perform_block_coeff_opt;
+
+ if (cpi->oxcf.enable_flip_idtx == 0) {
+ for (TX_TYPE tx_type = FLIPADST_DCT; tx_type <= H_FLIPADST; ++tx_type) {
+ ext_tx_used_flag &= ~(1 << tx_type);
+ }
+ }
+
+ uint16_t allowed_tx_mask = 0; // 1: allow; 0: skip.
+ if (txk_allowed < TX_TYPES) {
+ allowed_tx_mask = 1 << txk_allowed;
+ allowed_tx_mask &= ext_tx_used_flag;
+ } else if (fast_tx_search) {
+ allowed_tx_mask = 0x0c01; // V_DCT, H_DCT, DCT_DCT
+ allowed_tx_mask &= ext_tx_used_flag;
+ } else {
+ assert(plane == 0);
+ allowed_tx_mask = ext_tx_used_flag;
+ int num_allowed = 0;
+ const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
+ const int *tx_type_probs = cpi->tx_type_probs[update_type][tx_size];
+ int i;
+
+ if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
+ const int thresh = cpi->tx_type_probs_thresh[update_type];
+ uint16_t prune = 0;
+ int max_prob = -1;
+ int max_idx = 0;
+ for (i = 0; i < TX_TYPES; i++) {
+ if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
+ max_prob = tx_type_probs[i];
+ max_idx = i;
+ }
+ }
+
+ for (i = 0; i < TX_TYPES; i++) {
+ if (tx_type_probs[i] < thresh && i != max_idx) prune |= (1 << i);
+ }
+ allowed_tx_mask &= (~prune);
+ }
+ for (i = 0; i < TX_TYPES; i++) {
+ if (allowed_tx_mask & (1 << i)) num_allowed++;
+ }
+ assert(num_allowed > 0);
+
+ if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
+ int pf = prune_factors[x->prune_mode];
+ int mf = mul_factors[x->prune_mode];
+ if (num_allowed <= 7) {
+ const uint16_t prune = prune_txk_type(
+ cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
+ txk_map, allowed_tx_mask, pf, txb_ctx, cm->reduced_tx_set_used);
+ allowed_tx_mask &= (~prune);
+ } else {
+ const int num_sel = (num_allowed * mf + 50) / 100;
+ const uint16_t prune = prune_txk_type_separ(
+ cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
+ txk_map, allowed_tx_mask, pf, txb_ctx, cm->reduced_tx_set_used,
+ ref_best_rd, num_sel);
+
+ allowed_tx_mask &= (~prune);
+ }
+ } else {
+ assert(num_allowed > 0);
+ int allowed_tx_count = (x->prune_mode == PRUNE_2D_AGGRESSIVE) ? 1 : 5;
+ // !fast_tx_search && txk_end != txk_start && plane == 0
+ if (x->prune_mode >= PRUNE_2D_ACCURATE && is_inter &&
+ num_allowed > allowed_tx_count) {
+ prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
+ x->prune_mode, txk_map, &allowed_tx_mask);
+ }
+ }
+ }
+
+ // Need to have at least one transform type allowed.
+ if (allowed_tx_mask == 0) {
+ txk_allowed = (plane ? uv_tx_type : DCT_DCT);
+ allowed_tx_mask = (1 << txk_allowed);
+ }
+
+ assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
+
+ // Tranform domain distortion is accurate for higher residuals.
+ // TODO(any): Experiment with variance and mean based thresholds
+ int use_transform_domain_distortion =
+ (x->use_transform_domain_distortion > 0) &&
+ (block_mse_q8 >= x->tx_domain_dist_threshold) &&
+ // Any 64-pt transforms only preserves half the coefficients.
+ // Therefore transform domain distortion is not valid for these
+ // transform sizes.
+ txsize_sqr_up_map[tx_size] != TX_64X64;
+#if CONFIG_DIST_8X8
+ if (x->using_dist_8x8) use_transform_domain_distortion = 0;
+#endif
+ int calc_pixel_domain_distortion_final =
+ x->use_transform_domain_distortion == 1 &&
+ use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
+ if (calc_pixel_domain_distortion_final &&
+ (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
+ calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
+
+ const uint16_t *eobs_ptr = x->plane[plane].eobs;
+
+ TxfmParam txfm_param;
+ QUANT_PARAM quant_param;
+ av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
+ av1_setup_quant(cm, tx_size, !skip_trellis,
+ skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
+ : AV1_XFORM_QUANT_FP)
+ : AV1_XFORM_QUANT_FP,
+ &quant_param);
+ int use_qm = !(xd->lossless[mbmi->segment_id] || cm->using_qmatrix == 0);
+
+ for (int idx = 0; idx < TX_TYPES; ++idx) {
+ const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
+ if (!(allowed_tx_mask & (1 << tx_type))) continue;
+ txfm_param.tx_type = tx_type;
+ if (use_qm) {
+ av1_setup_qmatrix(cm, x, plane, tx_size, tx_type, &quant_param);
+ }
+ if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
+ RD_STATS this_rd_stats;
+ av1_invalid_rd_stats(&this_rd_stats);
+
+ av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
+ &quant_param);
+
+ if (quant_param.use_optimize_b) {
+ if (cpi->sf.rd_sf.optimize_b_precheck && best_rd < INT64_MAX &&
+ eobs_ptr[block] >= 4) {
+ // Calculate distortion quickly in transform domain.
+ dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
+ &this_rd_stats.sse);
+
+ const int64_t best_rd_ = AOMMIN(best_rd, ref_best_rd);
+ const int64_t dist_cost_estimate =
+ RDCOST(x->rdmult, 0, AOMMIN(this_rd_stats.dist, this_rd_stats.sse));
+ if (dist_cost_estimate - (dist_cost_estimate >> 3) > best_rd_) continue;
+ }
+ av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
+ cpi->sf.rd_sf.trellis_eob_fast, &rate_cost);
+ } else {
+ rate_cost =
+ av1_cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
+ use_fast_coef_costing, cm->reduced_tx_set_used);
+ }
+
+ // If rd cost based on coeff rate is more than best_rd, skip the calculation
+ // of distortion
+ int64_t tmp_rd = RDCOST(x->rdmult, rate_cost, 0);
+ if (tmp_rd > best_rd) continue;
+ if (eobs_ptr[block] == 0) {
+ // When eob is 0, pixel domain distortion is more efficient and accurate.
+ this_rd_stats.dist = this_rd_stats.sse = block_sse;
+ } else if (use_transform_domain_distortion) {
+ dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
+ &this_rd_stats.sse);
+ } else {
+ int64_t sse_diff = INT64_MAX;
+ // high_energy threshold assumes that every pixel within a txfm block
+ // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
+ // for 8 bit, then the threshold is scaled based on input bit depth.
+ const int64_t high_energy_thresh =
+ ((int64_t)128 * 128 * tx_size_2d[tx_size]) << ((xd->bd - 8) * 2);
+ const int is_high_energy = (block_sse >= high_energy_thresh);
+ if (tx_size == TX_64X64 || is_high_energy) {
+ // Because 3 out 4 quadrants of transform coefficients are forced to
+ // zero, the inverse transform has a tendency to overflow. sse_diff
+ // is effectively the energy of those 3 quadrants, here we use it
+ // to decide if we should do pixel domain distortion. If the energy
+ // is mostly in first quadrant, then it is unlikely that we have
+ // overflow issue in inverse transform.
+ dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
+ &this_rd_stats.sse);
+ sse_diff = block_sse - this_rd_stats.sse;
+ }
+ if (tx_size != TX_64X64 || !is_high_energy ||
+ (sse_diff * 2) < this_rd_stats.sse) {
+ const int64_t tx_domain_dist = this_rd_stats.dist;
+ this_rd_stats.dist = dist_block_px_domain(
+ cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
+ // For high energy blocks, occasionally, the pixel domain distortion
+ // can be artificially low due to clamping at reconstruction stage
+ // even when inverse transform output is hugely different from the
+ // actual residue.
+ if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
+ this_rd_stats.dist = tx_domain_dist;
+ } else {
+ this_rd_stats.dist += sse_diff;
+ }
+ this_rd_stats.sse = block_sse;
+ }
+
+ this_rd_stats.rate = rate_cost;
+
+ const int64_t rd =
+ RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
+
+ if (rd < best_rd) {
+ best_rd = rd;
+ *best_rd_stats = this_rd_stats;
+ best_tx_type = tx_type;
+ best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
+ best_eob = x->plane[plane].eobs[block];
+ last_tx_type = best_tx_type;
+
+ // Swap qcoeff and dqcoeff buffers
+ tran_low_t *const tmp_dqcoeff = best_dqcoeff;
+ best_dqcoeff = pd->dqcoeff;
+ pd->dqcoeff = tmp_dqcoeff;
+ }
+
+#if CONFIG_COLLECT_RD_STATS == 1
+ if (plane == 0) {
+ PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
+ plane_bsize, tx_size, tx_type, rd);
+ }
+#endif // CONFIG_COLLECT_RD_STATS == 1
+
+#if COLLECT_TX_SIZE_DATA
+ // Generate small sample to restrict output size.
+ static unsigned int seed = 21743;
+ if (lcg_rand16(&seed) % 200 == 0) {
+ FILE *fp = NULL;
+
+ if (within_border) {
+ fp = fopen(av1_tx_size_data_output_file, "a");
+ }
+
+ if (fp) {
+ // Transform info and RD
+ const int txb_w = tx_size_wide[tx_size];
+ const int txb_h = tx_size_high[tx_size];
+
+ // Residue signal.
+ const int diff_stride = block_size_wide[plane_bsize];
+ struct macroblock_plane *const p = &x->plane[plane];
+ const int16_t *src_diff =
+ &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
+
+ for (int r = 0; r < txb_h; ++r) {
+ for (int c = 0; c < txb_w; ++c) {
+ fprintf(fp, "%d,", src_diff[c]);
+ }
+ src_diff += diff_stride;
+ }
+
+ fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
+ fprintf(fp, "\n");
+ fclose(fp);
+ }
+ }
+#endif // COLLECT_TX_SIZE_DATA
+
+ if (cpi->sf.tx_sf.adaptive_txb_search_level) {
+ if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
+ ref_best_rd) {
+ break;
+ }
+ }
+
+ // Skip transform type search when we found the block has been quantized to
+ // all zero and at the same time, it has better rdcost than doing transform.
+ if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
+ }
+
+ assert(best_rd != INT64_MAX);
+
+ best_rd_stats->skip = best_eob == 0;
+ if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
+ x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
+ x->plane[plane].eobs[block] = best_eob;
+
+ pd->dqcoeff = best_dqcoeff;
+
+ if (calc_pixel_domain_distortion_final && best_eob) {
+ best_rd_stats->dist = dist_block_px_domain(
+ cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
+ best_rd_stats->sse = block_sse;
+ }
+
+ if (intra_txb_rd_info != NULL) {
+ intra_txb_rd_info->valid = 1;
+ intra_txb_rd_info->entropy_context = cur_joint_ctx;
+ intra_txb_rd_info->rate = best_rd_stats->rate;
+ intra_txb_rd_info->dist = best_rd_stats->dist;
+ intra_txb_rd_info->sse = best_rd_stats->sse;
+ intra_txb_rd_info->eob = best_eob;
+ intra_txb_rd_info->txb_entropy_ctx = best_txb_ctx;
+ intra_txb_rd_info->perform_block_coeff_opt = perform_block_coeff_opt;
+ if (plane == 0) intra_txb_rd_info->tx_type = best_tx_type;
+ }
+
+ recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+ txb_ctx, skip_trellis, best_tx_type, last_tx_type, &rate_cost,
+ best_eob);
+ pd->dqcoeff = orig_dqcoeff;
+}
+
+// Pick transform type for a transform block of tx_size.
+static AOM_INLINE void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
+ TX_SIZE tx_size, int blk_row, int blk_col,
+ int plane, int block, int plane_bsize,
+ TXB_CTX *txb_ctx, RD_STATS *rd_stats,
+ FAST_TX_SEARCH_MODE ftxs_mode,
+ int64_t ref_rdcost,
+ TXB_RD_INFO *rd_info_array) {
+ const struct macroblock_plane *const p = &x->plane[plane];
+ const uint16_t cur_joint_ctx =
+ (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
+ MACROBLOCKD *xd = &x->e_mbd;
+ const int tx_type_map_idx =
+ plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
+ // Look up RD and terminate early in case when we've already processed exactly
+ // the same residual with exactly the same entropy context.
+ if (rd_info_array != NULL && rd_info_array->valid &&
+ rd_info_array->entropy_context == cur_joint_ctx) {
+ if (plane == 0) xd->tx_type_map[tx_type_map_idx] = rd_info_array->tx_type;
+ const TX_TYPE ref_tx_type =
+ av1_get_tx_type(&x->e_mbd, get_plane_type(plane), blk_row, blk_col,
+ tx_size, cpi->common.reduced_tx_set_used);
+ if (ref_tx_type == rd_info_array->tx_type) {
+ rd_stats->rate += rd_info_array->rate;
+ rd_stats->dist += rd_info_array->dist;
+ rd_stats->sse += rd_info_array->sse;
+ rd_stats->skip &= rd_info_array->eob == 0;
+ p->eobs[block] = rd_info_array->eob;
+ p->txb_entropy_ctx[block] = rd_info_array->txb_entropy_ctx;
+ return;
+ }
+ }
+
+ RD_STATS this_rd_stats;
+ search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+ txb_ctx, ftxs_mode, 0, 0, ref_rdcost, &this_rd_stats);
+
+ av1_merge_rd_stats(rd_stats, &this_rd_stats);
+
+ // Save RD results for possible reuse in future.
+ if (rd_info_array != NULL) {
+ rd_info_array->valid = 1;
+ rd_info_array->entropy_context = cur_joint_ctx;
+ rd_info_array->rate = this_rd_stats.rate;
+ rd_info_array->dist = this_rd_stats.dist;
+ rd_info_array->sse = this_rd_stats.sse;
+ rd_info_array->eob = p->eobs[block];
+ rd_info_array->txb_entropy_ctx = p->txb_entropy_ctx[block];
+ if (plane == 0) rd_info_array->tx_type = xd->tx_type_map[tx_type_map_idx];
+ }
+}
+
+static AOM_INLINE void try_tx_block_no_split(
+ const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
+ TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
+ const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
+ int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
+ FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
+ TxCandidateInfo *no_split) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ struct macroblock_plane *const p = &x->plane[0];
+ const int bw = mi_size_wide[plane_bsize];
+
+ no_split->rd = INT64_MAX;
+ no_split->txb_entropy_ctx = 0;
+ no_split->tx_type = TX_TYPES;
+
+ const ENTROPY_CONTEXT *const pta = ta + blk_col;
+ const ENTROPY_CONTEXT *const ptl = tl + blk_row;
+
+ const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
+ TXB_CTX txb_ctx;
+ get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
+ const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
+ .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
+ rd_stats->zero_rate = zero_blk_rate;
+ const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
+ mbmi->inter_tx_size[index] = tx_size;
+ tx_type_rd(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize, &txb_ctx,
+ rd_stats, ftxs_mode, ref_best_rd,
+ rd_info_node != NULL ? rd_info_node->rd_info_array : NULL);
+ assert(rd_stats->rate < INT_MAX);
+
+ if ((RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
+ RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
+ rd_stats->skip == 1) &&
+ !xd->lossless[mbmi->segment_id]) {
+#if CONFIG_RD_DEBUG
+ av1_update_txb_coeff_cost(rd_stats, 0, tx_size, blk_row, blk_col,
+ zero_blk_rate - rd_stats->rate);
+#endif // CONFIG_RD_DEBUG
+ rd_stats->rate = zero_blk_rate;
+ rd_stats->dist = rd_stats->sse;
+ rd_stats->skip = 1;
+ set_blk_skip(x, 0, blk_row * bw + blk_col, 1);
+ p->eobs[block] = 0;
+ update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
+ } else {
+ set_blk_skip(x, 0, blk_row * bw + blk_col, 0);
+ rd_stats->skip = 0;
+ }
+
+ if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
+ rd_stats->rate += x->txfm_partition_cost[txfm_partition_ctx][0];
+
+ no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
+ no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
+ no_split->tx_type =
+ xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
+}
+
+static AOM_INLINE void try_tx_block_split(
+ const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
+ TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
+ ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
+ int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
+ FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
+ RD_STATS *split_rd_stats, int64_t *split_rd) {
+ assert(tx_size < TX_SIZES_ALL);
+ MACROBLOCKD *const xd = &x->e_mbd;
+ const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
+ const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
+ const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
+ const int bsw = tx_size_wide_unit[sub_txs];
+ const int bsh = tx_size_high_unit[sub_txs];
+ const int sub_step = bsw * bsh;
+ const int nblks =
+ (tx_size_high_unit[tx_size] / bsh) * (tx_size_wide_unit[tx_size] / bsw);
+ assert(nblks > 0);
+ int blk_idx = 0;
+ int64_t tmp_rd = 0;
+ *split_rd = INT64_MAX;
+ split_rd_stats->rate = x->txfm_partition_cost[txfm_partition_ctx][1];
+
+ for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
+ for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw, ++blk_idx) {
+ assert(blk_idx < 4);
+ const int offsetr = blk_row + r;
+ const int offsetc = blk_col + c;
+ if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+
+ RD_STATS this_rd_stats;
+ int this_cost_valid = 1;
+ select_tx_block(
+ cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta,
+ tl, tx_above, tx_left, &this_rd_stats, no_split_rd / nblks,
+ ref_best_rd - tmp_rd, &this_cost_valid, ftxs_mode,
+ (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
+ if (!this_cost_valid) return;
+ av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
+ tmp_rd = RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
+ if (no_split_rd < tmp_rd) return;
+ block += sub_step;
+ }
+ }
+
+ *split_rd = tmp_rd;
+}
+
+// Search for the best tx partition/type for a given luma block.
+static AOM_INLINE void select_tx_block(
+ const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
+ TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
+ ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
+ RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
+ int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
+ TXB_RD_INFO_NODE *rd_info_node) {
+ assert(tx_size < TX_SIZES_ALL);
+ av1_init_rd_stats(rd_stats);
+ if (ref_best_rd < 0) {
+ *is_cost_valid = 0;
+ return;
+ }
+
+ MACROBLOCKD *const xd = &x->e_mbd;
+ const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
+ const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
+ if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
+
+ const int bw = mi_size_wide[plane_bsize];
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
+ mbmi->sb_type, tx_size);
+ struct macroblock_plane *const p = &x->plane[0];
+
+ const int try_no_split =
+ cpi->oxcf.enable_tx64 || txsize_sqr_up_map[tx_size] != TX_64X64;
+ int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
+#if CONFIG_DIST_8X8
+ if (x->using_dist_8x8)
+ try_split &= tx_size_wide[tx_size] >= 16 && tx_size_high[tx_size] >= 16;
+#endif
+ TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
+
+ // TX no split
+ if (try_no_split) {
+ try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
+ plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
+ ftxs_mode, rd_info_node, &no_split);
+
+ if (cpi->sf.tx_sf.adaptive_txb_search_level &&
+ (no_split.rd -
+ (no_split.rd >> (1 + cpi->sf.tx_sf.adaptive_txb_search_level))) >
+ ref_best_rd) {
+ *is_cost_valid = 0;
+ return;
+ }
+
+ if (cpi->sf.tx_sf.txb_split_cap) {
+ if (p->eobs[block] == 0) try_split = 0;
+ }
+
+ if (cpi->sf.tx_sf.adaptive_txb_search_level &&
+ (no_split.rd -
+ (no_split.rd >> (2 + cpi->sf.tx_sf.adaptive_txb_search_level))) >
+ prev_level_rd) {
+ try_split = 0;
+ }
+ }
+
+ if (x->e_mbd.bd == 8 && try_split &&
+ !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
+ const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
+ if (threshold >= 0) {
+ const int split_score =
+ ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
+ if (split_score < -threshold) try_split = 0;
+ }
+ }
+
+ // TX split
+ int64_t split_rd = INT64_MAX;
+ RD_STATS split_rd_stats;
+ av1_init_rd_stats(&split_rd_stats);
+ if (try_split) {
+ try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
+ plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
+ AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
+ rd_info_node, &split_rd_stats, &split_rd);
+ }
+
+ if (no_split.rd < split_rd) {
+ ENTROPY_CONTEXT *pta = ta + blk_col;
+ ENTROPY_CONTEXT *ptl = tl + blk_row;
+ const TX_SIZE tx_size_selected = tx_size;
+ p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
+ av1_set_txb_context(x, 0, block, tx_size_selected, pta, ptl);
+ txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
+ tx_size);
+ for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
+ for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
+ const int index =
+ av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
+ mbmi->inter_tx_size[index] = tx_size_selected;
+ }
+ }
+ mbmi->tx_size = tx_size_selected;
+ update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
+ set_blk_skip(x, 0, blk_row * bw + blk_col, rd_stats->skip);
+ } else {
+ *rd_stats = split_rd_stats;
+ if (split_rd == INT64_MAX) *is_cost_valid = 0;
+ }
+}
+
+static AOM_INLINE void choose_largest_tx_size(const AV1_COMP *const cpi,
+ MACROBLOCK *x, RD_STATS *rd_stats,
+ int64_t ref_best_rd,
+ BLOCK_SIZE bs) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ mbmi->tx_size = tx_size_from_tx_mode(bs, x->tx_mode_search_type);
+
+ // If tx64 is not enabled, we need to go down to the next available size
+ if (!cpi->oxcf.enable_tx64) {
+ static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
+ TX_4X4, // 4x4 transform
+ TX_8X8, // 8x8 transform
+ TX_16X16, // 16x16 transform
+ TX_32X32, // 32x32 transform
+ TX_32X32, // 64x64 transform
+ TX_4X8, // 4x8 transform
+ TX_8X4, // 8x4 transform
+ TX_8X16, // 8x16 transform
+ TX_16X8, // 16x8 transform
+ TX_16X32, // 16x32 transform
+ TX_32X16, // 32x16 transform
+ TX_32X32, // 32x64 transform
+ TX_32X32, // 64x32 transform
+ TX_4X16, // 4x16 transform
+ TX_16X4, // 16x4 transform
+ TX_8X32, // 8x32 transform
+ TX_32X8, // 32x8 transform
+ TX_16X32, // 16x64 transform
+ TX_32X16, // 64x16 transform
+ };
+
+ mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
+ }
+
+ const int skip_ctx = av1_get_skip_context(xd);
+ int s0, s1;
+
+ s0 = x->skip_cost[skip_ctx][0];
+ s1 = x->skip_cost[skip_ctx][1];
+
+ int64_t skip_rd = INT64_MAX;
+ int64_t this_rd = RDCOST(x->rdmult, s0, 0);
+
+ // Skip RDcost is used only for Inter blocks
+ if (is_inter_block(xd->mi[0])) skip_rd = RDCOST(x->rdmult, s1, 0);
+
+ txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, AOMMIN(this_rd, skip_rd),
+ AOM_PLANE_Y, bs, mbmi->tx_size,
+ cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE, 0);
+}
+
+static AOM_INLINE void choose_smallest_tx_size(const AV1_COMP *const cpi,
+ MACROBLOCK *x,
+ RD_STATS *rd_stats,
+ int64_t ref_best_rd,
+ BLOCK_SIZE bs) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+
+ mbmi->tx_size = TX_4X4;
+ // TODO(any) : Pass this_rd based on skip/non-skip cost
+ txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
+ cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE, 0);
+}
+
+static AOM_INLINE void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
+ MACROBLOCK *x,
+ RD_STATS *rd_stats,
+ int64_t ref_best_rd,
+ BLOCK_SIZE bs) {
+ av1_invalid_rd_stats(rd_stats);
+
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
+ const int tx_select = x->tx_mode_search_type == TX_MODE_SELECT;
+ int start_tx;
+ int depth, init_depth;
+
+ if (tx_select) {
+ start_tx = max_rect_tx_size;
+ init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
+ is_inter_block(mbmi), &cpi->sf,
+ x->tx_size_search_method);
+ } else {
+ const TX_SIZE chosen_tx_size =
+ tx_size_from_tx_mode(bs, x->tx_mode_search_type);
+ start_tx = chosen_tx_size;
+ init_depth = MAX_TX_DEPTH;
+ }
+
+ uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
+ uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
+ TX_SIZE best_tx_size = max_rect_tx_size;
+ int64_t best_rd = INT64_MAX;
+ const int n4 = bsize_to_num_blk(bs);
+ x->rd_model = FULL_TXFM_RD;
+ depth = init_depth;
+ int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
+ for (int n = start_tx; depth <= MAX_TX_DEPTH;
+ depth++, n = sub_tx_size_map[n]) {
+#if CONFIG_DIST_8X8
+ if (x->using_dist_8x8) {
+ if (tx_size_wide[n] < 8 || tx_size_high[n] < 8) continue;
+ }
+#endif
+ if (!cpi->oxcf.enable_tx64 && txsize_sqr_up_map[n] == TX_64X64) continue;
+
+ RD_STATS this_rd_stats;
+ rd[depth] =
+ txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, n, FTXS_NONE, 0);
+
+ if (rd[depth] < best_rd) {
+ av1_copy_array(best_blk_skip, x->blk_skip, n4);
+ av1_copy_array(best_txk_type_map, xd->tx_type_map, n4);
+ best_tx_size = n;
+ best_rd = rd[depth];
+ *rd_stats = this_rd_stats;
+ }
+ if (n == TX_4X4) break;
+ // If we are searching three depths, prune the smallest size depending
+ // on rd results for the first two depths for low contrast blocks.
+ if (depth > init_depth && depth != MAX_TX_DEPTH &&
+ x->source_variance < 256) {
+ if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
+ }
+ }
+
+ if (rd_stats->rate != INT_MAX) {
+ mbmi->tx_size = best_tx_size;
+ av1_copy_array(xd->tx_type_map, best_txk_type_map, n4);
+ av1_copy_array(x->blk_skip, best_blk_skip, n4);
+ }
+}
+
+static AOM_INLINE void block_rd_txfm(int plane, int block, int blk_row,
+ int blk_col, BLOCK_SIZE plane_bsize,
+ TX_SIZE tx_size, void *arg) {
+ struct rdcost_block_args *args = arg;
+ MACROBLOCK *const x = args->x;
+ MACROBLOCKD *const xd = &x->e_mbd;
+ const int is_inter = is_inter_block(xd->mi[0]);
+ const AV1_COMP *cpi = args->cpi;
+ ENTROPY_CONTEXT *a = args->t_above + blk_col;
+ ENTROPY_CONTEXT *l = args->t_left + blk_row;
+ const AV1_COMMON *cm = &cpi->common;
+ RD_STATS this_rd_stats;
+
+ av1_init_rd_stats(&this_rd_stats);
+
+ if (args->exit_early) {
+ args->incomplete_exit = 1;
+ return;
+ }
+
+ if (!is_inter) {
+ av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
+ av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
+ }
+ TXB_CTX txb_ctx;
+ get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
+ search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
+ &txb_ctx, args->ftxs_mode, args->use_fast_coef_costing,
+ args->skip_trellis, args->best_rd - args->this_rd,
+ &this_rd_stats);
+
+ if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
+ assert(!is_inter || plane_bsize < BLOCK_8X8);
+ cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
+ }
+
+#if CONFIG_RD_DEBUG
+ av1_update_txb_coeff_cost(&this_rd_stats, plane, tx_size, blk_row, blk_col,
+ this_rd_stats.rate);
+#endif // CONFIG_RD_DEBUG
+ av1_set_txb_context(x, plane, block, tx_size, a, l);
+
+ const int blk_idx =
+ blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
+
+ if (plane == 0)
+ set_blk_skip(x, plane, blk_idx, x->plane[plane].eobs[block] == 0);
+ else
+ set_blk_skip(x, plane, blk_idx, 0);
+
+ int64_t rd;
+ if (is_inter) {
+ const int64_t rd1 =
+ RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
+ const int64_t rd2 = RDCOST(x->rdmult, 0, this_rd_stats.sse);
+
+ // TODO(jingning): temporarily enabled only for luma component
+ rd = AOMMIN(rd1, rd2);
+ this_rd_stats.skip &= !x->plane[plane].eobs[block];
+ } else {
+ // Signal non-skip for Intra blocks
+ rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
+ this_rd_stats.skip = 0;
+ }
+
+ av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
+
+ args->this_rd += rd;
+
+ if (args->this_rd > args->best_rd) args->exit_early = 1;
+}
+
+int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
+ int64_t ref_best_rd, BLOCK_SIZE bs, TX_SIZE tx_size,
+ FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ int64_t rd = INT64_MAX;
+ const int skip_ctx = av1_get_skip_context(xd);
+ int s0, s1;
+ const int is_inter = is_inter_block(mbmi);
+ const int tx_select = x->tx_mode_search_type == TX_MODE_SELECT &&
+ block_signals_txsize(mbmi->sb_type);
+ int ctx = txfm_partition_context(
+ xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
+ const int r_tx_size =
+ is_inter ? x->txfm_partition_cost[ctx][0] : tx_size_cost(x, bs, tx_size);
+
+ assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
+
+ s0 = x->skip_cost[skip_ctx][0];
+ s1 = x->skip_cost[skip_ctx][1];
+
+ int64_t skip_rd = INT64_MAX;
+ int64_t this_rd = RDCOST(x->rdmult, s0 + r_tx_size * tx_select, 0);
+
+ if (is_inter) skip_rd = RDCOST(x->rdmult, s1, 0);
+
+ mbmi->tx_size = tx_size;
+ txfm_rd_in_plane(
+ x, cpi, rd_stats, ref_best_rd, AOMMIN(this_rd, skip_rd), AOM_PLANE_Y, bs,
+ tx_size, cpi->sf.rd_sf.use_fast_coef_costing, ftxs_mode, skip_trellis);
+ if (rd_stats->rate == INT_MAX) return INT64_MAX;
+
+ // rdstats->rate should include all the rate except skip/non-skip cost as the
+ // same is accounted in the caller functions after rd evaluation of all
+ // planes. However the decisions should be done after considering the
+ // skip/non-skip header cost
+ if (rd_stats->skip && is_inter) {
+ rd = RDCOST(x->rdmult, s1, rd_stats->sse);
+ } else {
+ // Intra blocks are always signalled as non-skip
+ rd = RDCOST(x->rdmult, rd_stats->rate + s0 + r_tx_size * tx_select,
+ rd_stats->dist);
+ rd_stats->rate += r_tx_size * tx_select;
+ }
+ if (is_inter && !xd->lossless[xd->mi[0]->segment_id]) {
+ int64_t temp_skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
+ if (temp_skip_rd <= rd) {
+ rd = temp_skip_rd;
+ rd_stats->rate = 0;
+ rd_stats->dist = rd_stats->sse;
+ rd_stats->skip = 1;
+ }
+ }
+
+ return rd;
+}
+
+// Finds rd cost for a y block, given the transform size partitions
+static AOM_INLINE void tx_block_yrd(
+ const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
+ TX_SIZE tx_size, BLOCK_SIZE plane_bsize, int depth,
+ ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
+ TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, int64_t ref_best_rd,
+ RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
+ const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
+
+ assert(tx_size < TX_SIZES_ALL);
+
+ if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
+
+ const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
+ plane_bsize, blk_row, blk_col)];
+
+ int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
+ mbmi->sb_type, tx_size);
+
+ av1_init_rd_stats(rd_stats);
+ if (tx_size == plane_tx_size) {
+ ENTROPY_CONTEXT *ta = above_ctx + blk_col;
+ ENTROPY_CONTEXT *tl = left_ctx + blk_row;
+ const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
+ TXB_CTX txb_ctx;
+ get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
+
+ const int zero_blk_rate = x->coeff_costs[txs_ctx][get_plane_type(0)]
+ .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
+ rd_stats->zero_rate = zero_blk_rate;
+ tx_type_rd(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize,
+ &txb_ctx, rd_stats, ftxs_mode, ref_best_rd, NULL);
+ const int mi_width = mi_size_wide[plane_bsize];
+ if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
+ RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
+ rd_stats->skip == 1) {
+ rd_stats->rate = zero_blk_rate;
+ rd_stats->dist = rd_stats->sse;
+ rd_stats->skip = 1;
+ set_blk_skip(x, 0, blk_row * mi_width + blk_col, 1);
+ x->plane[0].eobs[block] = 0;
+ x->plane[0].txb_entropy_ctx[block] = 0;
+ update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
+ } else {
+ rd_stats->skip = 0;
+ set_blk_skip(x, 0, blk_row * mi_width + blk_col, 0);
+ }
+ if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
+ rd_stats->rate += x->txfm_partition_cost[ctx][0];
+ av1_set_txb_context(x, 0, block, tx_size, ta, tl);
+ txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
+ tx_size);
+ } else {
+ const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
+ const int bsw = tx_size_wide_unit[sub_txs];
+ const int bsh = tx_size_high_unit[sub_txs];
+ const int step = bsh * bsw;
+ RD_STATS pn_rd_stats;
+ int64_t this_rd = 0;
+ assert(bsw > 0 && bsh > 0);
+
+ for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
+ for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
+ const int offsetr = blk_row + row;
+ const int offsetc = blk_col + col;
+
+ if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
+
+ av1_init_rd_stats(&pn_rd_stats);
+ tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
+ depth + 1, above_ctx, left_ctx, tx_above, tx_left,
+ ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
+ if (pn_rd_stats.rate == INT_MAX) {
+ av1_invalid_rd_stats(rd_stats);
+ return;
+ }
+ av1_merge_rd_stats(rd_stats, &pn_rd_stats);
+ this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
+ block += step;
+ }
+ }
+
+ if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
+ rd_stats->rate += x->txfm_partition_cost[ctx][1];
+ }
+}
+
+// Return value 0: early termination triggered, no valid rd cost available;
+// 1: rd cost values are valid.
+static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, BLOCK_SIZE bsize,
+ int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ int is_cost_valid = 1;
+ int64_t this_rd = 0;
+
+ if (ref_best_rd < 0) is_cost_valid = 0;
+
+ av1_init_rd_stats(rd_stats);
+
+ if (is_cost_valid) {
+ const struct macroblockd_plane *const pd = &xd->plane[0];
+ const BLOCK_SIZE plane_bsize =
+ get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
+ const int mi_width = mi_size_wide[plane_bsize];
+ const int mi_height = mi_size_high[plane_bsize];
+ const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, plane_bsize, 0);
+ const int bh = tx_size_high_unit[max_tx_size];
+ const int bw = tx_size_wide_unit[max_tx_size];
+ const int init_depth = get_search_init_depth(
+ mi_width, mi_height, 1, &cpi->sf, x->tx_size_search_method);
+ int idx, idy;
+ int block = 0;
+ int step = tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
+ ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
+ ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
+ TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
+ TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
+ RD_STATS pn_rd_stats;
+
+ av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
+ memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
+ memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
+
+ for (idy = 0; idy < mi_height; idy += bh) {
+ for (idx = 0; idx < mi_width; idx += bw) {
+ av1_init_rd_stats(&pn_rd_stats);
+ tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, plane_bsize,
+ init_depth, ctxa, ctxl, tx_above, tx_left,
+ ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
+ if (pn_rd_stats.rate == INT_MAX) {
+ av1_invalid_rd_stats(rd_stats);
+ return 0;
+ }
+ av1_merge_rd_stats(rd_stats, &pn_rd_stats);
+ this_rd +=
+ AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
+ RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
+ block += step;
+ }
+ }
+ }
+
+ const int skip_ctx = av1_get_skip_context(xd);
+ const int s0 = x->skip_cost[skip_ctx][0];
+ const int s1 = x->skip_cost[skip_ctx][1];
+ int64_t skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
+ this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
+ if (skip_rd < this_rd) {
+ this_rd = skip_rd;
+ rd_stats->rate = 0;
+ rd_stats->dist = rd_stats->sse;
+ rd_stats->skip = 1;
+ }
+ if (this_rd > ref_best_rd) is_cost_valid = 0;
+
+ if (!is_cost_valid) {
+ // reset cost value
+ av1_invalid_rd_stats(rd_stats);
+ }
+ return is_cost_valid;
+}
+
+static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, BLOCK_SIZE bsize,
+ int64_t ref_best_rd,
+ TXB_RD_INFO_NODE *rd_info_tree) {
+ MACROBLOCKD *const xd = &x->e_mbd;
+ assert(is_inter_block(xd->mi[0]));
+ assert(bsize < BLOCK_SIZES_ALL);
+
+ // TODO(debargha): enable this as a speed feature where the
+ // select_inter_block_yrd() function above will use a simplified search
+ // such as not using full optimize, but the inter_block_yrd() function
+ // will use more complex search given that the transform partitions have
+ // already been decided.
+
+ const int fast_tx_search = x->tx_size_search_method > USE_FULL_RD;
+ int64_t rd_thresh = ref_best_rd;
+ if (fast_tx_search && rd_thresh < INT64_MAX) {
+ if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
+ }
+ assert(rd_thresh > 0);
+
+ const FAST_TX_SEARCH_MODE ftxs_mode =
+ fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
+ const struct macroblockd_plane *const pd = &xd->plane[0];
+ const BLOCK_SIZE plane_bsize =
+ get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
+ assert(plane_bsize < BLOCK_SIZES_ALL);
+ const int mi_width = mi_size_wide[plane_bsize];
+ const int mi_height = mi_size_high[plane_bsize];
+ ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
+ ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
+ TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
+ TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
+ av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
+ memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
+ memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
+
+ const int skip_ctx = av1_get_skip_context(xd);
+ const int s0 = x->skip_cost[skip_ctx][0];
+ const int s1 = x->skip_cost[skip_ctx][1];
+ const int init_depth = get_search_init_depth(mi_width, mi_height, 1, &cpi->sf,
+ x->tx_size_search_method);
+ const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
+ const int bh = tx_size_high_unit[max_tx_size];
+ const int bw = tx_size_wide_unit[max_tx_size];
+ const int step = bw * bh;
+ int64_t skip_rd = RDCOST(x->rdmult, s1, 0);
+ int64_t this_rd = RDCOST(x->rdmult, s0, 0);
+ int block = 0;
+
+ av1_init_rd_stats(rd_stats);
+ for (int idy = 0; idy < mi_height; idy += bh) {
+ for (int idx = 0; idx < mi_width; idx += bw) {
+ const int64_t best_rd_sofar =
+ (rd_thresh == INT64_MAX) ? INT64_MAX
+ : (rd_thresh - (AOMMIN(skip_rd, this_rd)));
+ int is_cost_valid = 1;
+ RD_STATS pn_rd_stats;
+ select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth,
+ plane_bsize, ctxa, ctxl, tx_above, tx_left, &pn_rd_stats,
+ INT64_MAX, best_rd_sofar, &is_cost_valid, ftxs_mode,
+ rd_info_tree);
+ if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
+ av1_invalid_rd_stats(rd_stats);
+ return INT64_MAX;
+ }
+ av1_merge_rd_stats(rd_stats, &pn_rd_stats);
+ skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
+ this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
+ block += step;
+ if (rd_info_tree != NULL) rd_info_tree += 1;
+ }
+ }
+
+ if (skip_rd <= this_rd) {
+ rd_stats->skip = 1;
+ } else {
+ rd_stats->skip = 0;
+ }
+
+ if (rd_stats->rate == INT_MAX) return INT64_MAX;
+
+ // If fast_tx_search is true, only DCT and 1D DCT were tested in
+ // select_inter_block_yrd() above. Do a better search for tx type with
+ // tx sizes already decided.
+ if (fast_tx_search) {
+ if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
+ return INT64_MAX;
+ }
+
+ int64_t rd;
+ if (rd_stats->skip) {
+ rd = RDCOST(x->rdmult, s1, rd_stats->sse);
+ } else {
+ rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
+ if (!xd->lossless[xd->mi[0]->segment_id])
+ rd = AOMMIN(rd, RDCOST(x->rdmult, s1, rd_stats->sse));
+ }
+
+ return rd;
+}
+
+// Search for best transform size and type for luma inter blocks.
+void pick_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, BLOCK_SIZE bsize,
+ int64_t ref_best_rd) {
+ const AV1_COMMON *cm = &cpi->common;
+ MACROBLOCKD *const xd = &x->e_mbd;
+ assert(is_inter_block(xd->mi[0]));
+
+ av1_invalid_rd_stats(rd_stats);
+
+ if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
+ ref_best_rd != INT64_MAX) {
+ int model_rate;
+ int64_t model_dist;
+ int model_skip;
+ model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
+ cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
+ NULL, NULL, NULL);
+ const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
+ // If the modeled rd is a lot worse than the best so far, breakout.
+ // TODO(debargha, urvang): Improve the model and make the check below
+ // tighter.
+ assert(cpi->sf.tx_sf.model_based_prune_tx_search_level >= 0 &&
+ cpi->sf.tx_sf.model_based_prune_tx_search_level <= 2);
+ static const int prune_factor_by8[] = { 3, 5 };
+ if (!model_skip &&
+ ((model_rd *
+ prune_factor_by8[cpi->sf.tx_sf.model_based_prune_tx_search_level -
+ 1]) >>
+ 3) > ref_best_rd)
+ return;
+ }
+
+ uint32_t hash = 0;
+ int32_t match_index = -1;
+ MB_RD_RECORD *mb_rd_record = NULL;
+ const int mi_row = x->e_mbd.mi_row;
+ const int mi_col = x->e_mbd.mi_col;
+ const int within_border =
+ mi_row >= xd->tile.mi_row_start &&
+ (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
+ mi_col >= xd->tile.mi_col_start &&
+ (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
+ const int is_mb_rd_hash_enabled =
+ (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
+ const int n4 = bsize_to_num_blk(bsize);
+ if (is_mb_rd_hash_enabled) {
+ hash = get_block_residue_hash(x, bsize);
+ mb_rd_record = &x->mb_rd_record;
+ match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
+ if (match_index != -1) {
+ MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
+ fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
+ return;
+ }
+ }
+
+ // If we predict that skip is the optimal RD decision - set the respective
+ // context and terminate early.
+ int64_t dist;
+ if (x->predict_skip_level &&
+ predict_skip_flag(x, bsize, &dist, cm->reduced_tx_set_used)) {
+ set_skip_flag(x, rd_stats, bsize, dist);
+ // Save the RD search results into tx_rd_record.
+ if (is_mb_rd_hash_enabled)
+ save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
+ return;
+ }
+#if CONFIG_SPEED_STATS
+ ++x->tx_search_count;
+#endif // CONFIG_SPEED_STATS
+
+ // Precompute residual hashes and find existing or add new RD records to
+ // store and reuse rate and distortion values to speed up TX size search.
+ TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
+ int found_rd_info = 0;
+ if (ref_best_rd != INT64_MAX && within_border &&
+ cpi->sf.tx_sf.use_inter_txb_hash) {
+ found_rd_info = find_tx_size_rd_records(x, bsize, matched_rd_info);
+ }
+
+ int found = 0;
+ RD_STATS this_rd_stats;
+ av1_init_rd_stats(&this_rd_stats);
+ const int64_t rd =
+ select_tx_size_and_type(cpi, x, &this_rd_stats, bsize, ref_best_rd,
+ found_rd_info ? matched_rd_info : NULL);
+
+ if (rd < INT64_MAX) {
+ *rd_stats = this_rd_stats;
+ found = 1;
+ }
+
+ // We should always find at least one candidate unless ref_best_rd is less
+ // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
+ // might have failed to find something better)
+ assert(IMPLIES(!found, ref_best_rd != INT64_MAX));
+ if (!found) return;
+
+ // Save the RD search results into tx_rd_record.
+ if (is_mb_rd_hash_enabled) {
+ assert(mb_rd_record != NULL);
+ save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
+ }
+}
+
+void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, BLOCK_SIZE bs, int64_t ref_best_rd) {
+ MACROBLOCKD *xd = &x->e_mbd;
+ av1_init_rd_stats(rd_stats);
+ int is_inter = is_inter_block(xd->mi[0]);
+ assert(bs == xd->mi[0]->sb_type);
+
+ const int mi_row = -xd->mb_to_top_edge >> (3 + MI_SIZE_LOG2);
+ const int mi_col = -xd->mb_to_left_edge >> (3 + MI_SIZE_LOG2);
+
+ uint32_t hash = 0;
+ int32_t match_index = -1;
+ MB_RD_RECORD *mb_rd_record = NULL;
+ const int within_border = mi_row >= xd->tile.mi_row_start &&
+ (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
+ mi_col >= xd->tile.mi_col_start &&
+ (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
+ const int is_mb_rd_hash_enabled =
+ (within_border && cpi->sf.rd_sf.use_mb_rd_hash && is_inter);
+ const int n4 = bsize_to_num_blk(bs);
+ if (is_mb_rd_hash_enabled) {
+ hash = get_block_residue_hash(x, bs);
+ mb_rd_record = &x->mb_rd_record;
+ match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
+ if (match_index != -1) {
+ MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
+ fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
+ return;
+ }
+ }
+
+ // If we predict that skip is the optimal RD decision - set the respective
+ // context and terminate early.
+ int64_t dist;
+
+ if (x->predict_skip_level && is_inter &&
+ (!xd->lossless[xd->mi[0]->segment_id]) &&
+ predict_skip_flag(x, bs, &dist, cpi->common.reduced_tx_set_used)) {
+ // Populate rdstats as per skip decision
+ set_skip_flag(x, rd_stats, bs, dist);
+ // Save the RD search results into tx_rd_record.
+ if (is_mb_rd_hash_enabled)
+ save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
+ return;
+ }
+
+ if (xd->lossless[xd->mi[0]->segment_id]) {
+ choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
+ } else if (x->tx_size_search_method == USE_LARGESTALL) {
+ choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
+ } else {
+ choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
+ }
+
+ // Save the RD search results into tx_rd_record.
+ if (is_mb_rd_hash_enabled) {
+ assert(mb_rd_record != NULL);
+ save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
+ }
+}
+
+// Return value 0: early termination triggered, no valid rd cost available;
+// 1: rd cost values are valid.
+int super_block_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, BLOCK_SIZE bsize,
+ int64_t ref_best_rd) {
+ av1_init_rd_stats(rd_stats);
+ int is_cost_valid = 1;
+ if (ref_best_rd < 0) is_cost_valid = 0;
+ if (x->skip_chroma_rd || !is_cost_valid) return is_cost_valid;
+
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
+ int plane;
+ const int is_inter = is_inter_block(mbmi);
+ int64_t this_rd = 0, skip_rd = 0;
+ const BLOCK_SIZE plane_bsize =
+ get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
+
+ if (is_inter && is_cost_valid) {
+ for (plane = 1; plane < MAX_MB_PLANE; ++plane)
+ av1_subtract_plane(x, plane_bsize, plane);
+ }
+
+ if (is_cost_valid) {
+ const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
+ for (plane = 1; plane < MAX_MB_PLANE; ++plane) {
+ RD_STATS pn_rd_stats;
+ int64_t chroma_ref_best_rd = ref_best_rd;
+ // For inter blocks, refined ref_best_rd is used for early exit
+ // For intra blocks, even though current rd crosses ref_best_rd, early
+ // exit is not recommended as current rd is used for gating subsequent
+ // modes as well (say, for angular modes)
+ // TODO(any): Extend the early exit mechanism for intra modes as well
+ if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
+ is_inter && chroma_ref_best_rd != INT64_MAX)
+ chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_rd);
+ txfm_rd_in_plane(x, cpi, &pn_rd_stats, chroma_ref_best_rd, 0, plane,
+ plane_bsize, uv_tx_size,
+ cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE, 0);
+ if (pn_rd_stats.rate == INT_MAX) {
+ is_cost_valid = 0;
+ break;
+ }
+ av1_merge_rd_stats(rd_stats, &pn_rd_stats);
+ this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
+ skip_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
+ if (AOMMIN(this_rd, skip_rd) > ref_best_rd) {
+ is_cost_valid = 0;
+ break;
+ }
+ }
+ }
+
+ if (!is_cost_valid) {
+ // reset cost value
+ av1_invalid_rd_stats(rd_stats);
+ }
+
+ return is_cost_valid;
+}
+
+void txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi, RD_STATS *rd_stats,
+ int64_t ref_best_rd, int64_t this_rd, int plane,
+ BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+ int use_fast_coef_casting, FAST_TX_SEARCH_MODE ftxs_mode,
+ int skip_trellis) {
+ if (!cpi->oxcf.enable_tx64 && txsize_sqr_up_map[tx_size] == TX_64X64) {
+ av1_invalid_rd_stats(rd_stats);
+ return;
+ }
+
+ MACROBLOCKD *const xd = &x->e_mbd;
+ const struct macroblockd_plane *const pd = &xd->plane[plane];
+ struct rdcost_block_args args;
+ av1_zero(args);
+ args.x = x;
+ args.cpi = cpi;
+ args.best_rd = ref_best_rd;
+ args.use_fast_coef_costing = use_fast_coef_casting;
+ args.ftxs_mode = ftxs_mode;
+ args.this_rd = this_rd;
+ args.skip_trellis = skip_trellis;
+ av1_init_rd_stats(&args.rd_stats);
+
+ if (plane == 0) xd->mi[0]->tx_size = tx_size;
+
+ av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
+
+ if (args.this_rd > args.best_rd) {
+ args.exit_early = 1;
+ }
+
+ av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
+ &args);
+
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ const int is_inter = is_inter_block(mbmi);
+ const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
+
+ if (invalid_rd) {
+ av1_invalid_rd_stats(rd_stats);
+ } else {
+ *rd_stats = args.rd_stats;
+ }
+}
+
+int txfm_search(const AV1_COMP *cpi, const TileDataEnc *tile_data,
+ MACROBLOCK *x, BLOCK_SIZE bsize, RD_STATS *rd_stats,
+ RD_STATS *rd_stats_y, RD_STATS *rd_stats_uv, int mode_rate,
+ int64_t ref_best_rd) {
+ /*
+ * This function combines y and uv planes' transform search processes
+ * together, when the prediction is generated. It first does subtraction to
+ * obtain the prediction error. Then it calls
+ * pick_tx_size_type_yrd/super_block_yrd and super_block_uvrd sequentially and
+ * handles the early terminations happening in those functions. At the end, it
+ * computes the rd_stats/_y/_uv accordingly.
+ */
+ const AV1_COMMON *cm = &cpi->common;
+ MACROBLOCKD *const xd = &x->e_mbd;
+ MB_MODE_INFO *const mbmi = xd->mi[0];
+ const int ref_frame_1 = mbmi->ref_frame[1];
+ const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
+ const int64_t rd_thresh =
+ ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
+ const int skip_ctx = av1_get_skip_context(xd);
+ const int skip_flag_cost[2] = { x->skip_cost[skip_ctx][0],
+ x->skip_cost[skip_ctx][1] };
+ const int64_t min_header_rate =
+ mode_rate + AOMMIN(skip_flag_cost[0], skip_flag_cost[1]);
+ // Account for minimum skip and non_skip rd.
+ // Eventually either one of them will be added to mode_rate
+ const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
+ (void)tile_data;
+
+ if (min_header_rd_possible > ref_best_rd) {
+ av1_invalid_rd_stats(rd_stats_y);
+ return 0;
+ }
+
+ av1_init_rd_stats(rd_stats);
+ av1_init_rd_stats(rd_stats_y);
+ rd_stats->rate = mode_rate;
+
+ // cost and distortion
+ av1_subtract_plane(x, bsize, 0);
+ if (x->tx_mode_search_type == TX_MODE_SELECT &&
+ !xd->lossless[mbmi->segment_id]) {
+ pick_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
+#if CONFIG_COLLECT_RD_STATS == 2
+ PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
+#endif // CONFIG_COLLECT_RD_STATS == 2
+ } else {
+ super_block_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
+ memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
+ for (int i = 0; i < xd->n4_h * xd->n4_w; ++i)
+ set_blk_skip(x, 0, i, rd_stats_y->skip);
+ }
+
+ if (rd_stats_y->rate == INT_MAX) {
+ // TODO(angiebird): check if we need this
+ // restore_dst_buf(xd, *orig_dst, num_planes);
+ mbmi->ref_frame[1] = ref_frame_1;
+ return 0;
+ }
+
+ av1_merge_rd_stats(rd_stats, rd_stats_y);
+
+ const int64_t non_skip_rdcosty =
+ RDCOST(x->rdmult, rd_stats->rate + skip_flag_cost[0], rd_stats->dist);
+ const int64_t skip_rdcosty =
+ RDCOST(x->rdmult, mode_rate + skip_flag_cost[1], rd_stats->sse);
+ const int64_t min_rdcosty = AOMMIN(non_skip_rdcosty, skip_rdcosty);
+ if (min_rdcosty > ref_best_rd) {
+ const int64_t tokenonly_rdy =
+ AOMMIN(RDCOST(x->rdmult, rd_stats_y->rate, rd_stats_y->dist),
+ RDCOST(x->rdmult, 0, rd_stats_y->sse));
+ // Invalidate rd_stats_y to skip the rest of the motion modes search
+ if (tokenonly_rdy -
+ (tokenonly_rdy >> cpi->sf.inter_sf.prune_motion_mode_level) >
+ rd_thresh)
+ av1_invalid_rd_stats(rd_stats_y);
+ mbmi->ref_frame[1] = ref_frame_1;
+ return 0;
+ }
+
+ av1_init_rd_stats(rd_stats_uv);
+ const int num_planes = av1_num_planes(cm);
+ if (num_planes > 1) {
+ int64_t ref_best_chroma_rd = ref_best_rd;
+ // Calculate best rd cost possible for chroma
+ if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
+ (ref_best_chroma_rd != INT64_MAX)) {
+ ref_best_chroma_rd =
+ (ref_best_chroma_rd - AOMMIN(non_skip_rdcosty, skip_rdcosty));
+ }
+ const int is_cost_valid_uv =
+ super_block_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
+ if (!is_cost_valid_uv) {
+ mbmi->ref_frame[1] = ref_frame_1;
+ return 0;
+ }
+ av1_merge_rd_stats(rd_stats, rd_stats_uv);
+ }
+
+ if (rd_stats->skip) {
+ rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
+ rd_stats_y->rate = 0;
+ rd_stats_uv->rate = 0;
+ rd_stats->dist = rd_stats->sse;
+ rd_stats_y->dist = rd_stats_y->sse;
+ rd_stats_uv->dist = rd_stats_uv->sse;
+ rd_stats->rate += skip_flag_cost[1];
+ mbmi->skip = 1;
+ // here mbmi->skip temporarily plays a role as what this_skip2 does
+
+ const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
+ if (tmprd > ref_best_rd) {
+ mbmi->ref_frame[1] = ref_frame_1;
+ return 0;
+ }
+ } else if (!xd->lossless[mbmi->segment_id] &&
+ (RDCOST(x->rdmult,
+ rd_stats_y->rate + rd_stats_uv->rate + skip_flag_cost[0],
+ rd_stats->dist) >=
+ RDCOST(x->rdmult, skip_flag_cost[1], rd_stats->sse))) {
+ rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
+ rd_stats->rate += skip_flag_cost[1];
+ rd_stats->dist = rd_stats->sse;
+ rd_stats_y->dist = rd_stats_y->sse;
+ rd_stats_uv->dist = rd_stats_uv->sse;
+ rd_stats_y->rate = 0;
+ rd_stats_uv->rate = 0;
+ mbmi->skip = 1;
+ } else {
+ rd_stats->rate += skip_flag_cost[0];
+ mbmi->skip = 0;
+ }
+
+ return 1;
+}
diff --git a/av1/encoder/tx_search.h b/av1/encoder/tx_search.h
new file mode 100644
index 0000000..3568f1e
--- /dev/null
+++ b/av1/encoder/tx_search.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2020, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_AV1_ENCODER_TRANSFORM_SEARCH_H_
+#define AOM_AV1_ENCODER_TRANSFORM_SEARCH_H_
+
+#include "av1/common/pred_common.h"
+#include "av1/encoder/encoder.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Set this macro as 1 to collect data about tx size selection.
+#define COLLECT_TX_SIZE_DATA 0
+
+#if COLLECT_TX_SIZE_DATA
+static const char av1_tx_size_data_output_file[] = "tx_size_data.txt";
+#endif
+
+enum {
+ FTXS_NONE = 0,
+ FTXS_DCT_AND_1D_DCT_ONLY = 1 << 0,
+ FTXS_DISABLE_TRELLIS_OPT = 1 << 1,
+ FTXS_USE_TRANSFORM_DOMAIN = 1 << 2
+} UENUM1BYTE(FAST_TX_SEARCH_MODE);
+
+static AOM_INLINE int tx_size_cost(const MACROBLOCK *const x, BLOCK_SIZE bsize,
+ TX_SIZE tx_size) {
+ assert(bsize == x->e_mbd.mi[0]->sb_type);
+ if (x->tx_mode_search_type != TX_MODE_SELECT || !block_signals_txsize(bsize))
+ return 0;
+
+ const int32_t tx_size_cat = bsize_to_tx_size_cat(bsize);
+ const int depth = tx_size_to_depth(tx_size, bsize);
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ const int tx_size_ctx = get_tx_size_context(xd);
+ return x->tx_size_cost[tx_size_cat][tx_size_ctx][depth];
+}
+
+int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
+ int64_t ref_best_rd, BLOCK_SIZE bs, TX_SIZE tx_size,
+ FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis);
+
+void pick_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, BLOCK_SIZE bsize,
+ int64_t ref_best_rd);
+
+void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, BLOCK_SIZE bs, int64_t ref_best_rd);
+
+int super_block_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x,
+ RD_STATS *rd_stats, BLOCK_SIZE bsize, int64_t ref_best_rd);
+
+void txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi, RD_STATS *rd_stats,
+ int64_t ref_best_rd, int64_t this_rd, int plane,
+ BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
+ int use_fast_coef_casting, FAST_TX_SEARCH_MODE ftxs_mode,
+ int skip_trellis);
+
+int txfm_search(const AV1_COMP *cpi, const TileDataEnc *tile_data,
+ MACROBLOCK *x, BLOCK_SIZE bsize, RD_STATS *rd_stats,
+ RD_STATS *rd_stats_y, RD_STATS *rd_stats_uv, int mode_rate,
+ int64_t ref_best_rd);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // AOM_AV1_ENCODER_TRANSFORM_SEARCH_H_