Implement av1_fwd_txfm2d_8x8_sse2

Change-Id: I30beb0b0b0c2c96175566dcd5672ceb5000edada
diff --git a/av1/common/x86/av1_txfm_sse2.h b/av1/common/x86/av1_txfm_sse2.h
index 98db0ec..92956a0 100644
--- a/av1/common/x86/av1_txfm_sse2.h
+++ b/av1/common/x86/av1_txfm_sse2.h
@@ -14,7 +14,9 @@
 #include <emmintrin.h>  // SSE2
 
 #include "./aom_config.h"
+#include "./av1_rtcd.h"
 #include "aom/aom_integer.h"
+#include "aom_dsp/x86/transpose_sse2.h"
 #include "av1/common/av1_txfm.h"
 
 #ifdef __cplusplus
@@ -48,6 +50,79 @@
     out1 = _mm_packs_epi32(d0, d1);               \
   }
 
+static INLINE __m128i load_16bit_to_16bit(const int16_t *a) {
+  return _mm_load_si128((const __m128i *)a);
+}
+
+static INLINE __m128i load_32bit_to_16bit(const int32_t *a) {
+  const __m128i a_low = _mm_load_si128((const __m128i *)a);
+  return _mm_packs_epi32(a_low, *(const __m128i *)(a + 4));
+}
+
+// Store 8 16 bit values. If the destination is 32 bits then sign extend the
+// values by multiplying by 1.
+static INLINE void store_16bit_to_32bit(__m128i a, int32_t *b) {
+  const __m128i one = _mm_set1_epi16(1);
+  const __m128i a_hi = _mm_mulhi_epi16(a, one);
+  const __m128i a_lo = _mm_mullo_epi16(a, one);
+  const __m128i a_1 = _mm_unpacklo_epi16(a_lo, a_hi);
+  const __m128i a_2 = _mm_unpackhi_epi16(a_lo, a_hi);
+  _mm_store_si128((__m128i *)(b), a_1);
+  _mm_store_si128((__m128i *)(b + 4), a_2);
+}
+
+static INLINE void load_buffer_16bit_to_16bit(const int16_t *in, int stride,
+                                              __m128i *out, int out_size) {
+  for (int i = 0; i < out_size; ++i) {
+    out[i] = load_16bit_to_16bit(in + i * stride);
+  }
+}
+
+static INLINE void load_buffer_32bit_to_16bit(const int32_t *in, int stride,
+                                              __m128i *out, int out_size) {
+  for (int i = 0; i < out_size; ++i) {
+    out[i] = load_32bit_to_16bit(in + i * stride);
+  }
+}
+
+static INLINE void store_buffer_16bit_to_32bit_8x8(const __m128i *in,
+                                                   int32_t *out) {
+  for (int i = 0; i < 8; ++i) {
+    store_16bit_to_32bit(in[i], out + i * 8);
+  }
+}
+
+static INLINE void store_buffer_16bit_to_16bit_8x8(const __m128i *in,
+                                                   int16_t *out) {
+  for (int i = 0; i < 8; ++i) {
+    _mm_store_si128((__m128i *)(out + i * 8), in[i]);
+  }
+}
+
+static INLINE void round_shift_16bit(__m128i *in, int size, int bit) {
+  if (bit < 0) {
+    bit = -bit;
+    __m128i rounding = _mm_set1_epi16(1 << (bit - 1));
+    for (int i = 0; i < size; ++i) {
+      in[i] = _mm_adds_epi16(in[i], rounding);
+      in[i] = _mm_srai_epi16(in[i], bit);
+    }
+  } else if (bit > 0) {
+    for (int i = 0; i < size; ++i) {
+      in[i] = _mm_slli_epi16(in[i], bit);
+    }
+  }
+}
+
+void av1_fwd_txfm2d_8x8_sse2(const int16_t *input, int32_t *output, int stride,
+                             TX_TYPE tx_type, int bd);
+
+typedef void (*transform_1d_sse2)(const __m128i *input, __m128i *output,
+                                  int8_t cos_bit);
+
+typedef struct {
+  transform_1d_sse2 col, row;  // vertical and horizontal
+} transform_2d_sse2;
 #ifdef __cplusplus
 }
 #endif  // __cplusplus
diff --git a/av1/encoder/x86/av1_fwd_txfm_sse2.c b/av1/encoder/x86/av1_fwd_txfm_sse2.c
index d0c9a7f..f8eaf83 100644
--- a/av1/encoder/x86/av1_fwd_txfm_sse2.c
+++ b/av1/encoder/x86/av1_fwd_txfm_sse2.c
@@ -10,6 +10,7 @@
  */
 
 #include "av1/common/x86/av1_txfm_sse2.h"
+#include "av1/encoder/av1_fwd_txfm1d_cfg.h"
 
 void fdct4_new_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) {
   const int32_t *cospi = cospi_arr(cos_bit);
@@ -1573,3 +1574,50 @@
   output[14] = x8[15];
   output[15] = x8[0];
 }
+
+void av1_fwd_txfm2d_8x8_sse2(const int16_t *input, int32_t *output, int stride,
+                             TX_TYPE tx_type, int bd) {
+  (void)stride;
+  (void)bd;
+  __m128i buf[8];
+  const int8_t *shift = fwd_txfm_shift_ls[TX_8X8];
+  const int txw_idx = get_txw_idx(TX_8X8);
+  const int txh_idx = get_txh_idx(TX_8X8);
+  const int cos_bit_col = fwd_cos_bit_col[txw_idx][txh_idx];
+  const int cos_bit_row = fwd_cos_bit_row[txw_idx][txh_idx];
+  const int buf_size = 8;
+  static const transform_2d_sse2 txfm_arr[] = {
+    { fdct8_new_sse2, fdct8_new_sse2 },    // DCT_DCT
+    { fadst8_new_sse2, fdct8_new_sse2 },   // ADST_DCT
+    { fdct8_new_sse2, fadst8_new_sse2 },   // DCT_ADST
+    { fadst8_new_sse2, fadst8_new_sse2 },  // ADST_ADST
+    { NULL, NULL },                        // FLIPADST_DCT
+    { NULL, NULL },                        // DCT_FLIPADST
+    { NULL, NULL },                        // FLIPADST_FLIPADST
+    { NULL, NULL },                        // ADST_FLIPADST
+    { NULL, NULL },                        // FLIPADST_ADST
+    { NULL, NULL },                        // IDTX
+    { NULL, NULL },                        // V_DCT
+    { NULL, NULL },                        // H_DCT
+    { NULL, NULL },                        // V_ADST
+    { NULL, NULL },                        // H_ADST
+    { NULL, NULL },                        // V_FLIPADST
+    { NULL, NULL },                        // H_FLIPADST
+  };
+
+  const transform_1d_sse2 col_txfm = txfm_arr[tx_type].col;
+  const transform_1d_sse2 row_txfm = txfm_arr[tx_type].row;
+  if (col_txfm != NULL && row_txfm != NULL) {
+    load_buffer_16bit_to_16bit(input, stride, buf, buf_size);
+    round_shift_16bit(buf, 8, shift[0]);
+    col_txfm(buf, buf, cos_bit_col);
+    round_shift_16bit(buf, 8, shift[1]);
+    transpose_16bit_8x8(buf, buf);
+    row_txfm(buf, buf, cos_bit_row);
+    round_shift_16bit(buf, 8, shift[2]);
+    transpose_16bit_8x8(buf, buf);
+    store_buffer_16bit_to_32bit_8x8(buf, output);
+  } else {
+    av1_fwd_txfm2d_8x8_c(input, output, stride, tx_type, bd);
+  }
+}
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc
index 6b019bc..b39600e 100644
--- a/test/av1_fwd_txfm2d_test.cc
+++ b/test/av1_fwd_txfm2d_test.cc
@@ -207,4 +207,77 @@
   }
 }
 
+#if HAVE_SSE2 && defined(__SSE2__)
+#include "av1/common/x86/av1_txfm_sse2.h"
+Fwd_Txfm2d_Func fwd_func_sse2_list[TX_SIZES_ALL][2] = {
+  { NULL, NULL },  // TX_4X4
+  { av1_fwd_txfm2d_8x8_c,
+    av1_fwd_txfm2d_8x8_sse2 },  // TX_8X8    // 8x8 transform
+  { NULL, NULL },               // TX_16X16
+  { NULL, NULL },               // TX_32X32
+#if CONFIG_TX64X64
+  { NULL, NULL },  // TX_64X64
+#endif             // CONFIG_TX64X64
+  { NULL, NULL },  // TX_4X8
+  { NULL, NULL },  // TX_8X4
+  { NULL, NULL },  // TX_8X16
+  { NULL, NULL },  // TX_16X8
+  { NULL, NULL },  // TX_16X32
+  { NULL, NULL },  // TX_32X16
+#if CONFIG_TX64X64
+  { NULL, NULL },  // TX_32X64
+  { NULL, NULL },  // TX_64X32
+#endif             // CONFIG_TX64X64
+  { NULL, NULL },  // TX_4X16
+  { NULL, NULL },  // TX_16X4
+  { NULL, NULL },  // TX_8X32
+  { NULL, NULL },  // TX_32X8
+#if CONFIG_TX64X64
+  { NULL, NULL },  // TX_16X64
+  { NULL, NULL },  // TX_64X16
+#endif             // CONFIG_TX64X64
+};
+
+TEST(av1_fwd_txfm2d_sse2, match) {
+  const int bd = 8;
+  for (int tx_size = TX_4X4; tx_size < TX_SIZES_ALL; ++tx_size) {
+    for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
+      Fwd_Txfm2d_Func ref_func = fwd_func_sse2_list[tx_size][0];
+      Fwd_Txfm2d_Func target_func = fwd_func_sse2_list[tx_size][1];
+      if (ref_func != NULL && target_func != NULL) {
+        int16_t input[64 * 64] = { 0 };
+        int32_t output[64 * 64] = { 0 };
+        int32_t ref_output[64 * 64] = { 0 };
+        int input_stride = 64;
+        ACMRandom rnd(ACMRandom::DeterministicSeed());
+        int rows = tx_size_high[tx_size];
+        int cols = tx_size_wide[tx_size];
+        for (int cnt = 0; cnt < 500; ++cnt) {
+          if (cnt == 0) {
+            for (int r = 0; r < rows; ++r) {
+              for (int c = 0; c < cols; ++c) {
+                input[r * input_stride + c] = (1 << bd) - 1;
+              }
+            }
+          } else {
+            for (int r = 0; r < rows; ++r) {
+              for (int c = 0; c < cols; ++c) {
+                input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
+              }
+            }
+          }
+          ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
+          target_func(input, output, input_stride, (TX_TYPE)tx_type, bd);
+          for (int r = 0; r < rows; ++r) {
+            for (int c = 0; c < cols; ++c) {
+              ASSERT_EQ(ref_output[r * cols + c], output[r * cols + c]);
+            }
+          }
+        }
+      }
+    }
+  }
+}
+#endif  // HAVE_SSE2
+
 }  // namespace