Implement 64x32 and 32x64 transforms

Change-Id: Ifa983d83a509cdfad78f6400df7d60c8f5b4f68c
diff --git a/av1/common/idct.c b/av1/common/idct.c
index 151ce4a..ca1c361 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -122,15 +122,16 @@
 // Inverse identity transform and add.
 #if CONFIG_EXT_TX
 static void inv_idtx_add_c(const tran_low_t *input, uint8_t *dest, int stride,
-                           int bs, int tx_type) {
+                           int bsx, int bsy, int tx_type) {
   int r, c;
-  const int shift = bs < 32 ? 3 : (bs < 64 ? 2 : 1);
+  const int pels = bsx * bsy;
+  const int shift = 3 - ((pels > 256) + (pels > 1024));
   if (tx_type == IDTX) {
-    for (r = 0; r < bs; ++r) {
-      for (c = 0; c < bs; ++c)
+    for (r = 0; r < bsy; ++r) {
+      for (c = 0; c < bsx; ++c)
         dest[c] = clip_pixel_add(dest[c], input[c] >> shift);
       dest += stride;
-      input += bs;
+      input += bsx;
     }
   }
 }
@@ -185,17 +186,19 @@
 #if CONFIG_HIGHBITDEPTH
 #if CONFIG_EXT_TX && CONFIG_TX64X64
 static void highbd_inv_idtx_add_c(const tran_low_t *input, uint8_t *dest8,
-                                  int stride, int bs, int tx_type, int bd) {
+                                  int stride, int bsx, int bsy, int tx_type,
+                                  int bd) {
   int r, c;
-  const int shift = bs < 32 ? 3 : 2;
+  const int pels = bsx * bsy;
+  const int shift = 3 - ((pels > 256) + (pels > 1024));
   uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
 
   if (tx_type == IDTX) {
-    for (r = 0; r < bs; ++r) {
-      for (c = 0; c < bs; ++c)
+    for (r = 0; r < bsy; ++r) {
+      for (c = 0; c < bsx; ++c)
         dest[c] = highbd_clip_pixel_add(dest[c], input[c] >> shift, bd);
       dest += stride;
-      input += bs;
+      input += bsx;
     }
   }
 }
@@ -1521,6 +1524,131 @@
     }
   }
 }
+
+void av1_iht64x32_2048_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+                             const TxfmParam *txfm_param) {
+  int tx_type = txfm_param->tx_type;
+#if CONFIG_MRC_TX
+  assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
+#endif  // CONFIG_MRC_TX
+#if CONFIG_DCT_ONLY
+  assert(tx_type == DCT_DCT);
+#endif
+  static const transform_2d IHT_64x32[] = {
+    { aom_idct32_c, idct64_row_c },      // DCT_DCT
+    { ihalfright32_c, idct64_row_c },    // ADST_DCT
+    { aom_idct32_c, ihalfright64_c },    // DCT_ADST
+    { ihalfright32_c, ihalfright64_c },  // ADST_ADST
+#if CONFIG_EXT_TX
+    { ihalfright32_c, idct64_row_c },    // FLIPADST_DCT
+    { aom_idct32_c, ihalfright64_c },    // DCT_FLIPADST
+    { ihalfright32_c, ihalfright64_c },  // FLIPADST_FLIPADST
+    { ihalfright32_c, ihalfright64_c },  // ADST_FLIPADST
+    { ihalfright32_c, ihalfright64_c },  // FLIPADST_ADST
+    { iidtx32_c, iidtx64_c },            // IDTX
+    { aom_idct32_c, iidtx64_c },         // V_DCT
+    { iidtx32_c, idct64_row_c },         // H_DCT
+    { ihalfright32_c, iidtx64_c },       // V_ADST
+    { iidtx32_c, ihalfright64_c },       // H_ADST
+    { ihalfright32_c, iidtx64_c },       // V_FLIPADST
+    { iidtx32_c, ihalfright64_c },       // H_FLIPADST
+#endif
+  };
+  const int n = 32;
+  const int n2 = 64;
+
+  int i, j;
+  tran_low_t out[64][32], tmp[64][32], outtmp[64];
+  tran_low_t *outp = &out[0][0];
+  int outstride = n;
+
+  // inverse transform row vectors and transpose
+  for (i = 0; i < n; ++i) {
+    IHT_64x32[tx_type].rows(input, outtmp);
+    for (j = 0; j < n2; ++j)
+      tmp[j][i] = (tran_low_t)dct_const_round_shift(outtmp[j] * Sqrt2);
+    input += n2;
+  }
+
+  // inverse transform column vectors
+  for (i = 0; i < n2; ++i) IHT_64x32[tx_type].cols(tmp[i], out[i]);
+
+#if CONFIG_EXT_TX
+  maybe_flip_strides(&dest, &stride, &outp, &outstride, tx_type, n, n2);
+#endif
+
+  // Sum with the destination
+  for (i = 0; i < n; ++i) {
+    for (j = 0; j < n2; ++j) {
+      int d = i * stride + j;
+      int s = j * outstride + i;
+      dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
+    }
+  }
+}
+
+void av1_iht32x64_2048_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+                             const TxfmParam *txfm_param) {
+  int tx_type = txfm_param->tx_type;
+#if CONFIG_MRC_TX
+  assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
+#endif  // CONFIG_MRC_TX
+#if CONFIG_DCT_ONLY
+  assert(tx_type == DCT_DCT);
+#endif
+  static const transform_2d IHT_32x64[] = {
+    { idct64_col_c, aom_idct32_c },      // DCT_DCT
+    { ihalfright64_c, aom_idct32_c },    // ADST_DCT
+    { idct64_col_c, ihalfright32_c },    // DCT_ADST
+    { ihalfright64_c, ihalfright32_c },  // ADST_ADST
+#if CONFIG_EXT_TX
+    { ihalfright64_c, aom_idct32_c },    // FLIPADST_DCT
+    { idct64_col_c, ihalfright32_c },    // DCT_FLIPADST
+    { ihalfright64_c, ihalfright32_c },  // FLIPADST_FLIPADST
+    { ihalfright64_c, ihalfright32_c },  // ADST_FLIPADST
+    { ihalfright64_c, ihalfright32_c },  // FLIPADST_ADST
+    { iidtx64_c, iidtx32_c },            // IDTX
+    { idct64_col_c, iidtx32_c },         // V_DCT
+    { iidtx64_c, aom_idct32_c },         // H_DCT
+    { ihalfright64_c, iidtx32_c },       // V_ADST
+    { iidtx64_c, ihalfright32_c },       // H_ADST
+    { ihalfright64_c, iidtx32_c },       // V_FLIPADST
+    { iidtx64_c, ihalfright32_c },       // H_FLIPADST
+#endif
+  };
+
+  const int n = 32;
+  const int n2 = 64;
+  int i, j;
+  tran_low_t out[32][64], tmp[32][64], outtmp[32];
+  tran_low_t *outp = &out[0][0];
+  int outstride = n2;
+
+  // inverse transform row vectors and transpose
+  for (i = 0; i < n2; ++i) {
+    IHT_32x64[tx_type].rows(input, outtmp);
+    for (j = 0; j < n; ++j)
+      tmp[j][i] = (tran_low_t)dct_const_round_shift(outtmp[j] * Sqrt2);
+    input += n;
+  }
+
+  // inverse transform column vectors
+  for (i = 0; i < n; ++i) IHT_32x64[tx_type].cols(tmp[i], out[i]);
+
+#if CONFIG_EXT_TX
+  maybe_flip_strides(&dest, &stride, &outp, &outstride, tx_type, n2, n);
+#endif
+
+  // Sum with the destination
+  for (i = 0; i < n2; ++i) {
+    for (j = 0; j < n; ++j) {
+      int d = i * stride + j;
+      int s = j * outstride + i;
+      dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
+    }
+  }
+}
+
 #endif  // CONFIG_TX64X64
 
 // idct
@@ -1743,7 +1871,7 @@
       // Use C version since DST only exists in C code
       av1_iht4x4_16_add_c(input, dest, stride, txfm_param);
       break;
-    case IDTX: inv_idtx_add_c(input, dest, stride, 4, tx_type); break;
+    case IDTX: inv_idtx_add_c(input, dest, stride, 4, 4, tx_type); break;
 #endif  // CONFIG_EXT_TX
     default: assert(0); break;
   }
@@ -1834,6 +1962,18 @@
   av1_iht32x16_512_add(input, dest, stride, txfm_param);
 }
 
+#if CONFIG_TX64X64
+static void inv_txfm_add_32x64(const tran_low_t *input, uint8_t *dest,
+                               int stride, const TxfmParam *txfm_param) {
+  av1_iht32x64_2048_add(input, dest, stride, txfm_param);
+}
+
+static void inv_txfm_add_64x32(const tran_low_t *input, uint8_t *dest,
+                               int stride, const TxfmParam *txfm_param) {
+  av1_iht64x32_2048_add(input, dest, stride, txfm_param);
+}
+#endif  // CONFIG_TX64X64
+
 static void inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest, int stride,
                              const TxfmParam *txfm_param) {
   const TX_TYPE tx_type = txfm_param->tx_type;
@@ -1875,7 +2015,7 @@
       // Use C version since DST only exists in C code
       av1_iht8x8_64_add_c(input, dest, stride, txfm_param);
       break;
-    case IDTX: inv_idtx_add_c(input, dest, stride, 8, tx_type); break;
+    case IDTX: inv_idtx_add_c(input, dest, stride, 8, 8, tx_type); break;
 #endif  // CONFIG_EXT_TX
     default: assert(0); break;
   }
@@ -1917,7 +2057,7 @@
       av1_iht16x16_256_add(input, dest, stride, txfm_param);
 #endif  // CONFIG_DAALA_DCT16
       break;
-    case IDTX: inv_idtx_add_c(input, dest, stride, 16, tx_type); break;
+    case IDTX: inv_idtx_add_c(input, dest, stride, 16, 16, tx_type); break;
 #endif  // CONFIG_EXT_TX
 #if CONFIG_MRC_TX
     case MRC_DCT: assert(0 && "Invalid tx type for tx size");
@@ -1954,7 +2094,7 @@
     case H_FLIPADST:
       av1_iht32x32_1024_add_c(input, dest, stride, txfm_param);
       break;
-    case IDTX: inv_idtx_add_c(input, dest, stride, 32, tx_type); break;
+    case IDTX: inv_idtx_add_c(input, dest, stride, 32, 32, tx_type); break;
 #endif  // CONFIG_EXT_TX
 #if CONFIG_MRC_TX
     case MRC_DCT: imrc32x32_add_c(input, dest, stride, txfm_param); break;
@@ -1990,7 +2130,7 @@
     case H_FLIPADST:
       av1_iht64x64_4096_add_c(input, dest, stride, txfm_param);
       break;
-    case IDTX: inv_idtx_add_c(input, dest, stride, 64, tx_type); break;
+    case IDTX: inv_idtx_add_c(input, dest, stride, 64, 64, tx_type); break;
 #endif  // CONFIG_EXT_TX
 #if CONFIG_MRC_TX
     case MRC_DCT: assert(0 && "Invalid tx type for tx size");
@@ -2130,6 +2270,22 @@
                              txfm_param->tx_type, txfm_param->bd);
 }
 
+#if CONFIG_TX64X64
+static void highbd_inv_txfm_add_32x64(const tran_low_t *input, uint8_t *dest,
+                                      int stride, const TxfmParam *txfm_param) {
+  const int32_t *src = (const int32_t *)input;
+  av1_inv_txfm2d_add_32x64_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                             txfm_param->tx_type, txfm_param->bd);
+}
+
+static void highbd_inv_txfm_add_64x32(const tran_low_t *input, uint8_t *dest,
+                                      int stride, const TxfmParam *txfm_param) {
+  const int32_t *src = (const int32_t *)input;
+  av1_inv_txfm2d_add_64x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+                             txfm_param->tx_type, txfm_param->bd);
+}
+#endif  // CONFIG_TX64X64
+
 static void highbd_inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest,
                                     int stride, const TxfmParam *txfm_param) {
   int bd = txfm_param->bd;
@@ -2280,7 +2436,7 @@
                                  DCT_DCT, bd);
       break;
     case IDTX:
-      highbd_inv_idtx_add_c(input, dest, stride, 64, tx_type, bd);
+      highbd_inv_idtx_add_c(input, dest, stride, 64, 64, tx_type, bd);
       break;
 #endif  // CONFIG_EXT_TX
     default: assert(0); break;
@@ -2304,6 +2460,10 @@
     case TX_16X8: inv_txfm_add_16x8(input, dest, stride, txfm_param); break;
     case TX_16X32: inv_txfm_add_16x32(input, dest, stride, txfm_param); break;
     case TX_32X16: inv_txfm_add_32x16(input, dest, stride, txfm_param); break;
+#if CONFIG_TX64X64
+    case TX_64X32: inv_txfm_add_64x32(input, dest, stride, txfm_param); break;
+    case TX_32X64: inv_txfm_add_32x64(input, dest, stride, txfm_param); break;
+#endif  // CONFIG_TX64X64
     case TX_4X4:
       // this is like av1_short_idct4x4 but has a special case around eob<=1
       // which is significant (not just an optimization) for the lossless
@@ -2474,6 +2634,14 @@
     case TX_32X16:
       highbd_inv_txfm_add_32x16(input, dest, stride, txfm_param);
       break;
+#if CONFIG_TX64X64
+    case TX_64X32:
+      highbd_inv_txfm_add_64x32(input, dest, stride, txfm_param);
+      break;
+    case TX_32X64:
+      highbd_inv_txfm_add_32x64(input, dest, stride, txfm_param);
+      break;
+#endif  // CONFIG_TX64X64
     case TX_4X4:
       // this is like av1_short_idct4x4 but has a special case around eob<=1
       // which is significant (not just an optimization) for the lossless