Use tx_size 1 level down for transform type search
This addresses an inconsistency between the set used
to decode the tx_type in the bitstream and the set used
for the tx_type search. Previously, the set used to
read/write the tx_type was based on the smallest tx_size
in the vartx partitioning, but the search uses a set
based on the largest possible tx_size. This patch
changes the tx_type search to use the transform type
set associated with the tx_size 1 recursive level down from
the max square tx_size to make the search more consistent
with the bitstream syntax. If a tx_size is selected for an
invalid tx_type, DCT_DCT is used for that partition instead.
This patch also adds assertions to all exposed transform
functions to ensure that no illegal transform type/size
combinations occur.
This currently gets a 0.1% drop in performance on lowres.
The drop is due to the reduction of the tx_types available
for 32x16 and 16x32 transform sizes. Before this patch,
32x16 and 16x32 transforms were getting assigned a
set of 12 tx_types, some of which we did not intend to
support for these sizes.
Change-Id: I44aca4876b261c345623cd04ad6235bca4532701
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 7bd8302..de6e05b 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -256,6 +256,9 @@
// apply. Otherwise they return 0
int get_lgt4(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
lgtmtx[0] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
@@ -271,6 +274,9 @@
int get_lgt8(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
lgtmtx[0] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
@@ -388,6 +394,9 @@
void get_lgt4_from_pred(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx, int ntx) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
PREDICTION_MODE mode = txfm_param->mode;
int stride = txfm_param->stride;
uint8_t *dst = txfm_param->dst;
@@ -469,6 +478,9 @@
void get_lgt8_from_pred(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx, int ntx) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
PREDICTION_MODE mode = txfm_param->mode;
int stride = txfm_param->stride;
uint8_t *dst = txfm_param->dst;
@@ -538,6 +550,9 @@
// will just call DCT or ADST
void get_lgt16up_from_pred(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx, int ntx) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
int tx_length = is_col ? tx_size_high[txfm_param->tx_size]
: tx_size_wide[txfm_param->tx_size];
assert(tx_length == 16 || tx_length == 32);
@@ -2414,6 +2429,9 @@
// idct
void av1_idct4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
const int eob = txfm_param->eob;
if (eob > 1)
av1_iht4x4_16_add(input, dest, stride, txfm_param);
@@ -2423,6 +2441,9 @@
void av1_iwht4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
const int eob = txfm_param->eob;
if (eob > 1)
aom_iwht4x4_16_add(input, dest, stride);
@@ -2897,6 +2918,9 @@
void av1_highbd_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
int eob = txfm_param->eob;
int bd = txfm_param->bd;
int lossless = txfm_param->lossless;
@@ -2942,6 +2966,9 @@
void av1_highbd_inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
const int32_t *src = cast_to_int32(input);
av1_inv_txfm2d_add_4x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
txfm_param->tx_type, txfm_param->bd);
@@ -2949,6 +2976,9 @@
void av1_highbd_inv_txfm_add_8x4(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
const int32_t *src = cast_to_int32(input);
av1_inv_txfm2d_add_8x4_c(src, CONVERT_TO_SHORTPTR(dest), stride,
txfm_param->tx_type, txfm_param->bd);
@@ -3158,6 +3188,9 @@
void av1_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
TxfmParam *txfm_param) {
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
const TX_SIZE tx_size = txfm_param->tx_size;
#if CONFIG_LGT_FROM_PRED
if (txfm_param->use_lgt) {
@@ -3199,13 +3232,27 @@
}
}
-static void init_txfm_param(const MACROBLOCKD *xd, TX_SIZE tx_size,
- TX_TYPE tx_type, int eob, TxfmParam *txfm_param) {
+static void init_txfm_param(const MACROBLOCKD *xd,
+#if CONFIG_EXT_TX
+ int plane,
+#endif // CONFIG_EXT_TX
+ TX_SIZE tx_size, TX_TYPE tx_type, int eob,
+ TxfmParam *txfm_param) {
txfm_param->tx_type = tx_type;
txfm_param->tx_size = tx_size;
txfm_param->eob = eob;
txfm_param->lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
txfm_param->bd = xd->bd;
+#if CONFIG_EXT_TX
+ const struct macroblockd_plane *const pd = &xd->plane[plane];
+ const BLOCK_SIZE plane_bsize =
+ get_plane_block_size(xd->mi[0]->mbmi.sb_type, pd);
+ // TODO(sarahparker) This assumes reduced_tx_set_used == 0. I will do a
+ // follow up refactor to make the actual value of reduced_tx_set_used
+ // within this function.
+ txfm_param->tx_set_type = get_ext_tx_set_type(
+ txfm_param->tx_size, plane_bsize, is_inter_block(&xd->mi[0]->mbmi), 0);
+#endif // CONFIG_EXT_TX
#if CONFIG_LGT
txfm_param->is_inter = is_inter_block(&xd->mi[0]->mbmi);
#endif
@@ -3234,12 +3281,19 @@
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
uint8_t *mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
+#if CONFIG_EXT_TX
+ int plane,
+#endif // CONFIG_EXT_TX
TX_TYPE tx_type, TX_SIZE tx_size, uint8_t *dst,
int stride, int eob) {
if (!eob) return;
TxfmParam txfm_param;
- init_txfm_param(xd, tx_size, tx_type, eob, &txfm_param);
+ init_txfm_param(xd,
+#if CONFIG_EXT_TX
+ plane,
+#endif // CONFIG_EXT_TX
+ tx_size, tx_type, eob, &txfm_param);
#if CONFIG_LGT || CONFIG_MRC_TX
txfm_param.is_inter = is_inter_block(&xd->mi[0]->mbmi);
#endif // CONFIG_LGT || CONFIG_MRC_TX
@@ -3253,6 +3307,9 @@
txfm_param.mode = mode;
#endif // CONFIG_LGT_FROM_PRED
#endif // CONFIG_LGT_FROM_PRED || CONFIG_MRC_TX
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param.tx_set_type][txfm_param.tx_type]);
+#endif // CONFIG_EXT_TX
const int is_hbd = get_bitdepth_data_path_index(xd);
#if CONFIG_TXMG
@@ -3304,12 +3361,18 @@
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
+#if CONFIG_EXT_TX
+ plane,
+#endif // CONFIG_EXT_TX
tx_type, tx_size, dst, dst_stride, eob);
}
void av1_highbd_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
TxfmParam *txfm_param) {
const TX_SIZE tx_size = txfm_param->tx_size;
+#if CONFIG_EXT_TX
+ assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
+#endif // CONFIG_EXT_TX
switch (tx_size) {
#if CONFIG_TX64X64
case TX_64X64: