Optimize highbd inv_txfm modules
Added AVX2 variants for 8x8,8x16,16x8,8x32 and 32x8 txfm blk_sizes.
When tested for multiple test cases observed 0.9%
average reduction in encoder time for speed = 1 preset.
Module level gains improved by a factor of ~1.66
on average w.r.t to SSE4_1 modules.
Change-Id: I061439c206aa4ae7cb39c2d77594e2dce4946c9b
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index d6b9fef..7049f16 100755
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -116,11 +116,11 @@
add_proto qw/void av1_highbd_inv_txfm_add_4x4/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
specialize qw/av1_highbd_inv_txfm_add_4x4 sse4_1/;
add_proto qw/void av1_highbd_inv_txfm_add_8x8/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_8x8 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_8x8 sse4_1 avx2/;
add_proto qw/void av1_highbd_inv_txfm_add_16x8/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_16x8 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_16x8 sse4_1 avx2/;
add_proto qw/void av1_highbd_inv_txfm_add_8x16/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_8x16 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_8x16 sse4_1 avx2/;
add_proto qw/void av1_highbd_inv_txfm_add_16x16/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
specialize qw/av1_highbd_inv_txfm_add_16x16 sse4_1 avx2/;
add_proto qw/void av1_highbd_inv_txfm_add_32x32/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
@@ -130,9 +130,9 @@
add_proto qw/void av1_highbd_inv_txfm_add_32x16/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
specialize qw/av1_highbd_inv_txfm_add_32x16 sse4_1 avx2/;
add_proto qw/void av1_highbd_inv_txfm_add_8x32/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_8x32 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_8x32 sse4_1 avx2/;
add_proto qw/void av1_highbd_inv_txfm_add_32x8/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
-specialize qw/av1_highbd_inv_txfm_add_32x8 sse4_1/;
+specialize qw/av1_highbd_inv_txfm_add_32x8 sse4_1 avx2/;
add_proto qw/void av1_highbd_inv_txfm_add_4x8/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
specialize qw/av1_highbd_inv_txfm_add_4x8 sse4_1/;
add_proto qw/void av1_highbd_inv_txfm_add_8x4/, "const tran_low_t *dqcoeff, uint8_t *dst, int stride, const TxfmParam *txfm_param";
diff --git a/av1/common/x86/highbd_inv_txfm_avx2.c b/av1/common/x86/highbd_inv_txfm_avx2.c
index abd892e..1d1ae06 100644
--- a/av1/common/x86/highbd_inv_txfm_avx2.c
+++ b/av1/common/x86/highbd_inv_txfm_avx2.c
@@ -73,7 +73,30 @@
_mm256_storeu_si256((__m256i *)(output + i * stride), u);
}
}
+static INLINE __m256i highbd_get_recon_8x8_avx2(const __m256i pred, __m256i res,
+ const int bd) {
+ __m256i x0 = pred;
+ x0 = _mm256_add_epi32(res, x0);
+ x0 = _mm256_packus_epi32(x0, x0);
+ x0 = _mm256_permute4x64_epi64(x0, 0xd8);
+ x0 = highbd_clamp_epi16_avx2(x0, bd);
+ return x0;
+}
+static INLINE void highbd_write_buffer_8xn_avx2(__m256i *in, uint16_t *output,
+ int stride, int flipud,
+ int height, const int bd) {
+ int j = flipud ? (height - 1) : 0;
+ __m128i temp;
+ const int step = flipud ? -1 : 1;
+ for (int i = 0; i < height; ++i, j += step) {
+ temp = _mm_loadu_si128((__m128i const *)(output + i * stride));
+ __m256i v = _mm256_cvtepi16_epi32(temp);
+ __m256i u = highbd_get_recon_8x8_avx2(v, in[j], bd);
+ __m128i u1 = _mm256_castsi256_si128(u);
+ _mm_storeu_si128((__m128i *)(output + i * stride), u1);
+ }
+}
static void neg_shift_avx2(const __m256i in0, const __m256i in1, __m256i *out0,
__m256i *out1, const __m256i *clamp_lo,
const __m256i *clamp_hi, int shift) {
@@ -2481,7 +2504,422 @@
}
}
}
+static void idct8x8_low1_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+ int bd, int out_shift) {
+ const int32_t *cospi = cospi_arr(bit);
+ const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+ const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+ const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+ __m256i x;
+ // stage 0
+ // stage 1
+ // stage 2
+ // stage 3
+ x = _mm256_mullo_epi32(in[0], cospi32);
+ x = _mm256_add_epi32(x, rnding);
+ x = _mm256_srai_epi32(x, bit);
+
+ // stage 4
+ // stage 5
+ if (!do_cols) {
+ const int log_range_out = AOMMAX(16, bd + 6);
+ const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX(
+ -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+ const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN(
+ (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+
+ __m256i offset = _mm256_set1_epi32((1 << out_shift) >> 1);
+ x = _mm256_add_epi32(x, offset);
+ x = _mm256_sra_epi32(x, _mm_cvtsi32_si128(out_shift));
+ x = _mm256_max_epi32(x, clamp_lo_out);
+ x = _mm256_min_epi32(x, clamp_hi_out);
+ }
+
+ out[0] = x;
+ out[1] = x;
+ out[2] = x;
+ out[3] = x;
+ out[4] = x;
+ out[5] = x;
+ out[6] = x;
+ out[7] = x;
+}
+static void idct8x8_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+ int bd, int out_shift) {
+ const int32_t *cospi = cospi_arr(bit);
+ const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
+ const __m256i cospim8 = _mm256_set1_epi32(-cospi[8]);
+ const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
+ const __m256i cospim40 = _mm256_set1_epi32(-cospi[40]);
+ const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
+ const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
+ const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+ const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+ const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
+ const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+ const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+ const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+ const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+ const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+ __m256i u0, u1, u2, u3, u4, u5, u6, u7;
+ __m256i v0, v1, v2, v3, v4, v5, v6, v7;
+ __m256i x, y;
+
+ // stage 0
+ // stage 1
+ // stage 2
+ u0 = in[0];
+ u1 = in[4];
+ u2 = in[2];
+ u3 = in[6];
+
+ x = _mm256_mullo_epi32(in[1], cospi56);
+ y = _mm256_mullo_epi32(in[7], cospim8);
+ u4 = _mm256_add_epi32(x, y);
+ u4 = _mm256_add_epi32(u4, rnding);
+ u4 = _mm256_srai_epi32(u4, bit);
+
+ x = _mm256_mullo_epi32(in[1], cospi8);
+ y = _mm256_mullo_epi32(in[7], cospi56);
+ u7 = _mm256_add_epi32(x, y);
+ u7 = _mm256_add_epi32(u7, rnding);
+ u7 = _mm256_srai_epi32(u7, bit);
+
+ x = _mm256_mullo_epi32(in[5], cospi24);
+ y = _mm256_mullo_epi32(in[3], cospim40);
+ u5 = _mm256_add_epi32(x, y);
+ u5 = _mm256_add_epi32(u5, rnding);
+ u5 = _mm256_srai_epi32(u5, bit);
+
+ x = _mm256_mullo_epi32(in[5], cospi40);
+ y = _mm256_mullo_epi32(in[3], cospi24);
+ u6 = _mm256_add_epi32(x, y);
+ u6 = _mm256_add_epi32(u6, rnding);
+ u6 = _mm256_srai_epi32(u6, bit);
+
+ // stage 3
+ x = _mm256_mullo_epi32(u0, cospi32);
+ y = _mm256_mullo_epi32(u1, cospi32);
+ v0 = _mm256_add_epi32(x, y);
+ v0 = _mm256_add_epi32(v0, rnding);
+ v0 = _mm256_srai_epi32(v0, bit);
+
+ v1 = _mm256_sub_epi32(x, y);
+ v1 = _mm256_add_epi32(v1, rnding);
+ v1 = _mm256_srai_epi32(v1, bit);
+
+ x = _mm256_mullo_epi32(u2, cospi48);
+ y = _mm256_mullo_epi32(u3, cospim16);
+ v2 = _mm256_add_epi32(x, y);
+ v2 = _mm256_add_epi32(v2, rnding);
+ v2 = _mm256_srai_epi32(v2, bit);
+
+ x = _mm256_mullo_epi32(u2, cospi16);
+ y = _mm256_mullo_epi32(u3, cospi48);
+ v3 = _mm256_add_epi32(x, y);
+ v3 = _mm256_add_epi32(v3, rnding);
+ v3 = _mm256_srai_epi32(v3, bit);
+
+ addsub_avx2(u4, u5, &v4, &v5, &clamp_lo, &clamp_hi);
+ addsub_avx2(u7, u6, &v7, &v6, &clamp_lo, &clamp_hi);
+
+ // stage 4
+ addsub_avx2(v0, v3, &u0, &u3, &clamp_lo, &clamp_hi);
+ addsub_avx2(v1, v2, &u1, &u2, &clamp_lo, &clamp_hi);
+ u4 = v4;
+ u7 = v7;
+
+ x = _mm256_mullo_epi32(v5, cospi32);
+ y = _mm256_mullo_epi32(v6, cospi32);
+ u6 = _mm256_add_epi32(y, x);
+ u6 = _mm256_add_epi32(u6, rnding);
+ u6 = _mm256_srai_epi32(u6, bit);
+
+ u5 = _mm256_sub_epi32(y, x);
+ u5 = _mm256_add_epi32(u5, rnding);
+ u5 = _mm256_srai_epi32(u5, bit);
+
+ // stage 5
+ if (do_cols) {
+ addsub_no_clamp_avx2(u0, u7, out + 0, out + 7);
+ addsub_no_clamp_avx2(u1, u6, out + 1, out + 6);
+ addsub_no_clamp_avx2(u2, u5, out + 2, out + 5);
+ addsub_no_clamp_avx2(u3, u4, out + 3, out + 4);
+ } else {
+ const int log_range_out = AOMMAX(16, bd + 6);
+ const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX(
+ -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift))));
+ const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN(
+ (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift))));
+ addsub_shift_avx2(u0, u7, out + 0, out + 7, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ addsub_shift_avx2(u1, u6, out + 1, out + 6, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ addsub_shift_avx2(u2, u5, out + 2, out + 5, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ addsub_shift_avx2(u3, u4, out + 3, out + 4, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ }
+}
+static void iadst8x8_low1_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+ int bd, int out_shift) {
+ const int32_t *cospi = cospi_arr(bit);
+ const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+ const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+ const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+ const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+ const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+ const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+ const __m256i kZero = _mm256_setzero_si256();
+ __m256i u[8], x;
+
+ // stage 0
+ // stage 1
+ // stage 2
+
+ x = _mm256_mullo_epi32(in[0], cospi60);
+ u[0] = _mm256_add_epi32(x, rnding);
+ u[0] = _mm256_srai_epi32(u[0], bit);
+
+ x = _mm256_mullo_epi32(in[0], cospi4);
+ u[1] = _mm256_sub_epi32(kZero, x);
+ u[1] = _mm256_add_epi32(u[1], rnding);
+ u[1] = _mm256_srai_epi32(u[1], bit);
+
+ // stage 3
+ // stage 4
+ __m256i temp1, temp2;
+ temp1 = _mm256_mullo_epi32(u[0], cospi16);
+ x = _mm256_mullo_epi32(u[1], cospi48);
+ temp1 = _mm256_add_epi32(temp1, x);
+ temp1 = _mm256_add_epi32(temp1, rnding);
+ temp1 = _mm256_srai_epi32(temp1, bit);
+ u[4] = temp1;
+
+ temp2 = _mm256_mullo_epi32(u[0], cospi48);
+ x = _mm256_mullo_epi32(u[1], cospi16);
+ u[5] = _mm256_sub_epi32(temp2, x);
+ u[5] = _mm256_add_epi32(u[5], rnding);
+ u[5] = _mm256_srai_epi32(u[5], bit);
+
+ // stage 5
+ // stage 6
+ temp1 = _mm256_mullo_epi32(u[0], cospi32);
+ x = _mm256_mullo_epi32(u[1], cospi32);
+ u[2] = _mm256_add_epi32(temp1, x);
+ u[2] = _mm256_add_epi32(u[2], rnding);
+ u[2] = _mm256_srai_epi32(u[2], bit);
+
+ u[3] = _mm256_sub_epi32(temp1, x);
+ u[3] = _mm256_add_epi32(u[3], rnding);
+ u[3] = _mm256_srai_epi32(u[3], bit);
+
+ temp1 = _mm256_mullo_epi32(u[4], cospi32);
+ x = _mm256_mullo_epi32(u[5], cospi32);
+ u[6] = _mm256_add_epi32(temp1, x);
+ u[6] = _mm256_add_epi32(u[6], rnding);
+ u[6] = _mm256_srai_epi32(u[6], bit);
+
+ u[7] = _mm256_sub_epi32(temp1, x);
+ u[7] = _mm256_add_epi32(u[7], rnding);
+ u[7] = _mm256_srai_epi32(u[7], bit);
+
+ // stage 7
+ if (do_cols) {
+ out[0] = u[0];
+ out[1] = _mm256_sub_epi32(kZero, u[4]);
+ out[2] = u[6];
+ out[3] = _mm256_sub_epi32(kZero, u[2]);
+ out[4] = u[3];
+ out[5] = _mm256_sub_epi32(kZero, u[7]);
+ out[6] = u[5];
+ out[7] = _mm256_sub_epi32(kZero, u[1]);
+ } else {
+ const int log_range_out = AOMMAX(16, bd + 6);
+ const __m256i clamp_lo_out = _mm256_set1_epi32(-(1 << (log_range_out - 1)));
+ const __m256i clamp_hi_out =
+ _mm256_set1_epi32((1 << (log_range_out - 1)) - 1);
+
+ neg_shift_avx2(u[0], u[4], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ neg_shift_avx2(u[6], u[2], out + 2, out + 3, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ neg_shift_avx2(u[3], u[7], out + 4, out + 5, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ neg_shift_avx2(u[5], u[1], out + 6, out + 7, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ }
+}
+
+static void iadst8x8_avx2(__m256i *in, __m256i *out, int bit, int do_cols,
+ int bd, int out_shift) {
+ const int32_t *cospi = cospi_arr(bit);
+ const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
+ const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
+ const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
+ const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
+ const __m256i cospi36 = _mm256_set1_epi32(cospi[36]);
+ const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
+ const __m256i cospi52 = _mm256_set1_epi32(cospi[52]);
+ const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
+ const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
+ const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
+ const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
+ const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
+ const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
+ const __m256i kZero = _mm256_setzero_si256();
+ const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
+ const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1)));
+ const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1);
+ __m256i u[8], v[8], x;
+
+ // stage 0
+ // stage 1
+ // stage 2
+
+ u[0] = _mm256_mullo_epi32(in[7], cospi4);
+ x = _mm256_mullo_epi32(in[0], cospi60);
+ u[0] = _mm256_add_epi32(u[0], x);
+ u[0] = _mm256_add_epi32(u[0], rnding);
+ u[0] = _mm256_srai_epi32(u[0], bit);
+
+ u[1] = _mm256_mullo_epi32(in[7], cospi60);
+ x = _mm256_mullo_epi32(in[0], cospi4);
+ u[1] = _mm256_sub_epi32(u[1], x);
+ u[1] = _mm256_add_epi32(u[1], rnding);
+ u[1] = _mm256_srai_epi32(u[1], bit);
+
+ u[2] = _mm256_mullo_epi32(in[5], cospi20);
+ x = _mm256_mullo_epi32(in[2], cospi44);
+ u[2] = _mm256_add_epi32(u[2], x);
+ u[2] = _mm256_add_epi32(u[2], rnding);
+ u[2] = _mm256_srai_epi32(u[2], bit);
+
+ u[3] = _mm256_mullo_epi32(in[5], cospi44);
+ x = _mm256_mullo_epi32(in[2], cospi20);
+ u[3] = _mm256_sub_epi32(u[3], x);
+ u[3] = _mm256_add_epi32(u[3], rnding);
+ u[3] = _mm256_srai_epi32(u[3], bit);
+
+ u[4] = _mm256_mullo_epi32(in[3], cospi36);
+ x = _mm256_mullo_epi32(in[4], cospi28);
+ u[4] = _mm256_add_epi32(u[4], x);
+ u[4] = _mm256_add_epi32(u[4], rnding);
+ u[4] = _mm256_srai_epi32(u[4], bit);
+
+ u[5] = _mm256_mullo_epi32(in[3], cospi28);
+ x = _mm256_mullo_epi32(in[4], cospi36);
+ u[5] = _mm256_sub_epi32(u[5], x);
+ u[5] = _mm256_add_epi32(u[5], rnding);
+ u[5] = _mm256_srai_epi32(u[5], bit);
+
+ u[6] = _mm256_mullo_epi32(in[1], cospi52);
+ x = _mm256_mullo_epi32(in[6], cospi12);
+ u[6] = _mm256_add_epi32(u[6], x);
+ u[6] = _mm256_add_epi32(u[6], rnding);
+ u[6] = _mm256_srai_epi32(u[6], bit);
+
+ u[7] = _mm256_mullo_epi32(in[1], cospi12);
+ x = _mm256_mullo_epi32(in[6], cospi52);
+ u[7] = _mm256_sub_epi32(u[7], x);
+ u[7] = _mm256_add_epi32(u[7], rnding);
+ u[7] = _mm256_srai_epi32(u[7], bit);
+
+ // stage 3
+ addsub_avx2(u[0], u[4], &v[0], &v[4], &clamp_lo, &clamp_hi);
+ addsub_avx2(u[1], u[5], &v[1], &v[5], &clamp_lo, &clamp_hi);
+ addsub_avx2(u[2], u[6], &v[2], &v[6], &clamp_lo, &clamp_hi);
+ addsub_avx2(u[3], u[7], &v[3], &v[7], &clamp_lo, &clamp_hi);
+
+ // stage 4
+ u[0] = v[0];
+ u[1] = v[1];
+ u[2] = v[2];
+ u[3] = v[3];
+
+ u[4] = _mm256_mullo_epi32(v[4], cospi16);
+ x = _mm256_mullo_epi32(v[5], cospi48);
+ u[4] = _mm256_add_epi32(u[4], x);
+ u[4] = _mm256_add_epi32(u[4], rnding);
+ u[4] = _mm256_srai_epi32(u[4], bit);
+
+ u[5] = _mm256_mullo_epi32(v[4], cospi48);
+ x = _mm256_mullo_epi32(v[5], cospi16);
+ u[5] = _mm256_sub_epi32(u[5], x);
+ u[5] = _mm256_add_epi32(u[5], rnding);
+ u[5] = _mm256_srai_epi32(u[5], bit);
+
+ u[6] = _mm256_mullo_epi32(v[6], cospim48);
+ x = _mm256_mullo_epi32(v[7], cospi16);
+ u[6] = _mm256_add_epi32(u[6], x);
+ u[6] = _mm256_add_epi32(u[6], rnding);
+ u[6] = _mm256_srai_epi32(u[6], bit);
+
+ u[7] = _mm256_mullo_epi32(v[6], cospi16);
+ x = _mm256_mullo_epi32(v[7], cospim48);
+ u[7] = _mm256_sub_epi32(u[7], x);
+ u[7] = _mm256_add_epi32(u[7], rnding);
+ u[7] = _mm256_srai_epi32(u[7], bit);
+
+ // stage 5
+ addsub_avx2(u[0], u[2], &v[0], &v[2], &clamp_lo, &clamp_hi);
+ addsub_avx2(u[1], u[3], &v[1], &v[3], &clamp_lo, &clamp_hi);
+ addsub_avx2(u[4], u[6], &v[4], &v[6], &clamp_lo, &clamp_hi);
+ addsub_avx2(u[5], u[7], &v[5], &v[7], &clamp_lo, &clamp_hi);
+
+ // stage 6
+ u[0] = v[0];
+ u[1] = v[1];
+ u[4] = v[4];
+ u[5] = v[5];
+
+ v[0] = _mm256_mullo_epi32(v[2], cospi32);
+ x = _mm256_mullo_epi32(v[3], cospi32);
+ u[2] = _mm256_add_epi32(v[0], x);
+ u[2] = _mm256_add_epi32(u[2], rnding);
+ u[2] = _mm256_srai_epi32(u[2], bit);
+
+ u[3] = _mm256_sub_epi32(v[0], x);
+ u[3] = _mm256_add_epi32(u[3], rnding);
+ u[3] = _mm256_srai_epi32(u[3], bit);
+
+ v[0] = _mm256_mullo_epi32(v[6], cospi32);
+ x = _mm256_mullo_epi32(v[7], cospi32);
+ u[6] = _mm256_add_epi32(v[0], x);
+ u[6] = _mm256_add_epi32(u[6], rnding);
+ u[6] = _mm256_srai_epi32(u[6], bit);
+
+ u[7] = _mm256_sub_epi32(v[0], x);
+ u[7] = _mm256_add_epi32(u[7], rnding);
+ u[7] = _mm256_srai_epi32(u[7], bit);
+
+ // stage 7
+ if (do_cols) {
+ out[0] = u[0];
+ out[1] = _mm256_sub_epi32(kZero, u[4]);
+ out[2] = u[6];
+ out[3] = _mm256_sub_epi32(kZero, u[2]);
+ out[4] = u[3];
+ out[5] = _mm256_sub_epi32(kZero, u[7]);
+ out[6] = u[5];
+ out[7] = _mm256_sub_epi32(kZero, u[1]);
+ } else {
+ const int log_range_out = AOMMAX(16, bd + 6);
+ const __m256i clamp_lo_out = _mm256_set1_epi32(-(1 << (log_range_out - 1)));
+ const __m256i clamp_hi_out =
+ _mm256_set1_epi32((1 << (log_range_out - 1)) - 1);
+
+ neg_shift_avx2(u[0], u[4], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ neg_shift_avx2(u[6], u[2], out + 2, out + 3, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ neg_shift_avx2(u[3], u[7], out + 4, out + 5, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ neg_shift_avx2(u[5], u[1], out + 6, out + 7, &clamp_lo_out, &clamp_hi_out,
+ out_shift);
+ }
+}
typedef void (*transform_1d_avx2)(__m256i *in, __m256i *out, int bit,
int do_cols, int bd, int out_shift);
@@ -2493,8 +2931,8 @@
{ NULL, NULL, NULL, NULL },
},
{
- { NULL, NULL, NULL, NULL },
- { NULL, NULL, NULL, NULL },
+ { idct8x8_low1_avx2, idct8x8_avx2, NULL, NULL },
+ { iadst8x8_low1_avx2, iadst8x8_avx2, NULL, NULL },
{ NULL, NULL, NULL, NULL },
},
{
@@ -2580,12 +3018,15 @@
}
// write to buffer
- {
+ if (txfm_size_col >= 16) {
for (int i = 0; i < (txfm_size_col >> 4); i++) {
highbd_write_buffer_16xn_avx2(buf1 + i * txfm_size_row * 2,
output + 16 * i, stride, ud_flip,
txfm_size_row, bd);
}
+ } else if (txfm_size_col == 8) {
+ highbd_write_buffer_8xn_avx2(buf1, output, stride, ud_flip, txfm_size_row,
+ bd);
}
}
@@ -2698,7 +3139,123 @@
default: assert(0);
}
}
+void av1_highbd_inv_txfm_add_8x8_avx2(const tran_low_t *input, uint8_t *dest,
+ int stride, const TxfmParam *txfm_param) {
+ int bd = txfm_param->bd;
+ const TX_TYPE tx_type = txfm_param->tx_type;
+ const int32_t *src = cast_to_int32(input);
+ switch (tx_type) {
+ // Assembly version doesn't support some transform types, so use C version
+ // for those.
+ case V_DCT:
+ case H_DCT:
+ case V_ADST:
+ case H_ADST:
+ case V_FLIPADST:
+ case H_FLIPADST:
+ case IDTX:
+ av1_inv_txfm2d_add_8x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type,
+ bd);
+ break;
+ default:
+ av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+ txfm_param->tx_size,
+ txfm_param->eob, bd);
+ break;
+ }
+}
+void av1_highbd_inv_txfm_add_8x32_avx2(const tran_low_t *input, uint8_t *dest,
+ int stride,
+ const TxfmParam *txfm_param) {
+ int bd = txfm_param->bd;
+ const TX_TYPE tx_type = txfm_param->tx_type;
+ const int32_t *src = cast_to_int32(input);
+ switch (tx_type) {
+ case DCT_DCT:
+ av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+ txfm_param->tx_size,
+ txfm_param->eob, bd);
+ break;
+ case IDTX:
+ av1_inv_txfm2d_add_8x32_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+ txfm_param->tx_type, txfm_param->bd);
+ break;
+ default: assert(0);
+ }
+}
+void av1_highbd_inv_txfm_add_32x8_avx2(const tran_low_t *input, uint8_t *dest,
+ int stride,
+ const TxfmParam *txfm_param) {
+ int bd = txfm_param->bd;
+ const TX_TYPE tx_type = txfm_param->tx_type;
+ const int32_t *src = cast_to_int32(input);
+ switch (tx_type) {
+ case DCT_DCT:
+ av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+ txfm_param->tx_size,
+ txfm_param->eob, bd);
+ break;
+ case IDTX:
+ av1_inv_txfm2d_add_32x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+ txfm_param->tx_type, txfm_param->bd);
+ break;
+ default: assert(0);
+ }
+}
+void av1_highbd_inv_txfm_add_16x8_avx2(const tran_low_t *input, uint8_t *dest,
+ int stride,
+ const TxfmParam *txfm_param) {
+ int bd = txfm_param->bd;
+ const TX_TYPE tx_type = txfm_param->tx_type;
+ const int32_t *src = cast_to_int32(input);
+ switch (tx_type) {
+ // Assembly version doesn't support some transform types, so use C version
+ // for those.
+ case V_DCT:
+ case H_DCT:
+ case V_ADST:
+ case H_ADST:
+ case V_FLIPADST:
+ case H_FLIPADST:
+ case IDTX:
+ av1_inv_txfm2d_add_16x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+ txfm_param->tx_type, txfm_param->bd);
+ break;
+ default:
+ av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+ txfm_param->tx_size,
+ txfm_param->eob, bd);
+ break;
+ }
+}
+
+void av1_highbd_inv_txfm_add_8x16_avx2(const tran_low_t *input, uint8_t *dest,
+ int stride,
+ const TxfmParam *txfm_param) {
+ int bd = txfm_param->bd;
+ const TX_TYPE tx_type = txfm_param->tx_type;
+ const int32_t *src = cast_to_int32(input);
+ switch (tx_type) {
+ // Assembly version doesn't support some transform types, so use C version
+ // for those.
+ case V_DCT:
+ case H_DCT:
+ case V_ADST:
+ case H_ADST:
+ case V_FLIPADST:
+ case H_FLIPADST:
+ case IDTX:
+ av1_inv_txfm2d_add_8x16_c(src, CONVERT_TO_SHORTPTR(dest), stride,
+ txfm_param->tx_type, txfm_param->bd);
+ break;
+ default:
+ av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type,
+ txfm_param->tx_size,
+ txfm_param->eob, bd);
+ break;
+ }
+}
void av1_highbd_inv_txfm_add_avx2(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *txfm_param) {
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
@@ -2711,7 +3268,7 @@
av1_highbd_inv_txfm_add_16x16_avx2(input, dest, stride, txfm_param);
break;
case TX_8X8:
- av1_highbd_inv_txfm_add_8x8_sse4_1(input, dest, stride, txfm_param);
+ av1_highbd_inv_txfm_add_8x8_avx2(input, dest, stride, txfm_param);
break;
case TX_4X8:
av1_highbd_inv_txfm_add_4x8_sse4_1(input, dest, stride, txfm_param);
@@ -2720,10 +3277,10 @@
av1_highbd_inv_txfm_add_8x4_sse4_1(input, dest, stride, txfm_param);
break;
case TX_8X16:
- av1_highbd_inv_txfm_add_8x16_sse4_1(input, dest, stride, txfm_param);
+ av1_highbd_inv_txfm_add_8x16_avx2(input, dest, stride, txfm_param);
break;
case TX_16X8:
- av1_highbd_inv_txfm_add_16x8_sse4_1(input, dest, stride, txfm_param);
+ av1_highbd_inv_txfm_add_16x8_avx2(input, dest, stride, txfm_param);
break;
case TX_16X32:
av1_highbd_inv_txfm_add_16x32_avx2(input, dest, stride, txfm_param);
@@ -2741,10 +3298,10 @@
av1_highbd_inv_txfm_add_4x16_sse4_1(input, dest, stride, txfm_param);
break;
case TX_8X32:
- av1_highbd_inv_txfm_add_8x32_sse4_1(input, dest, stride, txfm_param);
+ av1_highbd_inv_txfm_add_8x32_avx2(input, dest, stride, txfm_param);
break;
case TX_32X8:
- av1_highbd_inv_txfm_add_32x8_sse4_1(input, dest, stride, txfm_param);
+ av1_highbd_inv_txfm_add_32x8_avx2(input, dest, stride, txfm_param);
break;
case TX_64X64:
case TX_32X64: