Minor refactor to match the 4x4 forward transform.
Change-Id: Ib5337dfa78b73059ad169ca98a07119aa991864b
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 8f4e58a..49a91fb 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -249,6 +249,10 @@
void av1_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const INV_TXFM_PARAM *param) {
int tx_type = param->tx_type;
+ if (tx_type == DCT_DCT) {
+ aom_idct4x4_16_add(input, dest, stride);
+ return;
+ }
static const transform_2d IHT_4[] = {
{ aom_idct4_c, aom_idct4_c }, // DCT_DCT = 0
{ aom_iadst4_c, aom_idct4_c }, // ADST_DCT = 1
@@ -1303,15 +1307,17 @@
// idct
void av1_idct4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
- int eob) {
+ const INV_TXFM_PARAM *param) {
+ const int eob = param->eob;
if (eob > 1)
- aom_idct4x4_16_add(input, dest, stride);
+ av1_iht4x4_16_add(input, dest, stride, param);
else
aom_idct4x4_1_add(input, dest, stride);
}
void av1_iwht4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
- int eob) {
+ const INV_TXFM_PARAM *param) {
+ const int eob = param->eob;
if (eob > 1)
aom_iwht4x4_16_add(input, dest, stride);
else
@@ -1427,15 +1433,14 @@
static void inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest, int stride,
const INV_TXFM_PARAM *param) {
const TX_TYPE tx_type = param->tx_type;
- const int eob = param->eob;
if (param->lossless) {
assert(tx_type == DCT_DCT);
- av1_iwht4x4_add(input, dest, stride, eob);
+ av1_iwht4x4_add(input, dest, stride, param);
return;
}
switch (tx_type) {
- case DCT_DCT: av1_idct4x4_add(input, dest, stride, eob); break;
+ case DCT_DCT: av1_idct4x4_add(input, dest, stride, param); break;
case ADST_DCT:
case DCT_ADST:
case ADST_ADST:
diff --git a/av1/common/idct.h b/av1/common/idct.h
index 3c2fa6a..0591d4a 100644
--- a/av1/common/idct.h
+++ b/av1/common/idct.h
@@ -76,9 +76,9 @@
int av1_get_tx_scale(const TX_SIZE tx_size);
void av1_iwht4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
- int eob);
+ const INV_TXFM_PARAM *inv_txfm_param);
void av1_idct4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
- int eob);
+ const INV_TXFM_PARAM *inv_txfm_param);
void av1_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
INV_TXFM_PARAM *inv_txfm_param);
diff --git a/av1/encoder/encodemb.c b/av1/encoder/encodemb.c
index a9e2f2a..b6301aa 100644
--- a/av1/encoder/encodemb.c
+++ b/av1/encoder/encodemb.c
@@ -822,6 +822,7 @@
struct macroblock_plane *const p = &x->plane[plane];
struct macroblockd_plane *const pd = &xd->plane[plane];
tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
+ INV_TXFM_PARAM inv_txfm_param;
uint8_t *dst;
int ctx = 0;
dst = &pd->dst
@@ -857,12 +858,11 @@
#endif // CONFIG_HIGHBITDEPTH
}
#endif // !CONFIG_PVQ
-#if CONFIG_HIGHBITDEPTH
- INV_TXFM_PARAM inv_txfm_param;
inv_txfm_param.bd = xd->bd;
inv_txfm_param.tx_type = DCT_DCT;
inv_txfm_param.eob = p->eobs[block];
inv_txfm_param.lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
+#if CONFIG_HIGHBITDEPTH
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
av1_highbd_inv_txfm_add_4x4(dqcoeff, dst, pd->dst.stride,
&inv_txfm_param);
@@ -870,9 +870,9 @@
}
#endif // CONFIG_HIGHBITDEPTH
if (xd->lossless[xd->mi[0]->mbmi.segment_id]) {
- av1_iwht4x4_add(dqcoeff, dst, pd->dst.stride, p->eobs[block]);
+ av1_iwht4x4_add(dqcoeff, dst, pd->dst.stride, &inv_txfm_param);
} else {
- av1_idct4x4_add(dqcoeff, dst, pd->dst.stride, p->eobs[block]);
+ av1_idct4x4_add(dqcoeff, dst, pd->dst.stride, &inv_txfm_param);
}
}
}