[CFL] Constant Propagate Prediction Height These changes apply to both low bit depth (LBD) and high bit depth (HBD) predictions. Remove the duplicate tx_size parameter required for CfL prediction, which was only used to get the height. We now use the height. Move the NULL prediction functions into cfl.c to unclutter cfl.h. I wanted to reuse the CFL_TYPE macros used by subsampling, but that did not work out, I left the changes made to those macros in. Small speedups for all block sizes, except 32x32. For 32x32, the compiler unrolls all the loops which appears to be slightly slower than the previous code. SSSE3/CFLPredictTest 4x4: C time = 408 us, SIMD time = 126 us (~3.2x) 8x8: C time = 1367 us, SIMD time = 185 us (~7.4x) 16x16: C time = 5335 us, SIMD time = 648 us (~8.2x) 32x32: C time = 25477 us, SIMD time = 2301 us (~11x) AVX2/CFLPredictTest 4x4: C time = 408 us, SIMD time = 122 us (~3.3x) 8x8: C time = 1411 us, SIMD time = 205 us (~6.9x) 16x16: C time = 5542 us, SIMD time = 598 us (~9.3x) 32x32: C time = 25927 us, SIMD time = 1803 us (~14x) SSSE3/CFLPredictHBDTest 4x4: C time = 561 us, SIMD time = 152 us (~3.7x) 8x8: C time = 2067 us, SIMD time = 262 us (~7.9x) 16x16: C time = 10399 us, SIMD time = 863 us (~12x) 32x32: C time = 41489 us, SIMD time = 3403 us (~12x) AVX2/CFLPredictHBDTest 4x4: C time = 576 us, SIMD time = 156 us (~3.7x) 8x8: C time = 2077 us, SIMD time = 268 us (~7.8x) 16x16: C time = 10324 us, SIMD time = 584 us (~18x) 32x32: C time = 41976 us, SIMD time = 2119 us (~20x) Change-Id: Ib4a38e7d2cd37593d94ea58613d1c32232c24421
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl index fe91887..d0dffdd 100755 --- a/av1/common/av1_rtcd_defs.pl +++ b/av1/common/av1_rtcd_defs.pl
@@ -46,12 +46,10 @@ typedef void (*cfl_subtract_average_fn)(int16_t *pred_buf_q3); typedef void (*cfl_predict_lbd_fn)(const int16_t *pred_buf_q3, uint8_t *dst, - int dst_stride, TX_SIZE tx_size, - int alpha_q3); + int dst_stride, int alpha_q3); typedef void (*cfl_predict_hbd_fn)(const int16_t *pred_buf_q3, uint16_t *dst, - int dst_stride, TX_SIZE tx_size, - int alpha_q3, int bd); + int dst_stride, int alpha_q3, int bd); EOF } forward_decls qw/av1_common_forward_decls/;
diff --git a/av1/common/cfl.c b/av1/common/cfl.c index a0b4e2b..6ca9705 100644 --- a/av1/common/cfl.c +++ b/av1/common/cfl.c
@@ -155,11 +155,9 @@ return (alpha_sign == CFL_SIGN_POS) ? abs_alpha_q3 + 1 : -abs_alpha_q3 - 1; } -static void cfl_build_prediction_lbd(const int16_t *pred_buf_q3, uint8_t *dst, - int dst_stride, TX_SIZE tx_size, - int alpha_q3) { - const int height = tx_size_high[tx_size]; - const int width = tx_size_wide[tx_size]; +static INLINE void cfl_predict_lbd_c(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3, int width, + int height) { for (int j = 0; j < height; j++) { for (int i = 0; i < width; i++) { dst[i] = @@ -170,11 +168,21 @@ } } -static void cfl_build_prediction_hbd(const int16_t *pred_buf_q3, uint16_t *dst, - int dst_stride, TX_SIZE tx_size, - int alpha_q3, int bit_depth) { - const int height = tx_size_high[tx_size]; - const int width = tx_size_wide[tx_size]; +// Null function used for invalid tx_sizes +void cfl_predict_lbd_null(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3) { + (void)pred_buf_q3; + (void)dst; + (void)dst_stride; + (void)alpha_q3; + assert(0); +} + +CFL_PREDICT_FN(c, lbd) + +void cfl_predict_hbd_c(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bit_depth, int width, + int height) { for (int j = 0; j < height; j++) { for (int i = 0; i < width; i++) { dst[i] = clip_pixel_highbd( @@ -185,6 +193,19 @@ } } +// Null function used for invalid tx_sizes +void cfl_predict_hbd_null(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd) { + (void)pred_buf_q3; + (void)dst; + (void)dst_stride; + (void)alpha_q3; + (void)bd; + assert(0); +} + +CFL_PREDICT_FN(c, hbd) + static void cfl_compute_parameters(MACROBLOCKD *const xd, TX_SIZE tx_size) { CFL_CTX *const cfl = &xd->cfl; // Do not call cfl_compute_parameters multiple time on the same values. @@ -195,16 +216,6 @@ cfl->are_parameters_computed = 1; } -cfl_predict_lbd_fn get_predict_lbd_fn_c(TX_SIZE tx_size) { - (void)tx_size; - return cfl_build_prediction_lbd; -} - -cfl_predict_hbd_fn get_predict_hbd_fn_c(TX_SIZE tx_size) { - (void)tx_size; - return cfl_build_prediction_hbd; -} - void cfl_predict_block(MACROBLOCKD *const xd, uint8_t *dst, int dst_stride, TX_SIZE tx_size, int plane) { CFL_CTX *const cfl = &xd->cfl; @@ -219,12 +230,11 @@ CFL_BUF_SQUARE); if (get_bitdepth_data_path_index(xd)) { uint16_t *dst_16 = CONVERT_TO_SHORTPTR(dst); - get_predict_hbd_fn(tx_size)(cfl->pred_buf_q3, dst_16, dst_stride, tx_size, - alpha_q3, xd->bd); + get_predict_hbd_fn(tx_size)(cfl->pred_buf_q3, dst_16, dst_stride, alpha_q3, + xd->bd); return; } - get_predict_lbd_fn(tx_size)(cfl->pred_buf_q3, dst, dst_stride, tx_size, - alpha_q3); + get_predict_lbd_fn(tx_size)(cfl->pred_buf_q3, dst, dst_stride, alpha_q3); } // Null function used for invalid tx_sizes
diff --git a/av1/common/cfl.h b/av1/common/cfl.h index a76a27c..bb8d874 100644 --- a/av1/common/cfl.h +++ b/av1/common/cfl.h
@@ -54,19 +54,18 @@ int16_t *output_q3); // Allows the CFL_SUBSAMPLE function to switch types depending on the bitdepth. -#define CFL_SUBSAMPLE_INPUT_TYPE_lbd_ const uint8_t *input -#define CFL_SUBSAMPLE_INPUT_TYPE_hbd_ const uint16_t *input +#define CFL_lbd_TYPE uint8_t *cfl_type +#define CFL_hbd_TYPE uint16_t *cfl_type // Declare a size-specific wrapper for the size-generic function. The compiler // will inline the size generic function in here, the advantage is that the size // will be constant allowing for loop unrolling and other constant propagated // goodness. -#define CFL_SUBSAMPLE(arch, sub, bd, width, height) \ - void subsample_##bd##_##sub##_##width##x##height##_##arch( \ - CFL_SUBSAMPLE_INPUT_TYPE_##bd##_, int input_stride, \ - int16_t *output_q3) { \ - cfl_luma_subsampling_##sub##_##bd##_##arch(input, input_stride, output_q3, \ - width, height); \ +#define CFL_SUBSAMPLE(arch, sub, bd, width, height) \ + void subsample_##bd##_##sub##_##width##x##height##_##arch( \ + const CFL_##bd##_TYPE, int input_stride, int16_t *output_q3) { \ + cfl_luma_subsampling_##sub##_##bd##_##arch(cfl_type, input_stride, \ + output_q3, width, height); \ } // Declare size-specific wrappers for all valid CfL sizes. @@ -180,18 +179,75 @@ return sub_avg[tx_size % TX_SIZES_ALL]; \ } -#define CFL_PREDICT_LBD_X(width, arch) \ - void cfl_predict_lbd_##width##_##arch(const int16_t *pred_buf_q3, \ - uint8_t *dst, int dst_stride, \ - TX_SIZE tx_size, int alpha_q3) { \ - cfl_predict_lbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, width); \ +#define CFL_PREDICT_lbd(arch, width, height) \ + void predict_lbd_##width##x##height##_##arch(const int16_t *pred_buf_q3, \ + uint8_t *dst, int dst_stride, \ + int alpha_q3) { \ + cfl_predict_lbd_##arch(pred_buf_q3, dst, dst_stride, alpha_q3, width, \ + height); \ } -#define CFL_PREDICT_HBD_X(width, arch) \ - void cfl_predict_hbd_##width##_##arch( \ - const int16_t *pred_buf_q3, uint16_t *dst, int dst_stride, \ - TX_SIZE tx_size, int alpha_q3, int bd) { \ - cfl_predict_hbd_x(pred_buf_q3, dst, dst_stride, tx_size, alpha_q3, bd, \ - width); \ +#define CFL_PREDICT_hbd(arch, width, height) \ + void predict_hbd_##width##x##height##_##arch(const int16_t *pred_buf_q3, \ + uint16_t *dst, int dst_stride, \ + int alpha_q3, int bd) { \ + cfl_predict_hbd_##arch(pred_buf_q3, dst, dst_stride, alpha_q3, bd, width, \ + height); \ } + +// This wrapper exists because clang format does not like calling macros with +// lowercase letters. +#define CFL_PREDICT_X(arch, width, height, bd) \ + CFL_PREDICT_##bd(arch, width, height) + +// Null function used for invalid tx_sizes +void cfl_predict_lbd_null(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); + +// Null function used for invalid tx_sizes +void cfl_predict_hbd_null(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); + +#define CFL_PREDICT_FN(arch, bd) \ + CFL_PREDICT_X(arch, 4, 4, bd) \ + CFL_PREDICT_X(arch, 4, 8, bd) \ + CFL_PREDICT_X(arch, 4, 16, bd) \ + CFL_PREDICT_X(arch, 8, 4, bd) \ + CFL_PREDICT_X(arch, 8, 8, bd) \ + CFL_PREDICT_X(arch, 8, 16, bd) \ + CFL_PREDICT_X(arch, 8, 32, bd) \ + CFL_PREDICT_X(arch, 16, 4, bd) \ + CFL_PREDICT_X(arch, 16, 8, bd) \ + CFL_PREDICT_X(arch, 16, 16, bd) \ + CFL_PREDICT_X(arch, 16, 32, bd) \ + CFL_PREDICT_X(arch, 32, 8, bd) \ + CFL_PREDICT_X(arch, 32, 16, bd) \ + CFL_PREDICT_X(arch, 32, 32, bd) \ + cfl_predict_##bd##_fn get_predict_##bd##_fn_##arch(TX_SIZE tx_size) { \ + static const cfl_predict_##bd##_fn pred[TX_SIZES_ALL] = { \ + predict_##bd##_4x4_##arch, /* 4x4 */ \ + predict_##bd##_8x8_##arch, /* 8x8 */ \ + predict_##bd##_16x16_##arch, /* 16x16 */ \ + predict_##bd##_32x32_##arch, /* 32x32 */ \ + cfl_predict_##bd##_null, /* 64x64 (invalid CFL size) */ \ + predict_##bd##_4x8_##arch, /* 4x8 */ \ + predict_##bd##_8x4_##arch, /* 8x4 */ \ + predict_##bd##_8x16_##arch, /* 8x16 */ \ + predict_##bd##_16x8_##arch, /* 16x8 */ \ + predict_##bd##_16x32_##arch, /* 16x32 */ \ + predict_##bd##_32x16_##arch, /* 32x16 */ \ + cfl_predict_##bd##_null, /* 32x64 (invalid CFL size) */ \ + cfl_predict_##bd##_null, /* 64x32 (invalid CFL size) */ \ + predict_##bd##_4x16_##arch, /* 4x16 */ \ + predict_##bd##_16x4_##arch, /* 16x4 */ \ + predict_##bd##_8x32_##arch, /* 8x32 */ \ + predict_##bd##_32x8_##arch, /* 32x8 */ \ + cfl_predict_##bd##_null, /* 16x64 (invalid CFL size) */ \ + cfl_predict_##bd##_null, /* 64x16 (invalid CFL size) */ \ + }; \ + /* Modulo TX_SIZES_ALL to ensure that an attacker won't be able to */ \ + /* index the function pointer array out of bounds. */ \ + return pred[tx_size % TX_SIZES_ALL]; \ + } + #endif // AV1_COMMON_CFL_H_
diff --git a/av1/common/x86/cfl_avx2.c b/av1/common/x86/cfl_avx2.c index 775d3ff..5d3a141 100644 --- a/av1/common/x86/cfl_avx2.c +++ b/av1/common/x86/cfl_avx2.c
@@ -90,14 +90,15 @@ return _mm256_add_epi16(scaled_luma_q0, dc_q0); } -static INLINE void cfl_predict_lbd_32_avx2(const int16_t *pred_buf_q3, - uint8_t *dst, int dst_stride, - TX_SIZE tx_size, int alpha_q3) { +static INLINE void cfl_predict_lbd_avx2(const int16_t *pred_buf_q3, + uint8_t *dst, int dst_stride, + int alpha_q3, int width, int height) { + (void)width; const __m256i alpha_sign = _mm256_set1_epi16(alpha_q3); const __m256i alpha_q12 = _mm256_slli_epi16(_mm256_abs_epi16(alpha_sign), 9); const __m256i dc_q0 = _mm256_set1_epi16(*dst); __m256i *row = (__m256i *)pred_buf_q3; - const __m256i *row_end = row + tx_size_high[tx_size] * CFL_BUF_LINE_I256; + const __m256i *row_end = row + height * CFL_BUF_LINE_I256; do { __m256i res = predict_unclipped(row, alpha_q12, alpha_sign, dc_q0); @@ -109,6 +110,37 @@ } while ((row += CFL_BUF_LINE_I256) < row_end); } +CFL_PREDICT_X(avx2, 32, 8, lbd); +CFL_PREDICT_X(avx2, 32, 16, lbd); +CFL_PREDICT_X(avx2, 32, 32, lbd); + +cfl_predict_lbd_fn get_predict_lbd_fn_avx2(TX_SIZE tx_size) { + static const cfl_predict_lbd_fn pred[TX_SIZES_ALL] = { + predict_lbd_4x4_ssse3, /* 4x4 */ + predict_lbd_8x8_ssse3, /* 8x8 */ + predict_lbd_16x16_ssse3, /* 16x16 */ + predict_lbd_32x32_avx2, /* 32x32 */ + cfl_predict_lbd_null, /* 64x64 (invalid CFL size) */ + predict_lbd_4x8_ssse3, /* 4x8 */ + predict_lbd_8x4_ssse3, /* 8x4 */ + predict_lbd_8x16_ssse3, /* 8x16 */ + predict_lbd_16x8_ssse3, /* 16x8 */ + predict_lbd_16x32_ssse3, /* 16x32 */ + predict_lbd_32x16_avx2, /* 32x16 */ + cfl_predict_lbd_null, /* 32x64 (invalid CFL size) */ + cfl_predict_lbd_null, /* 64x32 (invalid CFL size) */ + predict_lbd_4x16_ssse3, /* 4x16 */ + predict_lbd_16x4_ssse3, /* 16x4 */ + predict_lbd_8x32_ssse3, /* 8x32 */ + predict_lbd_32x8_avx2, /* 32x8 */ + cfl_predict_lbd_null, /* 16x64 (invalid CFL size) */ + cfl_predict_lbd_null, /* 64x16 (invalid CFL size) */ + }; + /* Modulo TX_SIZES_ALL to ensure that an attacker won't be able to + */ /* index the function pointer array out of bounds. */ + return pred[tx_size % TX_SIZES_ALL]; +} + static __m256i highbd_max_epi16(int bd) { const __m256i neg_one = _mm256_set1_epi16(-1); // (1 << bd) - 1 => -(-1 << bd) -1 => -1 - (-1 << bd) => -1 ^ (-1 << bd) @@ -127,9 +159,10 @@ highbd_clamp_epi16(res, _mm256_setzero_si256(), max)); } -static INLINE void cfl_predict_hbd_x(const int16_t *pred_buf_q3, uint16_t *dst, - int dst_stride, TX_SIZE tx_size, - int alpha_q3, int bd, int width) { +static INLINE void cfl_predict_hbd_avx2(const int16_t *pred_buf_q3, + uint16_t *dst, int dst_stride, + int alpha_q3, int bd, int width, + int height) { // Use SSSE3 version for smaller widths assert(width == 16 || width == 32); const __m256i alpha_sign = _mm256_set1_epi16(alpha_q3); @@ -138,7 +171,7 @@ const __m256i max = highbd_max_epi16(bd); __m256i *row = (__m256i *)pred_buf_q3; - const __m256i *row_end = row + tx_size_high[tx_size] * CFL_BUF_LINE_I256; + const __m256i *row_end = row + height * CFL_BUF_LINE_I256; do { cfl_predict_hbd((__m256i *)dst, row, alpha_q12, alpha_sign, dc_q0, max); if (width == 32) { @@ -149,24 +182,39 @@ } while ((row += CFL_BUF_LINE_I256) < row_end); } -CFL_PREDICT_HBD_X(16, avx2) -CFL_PREDICT_HBD_X(32, avx2) - -cfl_predict_lbd_fn get_predict_lbd_fn_avx2(TX_SIZE tx_size) { - // Sizes 4, 8 and 16 reuse the SSSE3 version - static const cfl_predict_lbd_fn predict_lbd[4] = { cfl_predict_lbd_4_ssse3, - cfl_predict_lbd_8_ssse3, - cfl_predict_lbd_16_ssse3, - cfl_predict_lbd_32_avx2 }; - return predict_lbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3]; -} +CFL_PREDICT_X(avx2, 16, 4, hbd) +CFL_PREDICT_X(avx2, 16, 8, hbd) +CFL_PREDICT_X(avx2, 16, 16, hbd) +CFL_PREDICT_X(avx2, 16, 32, hbd) +CFL_PREDICT_X(avx2, 32, 8, hbd) +CFL_PREDICT_X(avx2, 32, 16, hbd) +CFL_PREDICT_X(avx2, 32, 32, hbd) cfl_predict_hbd_fn get_predict_hbd_fn_avx2(TX_SIZE tx_size) { - static const cfl_predict_hbd_fn predict_hbd[4] = { cfl_predict_hbd_4_ssse3, - cfl_predict_hbd_8_ssse3, - cfl_predict_hbd_16_avx2, - cfl_predict_hbd_32_avx2 }; - return predict_hbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3]; + static const cfl_predict_hbd_fn pred[TX_SIZES_ALL] = { + predict_hbd_4x4_ssse3, /* 4x4 */ + predict_hbd_8x8_ssse3, /* 8x8 */ + predict_hbd_16x16_avx2, /* 16x16 */ + predict_hbd_32x32_avx2, /* 32x32 */ + cfl_predict_hbd_null, /* 64x64 (invalid CFL size) */ + predict_hbd_4x8_ssse3, /* 4x8 */ + predict_hbd_8x4_ssse3, /* 8x4 */ + predict_hbd_8x16_ssse3, /* 8x16 */ + predict_hbd_16x8_avx2, /* 16x8 */ + predict_hbd_16x32_avx2, /* 16x32 */ + predict_hbd_32x16_avx2, /* 32x16 */ + cfl_predict_hbd_null, /* 32x64 (invalid CFL size) */ + cfl_predict_hbd_null, /* 64x32 (invalid CFL size) */ + predict_hbd_4x16_ssse3, /* 4x16 */ + predict_hbd_16x4_avx2, /* 16x4 */ + predict_hbd_8x32_ssse3, /* 8x32 */ + predict_hbd_32x8_avx2, /* 32x8 */ + cfl_predict_hbd_null, /* 16x64 (invalid CFL size) */ + cfl_predict_hbd_null, /* 64x16 (invalid CFL size) */ + }; + /* Modulo TX_SIZES_ALL to ensure that an attacker won't be able to + */ /* index the function pointer array out of bounds. */ + return pred[tx_size % TX_SIZES_ALL]; } // Returns a vector where all the (32-bits) elements are the sum of all the
diff --git a/av1/common/x86/cfl_simd.h b/av1/common/x86/cfl_simd.h index 3e75cb4..058d170 100644 --- a/av1/common/x86/cfl_simd.h +++ b/av1/common/x86/cfl_simd.h
@@ -11,28 +11,6 @@ #include "av1/common/blockd.h" -// SSSE3 version is optimal for with == 4, we reuse it in AVX2 -void cfl_predict_lbd_4_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, - int dst_stride, TX_SIZE tx_size, int alpha_q3); - -// SSSE3 version is optimal for with == 8, we reuse it in AVX2 -void cfl_predict_lbd_8_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, - int dst_stride, TX_SIZE tx_size, int alpha_q3); - -// SSSE3 version is optimal for with == 16, we reuse it in AVX2 -void cfl_predict_lbd_16_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, - int dst_stride, TX_SIZE tx_size, int alpha_q3); - -// SSSE3 version is optimal for with == 4, we reuse it in AVX2 -void cfl_predict_hbd_4_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, - int dst_stride, TX_SIZE tx_size, int alpha_q3, - int bd); - -// SSSE3 version is optimal for with == 8, we reuse it in AVX2 -void cfl_predict_hbd_8_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, - int dst_stride, TX_SIZE tx_size, int alpha_q3, - int bd); - // SSSE3 version is optimal for with == 4, we reuse them in AVX2 void subsample_lbd_420_4x4_ssse3(const uint8_t *input, int input_stride, int16_t *output_q3); @@ -71,3 +49,53 @@ void subtract_average_8x8_sse2(int16_t *pred_buf_q3); void subtract_average_8x16_sse2(int16_t *pred_buf_q3); void subtract_average_8x32_sse2(int16_t *pred_buf_q3); + +void predict_lbd_4x4_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); +void predict_lbd_4x8_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); +void predict_lbd_4x16_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); + +void predict_lbd_8x4_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); +void predict_lbd_8x8_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); +void predict_lbd_8x16_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); +void predict_lbd_8x32_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); + +void predict_lbd_16x4_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); +void predict_lbd_16x8_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); +void predict_lbd_16x16_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); +void predict_lbd_16x32_ssse3(const int16_t *pred_buf_q3, uint8_t *dst, + int dst_stride, int alpha_q3); + +void predict_hbd_4x4_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); +void predict_hbd_4x8_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); +void predict_hbd_4x16_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); + +void predict_hbd_8x4_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); +void predict_hbd_8x8_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); +void predict_hbd_8x16_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); +void predict_hbd_8x32_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); + +void predict_hbd_16x4_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); +void predict_hbd_16x8_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); +void predict_hbd_16x16_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd); +void predict_hbd_16x32_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, + int dst_stride, int alpha_q3, int bd);
diff --git a/av1/common/x86/cfl_ssse3.c b/av1/common/x86/cfl_ssse3.c index f7b6c6f..735b806 100644 --- a/av1/common/x86/cfl_ssse3.c +++ b/av1/common/x86/cfl_ssse3.c
@@ -89,16 +89,16 @@ return _mm_add_epi16(scaled_luma_q0, dc_q0); } -static INLINE void cfl_predict_lbd_x(const int16_t *pred_buf_q3, uint8_t *dst, - int dst_stride, TX_SIZE tx_size, - int alpha_q3, int width) { - uint8_t *row_end = dst + tx_size_high[tx_size] * dst_stride; +static INLINE void cfl_predict_lbd_ssse3(const int16_t *pred_buf_q3, + uint8_t *dst, int dst_stride, + int alpha_q3, int width, int height) { const __m128i alpha_sign = _mm_set1_epi16(alpha_q3); const __m128i alpha_q12 = _mm_slli_epi16(_mm_abs_epi16(alpha_sign), 9); const __m128i dc_q0 = _mm_set1_epi16(*dst); + __m128i *row = (__m128i *)pred_buf_q3; + const __m128i *row_end = row + height * CFL_BUF_LINE_I128; do { - __m128i res = predict_unclipped((__m128i *)(pred_buf_q3), alpha_q12, - alpha_sign, dc_q0); + __m128i res = predict_unclipped(row, alpha_q12, alpha_sign, dc_q0); if (width < 16) { res = _mm_packus_epi16(res, res); if (width == 4) @@ -106,24 +106,22 @@ else _mm_storel_epi64((__m128i *)dst, res); } else { - __m128i next = predict_unclipped((__m128i *)(pred_buf_q3 + 8), alpha_q12, - alpha_sign, dc_q0); + __m128i next = predict_unclipped(row + 1, alpha_q12, alpha_sign, dc_q0); res = _mm_packus_epi16(res, next); _mm_storeu_si128((__m128i *)dst, res); if (width == 32) { - res = predict_unclipped((__m128i *)(pred_buf_q3 + 16), alpha_q12, - alpha_sign, dc_q0); - next = predict_unclipped((__m128i *)(pred_buf_q3 + 24), alpha_q12, - alpha_sign, dc_q0); + res = predict_unclipped(row + 2, alpha_q12, alpha_sign, dc_q0); + next = predict_unclipped(row + 3, alpha_q12, alpha_sign, dc_q0); res = _mm_packus_epi16(res, next); _mm_storeu_si128((__m128i *)(dst + 16), res); } } dst += dst_stride; - pred_buf_q3 += CFL_BUF_LINE; - } while (dst < row_end); + } while ((row += CFL_BUF_LINE_I128) < row_end); } +CFL_PREDICT_FN(ssse3, lbd) + static INLINE __m128i highbd_max_epi16(int bd) { const __m128i neg_one = _mm_set1_epi16(-1); // (1 << bd) - 1 => -(-1 << bd) -1 => -1 - (-1 << bd) => -1 ^ (-1 << bd) @@ -141,60 +139,35 @@ _mm_storeu_si128(dst, highbd_clamp_epi16(res, _mm_setzero_si128(), max)); } -static INLINE void cfl_predict_hbd_x(const int16_t *pred_buf_q3, uint16_t *dst, - int dst_stride, TX_SIZE tx_size, - int alpha_q3, int bd, int width) { - uint16_t *row_end = dst + tx_size_high[tx_size] * dst_stride; +static INLINE void cfl_predict_hbd_ssse3(const int16_t *pred_buf_q3, + uint16_t *dst, int dst_stride, + int alpha_q3, int bd, int width, + int height) { const __m128i alpha_sign = _mm_set1_epi16(alpha_q3); const __m128i alpha_q12 = _mm_slli_epi16(_mm_abs_epi16(alpha_sign), 9); const __m128i dc_q0 = _mm_set1_epi16(*dst); const __m128i max = highbd_max_epi16(bd); + __m128i *row = (__m128i *)pred_buf_q3; + const __m128i *row_end = row + height * CFL_BUF_LINE_I128; do { if (width == 4) { - __m128i res = predict_unclipped((__m128i *)(pred_buf_q3), alpha_q12, - alpha_sign, dc_q0); + __m128i res = predict_unclipped(row, alpha_q12, alpha_sign, dc_q0); _mm_storel_epi64((__m128i *)dst, highbd_clamp_epi16(res, _mm_setzero_si128(), max)); } else { - cfl_predict_hbd((__m128i *)dst, (__m128i *)pred_buf_q3, alpha_q12, - alpha_sign, dc_q0, max); + cfl_predict_hbd((__m128i *)dst, row, alpha_q12, alpha_sign, dc_q0, max); } if (width >= 16) - cfl_predict_hbd((__m128i *)(dst + 8), (__m128i *)(pred_buf_q3 + 8), - alpha_q12, alpha_sign, dc_q0, max); + cfl_predict_hbd((__m128i *)(dst + 8), row + 1, alpha_q12, alpha_sign, + dc_q0, max); if (width == 32) { - cfl_predict_hbd((__m128i *)(dst + 16), (__m128i *)(pred_buf_q3 + 16), - alpha_q12, alpha_sign, dc_q0, max); - cfl_predict_hbd((__m128i *)(dst + 24), (__m128i *)(pred_buf_q3 + 24), - alpha_q12, alpha_sign, dc_q0, max); + cfl_predict_hbd((__m128i *)(dst + 16), row + 2, alpha_q12, alpha_sign, + dc_q0, max); + cfl_predict_hbd((__m128i *)(dst + 24), row + 3, alpha_q12, alpha_sign, + dc_q0, max); } dst += dst_stride; - pred_buf_q3 += CFL_BUF_LINE; - } while (dst < row_end); + } while ((row += CFL_BUF_LINE_I128) < row_end); } -CFL_PREDICT_LBD_X(4, ssse3) -CFL_PREDICT_LBD_X(8, ssse3) -CFL_PREDICT_LBD_X(16, ssse3) -CFL_PREDICT_LBD_X(32, ssse3) - -CFL_PREDICT_HBD_X(4, ssse3) -CFL_PREDICT_HBD_X(8, ssse3) -CFL_PREDICT_HBD_X(16, ssse3) -CFL_PREDICT_HBD_X(32, ssse3) - -cfl_predict_lbd_fn get_predict_lbd_fn_ssse3(TX_SIZE tx_size) { - static const cfl_predict_lbd_fn predict_lbd[4] = { cfl_predict_lbd_4_ssse3, - cfl_predict_lbd_8_ssse3, - cfl_predict_lbd_16_ssse3, - cfl_predict_lbd_32_ssse3 }; - return predict_lbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3]; -} - -cfl_predict_hbd_fn get_predict_hbd_fn_ssse3(TX_SIZE tx_size) { - static const cfl_predict_hbd_fn predict_hbd[4] = { cfl_predict_hbd_4_ssse3, - cfl_predict_hbd_8_ssse3, - cfl_predict_hbd_16_ssse3, - cfl_predict_hbd_32_ssse3 }; - return predict_hbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3]; -} +CFL_PREDICT_FN(ssse3, hbd)
diff --git a/test/cfl_test.cc b/test/cfl_test.cc index 0274117..5dfd28e 100644 --- a/test/cfl_test.cc +++ b/test/cfl_test.cc
@@ -267,10 +267,9 @@ TEST_P(CFLPredictTest, PredictTest) { for (int it = 0; it < NUM_ITERATIONS; it++) { init(8); - fun_under_test(tx_size)(sub_luma_pels, chroma_pels, CFL_BUF_LINE, tx_size, - alpha_q3); + fun_under_test(tx_size)(sub_luma_pels, chroma_pels, CFL_BUF_LINE, alpha_q3); get_predict_lbd_fn_c(tx_size)(sub_luma_pels_ref, chroma_pels_ref, - CFL_BUF_LINE, tx_size, alpha_q3); + CFL_BUF_LINE, alpha_q3); assert_eq<uint8_t>(chroma_pels, chroma_pels_ref, width, height); } @@ -285,8 +284,7 @@ aom_usec_timer_start(&ref_timer); for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { - predict_impl(sub_luma_pels_ref, chroma_pels_ref, CFL_BUF_LINE, tx_size, - alpha_q3); + predict_impl(sub_luma_pels_ref, chroma_pels_ref, CFL_BUF_LINE, alpha_q3); } aom_usec_timer_mark(&ref_timer); int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer); @@ -294,7 +292,7 @@ predict_impl = fun_under_test(tx_size); aom_usec_timer_start(&timer); for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { - predict_impl(sub_luma_pels, chroma_pels, CFL_BUF_LINE, tx_size, alpha_q3); + predict_impl(sub_luma_pels, chroma_pels, CFL_BUF_LINE, alpha_q3); } aom_usec_timer_mark(&timer); int elapsed_time = (int)aom_usec_timer_elapsed(&timer); @@ -307,10 +305,10 @@ int bd = 12; for (int it = 0; it < NUM_ITERATIONS; it++) { init(bd); - fun_under_test(tx_size)(sub_luma_pels, chroma_pels, CFL_BUF_LINE, tx_size, - alpha_q3, bd); + fun_under_test(tx_size)(sub_luma_pels, chroma_pels, CFL_BUF_LINE, alpha_q3, + bd); get_predict_hbd_fn_c(tx_size)(sub_luma_pels_ref, chroma_pels_ref, - CFL_BUF_LINE, tx_size, alpha_q3, bd); + CFL_BUF_LINE, alpha_q3, bd); assert_eq<uint16_t>(chroma_pels, chroma_pels_ref, width, height); } @@ -325,8 +323,8 @@ aom_usec_timer_start(&ref_timer); for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { - predict_impl(sub_luma_pels_ref, chroma_pels_ref, CFL_BUF_LINE, tx_size, - alpha_q3, bd); + predict_impl(sub_luma_pels_ref, chroma_pels_ref, CFL_BUF_LINE, alpha_q3, + bd); } aom_usec_timer_mark(&ref_timer); int ref_elapsed_time = (int)aom_usec_timer_elapsed(&ref_timer); @@ -334,8 +332,7 @@ predict_impl = fun_under_test(tx_size); aom_usec_timer_start(&timer); for (int k = 0; k < NUM_ITERATIONS_SPEED; k++) { - predict_impl(sub_luma_pels, chroma_pels, CFL_BUF_LINE, tx_size, alpha_q3, - bd); + predict_impl(sub_luma_pels, chroma_pels, CFL_BUF_LINE, alpha_q3, bd); } aom_usec_timer_mark(&timer); int elapsed_time = (int)aom_usec_timer_elapsed(&timer);