daala_tx: Add SIMD version of the 4-point DCT Currently this only requires SSSE3 operations, but we build it as AVX2 to get support for 3-operand instructions. Separate versions for different instruction sets will be added later. Change-Id: Ib02c1496832923ecf6dccc1a208dc5ac5559dad2
diff --git a/av1/common/x86/daala_inv_txfm_avx2.c b/av1/common/x86/daala_inv_txfm_avx2.c index d47d620..1e5f7cf 100644 --- a/av1/common/x86/daala_inv_txfm_avx2.c +++ b/av1/common/x86/daala_inv_txfm_avx2.c
@@ -9,6 +9,7 @@ * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ +#include <tmmintrin.h> #include "./av1_rtcd.h" #include "./aom_config.h" #include "./aom_dsp_rtcd.h" @@ -18,6 +19,365 @@ #if CONFIG_DAALA_TX +static INLINE __m128i od_unbiased_rshift1_epi16(__m128i a) { + return _mm_srai_epi16(_mm_add_epi16(_mm_srli_epi16(a, 15), a), 1); +} + +static INLINE __m128i od_avg_epi16(__m128i a, __m128i b) { + __m128i sign_bit; + /*x86 only provides an unsigned PAVGW with a bias (ARM is better here). + We emulate a signed one by adding an offset to convert to unsigned and + back. We use XOR instead of addition/subtraction because it dispatches + better on older processors.*/ + sign_bit = _mm_set1_epi16(0x8000); + return _mm_xor_si128( + _mm_avg_epu16(_mm_xor_si128(a, sign_bit), _mm_xor_si128(b, sign_bit)), + sign_bit); +} + +static INLINE __m128i od_mulhrs_epi16(__m128i a, int16_t b) { + return _mm_mulhrs_epi16(a, _mm_set1_epi16(b)); +} + +static INLINE __m128i od_hbd_max_epi16(int bd) { + return _mm_set1_epi16((1 << bd) - 1); +} + +static INLINE __m128i od_hbd_clamp_epi16(__m128i a, __m128i max) { + return _mm_max_epi16(_mm_setzero_si128(), _mm_min_epi16(a, max)); +} + +/* Loads a 4x4 buffer of 32-bit values into four SSE registers. */ +static INLINE void od_load_buffer_4x4_epi32(__m128i *q0, __m128i *q1, + __m128i *q2, __m128i *q3, + const tran_low_t *in) { + *q0 = _mm_loadu_si128((const __m128i *)in + 0); + *q1 = _mm_loadu_si128((const __m128i *)in + 1); + *q2 = _mm_loadu_si128((const __m128i *)in + 2); + *q3 = _mm_loadu_si128((const __m128i *)in + 3); +} + +/* Loads a 4x4 buffer of 16-bit values into four SSE registers. */ +static INLINE void od_load_buffer_4x4_epi16(__m128i *q0, __m128i *q1, + __m128i *q2, __m128i *q3, + const int16_t *in) { + *q0 = _mm_loadu_si128((const __m128i *)in + 0); + *q1 = _mm_unpackhi_epi64(*q0, *q0); + *q2 = _mm_loadu_si128((const __m128i *)in + 1); + *q3 = _mm_unpackhi_epi64(*q2, *q2); +} + +/* Loads an 8x4 buffer of 16-bit values into four SSE registers. */ +static INLINE void od_load_buffer_8x4_epi16(__m128i *q0, __m128i *q1, + __m128i *q2, __m128i *q3, + const int16_t *in, int in_stride) { + *q0 = _mm_loadu_si128((const __m128i *)(in + 0 * in_stride)); + *q1 = _mm_loadu_si128((const __m128i *)(in + 1 * in_stride)); + *q2 = _mm_loadu_si128((const __m128i *)(in + 2 * in_stride)); + *q3 = _mm_loadu_si128((const __m128i *)(in + 3 * in_stride)); +} + +/* Stores a 4x4 buffer of 16-bit values from two SSE registers. + Each register holds two rows of values. */ +static INLINE void od_store_buffer_4x4_epi16(int16_t *out, __m128i q0, + __m128i q1) { + _mm_storeu_si128((__m128i *)out + 0, q0); + _mm_storeu_si128((__m128i *)out + 1, q1); +} + +/* Stores a 4x8 buffer of 16-bit values from four SSE registers. + Each register holds two rows of values. */ +static INLINE void od_store_buffer_4x8_epi16(int16_t *out, __m128i q0, + __m128i q1, __m128i q2, + __m128i q3) { + _mm_storeu_si128((__m128i *)out + 0, q0); + _mm_storeu_si128((__m128i *)out + 1, q1); + _mm_storeu_si128((__m128i *)out + 2, q2); + _mm_storeu_si128((__m128i *)out + 3, q3); +} + +/* Loads a 4x4 buffer of 16-bit values, adds a 4x4 block of 16-bit values to + them, clamps to high bit depth, and stores the sum back. */ +static INLINE void od_add_store_buffer_hbd_4x4_epi16(void *output_pixels, + int output_stride, + __m128i q0, __m128i q1, + __m128i q2, __m128i q3, + int bd) { + uint16_t *output_pixels16; + __m128i p0; + __m128i p1; + __m128i p2; + __m128i p3; + __m128i max; + __m128i round; + int downshift; + output_pixels16 = CONVERT_TO_SHORTPTR(output_pixels); + max = od_hbd_max_epi16(bd); + downshift = TX_COEFF_DEPTH - bd; + round = _mm_set1_epi16((1 << downshift) >> 1); + p0 = _mm_loadl_epi64((const __m128i *)(output_pixels16 + 0 * output_stride)); + p1 = _mm_loadl_epi64((const __m128i *)(output_pixels16 + 1 * output_stride)); + p2 = _mm_loadl_epi64((const __m128i *)(output_pixels16 + 2 * output_stride)); + p3 = _mm_loadl_epi64((const __m128i *)(output_pixels16 + 3 * output_stride)); + q0 = _mm_srai_epi16(_mm_add_epi16(q0, round), downshift); + q1 = _mm_srai_epi16(_mm_add_epi16(q1, round), downshift); + q2 = _mm_srai_epi16(_mm_add_epi16(q2, round), downshift); + q3 = _mm_srai_epi16(_mm_add_epi16(q3, round), downshift); + p0 = od_hbd_clamp_epi16(_mm_add_epi16(p0, q0), max); + p1 = od_hbd_clamp_epi16(_mm_add_epi16(p1, q1), max); + p2 = od_hbd_clamp_epi16(_mm_add_epi16(p2, q2), max); + p3 = od_hbd_clamp_epi16(_mm_add_epi16(p3, q3), max); + _mm_storel_epi64((__m128i *)(output_pixels16 + 0 * output_stride), p0); + _mm_storel_epi64((__m128i *)(output_pixels16 + 1 * output_stride), p1); + _mm_storel_epi64((__m128i *)(output_pixels16 + 2 * output_stride), p2); + _mm_storel_epi64((__m128i *)(output_pixels16 + 3 * output_stride), p3); +} + +/* Loads an 8x4 buffer of 16-bit values, adds a 8x4 block of 16-bit values to + them, clamps to the high bit depth max, and stores the sum back. */ +static INLINE void od_add_store_buffer_hbd_8x4_epi16(void *output_pixels, + int output_stride, + __m128i q0, __m128i q1, + __m128i q2, __m128i q3, + int bd) { + uint16_t *output_pixels16; + __m128i p0; + __m128i p1; + __m128i p2; + __m128i p3; + __m128i max; + __m128i round; + int downshift; + output_pixels16 = CONVERT_TO_SHORTPTR(output_pixels); + max = od_hbd_max_epi16(bd); + downshift = TX_COEFF_DEPTH - bd; + round = _mm_set1_epi16((1 << downshift) >> 1); + p0 = _mm_loadu_si128((const __m128i *)(output_pixels16 + 0 * output_stride)); + p1 = _mm_loadu_si128((const __m128i *)(output_pixels16 + 1 * output_stride)); + p2 = _mm_loadu_si128((const __m128i *)(output_pixels16 + 2 * output_stride)); + p3 = _mm_loadu_si128((const __m128i *)(output_pixels16 + 3 * output_stride)); + q0 = _mm_srai_epi16(_mm_add_epi16(q0, round), downshift); + q1 = _mm_srai_epi16(_mm_add_epi16(q1, round), downshift); + q2 = _mm_srai_epi16(_mm_add_epi16(q2, round), downshift); + q3 = _mm_srai_epi16(_mm_add_epi16(q3, round), downshift); + p0 = od_hbd_clamp_epi16(_mm_add_epi16(p0, q0), max); + p1 = od_hbd_clamp_epi16(_mm_add_epi16(p1, q1), max); + p2 = od_hbd_clamp_epi16(_mm_add_epi16(p2, q2), max); + p3 = od_hbd_clamp_epi16(_mm_add_epi16(p3, q3), max); + _mm_storeu_si128((__m128i *)(output_pixels16 + 0 * output_stride), p0); + _mm_storeu_si128((__m128i *)(output_pixels16 + 1 * output_stride), p1); + _mm_storeu_si128((__m128i *)(output_pixels16 + 2 * output_stride), p2); + _mm_storeu_si128((__m128i *)(output_pixels16 + 3 * output_stride), p3); +} + +static INLINE void od_transpose_pack4x4(__m128i *q0, __m128i *q1, __m128i *q2, + __m128i *q3) { + __m128i a; + __m128i b; + __m128i c; + __m128i d; + /* Input: + q0: q30 q20 q10 q00 + q1: q31 q21 q11 q01 + q2: q32 q22 q12 q02 + q3: q33 q23 q13 q03 + */ + /* a: q32 q22 q12 q02 q30 q20 q10 q00 */ + a = _mm_packs_epi32(*q0, *q2); + /* b: q33 q23 q13 q03 q31 q21 q11 q01 */ + b = _mm_packs_epi32(*q1, *q3); + /* c: q31 q30 q21 q20 q11 q10 q01 q00 */ + c = _mm_unpacklo_epi16(a, b); + /* d: q33 q32 q23 q22 q13 q12 q03 q02 */ + d = _mm_unpackhi_epi16(a, b); + /* We don't care about the contents of the high half of each register. */ + /* q0: q13 q12 q11 q10 [q03 q02 q01 q00] */ + *q0 = _mm_unpacklo_epi32(c, d); + /* q1: q13 q12 q11 q10 [q13 q12 q11 q10] */ + *q1 = _mm_unpackhi_epi64(*q0, *q0); + /* q2: q33 q32 q31 q30 [q23 q22 q21 q20] */ + *q2 = _mm_unpackhi_epi32(c, d); + /* q3: q33 q32 q31 q30 [q33 q32 q31 q30] */ + *q3 = _mm_unpackhi_epi64(*q2, *q2); +} + +static INLINE void od_transpose4x4(__m128i *q0, __m128i q1, __m128i *q2, + __m128i q3) { + __m128i a; + __m128i b; + /* Input: + q0: ... ... ... ... q30 q20 q10 q00 + q1: ... ... ... ... q31 q21 q11 q01 + q2: ... ... ... ... q32 q22 q12 q02 + q3: ... ... ... ... q33 q23 q13 q03 + */ + /* a: q31 q30 q21 q20 q11 q10 q01 q00 */ + a = _mm_unpacklo_epi16(*q0, q1); + /* b: q33 q32 q23 q22 q13 q12 q03 q02 */ + b = _mm_unpacklo_epi16(*q2, q3); + /* q0: q13 q12 q11 q10 | q03 q02 q01 q00 */ + *q0 = _mm_unpacklo_epi32(a, b); + /* q2: q33 q32 q31 q30 | q23 q22 q21 q20 */ + *q2 = _mm_unpackhi_epi32(a, b); +} + +static INLINE void od_transpose8x4(__m128i *q0, __m128i *q1, __m128i *q2, + __m128i *q3) { + __m128i a; + __m128i b; + __m128i c; + __m128i d; + /* Input: + q0: q07 q06 q05 q04 q03 q02 q01 q00 + q1: q17 q16 q15 q14 q13 q12 q11 q10 + q2: q27 q26 q25 q24 q23 q22 q21 q20 + q3: q37 q36 q35 q34 q33 q32 q31 q30 + */ + /* a: q13 q03 q12 q02 q11 q01 q10 q00 */ + a = _mm_unpacklo_epi16(*q0, *q1); + /* b: q17 q07 q16 q06 q15 q05 q14 q04 */ + b = _mm_unpackhi_epi16(*q0, *q1); + /* c: q33 q23 q32 q22 q31 q21 q30 q20 */ + c = _mm_unpacklo_epi16(*q2, *q3); + /* d: q37 q27 q36 q26 q35 q25 q34 q24 */ + d = _mm_unpackhi_epi16(*q2, *q3); + /* q0: q31 q21 q11 q01 | q30 q20 q10 q00 */ + *q0 = _mm_unpacklo_epi32(a, c); + /* q1: q33 q23 q13 q03 | q32 q22 q12 q02 */ + *q1 = _mm_unpackhi_epi32(a, c); + /* q2: q35 q25 q15 q05 | q34 q24 q14 q04 */ + *q2 = _mm_unpacklo_epi32(b, d); + /* q3: q37 q27 q17 q07 | q36 q26 q16 q06 */ + *q3 = _mm_unpackhi_epi32(b, d); +} + +static INLINE void od_transpose_pack4x8(__m128i *q0, __m128i *q1, __m128i *q2, + __m128i *q3, __m128i q4, __m128i q5, + __m128i q6, __m128i q7) { + __m128i a; + __m128i b; + __m128i c; + __m128i d; + /* Input: + q0: q30 q20 q10 q00 + q1: q31 q21 q11 q01 + q2: q32 q22 q12 q02 + q3: q33 q23 q13 q03 + q4: q34 q24 q14 q04 + q5: q35 q25 q15 q05 + q6: q36 q26 q16 q06 + q7: q37 q27 q17 q07 + */ + /* a: q34 q24 q14 q04 q30 q20 q10 q00 */ + a = _mm_packs_epi32(*q0, q4); + /* b: q35 q25 q15 q05 q31 q21 q11 q01 */ + b = _mm_packs_epi32(*q1, q5); + /* c: q36 q26 q16 q06 q32 q22 q12 q02 */ + c = _mm_packs_epi32(*q2, q6); + /* d: q37 q27 q17 q07 q33 q23 q13 q03 */ + d = _mm_packs_epi32(*q3, q7); + /* a: q13 q12 q11 q10 q03 q02 q01 q00 + b: q33 q32 q31 q30 q33 q22 q21 q20 + c: q53 q52 q51 q50 q43 q42 q41 q40 + d: q73 q72 q71 q70 q63 q62 q61 q60 */ + od_transpose8x4(&a, &b, &c, &d); + /* q0: q07 q06 q05 q04 q03 q02 q01 q00 */ + *q0 = _mm_unpacklo_epi64(a, c); + /* q1: q17 q16 q15 q14 q13 q12 q11 q10 */ + *q1 = _mm_unpackhi_epi64(a, c); + /* q2: q27 q26 q25 q24 q23 q22 q21 q20 */ + *q2 = _mm_unpacklo_epi64(b, d); + /* q3: q37 q36 q35 q34 q33 q32 q31 q30 */ + *q3 = _mm_unpackhi_epi64(b, d); +} + +static INLINE void od_idct2_asym_kernel8_epi16(__m128i *p0, __m128i *p1, + __m128i *p1h) { + *p1 = _mm_sub_epi16(*p0, *p1); + *p1h = od_unbiased_rshift1_epi16(*p1); + *p0 = _mm_sub_epi16(*p0, *p1h); +} + +static INLINE void od_idst2_asym_kernel8_epi16(__m128i *p0, __m128i *p1) { + __m128i t_; + __m128i u_; + t_ = od_avg_epi16(*p0, *p1); + /* 3135/4096 ~= (Cos[Pi/8] - Sin[Pi/8])*Sqrt[2] = 0.7653668647301795 */ + u_ = od_mulhrs_epi16(*p1, 3135 << 3); + /* 15137/16384 ~= (Cos[Pi/8] + Sin[Pi/8])/Sqrt[2] = 0.9238795325112867 */ + *p1 = od_mulhrs_epi16(*p0, 15137 << 1); + /* 8867/8192 ~= Cos[3*Pi/8]*2*Sqrt[2] = 1.082392200292394 */ + t_ = _mm_add_epi16(t_, od_mulhrs_epi16(t_, (8867 - 8192) << 2)); + *p0 = _mm_add_epi16(u_, t_); + *p1 = _mm_sub_epi16(*p1, od_unbiased_rshift1_epi16(t_)); +} + +static INLINE void od_idct4_kernel8_epi16(__m128i *q0, __m128i *q2, __m128i *q1, + __m128i *q3) { + __m128i q1h; + od_idst2_asym_kernel8_epi16(q3, q2); + od_idct2_asym_kernel8_epi16(q0, q1, &q1h); + *q2 = _mm_add_epi16(*q2, q1h); + *q1 = _mm_sub_epi16(*q1, *q2); + *q0 = _mm_add_epi16(*q0, od_unbiased_rshift1_epi16(*q3)); + *q3 = _mm_sub_epi16(*q0, *q3); +} + +static void od_row_idct4_avx2(int16_t *out, int rows, const tran_low_t *in) { + __m128i q0; + __m128i q1; + __m128i q2; + __m128i q3; + if (rows <= 4) { + od_load_buffer_4x4_epi32(&q0, &q1, &q2, &q3, in); + /*TODO(any): Merge this transpose with coefficient scanning.*/ + od_transpose_pack4x4(&q0, &q1, &q2, &q3); + od_idct4_kernel8_epi16(&q0, &q1, &q2, &q3); + od_transpose4x4(&q0, q2, &q1, q3); + od_store_buffer_4x4_epi16(out, q0, q1); + } else { + int r; + /* Higher row counts require 32-bit precision. */ + assert(rows <= 16); + for (r = 0; r < rows; r += 8) { + __m128i q4; + __m128i q5; + __m128i q6; + __m128i q7; + od_load_buffer_4x4_epi32(&q0, &q1, &q2, &q3, in + 4 * r); + od_load_buffer_4x4_epi32(&q4, &q5, &q6, &q7, in + 4 * r + 16); + /*TODO(any): Merge this transpose with coefficient scanning.*/ + od_transpose_pack4x8(&q0, &q1, &q2, &q3, q4, q5, q6, q7); + od_idct4_kernel8_epi16(&q0, &q1, &q2, &q3); + od_transpose8x4(&q0, &q2, &q1, &q3); + od_store_buffer_4x8_epi16(out + 4 * r, q0, q2, q1, q3); + } + } +} + +static void od_col_idct4_add_hbd_avx2(unsigned char *output_pixels, + int output_stride, int cols, + const int16_t *in, int bd) { + __m128i q0; + __m128i q1; + __m128i q2; + __m128i q3; + if (cols <= 4) { + od_load_buffer_4x4_epi16(&q0, &q1, &q2, &q3, in); + od_idct4_kernel8_epi16(&q0, &q1, &q2, &q3); + od_add_store_buffer_hbd_4x4_epi16(output_pixels, output_stride, q0, q2, q1, + q3, bd); + } else { + int c; + for (c = 0; c < cols; c += 8) { + od_load_buffer_8x4_epi16(&q0, &q1, &q2, &q3, in + c, cols); + od_idct4_kernel8_epi16(&q0, &q1, &q2, &q3); + od_add_store_buffer_hbd_8x4_epi16(output_pixels + c, output_stride, q0, + q2, q1, q3, bd); + } + } +} + typedef void (*daala_row_itx)(int16_t *out, int rows, const tran_low_t *in); typedef void (*daala_col_itx_add)(unsigned char *output_pixels, int output_stride, int cols, @@ -25,7 +385,7 @@ static const daala_row_itx TX_ROW_MAP[TX_SIZES][TX_TYPES] = { // 4-point transforms - { NULL, NULL, NULL, NULL }, + { od_row_idct4_avx2, NULL, NULL, NULL }, // 8-point transforms { NULL, NULL, NULL, NULL }, // 16-point transforms @@ -57,7 +417,7 @@ // High bit depth output { // 4-point transforms - { NULL, NULL, NULL, NULL }, + { od_col_idct4_add_hbd_avx2, NULL, NULL, NULL }, // 8-point transforms { NULL, NULL, NULL, NULL }, // 16-point transforms