Add eob logic for highbd 64x64 inv_txfm

When tested for 10 frames of crowd_run_1080p_10
for speed=1 preset, observed 0.15% reduction.

Achieved module level time reduction by ~13%.

Code clean-up of functions which are unused.

Change-Id: I83ee2f27270f0b9619f8ecadffd778b5bdc82b01
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 988a172..ccb4f8c 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -119,8 +119,6 @@
 specialize qw/av1_highbd_inv_txfm_add_16x16 sse4_1/;
 add_proto qw/void av1_highbd_inv_txfm_add_32x32/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
 specialize qw/av1_highbd_inv_txfm_add_32x32 avx2/;
-add_proto qw/void av1_highbd_inv_txfm_add_64x64/,  "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_64x64 sse4_1/;
 
 add_proto qw/void av1_highbd_iwht4x4_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
 add_proto qw/void av1_highbd_iwht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
@@ -146,8 +144,6 @@
 add_proto qw/void av1_inv_txfm2d_add_16x64/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
 add_proto qw/void av1_inv_txfm2d_add_64x16/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
 
-specialize qw/av1_inv_txfm2d_add_64x64 sse4_1/;
-
 add_proto qw/void av1_inv_txfm2d_add_4x16/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
 add_proto qw/void av1_inv_txfm2d_add_16x4/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
 add_proto qw/void av1_inv_txfm2d_add_8x32/, "const int32_t *input, uint16_t *output, int stride, TX_TYPE tx_type, int bd";
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c
index ae331b4..5db2ccf 100644
--- a/av1/common/x86/av1_inv_txfm_avx2.c
+++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -18,6 +18,12 @@
 #include "av1/common/x86/av1_inv_txfm_avx2.h"
 #include "av1/common/x86/av1_inv_txfm_ssse3.h"
 
+// TODO(venkatsanampudi@ittiam.com): move this to header file
+
+// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5
+static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096,
+                                          4 * 5793 };
+
 static INLINE void idct16_stage5_avx2(__m256i *x1, const int32_t *cospi,
                                       const __m256i _r, int8_t cos_bit) {
   const __m256i cospi_m32_p32 = pair_set_w16_epi16(-cospi[32], cospi[32]);
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c
index dd7cee2..995bc3d 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.c
+++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -16,6 +16,12 @@
 #include "av1/common/x86/av1_inv_txfm_ssse3.h"
 #include "av1/common/x86/av1_txfm_sse2.h"
 
+// TODO(venkatsanampudi@ittiam.com): move this to header file
+
+// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5
+static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096,
+                                          4 * 5793 };
+
 // TODO(binpengsmail@gmail.com): replace some for loop with do {} while
 
 static void idct4_new_sse2(const __m128i *input, __m128i *output,
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.h b/av1/common/x86/av1_inv_txfm_ssse3.h
index dc9be25..0c5658c 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.h
+++ b/av1/common/x86/av1_inv_txfm_ssse3.h
@@ -94,10 +94,6 @@
   IIDENTITY_1D, IADST_1D,     IIDENTITY_1D, IFLIPADST_1D,
 };
 
-// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5
-static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096,
-                                          4 * 5793 };
-
 DECLARE_ALIGNED(16, static const int16_t, av1_eob_to_eobxy_8x8_default[8]) = {
   0x0707, 0x0707, 0x0707, 0x0707, 0x0707, 0x0707, 0x0707, 0x0707,
 };
diff --git a/av1/common/x86/highbd_inv_txfm_avx2.c b/av1/common/x86/highbd_inv_txfm_avx2.c
index 49ba404..36232cc 100644
--- a/av1/common/x86/highbd_inv_txfm_avx2.c
+++ b/av1/common/x86/highbd_inv_txfm_avx2.c
@@ -16,6 +16,7 @@
 
 #include "av1/common/av1_inv_txfm1d_cfg.h"
 #include "av1/common/idct.h"
+#include "av1/common/x86/highbd_txfm_utility_sse4.h"
 
 // Note:
 //  Total 32x4 registers to represent 32x32 block coefficients.
@@ -723,9 +724,6 @@
     case TX_32X16:
       av1_highbd_inv_txfm_add_32x16(input, dest, stride, txfm_param);
       break;
-    case TX_64X64:
-      av1_highbd_inv_txfm_add_64x64_sse4_1(input, dest, stride, txfm_param);
-      break;
     case TX_32X64:
       av1_highbd_inv_txfm_add_32x64(input, dest, stride, txfm_param);
       break;
@@ -753,6 +751,11 @@
     case TX_32X8:
       av1_highbd_inv_txfm_add_32x8(input, dest, stride, txfm_param);
       break;
+    case TX_64X64:
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(
+          input, dest, stride, txfm_param->tx_type, txfm_param->tx_size,
+          txfm_param->eob, txfm_param->bd);
+      break;
     default: assert(0 && "Invalid transform size"); break;
   }
 }
diff --git a/av1/common/x86/highbd_inv_txfm_sse4.c b/av1/common/x86/highbd_inv_txfm_sse4.c
index 59c1aec..a7057f4 100644
--- a/av1/common/x86/highbd_inv_txfm_sse4.c
+++ b/av1/common/x86/highbd_inv_txfm_sse4.c
@@ -15,8 +15,59 @@
 #include "config/av1_rtcd.h"
 
 #include "av1/common/av1_inv_txfm1d_cfg.h"
-#include "av1/common/x86/highbd_txfm_utility_sse4.h"
 #include "av1/common/idct.h"
+#include "av1/common/x86/av1_inv_txfm_ssse3.h"
+#include "av1/common/x86/av1_txfm_sse4.h"
+#include "av1/common/x86/highbd_txfm_utility_sse4.h"
+
+static INLINE __m128i highbd_clamp_epi16(__m128i u, int bd) {
+  const __m128i zero = _mm_setzero_si128();
+  const __m128i one = _mm_set1_epi16(1);
+  const __m128i max = _mm_sub_epi16(_mm_slli_epi16(one, bd), one);
+  __m128i clamped, mask;
+
+  mask = _mm_cmpgt_epi16(u, max);
+  clamped = _mm_andnot_si128(mask, u);
+  mask = _mm_and_si128(mask, max);
+  clamped = _mm_or_si128(mask, clamped);
+  mask = _mm_cmpgt_epi16(clamped, zero);
+  clamped = _mm_and_si128(clamped, mask);
+
+  return clamped;
+}
+
+static INLINE __m128i highbd_get_recon_8x8_sse4_1(const __m128i pred,
+                                                  __m128i res0, __m128i res1,
+                                                  const int bd) {
+  __m128i x0 = _mm_cvtepi16_epi32(pred);
+  __m128i x1 = _mm_cvtepi16_epi32(_mm_srli_si128(pred, 8));
+
+  x0 = _mm_add_epi32(res0, x0);
+  x1 = _mm_add_epi32(res1, x1);
+  x0 = _mm_packus_epi32(x0, x1);
+  x0 = highbd_clamp_epi16(x0, bd);
+  return x0;
+}
+
+static INLINE void highbd_write_buffer_8xn_sse4_1(__m128i *in, uint16_t *output,
+                                                  int stride, int flipud,
+                                                  int height, const int bd) {
+  int j = flipud ? (height - 1) : 0;
+  const int step = flipud ? -1 : 1;
+  for (int i = 0; i < height; ++i, j += step) {
+    __m128i v = _mm_loadu_si128((__m128i const *)(output + i * stride));
+    __m128i u = highbd_get_recon_8x8_sse4_1(v, in[j], in[j + height], bd);
+
+    _mm_storeu_si128((__m128i *)(output + i * stride), u);
+  }
+}
+
+static INLINE void load_buffer_32bit_input(const int32_t *in, int stride,
+                                           __m128i *out, int out_size) {
+  for (int i = 0; i < out_size; ++i) {
+    out[i] = _mm_loadu_si128((const __m128i *)(in + i * stride));
+  }
+}
 
 static INLINE void load_buffer_4x4(const int32_t *coeff, __m128i *in) {
   in[0] = _mm_load_si128((const __m128i *)(coeff + 0));
@@ -238,22 +289,6 @@
   in[3] = _mm_srai_epi32(in[3], shift);
 }
 
-static INLINE __m128i highbd_clamp_epi16(__m128i u, int bd) {
-  const __m128i zero = _mm_setzero_si128();
-  const __m128i one = _mm_set1_epi16(1);
-  const __m128i max = _mm_sub_epi16(_mm_slli_epi16(one, bd), one);
-  __m128i clamped, mask;
-
-  mask = _mm_cmpgt_epi16(u, max);
-  clamped = _mm_andnot_si128(mask, u);
-  mask = _mm_and_si128(mask, max);
-  clamped = _mm_or_si128(mask, clamped);
-  mask = _mm_cmpgt_epi16(clamped, zero);
-  clamped = _mm_and_si128(clamped, mask);
-
-  return clamped;
-}
-
 static void write_buffer_4x4(__m128i *in, uint16_t *output, int stride,
                              int fliplr, int flipud, int shift, int bd) {
   const __m128i zero = _mm_setzero_si128();
@@ -1758,132 +1793,6 @@
   }
 }
 
-static void load_buffer_64x64_lower_32x32(const int32_t *coeff, __m128i *in) {
-  int i, j;
-
-  __m128i zero = _mm_setzero_si128();
-
-  for (i = 0; i < 32; ++i) {
-    for (j = 0; j < 8; ++j) {
-      in[16 * i + j] =
-          _mm_loadu_si128((const __m128i *)(coeff + 32 * i + 4 * j));
-      in[16 * i + j + 8] = zero;
-    }
-  }
-
-  for (i = 0; i < 512; ++i) in[512 + i] = zero;
-}
-
-static void transpose_64x64(__m128i *in, __m128i *out, int do_cols) {
-  int i, j;
-  for (i = 0; i < (do_cols ? 16 : 8); ++i) {
-    for (j = 0; j < 8; ++j) {
-      TRANSPOSE_4X4(in[(4 * i + 0) * 16 + j], in[(4 * i + 1) * 16 + j],
-                    in[(4 * i + 2) * 16 + j], in[(4 * i + 3) * 16 + j],
-                    out[(4 * j + 0) * 16 + i], out[(4 * j + 1) * 16 + i],
-                    out[(4 * j + 2) * 16 + i], out[(4 * j + 3) * 16 + i]);
-    }
-  }
-}
-
-static void assign_16x16_input_from_32x32(const __m128i *in, __m128i *in16x16,
-                                          int col) {
-  int i;
-  for (i = 0; i < 16 * 16 / 4; i += 4) {
-    in16x16[i] = in[col];
-    in16x16[i + 1] = in[col + 1];
-    in16x16[i + 2] = in[col + 2];
-    in16x16[i + 3] = in[col + 3];
-    col += 8;
-  }
-}
-
-static void write_buffer_32x32(__m128i *in, uint16_t *output, int stride,
-                               int fliplr, int flipud, int shift, int bd) {
-  __m128i in16x16[16 * 16 / 4];
-  uint16_t *leftUp = &output[0];
-  uint16_t *rightUp = &output[16];
-  uint16_t *leftDown = &output[16 * stride];
-  uint16_t *rightDown = &output[16 * stride + 16];
-
-  if (fliplr) {
-    swap_addr(&leftUp, &rightUp);
-    swap_addr(&leftDown, &rightDown);
-  }
-
-  if (flipud) {
-    swap_addr(&leftUp, &leftDown);
-    swap_addr(&rightUp, &rightDown);
-  }
-
-  // Left-up quarter
-  assign_16x16_input_from_32x32(in, in16x16, 0);
-  write_buffer_16x16(in16x16, leftUp, stride, fliplr, flipud, shift, bd);
-
-  // Right-up quarter
-  assign_16x16_input_from_32x32(in, in16x16, 32 / 2 / 4);
-  write_buffer_16x16(in16x16, rightUp, stride, fliplr, flipud, shift, bd);
-
-  // Left-down quarter
-  assign_16x16_input_from_32x32(in, in16x16, 32 * 32 / 2 / 4);
-  write_buffer_16x16(in16x16, leftDown, stride, fliplr, flipud, shift, bd);
-
-  // Right-down quarter
-  assign_16x16_input_from_32x32(in, in16x16, 32 * 32 / 2 / 4 + 32 / 2 / 4);
-  write_buffer_16x16(in16x16, rightDown, stride, fliplr, flipud, shift, bd);
-}
-
-static void assign_32x32_input_from_64x64(const __m128i *in, __m128i *in32x32,
-                                          int col) {
-  int i;
-  for (i = 0; i < 32 * 32 / 4; i += 8) {
-    in32x32[i] = in[col];
-    in32x32[i + 1] = in[col + 1];
-    in32x32[i + 2] = in[col + 2];
-    in32x32[i + 3] = in[col + 3];
-    in32x32[i + 4] = in[col + 4];
-    in32x32[i + 5] = in[col + 5];
-    in32x32[i + 6] = in[col + 6];
-    in32x32[i + 7] = in[col + 7];
-    col += 16;
-  }
-}
-
-static void write_buffer_64x64(__m128i *in, uint16_t *output, int stride,
-                               int fliplr, int flipud, int shift, int bd) {
-  __m128i in32x32[32 * 32 / 4];
-  uint16_t *leftUp = &output[0];
-  uint16_t *rightUp = &output[32];
-  uint16_t *leftDown = &output[32 * stride];
-  uint16_t *rightDown = &output[32 * stride + 32];
-
-  if (fliplr) {
-    swap_addr(&leftUp, &rightUp);
-    swap_addr(&leftDown, &rightDown);
-  }
-
-  if (flipud) {
-    swap_addr(&leftUp, &leftDown);
-    swap_addr(&rightUp, &rightDown);
-  }
-
-  // Left-up quarter
-  assign_32x32_input_from_64x64(in, in32x32, 0);
-  write_buffer_32x32(in32x32, leftUp, stride, fliplr, flipud, shift, bd);
-
-  // Right-up quarter
-  assign_32x32_input_from_64x64(in, in32x32, 64 / 2 / 4);
-  write_buffer_32x32(in32x32, rightUp, stride, fliplr, flipud, shift, bd);
-
-  // Left-down quarter
-  assign_32x32_input_from_64x64(in, in32x32, 64 * 64 / 2 / 4);
-  write_buffer_32x32(in32x32, leftDown, stride, fliplr, flipud, shift, bd);
-
-  // Right-down quarter
-  assign_32x32_input_from_64x64(in, in32x32, 64 * 64 / 2 / 4 + 64 / 2 / 4);
-  write_buffer_32x32(in32x32, rightDown, stride, fliplr, flipud, shift, bd);
-}
-
 static void idct64x64_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols,
                              int bd, int out_shift) {
   int i, j;
@@ -1892,7 +1801,6 @@
   const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
   const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1)));
   const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1);
-  int col;
 
   const __m128i cospi1 = _mm_set1_epi32(cospi[1]);
   const __m128i cospi2 = _mm_set1_epi32(cospi[2]);
@@ -1974,46 +1882,46 @@
   const __m128i cospim60 = _mm_set1_epi32(-cospi[60]);
   const __m128i cospim61 = _mm_set1_epi32(-cospi[61]);
 
-  for (col = 0; col < (do_cols ? 64 / 4 : 32 / 4); ++col) {
+  {
     __m128i u[64], v[64];
 
     // stage 1
-    u[32] = in[1 * 16 + col];
-    u[34] = in[17 * 16 + col];
-    u[36] = in[9 * 16 + col];
-    u[38] = in[25 * 16 + col];
-    u[40] = in[5 * 16 + col];
-    u[42] = in[21 * 16 + col];
-    u[44] = in[13 * 16 + col];
-    u[46] = in[29 * 16 + col];
-    u[48] = in[3 * 16 + col];
-    u[50] = in[19 * 16 + col];
-    u[52] = in[11 * 16 + col];
-    u[54] = in[27 * 16 + col];
-    u[56] = in[7 * 16 + col];
-    u[58] = in[23 * 16 + col];
-    u[60] = in[15 * 16 + col];
-    u[62] = in[31 * 16 + col];
+    u[32] = in[1];
+    u[34] = in[17];
+    u[36] = in[9];
+    u[38] = in[25];
+    u[40] = in[5];
+    u[42] = in[21];
+    u[44] = in[13];
+    u[46] = in[29];
+    u[48] = in[3];
+    u[50] = in[19];
+    u[52] = in[11];
+    u[54] = in[27];
+    u[56] = in[7];
+    u[58] = in[23];
+    u[60] = in[15];
+    u[62] = in[31];
 
-    v[16] = in[2 * 16 + col];
-    v[18] = in[18 * 16 + col];
-    v[20] = in[10 * 16 + col];
-    v[22] = in[26 * 16 + col];
-    v[24] = in[6 * 16 + col];
-    v[26] = in[22 * 16 + col];
-    v[28] = in[14 * 16 + col];
-    v[30] = in[30 * 16 + col];
+    v[16] = in[2];
+    v[18] = in[18];
+    v[20] = in[10];
+    v[22] = in[26];
+    v[24] = in[6];
+    v[26] = in[22];
+    v[28] = in[14];
+    v[30] = in[30];
 
-    u[8] = in[4 * 16 + col];
-    u[10] = in[20 * 16 + col];
-    u[12] = in[12 * 16 + col];
-    u[14] = in[28 * 16 + col];
+    u[8] = in[4];
+    u[10] = in[20];
+    u[12] = in[12];
+    u[14] = in[28];
 
-    v[4] = in[8 * 16 + col];
-    v[6] = in[24 * 16 + col];
+    v[4] = in[8];
+    v[6] = in[24];
 
-    u[0] = in[0 * 16 + col];
-    u[2] = in[16 * 16 + col];
+    u[0] = in[0];
+    u[2] = in[16];
 
     // stage 2
     v[32] = half_btf_0_sse4_1(&cospi63, &u[32], &rnding, bit);
@@ -2346,8 +2254,7 @@
     // stage 11
     if (do_cols) {
       for (i = 0; i < 32; i++) {
-        addsub_no_clamp_sse4_1(v[i], v[63 - i], &out[16 * (i) + col],
-                               &out[16 * (63 - i) + col]);
+        addsub_no_clamp_sse4_1(v[i], v[63 - i], &out[(i)], &out[(63 - i)]);
       }
     } else {
       const int log_range_out = AOMMAX(16, bd + 6);
@@ -2357,49 +2264,13 @@
           (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
 
       for (i = 0; i < 32; i++) {
-        addsub_shift_sse4_1(v[i], v[63 - i], &out[16 * (i) + col],
-                            &out[16 * (63 - i) + col], &clamp_lo_out,
-                            &clamp_hi_out, out_shift);
+        addsub_shift_sse4_1(v[i], v[63 - i], &out[(i)], &out[(63 - i)],
+                            &clamp_lo_out, &clamp_hi_out, out_shift);
       }
     }
   }
 }
 
-void av1_inv_txfm2d_add_64x64_sse4_1(const int32_t *coeff, uint16_t *output,
-                                     int stride, TX_TYPE tx_type, int bd) {
-  __m128i in[64 * 64 / 4], out[64 * 64 / 4];
-  const int8_t *shift = inv_txfm_shift_ls[TX_64X64];
-  const int txw_idx = tx_size_wide_log2[TX_64X64] - tx_size_wide_log2[0];
-  const int txh_idx = tx_size_high_log2[TX_64X64] - tx_size_high_log2[0];
-
-  switch (tx_type) {
-    case DCT_DCT:
-      load_buffer_64x64_lower_32x32(coeff, in);
-      transpose_64x64(in, out, 0);
-      idct64x64_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd,
-                       -shift[0]);
-      transpose_64x64(in, out, 1);
-      idct64x64_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
-      write_buffer_64x64(in, output, stride, 0, 0, -shift[1], bd);
-      break;
-
-    default:
-      av1_inv_txfm2d_add_64x64_c(coeff, output, stride, tx_type, bd);
-      break;
-  }
-}
-
-void av1_highbd_inv_txfm_add_64x64_sse4_1(const tran_low_t *input,
-                                          uint8_t *dest, int stride,
-                                          const TxfmParam *txfm_param) {
-  const int bd = txfm_param->bd;
-  const TX_TYPE tx_type = txfm_param->tx_type;
-  const int32_t *src = cast_to_int32(input);
-  assert(tx_type == DCT_DCT);
-  av1_inv_txfm2d_add_64x64_sse4_1(src, CONVERT_TO_SHORTPTR(dest), stride,
-                                  tx_type, bd);
-}
-
 void av1_highbd_inv_txfm_add_8x8_sse4_1(const tran_low_t *input, uint8_t *dest,
                                         int stride,
                                         const TxfmParam *txfm_param) {
@@ -2486,6 +2357,115 @@
   }
 }
 
+static const transform_1d_sse4_1
+    highbd_txfm_all_1d_zeros_w8_arr[TX_SIZES][ITX_TYPES_1D][4] = {
+      {
+          { NULL, NULL, NULL, NULL },
+          { NULL, NULL, NULL, NULL },
+          { NULL, NULL, NULL, NULL },
+      },
+      { { NULL, NULL, NULL, NULL },
+        { NULL, NULL, NULL, NULL },
+        { NULL, NULL, NULL, NULL } },
+      {
+          { NULL, NULL, NULL, NULL },
+          { NULL, NULL, NULL, NULL },
+          { NULL, NULL, NULL, NULL },
+      },
+      { { NULL, NULL, NULL, NULL },
+        { NULL, NULL, NULL, NULL },
+        { NULL, NULL, NULL, NULL } },
+      { { idct64x64_sse4_1, idct64x64_sse4_1, idct64x64_sse4_1,
+          idct64x64_sse4_1 },
+        { NULL, NULL, NULL, NULL },
+        { NULL, NULL, NULL, NULL } }
+    };
+
+static void highbd_inv_txfm2d_add_no_identity_sse41(const int32_t *input,
+                                                    uint16_t *output,
+                                                    int stride, TX_TYPE tx_type,
+                                                    TX_SIZE tx_size, int eob,
+                                                    const int bd) {
+  __m128i buf1[64 * 16] = { { 0 } };
+  int eobx, eoby;
+  get_eobx_eoby_scan_default(&eobx, &eoby, tx_size, eob);
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int txfm_size_col = tx_size_wide[tx_size];
+  const int txfm_size_row = tx_size_high[tx_size];
+  const int buf_size_w_div8 = txfm_size_col >> 2;
+  const int buf_size_nonzero_w_div8 = (AOMMIN(32, txfm_size_col)) >> 2;
+  const int buf_size_nonzero_h_div8 = (eoby + 8) >> 3;
+  const int input_stride = AOMMIN(32, txfm_size_col);
+
+  const int fun_idx_x = lowbd_txfm_all_1d_zeros_idx[eobx];
+  const int fun_idx_y = lowbd_txfm_all_1d_zeros_idx[eoby];
+  const transform_1d_sse4_1 row_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txw_idx][hitx_1d_tab[tx_type]][fun_idx_x];
+  const transform_1d_sse4_1 col_txfm =
+      highbd_txfm_all_1d_zeros_w8_arr[txh_idx][vitx_1d_tab[tx_type]][fun_idx_y];
+
+  assert(col_txfm != NULL);
+  assert(row_txfm != NULL);
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+
+  // 1st stage: column transform
+  for (int i = 0; i < buf_size_nonzero_h_div8 << 1; i++) {
+    __m128i buf0[64];
+    const int32_t *input_row = input + i * input_stride * 4;
+    for (int j = 0; j < buf_size_nonzero_w_div8; ++j) {
+      __m128i *buf0_cur = buf0 + j * 4;
+      load_buffer_32bit_input(input_row + j * 4, input_stride, buf0_cur, 4);
+
+      TRANSPOSE_4X4(buf0_cur[0], buf0_cur[1], buf0_cur[2], buf0_cur[3],
+                    buf0_cur[0], buf0_cur[1], buf0_cur[2], buf0_cur[3]);
+    }
+    row_txfm(buf0, buf0, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]);
+
+    __m128i *_buf1 = buf1 + i * 4;
+    for (int j = 0; j < buf_size_w_div8; ++j) {
+      TRANSPOSE_4X4(buf0[j * 4 + 0], buf0[j * 4 + 1], buf0[j * 4 + 2],
+                    buf0[j * 4 + 3], _buf1[j * txfm_size_row + 0],
+                    _buf1[j * txfm_size_row + 1], _buf1[j * txfm_size_row + 2],
+                    _buf1[j * txfm_size_row + 3]);
+    }
+  }
+  // 2nd stage: column transform
+  for (int i = 0; i < buf_size_w_div8; i++) {
+    col_txfm(buf1 + i * txfm_size_row, buf1 + i * txfm_size_row,
+             inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0);
+
+    av1_round_shift_array_32_sse4_1(buf1 + i * txfm_size_row,
+                                    buf1 + i * txfm_size_row, txfm_size_row,
+                                    -shift[1]);
+  }
+
+  // write to buffer
+  {
+    for (int i = 0; i < (txfm_size_col >> 3); i++) {
+      highbd_write_buffer_8xn_sse4_1(buf1 + i * txfm_size_row * 2,
+                                     output + 8 * i, stride, ud_flip,
+                                     txfm_size_row, bd);
+    }
+  }
+}
+
+void av1_highbd_inv_txfm2d_add_universe_sse4_1(const int32_t *input,
+                                               uint8_t *output, int stride,
+                                               TX_TYPE tx_type, TX_SIZE tx_size,
+                                               int eob, const int bd) {
+  switch (tx_type) {
+    case DCT_DCT:
+      highbd_inv_txfm2d_add_no_identity_sse41(
+          input, CONVERT_TO_SHORTPTR(output), stride, tx_type, tx_size, eob,
+          bd);
+      break;
+    default: assert(0); break;
+  }
+}
+
 void av1_highbd_inv_txfm_add_sse4_1(const tran_low_t *input, uint8_t *dest,
                                     int stride, const TxfmParam *txfm_param) {
   assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
@@ -2518,9 +2498,6 @@
     case TX_32X16:
       av1_highbd_inv_txfm_add_32x16(input, dest, stride, txfm_param);
       break;
-    case TX_64X64:
-      av1_highbd_inv_txfm_add_64x64_sse4_1(input, dest, stride, txfm_param);
-      break;
     case TX_32X64:
       av1_highbd_inv_txfm_add_32x64(input, dest, stride, txfm_param);
       break;
@@ -2548,6 +2525,11 @@
     case TX_32X8:
       av1_highbd_inv_txfm_add_32x8(input, dest, stride, txfm_param);
       break;
+    case TX_64X64:
+      av1_highbd_inv_txfm2d_add_universe_sse4_1(
+          input, dest, stride, txfm_param->tx_type, txfm_param->tx_size,
+          txfm_param->eob, txfm_param->bd);
+      break;
     default: assert(0 && "Invalid transform size"); break;
   }
 }
diff --git a/av1/common/x86/highbd_txfm_utility_sse4.h b/av1/common/x86/highbd_txfm_utility_sse4.h
index b29bd1d..cfea022 100644
--- a/av1/common/x86/highbd_txfm_utility_sse4.h
+++ b/av1/common/x86/highbd_txfm_utility_sse4.h
@@ -100,4 +100,12 @@
   return x;
 }
 
+typedef void (*transform_1d_sse4_1)(__m128i *in, __m128i *out, int bit,
+                                    int do_cols, int bd, int out_shift);
+
+void av1_highbd_inv_txfm2d_add_universe_sse4_1(const int32_t *input,
+                                               uint8_t *output, int stride,
+                                               TX_TYPE tx_type, TX_SIZE tx_size,
+                                               int eob, const int bd);
+
 #endif  // _HIGHBD_TXFM_UTILITY_SSE4_H
diff --git a/test/av1_highbd_iht_test.cc b/test/av1_highbd_iht_test.cc
index 14a5a95..7382ff9 100644
--- a/test/av1_highbd_iht_test.cc
+++ b/test/av1_highbd_iht_test.cc
@@ -150,9 +150,6 @@
 #define PARAM_LIST_16X16                                     \
   &av1_fwd_txfm2d_16x16_c, &av1_inv_txfm2d_add_16x16_sse4_1, \
       &av1_inv_txfm2d_add_16x16_c, 256
-#define PARAM_LIST_64X64                                     \
-  &av1_fwd_txfm2d_64x64_c, &av1_inv_txfm2d_add_64x64_sse4_1, \
-      &av1_inv_txfm2d_add_64x64_c, 4096
 
 const IHbdHtParam kArrayIhtParam[] = {
   // 16x16
@@ -212,8 +209,6 @@
   make_tuple(PARAM_LIST_4X4, ADST_FLIPADST, 12),
   make_tuple(PARAM_LIST_4X4, FLIPADST_ADST, 10),
   make_tuple(PARAM_LIST_4X4, FLIPADST_ADST, 12),
-  make_tuple(PARAM_LIST_64X64, DCT_DCT, 10),
-  make_tuple(PARAM_LIST_64X64, DCT_DCT, 12),
 };
 
 INSTANTIATE_TEST_CASE_P(SSE4_1, AV1HighbdInvHTNxN,