Add av1_lowbd_inv_txfm2d_add_{16,32}_sse2

Implement av1_lowbd_inv_txfm2d_add_32x32_sse2
Implement av1_lowbd_inv_txfm2d_add_16x32_sse2
Implement av1_lowbd_inv_txfm2d_add_32x16_sse2

Change-Id: I1b5dc29d0cf75d5d43f4869b729f480f03534ea9
diff --git a/av1/common/x86/av1_inv_txfm_sse2.c b/av1/common/x86/av1_inv_txfm_sse2.c
index bb8aa1b..772fc06 100644
--- a/av1/common/x86/av1_inv_txfm_sse2.c
+++ b/av1/common/x86/av1_inv_txfm_sse2.c
@@ -1624,6 +1624,16 @@
     output[i] = _mm_packs_epi32(c_lo, c_hi);
   }
 }
+
+static void iidentity32_new_sse2(const __m128i *input, __m128i *output,
+                                 int8_t cos_bit) {
+  (void)cos_bit;
+  for (int i = 0; i < 32; ++i) {
+    output[i] = _mm_adds_epi16(input[i], input[i]);
+    output[i] = _mm_adds_epi16(output[i], output[i]);
+  }
+}
+
 static INLINE __m128i lowbd_get_recon_8x8_sse2(const __m128i pred,
                                                __m128i res) {
   const __m128i zero = _mm_setzero_si128();
@@ -1651,6 +1661,10 @@
   idct16_new_sse2, iadst16_new_sse2, iadst16_new_sse2, iidentity16_new_sse2,
 };
 
+static const transform_1d_sse2 lowbd_txfm32_1d_arr[TX_TYPES_1D] = {
+  idct32_new_sse2, NULL, NULL, iidentity32_new_sse2,
+};
+
 void av1_lowbd_inv_txfm2d_add_8x8_sse2(const int32_t *input, uint8_t *output,
                                        int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
@@ -1658,12 +1672,12 @@
   const int8_t *shift = inv_txfm_shift_ls[TX_8X8];
   const int txw_idx = get_txw_idx(TX_8X8);
   const int txh_idx = get_txh_idx(TX_8X8);
-  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   const int buf_size = 8;
 
-  const transform_1d_sse2 col_txfm = lowbd_txfm8_1d_arr[vtx_tab[tx_type]];
   const transform_1d_sse2 row_txfm = lowbd_txfm8_1d_arr[htx_tab[tx_type]];
+  const transform_1d_sse2 col_txfm = lowbd_txfm8_1d_arr[vtx_tab[tx_type]];
 
   int ud_flip, lr_flip;
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
@@ -1712,36 +1726,38 @@
   const int8_t *shift = inv_txfm_shift_ls[tx_size];
   const int txw_idx = get_txw_idx(tx_size);
   const int txh_idx = get_txh_idx(tx_size);
-  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   const int buf_size_w = tx_size_wide[tx_size];
   const int buf_size_h = tx_size_high[tx_size];
 
-  const transform_1d_sse2 col_txfm = lowbd_txfm16_1d_arr[vtx_tab[tx_type]];
   const transform_1d_sse2 row_txfm = lowbd_txfm16_1d_arr[htx_tab[tx_type]];
+  const transform_1d_sse2 col_txfm = lowbd_txfm16_1d_arr[vtx_tab[tx_type]];
 
   int ud_flip, lr_flip;
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   // i=0 do up 16x8,i=1 do down 16x8
   for (int i = 0; i < 2; i++) {
     __m128i buf0[16];
-    // left 8x8
-    load_buffer_32bit_to_16bit(input + i * buf_size_w * 8, buf_size_w, buf0, 8);
-    transpose_16bit_8x8(buf0, buf0);
-    // right 8x8
-    load_buffer_32bit_to_16bit(input + i * buf_size_w * 8 + 8, buf_size_w,
-                               buf0 + 8, 8);
-    transpose_16bit_8x8(buf0 + 8, buf0 + 8);
+    const int32_t *input_row = input + i * buf_size_w * 8;
+    for (int j = 0; j < 2; ++j) {
+      __m128i *buf0_cur = buf0 + 8 * j;
+      load_buffer_32bit_to_16bit(input_row + j * 8, buf_size_w, buf0_cur, 8);
+      transpose_16bit_8x8(buf0_cur, buf0_cur);
+    }
+
     row_txfm(buf0, buf0, cos_bit_row);
     round_shift_16bit(buf0, buf_size_w, shift[0]);
     if (lr_flip) {
-      __m128i temp[16];
-      flip_buf_sse2(buf0, temp, buf_size_w);
-      transpose_16bit_8x8(temp, buf1 + i * 8);
-      transpose_16bit_8x8(temp + 8, buf1 + i * 8 + buf_size_w);
+      for (int j = 0; j < 2; ++j) {
+        __m128i temp[8];
+        flip_buf_sse2(buf0 + 8 * j, temp, 8);
+        transpose_16bit_8x8(temp, buf1 + i * 8 + (1 - j) * buf_size_w);
+      }
     } else {
-      transpose_16bit_8x8(buf0, buf1 + i * 8);
-      transpose_16bit_8x8(buf0 + 8, buf1 + i * 8 + buf_size_w);
+      for (int j = 0; j < 2; ++j) {
+        transpose_16bit_8x8(buf0 + j * 8, buf1 + i * 8 + j * buf_size_h);
+      }
     }
   }
   for (int i = 0; i < 2; i++) {
@@ -1775,22 +1791,23 @@
   const int8_t *shift = inv_txfm_shift_ls[tx_size];
   const int txw_idx = get_txw_idx(tx_size);
   const int txh_idx = get_txh_idx(tx_size);
-  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   const int buf_size_w = tx_size_wide[tx_size];
   const int buf_size_h = tx_size_high[tx_size];
 
-  const transform_1d_sse2 col_txfm = lowbd_txfm8_1d_arr[vtx_tab[tx_type]];
   const transform_1d_sse2 row_txfm = lowbd_txfm16_1d_arr[htx_tab[tx_type]];
+  const transform_1d_sse2 col_txfm = lowbd_txfm8_1d_arr[vtx_tab[tx_type]];
 
   int ud_flip, lr_flip;
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
-  // left 8x8
-  load_buffer_32bit_to_16bit(input, buf_size_w, buf0, 8);
-  transpose_16bit_8x8(buf0, buf0);
-  // right 8x8
-  load_buffer_32bit_to_16bit(input + 8, buf_size_w, buf0 + 8, 8);
-  transpose_16bit_8x8(buf0 + 8, buf0 + 8);
+  const int32_t *input_row = input;
+  for (int j = 0; j < 2; ++j) {
+    __m128i *buf0_cur = buf0 + 8 * j;
+    load_buffer_32bit_to_16bit(input_row + j * 8, buf_size_w, buf0_cur, 8);
+    transpose_16bit_8x8(buf0_cur, buf0_cur);
+  }
+
   round_shift_sse2(buf0, buf0, buf_size_w);  // rect special code
   row_txfm(buf0, buf0, cos_bit_row);
   round_shift_16bit(buf0, buf_size_w, shift[0]);
@@ -1819,13 +1836,13 @@
   const int8_t *shift = inv_txfm_shift_ls[tx_size];
   const int txw_idx = get_txw_idx(tx_size);
   const int txh_idx = get_txh_idx(tx_size);
-  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
   const int buf_size_w = tx_size_wide[tx_size];
   const int buf_size_h = tx_size_high[tx_size];
 
-  const transform_1d_sse2 col_txfm = lowbd_txfm16_1d_arr[vtx_tab[tx_type]];
   const transform_1d_sse2 row_txfm = lowbd_txfm8_1d_arr[htx_tab[tx_type]];
+  const transform_1d_sse2 col_txfm = lowbd_txfm16_1d_arr[vtx_tab[tx_type]];
 
   int ud_flip, lr_flip;
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
@@ -1850,6 +1867,167 @@
   lowbd_write_buffer_8xn_sse2(buf0, output, stride, ud_flip, buf_size_h);
 }
 
+void av1_lowbd_inv_txfm2d_add_32x32_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  __m128i buf1[32 * 4];
+  const TX_SIZE tx_size = TX_32X32;
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
+  const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int buf_size_w = tx_size_wide[tx_size];
+  const int buf_size_h = tx_size_high[tx_size];
+
+  const transform_1d_sse2 row_txfm = lowbd_txfm32_1d_arr[htx_tab[tx_type]];
+  const transform_1d_sse2 col_txfm = lowbd_txfm32_1d_arr[vtx_tab[tx_type]];
+
+  assert(col_txfm != NULL);
+  assert(row_txfm != NULL);
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  for (int i = 0; i < 4; i++) {
+    __m128i buf0[32];
+    const int32_t *input_row = input + i * buf_size_w * 8;
+    for (int j = 0; j < 4; ++j) {
+      __m128i *buf0_cur = buf0 + j * 8;
+      load_buffer_32bit_to_16bit(input_row + j * 8, buf_size_w, buf0_cur, 8);
+      transpose_16bit_8x8(buf0_cur, buf0_cur);
+    }
+
+    row_txfm(buf0, buf0, cos_bit_row);
+    round_shift_16bit(buf0, buf_size_w, shift[0]);
+    __m128i *buf1_cur = buf1 + i * 8;
+    if (lr_flip) {
+      for (int j = 0; j < 4; ++j) {
+        __m128i temp[8];
+        flip_buf_sse2(buf0 + 8 * j, temp, 8);
+        transpose_16bit_8x8(temp, buf1_cur + buf_size_w * (3 - j));
+      }
+    } else {
+      for (int j = 0; j < 4; ++j) {
+        transpose_16bit_8x8(buf0 + 8 * j, buf1_cur + buf_size_h * j);
+      }
+    }
+  }
+  for (int i = 0; i < 4; i++) {
+    col_txfm(buf1 + i * buf_size_h, buf1 + i * buf_size_h, cos_bit_col);
+    round_shift_16bit(buf1 + i * buf_size_h, buf_size_h, shift[1]);
+  }
+  lowbd_write_buffer_16xn_sse2(buf1, output, stride, ud_flip, buf_size_h);
+  lowbd_write_buffer_16xn_sse2(buf1 + 64, output + 16, stride, ud_flip,
+                               buf_size_h);
+}
+
+void av1_lowbd_inv_txfm2d_add_32x16_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  __m128i buf1[32 * 2];
+  const TX_SIZE tx_size = TX_32X16;
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
+  const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int buf_size_w = tx_size_wide[tx_size];
+  const int buf_size_h = tx_size_high[tx_size];
+  const int buf_size_w_div8 = buf_size_w >> 3;
+  const int buf_size_h_div8 = buf_size_h >> 3;
+
+  const transform_1d_sse2 row_txfm = lowbd_txfm32_1d_arr[htx_tab[tx_type]];
+  const transform_1d_sse2 col_txfm = lowbd_txfm16_1d_arr[vtx_tab[tx_type]];
+
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  for (int i = 0; i < buf_size_h_div8; i++) {
+    __m128i buf0[32];
+    const int32_t *input_row = input + i * buf_size_w * 8;
+    for (int j = 0; j < buf_size_w_div8; ++j) {
+      __m128i *buf0_cur = buf0 + j * 8;
+      load_buffer_32bit_to_16bit(input_row + j * 8, buf_size_w, buf0_cur, 8);
+      transpose_16bit_8x8(buf0_cur, buf0_cur);
+    }
+    round_shift_sse2(buf0, buf0, buf_size_w);  // rect special code
+    row_txfm(buf0, buf0, cos_bit_row);
+    round_shift_16bit(buf0, buf_size_w, shift[0]);
+    __m128i *buf1_cur = buf1 + i * 8;
+    if (lr_flip) {
+      for (int j = 0; j < buf_size_w_div8; ++j) {
+        __m128i temp[8];
+        flip_buf_sse2(buf0 + 8 * j, temp, 8);
+        transpose_16bit_8x8(temp,
+                            buf1_cur + buf_size_h * (buf_size_w_div8 - 1 - j));
+      }
+    } else {
+      for (int j = 0; j < buf_size_w_div8; ++j) {
+        transpose_16bit_8x8(buf0 + 8 * j, buf1_cur + buf_size_h * j);
+      }
+    }
+  }
+  for (int i = 0; i < buf_size_w_div8; i++) {
+    col_txfm(buf1 + i * buf_size_h, buf1 + i * buf_size_h, cos_bit_col);
+    round_shift_16bit(buf1 + i * buf_size_h, buf_size_h, shift[1]);
+  }
+  lowbd_write_buffer_16xn_sse2(buf1, output, stride, ud_flip, buf_size_h);
+  lowbd_write_buffer_16xn_sse2(buf1 + 32, output + 16, stride, ud_flip,
+                               buf_size_h);
+}
+
+void av1_lowbd_inv_txfm2d_add_16x32_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  __m128i buf1[32 * 4];
+  const TX_SIZE tx_size = TX_16X32;
+  const int8_t *shift = inv_txfm_shift_ls[tx_size];
+  const int txw_idx = get_txw_idx(tx_size);
+  const int txh_idx = get_txh_idx(tx_size);
+  const int cos_bit_col = inv_cos_bit_col[txw_idx][txh_idx];
+  const int cos_bit_row = inv_cos_bit_row[txw_idx][txh_idx];
+  const int buf_size_w = tx_size_wide[tx_size];
+  const int buf_size_h = tx_size_high[tx_size];
+  const int buf_size_w_div8 = buf_size_w >> 3;
+  const int buf_size_h_div8 = buf_size_h >> 3;
+
+  const transform_1d_sse2 row_txfm = lowbd_txfm16_1d_arr[htx_tab[tx_type]];
+  const transform_1d_sse2 col_txfm = lowbd_txfm32_1d_arr[vtx_tab[tx_type]];
+
+  assert(col_txfm != NULL);
+  assert(row_txfm != NULL);
+  int ud_flip, lr_flip;
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  for (int i = 0; i < buf_size_h_div8; i++) {
+    __m128i buf0[16];  // buffer __m128i with count of buf_size_w
+    const int32_t *input_row = input + i * buf_size_w * 8;
+    for (int j = 0; j < buf_size_w_div8; ++j) {
+      __m128i *buf0_cur = buf0 + j * 8;
+      load_buffer_32bit_to_16bit(input_row + j * 8, buf_size_w, buf0_cur, 8);
+      transpose_16bit_8x8(buf0_cur, buf0_cur);
+    }
+    round_shift_sse2(buf0, buf0, buf_size_w);  // rect special code
+    row_txfm(buf0, buf0, cos_bit_row);
+    round_shift_16bit(buf0, buf_size_w, shift[0]);
+    __m128i *buf1_cur = buf1 + i * 8;
+    if (lr_flip) {
+      for (int j = 0; j < buf_size_w_div8; ++j) {
+        __m128i temp[8];
+        flip_buf_sse2(buf0 + 8 * j, temp, 8);
+        transpose_16bit_8x8(temp,
+                            buf1_cur + buf_size_h * (buf_size_w_div8 - 1 - j));
+      }
+    } else {
+      for (int j = 0; j < buf_size_w_div8; ++j) {
+        transpose_16bit_8x8(buf0 + 8 * j, buf1_cur + buf_size_h * j);
+      }
+    }
+  }
+  for (int i = 0; i < buf_size_w_div8; i++) {
+    col_txfm(buf1 + i * buf_size_h, buf1 + i * buf_size_h, cos_bit_col);
+    round_shift_16bit(buf1 + i * buf_size_h, buf_size_h, shift[1]);
+  }
+  lowbd_write_buffer_16xn_sse2(buf1, output, stride, ud_flip, buf_size_h);
+}
+
 typedef void (*inv_txfm_func)(const int32_t *input, uint8_t *output, int stride,
                               TX_TYPE tx_type, int bd);
 
diff --git a/av1/common/x86/av1_txfm_sse2.h b/av1/common/x86/av1_txfm_sse2.h
index f11f962..8fba6cf 100644
--- a/av1/common/x86/av1_txfm_sse2.h
+++ b/av1/common/x86/av1_txfm_sse2.h
@@ -205,6 +205,15 @@
 
 void av1_lowbd_inv_txfm2d_add_8x16_sse2(const int32_t *input, uint8_t *output,
                                         int stride, TX_TYPE tx_type, int bd);
+
+void av1_lowbd_inv_txfm2d_add_32x32_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd);
+
+void av1_lowbd_inv_txfm2d_add_32x16_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd);
+
+void av1_lowbd_inv_txfm2d_add_16x32_sse2(const int32_t *input, uint8_t *output,
+                                         int stride, TX_TYPE tx_type, int bd);
 #ifdef __cplusplus
 }
 #endif  // __cplusplus
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index 75104b0..cfd77fb 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -229,7 +229,18 @@
     ref_func_ = libaom_test::inv_txfm_func_ls[tx_size_];
     target_func_ = target_list[tx_size_];
   }
-
+  int ValidTypeSize(TX_TYPE tx_type) {
+    int rows = tx_size_wide[tx_size_];
+    int cols = tx_size_high[tx_size_];
+    TX_TYPE_1D vtype = vtx_tab[tx_type];
+    TX_TYPE_1D htype = htx_tab[tx_type];
+    if (rows == 32 && (htype == ADST_1D || htype == FLIPADST_1D)) {
+      return 0;
+    } else if (cols == 32 && (vtype == ADST_1D || vtype == FLIPADST_1D)) {
+      return 0;
+    }
+    return 1;
+  }
   void RunAV1InvTxfm2dTest(TX_TYPE tx_type, int run_times);
 
  private:
@@ -253,6 +264,7 @@
   int stride = BLK_WIDTH;
   int rows = tx_size_high[tx_size_];
   int cols = tx_size_wide[tx_size_];
+
   ACMRandom rnd(ACMRandom::DeterministicSeed());
   int randTimes = run_times == 1 ? 500 : 2;
   for (int cnt = 0; cnt < randTimes; ++cnt) {
@@ -265,6 +277,7 @@
       }
     }
     fwd_func_(input, inv_input, stride, tx_type, bd);
+
     aom_usec_timer timer;
     aom_usec_timer_start(&timer);
     for (int i = 0; i < run_times; ++i) {
@@ -296,12 +309,17 @@
 
 TEST_P(AV1LbdInvTxfm2d, match) {
   for (int i = 0; i < (int)TX_TYPES; ++i) {
-    RunAV1InvTxfm2dTest((TX_TYPE)i, 1);
+    if (ValidTypeSize((TX_TYPE)(i))) {
+      RunAV1InvTxfm2dTest((TX_TYPE)i, 1);
+    }
   }
 }
+
 TEST_P(AV1LbdInvTxfm2d, DISABLED_Speed) {
   for (int i = 0; i < (int)TX_TYPES; ++i) {
-    RunAV1InvTxfm2dTest((TX_TYPE)i, 10000000);
+    if (ValidTypeSize((TX_TYPE)(i))) {
+      RunAV1InvTxfm2dTest((TX_TYPE)i, 1000000);
+    }
   }
 }
 
@@ -312,16 +330,16 @@
   NULL,                                 // TX_4X4
   av1_lowbd_inv_txfm2d_add_8x8_sse2,    // TX_8X8
   av1_lowbd_inv_txfm2d_add_16x16_sse2,  // TX_16X16
-  NULL,                                 // TX_32X32
+  av1_lowbd_inv_txfm2d_add_32x32_sse2,  // TX_32X32
 #if CONFIG_TX64X64
-  NULL,                                // TX_64X64
-#endif                                 // CONFIG_TX64X64
-  NULL,                                // TX_4X8
-  NULL,                                // TX_8X4
-  av1_lowbd_inv_txfm2d_add_8x16_sse2,  // TX_8X16
-  av1_lowbd_inv_txfm2d_add_16x8_sse2,  // TX_16X8
-  NULL,                                // TX_16X32
-  NULL,                                // TX_32X16
+  NULL,                                 // TX_64X64
+#endif                                  // CONFIG_TX64X64
+  NULL,                                 // TX_4X8
+  NULL,                                 // TX_8X4
+  av1_lowbd_inv_txfm2d_add_8x16_sse2,   // TX_8X16
+  av1_lowbd_inv_txfm2d_add_16x8_sse2,   // TX_16X8
+  av1_lowbd_inv_txfm2d_add_16x32_sse2,  // TX_16X32
+  av1_lowbd_inv_txfm2d_add_32x16_sse2,  // TX_32X16
 #if CONFIG_TX64X64
   NULL,  // TX_32X64
   NULL,  // TX_64X32