Merge "Improvement on hybrid transform 4x4 DCT_DCT SSE4.1 optimization" into nextgenv2
diff --git a/vp10/encoder/x86/highbd_fwd_txfm_sse4.c b/vp10/encoder/x86/highbd_fwd_txfm_sse4.c
index 5fa4fc8..16323b3 100644
--- a/vp10/encoder/x86/highbd_fwd_txfm_sse4.c
+++ b/vp10/encoder/x86/highbd_fwd_txfm_sse4.c
@@ -13,15 +13,14 @@
 
 #include "./vp10_rtcd.h"
 #include "./vpx_config.h"
+#include "vp10/common/vp10_fwd_txfm2d_cfg.h"
+#include "vp10/common/vp10_txfm.h"
 #include "vpx_dsp/txfm_common.h"
 #include "vpx_ports/mem.h"
 
 static INLINE void load_buffer_4x4(const int16_t *input, __m128i *in,
-                                   int stride, int flipud, int fliplr) {
-  const __m128i k__nonzero_bias_a = _mm_setr_epi32(0, 1, 1, 1);
-  const __m128i k__nonzero_bias_b = _mm_setr_epi32(1, 0, 0, 0);
-  __m128i mask;
-
+                                   int stride, int flipud, int fliplr,
+                                   int shift) {
   if (!flipud) {
     in[0] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride));
     in[1] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride));
@@ -46,120 +45,73 @@
   in[2] = _mm_cvtepi16_epi32(in[2]);
   in[3] = _mm_cvtepi16_epi32(in[3]);
 
-  in[0] = _mm_slli_epi32(in[0], 4);
-  in[1] = _mm_slli_epi32(in[1], 4);
-  in[2] = _mm_slli_epi32(in[2], 4);
-  in[3] = _mm_slli_epi32(in[3], 4);
-
-  mask = _mm_cmpeq_epi32(in[0], k__nonzero_bias_a);
-  in[0] = _mm_add_epi32(in[0], mask);
-  in[0] = _mm_add_epi32(in[0], k__nonzero_bias_b);
+  in[0] = _mm_slli_epi32(in[0], shift);
+  in[1] = _mm_slli_epi32(in[1], shift);
+  in[2] = _mm_slli_epi32(in[2], shift);
+  in[3] = _mm_slli_epi32(in[3], shift);
 }
 
-static void fdct4x4_sse4_1(__m128i *in) {
-  const __m128i k__cospi_p16_p16 = _mm_set1_epi64x(cospi_16_64);
-  const __m128i k__cospi_m16_m16 = _mm_set1_epi64x(-cospi_16_64);
-  const __m128i k__cospi_p08_p08 = _mm_set1_epi64x(cospi_8_64);
-  const __m128i k__cospi_m08_m08 = _mm_set1_epi64x(-cospi_8_64);
-  const __m128i k__cospi_p24_p24 = _mm_set1_epi64x(cospi_24_64);
-  const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi64x(DCT_CONST_ROUNDING);
-
-  __m128i s[8];
+// We only use stage-2 bit;
+// shift[0] is used in load_buffer_4x4()
+// shift[1] is used in txfm_func_col()
+// shift[2] is used in txfm_func_row()
+static void fdct4x4_sse4_1(__m128i *in, int bit) {
+  const int32_t *cospi = cospi_arr[bit - cos_bit_min];
+  const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
+  const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
+  const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
+  const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
+  __m128i s0, s1, s2, s3;
   __m128i u0, u1, u2, u3;
-  __m128i v0, v1, v2, v3, v4, v5, v6, v7;
+  __m128i v0, v1, v2, v3;
 
-  s[0] = _mm_add_epi32(in[0], in[3]);
-  s[1] = _mm_add_epi32(in[1], in[2]);
-  s[2] = _mm_sub_epi32(in[1], in[2]);
-  s[3] = _mm_sub_epi32(in[0], in[3]);
+  s0 = _mm_add_epi32(in[0], in[3]);
+  s1 = _mm_add_epi32(in[1], in[2]);
+  s2 = _mm_sub_epi32(in[1], in[2]);
+  s3 = _mm_sub_epi32(in[0], in[3]);
 
-  v0 = _mm_cvtepi32_epi64(s[0]);  // s01 s00
-  v1 = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(s[0], s[0]));  // s03 s02
-  v2 = _mm_cvtepi32_epi64(s[1]);  // s11 s10
-  v3 = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(s[1], s[1]));  // s13 s12
-  v4 = _mm_cvtepi32_epi64(s[2]);  // s21 s20
-  v5 = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(s[2], s[2]));  // s23 s22
-  v6 = _mm_cvtepi32_epi64(s[3]);  // s31 s30
-  v7 = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(s[3], s[3]));  // s33 s32
+  // btf_32_sse4_1_type0(cospi32, cospi32, s[01], u[02], bit);
+  u0 = _mm_mullo_epi32(s0, cospi32);
+  u1 = _mm_mullo_epi32(s1, cospi32);
+  u2 = _mm_add_epi32(u0, u1);
+  v0 = _mm_sub_epi32(u0, u1);
 
-  u0 = _mm_mul_epi32(v0, k__cospi_p16_p16);
-  u1 = _mm_mul_epi32(v1, k__cospi_p16_p16);
-  u2 = _mm_mul_epi32(v2, k__cospi_p16_p16);
-  u3 = _mm_mul_epi32(v3, k__cospi_p16_p16);
+  u3 = _mm_add_epi32(u2, rnding);
+  v1 = _mm_add_epi32(v0, rnding);
 
-  s[0] = _mm_add_epi64(u0, u2);  // y10 y00
-  s[1] = _mm_add_epi64(u1, u3);  // y30 y20
+  u0 = _mm_srai_epi32(u3, bit);
+  u2 = _mm_srai_epi32(v1, bit);
 
-  u2 = _mm_mul_epi32(v2, k__cospi_m16_m16);
-  u3 = _mm_mul_epi32(v3, k__cospi_m16_m16);
+  // btf_32_sse4_1_type1(cospi48, cospi16, s[23], u[13], bit);
+  v0 = _mm_mullo_epi32(s2, cospi48);
+  v1 = _mm_mullo_epi32(s3, cospi16);
+  v2 = _mm_add_epi32(v0, v1);
 
-  s[2] = _mm_add_epi64(u0, u2);  // y12 y02
-  s[3] = _mm_add_epi64(u1, u3);  // y32 y22
+  v3 = _mm_add_epi32(v2, rnding);
+  u1 = _mm_srai_epi32(v3, bit);
 
-  u0 = _mm_mul_epi32(v6, k__cospi_p08_p08);
-  u1 = _mm_mul_epi32(v5, k__cospi_p24_p24);
-  u2 = _mm_mul_epi32(v4, k__cospi_p24_p24);
-  u3 = _mm_mul_epi32(v7, k__cospi_p08_p08);
+  v0 = _mm_mullo_epi32(s2, cospi16);
+  v1 = _mm_mullo_epi32(s3, cospi48);
+  v2 = _mm_sub_epi32(v1, v0);
 
-  s[4] = _mm_add_epi64(u0, u2);  // y11 y01
-  s[5] = _mm_add_epi64(u1, u3);  // y31 y21
+  v3 = _mm_add_epi32(v2, rnding);
+  u3 = _mm_srai_epi32(v3, bit);
 
-  u0 = _mm_mul_epi32(v4, k__cospi_m08_m08);
-  u1 = _mm_mul_epi32(v5, k__cospi_m08_m08);
-  u2 = _mm_mul_epi32(v6, k__cospi_p24_p24);
-  u3 = _mm_mul_epi32(v7, k__cospi_p24_p24);
+  // Note: shift[1] and shift[2] are zeros
 
-  s[6] = _mm_add_epi64(u0, u2);  // y13 y03
-  s[7] = _mm_add_epi64(u1, u3);  // y33 y23
+  // Transpose 4x4 32-bit
+  v0 = _mm_unpacklo_epi32(u0, u1);
+  v1 = _mm_unpackhi_epi32(u0, u1);
+  v2 = _mm_unpacklo_epi32(u2, u3);
+  v3 = _mm_unpackhi_epi32(u2, u3);
 
-  s[0] = _mm_add_epi64(s[0], k__DCT_CONST_ROUNDING);
-  s[1] = _mm_add_epi64(s[1], k__DCT_CONST_ROUNDING);
-  s[2] = _mm_add_epi64(s[2], k__DCT_CONST_ROUNDING);
-  s[3] = _mm_add_epi64(s[3], k__DCT_CONST_ROUNDING);
-  s[4] = _mm_add_epi64(s[4], k__DCT_CONST_ROUNDING);
-  s[5] = _mm_add_epi64(s[5], k__DCT_CONST_ROUNDING);
-  s[6] = _mm_add_epi64(s[6], k__DCT_CONST_ROUNDING);
-  s[7] = _mm_add_epi64(s[7], k__DCT_CONST_ROUNDING);
-
-  s[0] = _mm_srli_epi64(s[0], DCT_CONST_BITS);
-  s[1] = _mm_srli_epi64(s[1], DCT_CONST_BITS);
-  s[2] = _mm_srli_epi64(s[2], DCT_CONST_BITS);
-  s[3] = _mm_srli_epi64(s[3], DCT_CONST_BITS);
-  s[4] = _mm_srli_epi64(s[4], DCT_CONST_BITS);
-  s[5] = _mm_srli_epi64(s[5], DCT_CONST_BITS);
-  s[6] = _mm_srli_epi64(s[6], DCT_CONST_BITS);
-  s[7] = _mm_srli_epi64(s[7], DCT_CONST_BITS);
-
-  s[0] = _mm_shuffle_epi32(s[0], 0x88);
-  s[1] = _mm_shuffle_epi32(s[1], 0x88);
-  s[2] = _mm_shuffle_epi32(s[2], 0x88);
-  s[3] = _mm_shuffle_epi32(s[3], 0x88);
-  s[4] = _mm_shuffle_epi32(s[4], 0x88);
-  s[5] = _mm_shuffle_epi32(s[5], 0x88);
-  s[6] = _mm_shuffle_epi32(s[6], 0x88);
-  s[7] = _mm_shuffle_epi32(s[7], 0x88);
-
-  v0 = _mm_unpacklo_epi32(s[0], s[4]);
-  v1 = _mm_unpacklo_epi32(s[2], s[6]);
-  v2 = _mm_unpacklo_epi32(s[1], s[5]);
-  v3 = _mm_unpacklo_epi32(s[3], s[7]);
-
-  in[0] = _mm_unpacklo_epi64(v0, v1);
-  in[1] = _mm_unpackhi_epi64(v0, v1);
-  in[2] = _mm_unpacklo_epi64(v2, v3);
-  in[3] = _mm_unpackhi_epi64(v2, v3);
+  in[0] = _mm_unpacklo_epi64(v0, v2);
+  in[1] = _mm_unpackhi_epi64(v0, v2);
+  in[2] = _mm_unpacklo_epi64(v1, v3);
+  in[3] = _mm_unpackhi_epi64(v1, v3);
 }
 
 static INLINE void write_buffer_4x4(tran_low_t *output, __m128i *res) {
-  const __m128i kOne = _mm_set1_epi32(1);
-  res[0] = _mm_add_epi32(res[0], kOne);
-  res[1] = _mm_add_epi32(res[1], kOne);
-  res[2] = _mm_add_epi32(res[2], kOne);
-  res[3] = _mm_add_epi32(res[3], kOne);
-  res[0] = _mm_srai_epi32(res[0], 2);
-  res[1] = _mm_srai_epi32(res[1], 2);
-  res[2] = _mm_srai_epi32(res[2], 2);
-  res[3] = _mm_srai_epi32(res[3], 2);
   _mm_store_si128((__m128i *)(output + 0 * 4), res[0]);
   _mm_store_si128((__m128i *)(output + 1 * 4), res[1]);
   _mm_store_si128((__m128i *)(output + 2 * 4), res[2]);
@@ -169,12 +121,17 @@
 void vp10_highbd_fht4x4_sse4_1(const int16_t *input, tran_low_t *output,
                                int stride, int tx_type) {
   __m128i in[4];
+  const TXFM_2D_CFG *cfg;
+  int bit;
 
   switch (tx_type) {
     case DCT_DCT:
-      load_buffer_4x4(input, in, stride, 0, 0);
-      fdct4x4_sse4_1(in);
-      fdct4x4_sse4_1(in);
+      cfg = &fwd_txfm_2d_cfg_dct_dct_4;
+      load_buffer_4x4(input, in, stride, 0, 0, cfg->shift[0]);
+      bit = cfg->cos_bit_col[2];
+      fdct4x4_sse4_1(in, bit);
+      bit = cfg->cos_bit_row[2];
+      fdct4x4_sse4_1(in, bit);
       write_buffer_4x4(output, in);
       break;
     case ADST_DCT: