Implement av1_fwd_txfm2d_8x8_sse2 Change-Id: I30beb0b0b0c2c96175566dcd5672ceb5000edada
diff --git a/av1/common/x86/av1_txfm_sse2.h b/av1/common/x86/av1_txfm_sse2.h index 98db0ec..92956a0 100644 --- a/av1/common/x86/av1_txfm_sse2.h +++ b/av1/common/x86/av1_txfm_sse2.h
@@ -14,7 +14,9 @@ #include <emmintrin.h> // SSE2 #include "./aom_config.h" +#include "./av1_rtcd.h" #include "aom/aom_integer.h" +#include "aom_dsp/x86/transpose_sse2.h" #include "av1/common/av1_txfm.h" #ifdef __cplusplus @@ -48,6 +50,79 @@ out1 = _mm_packs_epi32(d0, d1); \ } +static INLINE __m128i load_16bit_to_16bit(const int16_t *a) { + return _mm_load_si128((const __m128i *)a); +} + +static INLINE __m128i load_32bit_to_16bit(const int32_t *a) { + const __m128i a_low = _mm_load_si128((const __m128i *)a); + return _mm_packs_epi32(a_low, *(const __m128i *)(a + 4)); +} + +// Store 8 16 bit values. If the destination is 32 bits then sign extend the +// values by multiplying by 1. +static INLINE void store_16bit_to_32bit(__m128i a, int32_t *b) { + const __m128i one = _mm_set1_epi16(1); + const __m128i a_hi = _mm_mulhi_epi16(a, one); + const __m128i a_lo = _mm_mullo_epi16(a, one); + const __m128i a_1 = _mm_unpacklo_epi16(a_lo, a_hi); + const __m128i a_2 = _mm_unpackhi_epi16(a_lo, a_hi); + _mm_store_si128((__m128i *)(b), a_1); + _mm_store_si128((__m128i *)(b + 4), a_2); +} + +static INLINE void load_buffer_16bit_to_16bit(const int16_t *in, int stride, + __m128i *out, int out_size) { + for (int i = 0; i < out_size; ++i) { + out[i] = load_16bit_to_16bit(in + i * stride); + } +} + +static INLINE void load_buffer_32bit_to_16bit(const int32_t *in, int stride, + __m128i *out, int out_size) { + for (int i = 0; i < out_size; ++i) { + out[i] = load_32bit_to_16bit(in + i * stride); + } +} + +static INLINE void store_buffer_16bit_to_32bit_8x8(const __m128i *in, + int32_t *out) { + for (int i = 0; i < 8; ++i) { + store_16bit_to_32bit(in[i], out + i * 8); + } +} + +static INLINE void store_buffer_16bit_to_16bit_8x8(const __m128i *in, + int16_t *out) { + for (int i = 0; i < 8; ++i) { + _mm_store_si128((__m128i *)(out + i * 8), in[i]); + } +} + +static INLINE void round_shift_16bit(__m128i *in, int size, int bit) { + if (bit < 0) { + bit = -bit; + __m128i rounding = _mm_set1_epi16(1 << (bit - 1)); + for (int i = 0; i < size; ++i) { + in[i] = _mm_adds_epi16(in[i], rounding); + in[i] = _mm_srai_epi16(in[i], bit); + } + } else if (bit > 0) { + for (int i = 0; i < size; ++i) { + in[i] = _mm_slli_epi16(in[i], bit); + } + } +} + +void av1_fwd_txfm2d_8x8_sse2(const int16_t *input, int32_t *output, int stride, + TX_TYPE tx_type, int bd); + +typedef void (*transform_1d_sse2)(const __m128i *input, __m128i *output, + int8_t cos_bit); + +typedef struct { + transform_1d_sse2 col, row; // vertical and horizontal +} transform_2d_sse2; #ifdef __cplusplus } #endif // __cplusplus
diff --git a/av1/encoder/x86/av1_fwd_txfm_sse2.c b/av1/encoder/x86/av1_fwd_txfm_sse2.c index d0c9a7f..f8eaf83 100644 --- a/av1/encoder/x86/av1_fwd_txfm_sse2.c +++ b/av1/encoder/x86/av1_fwd_txfm_sse2.c
@@ -10,6 +10,7 @@ */ #include "av1/common/x86/av1_txfm_sse2.h" +#include "av1/encoder/av1_fwd_txfm1d_cfg.h" void fdct4_new_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) { const int32_t *cospi = cospi_arr(cos_bit); @@ -1573,3 +1574,50 @@ output[14] = x8[15]; output[15] = x8[0]; } + +void av1_fwd_txfm2d_8x8_sse2(const int16_t *input, int32_t *output, int stride, + TX_TYPE tx_type, int bd) { + (void)stride; + (void)bd; + __m128i buf[8]; + const int8_t *shift = fwd_txfm_shift_ls[TX_8X8]; + const int txw_idx = get_txw_idx(TX_8X8); + const int txh_idx = get_txh_idx(TX_8X8); + const int cos_bit_col = fwd_cos_bit_col[txw_idx][txh_idx]; + const int cos_bit_row = fwd_cos_bit_row[txw_idx][txh_idx]; + const int buf_size = 8; + static const transform_2d_sse2 txfm_arr[] = { + { fdct8_new_sse2, fdct8_new_sse2 }, // DCT_DCT + { fadst8_new_sse2, fdct8_new_sse2 }, // ADST_DCT + { fdct8_new_sse2, fadst8_new_sse2 }, // DCT_ADST + { fadst8_new_sse2, fadst8_new_sse2 }, // ADST_ADST + { NULL, NULL }, // FLIPADST_DCT + { NULL, NULL }, // DCT_FLIPADST + { NULL, NULL }, // FLIPADST_FLIPADST + { NULL, NULL }, // ADST_FLIPADST + { NULL, NULL }, // FLIPADST_ADST + { NULL, NULL }, // IDTX + { NULL, NULL }, // V_DCT + { NULL, NULL }, // H_DCT + { NULL, NULL }, // V_ADST + { NULL, NULL }, // H_ADST + { NULL, NULL }, // V_FLIPADST + { NULL, NULL }, // H_FLIPADST + }; + + const transform_1d_sse2 col_txfm = txfm_arr[tx_type].col; + const transform_1d_sse2 row_txfm = txfm_arr[tx_type].row; + if (col_txfm != NULL && row_txfm != NULL) { + load_buffer_16bit_to_16bit(input, stride, buf, buf_size); + round_shift_16bit(buf, 8, shift[0]); + col_txfm(buf, buf, cos_bit_col); + round_shift_16bit(buf, 8, shift[1]); + transpose_16bit_8x8(buf, buf); + row_txfm(buf, buf, cos_bit_row); + round_shift_16bit(buf, 8, shift[2]); + transpose_16bit_8x8(buf, buf); + store_buffer_16bit_to_32bit_8x8(buf, output); + } else { + av1_fwd_txfm2d_8x8_c(input, output, stride, tx_type, bd); + } +}
diff --git a/test/av1_fwd_txfm2d_test.cc b/test/av1_fwd_txfm2d_test.cc index 6b019bc..b39600e 100644 --- a/test/av1_fwd_txfm2d_test.cc +++ b/test/av1_fwd_txfm2d_test.cc
@@ -207,4 +207,77 @@ } } +#if HAVE_SSE2 && defined(__SSE2__) +#include "av1/common/x86/av1_txfm_sse2.h" +Fwd_Txfm2d_Func fwd_func_sse2_list[TX_SIZES_ALL][2] = { + { NULL, NULL }, // TX_4X4 + { av1_fwd_txfm2d_8x8_c, + av1_fwd_txfm2d_8x8_sse2 }, // TX_8X8 // 8x8 transform + { NULL, NULL }, // TX_16X16 + { NULL, NULL }, // TX_32X32 +#if CONFIG_TX64X64 + { NULL, NULL }, // TX_64X64 +#endif // CONFIG_TX64X64 + { NULL, NULL }, // TX_4X8 + { NULL, NULL }, // TX_8X4 + { NULL, NULL }, // TX_8X16 + { NULL, NULL }, // TX_16X8 + { NULL, NULL }, // TX_16X32 + { NULL, NULL }, // TX_32X16 +#if CONFIG_TX64X64 + { NULL, NULL }, // TX_32X64 + { NULL, NULL }, // TX_64X32 +#endif // CONFIG_TX64X64 + { NULL, NULL }, // TX_4X16 + { NULL, NULL }, // TX_16X4 + { NULL, NULL }, // TX_8X32 + { NULL, NULL }, // TX_32X8 +#if CONFIG_TX64X64 + { NULL, NULL }, // TX_16X64 + { NULL, NULL }, // TX_64X16 +#endif // CONFIG_TX64X64 +}; + +TEST(av1_fwd_txfm2d_sse2, match) { + const int bd = 8; + for (int tx_size = TX_4X4; tx_size < TX_SIZES_ALL; ++tx_size) { + for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) { + Fwd_Txfm2d_Func ref_func = fwd_func_sse2_list[tx_size][0]; + Fwd_Txfm2d_Func target_func = fwd_func_sse2_list[tx_size][1]; + if (ref_func != NULL && target_func != NULL) { + int16_t input[64 * 64] = { 0 }; + int32_t output[64 * 64] = { 0 }; + int32_t ref_output[64 * 64] = { 0 }; + int input_stride = 64; + ACMRandom rnd(ACMRandom::DeterministicSeed()); + int rows = tx_size_high[tx_size]; + int cols = tx_size_wide[tx_size]; + for (int cnt = 0; cnt < 500; ++cnt) { + if (cnt == 0) { + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + input[r * input_stride + c] = (1 << bd) - 1; + } + } + } else { + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + input[r * input_stride + c] = rnd.Rand16() % (1 << bd); + } + } + } + ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd); + target_func(input, output, input_stride, (TX_TYPE)tx_type, bd); + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + ASSERT_EQ(ref_output[r * cols + c], output[r * cols + c]); + } + } + } + } + } + } +} +#endif // HAVE_SSE2 + } // namespace