Implement av1_lowbd_fwd_txfm2d_4x4_sse2

Change-Id: I297b1572e89d5668529855f2436b47af95ea2f80
diff --git a/av1/common/x86/av1_txfm_sse2.h b/av1/common/x86/av1_txfm_sse2.h
index efbcae7..10415dd 100644
--- a/av1/common/x86/av1_txfm_sse2.h
+++ b/av1/common/x86/av1_txfm_sse2.h
@@ -64,6 +64,13 @@
   return _mm_packs_epi32(a_low, a_low);
 }
 
+// Store 4 16 bit values. Sign extend the values.
+static INLINE void store_16bit_to_32bit_w4(const __m128i a, int32_t *const b) {
+  const __m128i a_lo = _mm_unpacklo_epi16(a, a);
+  const __m128i a_1 = _mm_srai_epi32(a_lo, 16);
+  _mm_store_si128((__m128i *)b, a_1);
+}
+
 // Store 8 16 bit values. Sign extend the values.
 static INLINE void store_16bit_to_32bit(__m128i a, int32_t *b) {
   const __m128i a_lo = _mm_unpacklo_epi16(a, a);
@@ -74,20 +81,39 @@
   _mm_store_si128((__m128i *)(b + 4), a_2);
 }
 
-static INLINE void store_rect_16bit_to_32bit(__m128i a, int32_t *b) {
-  const __m128i sqrt2_coef = _mm_set1_epi16(NewSqrt2);
-  const __m128i rounding = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
-  __m128i a_lo, a_hi;
-  a_lo = _mm_unpacklo_epi16(a, _mm_setzero_si128());
-  a_hi = _mm_unpackhi_epi16(a, _mm_setzero_si128());
-  a_lo = _mm_madd_epi16(a_lo, sqrt2_coef);
-  a_hi = _mm_madd_epi16(a_hi, sqrt2_coef);
-  a_lo = _mm_add_epi32(a_lo, rounding);
-  a_hi = _mm_add_epi32(a_hi, rounding);
-  a_lo = _mm_srai_epi32(a_lo, NewSqrt2Bits);
-  a_hi = _mm_srai_epi32(a_hi, NewSqrt2Bits);
-  _mm_store_si128((__m128i *)b, a_lo);
-  _mm_store_si128((__m128i *)(b + 4), a_hi);
+static INLINE __m128i scale_round_sse2(const __m128i a, const int scale) {
+  const __m128i scale_rounding = pair_set_epi16(scale, 1 << (NewSqrt2Bits - 1));
+  const __m128i b = _mm_madd_epi16(a, scale_rounding);
+  return _mm_srai_epi32(b, NewSqrt2Bits);
+}
+
+static INLINE void store_rect_16bit_to_32bit(const __m128i a,
+                                             int32_t *const b) {
+  const __m128i one = _mm_set1_epi16(1);
+  const __m128i a_lo = _mm_unpacklo_epi16(a, one);
+  const __m128i a_hi = _mm_unpackhi_epi16(a, one);
+  const __m128i b_lo = scale_round_sse2(a_lo, NewSqrt2);
+  const __m128i b_hi = scale_round_sse2(a_hi, NewSqrt2);
+  _mm_store_si128((__m128i *)b, b_lo);
+  _mm_store_si128((__m128i *)(b + 4), b_hi);
+}
+
+static INLINE void load_buffer_16bit_to_16bit_w4(const int16_t *const in,
+                                                 const int stride,
+                                                 __m128i *const out,
+                                                 const int out_size) {
+  for (int i = 0; i < out_size; ++i) {
+    out[i] = _mm_loadl_epi64((const __m128i *)(in + i * stride));
+  }
+}
+
+static INLINE void load_buffer_16bit_to_16bit_w4_flip(const int16_t *const in,
+                                                      const int stride,
+                                                      __m128i *const out,
+                                                      const int out_size) {
+  for (int i = 0; i < out_size; ++i) {
+    out[out_size - i - 1] = _mm_loadl_epi64((const __m128i *)(in + i * stride));
+  }
 }
 
 static INLINE void load_buffer_16bit_to_16bit(const int16_t *in, int stride,
@@ -127,6 +153,15 @@
   }
 }
 
+static INLINE void store_buffer_16bit_to_32bit_w4(const __m128i *const in,
+                                                  int32_t *const out,
+                                                  const int stride,
+                                                  const int out_size) {
+  for (int i = 0; i < out_size; ++i) {
+    store_16bit_to_32bit_w4(in[i], out + i * stride);
+  }
+}
+
 static INLINE void store_buffer_16bit_to_32bit_8x8(const __m128i *const in,
                                                    int32_t *const out,
                                                    const int stride) {
@@ -172,6 +207,9 @@
   }
 }
 
+void av1_lowbd_fwd_txfm2d_4x4_sse2(const int16_t *input, int32_t *output,
+                                   int stride, TX_TYPE tx_type, int bd);
+
 void av1_lowbd_fwd_txfm2d_8x8_sse2(const int16_t *input, int32_t *output,
                                    int stride, TX_TYPE tx_type, int bd);
 
diff --git a/av1/encoder/x86/av1_fwd_txfm_sse2.c b/av1/encoder/x86/av1_fwd_txfm_sse2.c
index 9100f1e..cbe62ae 100644
--- a/av1/encoder/x86/av1_fwd_txfm_sse2.c
+++ b/av1/encoder/x86/av1_fwd_txfm_sse2.c
@@ -1303,6 +1303,57 @@
   output[63] = x10[63];
 }
 
+static void fadst4_new_sse2(const __m128i *input, __m128i *output,
+                            int8_t cos_bit) {
+  const int32_t *sinpi = sinpi_arr(cos_bit);
+  const __m128i sinpi_p01_p02 = pair_set_epi16(sinpi[1], sinpi[2]);
+  const __m128i sinpi_p04_m01 = pair_set_epi16(sinpi[4], -sinpi[1]);
+  const __m128i sinpi_p03_p04 = pair_set_epi16(sinpi[3], sinpi[4]);
+  const __m128i sinpi_m03_p02 = pair_set_epi16(-sinpi[3], sinpi[2]);
+  const __m128i sinpi_p03_p03 = _mm_set1_epi16((int16_t)sinpi[3]);
+  const __m128i __zero = _mm_set1_epi16(0);
+  const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));
+  const __m128i in7 = _mm_add_epi16(input[0], input[1]);
+  __m128i u[8], v[8];
+
+  u[0] = _mm_unpacklo_epi16(input[0], input[1]);
+  u[1] = _mm_unpacklo_epi16(input[2], input[3]);
+  u[2] = _mm_unpacklo_epi16(in7, __zero);
+  u[3] = _mm_unpacklo_epi16(input[2], __zero);
+  u[4] = _mm_unpacklo_epi16(input[3], __zero);
+
+  v[0] = _mm_madd_epi16(u[0], sinpi_p01_p02);  // s0 + s2
+  v[1] = _mm_madd_epi16(u[1], sinpi_p03_p04);  // s4 + s5
+  v[2] = _mm_madd_epi16(u[2], sinpi_p03_p03);  // x1
+  v[3] = _mm_madd_epi16(u[0], sinpi_p04_m01);  // s1 - s3
+  v[4] = _mm_madd_epi16(u[1], sinpi_m03_p02);  // -s4 + s6
+  v[5] = _mm_madd_epi16(u[3], sinpi_p03_p03);  // s4
+  v[6] = _mm_madd_epi16(u[4], sinpi_p03_p03);
+
+  u[0] = _mm_add_epi32(v[0], v[1]);
+  u[1] = _mm_sub_epi32(v[2], v[6]);
+  u[2] = _mm_add_epi32(v[3], v[4]);
+  u[3] = _mm_sub_epi32(u[2], u[0]);
+  u[4] = _mm_slli_epi32(v[5], 2);
+  u[5] = _mm_sub_epi32(u[4], v[5]);
+  u[6] = _mm_add_epi32(u[3], u[5]);
+
+  v[0] = _mm_add_epi32(u[0], __rounding);
+  v[1] = _mm_add_epi32(u[1], __rounding);
+  v[2] = _mm_add_epi32(u[2], __rounding);
+  v[3] = _mm_add_epi32(u[6], __rounding);
+
+  u[0] = _mm_srai_epi32(v[0], cos_bit);
+  u[1] = _mm_srai_epi32(v[1], cos_bit);
+  u[2] = _mm_srai_epi32(v[2], cos_bit);
+  u[3] = _mm_srai_epi32(v[3], cos_bit);
+
+  output[0] = _mm_packs_epi32(u[0], u[2]);
+  output[1] = _mm_packs_epi32(u[1], u[3]);
+  output[2] = _mm_srli_si128(output[0], 8);
+  output[3] = _mm_srli_si128(output[1], 8);
+}
+
 void fadst8_new_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) {
   const int32_t *cospi = cospi_arr(cos_bit);
   const __m128i __zero = _mm_setzero_si128();
@@ -1575,6 +1626,19 @@
   output[15] = x8[0];
 }
 
+static INLINE void fidentity4_new_sse2(const __m128i *const input,
+                                       __m128i *const output,
+                                       const int8_t cos_bit) {
+  (void)cos_bit;
+  const __m128i one = _mm_set1_epi16(1);
+
+  for (int i = 0; i < 4; ++i) {
+    const __m128i a = _mm_unpacklo_epi16(input[i], one);
+    const __m128i b = scale_round_sse2(a, NewSqrt2);
+    output[i] = _mm_packs_epi32(b, b);
+  }
+}
+
 static INLINE void fidentity8_new_sse2(const __m128i *input, __m128i *output,
                                        int8_t cos_bit) {
   (void)cos_bit;
@@ -1592,18 +1656,14 @@
 static INLINE void fidentity16_new_sse2(const __m128i *input, __m128i *output,
                                         int8_t cos_bit) {
   (void)cos_bit;
-  const __m128i scale = _mm_set1_epi16(2 * NewSqrt2);
-  const __m128i rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
   const __m128i one = _mm_set1_epi16(1);
-  const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
+
   for (int i = 0; i < 16; ++i) {
-    __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
-    __m128i a_hi = _mm_unpackhi_epi16(input[i], one);
-    __m128i b_lo = _mm_madd_epi16(a_lo, scale_rounding);
-    __m128i b_hi = _mm_madd_epi16(a_hi, scale_rounding);
-    __m128i c_lo = _mm_srai_epi32(b_lo, NewSqrt2Bits);
-    __m128i c_hi = _mm_srai_epi32(b_hi, NewSqrt2Bits);
-    output[i] = _mm_packs_epi32(c_lo, c_hi);
+    const __m128i a_lo = _mm_unpacklo_epi16(input[i], one);
+    const __m128i a_hi = _mm_unpackhi_epi16(input[i], one);
+    const __m128i b_lo = scale_round_sse2(a_lo, 2 * NewSqrt2);
+    const __m128i b_hi = scale_round_sse2(a_hi, 2 * NewSqrt2);
+    output[i] = _mm_packs_epi32(b_lo, b_hi);
   }
 }
 
@@ -1615,6 +1675,25 @@
   }
 }
 
+static const transform_2d_sse2 txfm4_arr[] = {
+  { fdct4_new_sse2, fdct4_new_sse2 },            // DCT_DCT
+  { fadst4_new_sse2, fdct4_new_sse2 },           // ADST_DCT
+  { fdct4_new_sse2, fadst4_new_sse2 },           // DCT_ADST
+  { fadst4_new_sse2, fadst4_new_sse2 },          // ADST_ADST
+  { fadst4_new_sse2, fdct4_new_sse2 },           // FLIPADST_DCT
+  { fdct4_new_sse2, fadst4_new_sse2 },           // DCT_FLIPADST
+  { fadst4_new_sse2, fadst4_new_sse2 },          // FLIPADST_FLIPADST
+  { fadst4_new_sse2, fadst4_new_sse2 },          // ADST_FLIPADST
+  { fadst4_new_sse2, fadst4_new_sse2 },          // FLIPADST_ADST
+  { fidentity4_new_sse2, fidentity4_new_sse2 },  // IDTX
+  { fdct4_new_sse2, fidentity4_new_sse2 },       // V_DCT
+  { fidentity4_new_sse2, fdct4_new_sse2 },       // H_DCT
+  { fadst4_new_sse2, fidentity4_new_sse2 },      // V_ADST
+  { fidentity4_new_sse2, fadst4_new_sse2 },      // H_ADST
+  { fadst4_new_sse2, fidentity4_new_sse2 },      // V_FLIPADST
+  { fidentity4_new_sse2, fadst4_new_sse2 },      // H_FLIPADST
+};
+
 static const transform_2d_sse2 txfm8_arr[] = {
   { fdct8_new_sse2, fdct8_new_sse2 },            // DCT_DCT
   { fadst8_new_sse2, fdct8_new_sse2 },           // ADST_DCT
@@ -1672,40 +1751,79 @@
   { NULL, NULL },                                  // H_FLIPADST
 };
 
+void av1_lowbd_fwd_txfm2d_4x4_sse2(const int16_t *input, int32_t *output,
+                                   int stride, TX_TYPE tx_type, int bd) {
+  (void)bd;
+  __m128i buf0[4], buf1[4], *buf;
+  const int8_t *shift = fwd_txfm_shift_ls[TX_4X4];
+  const int txw_idx = get_txw_idx(TX_4X4);
+  const int txh_idx = get_txh_idx(TX_4X4);
+  const int cos_bit_col = fwd_cos_bit_col[txw_idx][txh_idx];
+  const int cos_bit_row = fwd_cos_bit_row[txw_idx][txh_idx];
+  const int width = 4;
+  const int height = 4;
+  const transform_1d_sse2 col_txfm = txfm4_arr[tx_type].col;
+  const transform_1d_sse2 row_txfm = txfm4_arr[tx_type].row;
+  int ud_flip, lr_flip;
+
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
+  if (ud_flip) {
+    load_buffer_16bit_to_16bit_w4_flip(input, stride, buf0, height);
+  } else {
+    load_buffer_16bit_to_16bit_w4(input, stride, buf0, height);
+  }
+  round_shift_16bit(buf0, height, shift[0]);
+  col_txfm(buf0, buf0, cos_bit_col);
+  round_shift_16bit(buf0, height, shift[1]);
+  transpose_16bit_4x4(buf0, buf1);
+
+  if (lr_flip) {
+    buf = buf0;
+    flip_buf_sse2(buf1, buf, width);
+  } else {
+    buf = buf1;
+  }
+  row_txfm(buf, buf, cos_bit_row);
+  round_shift_16bit(buf, width, shift[2]);
+  transpose_16bit_4x4(buf, buf);
+  store_buffer_16bit_to_32bit_w4(buf, output, width, height);
+}
+
 void av1_lowbd_fwd_txfm2d_8x8_sse2(const int16_t *input, int32_t *output,
                                    int stride, TX_TYPE tx_type, int bd) {
-  (void)stride;
   (void)bd;
-  __m128i buf[8];
+  __m128i buf0[8], buf1[8], *buf;
   const int8_t *shift = fwd_txfm_shift_ls[TX_8X8];
   const int txw_idx = get_txw_idx(TX_8X8);
   const int txh_idx = get_txh_idx(TX_8X8);
   const int cos_bit_col = fwd_cos_bit_col[txw_idx][txh_idx];
   const int cos_bit_row = fwd_cos_bit_row[txw_idx][txh_idx];
-  const int buf_size = 8;
-
+  const int width = 8;
+  const int height = 8;
   const transform_1d_sse2 col_txfm = txfm8_arr[tx_type].col;
   const transform_1d_sse2 row_txfm = txfm8_arr[tx_type].row;
   int ud_flip, lr_flip;
+
   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   if (ud_flip)
-    load_buffer_16bit_to_16bit_flip(input, stride, buf, buf_size);
+    load_buffer_16bit_to_16bit_flip(input, stride, buf0, height);
   else
-    load_buffer_16bit_to_16bit(input, stride, buf, buf_size);
-  round_shift_16bit(buf, 8, shift[0]);
-  col_txfm(buf, buf, cos_bit_col);
-  round_shift_16bit(buf, 8, shift[1]);
+    load_buffer_16bit_to_16bit(input, stride, buf0, height);
+  round_shift_16bit(buf0, height, shift[0]);
+  col_txfm(buf0, buf0, cos_bit_col);
+  round_shift_16bit(buf0, height, shift[1]);
+  transpose_16bit_8x8(buf0, buf1);
+
   if (lr_flip) {
-    __m128i tmp[8];
-    transpose_16bit_8x8(buf, tmp);
-    flip_buf_sse2(tmp, buf, 8);
+    buf = buf0;
+    flip_buf_sse2(buf1, buf, width);
   } else {
-    transpose_16bit_8x8(buf, buf);
+    buf = buf1;
   }
   row_txfm(buf, buf, cos_bit_row);
-  round_shift_16bit(buf, 8, shift[2]);
+  round_shift_16bit(buf, width, shift[2]);
   transpose_16bit_8x8(buf, buf);
-  store_buffer_16bit_to_32bit_8x8(buf, output, buf_size);
+  store_buffer_16bit_to_32bit_8x8(buf, output, width);
 }
 
 void av1_lowbd_fwd_txfm2d_8x16_sse2(const int16_t *input, int32_t *output,
@@ -1721,10 +1839,9 @@
   const int height = 16;
   const transform_1d_sse2 col_txfm = txfm16_arr[tx_type].col;
   const transform_1d_sse2 row_txfm = txfm8_arr[tx_type].row;
-
   int ud_flip, lr_flip;
-  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
 
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   if (ud_flip) {
     load_buffer_16bit_to_16bit_flip(input, stride, buf0, height);
   } else {
@@ -1764,10 +1881,9 @@
   const int height = 32;
   const transform_1d_sse2 col_txfm = txfm32_arr[tx_type].col;
   const transform_1d_sse2 row_txfm = txfm8_arr[tx_type].row;
-
   int ud_flip, lr_flip;
-  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
 
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   if (ud_flip) {
     load_buffer_16bit_to_16bit_flip(input, stride, buf0, height);
   } else {
@@ -1809,11 +1925,10 @@
   const int height = 8;
   const transform_1d_sse2 col_txfm = txfm8_arr[tx_type].col;
   const transform_1d_sse2 row_txfm = txfm16_arr[tx_type].row;
-
   __m128i *buf;
   int ud_flip, lr_flip;
-  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
 
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   for (int i = 0; i < 2; i++) {
     if (ud_flip) {
       load_buffer_16bit_to_16bit_flip(input + 8 * i, stride, buf0, height);
@@ -1853,10 +1968,9 @@
   const int height = 16;
   const transform_1d_sse2 col_txfm = txfm16_arr[tx_type].col;
   const transform_1d_sse2 row_txfm = txfm16_arr[tx_type].row;
-
   int ud_flip, lr_flip;
-  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
 
+  get_flip_cfg(tx_type, &ud_flip, &lr_flip);
   for (int i = 0; i < 2; i++) {
     if (ud_flip) {
       load_buffer_16bit_to_16bit_flip(input + 8 * i, stride, buf0, height);
@@ -2120,7 +2234,7 @@
                                   int stride, TX_TYPE tx_type, int bd);
 
 FwdTxfm2dFuncSSE2 fwd_txfm2d_func_ls[TX_SIZES_ALL] = {
-  NULL,                             // 4x4 transform
+  av1_lowbd_fwd_txfm2d_4x4_sse2,    // 4x4 transform
   av1_lowbd_fwd_txfm2d_8x8_sse2,    // 8x8 transform
   av1_lowbd_fwd_txfm2d_16x16_sse2,  // 16x16 transform
   av1_lowbd_fwd_txfm2d_32x32_sse2,  // 32x32 transform
@@ -2151,9 +2265,10 @@
                              int diff_stride, TxfmParam *txfm_param) {
   FwdTxfm2dFuncSSE2 fwd_txfm2d_func = fwd_txfm2d_func_ls[txfm_param->tx_size];
 
-  if (fwd_txfm2d_func)
+  if ((fwd_txfm2d_func == NULL) ||
+      (txfm_param->lossless && txfm_param->tx_size == TX_4X4))
+    av1_lowbd_fwd_txfm_c(src_diff, coeff, diff_stride, txfm_param);
+  else
     fwd_txfm2d_func(src_diff, coeff, diff_stride, txfm_param->tx_type,
                     txfm_param->bd);
-  else
-    av1_lowbd_fwd_txfm_c(src_diff, coeff, diff_stride, txfm_param);
 }
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index 2322616..7c8c0f6 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -211,7 +211,7 @@
 #if HAVE_SSE2 && defined(__SSE2__)
 #include "av1/common/x86/av1_txfm_sse2.h"
 FwdTxfm2dFunc fwd_func_sse2_list[TX_SIZES_ALL][2] = {
-  { NULL, NULL },                                               // TX_4X4
+  { av1_fwd_txfm2d_4x4_c, av1_lowbd_fwd_txfm2d_4x4_sse2 },      // TX_4X4
   { av1_fwd_txfm2d_8x8_c, av1_lowbd_fwd_txfm2d_8x8_sse2 },      // TX_8X8
   { av1_fwd_txfm2d_16x16_c, av1_lowbd_fwd_txfm2d_16x16_sse2 },  // TX_16X16
   { av1_fwd_txfm2d_32x32_c, av1_lowbd_fwd_txfm2d_32x32_sse2 },  // TX_32X32