Sync inv identity txfm ssse3/avx2 with C version
Due to the change in CL50009.
https://aomedia-review.googlesource.com/c/aom/+/50009
The behaviour of identity txfm ssse3 and avx2 version
aren't match with C version.
Change-Id: I68886f3f37f586cf587b3c3cd31de04eab6b5e4a
diff --git a/av1/common/x86/av1_inv_txfm_avx2.c b/av1/common/x86/av1_inv_txfm_avx2.c
index b36819d..899d068 100644
--- a/av1/common/x86/av1_inv_txfm_avx2.c
+++ b/av1/common/x86/av1_inv_txfm_avx2.c
@@ -863,184 +863,70 @@
}
}
-static INLINE void iidentity16_row_16xn_avx2(__m256i *out, const int32_t *input,
- int stride, int shift, int height,
- int rect) {
+static INLINE void iidentity_row_16xn_avx2(__m256i *out, const int32_t *input,
+ int stride, int shift, int height,
+ int txw_idx, int rect_type) {
const int32_t *input_row = input;
- const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
- const __m256i scale =
- _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- if (!rect) {
- for (int h = 0; h < height; ++h) {
+ const __m256i scale = _mm256_set1_epi16(NewSqrt2list[txw_idx]);
+ const __m256i rounding = _mm256_set1_epi16((1 << (NewSqrt2Bits - 1)) +
+ (1 << (NewSqrt2Bits - shift - 1)));
+ const __m256i one = _mm256_set1_epi16(1);
+ const __m256i scale_rounding = _mm256_unpacklo_epi16(scale, rounding);
+ if (rect_type != 1 && rect_type != -1) {
+ for (int i = 0; i < height; ++i) {
__m256i src = load_32bit_to_16bit_w16_avx2(input_row);
input_row += stride;
- __m256i x = _mm256_mulhrs_epi16(src, scale);
- __m256i srcx2 = _mm256_adds_epi16(src, src);
- x = _mm256_adds_epi16(x, srcx2);
- out[h] = _mm256_mulhrs_epi16(x, mshift);
+ __m256i lo = _mm256_unpacklo_epi16(src, one);
+ __m256i hi = _mm256_unpackhi_epi16(src, one);
+ lo = _mm256_madd_epi16(lo, scale_rounding);
+ hi = _mm256_madd_epi16(hi, scale_rounding);
+ lo = _mm256_srai_epi32(lo, NewSqrt2Bits - shift);
+ hi = _mm256_srai_epi32(hi, NewSqrt2Bits - shift);
+ out[i] = _mm256_packs_epi32(lo, hi);
}
} else {
const __m256i rect_scale =
_mm256_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
+ for (int i = 0; i < height; ++i) {
__m256i src = load_32bit_to_16bit_w16_avx2(input_row);
- input_row += stride;
src = _mm256_mulhrs_epi16(src, rect_scale);
- __m256i x = _mm256_mulhrs_epi16(src, scale);
- __m256i srcx2 = _mm256_adds_epi16(src, src);
- x = _mm256_adds_epi16(x, srcx2);
- out[h] = _mm256_mulhrs_epi16(x, mshift);
+ input_row += stride;
+ __m256i lo = _mm256_unpacklo_epi16(src, one);
+ __m256i hi = _mm256_unpackhi_epi16(src, one);
+ lo = _mm256_madd_epi16(lo, scale_rounding);
+ hi = _mm256_madd_epi16(hi, scale_rounding);
+ lo = _mm256_srai_epi32(lo, NewSqrt2Bits - shift);
+ hi = _mm256_srai_epi32(hi, NewSqrt2Bits - shift);
+ out[i] = _mm256_packs_epi32(lo, hi);
}
}
}
-static INLINE void iidentity16_col_16xn_avx2(uint8_t *output, int stride,
- __m256i *buf, int shift,
- int height) {
- const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
- const __m256i scale =
- _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
+static INLINE void iidentity_col_16xn_avx2(uint8_t *output, int stride,
+ __m256i *buf, int shift, int height,
+ int txh_idx) {
+ const __m256i scale = _mm256_set1_epi16(NewSqrt2list[txh_idx]);
+ const __m256i scale_rounding = _mm256_set1_epi16(1 << (NewSqrt2Bits - 1));
+ const __m256i shift_rounding = _mm256_set1_epi32(1 << (-shift - 1));
+ const __m256i one = _mm256_set1_epi16(1);
+ const __m256i scale_coeff = _mm256_unpacklo_epi16(scale, scale_rounding);
for (int h = 0; h < height; ++h) {
- __m256i x = _mm256_mulhrs_epi16(buf[h], scale);
- __m256i srcx2 = _mm256_adds_epi16(buf[h], buf[h]);
- x = _mm256_adds_epi16(x, srcx2);
- x = _mm256_mulhrs_epi16(x, mshift);
+ __m256i lo = _mm256_unpacklo_epi16(buf[h], one);
+ __m256i hi = _mm256_unpackhi_epi16(buf[h], one);
+ lo = _mm256_madd_epi16(lo, scale_coeff);
+ hi = _mm256_madd_epi16(hi, scale_coeff);
+ lo = _mm256_srai_epi32(lo, NewSqrt2Bits);
+ hi = _mm256_srai_epi32(hi, NewSqrt2Bits);
+ lo = _mm256_add_epi32(lo, shift_rounding);
+ hi = _mm256_add_epi32(hi, shift_rounding);
+ lo = _mm256_srai_epi32(lo, -shift);
+ hi = _mm256_srai_epi32(hi, -shift);
+ __m256i x = _mm256_packs_epi32(lo, hi);
write_recon_w16_avx2(x, output);
output += stride;
}
}
-static INLINE void iidentity32_row_16xn_avx2(__m256i *out, const int32_t *input,
- int stride, int shift, int height,
- int rect) {
- const int32_t *input_row = input;
- const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
- if (!rect) {
- for (int h = 0; h < height; ++h) {
- __m256i x = load_32bit_to_16bit_w16_avx2(input_row);
- input_row += stride;
- x = _mm256_adds_epi16(x, x);
- x = _mm256_adds_epi16(x, x);
- out[h] = _mm256_mulhrs_epi16(x, mshift);
- }
- } else {
- const __m256i rect_scale = _mm256_set1_epi16(NewInvSqrt2 * 8);
- for (int h = 0; h < height; ++h) {
- __m256i x = load_32bit_to_16bit_w16_avx2(input_row);
- input_row += stride;
- x = _mm256_mulhrs_epi16(x, rect_scale);
- x = _mm256_adds_epi16(x, x);
- x = _mm256_adds_epi16(x, x);
- out[h] = _mm256_mulhrs_epi16(x, mshift);
- }
- }
-}
-
-static INLINE void iidentity32_col_16xn_avx2(uint8_t *output, int stride,
- __m256i *buf, int shift,
- int height) {
- const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
- for (int h = 0; h < height; ++h) {
- __m256i x = _mm256_adds_epi16(buf[h], buf[h]);
- x = _mm256_adds_epi16(x, x);
- x = _mm256_mulhrs_epi16(x, mshift);
- write_recon_w16_avx2(x, output);
- output += stride;
- }
-}
-
-static INLINE void iidentity64_row_16xn_avx2(__m256i *out, const int32_t *input,
- int stride, int shift, int height,
- int rect) {
- const int32_t *input_row = input;
- const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
- const __m256i scale =
- _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- if (!rect) {
- for (int h = 0; h < height; ++h) {
- __m256i src = load_32bit_to_16bit_w16_avx2(input_row);
- input_row += stride;
- __m256i x = _mm256_mulhrs_epi16(src, scale);
- __m256i srcx5 = _mm256_adds_epi16(src, src);
- srcx5 = _mm256_adds_epi16(srcx5, srcx5);
- srcx5 = _mm256_adds_epi16(srcx5, src);
- x = _mm256_adds_epi16(x, srcx5);
- out[h] = _mm256_mulhrs_epi16(x, mshift);
- }
- } else {
- const __m256i rect_scale =
- _mm256_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
- __m256i src = load_32bit_to_16bit_w16_avx2(input_row);
- input_row += stride;
- src = _mm256_mulhrs_epi16(src, rect_scale);
- __m256i x = _mm256_mulhrs_epi16(src, scale);
- __m256i srcx5 = _mm256_adds_epi16(src, src);
- srcx5 = _mm256_adds_epi16(srcx5, srcx5);
- srcx5 = _mm256_adds_epi16(srcx5, src);
- x = _mm256_adds_epi16(x, srcx5);
- out[h] = _mm256_mulhrs_epi16(x, mshift);
- }
- }
-}
-
-static INLINE void iidentity64_col_16xn_avx2(uint8_t *output, int stride,
- __m256i *buf, int shift,
- int height) {
- const __m256i mshift = _mm256_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
- const __m256i scale =
- _mm256_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
- __m256i x = _mm256_mulhrs_epi16(buf[h], scale);
- __m256i srcx5 = _mm256_adds_epi16(buf[h], buf[h]);
- srcx5 = _mm256_adds_epi16(srcx5, srcx5);
- srcx5 = _mm256_adds_epi16(srcx5, buf[h]);
- x = _mm256_adds_epi16(x, srcx5);
- x = _mm256_mulhrs_epi16(x, mshift);
- write_recon_w16_avx2(x, output);
- output += stride;
- }
-}
-
-static INLINE void identity_row_16xn_avx2(__m256i *out, const int32_t *input,
- int stride, int shift, int height,
- int txw_idx, int rect_type) {
- int rect = (rect_type != 1 && rect_type != -1) ? 0 : 1;
- switch (txw_idx) {
- case 2:
- iidentity16_row_16xn_avx2(out, input, stride, shift, height, rect);
- break;
- case 3:
- iidentity32_row_16xn_avx2(out, input, stride, shift, height, rect);
- break;
- case 4:
- iidentity64_row_16xn_avx2(out, input, stride, shift, height, rect);
- break;
- default: break;
- }
-}
-
-static INLINE void identity_col_16xn_avx2(uint8_t *output, int stride,
- __m256i *buf, int shift, int height,
- int txh_idx) {
- switch (txh_idx) {
- case 2:
- iidentity16_col_16xn_avx2(output, stride, buf, shift, height);
- break;
- case 3:
- iidentity32_col_16xn_avx2(output, stride, buf, shift, height);
- break;
- case 4:
- iidentity64_col_16xn_avx2(output, stride, buf, shift, height);
- break;
- default: break;
- }
-}
-
static INLINE void lowbd_inv_txfm2d_add_idtx_avx2(const int32_t *input,
uint8_t *output, int stride,
TX_SIZE tx_size) {
@@ -1054,9 +940,10 @@
const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row);
__m256i buf[32];
for (int i = 0; i < input_stride; i += 16) {
- identity_row_16xn_avx2(buf, input + i, input_stride, shift[0], row_max,
- txw_idx, rect_type);
- identity_col_16xn_avx2(output + i, stride, buf, shift[1], row_max, txh_idx);
+ iidentity_row_16xn_avx2(buf, input + i, input_stride, shift[0], row_max,
+ txw_idx, rect_type);
+ iidentity_col_16xn_avx2(output + i, stride, buf, shift[1], row_max,
+ txh_idx);
}
}
@@ -1085,8 +972,8 @@
get_flip_cfg(tx_type, &ud_flip, &lr_flip);
for (int i = 0; i < txfm_size_col_notzero; i += 16) {
__m256i buf0[64];
- identity_row_16xn_avx2(buf0, input + i, input_stride, shift[0],
- txfm_size_row_notzero, txw_idx, rect_type);
+ iidentity_row_16xn_avx2(buf0, input + i, input_stride, shift[0],
+ txfm_size_row_notzero, txw_idx, rect_type);
col_txfm(buf0, buf0, cos_bit_col);
__m256i mshift = _mm256_set1_epi16(1 << (15 + shift[1]));
int k = ud_flip ? (txfm_size_row - 1) : 0;
@@ -1149,8 +1036,8 @@
}
}
for (int j = 0; j < buf_size_w_div16; ++j) {
- identity_col_16xn_avx2(output + i * 16 * stride + j * 16, stride,
- buf1 + j * 16, shift[1], 16, txh_idx);
+ iidentity_col_16xn_avx2(output + i * 16 * stride + j * 16, stride,
+ buf1 + j * 16, shift[1], 16, txh_idx);
}
}
}
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.c b/av1/common/x86/av1_inv_txfm_ssse3.c
index 25cac55..b9706e9 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.c
+++ b/av1/common/x86/av1_inv_txfm_ssse3.c
@@ -1563,339 +1563,72 @@
{ NULL, NULL, NULL },
};
-static INLINE void iidentity4_row_8xn_ssse3(__m128i *out, const int32_t *input,
- int stride, int shift, int height) {
+static INLINE void iidentity_row_8xn_ssse3(__m128i *out, const int32_t *input,
+ int stride, int shift, int height,
+ int txw_idx, int rect_type) {
const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
- __m128i src = load_32bit_to_16bit(input_row);
- input_row += stride;
- __m128i x = _mm_mulhrs_epi16(src, scale);
- x = _mm_adds_epi16(x, src);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity4_row_rect_8xn_ssse3(__m128i *out,
- const int32_t *input,
- int stride, int shift,
- int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
- __m128i src = load_32bit_to_16bit(input_row);
- input_row += stride;
- src = _mm_mulhrs_epi16(src, rect_scale);
- __m128i x = _mm_mulhrs_epi16(src, scale);
- x = _mm_adds_epi16(x, src);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity4_col_8xn_ssse3(uint8_t *output, int stride,
- __m128i *buf, int shift,
- int height) {
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = NewSqrt2 - (1 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- const __m128i zero = _mm_setzero_si128();
- for (int h = 0; h < height; ++h) {
- __m128i x = _mm_mulhrs_epi16(buf[h], scale);
- x = _mm_adds_epi16(x, buf[h]);
- x = _mm_mulhrs_epi16(x, mshift);
- const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
- x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
- __m128i u = _mm_packus_epi16(x, x);
- _mm_storel_epi64((__m128i *)(output), u);
- output += stride;
- }
-}
-
-static INLINE void iidentity8_row_8xn_ssse3(__m128i *out, const int32_t *input,
- int stride, int shift, int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- for (int h = 0; h < height; ++h) {
- __m128i src0 = _mm_load_si128((__m128i *)(input_row));
- __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
- input_row += stride;
- __m128i x = _mm_packs_epi32(src0, src1);
- x = _mm_adds_epi16(x, x);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity8_row_rect_8xn_ssse3(__m128i *out,
- const int32_t *input,
- int stride, int shift,
- int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 * 8);
- for (int h = 0; h < height; ++h) {
- __m128i src0 = _mm_load_si128((__m128i *)(input_row));
- __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
- input_row += stride;
- __m128i x = _mm_packs_epi32(src0, src1);
- x = _mm_mulhrs_epi16(x, rect_scale);
- x = _mm_adds_epi16(x, x);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity8_col_8xn_ssse3(uint8_t *output, int stride,
- __m128i *buf, int shift,
- int height) {
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const __m128i zero = _mm_setzero_si128();
- for (int h = 0; h < height; ++h) {
- __m128i x = _mm_adds_epi16(buf[h], buf[h]);
- x = _mm_mulhrs_epi16(x, mshift);
- const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
- x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
- __m128i u = _mm_packus_epi16(x, x);
- _mm_storel_epi64((__m128i *)(output), u);
- output += stride;
- }
-}
-
-static INLINE void iidentity16_row_8xn_ssse3(__m128i *out, const int32_t *input,
- int stride, int shift,
- int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
- __m128i src = load_32bit_to_16bit(input_row);
- input_row += stride;
- __m128i x = _mm_mulhrs_epi16(src, scale);
- __m128i srcx2 = _mm_adds_epi16(src, src);
- x = _mm_adds_epi16(x, srcx2);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity16_row_rect_8xn_ssse3(__m128i *out,
- const int32_t *input,
- int stride, int shift,
- int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
- __m128i src = load_32bit_to_16bit(input_row);
- input_row += stride;
- src = _mm_mulhrs_epi16(src, rect_scale);
- __m128i x = _mm_mulhrs_epi16(src, scale);
- __m128i srcx2 = _mm_adds_epi16(src, src);
- x = _mm_adds_epi16(x, srcx2);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity16_col_8xn_ssse3(uint8_t *output, int stride,
- __m128i *buf, int shift,
- int height) {
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 2 * NewSqrt2 - (2 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- const __m128i zero = _mm_setzero_si128();
- for (int h = 0; h < height; ++h) {
- __m128i x = _mm_mulhrs_epi16(buf[h], scale);
- __m128i srcx2 = _mm_adds_epi16(buf[h], buf[h]);
- x = _mm_adds_epi16(x, srcx2);
- x = _mm_mulhrs_epi16(x, mshift);
- const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
- x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
- __m128i u = _mm_packus_epi16(x, x);
- _mm_storel_epi64((__m128i *)(output), u);
- output += stride;
- }
-}
-
-static INLINE void iidentity32_row_8xn_ssse3(__m128i *out, const int32_t *input,
- int stride, int shift,
- int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- for (int h = 0; h < height; ++h) {
- __m128i src0 = _mm_load_si128((__m128i *)(input_row));
- __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
- input_row += stride;
- __m128i x = _mm_packs_epi32(src0, src1);
- x = _mm_adds_epi16(x, x);
- x = _mm_adds_epi16(x, x);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity32_row_rect_8xn_ssse3(__m128i *out,
- const int32_t *input,
- int stride, int shift,
- int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 * 8);
- for (int h = 0; h < height; ++h) {
- __m128i src0 = _mm_load_si128((__m128i *)(input_row));
- __m128i src1 = _mm_load_si128((__m128i *)(input_row + 4));
- input_row += stride;
- __m128i x = _mm_packs_epi32(src0, src1);
- x = _mm_mulhrs_epi16(x, rect_scale);
- x = _mm_adds_epi16(x, x);
- x = _mm_adds_epi16(x, x);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity32_col_8xn_ssse3(uint8_t *output, int stride,
- __m128i *buf, int shift,
- int height) {
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const __m128i zero = _mm_setzero_si128();
- for (int h = 0; h < height; ++h) {
- __m128i x = _mm_adds_epi16(buf[h], buf[h]);
- x = _mm_adds_epi16(x, x);
- x = _mm_mulhrs_epi16(x, mshift);
- const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
- x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
- __m128i u = _mm_packus_epi16(x, x);
- _mm_storel_epi64((__m128i *)(output), u);
- output += stride;
- }
-}
-
-static INLINE void iidentity64_row_8xn_ssse3(__m128i *out, const int32_t *input,
- int stride, int shift,
- int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
- __m128i src = load_32bit_to_16bit(input_row);
- input_row += stride;
- __m128i x = _mm_mulhrs_epi16(src, scale);
- __m128i srcx5 = _mm_adds_epi16(src, src);
- srcx5 = _mm_adds_epi16(srcx5, srcx5);
- srcx5 = _mm_adds_epi16(srcx5, src);
- x = _mm_adds_epi16(x, srcx5);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity64_row_rect_8xn_ssse3(__m128i *out,
- const int32_t *input,
- int stride, int shift,
- int height) {
- const int32_t *input_row = input;
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- const __m128i rect_scale = _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
- for (int h = 0; h < height; ++h) {
- __m128i src = load_32bit_to_16bit(input_row);
- input_row += stride;
- src = _mm_mulhrs_epi16(src, rect_scale);
- __m128i x = _mm_mulhrs_epi16(src, scale);
- __m128i srcx5 = _mm_adds_epi16(src, src);
- srcx5 = _mm_adds_epi16(srcx5, srcx5);
- srcx5 = _mm_adds_epi16(srcx5, src);
- x = _mm_adds_epi16(x, srcx5);
- out[h] = _mm_mulhrs_epi16(x, mshift);
- }
-}
-
-static INLINE void iidentity64_col_8xn_ssse3(uint8_t *output, int stride,
- __m128i *buf, int shift,
- int height) {
- const __m128i mshift = _mm_set1_epi16(1 << (15 + shift));
- const int16_t scale_fractional = 4 * NewSqrt2 - (5 << NewSqrt2Bits);
- const __m128i scale = _mm_set1_epi16(scale_fractional << (15 - NewSqrt2Bits));
- const __m128i zero = _mm_setzero_si128();
- for (int h = 0; h < height; ++h) {
- __m128i x = _mm_mulhrs_epi16(buf[h], scale);
- __m128i srcx5 = _mm_adds_epi16(buf[h], buf[h]);
- srcx5 = _mm_adds_epi16(srcx5, srcx5);
- srcx5 = _mm_adds_epi16(srcx5, buf[h]);
- x = _mm_adds_epi16(x, srcx5);
- x = _mm_mulhrs_epi16(x, mshift);
- const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
- x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
- __m128i u = _mm_packus_epi16(x, x);
- _mm_storel_epi64((__m128i *)(output), u);
- output += stride;
- }
-}
-
-static INLINE void identity_row_8xn_ssse3(__m128i *out, const int32_t *input,
- int stride, int shift, int height,
- int txw_idx, int rect_type) {
+ const __m128i scale = _mm_set1_epi16(NewSqrt2list[txw_idx]);
+ const __m128i rounding = _mm_set1_epi16((1 << (NewSqrt2Bits - 1)) +
+ (1 << (NewSqrt2Bits - shift - 1)));
+ const __m128i one = _mm_set1_epi16(1);
+ const __m128i scale_rounding = _mm_unpacklo_epi16(scale, rounding);
if (rect_type != 1 && rect_type != -1) {
- switch (txw_idx) {
- case 0:
- iidentity4_row_8xn_ssse3(out, input, stride, shift, height);
- break;
- case 1:
- iidentity8_row_8xn_ssse3(out, input, stride, shift, height);
- break;
- case 2:
- iidentity16_row_8xn_ssse3(out, input, stride, shift, height);
- break;
- case 3:
- iidentity32_row_8xn_ssse3(out, input, stride, shift, height);
- break;
- case 4:
- iidentity64_row_8xn_ssse3(out, input, stride, shift, height);
- break;
- default: break;
+ for (int i = 0; i < height; ++i) {
+ __m128i src = load_32bit_to_16bit(input_row);
+ input_row += stride;
+ __m128i lo = _mm_unpacklo_epi16(src, one);
+ __m128i hi = _mm_unpackhi_epi16(src, one);
+ lo = _mm_madd_epi16(lo, scale_rounding);
+ hi = _mm_madd_epi16(hi, scale_rounding);
+ lo = _mm_srai_epi32(lo, NewSqrt2Bits - shift);
+ hi = _mm_srai_epi32(hi, NewSqrt2Bits - shift);
+ out[i] = _mm_packs_epi32(lo, hi);
}
} else {
- switch (txw_idx) {
- case 0:
- iidentity4_row_rect_8xn_ssse3(out, input, stride, shift, height);
- break;
- case 1:
- iidentity8_row_rect_8xn_ssse3(out, input, stride, shift, height);
- break;
- case 2:
- iidentity16_row_rect_8xn_ssse3(out, input, stride, shift, height);
- break;
- case 3:
- iidentity32_row_rect_8xn_ssse3(out, input, stride, shift, height);
- break;
- case 4:
- iidentity64_row_rect_8xn_ssse3(out, input, stride, shift, height);
- break;
- default: break;
+ const __m128i rect_scale =
+ _mm_set1_epi16(NewInvSqrt2 << (15 - NewSqrt2Bits));
+ for (int i = 0; i < height; ++i) {
+ __m128i src = load_32bit_to_16bit(input_row);
+ src = _mm_mulhrs_epi16(src, rect_scale);
+ input_row += stride;
+ __m128i lo = _mm_unpacklo_epi16(src, one);
+ __m128i hi = _mm_unpackhi_epi16(src, one);
+ lo = _mm_madd_epi16(lo, scale_rounding);
+ hi = _mm_madd_epi16(hi, scale_rounding);
+ lo = _mm_srai_epi32(lo, NewSqrt2Bits - shift);
+ hi = _mm_srai_epi32(hi, NewSqrt2Bits - shift);
+ out[i] = _mm_packs_epi32(lo, hi);
}
}
}
-static INLINE void identity_col_8xn_ssse3(uint8_t *output, int stride,
- __m128i *buf, int shift, int height,
- int txh_idx) {
- switch (txh_idx) {
- case 0: iidentity4_col_8xn_ssse3(output, stride, buf, shift, height); break;
- case 1: iidentity8_col_8xn_ssse3(output, stride, buf, shift, height); break;
- case 2:
- iidentity16_col_8xn_ssse3(output, stride, buf, shift, height);
- break;
- case 3:
- iidentity32_col_8xn_ssse3(output, stride, buf, shift, height);
- break;
- case 4:
- iidentity64_col_8xn_ssse3(output, stride, buf, shift, height);
- break;
- default: break;
+static INLINE void iidentity_col_8xn_ssse3(uint8_t *output, int stride,
+ __m128i *buf, int shift, int height,
+ int txh_idx) {
+ const __m128i scale = _mm_set1_epi16(NewSqrt2list[txh_idx]);
+ const __m128i scale_rounding = _mm_set1_epi16(1 << (NewSqrt2Bits - 1));
+ const __m128i shift_rounding = _mm_set1_epi32(1 << (-shift - 1));
+ const __m128i one = _mm_set1_epi16(1);
+ const __m128i scale_coeff = _mm_unpacklo_epi16(scale, scale_rounding);
+ const __m128i zero = _mm_setzero_si128();
+ for (int h = 0; h < height; ++h) {
+ __m128i lo = _mm_unpacklo_epi16(buf[h], one);
+ __m128i hi = _mm_unpackhi_epi16(buf[h], one);
+ lo = _mm_madd_epi16(lo, scale_coeff);
+ hi = _mm_madd_epi16(hi, scale_coeff);
+ lo = _mm_srai_epi32(lo, NewSqrt2Bits);
+ hi = _mm_srai_epi32(hi, NewSqrt2Bits);
+ lo = _mm_add_epi32(lo, shift_rounding);
+ hi = _mm_add_epi32(hi, shift_rounding);
+ lo = _mm_srai_epi32(lo, -shift);
+ hi = _mm_srai_epi32(hi, -shift);
+ __m128i x = _mm_packs_epi32(lo, hi);
+
+ const __m128i pred = _mm_loadl_epi64((__m128i const *)(output));
+ x = _mm_adds_epi16(x, _mm_unpacklo_epi8(pred, zero));
+ __m128i u = _mm_packus_epi16(x, x);
+ _mm_storel_epi64((__m128i *)(output), u);
+ output += stride;
}
}
@@ -1913,10 +1646,10 @@
__m128i buf[32];
for (int i = 0; i<input_stride>> 3; ++i) {
- identity_row_8xn_ssse3(buf, input + 8 * i, input_stride, shift[0], row_max,
- txw_idx, rect_type);
- identity_col_8xn_ssse3(output + 8 * i, stride, buf, shift[1], row_max,
- txh_idx);
+ iidentity_row_8xn_ssse3(buf, input + 8 * i, input_stride, shift[0], row_max,
+ txw_idx, rect_type);
+ iidentity_col_8xn_ssse3(output + 8 * i, stride, buf, shift[1], row_max,
+ txh_idx);
}
}
@@ -2083,8 +1816,8 @@
get_flip_cfg(tx_type, &ud_flip, &lr_flip);
for (int i = 0; i < AOMMIN(4, buf_size_w_div8); i++) {
__m128i buf0[64];
- identity_row_8xn_ssse3(buf0, input + 8 * i, input_stride, shift[0],
- txfm_size_row_notzero, txw_idx, rect_type);
+ iidentity_row_8xn_ssse3(buf0, input + 8 * i, input_stride, shift[0],
+ txfm_size_row_notzero, txw_idx, rect_type);
col_txfm(buf0, buf0, cos_bit_col);
__m128i mshift = _mm_set1_epi16(1 << (15 + shift[1]));
int k = ud_flip ? (txfm_size_row - 1) : 0;
@@ -2149,8 +1882,8 @@
}
for (int j = 0; j < buf_size_w_div8; ++j) {
- identity_col_8xn_ssse3(output + i * 8 * stride + j * 8, stride,
- buf1 + j * 8, shift[1], 8, txh_idx);
+ iidentity_col_8xn_ssse3(output + i * 8 * stride + j * 8, stride,
+ buf1 + j * 8, shift[1], 8, txh_idx);
}
}
}
diff --git a/av1/common/x86/av1_inv_txfm_ssse3.h b/av1/common/x86/av1_inv_txfm_ssse3.h
index 96dc0d6..ccdb006 100644
--- a/av1/common/x86/av1_inv_txfm_ssse3.h
+++ b/av1/common/x86/av1_inv_txfm_ssse3.h
@@ -53,6 +53,10 @@
out1 = _mm_subs_epi16(_in0, _in1); \
} while (0)
+#ifdef __cplusplus
+extern "C" {
+#endif
+
static INLINE void round_shift_16bit_ssse3(__m128i *in, int size, int bit) {
if (bit < 0) {
const __m128i scale = _mm_set1_epi16(1 << (15 + bit));
@@ -66,10 +70,6 @@
}
}
-#ifdef __cplusplus
-extern "C" {
-#endif
-
// 1D itx types
typedef enum ATTRIBUTE_PACKED {
IDCT_1D,
@@ -93,6 +93,10 @@
IIDENTITY_1D, IADST_1D, IIDENTITY_1D, IFLIPADST_1D,
};
+// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5
+static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096,
+ 4 * 5793 };
+
typedef void (*transform_1d_ssse3)(const __m128i *input, __m128i *output,
int8_t cos_bit);
diff --git a/test/av1_inv_txfm2d_test.cc b/test/av1_inv_txfm2d_test.cc
index 73feb16..c07ce09 100644
--- a/test/av1_inv_txfm2d_test.cc
+++ b/test/av1_inv_txfm2d_test.cc
@@ -325,11 +325,13 @@
#endif // HAVE_SSSE3
#if HAVE_AVX2
-#if defined(_MSC_VER) || defined(__AVX2__)
-#include "av1/common/x86/av1_inv_txfm_avx2.h"
+extern "C" void av1_lowbd_inv_txfm2d_add_avx2(const int32_t *input,
+ uint8_t *output, int stride,
+ TX_TYPE tx_type, TX_SIZE tx_size,
+ int eob);
+
INSTANTIATE_TEST_CASE_P(AVX2, AV1LbdInvTxfm2d,
::testing::Values(av1_lowbd_inv_txfm2d_add_avx2));
-#endif // (_MSC_VER) || (__AVX2__)
#endif // HAVE_AVX2
} // namespace