Sync inv identity txfm ssse3/avx2 with C version Due to the change in CL50009. https://aomedia-review.googlesource.com/c/aom/+/50009 The behaviour of identity txfm ssse3 and avx2 version aren't match with C version. Change-Id: I68886f3f37f586cf587b3c3cd31de04eab6b5e4a
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c index b36819d..899d068 100644 --- a/av1/common/x86/av1_inv_txfm_avx2.c +++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -863,184 +863,70 @@ } } -static INLINE void iidentity16_row_16xn_avx2(__m256i *out, const int32_t *input, - int stride, int shift, int height, - int rect) { +static INLINE void iidentity_row_16xn_avx2(__m256i *out, const int32_t *input, + int stride, int shift, int height, + int txw_idx, int rect_type) { const int32_t *input_row = input; - const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits); - const __m256i scale = - _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - if (!rect) { - for (int h = 0; h < height; ++h) { + const __m256i scale = _mm256_set1_epi16(NewSqrt2list[txw_idx]); + const __m256i rounding = _mm256_set1_epi16((1 << (NewSqrt2Bits - 1)) + + (1 << (NewSqrt2Bits - shift - 1))); + const __m256i one = _mm256_set1_epi16(1); + const __m256i scale_rounding = _mm256_unpacklo_epi16(scale, rounding); + if (rect_type != 1 && rect_type != -1) { + for (int i = 0; i < height; ++i) { __m256i src = load_32bit_to_16bit_w16_avx2(input_row); input_row += stride; - __m256i x = _mm256_mulhrs_epi16(src, scale); - __m256i srcx2 = _mm256_adds_epi16(src, src); - x = _mm256_adds_epi16(x, srcx2); - out[h] = _mm256_mulhrs_epi16(x, mshift); + __m256i lo = _mm256_unpacklo_epi16(src, one); + __m256i hi = _mm256_unpackhi_epi16(src, one); + lo = _mm256_madd_epi16(lo, scale_rounding); + hi = _mm256_madd_epi16(hi, scale_rounding); + lo = _mm256_srai_epi32(lo, NewSqrt2Bits - shift); + hi = _mm256_srai_epi32(hi, NewSqrt2Bits - shift); + out[i] = _mm256_packs_epi32(lo, hi); } } else { const __m256i rect_scale = _mm256_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { + for (int i = 0; i < height; ++i) { __m256i src = load_32bit_to_16bit_w16_avx2(input_row); - input_row += stride; src = _mm256_mulhrs_epi16(src, rect_scale); - __m256i x = _mm256_mulhrs_epi16(src, scale); - __m256i srcx2 = _mm256_adds_epi16(src, src); - x = _mm256_adds_epi16(x, srcx2); - out[h] = _mm256_mulhrs_epi16(x, mshift); + input_row += stride; + __m256i lo = _mm256_unpacklo_epi16(src, one); + __m256i hi = _mm256_unpackhi_epi16(src, one); + lo = _mm256_madd_epi16(lo, scale_rounding); + hi = _mm256_madd_epi16(hi, scale_rounding); + lo = _mm256_srai_epi32(lo, NewSqrt2Bits - shift); + hi = _mm256_srai_epi32(hi, NewSqrt2Bits - shift); + out[i] = _mm256_packs_epi32(lo, hi); } } } -static INLINE void iidentity16_col_16xn_avx2(uint8_t *output, int stride, - __m256i *buf, int shift, - int height) { - const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits); - const __m256i scale = - _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); +static INLINE void iidentity_col_16xn_avx2(uint8_t *output, int stride, + __m256i *buf, int shift, int height, + int txh_idx) { + const __m256i scale = _mm256_set1_epi16(NewSqrt2list[txh_idx]); + const __m256i scale_rounding = _mm256_set1_epi16(1 << (NewSqrt2Bits - 1)); + const __m256i shift_rounding = _mm256_set1_epi32(1 << (-shift - 1)); + const __m256i one = _mm256_set1_epi16(1); + const __m256i scale_coeff = _mm256_unpacklo_epi16(scale, scale_rounding); for (int h = 0; h < height; ++h) { - __m256i x = _mm256_mulhrs_epi16(buf[h], scale); - __m256i srcx2 = _mm256_adds_epi16(buf[h], buf[h]); - x = _mm256_adds_epi16(x, srcx2); - x = _mm256_mulhrs_epi16(x, mshift); + __m256i lo = _mm256_unpacklo_epi16(buf[h], one); + __m256i hi = _mm256_unpackhi_epi16(buf[h], one); + lo = _mm256_madd_epi16(lo, scale_coeff); + hi = _mm256_madd_epi16(hi, scale_coeff); + lo = _mm256_srai_epi32(lo, NewSqrt2Bits); + hi = _mm256_srai_epi32(hi, NewSqrt2Bits); + lo = _mm256_add_epi32(lo, shift_rounding); + hi = _mm256_add_epi32(hi, shift_rounding); + lo = _mm256_srai_epi32(lo, -shift); + hi = _mm256_srai_epi32(hi, -shift); + __m256i x = _mm256_packs_epi32(lo, hi); write_recon_w16_avx2(x, output); output += stride; } } -static INLINE void iidentity32_row_16xn_avx2(__m256i *out, const int32_t *input, - int stride, int shift, int height, - int rect) { - const int32_t *input_row = input; - const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift)); - if (!rect) { - for (int h = 0; h < height; ++h) { - __m256i x = load_32bit_to_16bit_w16_avx2(input_row); - input_row += stride; - x = _mm256_adds_epi16(x, x); - x = _mm256_adds_epi16(x, x); - out[h] = _mm256_mulhrs_epi16(x, mshift); - } - } else { - const __m256i rect_scale = _mm256_set1_epi16(NewInvSqrt2 * 8); - for (int h = 0; h < height; ++h) { - __m256i x = load_32bit_to_16bit_w16_avx2(input_row); - input_row += stride; - x = _mm256_mulhrs_epi16(x, rect_scale); - x = _mm256_adds_epi16(x, x); - x = _mm256_adds_epi16(x, x); - out[h] = _mm256_mulhrs_epi16(x, mshift); - } - } -} - -static INLINE void iidentity32_col_16xn_avx2(uint8_t *output, int stride, - __m256i *buf, int shift, - int height) { - const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift)); - for (int h = 0; h < height; ++h) { - __m256i x = _mm256_adds_epi16(buf[h], buf[h]); - x = _mm256_adds_epi16(x, x); - x = _mm256_mulhrs_epi16(x, mshift); - write_recon_w16_avx2(x, output); - output += stride; - } -} - -static INLINE void iidentity64_row_16xn_avx2(__m256i *out, const int32_t *input, - int stride, int shift, int height, - int rect) { - const int32_t *input_row = input; - const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits); - const __m256i scale = - _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - if (!rect) { - for (int h = 0; h < height; ++h) { - __m256i src = load_32bit_to_16bit_w16_avx2(input_row); - input_row += stride; - __m256i x = _mm256_mulhrs_epi16(src, scale); - __m256i srcx5 = _mm256_adds_epi16(src, src); - srcx5 = _mm256_adds_epi16(srcx5, srcx5); - srcx5 = _mm256_adds_epi16(srcx5, src); - x = _mm256_adds_epi16(x, srcx5); - out[h] = _mm256_mulhrs_epi16(x, mshift); - } - } else { - const __m256i rect_scale = - _mm256_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { - __m256i src = load_32bit_to_16bit_w16_avx2(input_row); - input_row += stride; - src = _mm256_mulhrs_epi16(src, rect_scale); - __m256i x = _mm256_mulhrs_epi16(src, scale); - __m256i srcx5 = _mm256_adds_epi16(src, src); - srcx5 = _mm256_adds_epi16(srcx5, srcx5); - srcx5 = _mm256_adds_epi16(srcx5, src); - x = _mm256_adds_epi16(x, srcx5); - out[h] = _mm256_mulhrs_epi16(x, mshift); - } - } -} - -static INLINE void iidentity64_col_16xn_avx2(uint8_t *output, int stride, - __m256i *buf, int shift, - int height) { - const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits); - const __m256i scale = - _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { - __m256i x = _mm256_mulhrs_epi16(buf[h], scale); - __m256i srcx5 = _mm256_adds_epi16(buf[h], buf[h]); - srcx5 = _mm256_adds_epi16(srcx5, srcx5); - srcx5 = _mm256_adds_epi16(srcx5, buf[h]); - x = _mm256_adds_epi16(x, srcx5); - x = _mm256_mulhrs_epi16(x, mshift); - write_recon_w16_avx2(x, output); - output += stride; - } -} - -static INLINE void identity_row_16xn_avx2(__m256i *out, const int32_t *input, - int stride, int shift, int height, - int txw_idx, int rect_type) { - int rect = (rect_type != 1 && rect_type != -1) ? 0 : 1; - switch (txw_idx) { - case 2: - iidentity16_row_16xn_avx2(out, input, stride, shift, height, rect); - break; - case 3: - iidentity32_row_16xn_avx2(out, input, stride, shift, height, rect); - break; - case 4: - iidentity64_row_16xn_avx2(out, input, stride, shift, height, rect); - break; - default: break; - } -} - -static INLINE void identity_col_16xn_avx2(uint8_t *output, int stride, - __m256i *buf, int shift, int height, - int txh_idx) { - switch (txh_idx) { - case 2: - iidentity16_col_16xn_avx2(output, stride, buf, shift, height); - break; - case 3: - iidentity32_col_16xn_avx2(output, stride, buf, shift, height); - break; - case 4: - iidentity64_col_16xn_avx2(output, stride, buf, shift, height); - break; - default: break; - } -} - static INLINE void lowbd_inv_txfm2d_add_idtx_avx2(const int32_t *input, uint8_t *output, int stride, TX_SIZE tx_size) { @@ -1054,9 +940,10 @@ const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row); __m256i buf[32]; for (int i = 0; i < input_stride; i += 16) { - identity_row_16xn_avx2(buf, input + i, input_stride, shift[0], row_max, - txw_idx, rect_type); - identity_col_16xn_avx2(output + i, stride, buf, shift[1], row_max, txh_idx); + iidentity_row_16xn_avx2(buf, input + i, input_stride, shift[0], row_max, + txw_idx, rect_type); + iidentity_col_16xn_avx2(output + i, stride, buf, shift[1], row_max, + txh_idx); } } @@ -1085,8 +972,8 @@ get_flip_cfg(tx_type, &ud_flip, &lr_flip); for (int i = 0; i < txfm_size_col_notzero; i += 16) { __m256i buf0[64]; - identity_row_16xn_avx2(buf0, input + i, input_stride, shift[0], - txfm_size_row_notzero, txw_idx, rect_type); + iidentity_row_16xn_avx2(buf0, input + i, input_stride, shift[0], + txfm_size_row_notzero, txw_idx, rect_type); col_txfm(buf0, buf0, cos_bit_col); __m256i mshift = _mm256_set1_epi16(1 << (15 + shift[1])); int k = ud_flip ? (txfm_size_row - 1) : 0; @@ -1149,8 +1036,8 @@ } } for (int j = 0; j < buf_size_w_div16; ++j) { - identity_col_16xn_avx2(output + i * 16 * stride + j * 16, stride, - buf1 + j * 16, shift[1], 16, txh_idx); + iidentity_col_16xn_avx2(output + i * 16 * stride + j * 16, stride, + buf1 + j * 16, shift[1], 16, txh_idx); } } }
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c index 25cac55..b9706e9 100644 --- a/av1/common/x86/av1_inv_txfm_ssse3.c +++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -1563,339 +1563,72 @@ { NULL, NULL, NULL }, }; -static INLINE void iidentity4_row_8xn_ssse3(__m128i *out, const int32_t *input, - int stride, int shift, int height) { +static INLINE void iidentity_row_8xn_ssse3(__m128i *out, const int32_t *input, + int stride, int shift, int height, + int txw_idx, int rect_type) { const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { - __m128i src = load_32bit_to_16bit(input_row); - input_row += stride; - __m128i x = _mm_mulhrs_epi16(src, scale); - x = _mm_adds_epi16(x, src); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity4_row_rect_8xn_ssse3(__m128i *out, - const int32_t *input, - int stride, int shift, - int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { - __m128i src = load_32bit_to_16bit(input_row); - input_row += stride; - src = _mm_mulhrs_epi16(src, rect_scale); - __m128i x = _mm_mulhrs_epi16(src, scale); - x = _mm_adds_epi16(x, src); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity4_col_8xn_ssse3(uint8_t *output, int stride, - __m128i *buf, int shift, - int height) { - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - const __m128i zero = _mm_setzero_si128(); - for (int h = 0; h < height; ++h) { - __m128i x = _mm_mulhrs_epi16(buf[h], scale); - x = _mm_adds_epi16(x, buf[h]); - x = _mm_mulhrs_epi16(x, mshift); - const __m128i pred = _mm_loadl_epi64((__m128i const *)(output)); - x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero)); - __m128i u = _mm_packus_epi16(x, x); - _mm_storel_epi64((__m128i *)(output), u); - output += stride; - } -} - -static INLINE void iidentity8_row_8xn_ssse3(__m128i *out, const int32_t *input, - int stride, int shift, int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - for (int h = 0; h < height; ++h) { - __m128i src0 = _mm_load_si128((__m128i *)(input_row)); - __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4)); - input_row += stride; - __m128i x = _mm_packs_epi32(src0, src1); - x = _mm_adds_epi16(x, x); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity8_row_rect_8xn_ssse3(__m128i *out, - const int32_t *input, - int stride, int shift, - int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 * 8); - for (int h = 0; h < height; ++h) { - __m128i src0 = _mm_load_si128((__m128i *)(input_row)); - __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4)); - input_row += stride; - __m128i x = _mm_packs_epi32(src0, src1); - x = _mm_mulhrs_epi16(x, rect_scale); - x = _mm_adds_epi16(x, x); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity8_col_8xn_ssse3(uint8_t *output, int stride, - __m128i *buf, int shift, - int height) { - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const __m128i zero = _mm_setzero_si128(); - for (int h = 0; h < height; ++h) { - __m128i x = _mm_adds_epi16(buf[h], buf[h]); - x = _mm_mulhrs_epi16(x, mshift); - const __m128i pred = _mm_loadl_epi64((__m128i const *)(output)); - x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero)); - __m128i u = _mm_packus_epi16(x, x); - _mm_storel_epi64((__m128i *)(output), u); - output += stride; - } -} - -static INLINE void iidentity16_row_8xn_ssse3(__m128i *out, const int32_t *input, - int stride, int shift, - int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { - __m128i src = load_32bit_to_16bit(input_row); - input_row += stride; - __m128i x = _mm_mulhrs_epi16(src, scale); - __m128i srcx2 = _mm_adds_epi16(src, src); - x = _mm_adds_epi16(x, srcx2); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity16_row_rect_8xn_ssse3(__m128i *out, - const int32_t *input, - int stride, int shift, - int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { - __m128i src = load_32bit_to_16bit(input_row); - input_row += stride; - src = _mm_mulhrs_epi16(src, rect_scale); - __m128i x = _mm_mulhrs_epi16(src, scale); - __m128i srcx2 = _mm_adds_epi16(src, src); - x = _mm_adds_epi16(x, srcx2); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity16_col_8xn_ssse3(uint8_t *output, int stride, - __m128i *buf, int shift, - int height) { - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - const __m128i zero = _mm_setzero_si128(); - for (int h = 0; h < height; ++h) { - __m128i x = _mm_mulhrs_epi16(buf[h], scale); - __m128i srcx2 = _mm_adds_epi16(buf[h], buf[h]); - x = _mm_adds_epi16(x, srcx2); - x = _mm_mulhrs_epi16(x, mshift); - const __m128i pred = _mm_loadl_epi64((__m128i const *)(output)); - x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero)); - __m128i u = _mm_packus_epi16(x, x); - _mm_storel_epi64((__m128i *)(output), u); - output += stride; - } -} - -static INLINE void iidentity32_row_8xn_ssse3(__m128i *out, const int32_t *input, - int stride, int shift, - int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - for (int h = 0; h < height; ++h) { - __m128i src0 = _mm_load_si128((__m128i *)(input_row)); - __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4)); - input_row += stride; - __m128i x = _mm_packs_epi32(src0, src1); - x = _mm_adds_epi16(x, x); - x = _mm_adds_epi16(x, x); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity32_row_rect_8xn_ssse3(__m128i *out, - const int32_t *input, - int stride, int shift, - int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 * 8); - for (int h = 0; h < height; ++h) { - __m128i src0 = _mm_load_si128((__m128i *)(input_row)); - __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4)); - input_row += stride; - __m128i x = _mm_packs_epi32(src0, src1); - x = _mm_mulhrs_epi16(x, rect_scale); - x = _mm_adds_epi16(x, x); - x = _mm_adds_epi16(x, x); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity32_col_8xn_ssse3(uint8_t *output, int stride, - __m128i *buf, int shift, - int height) { - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const __m128i zero = _mm_setzero_si128(); - for (int h = 0; h < height; ++h) { - __m128i x = _mm_adds_epi16(buf[h], buf[h]); - x = _mm_adds_epi16(x, x); - x = _mm_mulhrs_epi16(x, mshift); - const __m128i pred = _mm_loadl_epi64((__m128i const *)(output)); - x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero)); - __m128i u = _mm_packus_epi16(x, x); - _mm_storel_epi64((__m128i *)(output), u); - output += stride; - } -} - -static INLINE void iidentity64_row_8xn_ssse3(__m128i *out, const int32_t *input, - int stride, int shift, - int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { - __m128i src = load_32bit_to_16bit(input_row); - input_row += stride; - __m128i x = _mm_mulhrs_epi16(src, scale); - __m128i srcx5 = _mm_adds_epi16(src, src); - srcx5 = _mm_adds_epi16(srcx5, srcx5); - srcx5 = _mm_adds_epi16(srcx5, src); - x = _mm_adds_epi16(x, srcx5); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity64_row_rect_8xn_ssse3(__m128i *out, - const int32_t *input, - int stride, int shift, - int height) { - const int32_t *input_row = input; - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits)); - for (int h = 0; h < height; ++h) { - __m128i src = load_32bit_to_16bit(input_row); - input_row += stride; - src = _mm_mulhrs_epi16(src, rect_scale); - __m128i x = _mm_mulhrs_epi16(src, scale); - __m128i srcx5 = _mm_adds_epi16(src, src); - srcx5 = _mm_adds_epi16(srcx5, srcx5); - srcx5 = _mm_adds_epi16(srcx5, src); - x = _mm_adds_epi16(x, srcx5); - out[h] = _mm_mulhrs_epi16(x, mshift); - } -} - -static INLINE void iidentity64_col_8xn_ssse3(uint8_t *output, int stride, - __m128i *buf, int shift, - int height) { - const __m128i mshift = _mm_set1_epi16(1 << (15 + shift)); - const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits); - const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits)); - const __m128i zero = _mm_setzero_si128(); - for (int h = 0; h < height; ++h) { - __m128i x = _mm_mulhrs_epi16(buf[h], scale); - __m128i srcx5 = _mm_adds_epi16(buf[h], buf[h]); - srcx5 = _mm_adds_epi16(srcx5, srcx5); - srcx5 = _mm_adds_epi16(srcx5, buf[h]); - x = _mm_adds_epi16(x, srcx5); - x = _mm_mulhrs_epi16(x, mshift); - const __m128i pred = _mm_loadl_epi64((__m128i const *)(output)); - x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero)); - __m128i u = _mm_packus_epi16(x, x); - _mm_storel_epi64((__m128i *)(output), u); - output += stride; - } -} - -static INLINE void identity_row_8xn_ssse3(__m128i *out, const int32_t *input, - int stride, int shift, int height, - int txw_idx, int rect_type) { + const __m128i scale = _mm_set1_epi16(NewSqrt2list[txw_idx]); + const __m128i rounding = _mm_set1_epi16((1 << (NewSqrt2Bits - 1)) + + (1 << (NewSqrt2Bits - shift - 1))); + const __m128i one = _mm_set1_epi16(1); + const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding); if (rect_type != 1 && rect_type != -1) { - switch (txw_idx) { - case 0: - iidentity4_row_8xn_ssse3(out, input, stride, shift, height); - break; - case 1: - iidentity8_row_8xn_ssse3(out, input, stride, shift, height); - break; - case 2: - iidentity16_row_8xn_ssse3(out, input, stride, shift, height); - break; - case 3: - iidentity32_row_8xn_ssse3(out, input, stride, shift, height); - break; - case 4: - iidentity64_row_8xn_ssse3(out, input, stride, shift, height); - break; - default: break; + for (int i = 0; i < height; ++i) { + __m128i src = load_32bit_to_16bit(input_row); + input_row += stride; + __m128i lo = _mm_unpacklo_epi16(src, one); + __m128i hi = _mm_unpackhi_epi16(src, one); + lo = _mm_madd_epi16(lo, scale_rounding); + hi = _mm_madd_epi16(hi, scale_rounding); + lo = _mm_srai_epi32(lo, NewSqrt2Bits - shift); + hi = _mm_srai_epi32(hi, NewSqrt2Bits - shift); + out[i] = _mm_packs_epi32(lo, hi); } } else { - switch (txw_idx) { - case 0: - iidentity4_row_rect_8xn_ssse3(out, input, stride, shift, height); - break; - case 1: - iidentity8_row_rect_8xn_ssse3(out, input, stride, shift, height); - break; - case 2: - iidentity16_row_rect_8xn_ssse3(out, input, stride, shift, height); - break; - case 3: - iidentity32_row_rect_8xn_ssse3(out, input, stride, shift, height); - break; - case 4: - iidentity64_row_rect_8xn_ssse3(out, input, stride, shift, height); - break; - default: break; + const __m128i rect_scale = + _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits)); + for (int i = 0; i < height; ++i) { + __m128i src = load_32bit_to_16bit(input_row); + src = _mm_mulhrs_epi16(src, rect_scale); + input_row += stride; + __m128i lo = _mm_unpacklo_epi16(src, one); + __m128i hi = _mm_unpackhi_epi16(src, one); + lo = _mm_madd_epi16(lo, scale_rounding); + hi = _mm_madd_epi16(hi, scale_rounding); + lo = _mm_srai_epi32(lo, NewSqrt2Bits - shift); + hi = _mm_srai_epi32(hi, NewSqrt2Bits - shift); + out[i] = _mm_packs_epi32(lo, hi); } } } -static INLINE void identity_col_8xn_ssse3(uint8_t *output, int stride, - __m128i *buf, int shift, int height, - int txh_idx) { - switch (txh_idx) { - case 0: iidentity4_col_8xn_ssse3(output, stride, buf, shift, height); break; - case 1: iidentity8_col_8xn_ssse3(output, stride, buf, shift, height); break; - case 2: - iidentity16_col_8xn_ssse3(output, stride, buf, shift, height); - break; - case 3: - iidentity32_col_8xn_ssse3(output, stride, buf, shift, height); - break; - case 4: - iidentity64_col_8xn_ssse3(output, stride, buf, shift, height); - break; - default: break; +static INLINE void iidentity_col_8xn_ssse3(uint8_t *output, int stride, + __m128i *buf, int shift, int height, + int txh_idx) { + const __m128i scale = _mm_set1_epi16(NewSqrt2list[txh_idx]); + const __m128i scale_rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1)); + const __m128i shift_rounding = _mm_set1_epi32(1 << (-shift - 1)); + const __m128i one = _mm_set1_epi16(1); + const __m128i scale_coeff = _mm_unpacklo_epi16(scale, scale_rounding); + const __m128i zero = _mm_setzero_si128(); + for (int h = 0; h < height; ++h) { + __m128i lo = _mm_unpacklo_epi16(buf[h], one); + __m128i hi = _mm_unpackhi_epi16(buf[h], one); + lo = _mm_madd_epi16(lo, scale_coeff); + hi = _mm_madd_epi16(hi, scale_coeff); + lo = _mm_srai_epi32(lo, NewSqrt2Bits); + hi = _mm_srai_epi32(hi, NewSqrt2Bits); + lo = _mm_add_epi32(lo, shift_rounding); + hi = _mm_add_epi32(hi, shift_rounding); + lo = _mm_srai_epi32(lo, -shift); + hi = _mm_srai_epi32(hi, -shift); + __m128i x = _mm_packs_epi32(lo, hi); + + const __m128i pred = _mm_loadl_epi64((__m128i const *)(output)); + x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero)); + __m128i u = _mm_packus_epi16(x, x); + _mm_storel_epi64((__m128i *)(output), u); + output += stride; } } @@ -1913,10 +1646,10 @@ __m128i buf[32]; for (int i = 0; i<input_stride>> 3; ++i) { - identity_row_8xn_ssse3(buf, input + 8 * i, input_stride, shift[0], row_max, - txw_idx, rect_type); - identity_col_8xn_ssse3(output + 8 * i, stride, buf, shift[1], row_max, - txh_idx); + iidentity_row_8xn_ssse3(buf, input + 8 * i, input_stride, shift[0], row_max, + txw_idx, rect_type); + iidentity_col_8xn_ssse3(output + 8 * i, stride, buf, shift[1], row_max, + txh_idx); } } @@ -2083,8 +1816,8 @@ get_flip_cfg(tx_type, &ud_flip, &lr_flip); for (int i = 0; i < AOMMIN(4, buf_size_w_div8); i++) { __m128i buf0[64]; - identity_row_8xn_ssse3(buf0, input + 8 * i, input_stride, shift[0], - txfm_size_row_notzero, txw_idx, rect_type); + iidentity_row_8xn_ssse3(buf0, input + 8 * i, input_stride, shift[0], + txfm_size_row_notzero, txw_idx, rect_type); col_txfm(buf0, buf0, cos_bit_col); __m128i mshift = _mm_set1_epi16(1 << (15 + shift[1])); int k = ud_flip ? (txfm_size_row - 1) : 0; @@ -2149,8 +1882,8 @@ } for (int j = 0; j < buf_size_w_div8; ++j) { - identity_col_8xn_ssse3(output + i * 8 * stride + j * 8, stride, - buf1 + j * 8, shift[1], 8, txh_idx); + iidentity_col_8xn_ssse3(output + i * 8 * stride + j * 8, stride, + buf1 + j * 8, shift[1], 8, txh_idx); } } }
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.h b/av1/common/x86/av1_inv_txfm_ssse3.h index 96dc0d6..ccdb006 100644 --- a/av1/common/x86/av1_inv_txfm_ssse3.h +++ b/av1/common/x86/av1_inv_txfm_ssse3.h
@@ -53,6 +53,10 @@ out1 = _mm_subs_epi16(_in0, _in1); \ } while (0) +#ifdef __cplusplus +extern "C" { +#endif + static INLINE void round_shift_16bit_ssse3(__m128i *in, int size, int bit) { if (bit < 0) { const __m128i scale = _mm_set1_epi16(1 << (15 + bit)); @@ -66,10 +70,6 @@ } } -#ifdef __cplusplus -extern "C" { -#endif - // 1D itx types typedef enum ATTRIBUTE_PACKED { IDCT_1D, @@ -93,6 +93,10 @@ IIDENTITY_1D, IADST_1D, IIDENTITY_1D, IFLIPADST_1D, }; +// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5 +static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096, + 4 * 5793 }; + typedef void (*transform_1d_ssse3)(const __m128i *input, __m128i *output, int8_t cos_bit);
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc index 73feb16..c07ce09 100644 --- a/test/av1_inv_txfm2d_test.cc +++ b/test/av1_inv_txfm2d_test.cc
@@ -325,11 +325,13 @@ #endif // HAVE_SSSE3 #if HAVE_AVX2 -#if defined(_MSC_VER) || defined(__AVX2__) -#include "av1/common/x86/av1_inv_txfm_avx2.h" +extern "C" void av1_lowbd_inv_txfm2d_add_avx2(const int32_t *input, + uint8_t *output, int stride, + TX_TYPE tx_type, TX_SIZE tx_size, + int eob); + INSTANTIATE_TEST_CASE_P(AVX2, AV1LbdInvTxfm2d, ::testing::Values(av1_lowbd_inv_txfm2d_add_avx2)); -#endif // (_MSC_VER) || (__AVX2__) #endif // HAVE_AVX2 } // namespace