[CFL] Faster AVX2 Average Subtract
Based on the observation that for small blocks AVX2 does not outperform
SSE2, we call the SSE2 code for block widths 4 and 8.
For widths 16 and 32, the AVX2 version is optimized by:
* Summing over two rows in the summing loop;
* Operating over the full 256bit registers in the summing loop;
* Using more accumulators to reduce coupling operations;
* Leveraging chained hadd calls in the fill function.
AVX2/CFLSubAvgTest
4x4: C time = 384 us, SIMD time = 153 us (~2.5x)
8x8: C time = 805 us, SIMD time = 229 us (~3.5x)
16x16: C time = 2757 us, SIMD time = 775 us (~3.6x)
32x32: C time = 10035 us, SIMD time = 2524 us (~4x)
Change-Id: I683994026c1f1626828e90949cd0bd911b46ed5e
diff --git a/av1/common/cfl.h b/av1/common/cfl.h
index 3c1494b..a76a27c 100644
--- a/av1/common/cfl.h
+++ b/av1/common/cfl.h
@@ -127,12 +127,17 @@
assert(0);
}
-#define CFL_SUB_AVG_X(arch, width, height, round_offset, num_pel_log2) \
- static void subtract_average_##width##x##height##_x(int16_t *pred_buf_q3) { \
- subtract_average_##arch(pred_buf_q3, width, height, round_offset, \
- num_pel_log2); \
+// Declare a size-specific wrapper for the size-generic function. The compiler
+// will inline the size generic function in here, the advantage is that the size
+// will be constant allowing for loop unrolling and other constant propagated
+// goodness.
+#define CFL_SUB_AVG_X(arch, width, height, round_offset, num_pel_log2) \
+ void subtract_average_##width##x##height##_##arch(int16_t *pred_buf_q3) { \
+ subtract_average_##arch(pred_buf_q3, width, height, round_offset, \
+ num_pel_log2); \
}
+// Declare size-specific wrappers for all valid CfL sizes.
#define CFL_SUB_AVG_FN(arch) \
CFL_SUB_AVG_X(arch, 4, 4, 8, 4) \
CFL_SUB_AVG_X(arch, 4, 8, 16, 5) \
@@ -150,25 +155,25 @@
CFL_SUB_AVG_X(arch, 32, 32, 512, 10) \
cfl_subtract_average_fn get_subtract_average_fn_##arch(TX_SIZE tx_size) { \
static const cfl_subtract_average_fn sub_avg[TX_SIZES_ALL] = { \
- subtract_average_4x4_x, /* 4x4 */ \
- subtract_average_8x8_x, /* 8x8 */ \
- subtract_average_16x16_x, /* 16x16 */ \
- subtract_average_32x32_x, /* 32x32 */ \
- cfl_subtract_average_null, /* 64x64 (invalid CFL size) */ \
- subtract_average_4x8_x, /* 4x8 */ \
- subtract_average_8x4_x, /* 8x4 */ \
- subtract_average_8x16_x, /* 8x16 */ \
- subtract_average_16x8_x, /* 16x8 */ \
- subtract_average_16x32_x, /* 16x32 */ \
- subtract_average_32x16_x, /* 32x16 */ \
- cfl_subtract_average_null, /* 32x64 (invalid CFL size) */ \
- cfl_subtract_average_null, /* 64x32 (invalid CFL size) */ \
- subtract_average_4x16_x, /* 4x16 (invalid CFL size) */ \
- subtract_average_16x4_x, /* 16x4 (invalid CFL size) */ \
- subtract_average_8x32_x, /* 8x32 (invalid CFL size) */ \
- subtract_average_32x8_x, /* 32x8 (invalid CFL size) */ \
- cfl_subtract_average_null, /* 16x64 (invalid CFL size) */ \
- cfl_subtract_average_null, /* 64x16 (invalid CFL size) */ \
+ subtract_average_4x4_##arch, /* 4x4 */ \
+ subtract_average_8x8_##arch, /* 8x8 */ \
+ subtract_average_16x16_##arch, /* 16x16 */ \
+ subtract_average_32x32_##arch, /* 32x32 */ \
+ cfl_subtract_average_null, /* 64x64 (invalid CFL size) */ \
+ subtract_average_4x8_##arch, /* 4x8 */ \
+ subtract_average_8x4_##arch, /* 8x4 */ \
+ subtract_average_8x16_##arch, /* 8x16 */ \
+ subtract_average_16x8_##arch, /* 16x8 */ \
+ subtract_average_16x32_##arch, /* 16x32 */ \
+ subtract_average_32x16_##arch, /* 32x16 */ \
+ cfl_subtract_average_null, /* 32x64 (invalid CFL size) */ \
+ cfl_subtract_average_null, /* 64x32 (invalid CFL size) */ \
+ subtract_average_4x16_##arch, /* 4x16 (invalid CFL size) */ \
+ subtract_average_16x4_##arch, /* 16x4 (invalid CFL size) */ \
+ subtract_average_8x32_##arch, /* 8x32 (invalid CFL size) */ \
+ subtract_average_32x8_##arch, /* 32x8 (invalid CFL size) */ \
+ cfl_subtract_average_null, /* 16x64 (invalid CFL size) */ \
+ cfl_subtract_average_null, /* 64x16 (invalid CFL size) */ \
}; \
/* Modulo TX_SIZES_ALL to ensure that an attacker won't be able to */ \
/* index the function pointer array out of bounds. */ \
diff --git a/av1/common/x86/cfl_avx2.c b/av1/common/x86/cfl_avx2.c
index d7ecad1..775d3ff 100644
--- a/av1/common/x86/cfl_avx2.c
+++ b/av1/common/x86/cfl_avx2.c
@@ -169,64 +169,70 @@
return predict_hbd[(tx_size_wide_log2[tx_size] - tx_size_wide_log2[0]) & 3];
}
-static INLINE __m256i fill_sum_epi32(__m256i l0) {
- l0 = _mm256_add_epi32(l0, _mm256_shuffle_epi32(l0, _MM_SHUFFLE(1, 0, 3, 2)));
- return _mm256_add_epi32(l0,
- _mm256_shuffle_epi32(l0, _MM_SHUFFLE(2, 3, 0, 1)));
+// Returns a vector where all the (32-bits) elements are the sum of all the
+// lanes in a.
+static INLINE __m256i fill_sum_epi32(__m256i a) {
+ // Given that a == [A, B, C, D, E, F, G, H]
+ a = _mm256_hadd_epi32(a, a);
+ // Given that A' == A + B, C' == C + D, E' == E + F, G' == G + H
+ // a == [A', C', A', C', E', G', E', G']
+ a = _mm256_permute4x64_epi64(a, _MM_SHUFFLE(3, 1, 2, 0));
+ // a == [A', C', E', G', A', C', E', G']
+ a = _mm256_hadd_epi32(a, a);
+ // Given that A'' == A' + C' and E'' == E' + G'
+ // a == [A'', E'', A'', E'', A'', E'', A'', E'']
+ return _mm256_hadd_epi32(a, a);
+ // Given that A''' == A'' + E''
+ // a == [A''', A''', A''', A''', A''', A''', A''', A''']
+}
+
+static INLINE __m256i _mm256_addl_epi16(__m256i a) {
+ return _mm256_add_epi32(_mm256_unpacklo_epi16(a, _mm256_setzero_si256()),
+ _mm256_unpackhi_epi16(a, _mm256_setzero_si256()));
}
static INLINE void subtract_average_avx2(int16_t *pred_buf, int width,
int height, int round_offset,
int num_pel_log2) {
- const __m256i zeros = _mm256_setzero_si256();
+ // Use SSE2 version for smaller widths
+ assert(width == 16 || width == 32);
__m256i *row = (__m256i *)pred_buf;
const __m256i *const end = row + height * CFL_BUF_LINE_I256;
- const int step = CFL_BUF_LINE_I256 * (1 + (width == 8) + 3 * (width == 4));
- union {
- __m256i v;
- int32_t i32[8];
- } sum;
- sum.v = zeros;
+ // To maximize usage of the AVX2 registers, we sum two rows per loop
+ // iteration
+ const int step = 2 * CFL_BUF_LINE_I256;
+ __m256i sum = _mm256_setzero_si256();
+
+ // For width 32, we use a second sum accumulator to reduce accumulator
+ // dependencies in the loop.
+ __m256i sum2;
+ if (width == 32) sum2 = _mm256_setzero_si256();
do {
- if (width == 4) {
- __m256i l0 = _mm256_loadu_si256(row);
- __m256i l1 = _mm256_loadu_si256(row + CFL_BUF_LINE_I256);
- __m256i l2 = _mm256_loadu_si256(row + 2 * CFL_BUF_LINE_I256);
- __m256i l3 = _mm256_loadu_si256(row + 3 * CFL_BUF_LINE_I256);
-
- __m256i t0 = _mm256_add_epi16(l0, l1);
- __m256i t1 = _mm256_add_epi16(l2, l3);
-
- sum.v = _mm256_add_epi32(
- sum.v, _mm256_add_epi32(_mm256_unpacklo_epi16(t0, zeros),
- _mm256_unpacklo_epi16(t1, zeros)));
- } else {
- __m256i l0;
- if (width == 8) {
- l0 = _mm256_add_epi16(_mm256_loadu_si256(row),
- _mm256_loadu_si256(row + CFL_BUF_LINE_I256));
- } else {
- l0 = _mm256_loadu_si256(row);
- l0 = _mm256_add_epi16(l0, _mm256_permute2x128_si256(l0, l0, 1));
- }
- sum.v = _mm256_add_epi32(
- sum.v, _mm256_add_epi32(_mm256_unpacklo_epi16(l0, zeros),
- _mm256_unpackhi_epi16(l0, zeros)));
- if (width == 32) {
- l0 = _mm256_loadu_si256(row + 1);
- l0 = _mm256_add_epi16(l0, _mm256_permute2x128_si256(l0, l0, 1));
- sum.v = _mm256_add_epi32(
- sum.v, _mm256_add_epi32(_mm256_unpacklo_epi16(l0, zeros),
- _mm256_unpackhi_epi16(l0, zeros)));
- }
+ // Add top row to the bottom row
+ __m256i l0 = _mm256_add_epi16(_mm256_loadu_si256(row),
+ _mm256_loadu_si256(row + CFL_BUF_LINE_I256));
+ sum = _mm256_add_epi32(sum, _mm256_addl_epi16(l0));
+ if (width == 32) { /* Don't worry, this if it gets optimized out. */
+ // Add the second part of the top row to the second part of the bottom row
+ __m256i l1 =
+ _mm256_add_epi16(_mm256_loadu_si256(row + 1),
+ _mm256_loadu_si256(row + 1 + CFL_BUF_LINE_I256));
+ // Store the sum of the second part in the same accumulator as the first
+ // part
+ sum2 = _mm256_add_epi32(sum2, _mm256_addl_epi16(l1));
}
} while ((row += step) < end);
+ // Combine both sum accumulator
+ if (width == 32) sum = _mm256_add_epi32(sum, sum2);
- sum.v = fill_sum_epi32(sum.v);
+ // The sum accumulator now contains the 8 lanes
+ __m256i fill = fill_sum_epi32(sum);
- __m256i avg_epi16 =
- _mm256_set1_epi16((sum.i32[0] + round_offset) >> num_pel_log2);
+ __m256i avg_epi16 = _mm256_srli_epi32(
+ _mm256_add_epi32(fill, _mm256_set1_epi32(round_offset)), num_pel_log2);
+ avg_epi16 = _mm256_packs_epi32(avg_epi16, avg_epi16);
+ // Store and subtract loop
row = (__m256i *)pred_buf;
do {
_mm256_storeu_si256(row,
@@ -238,4 +244,40 @@
} while ((row += CFL_BUF_LINE_I256) < end);
}
-CFL_SUB_AVG_FN(avx2)
+// Declare wrappers for AVX2 sizes
+CFL_SUB_AVG_X(avx2, 16, 4, 32, 6)
+CFL_SUB_AVG_X(avx2, 16, 8, 64, 7)
+CFL_SUB_AVG_X(avx2, 16, 16, 128, 8)
+CFL_SUB_AVG_X(avx2, 16, 32, 256, 9)
+CFL_SUB_AVG_X(avx2, 32, 8, 128, 8)
+CFL_SUB_AVG_X(avx2, 32, 16, 256, 9)
+CFL_SUB_AVG_X(avx2, 32, 32, 512, 10)
+
+// Based on the observation that for small blocks AVX2 does not outperform
+// SSE2, we call the SSE2 code for block widths 4 and 8.
+cfl_subtract_average_fn get_subtract_average_fn_avx2(TX_SIZE tx_size) {
+ static const cfl_subtract_average_fn sub_avg[TX_SIZES_ALL] = {
+ subtract_average_4x4_sse2, /* 4x4 */
+ subtract_average_8x8_sse2, /* 8x8 */
+ subtract_average_16x16_avx2, /* 16x16 */
+ subtract_average_32x32_avx2, /* 32x32 */
+ cfl_subtract_average_null, /* 64x64 (invalid CFL size) */
+ subtract_average_4x8_sse2, /* 4x8 */
+ subtract_average_8x4_sse2, /* 8x4 */
+ subtract_average_8x16_sse2, /* 8x16 */
+ subtract_average_16x8_avx2, /* 16x8 */
+ subtract_average_16x32_avx2, /* 16x32 */
+ subtract_average_32x16_avx2, /* 32x16 */
+ cfl_subtract_average_null, /* 32x64 (invalid CFL size) */
+ cfl_subtract_average_null, /* 64x32 (invalid CFL size) */
+ subtract_average_4x16_sse2, /* 4x16 */
+ subtract_average_16x4_avx2, /* 16x4 */
+ subtract_average_8x32_sse2, /* 8x32 */
+ subtract_average_32x8_avx2, /* 32x8 */
+ cfl_subtract_average_null, /* 16x64 (invalid CFL size) */
+ cfl_subtract_average_null, /* 64x16 (invalid CFL size) */
+ };
+ // Modulo TX_SIZES_ALL to ensure that an attacker won't be able to
+ // index the function pointer array out of bounds.
+ return sub_avg[tx_size % TX_SIZES_ALL];
+}
diff --git a/av1/common/x86/cfl_simd.h b/av1/common/x86/cfl_simd.h
index 17aaf15..3e75cb4 100644
--- a/av1/common/x86/cfl_simd.h
+++ b/av1/common/x86/cfl_simd.h
@@ -60,3 +60,14 @@
int16_t *output_q3);
void subsample_lbd_420_16x32_ssse3(const uint8_t *input, int input_stride,
int16_t *output_q3);
+
+// SSE2 version is optimal for with == 4, we reuse them in AVX2
+void subtract_average_4x4_sse2(int16_t *pred_buf_q3);
+void subtract_average_4x8_sse2(int16_t *pred_buf_q3);
+void subtract_average_4x16_sse2(int16_t *pred_buf_q3);
+
+// SSE2 version is optimal for with == 8, we reuse them in AVX2
+void subtract_average_8x4_sse2(int16_t *pred_buf_q3);
+void subtract_average_8x8_sse2(int16_t *pred_buf_q3);
+void subtract_average_8x16_sse2(int16_t *pred_buf_q3);
+void subtract_average_8x32_sse2(int16_t *pred_buf_q3);