Adding 64x64 forward and inverse transforms

Change-Id: I213f3111fc0656aecd1303a8b871ecded2b92bc2
diff --git a/av1/common/av1_fwd_txfm2d_cfg.h b/av1/common/av1_fwd_txfm2d_cfg.h
index 49d324d..5a7c218 100644
--- a/av1/common/av1_fwd_txfm2d_cfg.h
+++ b/av1/common/av1_fwd_txfm2d_cfg.h
@@ -109,7 +109,7 @@
 };  // .txfm_type_row
 
 //  ---------------- config fwd_dct_dct_64 ----------------
-static const int8_t fwd_shift_dct_dct_64[3] = { 2, -2, -2 };
+static const int8_t fwd_shift_dct_dct_64[3] = { 0, -2, -2 };
 static const int8_t fwd_stage_range_col_dct_dct_64[12] = {
   13, 14, 15, 16, 17, 18, 19, 19, 19, 19, 19, 19
 };
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index e52dd04..af98f79 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -384,6 +384,11 @@
 add_proto qw/void av1_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
 specialize qw/av1_fht32x32 avx2/;
 
+if (aom_config("CONFIG_TX64X64") eq "yes") {
+  add_proto qw/void av1_fht64x64/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
+  specialize qw/av1_fht64x64/;
+}
+
 if (aom_config("CONFIG_EXT_TX") eq "yes") {
   add_proto qw/void av1_fht4x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
   specialize qw/av1_fht4x8 sse2/;
@@ -526,6 +531,11 @@
   add_proto qw/void av1_highbd_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
   specialize qw/av1_highbd_fht32x32/;
 
+  if (aom_config("CONFIG_TX64X64") eq "yes") {
+    add_proto qw/void av1_highbd_fht64x64/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
+    specialize qw/av1_highbd_fht64x64/;
+  }
+
   add_proto qw/void av1_highbd_fwht4x4/, "const int16_t *input, tran_low_t *output, int stride";
   specialize qw/av1_highbd_fwht4x4/;
 
diff --git a/av1/common/idct.c b/av1/common/idct.c
index cc20858..b5e3742 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -75,6 +75,47 @@
   // Note overall scaling factor is 4 times orthogonal
 }
 
+#if CONFIG_TX64X64
+static void idct64_col_c(const tran_low_t *input, tran_low_t *output) {
+  int32_t in[64], out[64];
+  int i;
+  for (i = 0; i < 64; ++i) in[i] = (int32_t)input[i];
+  av1_idct64_new(in, out, inv_cos_bit_col_dct_dct_64,
+                 inv_stage_range_col_dct_dct_64);
+  for (i = 0; i < 64; ++i) output[i] = (tran_low_t)out[i];
+}
+
+static void idct64_row_c(const tran_low_t *input, tran_low_t *output) {
+  int32_t in[64], out[64];
+  int i;
+  for (i = 0; i < 64; ++i) in[i] = (int32_t)input[i];
+  av1_idct64_new(in, out, inv_cos_bit_row_dct_dct_64,
+                 inv_stage_range_row_dct_dct_64);
+  for (i = 0; i < 64; ++i) output[i] = (tran_low_t)out[i];
+}
+
+static void iidtx64_c(const tran_low_t *input, tran_low_t *output) {
+  int i;
+  for (i = 0; i < 64; ++i)
+    output[i] = (tran_low_t)dct_const_round_shift(input[i] * 4 * Sqrt2);
+}
+
+// For use in lieu of ADST
+static void ihalfright64_c(const tran_low_t *input, tran_low_t *output) {
+  int i;
+  tran_low_t inputhalf[32];
+  // Multiply input by sqrt(2)
+  for (i = 0; i < 32; ++i) {
+    inputhalf[i] = (tran_low_t)dct_const_round_shift(input[i] * Sqrt2);
+  }
+  for (i = 0; i < 32; ++i) {
+    output[i] = (tran_low_t)dct_const_round_shift(input[32 + i] * 4 * Sqrt2);
+  }
+  aom_idct32_c(inputhalf, output + 32);
+  // Note overall scaling factor is 4 * sqrt(2)  times orthogonal
+}
+#endif  // CONFIG_TX64X64
+
 #if CONFIG_AOM_HIGHBITDEPTH
 #if CONFIG_EXT_TX
 static void highbd_iidtx4_c(const tran_low_t *input, tran_low_t *output,
@@ -122,6 +163,56 @@
   aom_highbd_idct16_c(inputhalf, output + 16, bd);
   // Note overall scaling factor is 4 times orthogonal
 }
+
+#if CONFIG_TX64X64
+static void highbd_iidtx64_c(const tran_low_t *input, tran_low_t *output,
+                             int bd) {
+  int i;
+  for (i = 0; i < 64; ++i)
+    output[i] =
+        HIGHBD_WRAPLOW(highbd_dct_const_round_shift(input[i] * 4 * Sqrt2), bd);
+}
+
+// For use in lieu of ADST
+static void highbd_ihalfright64_c(const tran_low_t *input, tran_low_t *output,
+                                  int bd) {
+  int i;
+  tran_low_t inputhalf[32];
+  // Multiply input by sqrt(2)
+  for (i = 0; i < 32; ++i) {
+    inputhalf[i] =
+        HIGHBD_WRAPLOW(highbd_dct_const_round_shift(input[i] * Sqrt2), bd);
+  }
+  for (i = 0; i < 32; ++i) {
+    output[i] = HIGHBD_WRAPLOW(
+        highbd_dct_const_round_shift(input[32 + i] * 4 * Sqrt2), bd);
+  }
+  aom_highbd_idct32_c(inputhalf, output + 32, bd);
+  // Note overall scaling factor is 4 * sqrt(2)  times orthogonal
+}
+
+static void highbd_idct64_col_c(const tran_low_t *input, tran_low_t *output,
+                                int bd) {
+  int32_t in[64], out[64];
+  int i;
+  (void)bd;
+  for (i = 0; i < 64; ++i) in[i] = (int32_t)input[i];
+  av1_idct64_new(in, out, inv_cos_bit_col_dct_dct_64,
+                 inv_stage_range_col_dct_dct_64);
+  for (i = 0; i < 64; ++i) output[i] = (tran_low_t)out[i];
+}
+
+static void highbd_idct64_row_c(const tran_low_t *input, tran_low_t *output,
+                                int bd) {
+  int32_t in[64], out[64];
+  int i;
+  (void)bd;
+  for (i = 0; i < 64; ++i) in[i] = (int32_t)input[i];
+  av1_idct64_new(in, out, inv_cos_bit_row_dct_dct_64,
+                 inv_stage_range_row_dct_dct_64);
+  for (i = 0; i < 64; ++i) output[i] = (tran_low_t)out[i];
+}
+#endif  // CONFIG_TX64X64
 #endif  // CONFIG_EXT_TX
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
@@ -793,10 +884,10 @@
     { iidtx32_c, iidtx32_c },            // IDTX
     { aom_idct32_c, iidtx32_c },         // V_DCT
     { iidtx32_c, aom_idct32_c },         // H_DCT
-    { ihalfright32_c, iidtx16_c },       // V_ADST
-    { iidtx16_c, ihalfright32_c },       // H_ADST
-    { ihalfright32_c, iidtx16_c },       // V_FLIPADST
-    { iidtx16_c, ihalfright32_c },       // H_FLIPADST
+    { ihalfright32_c, iidtx32_c },       // V_ADST
+    { iidtx32_c, ihalfright32_c },       // H_ADST
+    { ihalfright32_c, iidtx32_c },       // V_FLIPADST
+    { iidtx32_c, ihalfright32_c },       // H_FLIPADST
   };
 
   int i, j;
@@ -836,6 +927,68 @@
     }
   }
 }
+
+#if CONFIG_TX64X64
+void av1_iht64x64_4096_add_c(const tran_low_t *input, uint8_t *dest, int stride,
+                             int tx_type) {
+  static const transform_2d IHT_64[] = {
+    { idct64_col_c, idct64_row_c },      // DCT_DCT
+    { ihalfright64_c, idct64_row_c },    // ADST_DCT
+    { idct64_col_c, ihalfright64_c },    // DCT_ADST
+    { ihalfright64_c, ihalfright64_c },  // ADST_ADST
+    { ihalfright64_c, idct64_row_c },    // FLIPADST_DCT
+    { idct64_col_c, ihalfright64_c },    // DCT_FLIPADST
+    { ihalfright64_c, ihalfright64_c },  // FLIPADST_FLIPADST
+    { ihalfright64_c, ihalfright64_c },  // ADST_FLIPADST
+    { ihalfright64_c, ihalfright64_c },  // FLIPADST_ADST
+    { iidtx64_c, iidtx64_c },            // IDTX
+    { idct64_col_c, iidtx64_c },         // V_DCT
+    { iidtx64_c, idct64_row_c },         // H_DCT
+    { ihalfright64_c, iidtx64_c },       // V_ADST
+    { iidtx64_c, ihalfright64_c },       // H_ADST
+    { ihalfright64_c, iidtx64_c },       // V_FLIPADST
+    { iidtx64_c, ihalfright64_c },       // H_FLIPADST
+  };
+
+  int i, j;
+  tran_low_t tmp;
+  tran_low_t out[64][64];
+  tran_low_t *outp = &out[0][0];
+  int outstride = 64;
+
+  // inverse transform row vectors
+  for (i = 0; i < 64; ++i) {
+    IHT_64[tx_type].rows(input, out[i]);
+    for (j = 0; j < 64; ++j) out[i][j] = ROUND_POWER_OF_TWO(out[i][j], 1);
+    input += 64;
+  }
+
+  // transpose
+  for (i = 1; i < 64; i++) {
+    for (j = 0; j < i; j++) {
+      tmp = out[i][j];
+      out[i][j] = out[j][i];
+      out[j][i] = tmp;
+    }
+  }
+
+  // inverse transform column vectors
+  for (i = 0; i < 64; ++i) {
+    IHT_64[tx_type].cols(out[i], out[i]);
+  }
+
+  maybe_flip_strides(&dest, &stride, &outp, &outstride, tx_type, 64, 64);
+
+  // Sum with the destination
+  for (i = 0; i < 64; ++i) {
+    for (j = 0; j < 64; ++j) {
+      int d = i * stride + j;
+      int s = j * outstride + i;
+      dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
+    }
+  }
+}
+#endif  // CONFIG_TX64X64
 #endif  // CONFIG_EXT_TX
 
 // idct
@@ -1658,6 +1811,71 @@
     }
   }
 }
+
+#if CONFIG_TX64X64
+void av1_highbd_iht64x64_4096_add_c(const tran_low_t *input, uint8_t *dest8,
+                                    int stride, int tx_type, int bd) {
+  static const highbd_transform_2d HIGH_IHT_64[] = {
+    { highbd_idct64_col_c, highbd_idct64_row_c },      // DCT_DCT
+    { highbd_ihalfright64_c, highbd_idct64_row_c },    // ADST_DCT
+    { highbd_idct64_col_c, highbd_ihalfright64_c },    // DCT_ADST
+    { highbd_ihalfright64_c, highbd_ihalfright64_c },  // ADST_ADST
+    { highbd_ihalfright64_c, highbd_idct64_row_c },    // FLIPADST_DCT
+    { highbd_idct64_col_c, highbd_ihalfright64_c },    // DCT_FLIPADST
+    { highbd_ihalfright64_c, highbd_ihalfright64_c },  // FLIPADST_FLIPADST
+    { highbd_ihalfright64_c, highbd_ihalfright64_c },  // ADST_FLIPADST
+    { highbd_ihalfright64_c, highbd_ihalfright64_c },  // FLIPADST_ADST
+    { highbd_iidtx64_c, highbd_iidtx64_c },            // IDTX
+    { highbd_idct64_col_c, highbd_iidtx64_c },         // V_DCT
+    { highbd_iidtx64_c, highbd_idct64_row_c },         // H_DCT
+    { highbd_ihalfright64_c, highbd_iidtx64_c },       // V_ADST
+    { highbd_iidtx64_c, highbd_ihalfright64_c },       // H_ADST
+    { highbd_ihalfright64_c, highbd_iidtx64_c },       // V_FLIPADST
+    { highbd_iidtx64_c, highbd_ihalfright64_c },       // H_FLIPADST
+  };
+
+  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
+
+  int i, j;
+  tran_low_t tmp;
+  tran_low_t out[64][64];
+  tran_low_t *outp = &out[0][0];
+  int outstride = 64;
+
+  // inverse transform row vectors
+  for (i = 0; i < 64; ++i) {
+    HIGH_IHT_64[tx_type].rows(input, out[i], bd);
+    for (j = 0; j < 64; ++j) out[i][j] = ROUND_POWER_OF_TWO(out[i][j], 1);
+    input += 64;
+  }
+
+  // transpose
+  for (i = 1; i < 64; i++) {
+    for (j = 0; j < i; j++) {
+      tmp = out[i][j];
+      out[i][j] = out[j][i];
+      out[j][i] = tmp;
+    }
+  }
+
+  // inverse transform column vectors
+  for (i = 0; i < 64; ++i) {
+    HIGH_IHT_64[tx_type].cols(out[i], out[i], bd);
+  }
+
+  maybe_flip_strides16(&dest, &stride, &outp, &outstride, tx_type, 64, 64);
+
+  // Sum with the destination
+  for (i = 0; i < 64; ++i) {
+    for (j = 0; j < 64; ++j) {
+      int d = i * stride + j;
+      int s = j * outstride + i;
+      dest[d] =
+          highbd_clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5), bd);
+    }
+  }
+}
+#endif  // CONFIG_TX64X64
 #endif  // CONFIG_EXT_TX
 
 // idct
diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c
index 221e3cd..dd4031f 100644
--- a/av1/encoder/dct.c
+++ b/av1/encoder/dct.c
@@ -18,6 +18,8 @@
 #include "aom_dsp/fwd_txfm.h"
 #include "aom_ports/mem.h"
 #include "av1/common/blockd.h"
+#include "av1/common/av1_fwd_txfm1d.h"
+#include "av1/common/av1_fwd_txfm2d_cfg.h"
 #include "av1/common/idct.h"
 
 static INLINE void range_check(const tran_low_t *input, const int size,
@@ -1874,12 +1876,103 @@
   }
 }
 
+#if CONFIG_TX64X64
+#if CONFIG_EXT_TX
+static void fidtx64(const tran_low_t *input, tran_low_t *output) {
+  int i;
+  for (i = 0; i < 64; ++i)
+    output[i] = (tran_low_t)fdct_round_shift(input[i] * 4 * Sqrt2);
+}
+
+// For use in lieu of ADST
+static void fhalfright64(const tran_low_t *input, tran_low_t *output) {
+  int i;
+  tran_low_t inputhalf[32];
+  for (i = 0; i < 32; ++i) {
+    output[32 + i] = (tran_low_t)fdct_round_shift(input[i] * 4 * Sqrt2);
+  }
+  // Multiply input by sqrt(2)
+  for (i = 0; i < 32; ++i) {
+    inputhalf[i] = (tran_low_t)fdct_round_shift(input[i + 32] * Sqrt2);
+  }
+  fdct32(inputhalf, output);
+  // Note overall scaling factor is 2 times unitary
+}
+#endif  // CONFIG_EXT_TX
+
+static void fdct64_col(const tran_low_t *input, tran_low_t *output) {
+  int32_t in[64], out[64];
+  int i;
+  for (i = 0; i < 64; ++i) in[i] = (int32_t)input[i];
+  av1_fdct64_new(in, out, fwd_cos_bit_col_dct_dct_64,
+                 fwd_stage_range_col_dct_dct_64);
+  for (i = 0; i < 64; ++i) output[i] = (tran_low_t)out[i];
+}
+
+static void fdct64_row(const tran_low_t *input, tran_low_t *output) {
+  int32_t in[64], out[64];
+  int i;
+  for (i = 0; i < 64; ++i) in[i] = (int32_t)input[i];
+  av1_fdct64_new(in, out, fwd_cos_bit_row_dct_dct_64,
+                 fwd_stage_range_row_dct_dct_64);
+  for (i = 0; i < 64; ++i) output[i] = (tran_low_t)out[i];
+}
+
+void av1_fht64x64_c(const int16_t *input, tran_low_t *output, int stride,
+                    int tx_type) {
+  static const transform_2d FHT[] = {
+    { fdct64_col, fdct64_row },  // DCT_DCT
+#if CONFIG_EXT_TX
+    { fhalfright64, fdct64_row },    // ADST_DCT
+    { fdct64_col, fhalfright64 },    // DCT_ADST
+    { fhalfright64, fhalfright64 },  // ADST_ADST
+    { fhalfright64, fdct64_row },    // FLIPADST_DCT
+    { fdct64_col, fhalfright64 },    // DCT_FLIPADST
+    { fhalfright64, fhalfright64 },  // FLIPADST_FLIPADST
+    { fhalfright64, fhalfright64 },  // ADST_FLIPADST
+    { fhalfright64, fhalfright64 },  // FLIPADST_ADST
+    { fidtx64, fidtx64 },            // IDTX
+    { fdct64_col, fidtx64 },         // V_DCT
+    { fidtx64, fdct64_row },         // H_DCT
+    { fhalfright64, fidtx64 },       // V_ADST
+    { fidtx64, fhalfright64 },       // H_ADST
+    { fhalfright64, fidtx64 },       // V_FLIPADST
+    { fidtx64, fhalfright64 },       // H_FLIPADST
+#endif
+  };
+  const transform_2d ht = FHT[tx_type];
+  tran_low_t out[4096];
+  int i, j;
+  tran_low_t temp_in[64], temp_out[64];
+#if CONFIG_EXT_TX
+  int16_t flipped_input[64 * 64];
+  maybe_flip_input(&input, &stride, 64, 64, flipped_input, tx_type);
+#endif
+  // Columns
+  for (i = 0; i < 64; ++i) {
+    for (j = 0; j < 64; ++j) temp_in[j] = input[j * stride + i];
+    ht.cols(temp_in, temp_out);
+    for (j = 0; j < 64; ++j)
+      out[j * 64 + i] = (temp_out[j] + 1 + (temp_out[j] > 0)) >> 2;
+  }
+
+  // Rows
+  for (i = 0; i < 64; ++i) {
+    for (j = 0; j < 64; ++j) temp_in[j] = out[j + i * 64];
+    ht.rows(temp_in, temp_out);
+    for (j = 0; j < 64; ++j)
+      output[j + i * 64] =
+          (tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2);
+  }
+}
+#endif  // CONFIG_TX64X64
+
 #if CONFIG_EXT_TX
 // Forward identity transform.
 void av1_fwd_idtx_c(const int16_t *src_diff, tran_low_t *coeff, int stride,
                     int bs, int tx_type) {
   int r, c;
-  const int shift = bs < 32 ? 3 : 2;
+  const int shift = bs < 32 ? 3 : (bs < 64 ? 2 : 1);
   if (tx_type == IDTX) {
     for (r = 0; r < bs; ++r) {
       for (c = 0; c < bs; ++c) coeff[c] = src_diff[c] * (1 << shift);
@@ -1894,5 +1987,12 @@
                            int tx_type) {
   av1_fht32x32_c(input, output, stride, tx_type);
 }
+
+#if CONFIG_TX64X64
+void av1_highbd_fht64x64_c(const int16_t *input, tran_low_t *output, int stride,
+                           int tx_type) {
+  av1_fht64x64_c(input, output, stride, tx_type);
+}
+#endif  // CONFIG_TX64X64
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 #endif  // CONFIG_EXT_TX