Add inv txfm2d 64 sse2

Implement av1_lowbd_inv_txfm2d_add_32x64_sse2
Implement av1_lowbd_inv_txfm2d_add_64x32_sse2
Implement av1_lowbd_inv_txfm2d_add_16x64_sse2
Implement av1_lowbd_inv_txfm2d_add_64x16_sse2

Change-Id: I1b27618f153583cc787e7bf6ef1616e7c6932990
diff --git a/av1/common/x86/av1_inv_txfm_sse2.c b/av1/common/x86/av1_inv_txfm_sse2.c
index 5e94195..50e0c4b 100644
--- a/av1/common/x86/av1_inv_txfm_sse2.c
+++ b/av1/common/x86/av1_inv_txfm_sse2.c
@@ -1634,6 +1634,26 @@
   }
 }
 
+#if CONFIG_TX64X64
+static void iidentity64_new_sse2(const __m128i *input, __m128i *output,
+                                 int8_t cos_bit) {
+  (void)cos_bit;
+  const __m128i scale = _mm_set1_epi16(4 * NewSqrt2);
+  const __m128i rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
+  const __m128i one = _mm_set1_epi16(1);
+  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
+  for (int i = 0; i < 64; ++i) {
+    __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
+    __m128i a_hi = _mm_unpackhi_epi16(input[i], one);
+    __m128i b_lo = _mm_madd_epi16(a_lo, scale_rounding);
+    __m128i b_hi = _mm_madd_epi16(a_hi, scale_rounding);
+    __m128i c_lo = _mm_srai_epi32(b_lo, NewSqrt2Bits);
+    __m128i c_hi = _mm_srai_epi32(b_hi, NewSqrt2Bits);
+    output[i] = _mm_packs_epi32(c_lo, c_hi);
+  }
+}
+#endif
+
 static INLINE __m128i lowbd_get_recon_8x8_sse2(const __m128i pred,
                                                __m128i res) {
   const __m128i zero = _mm_setzero_si128();
@@ -1659,7 +1679,7 @@
   { idct16_new_sse2, iadst16_new_sse2, iadst16_new_sse2, iidentity16_new_sse2 },
   { idct32_new_sse2, NULL, NULL, iidentity32_new_sse2 },
 #if CONFIG_TX64X64
-  { idct64_new_sse2, NULL, NULL, NULL },
+  { idct64_new_sse2, NULL, NULL, iidentity64_new_sse2 },
 #endif
 };
 
@@ -1816,6 +1836,28 @@
   lowbd_inv_txfm2d_add_internal_sse2(input, output, stride, tx_type, TX_32X32);
 }
 
+#if CONFIG_TX64X64
+void av1_lowbd_inv_txfm2d_add_64x64_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  // TODO(binpengsmail@gmail.com):
+  // Potential optimization to take advantage of zeros outside
+  // of the top-left block, and same for all other TX_SIZE with 64
+
+  // Remap 32x32 input into a modified 64x64 by:
+  // - Copying over these values in top-left 32x32 locations.
+  // - Setting the rest of the locations to 0.
+  int32_t mod_input[64 * 64];
+  for (int row = 0; row < 32; ++row) {
+    memcpy(mod_input + row * 64, input + row * 32, 32 * sizeof(*mod_input));
+    memset(mod_input + row * 64 + 32, 0, 32 * sizeof(*mod_input));
+  }
+  memset(mod_input + 32 * 64, 0, 32 * 64 * sizeof(*mod_input));
+  lowbd_inv_txfm2d_add_internal_sse2(mod_input, output, stride, tx_type,
+                                     TX_64X64);
+}
+#endif
+
 void av1_lowbd_inv_txfm2d_add_8x16_sse2(const int32_t *input, uint8_t *output,
                                         int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
@@ -1840,6 +1882,36 @@
   lowbd_inv_txfm2d_add_internal_sse2(input, output, stride, tx_type, TX_32X16);
 }
 
+#if CONFIG_TX64X64
+void av1_lowbd_inv_txfm2d_add_32x64_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  // Remap 32x32 input into a modified 32x64 input by:
+  // - Copying over these values in top-left 32x32 locations.
+  // - Setting the rest of the locations to 0.
+  DECLARE_ALIGNED(32, int32_t, mod_input[32 * 64]);
+  memcpy(mod_input, input, 32 * 32 * sizeof(*mod_input));
+  memset(mod_input + 32 * 32, 0, 32 * 32 * sizeof(*mod_input));
+  lowbd_inv_txfm2d_add_internal_sse2(mod_input, output, stride, tx_type,
+                                     TX_32X64);
+}
+
+void av1_lowbd_inv_txfm2d_add_64x32_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  // Remap 32x32 input into a modified 64x32 by:
+  // - Copying over these values in top-left 32x32 locations.
+  // - Setting the rest of the locations to 0.
+  DECLARE_ALIGNED(32, int32_t, mod_input[64 * 32]);
+  for (int row = 0; row < 32; ++row) {
+    memcpy(mod_input + row * 64, input + row * 32, 32 * sizeof(*mod_input));
+    memset(mod_input + row * 64 + 32, 0, 32 * sizeof(*mod_input));
+  }
+  lowbd_inv_txfm2d_add_internal_sse2(mod_input, output, stride, tx_type,
+                                     TX_64X32);
+}
+#endif
+
 void av1_lowbd_inv_txfm2d_add_8x32_sse2(const int32_t *input, uint8_t *output,
                                         int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
@@ -1851,6 +1923,35 @@
   (void)bd;
   lowbd_inv_txfm2d_add_internal_sse2(input, output, stride, tx_type, TX_32X8);
 }
+#if CONFIG_TX64X64
+void av1_lowbd_inv_txfm2d_add_16x64_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  // Remap 16x32 input into a modified 16x64 input by:
+  // - Copying over these values in top-left 16x32 locations.
+  // - Setting the rest of the locations to 0.
+  DECLARE_ALIGNED(32, int32_t, mod_input[16 * 64]);
+  memcpy(mod_input, input, 16 * 32 * sizeof(*mod_input));
+  memset(mod_input + 16 * 32, 0, 16 * 32 * sizeof(*mod_input));
+  lowbd_inv_txfm2d_add_internal_sse2(mod_input, output, stride, tx_type,
+                                     TX_16X64);
+}
+
+void av1_lowbd_inv_txfm2d_add_64x16_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  // Remap 32x16 input into a modified 64x16 by:
+  // - Copying over these values in top-left 32x16 locations.
+  // - Setting the rest of the locations to 0.
+  DECLARE_ALIGNED(32, int32_t, mod_input[64 * 16]);
+  for (int row = 0; row < 16; ++row) {
+    memcpy(mod_input + row * 64, input + row * 32, 32 * sizeof(*mod_input));
+    memset(mod_input + row * 64 + 32, 0, 32 * sizeof(*mod_input));
+  }
+  lowbd_inv_txfm2d_add_internal_sse2(mod_input, output, stride, tx_type,
+                                     TX_64X16);
+}
+#endif
 
 typedef void (*inv_txfm_func)(const int32_t *input, uint8_t *output, int stride,
                               TX_TYPE tx_type, int bd);
@@ -1861,7 +1962,7 @@
   av1_lowbd_inv_txfm2d_add_16x16_sse2,  // 16x16
   av1_lowbd_inv_txfm2d_add_32x32_sse2,  // 32x32
 #if CONFIG_TX64X64
-  NULL,                                 // 64x64
+  av1_lowbd_inv_txfm2d_add_64x64_sse2,  // 64x64
 #endif                                  // CONFIG_TX64X64
   NULL,                                 // 4x8
   NULL,                                 // 8x4
@@ -1870,17 +1971,17 @@
   av1_lowbd_inv_txfm2d_add_16x32_sse2,  // 16x32
   av1_lowbd_inv_txfm2d_add_32x16_sse2,  // 32x16
 #if CONFIG_TX64X64
-  NULL,                                // 32x64
-  NULL,                                // 64x32
-#endif                                 // CONFIG_TX64X64
-  NULL,                                // 4x16
-  NULL,                                // 16x4
-  av1_lowbd_inv_txfm2d_add_8x32_sse2,  // 8x32
-  av1_lowbd_inv_txfm2d_add_32x8_sse2,  // 32x8
+  av1_lowbd_inv_txfm2d_add_32x64_sse2,  // 32x64
+  av1_lowbd_inv_txfm2d_add_64x32_sse2,  // 64x32
+#endif                                  // CONFIG_TX64X64
+  NULL,                                 // 4x16
+  NULL,                                 // 16x4
+  av1_lowbd_inv_txfm2d_add_8x32_sse2,   // 8x32
+  av1_lowbd_inv_txfm2d_add_32x8_sse2,   // 32x8
 #if CONFIG_TX64X64
-  NULL,  // 16x64
-  NULL,  // 64x16
-#endif   // CONFIG_TX64X64
+  av1_lowbd_inv_txfm2d_add_16x64_sse2,  // 16x64
+  av1_lowbd_inv_txfm2d_add_64x16_sse2,  // 64x16
+#endif                                  // CONFIG_TX64X64
 };
 
 void av1_inv_txfm_add_sse2(const tran_low_t *dqcoeff, uint8_t *dst, int stride,
diff --git a/av1/common/x86/av1_txfm_sse2.h b/av1/common/x86/av1_txfm_sse2.h
index a10346b..a924003 100644
--- a/av1/common/x86/av1_txfm_sse2.h
+++ b/av1/common/x86/av1_txfm_sse2.h
@@ -203,6 +203,11 @@
 void av1_lowbd_inv_txfm2d_add_32x32_sse2(const int32_t *input, uint8_t *output,
                                          int stride, TX_TYPE tx_type, int bd);
 
+#if CONFIG_TX64X64
+void av1_lowbd_inv_txfm2d_add_64x64_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd);
+#endif
+
 void av1_lowbd_inv_txfm2d_add_8x16_sse2(const int32_t *input, uint8_t *output,
                                         int stride, TX_TYPE tx_type, int bd);
 
@@ -215,12 +220,28 @@
 void av1_lowbd_inv_txfm2d_add_32x16_sse2(const int32_t *input, uint8_t *output,
                                          int stride, TX_TYPE tx_type, int bd);
 
+#if CONFIG_TX64X64
+void av1_lowbd_inv_txfm2d_add_32x64_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd);
+
+void av1_lowbd_inv_txfm2d_add_64x32_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd);
+#endif
+
 void av1_lowbd_inv_txfm2d_add_8x32_sse2(const int32_t *input, uint8_t *output,
                                         int stride, TX_TYPE tx_type, int bd);
 
 void av1_lowbd_inv_txfm2d_add_32x8_sse2(const int32_t *input, uint8_t *output,
                                         int stride, TX_TYPE tx_type, int bd);
 
+#if CONFIG_TX64X64
+void av1_lowbd_inv_txfm2d_add_16x64_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd);
+
+void av1_lowbd_inv_txfm2d_add_64x16_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd);
+#endif
+
 #ifdef __cplusplus
 }
 #endif  // __cplusplus
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index e058222..5d56cdc 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -235,9 +235,9 @@
     const int cols = tx_size_high[tx_size_];
     const TX_TYPE_1D vtype = vtx_tab[tx_type];
     const TX_TYPE_1D htype = htx_tab[tx_type];
-    if (rows == 32 && (htype == ADST_1D || htype == FLIPADST_1D)) {
+    if (rows >= 32 && (htype == ADST_1D || htype == FLIPADST_1D)) {
       return false;
-    } else if (cols == 32 && (vtype == ADST_1D || vtype == FLIPADST_1D)) {
+    } else if (cols >= 32 && (vtype == ADST_1D || vtype == FLIPADST_1D)) {
       return false;
     }
     return true;
@@ -334,7 +334,7 @@
   av1_lowbd_inv_txfm2d_add_16x16_sse2,  // TX_16X16
   av1_lowbd_inv_txfm2d_add_32x32_sse2,  // TX_32X32
 #if CONFIG_TX64X64
-  NULL,                                 // TX_64X64
+  av1_lowbd_inv_txfm2d_add_64x64_sse2,  // 64x64
 #endif                                  // CONFIG_TX64X64
   NULL,                                 // TX_4X8
   NULL,                                 // TX_8X4
@@ -343,17 +343,17 @@
   av1_lowbd_inv_txfm2d_add_16x32_sse2,  // TX_16X32
   av1_lowbd_inv_txfm2d_add_32x16_sse2,  // TX_32X16
 #if CONFIG_TX64X64
-  NULL,                                // TX_32X64
-  NULL,                                // TX_64X32
-#endif                                 // CONFIG_TX64X64
-  NULL,                                // TX_4X16
-  NULL,                                // TX_16X4
-  av1_lowbd_inv_txfm2d_add_8x32_sse2,  // 8x32
-  av1_lowbd_inv_txfm2d_add_32x8_sse2,  // 32x8
+  av1_lowbd_inv_txfm2d_add_32x64_sse2,  // TX_32X64
+  av1_lowbd_inv_txfm2d_add_64x32_sse2,  // TX_64X32
+#endif                                  // CONFIG_TX64X64
+  NULL,                                 // TX_4X16
+  NULL,                                 // TX_16X4
+  av1_lowbd_inv_txfm2d_add_8x32_sse2,   // 8x32
+  av1_lowbd_inv_txfm2d_add_32x8_sse2,   // 32x8
 #if CONFIG_TX64X64
-  NULL,  // TX_16X64
-  NULL,  // TX_64X16
-#endif   // CONFIG_TX64X64
+  av1_lowbd_inv_txfm2d_add_16x64_sse2,  // 16x64
+  av1_lowbd_inv_txfm2d_add_64x16_sse2,  // 64x16
+#endif                                  // CONFIG_TX64X64
 };
 INSTANTIATE_TEST_CASE_P(SSE2, AV1LbdInvTxfm2d,
                         Combine(Values(kLbdInvFuncSSE2List),