Optimize highbd 32x16 and 16x32 fwd_txfm

Add sse4_1 variant for highbd 32x16 and 16x32 fwd_txfm.

Re factored code for highbd 32x32 fwd_txfm.

Achieved module level gains of ~4.1x w.r.t. C code.

When tested for 20 frames of crowd_run_360p_10 at 1 mbps
for speed=1 preset, observed ~1.1% reduction in encoder time.

Change-Id: Ica7b9ad7b563c3ee6334dd3f506d7d6e9b06ce6e
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 8e9336a..a6d5138 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -214,7 +214,9 @@
   add_proto qw/void av1_fwd_txfm2d_16x8/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
   specialize qw/av1_fwd_txfm2d_16x8 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_16x32/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
+  specialize qw/av1_fwd_txfm2d_16x32 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_32x16/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
+  specialize qw/av1_fwd_txfm2d_32x16 sse4_1/;
   add_proto qw/void av1_fwd_txfm2d_4x16/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
   add_proto qw/void av1_fwd_txfm2d_16x4/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
   add_proto qw/void av1_fwd_txfm2d_8x32/, "const int16_t *input, int32_t *output, int stride, TX_TYPE tx_type, int bd";
diff --git a/av1/common/x86/highbd_txfm_utility_sse4.h b/av1/common/x86/highbd_txfm_utility_sse4.h
index 6f24e59..5734810 100644
--- a/av1/common/x86/highbd_txfm_utility_sse4.h
+++ b/av1/common/x86/highbd_txfm_utility_sse4.h
@@ -75,13 +75,20 @@
                 out[63]);
 }
 
-static INLINE void transpose_32x32(const __m128i *input, __m128i *output) {
-  for (int j = 0; j < 8; j++) {
-    for (int i = 0; i < 8; i++) {
-      TRANSPOSE_4X4(input[i * 32 + j + 0], input[i * 32 + j + 8],
-                    input[i * 32 + j + 16], input[i * 32 + j + 24],
-                    output[j * 32 + i + 0], output[j * 32 + i + 8],
-                    output[j * 32 + i + 16], output[j * 32 + i + 24]);
+static INLINE void transpose_8nx8n(const __m128i *input, __m128i *output,
+                                   const int width, const int height) {
+  const int numcol = height >> 2;
+  const int numrow = width >> 2;
+  for (int j = 0; j < numrow; j++) {
+    for (int i = 0; i < numcol; i++) {
+      TRANSPOSE_4X4(input[i * width + j + (numrow * 0)],
+                    input[i * width + j + (numrow * 1)],
+                    input[i * width + j + (numrow * 2)],
+                    input[i * width + j + (numrow * 3)],
+                    output[j * height + i + (numcol * 0)],
+                    output[j * height + i + (numcol * 1)],
+                    output[j * height + i + (numcol * 2)],
+                    output[j * height + i + (numcol * 3)]);
     }
   }
 }
diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c
index 67898fd..22d29cf 100644
--- a/av1/encoder/hybrid_fwd_txfm.c
+++ b/av1/encoder/hybrid_fwd_txfm.c
@@ -165,15 +165,33 @@
 static void highbd_fwd_txfm_16x32(const int16_t *src_diff, tran_low_t *coeff,
                                   int diff_stride, TxfmParam *txfm_param) {
   int32_t *dst_coeff = (int32_t *)coeff;
-  av1_fwd_txfm2d_16x32_c(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
-                         txfm_param->bd);
+  switch (txfm_param->tx_type) {
+    case DCT_DCT:
+      av1_fwd_txfm2d_16x32(src_diff, dst_coeff, diff_stride,
+                           txfm_param->tx_type, txfm_param->bd);
+      break;
+    case IDTX:
+      av1_fwd_txfm2d_16x32_c(src_diff, dst_coeff, diff_stride,
+                             txfm_param->tx_type, txfm_param->bd);
+      break;
+    default: assert(0);
+  }
 }
 
 static void highbd_fwd_txfm_32x16(const int16_t *src_diff, tran_low_t *coeff,
                                   int diff_stride, TxfmParam *txfm_param) {
   int32_t *dst_coeff = (int32_t *)coeff;
-  av1_fwd_txfm2d_32x16_c(src_diff, dst_coeff, diff_stride, txfm_param->tx_type,
-                         txfm_param->bd);
+  switch (txfm_param->tx_type) {
+    case DCT_DCT:
+      av1_fwd_txfm2d_32x16(src_diff, dst_coeff, diff_stride,
+                           txfm_param->tx_type, txfm_param->bd);
+      break;
+    case IDTX:
+      av1_fwd_txfm2d_32x16_c(src_diff, dst_coeff, diff_stride,
+                             txfm_param->tx_type, txfm_param->bd);
+      break;
+    default: assert(0);
+  }
 }
 
 static void highbd_fwd_txfm_16x4(const int16_t *src_diff, tran_low_t *coeff,
diff --git a/av1/encoder/x86/av1_fwd_txfm1d_sse4.c b/av1/encoder/x86/av1_fwd_txfm1d_sse4.c
index 0761554..faa19b7 100644
--- a/av1/encoder/x86/av1_fwd_txfm1d_sse4.c
+++ b/av1/encoder/x86/av1_fwd_txfm1d_sse4.c
@@ -11,45 +11,45 @@
 
 #include "av1/encoder/x86/av1_txfm1d_sse4.h"
 
-void av1_fdct32_new_sse4_1(const __m128i *input, __m128i *output,
-                           int8_t cos_bit) {
+void av1_fdct32_new_sse4_1(__m128i *input, __m128i *output, int cos_bit,
+                           const int stride) {
   __m128i buf0[32];
   __m128i buf1[32];
   const int32_t *cospi;
   // stage 0
   // stage 1
-  buf1[0] = _mm_add_epi32(input[0], input[31]);
-  buf1[31] = _mm_sub_epi32(input[0], input[31]);
-  buf1[1] = _mm_add_epi32(input[1], input[30]);
-  buf1[30] = _mm_sub_epi32(input[1], input[30]);
-  buf1[2] = _mm_add_epi32(input[2], input[29]);
-  buf1[29] = _mm_sub_epi32(input[2], input[29]);
-  buf1[3] = _mm_add_epi32(input[3], input[28]);
-  buf1[28] = _mm_sub_epi32(input[3], input[28]);
-  buf1[4] = _mm_add_epi32(input[4], input[27]);
-  buf1[27] = _mm_sub_epi32(input[4], input[27]);
-  buf1[5] = _mm_add_epi32(input[5], input[26]);
-  buf1[26] = _mm_sub_epi32(input[5], input[26]);
-  buf1[6] = _mm_add_epi32(input[6], input[25]);
-  buf1[25] = _mm_sub_epi32(input[6], input[25]);
-  buf1[7] = _mm_add_epi32(input[7], input[24]);
-  buf1[24] = _mm_sub_epi32(input[7], input[24]);
-  buf1[8] = _mm_add_epi32(input[8], input[23]);
-  buf1[23] = _mm_sub_epi32(input[8], input[23]);
-  buf1[9] = _mm_add_epi32(input[9], input[22]);
-  buf1[22] = _mm_sub_epi32(input[9], input[22]);
-  buf1[10] = _mm_add_epi32(input[10], input[21]);
-  buf1[21] = _mm_sub_epi32(input[10], input[21]);
-  buf1[11] = _mm_add_epi32(input[11], input[20]);
-  buf1[20] = _mm_sub_epi32(input[11], input[20]);
-  buf1[12] = _mm_add_epi32(input[12], input[19]);
-  buf1[19] = _mm_sub_epi32(input[12], input[19]);
-  buf1[13] = _mm_add_epi32(input[13], input[18]);
-  buf1[18] = _mm_sub_epi32(input[13], input[18]);
-  buf1[14] = _mm_add_epi32(input[14], input[17]);
-  buf1[17] = _mm_sub_epi32(input[14], input[17]);
-  buf1[15] = _mm_add_epi32(input[15], input[16]);
-  buf1[16] = _mm_sub_epi32(input[15], input[16]);
+  buf1[0] = _mm_add_epi32(input[0 * stride], input[31 * stride]);
+  buf1[31] = _mm_sub_epi32(input[0 * stride], input[31 * stride]);
+  buf1[1] = _mm_add_epi32(input[1 * stride], input[30 * stride]);
+  buf1[30] = _mm_sub_epi32(input[1 * stride], input[30 * stride]);
+  buf1[2] = _mm_add_epi32(input[2 * stride], input[29 * stride]);
+  buf1[29] = _mm_sub_epi32(input[2 * stride], input[29 * stride]);
+  buf1[3] = _mm_add_epi32(input[3 * stride], input[28 * stride]);
+  buf1[28] = _mm_sub_epi32(input[3 * stride], input[28 * stride]);
+  buf1[4] = _mm_add_epi32(input[4 * stride], input[27 * stride]);
+  buf1[27] = _mm_sub_epi32(input[4 * stride], input[27 * stride]);
+  buf1[5] = _mm_add_epi32(input[5 * stride], input[26 * stride]);
+  buf1[26] = _mm_sub_epi32(input[5 * stride], input[26 * stride]);
+  buf1[6] = _mm_add_epi32(input[6 * stride], input[25 * stride]);
+  buf1[25] = _mm_sub_epi32(input[6 * stride], input[25 * stride]);
+  buf1[7] = _mm_add_epi32(input[7 * stride], input[24 * stride]);
+  buf1[24] = _mm_sub_epi32(input[7 * stride], input[24 * stride]);
+  buf1[8] = _mm_add_epi32(input[8 * stride], input[23 * stride]);
+  buf1[23] = _mm_sub_epi32(input[8 * stride], input[23 * stride]);
+  buf1[9] = _mm_add_epi32(input[9 * stride], input[22 * stride]);
+  buf1[22] = _mm_sub_epi32(input[9 * stride], input[22 * stride]);
+  buf1[10] = _mm_add_epi32(input[10 * stride], input[21 * stride]);
+  buf1[21] = _mm_sub_epi32(input[10 * stride], input[21 * stride]);
+  buf1[11] = _mm_add_epi32(input[11 * stride], input[20 * stride]);
+  buf1[20] = _mm_sub_epi32(input[11 * stride], input[20 * stride]);
+  buf1[12] = _mm_add_epi32(input[12 * stride], input[19 * stride]);
+  buf1[19] = _mm_sub_epi32(input[12 * stride], input[19 * stride]);
+  buf1[13] = _mm_add_epi32(input[13 * stride], input[18 * stride]);
+  buf1[18] = _mm_sub_epi32(input[13 * stride], input[18 * stride]);
+  buf1[14] = _mm_add_epi32(input[14 * stride], input[17 * stride]);
+  buf1[17] = _mm_sub_epi32(input[14 * stride], input[17 * stride]);
+  buf1[15] = _mm_add_epi32(input[15 * stride], input[16 * stride]);
+  buf1[16] = _mm_sub_epi32(input[15 * stride], input[16 * stride]);
 
   // stage 2
   cospi = cospi_arr(cos_bit);
@@ -297,38 +297,38 @@
                       buf0[24], cos_bit);
 
   // stage 9
-  output[0] = buf0[0];
-  output[1] = buf0[16];
-  output[2] = buf0[8];
-  output[3] = buf0[24];
-  output[4] = buf0[4];
-  output[5] = buf0[20];
-  output[6] = buf0[12];
-  output[7] = buf0[28];
-  output[8] = buf0[2];
-  output[9] = buf0[18];
-  output[10] = buf0[10];
-  output[11] = buf0[26];
-  output[12] = buf0[6];
-  output[13] = buf0[22];
-  output[14] = buf0[14];
-  output[15] = buf0[30];
-  output[16] = buf0[1];
-  output[17] = buf0[17];
-  output[18] = buf0[9];
-  output[19] = buf0[25];
-  output[20] = buf0[5];
-  output[21] = buf0[21];
-  output[22] = buf0[13];
-  output[23] = buf0[29];
-  output[24] = buf0[3];
-  output[25] = buf0[19];
-  output[26] = buf0[11];
-  output[27] = buf0[27];
-  output[28] = buf0[7];
-  output[29] = buf0[23];
-  output[30] = buf0[15];
-  output[31] = buf0[31];
+  output[0 * stride] = buf0[0];
+  output[1 * stride] = buf0[16];
+  output[2 * stride] = buf0[8];
+  output[3 * stride] = buf0[24];
+  output[4 * stride] = buf0[4];
+  output[5 * stride] = buf0[20];
+  output[6 * stride] = buf0[12];
+  output[7 * stride] = buf0[28];
+  output[8 * stride] = buf0[2];
+  output[9 * stride] = buf0[18];
+  output[10 * stride] = buf0[10];
+  output[11 * stride] = buf0[26];
+  output[12 * stride] = buf0[6];
+  output[13 * stride] = buf0[22];
+  output[14 * stride] = buf0[14];
+  output[15 * stride] = buf0[30];
+  output[16 * stride] = buf0[1];
+  output[17 * stride] = buf0[17];
+  output[18 * stride] = buf0[9];
+  output[19 * stride] = buf0[25];
+  output[20 * stride] = buf0[5];
+  output[21 * stride] = buf0[21];
+  output[22 * stride] = buf0[13];
+  output[23 * stride] = buf0[29];
+  output[24 * stride] = buf0[3];
+  output[25 * stride] = buf0[19];
+  output[26 * stride] = buf0[11];
+  output[27 * stride] = buf0[27];
+  output[28 * stride] = buf0[7];
+  output[29 * stride] = buf0[23];
+  output[30 * stride] = buf0[15];
+  output[31 * stride] = buf0[31];
 }
 
 void av1_fadst4_new_sse4_1(const __m128i *input, __m128i *output,
@@ -394,9 +394,8 @@
   }
 }
 
-void av1_fdct64_new_sse4_1(const __m128i *input, __m128i *output,
-                           int8_t cos_bit, const int instride,
-                           const int outstride) {
+void av1_fdct64_new_sse4_1(__m128i *input, __m128i *output, int8_t cos_bit,
+                           const int instride, const int outstride) {
   const int32_t *cospi = cospi_arr(cos_bit);
   const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));
 
diff --git a/av1/encoder/x86/av1_fwd_txfm2d_sse4.c b/av1/encoder/x86/av1_fwd_txfm2d_sse4.c
index 8ec0256..c2903d2 100644
--- a/av1/encoder/x86/av1_fwd_txfm2d_sse4.c
+++ b/av1/encoder/x86/av1_fwd_txfm2d_sse4.c
@@ -29,31 +29,22 @@
   }
 }
 
-typedef void (*TxfmFuncSSE2)(const __m128i *input, __m128i *output,
+typedef void (*TxfmFuncSSE2)(__m128i *input, __m128i *output,
                              const int8_t cos_bit, const int8_t *stage_range);
 
-static void fdct32_new_sse4_1(const __m128i *input, __m128i *output,
+static void fdct32_new_sse4_1(__m128i *input, __m128i *output,
                               const int8_t cos_bit, const int8_t *stage_range) {
   const int txfm_size = 32;
   const int num_per_128 = 4;
-  __m128i buf0[32];
-  __m128i buf1[32];
   int col_num = txfm_size / num_per_128;
   int col;
   (void)stage_range;
   for (col = 0; col < col_num; col++) {
-    int j;
-    for (j = 0; j < 32; ++j) {
-      buf0[j] = input[j * col_num + col];
-    }
-    av1_fdct32_new_sse4_1(buf0, buf1, cos_bit);
-    for (j = 0; j < 32; ++j) {
-      output[j * col_num + col] = buf1[j];
-    }
+    av1_fdct32_new_sse4_1((input + col), (output + col), cos_bit, col_num);
   }
 }
 
-static void fdct64_new_sse4_1(const __m128i *input, __m128i *output,
+static void fdct64_new_sse4_1(__m128i *input, __m128i *output,
                               const int8_t cos_bit, const int8_t *stage_range) {
   const int txfm_size = 64;
   const int num_per_128 = 4;
@@ -142,7 +133,7 @@
 
   txfm2d_size_128 = (col_num >> 1) * (txfm_size >> 1);
   av1_round_shift_array_32_sse4_1(out_128, buf_128, txfm2d_size_128, -shift[2]);
-  transpose_32x32(buf_128, out_128);
+  transpose_8nx8n(buf_128, out_128, 32, 32);
 }
 
 void av1_fwd_txfm2d_32x32_sse4_1(const int16_t *input, int32_t *output,
@@ -317,8 +308,8 @@
       bufA[j] = _mm_cvtepi16_epi32(buf[j]);
       bufB[j] = _mm_cvtepi16_epi32(_mm_unpackhi_epi64(buf[j], buf[j]));
     }
-    av1_fdct32_new_sse4_1(bufA, bufA, cos_bit_row);
-    av1_fdct32_new_sse4_1(bufB, bufB, cos_bit_row);
+    av1_fdct32_new_sse4_1(bufA, bufA, cos_bit_row, 1);
+    av1_fdct32_new_sse4_1(bufB, bufB, cos_bit_row, 1);
     av1_round_shift_rect_array_32_sse4_1(bufA, bufA, 32, -shift[2], NewSqrt2);
     av1_round_shift_rect_array_32_sse4_1(bufB, bufB, 32, -shift[2], NewSqrt2);
 
diff --git a/av1/encoder/x86/av1_txfm1d_sse4.h b/av1/encoder/x86/av1_txfm1d_sse4.h
index 6df2a8b..ccdc327 100644
--- a/av1/encoder/x86/av1_txfm1d_sse4.h
+++ b/av1/encoder/x86/av1_txfm1d_sse4.h
@@ -26,11 +26,10 @@
                           const int8_t cos_bit, const int8_t *stage_range);
 void av1_fdct16_new_sse4_1(const __m128i *input, __m128i *output,
                            const int8_t cos_bit, const int8_t *stage_range);
-void av1_fdct32_new_sse4_1(const __m128i *input, __m128i *output,
-                           int8_t cos_bit);
-void av1_fdct64_new_sse4_1(const __m128i *input, __m128i *output,
-                           int8_t cos_bit, const int instride,
-                           const int outstride);
+void av1_fdct32_new_sse4_1(__m128i *input, __m128i *output, int cos_bit,
+                           const int stride);
+void av1_fdct64_new_sse4_1(__m128i *input, __m128i *output, int8_t cos_bit,
+                           const int instride, const int outstride);
 
 void av1_fadst4_new_sse4_1(const __m128i *input, __m128i *output,
                            const int8_t cos_bit, const int8_t *stage_range);
diff --git a/av1/encoder/x86/highbd_fwd_txfm_sse4.c b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
index 535485a..201565e 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_sse4.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_sse4.c
@@ -986,6 +986,19 @@
   load_buffer_8x8(botL, out + 16, stride, flipud, fliplr, shift);
 }
 
+static INLINE void load_buffer_32x16(const int16_t *input, __m128i *out,
+                                     int stride, int flipud, int fliplr,
+                                     int shift) {
+  const int16_t *in = input;
+  __m128i *output = out;
+  for (int col = 0; col < 16; col++) {
+    in = input + col * stride;
+    output = out + col * 8;
+    load_buffer_4x4(in, output, 4, flipud, fliplr, shift);
+    load_buffer_4x4((in + 16), (output + 4), 4, flipud, fliplr, shift);
+  }
+}
+
 static void fdct16x16_sse4_1(__m128i *in, __m128i *out, int bit,
                              const int col_num) {
   const int32_t *cospi = cospi_arr(bit);
@@ -1719,6 +1732,44 @@
   NULL              // H_FLIPADST
 };
 
+static const fwd_transform_1d_sse4_1 col_highbd_txfm8x32_arr[TX_TYPES] = {
+  av1_fdct32_new_sse4_1,  // DCT_DCT
+  NULL,                   // ADST_DCT
+  NULL,                   // DCT_ADST
+  NULL,                   // ADST_ADST
+  NULL,                   // FLIPADST_DCT
+  NULL,                   // DCT_FLIPADST
+  NULL,                   // FLIPADST_FLIPADST
+  NULL,                   // ADST_FLIPADST
+  NULL,                   // FLIPADST_ADST
+  NULL,                   // IDTX
+  NULL,                   // V_DCT
+  NULL,                   // H_DCT
+  NULL,                   // V_ADST
+  NULL,                   // H_ADST
+  NULL,                   // V_FLIPADST
+  NULL                    // H_FLIPADST
+};
+
+static const fwd_transform_1d_sse4_1 row_highbd_txfm8x32_arr[TX_TYPES] = {
+  fdct16x16_sse4_1,  // DCT_DCT
+  NULL,              // ADST_DCT
+  NULL,              // DCT_ADST
+  NULL,              // ADST_ADST
+  NULL,              // FLIPADST_DCT
+  NULL,              // DCT_FLIPADST
+  NULL,              // FLIPADST_FLIPADST
+  NULL,              // ADST_FLIPADST
+  NULL,              // FLIPADST_ADST
+  NULL,              // IDTX
+  NULL,              // V_DCT
+  NULL,              // H_DCT
+  NULL,              // V_ADST
+  NULL,              // H_ADST
+  NULL,              // V_FLIPADST
+  NULL               // H_FLIPADST
+};
+
 void av1_fwd_txfm2d_16x8_sse4_1(const int16_t *input, int32_t *coeff,
                                 int stride, TX_TYPE tx_type, int bd) {
   __m128i in[32], out[32];
@@ -1781,3 +1832,65 @@
 
   (void)bd;
 }
+
+void av1_fwd_txfm2d_16x32_sse4_1(const int16_t *input, int32_t *coeff,
+                                 int stride, TX_TYPE tx_type, int bd) {
+  assert(DCT_DCT == tx_type);
+  __m128i in[128];
+  __m128i *outcoef128 = (__m128i *)coeff;
+  const int8_t *shift = fwd_txfm_shift_ls[TX_16X32];
+  const int txw_idx = get_txw_idx(TX_16X32);
+  const int txh_idx = get_txh_idx(TX_16X32);
+  const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm8x32_arr[tx_type];
+  const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm8x32_arr[tx_type];
+  int bitcol = fwd_cos_bit_col[txw_idx][txh_idx];
+  int bitrow = fwd_cos_bit_row[txw_idx][txh_idx];
+
+  // column transform
+  load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
+  load_buffer_16x16(input + 16 * stride, in + 64, stride, 0, 0, shift[0]);
+
+  for (int i = 0; i < 4; i++) {
+    col_txfm((in + i), (in + i), bitcol, 4);
+  }
+  col_txfm_16x16_rounding(&in[0], -shift[1]);
+  col_txfm_16x16_rounding(&in[64], -shift[1]);
+  transpose_8nx8n(in, outcoef128, 16, 32);
+
+  // row transform
+  row_txfm(outcoef128, in, bitrow, 8);
+  transpose_8nx8n(in, outcoef128, 32, 16);
+  av1_round_shift_rect_array_32_sse4_1(outcoef128, outcoef128, 128, -shift[2],
+                                       NewSqrt2);
+  (void)bd;
+}
+
+void av1_fwd_txfm2d_32x16_sse4_1(const int16_t *input, int32_t *coeff,
+                                 int stride, TX_TYPE tx_type, int bd) {
+  assert(DCT_DCT == tx_type);
+  __m128i in[128];
+  __m128i *outcoef128 = (__m128i *)coeff;
+  const int8_t *shift = fwd_txfm_shift_ls[TX_32X16];
+  const int txw_idx = get_txw_idx(TX_32X16);
+  const int txh_idx = get_txh_idx(TX_32X16);
+  const fwd_transform_1d_sse4_1 col_txfm = row_highbd_txfm8x32_arr[tx_type];
+  const fwd_transform_1d_sse4_1 row_txfm = col_highbd_txfm8x32_arr[tx_type];
+  int bitcol = fwd_cos_bit_col[txw_idx][txh_idx];
+  int bitrow = fwd_cos_bit_row[txw_idx][txh_idx];
+
+  // column transform
+  load_buffer_32x16(input, in, stride, 0, 0, shift[0]);
+  col_txfm(in, in, bitcol, 8);
+  col_txfm_16x16_rounding(&in[0], -shift[1]);
+  col_txfm_16x16_rounding(&in[64], -shift[1]);
+  transpose_8nx8n(in, outcoef128, 32, 16);
+
+  // row transform
+  for (int i = 0; i < 4; i++) {
+    row_txfm((outcoef128 + i), (in + i), bitrow, 4);
+  }
+  transpose_8nx8n(in, outcoef128, 16, 32);
+  av1_round_shift_rect_array_32_sse4_1(outcoef128, outcoef128, 128, -shift[2],
+                                       NewSqrt2);
+  (void)bd;
+}