Implement av1_lowbd_fwd_txfm2d_4x8_sse2

Change-Id: Ife016b5adeeb26071fa20ce6f66b7b52074b36d6
diff --git a/av1/common/x86/av1_txfm_sse2.h b/av1/common/x86/av1_txfm_sse2.h
index 10415dd..ae3578f 100644
--- a/av1/common/x86/av1_txfm_sse2.h
+++ b/av1/common/x86/av1_txfm_sse2.h
@@ -87,6 +87,14 @@
   return _mm_srai_epi32(b, NewSqrt2Bits);
 }
 
+static INLINE void store_rect_16bit_to_32bit_w4(const __m128i a,
+                                                int32_t *const b) {
+  const __m128i one = _mm_set1_epi16(1);
+  const __m128i a_lo = _mm_unpacklo_epi16(a, one);
+  const __m128i b_lo = scale_round_sse2(a_lo, NewSqrt2);
+  _mm_store_si128((__m128i *)b, b_lo);
+}
+
 static INLINE void store_rect_16bit_to_32bit(const __m128i a,
                                              int32_t *const b) {
   const __m128i one = _mm_set1_epi16(1);
@@ -170,10 +178,20 @@
   }
 }
 
-static INLINE void store_rect_buffer_16bit_to_32bit_8x8(const __m128i *const in,
-                                                        int32_t *const out,
-                                                        const int stride) {
-  for (int i = 0; i < 8; ++i) {
+static INLINE void store_rect_buffer_16bit_to_32bit_w4(const __m128i *const in,
+                                                       int32_t *const out,
+                                                       const int stride,
+                                                       const int out_size) {
+  for (int i = 0; i < out_size; ++i) {
+    store_rect_16bit_to_32bit_w4(in[i], out + i * stride);
+  }
+}
+
+static INLINE void store_rect_buffer_16bit_to_32bit_w8(const __m128i *const in,
+                                                       int32_t *const out,
+                                                       const int stride,
+                                                       const int out_size) {
+  for (int i = 0; i < out_size; ++i) {
     store_rect_16bit_to_32bit(in[i], out + i * stride);
   }
 }
@@ -210,6 +228,9 @@
 void av1_lowbd_fwd_txfm2d_4x4_sse2(const int16_t *input, int32_t *output,
                                    int stride, TX_TYPE tx_type, int bd);
 
+void av1_lowbd_fwd_txfm2d_4x8_sse2(const int16_t *input, int32_t *output,
+                                   int stride, TX_TYPE tx_type, int bd);
+
 void av1_lowbd_fwd_txfm2d_8x8_sse2(const int16_t *input, int32_t *output,
                                    int stride, TX_TYPE tx_type, int bd);
 
diff --git a/av1/encoder/x86/av1_fwd_txfm_sse2.c b/av1/encoder/x86/av1_fwd_txfm_sse2.c
index cbe62ae..7afe8c0 100644
--- a/av1/encoder/x86/av1_fwd_txfm_sse2.c
+++ b/av1/encoder/x86/av1_fwd_txfm_sse2.c
@@ -1354,6 +1354,84 @@
   output[3] = _mm_srli_si128(output[1], 8);
 }
 
+static void fadst8x4_new_sse2(const __m128i *input, __m128i *output,
+                              int8_t cos_bit) {
+  const int32_t *sinpi = sinpi_arr(cos_bit);
+  const __m128i sinpi_p01_p02 = pair_set_epi16(sinpi[1], sinpi[2]);
+  const __m128i sinpi_p04_m01 = pair_set_epi16(sinpi[4], -sinpi[1]);
+  const __m128i sinpi_p03_p04 = pair_set_epi16(sinpi[3], sinpi[4]);
+  const __m128i sinpi_m03_p02 = pair_set_epi16(-sinpi[3], sinpi[2]);
+  const __m128i sinpi_p03_p03 = _mm_set1_epi16((int16_t)sinpi[3]);
+  const __m128i __zero = _mm_set1_epi16(0);
+  const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));
+  const __m128i in7 = _mm_add_epi16(input[0], input[1]);
+  __m128i u_lo[8], u_hi[8], v_lo[8], v_hi[8];
+
+  u_lo[0] = _mm_unpacklo_epi16(input[0], input[1]);
+  u_hi[0] = _mm_unpackhi_epi16(input[0], input[1]);
+  u_lo[1] = _mm_unpacklo_epi16(input[2], input[3]);
+  u_hi[1] = _mm_unpackhi_epi16(input[2], input[3]);
+  u_lo[2] = _mm_unpacklo_epi16(in7, __zero);
+  u_hi[2] = _mm_unpackhi_epi16(in7, __zero);
+  u_lo[3] = _mm_unpacklo_epi16(input[2], __zero);
+  u_hi[3] = _mm_unpackhi_epi16(input[2], __zero);
+  u_lo[4] = _mm_unpacklo_epi16(input[3], __zero);
+  u_hi[4] = _mm_unpackhi_epi16(input[3], __zero);
+
+  v_lo[0] = _mm_madd_epi16(u_lo[0], sinpi_p01_p02);  // s0 + s2
+  v_hi[0] = _mm_madd_epi16(u_hi[0], sinpi_p01_p02);  // s0 + s2
+  v_lo[1] = _mm_madd_epi16(u_lo[1], sinpi_p03_p04);  // s4 + s5
+  v_hi[1] = _mm_madd_epi16(u_hi[1], sinpi_p03_p04);  // s4 + s5
+  v_lo[2] = _mm_madd_epi16(u_lo[2], sinpi_p03_p03);  // x1
+  v_hi[2] = _mm_madd_epi16(u_hi[2], sinpi_p03_p03);  // x1
+  v_lo[3] = _mm_madd_epi16(u_lo[0], sinpi_p04_m01);  // s1 - s3
+  v_hi[3] = _mm_madd_epi16(u_hi[0], sinpi_p04_m01);  // s1 - s3
+  v_lo[4] = _mm_madd_epi16(u_lo[1], sinpi_m03_p02);  // -s4 + s6
+  v_hi[4] = _mm_madd_epi16(u_hi[1], sinpi_m03_p02);  // -s4 + s6
+  v_lo[5] = _mm_madd_epi16(u_lo[3], sinpi_p03_p03);  // s4
+  v_hi[5] = _mm_madd_epi16(u_hi[3], sinpi_p03_p03);  // s4
+  v_lo[6] = _mm_madd_epi16(u_lo[4], sinpi_p03_p03);
+  v_hi[6] = _mm_madd_epi16(u_hi[4], sinpi_p03_p03);
+
+  u_lo[0] = _mm_add_epi32(v_lo[0], v_lo[1]);
+  u_hi[0] = _mm_add_epi32(v_hi[0], v_hi[1]);
+  u_lo[1] = _mm_sub_epi32(v_lo[2], v_lo[6]);
+  u_hi[1] = _mm_sub_epi32(v_hi[2], v_hi[6]);
+  u_lo[2] = _mm_add_epi32(v_lo[3], v_lo[4]);
+  u_hi[2] = _mm_add_epi32(v_hi[3], v_hi[4]);
+  u_lo[3] = _mm_sub_epi32(u_lo[2], u_lo[0]);
+  u_hi[3] = _mm_sub_epi32(u_hi[2], u_hi[0]);
+  u_lo[4] = _mm_slli_epi32(v_lo[5], 2);
+  u_hi[4] = _mm_slli_epi32(v_hi[5], 2);
+  u_lo[5] = _mm_sub_epi32(u_lo[4], v_lo[5]);
+  u_hi[5] = _mm_sub_epi32(u_hi[4], v_hi[5]);
+  u_lo[6] = _mm_add_epi32(u_lo[3], u_lo[5]);
+  u_hi[6] = _mm_add_epi32(u_hi[3], u_hi[5]);
+
+  v_lo[0] = _mm_add_epi32(u_lo[0], __rounding);
+  v_hi[0] = _mm_add_epi32(u_hi[0], __rounding);
+  v_lo[1] = _mm_add_epi32(u_lo[1], __rounding);
+  v_hi[1] = _mm_add_epi32(u_hi[1], __rounding);
+  v_lo[2] = _mm_add_epi32(u_lo[2], __rounding);
+  v_hi[2] = _mm_add_epi32(u_hi[2], __rounding);
+  v_lo[3] = _mm_add_epi32(u_lo[6], __rounding);
+  v_hi[3] = _mm_add_epi32(u_hi[6], __rounding);
+
+  u_lo[0] = _mm_srai_epi32(v_lo[0], cos_bit);
+  u_hi[0] = _mm_srai_epi32(v_hi[0], cos_bit);
+  u_lo[1] = _mm_srai_epi32(v_lo[1], cos_bit);
+  u_hi[1] = _mm_srai_epi32(v_hi[1], cos_bit);
+  u_lo[2] = _mm_srai_epi32(v_lo[2], cos_bit);
+  u_hi[2] = _mm_srai_epi32(v_hi[2], cos_bit);
+  u_lo[3] = _mm_srai_epi32(v_lo[3], cos_bit);
+  u_hi[3] = _mm_srai_epi32(v_hi[3], cos_bit);
+
+  output[0] = _mm_packs_epi32(u_lo[0], u_hi[0]);
+  output[1] = _mm_packs_epi32(u_lo[1], u_hi[1]);
+  output[2] = _mm_packs_epi32(u_lo[2], u_hi[2]);
+  output[3] = _mm_packs_epi32(u_lo[3], u_hi[3]);
+}
+
 void fadst8_new_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) {
   const int32_t *cospi = cospi_arr(cos_bit);
   const __m128i __zero = _mm_setzero_si128();
@@ -1639,6 +1717,21 @@
   }
 }
 
+static INLINE void fidentity8x4_new_sse2(const __m128i *const input,
+                                         __m128i *const output,
+                                         const int8_t cos_bit) {
+  (void)cos_bit;
+  const __m128i one = _mm_set1_epi16(1);
+
+  for (int i = 0; i < 4; ++i) {
+    const __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
+    const __m128i a_hi = _mm_unpackhi_epi16(input[i], one);
+    const __m128i b_lo = scale_round_sse2(a_lo, NewSqrt2);
+    const __m128i b_hi = scale_round_sse2(a_hi, NewSqrt2);
+    output[i] = _mm_packs_epi32(b_lo, b_hi);
+  }
+}
+
 static INLINE void fidentity8_new_sse2(const __m128i *input, __m128i *output,
                                        int8_t cos_bit) {
   (void)cos_bit;
@@ -1694,6 +1787,25 @@
   { fidentity4_new_sse2, fadst4_new_sse2 },      // H_FLIPADST
 };
 
+static const transform_2d_sse2 txfm4x8_arr[16] = {
+  { fdct8_new_sse2, fdct4_new_sse2 },              // DCT_DCT
+  { fadst8_new_sse2, fdct4_new_sse2 },             // ADST_DCT
+  { fdct8_new_sse2, fadst8x4_new_sse2 },           // DCT_ADST
+  { fadst8_new_sse2, fadst8x4_new_sse2 },          // ADST_ADST
+  { fadst8_new_sse2, fdct4_new_sse2 },             // FLIPADST_DCT
+  { fdct8_new_sse2, fadst8x4_new_sse2 },           // DCT_FLIPADST
+  { fadst8_new_sse2, fadst8x4_new_sse2 },          // FLIPADST_FLIPADST
+  { fadst8_new_sse2, fadst8x4_new_sse2 },          // ADST_FLIPADST
+  { fadst8_new_sse2, fadst8x4_new_sse2 },          // FLIPADST_ADST
+  { fidentity8_new_sse2, fidentity8x4_new_sse2 },  // IDTX
+  { fdct8_new_sse2, fidentity8x4_new_sse2 },       // V_DCT
+  { fidentity8_new_sse2, fdct4_new_sse2 },         // H_DCT
+  { fadst8_new_sse2, fidentity8x4_new_sse2 },      // V_ADST
+  { fidentity8_new_sse2, fadst8x4_new_sse2 },      // H_ADST
+  { fadst8_new_sse2, fidentity8x4_new_sse2 },      // V_FLIPADST
+  { fidentity8_new_sse2, fadst8x4_new_sse2 },      // H_FLIPADST
+};
+
 static const transform_2d_sse2 txfm8_arr[] = {
   { fdct8_new_sse2, fdct8_new_sse2 },            // DCT_DCT
   { fadst8_new_sse2, fdct8_new_sse2 },           // ADST_DCT
@@ -1789,6 +1901,45 @@
   store_buffer_16bit_to_32bit_w4(buf, output, width, height);
 }
 
+void av1_lowbd_fwd_txfm2d_4x8_sse2(const int16_t *input, int32_t *output,
+                                   int stride, TX_TYPE tx_type, int bd) {
+  (void)stride;
+  (void)bd;
+  __m128i buf0[8], buf1[8], *buf;
+  const int8_t *shift = fwd_txfm_shift_ls[TX_4X8];
+  const int txw_idx = get_txw_idx(TX_4X8);
+  const int txh_idx = get_txh_idx(TX_4X8);
+  const int cos_bit_col = fwd_cos_bit_col[txw_idx][txh_idx];
+  const int cos_bit_row = fwd_cos_bit_row[txw_idx][txh_idx];
+  const int width = 4;
+  const int height = 8;
+  const transform_1d_sse2 col_txfm = txfm4x8_arr[tx_type].col;
+  const transform_1d_sse2 row_txfm = txfm4x8_arr[tx_type].row;
+  int ud_flip, lr_flip;
+
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  if (ud_flip) {
+    load_buffer_16bit_to_16bit_w4_flip(input, stride, buf0, height);
+  } else {
+    load_buffer_16bit_to_16bit_w4(input, stride, buf0, height);
+  }
+  round_shift_16bit(buf0, height, shift[0]);
+  col_txfm(buf0, buf0, cos_bit_col);
+  round_shift_16bit(buf0, height, shift[1]);
+  transpose_16bit_4x8(buf0, buf1);
+
+  if (lr_flip) {
+    buf = buf0;
+    flip_buf_sse2(buf1, buf, width);
+  } else {
+    buf = buf1;
+  }
+  row_txfm(buf, buf, cos_bit_row);
+  round_shift_16bit(buf, width, shift[2]);
+  transpose_16bit_8x4(buf, buf);
+  store_rect_buffer_16bit_to_32bit_w4(buf, output, width, height);
+}
+
 void av1_lowbd_fwd_txfm2d_8x8_sse2(const int16_t *input, int32_t *output,
                                    int stride, TX_TYPE tx_type, int bd) {
   (void)bd;
@@ -1864,7 +2015,7 @@
     row_txfm(buf, buf, cos_bit_row);
     round_shift_16bit(buf, width, shift[2]);
     transpose_16bit_8x8(buf, buf);
-    store_rect_buffer_16bit_to_32bit_8x8(buf, output + 8 * width * i, width);
+    store_rect_buffer_16bit_to_32bit_w8(buf, output + 8 * width * i, width, 8);
   }
 }
 
@@ -1950,9 +2101,9 @@
   row_txfm(buf, buf, cos_bit_row);
   round_shift_16bit(buf, width, shift[2]);
   transpose_16bit_8x8(buf, buf);
-  store_rect_buffer_16bit_to_32bit_8x8(buf, output, width);
+  store_rect_buffer_16bit_to_32bit_w8(buf, output, width, height);
   transpose_16bit_8x8(buf + 8, buf + 8);
-  store_rect_buffer_16bit_to_32bit_8x8(buf + 8, output + 8, width);
+  store_rect_buffer_16bit_to_32bit_w8(buf + 8, output + 8, width, height);
 }
 
 void av1_lowbd_fwd_txfm2d_16x16_sse2(const int16_t *input, int32_t *output,
@@ -2045,10 +2196,11 @@
       row_txfm(buf, buf, cos_bit_row);
       round_shift_16bit(buf, width, shift[2]);
       transpose_16bit_8x8(buf, buf);
-      store_rect_buffer_16bit_to_32bit_8x8(buf, output + 8 * width * i, width);
+      store_rect_buffer_16bit_to_32bit_w8(buf, output + 8 * width * i, width,
+                                          8);
       transpose_16bit_8x8(buf + 8, buf + 8);
-      store_rect_buffer_16bit_to_32bit_8x8(buf + 8, output + 8 * width * i + 8,
-                                           width);
+      store_rect_buffer_16bit_to_32bit_w8(buf + 8, output + 8 * width * i + 8,
+                                          width, 8);
     }
   } else {
     av1_fwd_txfm2d_16x32_c(input, output, stride, tx_type, bd);
@@ -2154,16 +2306,17 @@
       row_txfm(buf, buf, cos_bit_row);
       round_shift_16bit(buf, width, shift[2]);
       transpose_16bit_8x8(buf, buf);
-      store_rect_buffer_16bit_to_32bit_8x8(buf, output + 8 * width * i, width);
+      store_rect_buffer_16bit_to_32bit_w8(buf, output + 8 * width * i, width,
+                                          8);
       transpose_16bit_8x8(buf + 8, buf + 8);
-      store_rect_buffer_16bit_to_32bit_8x8(buf + 8, output + 8 * width * i + 8,
-                                           width);
+      store_rect_buffer_16bit_to_32bit_w8(buf + 8, output + 8 * width * i + 8,
+                                          width, 8);
       transpose_16bit_8x8(buf + 16, buf + 16);
-      store_rect_buffer_16bit_to_32bit_8x8(buf + 16,
-                                           output + 8 * width * i + 16, width);
+      store_rect_buffer_16bit_to_32bit_w8(buf + 16, output + 8 * width * i + 16,
+                                          width, 8);
       transpose_16bit_8x8(buf + 24, buf + 24);
-      store_rect_buffer_16bit_to_32bit_8x8(buf + 24,
-                                           output + 8 * width * i + 24, width);
+      store_rect_buffer_16bit_to_32bit_w8(buf + 24, output + 8 * width * i + 24,
+                                          width, 8);
     }
   } else {
     av1_fwd_txfm2d_32x16_c(input, output, stride, tx_type, bd);
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index 7c8c0f6..0068daa 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -218,7 +218,7 @@
 #if CONFIG_TX64X64
   { NULL, NULL },                                             // TX_64X64
 #endif                                                        // CONFIG_TX64X64
-  { NULL, NULL },                                             // TX_4X8
+  { av1_fwd_txfm2d_4x8_c, av1_lowbd_fwd_txfm2d_4x8_sse2 },    // TX_4X8
   { NULL, NULL },                                             // TX_8X4
   { av1_fwd_txfm2d_8x16_c, av1_lowbd_fwd_txfm2d_8x16_sse2 },  // TX_8X16
   { av1_fwd_txfm2d_16x8_c, av1_lowbd_fwd_txfm2d_16x8_sse2 },  // TX_16X8