Implement 64x32 and 32x64 transforms
Change-Id: Ifa983d83a509cdfad78f6400df7d60c8f5b4f68c
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 151ce4a..ca1c361 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -122,15 +122,16 @@
// Inverse identity transform and add.
#if CONFIG_EXT_TX
static void inv_idtx_add_c(const tran_low_t *input, uint8_t *dest, int stride,
- int bs, int tx_type) {
+ int bsx, int bsy, int tx_type) {
int r, c;
- const int shift = bs < 32 ? 3 : (bs < 64 ? 2 : 1);
+ const int pels = bsx * bsy;
+ const int shift = 3 - ((pels > 256) + (pels > 1024));
if (tx_type == IDTX) {
- for (r = 0; r < bs; ++r) {
- for (c = 0; c < bs; ++c)
+ for (r = 0; r < bsy; ++r) {
+ for (c = 0; c < bsx; ++c)
dest[c] = clip_pixel_add(dest[c], input[c] >> shift);
dest += stride;
- input += bs;
+ input += bsx;
}
}
}
@@ -185,17 +186,19 @@
#if CONFIG_HIGHBITDEPTH
#if CONFIG_EXT_TX && CONFIG_TX64X64
static void highbd_inv_idtx_add_c(const tran_low_t *input, uint8_t *dest8,
- int stride, int bs, int tx_type, int bd) {
+ int stride, int bsx, int bsy, int tx_type,
+ int bd) {
int r, c;
- const int shift = bs < 32 ? 3 : 2;
+ const int pels = bsx * bsy;
+ const int shift = 3 - ((pels > 256) + (pels > 1024));
uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
if (tx_type == IDTX) {
- for (r = 0; r < bs; ++r) {
- for (c = 0; c < bs; ++c)
+ for (r = 0; r < bsy; ++r) {
+ for (c = 0; c < bsx; ++c)
dest[c] = highbd_clip_pixel_add(dest[c], input[c] >> shift, bd);
dest += stride;
- input += bs;
+ input += bsx;
}
}
}
@@ -1521,6 +1524,131 @@
}
}
}
+
+void av1_iht64x32_2048_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+ const TxfmParam *txfm_param) {
+ int tx_type = txfm_param->tx_type;
+#if CONFIG_MRC_TX
+ assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
+#endif // CONFIG_MRC_TX
+#if CONFIG_DCT_ONLY
+ assert(tx_type == DCT_DCT);
+#endif
+ static const transform_2d IHT_64x32[] = {
+ { aom_idct32_c, idct64_row_c }, // DCT_DCT
+ { ihalfright32_c, idct64_row_c }, // ADST_DCT
+ { aom_idct32_c, ihalfright64_c }, // DCT_ADST
+ { ihalfright32_c, ihalfright64_c }, // ADST_ADST
+#if CONFIG_EXT_TX
+ { ihalfright32_c, idct64_row_c }, // FLIPADST_DCT
+ { aom_idct32_c, ihalfright64_c }, // DCT_FLIPADST
+ { ihalfright32_c, ihalfright64_c }, // FLIPADST_FLIPADST
+ { ihalfright32_c, ihalfright64_c }, // ADST_FLIPADST
+ { ihalfright32_c, ihalfright64_c }, // FLIPADST_ADST
+ { iidtx32_c, iidtx64_c }, // IDTX
+ { aom_idct32_c, iidtx64_c }, // V_DCT
+ { iidtx32_c, idct64_row_c }, // H_DCT
+ { ihalfright32_c, iidtx64_c }, // V_ADST
+ { iidtx32_c, ihalfright64_c }, // H_ADST
+ { ihalfright32_c, iidtx64_c }, // V_FLIPADST
+ { iidtx32_c, ihalfright64_c }, // H_FLIPADST
+#endif
+ };
+ const int n = 32;
+ const int n2 = 64;
+
+ int i, j;
+ tran_low_t out[64][32], tmp[64][32], outtmp[64];
+ tran_low_t *outp = &out[0][0];
+ int outstride = n;
+
+ // inverse transform row vectors and transpose
+ for (i = 0; i < n; ++i) {
+ IHT_64x32[tx_type].rows(input, outtmp);
+ for (j = 0; j < n2; ++j)
+ tmp[j][i] = (tran_low_t)dct_const_round_shift(outtmp[j] * Sqrt2);
+ input += n2;
+ }
+
+ // inverse transform column vectors
+ for (i = 0; i < n2; ++i) IHT_64x32[tx_type].cols(tmp[i], out[i]);
+
+#if CONFIG_EXT_TX
+ maybe_flip_strides(&dest, &stride, &outp, &outstride, tx_type, n, n2);
+#endif
+
+ // Sum with the destination
+ for (i = 0; i < n; ++i) {
+ for (j = 0; j < n2; ++j) {
+ int d = i * stride + j;
+ int s = j * outstride + i;
+ dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
+ }
+ }
+}
+
+void av1_iht32x64_2048_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+ const TxfmParam *txfm_param) {
+ int tx_type = txfm_param->tx_type;
+#if CONFIG_MRC_TX
+ assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
+#endif // CONFIG_MRC_TX
+#if CONFIG_DCT_ONLY
+ assert(tx_type == DCT_DCT);
+#endif
+ static const transform_2d IHT_32x64[] = {
+ { idct64_col_c, aom_idct32_c }, // DCT_DCT
+ { ihalfright64_c, aom_idct32_c }, // ADST_DCT
+ { idct64_col_c, ihalfright32_c }, // DCT_ADST
+ { ihalfright64_c, ihalfright32_c }, // ADST_ADST
+#if CONFIG_EXT_TX
+ { ihalfright64_c, aom_idct32_c }, // FLIPADST_DCT
+ { idct64_col_c, ihalfright32_c }, // DCT_FLIPADST
+ { ihalfright64_c, ihalfright32_c }, // FLIPADST_FLIPADST
+ { ihalfright64_c, ihalfright32_c }, // ADST_FLIPADST
+ { ihalfright64_c, ihalfright32_c }, // FLIPADST_ADST
+ { iidtx64_c, iidtx32_c }, // IDTX
+ { idct64_col_c, iidtx32_c }, // V_DCT
+ { iidtx64_c, aom_idct32_c }, // H_DCT
+ { ihalfright64_c, iidtx32_c }, // V_ADST
+ { iidtx64_c, ihalfright32_c }, // H_ADST
+ { ihalfright64_c, iidtx32_c }, // V_FLIPADST
+ { iidtx64_c, ihalfright32_c }, // H_FLIPADST
+#endif
+ };
+
+ const int n = 32;
+ const int n2 = 64;
+ int i, j;
+ tran_low_t out[32][64], tmp[32][64], outtmp[32];
+ tran_low_t *outp = &out[0][0];
+ int outstride = n2;
+
+ // inverse transform row vectors and transpose
+ for (i = 0; i < n2; ++i) {
+ IHT_32x64[tx_type].rows(input, outtmp);
+ for (j = 0; j < n; ++j)
+ tmp[j][i] = (tran_low_t)dct_const_round_shift(outtmp[j] * Sqrt2);
+ input += n;
+ }
+
+ // inverse transform column vectors
+ for (i = 0; i < n; ++i) IHT_32x64[tx_type].cols(tmp[i], out[i]);
+
+#if CONFIG_EXT_TX
+ maybe_flip_strides(&dest, &stride, &outp, &outstride, tx_type, n2, n);
+#endif
+
+ // Sum with the destination
+ for (i = 0; i < n2; ++i) {
+ for (j = 0; j < n; ++j) {
+ int d = i * stride + j;
+ int s = j * outstride + i;
+ dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
+ }
+ }
+}
+
#endif // CONFIG_TX64X64
// idct
@@ -1743,7 +1871,7 @@
// Use C version since DST only exists in C code
av1_iht4x4_16_add_c(input, dest, stride, txfm_param);
break;
- case IDTX: inv_idtx_add_c(input, dest, stride, 4, tx_type); break;
+ case IDTX: inv_idtx_add_c(input, dest, stride, 4, 4, tx_type); break;
#endif // CONFIG_EXT_TX
default: assert(0); break;
}
@@ -1834,6 +1962,18 @@
av1_iht32x16_512_add(input, dest, stride, txfm_param);
}
+#if CONFIG_TX64X64
+static void inv_txfm_add_32x64(const tran_low_t *input, uint8_t *dest,
+ int stride, const TxfmParam *txfm_param) {
+ av1_iht32x64_2048_add(input, dest, stride, txfm_param);
+}
+
+static void inv_txfm_add_64x32(const tran_low_t *input, uint8_t *dest,
+ int stride, const TxfmParam *txfm_param) {
+ av1_iht64x32_2048_add(input, dest, stride, txfm_param);
+}
+#endif // CONFIG_TX64X64
+
static void inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
const TX_TYPE tx_type = txfm_param->tx_type;
@@ -1875,7 +2015,7 @@
// Use C version since DST only exists in C code
av1_iht8x8_64_add_c(input, dest, stride, txfm_param);
break;
- case IDTX: inv_idtx_add_c(input, dest, stride, 8, tx_type); break;
+ case IDTX: inv_idtx_add_c(input, dest, stride, 8, 8, tx_type); break;
#endif // CONFIG_EXT_TX
default: assert(0); break;
}
@@ -1917,7 +2057,7 @@
av1_iht16x16_256_add(input, dest, stride, txfm_param);
#endif // CONFIG_DAALA_DCT16
break;
- case IDTX: inv_idtx_add_c(input, dest, stride, 16, tx_type); break;
+ case IDTX: inv_idtx_add_c(input, dest, stride, 16, 16, tx_type); break;
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX
case MRC_DCT: assert(0 && "Invalid tx type for tx size");
@@ -1954,7 +2094,7 @@
case H_FLIPADST:
av1_iht32x32_1024_add_c(input, dest, stride, txfm_param);
break;
- case IDTX: inv_idtx_add_c(input, dest, stride, 32, tx_type); break;
+ case IDTX: inv_idtx_add_c(input, dest, stride, 32, 32, tx_type); break;
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX
case MRC_DCT: imrc32x32_add_c(input, dest, stride, txfm_param); break;
@@ -1990,7 +2130,7 @@
case H_FLIPADST:
av1_iht64x64_4096_add_c(input, dest, stride, txfm_param);
break;
- case IDTX: inv_idtx_add_c(input, dest, stride, 64, tx_type); break;
+ case IDTX: inv_idtx_add_c(input, dest, stride, 64, 64, tx_type); break;
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX
case MRC_DCT: assert(0 && "Invalid tx type for tx size");
@@ -2130,6 +2270,22 @@
txfm_param->tx_type, txfm_param->bd);
}
+#if CONFIG_TX64X64
+static void highbd_inv_txfm_add_32x64(const tran_low_t *input, uint8_t *dest,
+ int stride, const TxfmParam *txfm_param) {
+ const int32_t *src = (const int32_t *)input;
+ av1_inv_txfm2d_add_32x64_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+ txfm_param->tx_type, txfm_param->bd);
+}
+
+static void highbd_inv_txfm_add_64x32(const tran_low_t *input, uint8_t *dest,
+ int stride, const TxfmParam *txfm_param) {
+ const int32_t *src = (const int32_t *)input;
+ av1_inv_txfm2d_add_64x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+ txfm_param->tx_type, txfm_param->bd);
+}
+#endif // CONFIG_TX64X64
+
static void highbd_inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *txfm_param) {
int bd = txfm_param->bd;
@@ -2280,7 +2436,7 @@
DCT_DCT, bd);
break;
case IDTX:
- highbd_inv_idtx_add_c(input, dest, stride, 64, tx_type, bd);
+ highbd_inv_idtx_add_c(input, dest, stride, 64, 64, tx_type, bd);
break;
#endif // CONFIG_EXT_TX
default: assert(0); break;
@@ -2304,6 +2460,10 @@
case TX_16X8: inv_txfm_add_16x8(input, dest, stride, txfm_param); break;
case TX_16X32: inv_txfm_add_16x32(input, dest, stride, txfm_param); break;
case TX_32X16: inv_txfm_add_32x16(input, dest, stride, txfm_param); break;
+#if CONFIG_TX64X64
+ case TX_64X32: inv_txfm_add_64x32(input, dest, stride, txfm_param); break;
+ case TX_32X64: inv_txfm_add_32x64(input, dest, stride, txfm_param); break;
+#endif // CONFIG_TX64X64
case TX_4X4:
// this is like av1_short_idct4x4 but has a special case around eob<=1
// which is significant (not just an optimization) for the lossless
@@ -2474,6 +2634,14 @@
case TX_32X16:
highbd_inv_txfm_add_32x16(input, dest, stride, txfm_param);
break;
+#if CONFIG_TX64X64
+ case TX_64X32:
+ highbd_inv_txfm_add_64x32(input, dest, stride, txfm_param);
+ break;
+ case TX_32X64:
+ highbd_inv_txfm_add_32x64(input, dest, stride, txfm_param);
+ break;
+#endif // CONFIG_TX64X64
case TX_4X4:
// this is like av1_short_idct4x4 but has a special case around eob<=1
// which is significant (not just an optimization) for the lossless