fix method 4 implementation for AVM-research-v8.0.0
diff --git a/av1/common/idct.c b/av1/common/idct.c
index f45750e..5af71fc 100644
--- a/av1/common/idct.c
+++ b/av1/common/idct.c
@@ -452,10 +452,24 @@
txfm_param->tx_size = tx_size;
// EOB needs to adjusted after inverse IST
if (txfm_param->sec_tx_type) {
+#if CONFIG_IST_REDUCE_METHOD4
+ const int st_size_class = (width == 8 && height == 8) ? 1 : (width >= 8 && height >= 8) ? 2 : 0;
+ txfm_param->eob = eob;
+ if (st_size_class == 0) {
+ txfm_param->eob = IST_4x4_HEIGHT;
+ }
+ else if (st_size_class == 1) {
+ txfm_param->eob = IST_8x8_HEIGHT_RED;
+ }
+ else {
+ txfm_param->eob = IST_8x8_HEIGHT;
+ }
+#else
// txfm_param->eob = av1_get_max_eob(tx_size);
const int sb_size =
(tx_size_wide[tx_size] >= 8 && tx_size_high[tx_size] >= 8) ? 8 : 4;
txfm_param->eob = (sb_size == 4) ? IST_4x4_WIDTH : IST_8x8_WIDTH;
+#endif
} else {
txfm_param->eob = eob;
}
diff --git a/av1/common/x86/highbd_inv_txfm_avx2.c b/av1/common/x86/highbd_inv_txfm_avx2.c
index 1b23c9a..9aafa7e 100644
--- a/av1/common/x86/highbd_inv_txfm_avx2.c
+++ b/av1/common/x86/highbd_inv_txfm_avx2.c
@@ -4487,20 +4487,45 @@
void inv_stxfm_avx2(tran_low_t *src, tran_low_t *dst,
const PREDICTION_MODE mode, const uint8_t stx_idx,
const int size) {
+#if CONFIG_IST_REDUCE_METHOD4
+ const int16_t *kernel = (size == 0) ? ist_4x4_kernel[mode][stx_idx][0]
+ : ist_8x8_kernel[mode][stx_idx][0];
+ const int dimension = (size == 0) ? 16 : 64;
+#else
const int16_t *kernel = (size == 4) ? ist_4x4_kernel[mode][stx_idx][0]
: ist_8x8_kernel[mode][stx_idx][0];
+#endif
assert(stx_idx < 4);
const int rnd_factor = 1 << (7 - 1);
const __m256i round = _mm256_set1_epi32(rnd_factor);
int reduced_width, reduced_height;
+#if CONFIG_IST_REDUCE_METHOD4
+ if (size == 0) {
+ reduced_height = IST_4x4_HEIGHT;
+ reduced_width = IST_4x4_WIDTH;
+ }
+ else if (size == 1) {
+ reduced_height = IST_8x8_HEIGHT_RED;
+ reduced_width = IST_8x8_WIDTH;
+ }
+ else {
+ reduced_height = IST_8x8_HEIGHT;
+ reduced_width = IST_8x8_WIDTH;
+ }
+#else
if (size == 4) {
reduced_height = IST_4x4_HEIGHT;
reduced_width = IST_4x4_WIDTH;
} else {
+#if CONFIG_IST_REDUCE_METHOD3
+ reduced_height = IST_8x8_HEIGHT_RED;
+#else
reduced_height = IST_8x8_HEIGHT;
+#endif
reduced_width = IST_8x8_WIDTH;
}
+#endif
for (int j = 0; j < reduced_height; j++) {
const int16_t *kernel_tmp = kernel;
int *srcPtr = src;
@@ -4515,7 +4540,11 @@
tmp = _mm256_add_epi32(tmp, sum);
_mm256_storeu_si256(tmpBlock, tmp);
}
+#if CONFIG_IST_REDUCE_METHOD4
+ kernel += dimension;
+#else
kernel += (size * size);
+#endif
}
int *out = dst;
__m256i *tmpBlock = (__m256i *)out;
diff --git a/av1/encoder/x86/highbd_fwd_txfm_avx2.c b/av1/encoder/x86/highbd_fwd_txfm_avx2.c
index 97bfbb2..6c94537 100644
--- a/av1/encoder/x86/highbd_fwd_txfm_avx2.c
+++ b/av1/encoder/x86/highbd_fwd_txfm_avx2.c
@@ -3705,14 +3705,25 @@
void fwd_stxfm_avx2(tran_low_t *src, tran_low_t *dst,
const PREDICTION_MODE mode, const uint8_t stx_idx,
const int size) {
+#if CONFIG_IST_REDUCE_METHOD4
+ const int16_t *kernel = (size == 0) ? ist_4x4_kernel[mode][stx_idx][0]
+ : ist_8x8_kernel[mode][stx_idx][0];
+ const int dimension = (size == 0) ? 16 : 64;
+ const int ist_height = (size == 0) ? IST_4x4_HEIGHT : (size == 1) ? IST_8x8_HEIGHT_RED : IST_8x8_HEIGHT ;
+#else
const int16_t *kernel = (size == 4) ? ist_4x4_kernel[mode][stx_idx][0]
: ist_8x8_kernel[mode][stx_idx][0];
+#endif
int *out = dst;
assert(stx_idx < 4);
int shift = 7;
int offset = 1 << (shift - 1);
int *srcPtr = src;
+#if CONFIG_IST_REDUCE_METHOD4
+ if (size == 0) {
+#else
if (size == 4) {
+#endif
assert(IST_4x4_WIDTH == 16);
const __m256i offset_vec = _mm256_set1_epi32(offset);
__m256i kernel_t[16];
@@ -3743,7 +3754,11 @@
const __m256i src_6 = _mm256_loadu_si256((__m256i *)(srcPtr + 48));
// s56 s57 s58 s59 s60 s61 s62 s63
const __m256i src_7 = _mm256_loadu_si256((__m256i *)(srcPtr + 56));
+#if CONFIG_IST_REDUCE_METHOD4
+ for (int j = 0; j < ist_height; j++) {
+#else
for (int j = 0; j < IST_8x8_HEIGHT; j++) {
+#endif
const int16_t *kernel_tmp = kernel;
// k0 k1 k2 k3 k4 k5 k6 k7
const __m256i ker_0 =
@@ -3795,7 +3810,11 @@
_mm_add_epi32(sum_32x2, _mm_srli_si128(sum_32x2, 4));
int coef = _mm_cvtsi128_si32(sum_32x1);
*out++ = (coef + offset) >> shift;
+#if CONFIG_IST_REDUCE_METHOD4
+ kernel += dimension;
+#else
kernel += (size * size);
+#endif
}
}
}